Fix timeout issues with fetching members via query_members

This uses the nonce field to properly disambiguate queries. There's
also some redesigning going on behind the scenes and minor clean-up.
Originally I planned on working on this more to account for the more
widespread chunking changes planned for gateway v7 but I realized that
this would indiscriminately slow down everyone else who isn't planning
on working with intents for now.

I will work on the larger chunking changes in the future, should time
allow for it.
This commit is contained in:
Rapptz
2020-05-10 19:30:46 -04:00
parent 5769511779
commit 13a3f760e6
4 changed files with 42 additions and 58 deletions

View File

@ -35,6 +35,9 @@ import weakref
import inspect
import gc
import os
import binascii
from .guild import Guild
from .activity import BaseActivity
from .user import User, ClientUser
@ -62,7 +65,7 @@ log = logging.getLogger(__name__)
ReadyState = namedtuple('ReadyState', ('launch', 'guilds'))
class ConnectionState:
def __init__(self, *, dispatch, chunker, handlers, syncer, http, loop, **options):
def __init__(self, *, dispatch, handlers, syncer, http, loop, **options):
self.loop = loop
self.http = http
self.max_messages = options.get('max_messages', 1000)
@ -70,7 +73,6 @@ class ConnectionState:
self.max_messages = 1000
self.dispatch = dispatch
self.chunker = chunker
self.syncer = syncer
self.is_bot = None
self.handlers = handlers
@ -132,6 +134,9 @@ class ConnectionState:
# to reconnect loops which cause mass allocations and deallocations.
gc.collect()
def get_nonce(self):
return binascii.hexlify(os.urandom(16)).decode('ascii')
def process_listeners(self, listener_type, argument, result):
removed = []
for i, listener in enumerate(self._listeners):
@ -298,6 +303,10 @@ class ConnectionState:
return channel or Object(id=channel_id), guild
async def chunker(self, guild_id, query='', limit=0, *, nonce=None):
ws = self._get_websocket(guild_id) # This is ignored upstream
await ws.request_chunks(guild_id, query=query, limit=limit, nonce=nonce)
async def request_offline_members(self, guilds):
# get all the chunks
chunks = []
@ -307,7 +316,7 @@ class ConnectionState:
# we only want to request ~75 guilds per chunk request.
splits = [guilds[i:i + 75] for i in range(0, len(guilds), 75)]
for split in splits:
await self.chunker(split)
await self.chunker([g.id for g in split])
# wait for the chunks
if chunks:
@ -329,10 +338,11 @@ class ConnectionState:
# and they don't receive GUILD_MEMBER events which make computing
# member_count impossible. The only way to fix it is by limiting
# the limit parameter to 1 to 1000.
future = self.receive_member_query(guild_id, query)
nonce = self.get_nonce()
future = self.receive_member_query(guild_id, nonce)
try:
# start the query operation
await ws.request_chunks(guild_id, query, limit)
await ws.request_chunks(guild_id, query, limit, nonce=nonce)
members = await asyncio.wait_for(future, timeout=5.0)
if cache:
@ -894,8 +904,7 @@ class ConnectionState:
guild._add_member(member)
self.process_listeners(ListenerType.chunk, guild, len(members))
names = [x.name.lower() for x in members]
self.process_listeners(ListenerType.query_members, (guild_id, names), members)
self.process_listeners(ListenerType.query_members, (guild_id, data.get('nonce')), members)
def parse_guild_integrations_update(self, data):
guild = self._get_guild(int(data['guild_id']))
@ -1025,10 +1034,10 @@ class ConnectionState:
self._listeners.append(listener)
return future
def receive_member_query(self, guild_id, query):
def predicate(args, *, guild_id=guild_id, query=query.lower()):
request_guild_id, names = args
return request_guild_id == guild_id and all(n.startswith(query) for n in names)
def receive_member_query(self, guild_id, nonce):
def predicate(args, *, guild_id=guild_id, nonce=nonce):
return args == (guild_id, nonce)
future = self.loop.create_future()
listener = Listener(ListenerType.query_members, future, predicate)
self._listeners.append(listener)
@ -1040,6 +1049,10 @@ class AutoShardedConnectionState(ConnectionState):
self._ready_task = None
self.shard_ids = ()
async def chunker(self, guild_id, query='', limit=0, *, shard_id, nonce=None):
ws = self._get_websocket(shard_id=shard_id)
await ws.request_chunks(guild_id, query=query, limit=limit, nonce=nonce)
async def request_offline_members(self, guilds, *, shard_id):
# get all the chunks
chunks = []
@ -1049,7 +1062,7 @@ class AutoShardedConnectionState(ConnectionState):
# we only want to request ~75 guilds per chunk request.
splits = [guilds[i:i + 75] for i in range(0, len(guilds), 75)]
for split in splits:
await self.chunker(split, shard_id=shard_id)
await self.chunker([g.id for g in split], shard_id=shard_id)
# wait for the chunks
if chunks: