mirror of
				https://github.com/Rapptz/discord.py.git
				synced 2025-10-31 05:23:03 +00:00 
			
		
		
		
	Add before_identify_hook to have finer control over IDENTIFY syncing
This commit is contained in:
		| @@ -223,8 +223,12 @@ class Client: | ||||
|             'ready': self._handle_ready | ||||
|         } | ||||
|  | ||||
|         self._hooks = { | ||||
|             'before_identify': self._call_before_identify_hook | ||||
|         } | ||||
|  | ||||
|         self._connection = ConnectionState(dispatch=self.dispatch, handlers=self._handlers, | ||||
|                                            syncer=self._syncer, http=self.http, loop=self.loop, **options) | ||||
|                                            hooks=self._hooks, syncer=self._syncer, http=self.http, loop=self.loop, **options) | ||||
|  | ||||
|         self._connection.shard_count = self.shard_count | ||||
|         self._closed = False | ||||
| @@ -394,6 +398,36 @@ class Client: | ||||
|  | ||||
|         await self._connection.request_offline_members(guilds) | ||||
|  | ||||
|     # hooks | ||||
|  | ||||
|     async def _call_before_identify_hook(self, shard_id, *, initial=False): | ||||
|         # This hook is an internal hook that actually calls the public one. | ||||
|         # It allows the library to have its own hook without stepping on the | ||||
|         # toes of those who need to override their own hook. | ||||
|         await self.before_identify_hook(shard_id, initial=initial) | ||||
|  | ||||
|     async def before_identify_hook(self, shard_id, *, initial=False): | ||||
|         """|coro| | ||||
|  | ||||
|         A hook that is called before IDENTIFYing a session. This is useful | ||||
|         if you wish to have more control over the synchronization of multiple | ||||
|         IDENTIFYing clients. | ||||
|  | ||||
|         The default implementation sleeps for 5 seconds. | ||||
|  | ||||
|         .. versionadded:: 1.4 | ||||
|  | ||||
|         Parameters | ||||
|         ------------ | ||||
|         shard_id: :class:`int` | ||||
|             The shard ID that requested being IDENTIFY'd | ||||
|         initial: :class:`bool` | ||||
|             Whether this IDENTIFY is the first initial IDENTIFY. | ||||
|         """ | ||||
|  | ||||
|         if not initial: | ||||
|             await asyncio.sleep(5.0) | ||||
|  | ||||
|     # login state management | ||||
|  | ||||
|     async def login(self, token, *, bot=True): | ||||
| @@ -447,7 +481,7 @@ class Client: | ||||
|         await self.close() | ||||
|  | ||||
|     async def _connect(self): | ||||
|         coro = DiscordWebSocket.from_client(self, shard_id=self.shard_id) | ||||
|         coro = DiscordWebSocket.from_client(self, initial=True, shard_id=self.shard_id) | ||||
|         self.ws = await asyncio.wait_for(coro, timeout=180.0) | ||||
|         while True: | ||||
|             try: | ||||
| @@ -455,11 +489,8 @@ class Client: | ||||
|             except ReconnectWebSocket as e: | ||||
|                 log.info('Got a request to %s the websocket.', e.op) | ||||
|                 self.dispatch('disconnect') | ||||
|                 if not e.resume: | ||||
|                     await asyncio.sleep(5.0) | ||||
|  | ||||
|                 coro = DiscordWebSocket.from_client(self, shard_id=self.shard_id, session=self.ws.session_id, | ||||
|                                                     sequence=self.ws.sequence, resume=e.resume) | ||||
|                                                           sequence=self.ws.sequence, resume=e.resume) | ||||
|                 self.ws = await asyncio.wait_for(coro, timeout=180.0) | ||||
|  | ||||
|     async def connect(self, *, reconnect=True): | ||||
|   | ||||
| @@ -250,7 +250,7 @@ class DiscordWebSocket: | ||||
|         return not self.socket.closed | ||||
|  | ||||
|     @classmethod | ||||
|     async def from_client(cls, client, *, gateway=None, shard_id=None, session=None, sequence=None, resume=False): | ||||
|     async def from_client(cls, client, *, initial=False, gateway=None, shard_id=None, session=None, sequence=None, resume=False): | ||||
|         """Creates a main websocket for Discord from a :class:`Client`. | ||||
|  | ||||
|         This is for internal use only. | ||||
| @@ -265,6 +265,8 @@ class DiscordWebSocket: | ||||
|         ws._discord_parsers = client._connection.parsers | ||||
|         ws._dispatch = client.dispatch | ||||
|         ws.gateway = gateway | ||||
|         ws.call_hooks = client._connection.call_hooks | ||||
|         ws._initial_identify = initial | ||||
|         ws.shard_id = shard_id | ||||
|         ws.shard_count = client._connection.shard_count | ||||
|         ws.session_id = session | ||||
| @@ -345,6 +347,7 @@ class DiscordWebSocket: | ||||
|                 'afk': False | ||||
|             } | ||||
|  | ||||
|         await self.call_hooks('before_identify', self.shard_id, initial=self._initial_identify) | ||||
|         await self.send_as_json(payload) | ||||
|         log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id) | ||||
|  | ||||
|   | ||||
| @@ -96,9 +96,6 @@ class Shard: | ||||
|             self._task.cancel() | ||||
|  | ||||
|         log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id) | ||||
|         if not exc.resume: | ||||
|             await asyncio.sleep(5.0) | ||||
|  | ||||
|         coro = DiscordWebSocket.from_client(self._client, resume=exc.resume, shard_id=self.id, | ||||
|                                             session=self.ws.session_id, sequence=self.ws.sequence) | ||||
|         self._dispatch('disconnect') | ||||
| @@ -144,7 +141,7 @@ class AutoShardedClient(Client): | ||||
|  | ||||
|         self._connection = AutoShardedConnectionState(dispatch=self.dispatch, | ||||
|                                                       handlers=self._handlers, syncer=self._syncer, | ||||
|                                                       http=self.http, loop=self.loop, **kwargs) | ||||
|                                                       hooks=self._hooks, http=self.http, loop=self.loop, **kwargs) | ||||
|  | ||||
|         # instead of a single websocket, we have multiple | ||||
|         # the key is the shard_id | ||||
| @@ -208,12 +205,12 @@ class AutoShardedClient(Client): | ||||
|             sub_guilds = list(sub_guilds) | ||||
|             await self._connection.request_offline_members(sub_guilds, shard_id=shard_id) | ||||
|  | ||||
|     async def launch_shard(self, gateway, shard_id): | ||||
|     async def launch_shard(self, gateway, shard_id, *, initial=False): | ||||
|         try: | ||||
|             coro = DiscordWebSocket.from_client(self, gateway=gateway, shard_id=shard_id) | ||||
|             coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id) | ||||
|             ws = await asyncio.wait_for(coro, timeout=180.0) | ||||
|         except Exception: | ||||
|             log.info('Failed to connect for shard_id: %s. Retrying...', shard_id) | ||||
|             log.exception('Failed to connect for shard_id: %s. Retrying...', shard_id) | ||||
|             await asyncio.sleep(5.0) | ||||
|             return await self.launch_shard(gateway, shard_id) | ||||
|  | ||||
| @@ -232,11 +229,9 @@ class AutoShardedClient(Client): | ||||
|         shard_ids = self.shard_ids if self.shard_ids else range(self.shard_count) | ||||
|         self._connection.shard_ids = shard_ids | ||||
|  | ||||
|         last_shard_id = shard_ids[-1] | ||||
|         for shard_id in shard_ids: | ||||
|             await self.launch_shard(gateway, shard_id) | ||||
|             if shard_id != last_shard_id: | ||||
|                 await asyncio.sleep(5.0) | ||||
|             initial = shard_id == shard_ids[0] | ||||
|             await self.launch_shard(gateway, shard_id, initial=initial) | ||||
|  | ||||
|         self._connection.shards_launched.set() | ||||
|  | ||||
|   | ||||
| @@ -64,7 +64,7 @@ log = logging.getLogger(__name__) | ||||
| ReadyState = namedtuple('ReadyState', ('launch', 'guilds')) | ||||
|  | ||||
| class ConnectionState: | ||||
|     def __init__(self, *, dispatch, handlers, syncer, http, loop, **options): | ||||
|     def __init__(self, *, dispatch, handlers, hooks, syncer, http, loop, **options): | ||||
|         self.loop = loop | ||||
|         self.http = http | ||||
|         self.max_messages = options.get('max_messages', 1000) | ||||
| @@ -75,6 +75,7 @@ class ConnectionState: | ||||
|         self.syncer = syncer | ||||
|         self.is_bot = None | ||||
|         self.handlers = handlers | ||||
|         self.hooks = hooks | ||||
|         self.shard_count = None | ||||
|         self._ready_task = None | ||||
|         self._fetch_offline = options.get('fetch_offline_members', True) | ||||
| @@ -170,6 +171,14 @@ class ConnectionState: | ||||
|         else: | ||||
|             func(*args, **kwargs) | ||||
|  | ||||
|     async def call_hooks(self, key, *args, **kwargs): | ||||
|         try: | ||||
|             coro = self.hooks[key] | ||||
|         except KeyError: | ||||
|             pass | ||||
|         else: | ||||
|             await coro(*args, **kwargs) | ||||
|  | ||||
|     @property | ||||
|     def self_id(self): | ||||
|         u = self.user | ||||
|   | ||||
		Reference in New Issue
	
	Block a user