Add before_identify_hook to have finer control over IDENTIFY syncing
This commit is contained in:
parent
9c7ae6b9dc
commit
394b514cc9
@ -223,8 +223,12 @@ class Client:
|
|||||||
'ready': self._handle_ready
|
'ready': self._handle_ready
|
||||||
}
|
}
|
||||||
|
|
||||||
|
self._hooks = {
|
||||||
|
'before_identify': self._call_before_identify_hook
|
||||||
|
}
|
||||||
|
|
||||||
self._connection = ConnectionState(dispatch=self.dispatch, handlers=self._handlers,
|
self._connection = ConnectionState(dispatch=self.dispatch, handlers=self._handlers,
|
||||||
syncer=self._syncer, http=self.http, loop=self.loop, **options)
|
hooks=self._hooks, syncer=self._syncer, http=self.http, loop=self.loop, **options)
|
||||||
|
|
||||||
self._connection.shard_count = self.shard_count
|
self._connection.shard_count = self.shard_count
|
||||||
self._closed = False
|
self._closed = False
|
||||||
@ -394,6 +398,36 @@ class Client:
|
|||||||
|
|
||||||
await self._connection.request_offline_members(guilds)
|
await self._connection.request_offline_members(guilds)
|
||||||
|
|
||||||
|
# hooks
|
||||||
|
|
||||||
|
async def _call_before_identify_hook(self, shard_id, *, initial=False):
|
||||||
|
# This hook is an internal hook that actually calls the public one.
|
||||||
|
# It allows the library to have its own hook without stepping on the
|
||||||
|
# toes of those who need to override their own hook.
|
||||||
|
await self.before_identify_hook(shard_id, initial=initial)
|
||||||
|
|
||||||
|
async def before_identify_hook(self, shard_id, *, initial=False):
|
||||||
|
"""|coro|
|
||||||
|
|
||||||
|
A hook that is called before IDENTIFYing a session. This is useful
|
||||||
|
if you wish to have more control over the synchronization of multiple
|
||||||
|
IDENTIFYing clients.
|
||||||
|
|
||||||
|
The default implementation sleeps for 5 seconds.
|
||||||
|
|
||||||
|
.. versionadded:: 1.4
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
------------
|
||||||
|
shard_id: :class:`int`
|
||||||
|
The shard ID that requested being IDENTIFY'd
|
||||||
|
initial: :class:`bool`
|
||||||
|
Whether this IDENTIFY is the first initial IDENTIFY.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not initial:
|
||||||
|
await asyncio.sleep(5.0)
|
||||||
|
|
||||||
# login state management
|
# login state management
|
||||||
|
|
||||||
async def login(self, token, *, bot=True):
|
async def login(self, token, *, bot=True):
|
||||||
@ -447,7 +481,7 @@ class Client:
|
|||||||
await self.close()
|
await self.close()
|
||||||
|
|
||||||
async def _connect(self):
|
async def _connect(self):
|
||||||
coro = DiscordWebSocket.from_client(self, shard_id=self.shard_id)
|
coro = DiscordWebSocket.from_client(self, initial=True, shard_id=self.shard_id)
|
||||||
self.ws = await asyncio.wait_for(coro, timeout=180.0)
|
self.ws = await asyncio.wait_for(coro, timeout=180.0)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@ -455,11 +489,8 @@ class Client:
|
|||||||
except ReconnectWebSocket as e:
|
except ReconnectWebSocket as e:
|
||||||
log.info('Got a request to %s the websocket.', e.op)
|
log.info('Got a request to %s the websocket.', e.op)
|
||||||
self.dispatch('disconnect')
|
self.dispatch('disconnect')
|
||||||
if not e.resume:
|
|
||||||
await asyncio.sleep(5.0)
|
|
||||||
|
|
||||||
coro = DiscordWebSocket.from_client(self, shard_id=self.shard_id, session=self.ws.session_id,
|
coro = DiscordWebSocket.from_client(self, shard_id=self.shard_id, session=self.ws.session_id,
|
||||||
sequence=self.ws.sequence, resume=e.resume)
|
sequence=self.ws.sequence, resume=e.resume)
|
||||||
self.ws = await asyncio.wait_for(coro, timeout=180.0)
|
self.ws = await asyncio.wait_for(coro, timeout=180.0)
|
||||||
|
|
||||||
async def connect(self, *, reconnect=True):
|
async def connect(self, *, reconnect=True):
|
||||||
|
@ -250,7 +250,7 @@ class DiscordWebSocket:
|
|||||||
return not self.socket.closed
|
return not self.socket.closed
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_client(cls, client, *, gateway=None, shard_id=None, session=None, sequence=None, resume=False):
|
async def from_client(cls, client, *, initial=False, gateway=None, shard_id=None, session=None, sequence=None, resume=False):
|
||||||
"""Creates a main websocket for Discord from a :class:`Client`.
|
"""Creates a main websocket for Discord from a :class:`Client`.
|
||||||
|
|
||||||
This is for internal use only.
|
This is for internal use only.
|
||||||
@ -265,6 +265,8 @@ class DiscordWebSocket:
|
|||||||
ws._discord_parsers = client._connection.parsers
|
ws._discord_parsers = client._connection.parsers
|
||||||
ws._dispatch = client.dispatch
|
ws._dispatch = client.dispatch
|
||||||
ws.gateway = gateway
|
ws.gateway = gateway
|
||||||
|
ws.call_hooks = client._connection.call_hooks
|
||||||
|
ws._initial_identify = initial
|
||||||
ws.shard_id = shard_id
|
ws.shard_id = shard_id
|
||||||
ws.shard_count = client._connection.shard_count
|
ws.shard_count = client._connection.shard_count
|
||||||
ws.session_id = session
|
ws.session_id = session
|
||||||
@ -345,6 +347,7 @@ class DiscordWebSocket:
|
|||||||
'afk': False
|
'afk': False
|
||||||
}
|
}
|
||||||
|
|
||||||
|
await self.call_hooks('before_identify', self.shard_id, initial=self._initial_identify)
|
||||||
await self.send_as_json(payload)
|
await self.send_as_json(payload)
|
||||||
log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id)
|
log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id)
|
||||||
|
|
||||||
|
@ -96,9 +96,6 @@ class Shard:
|
|||||||
self._task.cancel()
|
self._task.cancel()
|
||||||
|
|
||||||
log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
|
log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id)
|
||||||
if not exc.resume:
|
|
||||||
await asyncio.sleep(5.0)
|
|
||||||
|
|
||||||
coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id,
|
coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id,
|
||||||
session=self.ws.session_id, sequence=self.ws.sequence)
|
session=self.ws.session_id, sequence=self.ws.sequence)
|
||||||
self._dispatch('disconnect')
|
self._dispatch('disconnect')
|
||||||
@ -144,7 +141,7 @@ class AutoShardedClient(Client):
|
|||||||
|
|
||||||
self._connection = AutoShardedConnectionState(dispatch=self.dispatch,
|
self._connection = AutoShardedConnectionState(dispatch=self.dispatch,
|
||||||
handlers=self._handlers, syncer=self._syncer,
|
handlers=self._handlers, syncer=self._syncer,
|
||||||
http=self.http, loop=self.loop, **kwargs)
|
hooks=self._hooks, http=self.http, loop=self.loop, **kwargs)
|
||||||
|
|
||||||
# instead of a single websocket, we have multiple
|
# instead of a single websocket, we have multiple
|
||||||
# the key is the shard_id
|
# the key is the shard_id
|
||||||
@ -208,12 +205,12 @@ class AutoShardedClient(Client):
|
|||||||
sub_guilds = list(sub_guilds)
|
sub_guilds = list(sub_guilds)
|
||||||
await self._connection.request_offline_members(sub_guilds, shard_id=shard_id)
|
await self._connection.request_offline_members(sub_guilds, shard_id=shard_id)
|
||||||
|
|
||||||
async def launch_shard(self, gateway, shard_id):
|
async def launch_shard(self, gateway, shard_id, *, initial=False):
|
||||||
try:
|
try:
|
||||||
coro = DiscordWebSocket.from_client(self, gateway=gateway, shard_id=shard_id)
|
coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id)
|
||||||
ws = await asyncio.wait_for(coro, timeout=180.0)
|
ws = await asyncio.wait_for(coro, timeout=180.0)
|
||||||
except Exception:
|
except Exception:
|
||||||
log.info('Failed to connect for shard_id: %s. Retrying...', shard_id)
|
log.exception('Failed to connect for shard_id: %s. Retrying...', shard_id)
|
||||||
await asyncio.sleep(5.0)
|
await asyncio.sleep(5.0)
|
||||||
return await self.launch_shard(gateway, shard_id)
|
return await self.launch_shard(gateway, shard_id)
|
||||||
|
|
||||||
@ -232,11 +229,9 @@ class AutoShardedClient(Client):
|
|||||||
shard_ids = self.shard_ids if self.shard_ids else range(self.shard_count)
|
shard_ids = self.shard_ids if self.shard_ids else range(self.shard_count)
|
||||||
self._connection.shard_ids = shard_ids
|
self._connection.shard_ids = shard_ids
|
||||||
|
|
||||||
last_shard_id = shard_ids[-1]
|
|
||||||
for shard_id in shard_ids:
|
for shard_id in shard_ids:
|
||||||
await self.launch_shard(gateway, shard_id)
|
initial = shard_id == shard_ids[0]
|
||||||
if shard_id != last_shard_id:
|
await self.launch_shard(gateway, shard_id, initial=initial)
|
||||||
await asyncio.sleep(5.0)
|
|
||||||
|
|
||||||
self._connection.shards_launched.set()
|
self._connection.shards_launched.set()
|
||||||
|
|
||||||
|
@ -64,7 +64,7 @@ log = logging.getLogger(__name__)
|
|||||||
ReadyState = namedtuple('ReadyState', ('launch', 'guilds'))
|
ReadyState = namedtuple('ReadyState', ('launch', 'guilds'))
|
||||||
|
|
||||||
class ConnectionState:
|
class ConnectionState:
|
||||||
def __init__(self, *, dispatch, handlers, syncer, http, loop, **options):
|
def __init__(self, *, dispatch, handlers, hooks, syncer, http, loop, **options):
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
self.http = http
|
self.http = http
|
||||||
self.max_messages = options.get('max_messages', 1000)
|
self.max_messages = options.get('max_messages', 1000)
|
||||||
@ -75,6 +75,7 @@ class ConnectionState:
|
|||||||
self.syncer = syncer
|
self.syncer = syncer
|
||||||
self.is_bot = None
|
self.is_bot = None
|
||||||
self.handlers = handlers
|
self.handlers = handlers
|
||||||
|
self.hooks = hooks
|
||||||
self.shard_count = None
|
self.shard_count = None
|
||||||
self._ready_task = None
|
self._ready_task = None
|
||||||
self._fetch_offline = options.get('fetch_offline_members', True)
|
self._fetch_offline = options.get('fetch_offline_members', True)
|
||||||
@ -170,6 +171,14 @@ class ConnectionState:
|
|||||||
else:
|
else:
|
||||||
func(*args, **kwargs)
|
func(*args, **kwargs)
|
||||||
|
|
||||||
|
async def call_hooks(self, key, *args, **kwargs):
|
||||||
|
try:
|
||||||
|
coro = self.hooks[key]
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
await coro(*args, **kwargs)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def self_id(self):
|
def self_id(self):
|
||||||
u = self.user
|
u = self.user
|
||||||
|
Loading…
x
Reference in New Issue
Block a user