Implement AutoShardedClient for transparent sharding.
This allows people to run their >2,500 guild bot in a single process without the headaches of IPC/RPC or much difficulty.
This commit is contained in:
@ -43,6 +43,7 @@ import datetime
|
||||
import asyncio
|
||||
import logging
|
||||
import weakref
|
||||
import itertools
|
||||
|
||||
class ListenerType(enum.Enum):
|
||||
chunk = 0
|
||||
@ -60,13 +61,12 @@ class ConnectionState:
|
||||
self.chunker = chunker
|
||||
self.syncer = syncer
|
||||
self.is_bot = None
|
||||
self.shard_count = None
|
||||
self._listeners = []
|
||||
self.clear()
|
||||
|
||||
def clear(self):
|
||||
self.user = None
|
||||
self.sequence = None
|
||||
self.session_id = None
|
||||
self._users = weakref.WeakValueDictionary()
|
||||
self._calls = {}
|
||||
self._emojis = {}
|
||||
@ -355,7 +355,8 @@ class ConnectionState:
|
||||
# the reason we're doing this is so it's also removed from the
|
||||
# private channel by user cache as well
|
||||
channel = self._get_private_channel(channel_id)
|
||||
self._remove_private_channel(channel)
|
||||
if channel is not None:
|
||||
self._remove_private_channel(channel)
|
||||
|
||||
def parse_channel_update(self, data):
|
||||
channel_type = try_enum(ChannelType, data.get('type'))
|
||||
@ -701,3 +702,76 @@ class ConnectionState:
|
||||
listener = Listener(ListenerType.chunk, future, lambda s: s.id == guild_id)
|
||||
self._listeners.append(listener)
|
||||
return future
|
||||
|
||||
class AutoShardedConnectionState(ConnectionState):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[])
|
||||
self._ready_task = None
|
||||
|
||||
@asyncio.coroutine
|
||||
def _delay_ready(self):
|
||||
launch = self._ready_state.launch
|
||||
while not launch.is_set():
|
||||
# this snippet of code is basically waiting 2 seconds
|
||||
# until the last GUILD_CREATE was sent
|
||||
launch.set()
|
||||
yield from asyncio.sleep(2.0 * self.shard_count, loop=self.loop)
|
||||
|
||||
guilds = sorted(self._ready_state.guilds, key=lambda g: g.shard_id)
|
||||
|
||||
# we only want to request ~75 guilds per chunk request.
|
||||
# we also want to split the chunks per shard_id
|
||||
for shard_id, sub_guilds in itertools.groupby(guilds, key=lambda g: g.shard_id):
|
||||
sub_guilds = list(sub_guilds)
|
||||
|
||||
# split chunks by shard ID
|
||||
chunks = []
|
||||
for guild in sub_guilds:
|
||||
chunks.extend(self.chunks_needed(guild))
|
||||
|
||||
splits = [sub_guilds[i:i + 75] for i in range(0, len(sub_guilds), 75)]
|
||||
for split in splits:
|
||||
yield from self.chunker(split, shard_id=shard_id)
|
||||
|
||||
# wait for the chunks
|
||||
if chunks:
|
||||
try:
|
||||
yield from asyncio.wait(chunks, timeout=len(chunks) * 30.0, loop=self.loop)
|
||||
except asyncio.TimeoutError:
|
||||
log.info('Somehow timed out waiting for chunks for %s shard_id' % shard_id)
|
||||
|
||||
self.dispatch('shard_ready', shard_id)
|
||||
|
||||
# sleep a second for every shard ID.
|
||||
# yield from asyncio.sleep(1.0, loop=self.loop)
|
||||
|
||||
# remove the state
|
||||
try:
|
||||
del self._ready_state
|
||||
except AttributeError:
|
||||
pass # already been deleted somehow
|
||||
|
||||
# regular users cannot shard so we won't worry about it here.
|
||||
|
||||
# dispatch the event
|
||||
self.dispatch('ready')
|
||||
|
||||
def parse_ready(self, data):
|
||||
if not hasattr(self, '_ready_state'):
|
||||
self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[])
|
||||
|
||||
self.user = self.store_user(data['user'])
|
||||
|
||||
guilds = self._ready_state.guilds
|
||||
for guild_data in data['guilds']:
|
||||
guild = self._add_guild_from_data(guild_data)
|
||||
if not self.is_bot or guild.large:
|
||||
guilds.append(guild)
|
||||
|
||||
for pm in data.get('private_channels', []):
|
||||
factory, _ = _channel_factory(pm['type'])
|
||||
self._add_private_channel(factory(me=self.user, data=pm, state=self))
|
||||
|
||||
if self._ready_task is None:
|
||||
self._ready_task = compat.create_task(self._delay_ready(), loop=self.loop)
|
||||
|
Reference in New Issue
Block a user