Add before_identify_hook to have finer control over IDENTIFY syncing
This commit is contained in:
@@ -223,8 +223,12 @@ class Client:
|
||||
'ready': self._handle_ready
|
||||
}
|
||||
|
||||
self._hooks = {
|
||||
'before_identify': self._call_before_identify_hook
|
||||
}
|
||||
|
||||
self._connection = ConnectionState(dispatch=self.dispatch, handlers=self._handlers,
|
||||
syncer=self._syncer, http=self.http, loop=self.loop, **options)
|
||||
hooks=self._hooks, syncer=self._syncer, http=self.http, loop=self.loop, **options)
|
||||
|
||||
self._connection.shard_count = self.shard_count
|
||||
self._closed = False
|
||||
@@ -394,6 +398,36 @@ class Client:
|
||||
|
||||
await self._connection.request_offline_members(guilds)
|
||||
|
||||
# hooks
|
||||
|
||||
async def _call_before_identify_hook(self, shard_id, *, initial=False):
|
||||
# This hook is an internal hook that actually calls the public one.
|
||||
# It allows the library to have its own hook without stepping on the
|
||||
# toes of those who need to override their own hook.
|
||||
await self.before_identify_hook(shard_id, initial=initial)
|
||||
|
||||
async def before_identify_hook(self, shard_id, *, initial=False):
|
||||
"""|coro|
|
||||
|
||||
A hook that is called before IDENTIFYing a session. This is useful
|
||||
if you wish to have more control over the synchronization of multiple
|
||||
IDENTIFYing clients.
|
||||
|
||||
The default implementation sleeps for 5 seconds.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
|
||||
Parameters
|
||||
------------
|
||||
shard_id: :class:`int`
|
||||
The shard ID that requested being IDENTIFY'd
|
||||
initial: :class:`bool`
|
||||
Whether this IDENTIFY is the first initial IDENTIFY.
|
||||
"""
|
||||
|
||||
if not initial:
|
||||
await asyncio.sleep(5.0)
|
||||
|
||||
# login state management
|
||||
|
||||
async def login(self, token, *, bot=True):
|
||||
@@ -447,7 +481,7 @@ class Client:
|
||||
await self.close()
|
||||
|
||||
async def _connect(self):
|
||||
coro = DiscordWebSocket.from_client(self, shard_id=self.shard_id)
|
||||
coro = DiscordWebSocket.from_client(self, initial=True, shard_id=self.shard_id)
|
||||
self.ws = await asyncio.wait_for(coro, timeout=180.0)
|
||||
while True:
|
||||
try:
|
||||
@@ -455,9 +489,6 @@ class Client:
|
||||
except ReconnectWebSocket as e:
|
||||
log.info('Got a request to %s the websocket.', e.op)
|
||||
self.dispatch('disconnect')
|
||||
if not e.resume:
|
||||
await asyncio.sleep(5.0)
|
||||
|
||||
coro = DiscordWebSocket.from_client(self, shard_id=self.shard_id, session=self.ws.session_id,
|
||||
sequence=self.ws.sequence, resume=e.resume)
|
||||
self.ws = await asyncio.wait_for(coro, timeout=180.0)
|
||||
|
@@ -250,7 +250,7 @@ class DiscordWebSocket:
|
||||
return not self.socket.closed
|
||||
|
||||
@classmethod
|
||||
async def from_client(cls, client, *, gateway=None, shard_id=None, session=None, sequence=None, resume=False):
|
||||
async def from_client(cls, client, *, initial=False, gateway=None, shard_id=None, session=None, sequence=None, resume=False):
|
||||
"""Creates a main websocket for Discord from a :class:`Client`.
|
||||
|
||||
This is for internal use only.
|
||||
@@ -265,6 +265,8 @@ class DiscordWebSocket:
|
||||
ws._discord_parsers = client._connection.parsers
|
||||
ws._dispatch = client.dispatch
|
||||
ws.gateway = gateway
|
||||
ws.call_hooks = client._connection.call_hooks
|
||||
ws._initial_identify = initial
|
||||
ws.shard_id = shard_id
|
||||
ws.shard_count = client._connection.shard_count
|
||||
ws.session_id = session
|
||||
@@ -345,6 +347,7 @@ class DiscordWebSocket:
|
||||
'afk': False
|
||||
}
|
||||
|
||||
await self.call_hooks('before_identify', self.shard_id, initial=self._initial_identify)
|
||||
await self.send_as_json(payload)
|
||||
log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id)
|
||||
|
||||
|
@@ -96,9 +96,6 @@ class Shard:
|
||||
self._task.cancel()
|
||||
|
||||
log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
|
||||
if not exc.resume:
|
||||
await asyncio.sleep(5.0)
|
||||
|
||||
coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id,
|
||||
session=self.ws.session_id, sequence=self.ws.sequence)
|
||||
self._dispatch('disconnect')
|
||||
@@ -144,7 +141,7 @@ class AutoShardedClient(Client):
|
||||
|
||||
self._connection = AutoShardedConnectionState(dispatch=self.dispatch,
|
||||
handlers=self._handlers, syncer=self._syncer,
|
||||
http=self.http, loop=self.loop, **kwargs)
|
||||
hooks=self._hooks, http=self.http, loop=self.loop, **kwargs)
|
||||
|
||||
# instead of a single websocket, we have multiple
|
||||
# the key is the shard_id
|
||||
@@ -208,12 +205,12 @@ class AutoShardedClient(Client):
|
||||
sub_guilds = list(sub_guilds)
|
||||
await self._connection.request_offline_members(sub_guilds, shard_id=shard_id)
|
||||
|
||||
async def launch_shard(self, gateway, shard_id):
|
||||
async def launch_shard(self, gateway, shard_id, *, initial=False):
|
||||
try:
|
||||
coro = DiscordWebSocket.from_client(self, gateway=gateway, shard_id=shard_id)
|
||||
coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id)
|
||||
ws = await asyncio.wait_for(coro, timeout=180.0)
|
||||
except Exception:
|
||||
log.info('Failed to connect for shard_id: %s. Retrying...', shard_id)
|
||||
log.exception('Failed to connect for shard_id: %s. Retrying...', shard_id)
|
||||
await asyncio.sleep(5.0)
|
||||
return await self.launch_shard(gateway, shard_id)
|
||||
|
||||
@@ -232,11 +229,9 @@ class AutoShardedClient(Client):
|
||||
shard_ids = self.shard_ids if self.shard_ids else range(self.shard_count)
|
||||
self._connection.shard_ids = shard_ids
|
||||
|
||||
last_shard_id = shard_ids[-1]
|
||||
for shard_id in shard_ids:
|
||||
await self.launch_shard(gateway, shard_id)
|
||||
if shard_id != last_shard_id:
|
||||
await asyncio.sleep(5.0)
|
||||
initial = shard_id == shard_ids[0]
|
||||
await self.launch_shard(gateway, shard_id, initial=initial)
|
||||
|
||||
self._connection.shards_launched.set()
|
||||
|
||||
|
@@ -64,7 +64,7 @@ log = logging.getLogger(__name__)
|
||||
ReadyState = namedtuple('ReadyState', ('launch', 'guilds'))
|
||||
|
||||
class ConnectionState:
|
||||
def __init__(self, *, dispatch, handlers, syncer, http, loop, **options):
|
||||
def __init__(self, *, dispatch, handlers, hooks, syncer, http, loop, **options):
|
||||
self.loop = loop
|
||||
self.http = http
|
||||
self.max_messages = options.get('max_messages', 1000)
|
||||
@@ -75,6 +75,7 @@ class ConnectionState:
|
||||
self.syncer = syncer
|
||||
self.is_bot = None
|
||||
self.handlers = handlers
|
||||
self.hooks = hooks
|
||||
self.shard_count = None
|
||||
self._ready_task = None
|
||||
self._fetch_offline = options.get('fetch_offline_members', True)
|
||||
@@ -170,6 +171,14 @@ class ConnectionState:
|
||||
else:
|
||||
func(*args, **kwargs)
|
||||
|
||||
async def call_hooks(self, key, *args, **kwargs):
|
||||
try:
|
||||
coro = self.hooks[key]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
await coro(*args, **kwargs)
|
||||
|
||||
@property
|
||||
def self_id(self):
|
||||
u = self.user
|
||||
|
Reference in New Issue
Block a user