mirror of
				https://github.com/Rapptz/discord.py.git
				synced 2025-10-22 00:13:01 +00:00 
			
		
		
		
	Fix timeout issues with fetching members via query_members
This uses the nonce field to properly disambiguate queries. There's also some redesigning going on behind the scenes and minor clean-up. Originally I planned on working on this more to account for the more widespread chunking changes planned for gateway v7 but I realized that this would indiscriminately slow down everyone else who isn't planning on working with intents for now. I will work on the larger chunking changes in the future, should time allow for it.
This commit is contained in:
		| @@ -223,13 +223,13 @@ class Client: | ||||
|             'ready': self._handle_ready | ||||
|         } | ||||
|  | ||||
|         self._connection = ConnectionState(dispatch=self.dispatch, chunker=self._chunker, handlers=self._handlers, | ||||
|         self._connection = ConnectionState(dispatch=self.dispatch, handlers=self._handlers, | ||||
|                                            syncer=self._syncer, http=self.http, loop=self.loop, **options) | ||||
|  | ||||
|         self._connection.shard_count = self.shard_count | ||||
|         self._closed = False | ||||
|         self._ready = asyncio.Event() | ||||
|         self._connection._get_websocket = lambda g: self.ws | ||||
|         self._connection._get_websocket = self._get_websocket | ||||
|  | ||||
|         if VoiceClient.warn_nacl: | ||||
|             VoiceClient.warn_nacl = False | ||||
| @@ -237,26 +237,12 @@ class Client: | ||||
|  | ||||
|     # internals | ||||
|  | ||||
|     def _get_websocket(self, guild_id=None, *, shard_id=None): | ||||
|         return self.ws | ||||
|  | ||||
|     async def _syncer(self, guilds): | ||||
|         await self.ws.request_sync(guilds) | ||||
|  | ||||
|     async def _chunker(self, guild): | ||||
|         try: | ||||
|             guild_id = guild.id | ||||
|         except AttributeError: | ||||
|             guild_id = [s.id for s in guild] | ||||
|  | ||||
|         payload = { | ||||
|             'op': 8, | ||||
|             'd': { | ||||
|                 'guild_id': guild_id, | ||||
|                 'query': '', | ||||
|                 'limit': 0 | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         await self.ws.send_as_json(payload) | ||||
|  | ||||
|     def _handle_ready(self): | ||||
|         self._ready.set() | ||||
|  | ||||
|   | ||||
| @@ -535,15 +535,19 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): | ||||
|         } | ||||
|         await self.send_as_json(payload) | ||||
|  | ||||
|     async def request_chunks(self, guild_id, query, limit): | ||||
|     async def request_chunks(self, guild_id, query, limit, *, nonce=None): | ||||
|         payload = { | ||||
|             'op': self.REQUEST_MEMBERS, | ||||
|             'd': { | ||||
|                 'guild_id': str(guild_id), | ||||
|                 'guild_id': guild_id, | ||||
|                 'query': query, | ||||
|                 'limit': limit | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         if nonce: | ||||
|             payload['d']['nonce'] = nonce | ||||
|  | ||||
|         await self.send_as_json(payload) | ||||
|  | ||||
|     async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False): | ||||
|   | ||||
| @@ -126,38 +126,19 @@ class AutoShardedClient(Client): | ||||
|             elif not isinstance(self.shard_ids, (list, tuple)): | ||||
|                 raise ClientException('shard_ids parameter must be a list or a tuple.') | ||||
|  | ||||
|         self._connection = AutoShardedConnectionState(dispatch=self.dispatch, chunker=self._chunker, | ||||
|         self._connection = AutoShardedConnectionState(dispatch=self.dispatch, | ||||
|                                                       handlers=self._handlers, syncer=self._syncer, | ||||
|                                                       http=self.http, loop=self.loop, **kwargs) | ||||
|  | ||||
|         # instead of a single websocket, we have multiple | ||||
|         # the key is the shard_id | ||||
|         self.shards = {} | ||||
|         self._connection._get_websocket = self._get_websocket | ||||
|  | ||||
|         def _get_websocket(guild_id): | ||||
|             i = (guild_id >> 22) % self.shard_count | ||||
|             return self.shards[i].ws | ||||
|  | ||||
|         self._connection._get_websocket = _get_websocket | ||||
|  | ||||
|     async def _chunker(self, guild, *, shard_id=None): | ||||
|         try: | ||||
|             guild_id = guild.id | ||||
|             shard_id = shard_id or guild.shard_id | ||||
|         except AttributeError: | ||||
|             guild_id = [s.id for s in guild] | ||||
|  | ||||
|         payload = { | ||||
|             'op': 8, | ||||
|             'd': { | ||||
|                 'guild_id': guild_id, | ||||
|                 'query': '', | ||||
|                 'limit': 0 | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         ws = self.shards[shard_id].ws | ||||
|         await ws.send_as_json(payload) | ||||
|     def _get_websocket(self, guild_id=None, *, shard_id=None): | ||||
|         if shard_id is None: | ||||
|             shard_id = (guild_id >> 22) % self.shard_count | ||||
|         return self.shards[shard_id].ws | ||||
|  | ||||
|     @property | ||||
|     def latency(self): | ||||
|   | ||||
| @@ -35,6 +35,9 @@ import weakref | ||||
| import inspect | ||||
| import gc | ||||
|  | ||||
| import os | ||||
| import binascii | ||||
|  | ||||
| from .guild import Guild | ||||
| from .activity import BaseActivity | ||||
| from .user import User, ClientUser | ||||
| @@ -62,7 +65,7 @@ log = logging.getLogger(__name__) | ||||
| ReadyState = namedtuple('ReadyState', ('launch', 'guilds')) | ||||
|  | ||||
| class ConnectionState: | ||||
|     def __init__(self, *, dispatch, chunker, handlers, syncer, http, loop, **options): | ||||
|     def __init__(self, *, dispatch, handlers, syncer, http, loop, **options): | ||||
|         self.loop = loop | ||||
|         self.http = http | ||||
|         self.max_messages = options.get('max_messages', 1000) | ||||
| @@ -70,7 +73,6 @@ class ConnectionState: | ||||
|             self.max_messages = 1000 | ||||
|  | ||||
|         self.dispatch = dispatch | ||||
|         self.chunker = chunker | ||||
|         self.syncer = syncer | ||||
|         self.is_bot = None | ||||
|         self.handlers = handlers | ||||
| @@ -132,6 +134,9 @@ class ConnectionState: | ||||
|         # to reconnect loops which cause mass allocations and deallocations. | ||||
|         gc.collect() | ||||
|  | ||||
|     def get_nonce(self): | ||||
|         return binascii.hexlify(os.urandom(16)).decode('ascii') | ||||
|  | ||||
|     def process_listeners(self, listener_type, argument, result): | ||||
|         removed = [] | ||||
|         for i, listener in enumerate(self._listeners): | ||||
| @@ -298,6 +303,10 @@ class ConnectionState: | ||||
|  | ||||
|         return channel or Object(id=channel_id), guild | ||||
|  | ||||
|     async def chunker(self, guild_id, query='', limit=0, *, nonce=None): | ||||
|         ws = self._get_websocket(guild_id) # This is ignored upstream | ||||
|         await ws.request_chunks(guild_id, query=query, limit=limit, nonce=nonce) | ||||
|  | ||||
|     async def request_offline_members(self, guilds): | ||||
|         # get all the chunks | ||||
|         chunks = [] | ||||
| @@ -307,7 +316,7 @@ class ConnectionState: | ||||
|         # we only want to request ~75 guilds per chunk request. | ||||
|         splits = [guilds[i:i + 75] for i in range(0, len(guilds), 75)] | ||||
|         for split in splits: | ||||
|             await self.chunker(split) | ||||
|             await self.chunker([g.id for g in split]) | ||||
|  | ||||
|         # wait for the chunks | ||||
|         if chunks: | ||||
| @@ -329,10 +338,11 @@ class ConnectionState: | ||||
|         # and they don't receive GUILD_MEMBER events which make computing | ||||
|         # member_count impossible. The only way to fix it is by limiting | ||||
|         # the limit parameter to 1 to 1000. | ||||
|         future = self.receive_member_query(guild_id, query) | ||||
|         nonce = self.get_nonce() | ||||
|         future = self.receive_member_query(guild_id, nonce) | ||||
|         try: | ||||
|             # start the query operation | ||||
|             await ws.request_chunks(guild_id, query, limit) | ||||
|             await ws.request_chunks(guild_id, query, limit, nonce=nonce) | ||||
|             members = await asyncio.wait_for(future, timeout=5.0) | ||||
|  | ||||
|             if cache: | ||||
| @@ -894,8 +904,7 @@ class ConnectionState: | ||||
|                     guild._add_member(member) | ||||
|  | ||||
|         self.process_listeners(ListenerType.chunk, guild, len(members)) | ||||
|         names = [x.name.lower() for x in members] | ||||
|         self.process_listeners(ListenerType.query_members, (guild_id, names), members) | ||||
|         self.process_listeners(ListenerType.query_members, (guild_id, data.get('nonce')), members) | ||||
|  | ||||
|     def parse_guild_integrations_update(self, data): | ||||
|         guild = self._get_guild(int(data['guild_id'])) | ||||
| @@ -1025,10 +1034,10 @@ class ConnectionState: | ||||
|         self._listeners.append(listener) | ||||
|         return future | ||||
|  | ||||
|     def receive_member_query(self, guild_id, query): | ||||
|         def predicate(args, *, guild_id=guild_id, query=query.lower()): | ||||
|             request_guild_id, names = args | ||||
|             return request_guild_id == guild_id and all(n.startswith(query) for n in names) | ||||
|     def receive_member_query(self, guild_id, nonce): | ||||
|         def predicate(args, *, guild_id=guild_id, nonce=nonce): | ||||
|             return args == (guild_id, nonce) | ||||
|  | ||||
|         future = self.loop.create_future() | ||||
|         listener = Listener(ListenerType.query_members, future, predicate) | ||||
|         self._listeners.append(listener) | ||||
| @@ -1040,6 +1049,10 @@ class AutoShardedConnectionState(ConnectionState): | ||||
|         self._ready_task = None | ||||
|         self.shard_ids = () | ||||
|  | ||||
|     async def chunker(self, guild_id, query='', limit=0, *, shard_id, nonce=None): | ||||
|         ws = self._get_websocket(shard_id=shard_id) | ||||
|         await ws.request_chunks(guild_id, query=query, limit=limit, nonce=nonce) | ||||
|  | ||||
|     async def request_offline_members(self, guilds, *, shard_id): | ||||
|         # get all the chunks | ||||
|         chunks = [] | ||||
| @@ -1049,7 +1062,7 @@ class AutoShardedConnectionState(ConnectionState): | ||||
|         # we only want to request ~75 guilds per chunk request. | ||||
|         splits = [guilds[i:i + 75] for i in range(0, len(guilds), 75)] | ||||
|         for split in splits: | ||||
|             await self.chunker(split, shard_id=shard_id) | ||||
|             await self.chunker([g.id for g in split], shard_id=shard_id) | ||||
|  | ||||
|         # wait for the chunks | ||||
|         if chunks: | ||||
|   | ||||
		Reference in New Issue
	
	Block a user