Fix timeout issues with fetching members via query_members
This uses the nonce field to properly disambiguate queries. There's also some redesigning going on behind the scenes and minor clean-up. Originally I planned on working on this more to account for the more widespread chunking changes planned for gateway v7 but I realized that this would indiscriminately slow down everyone else who isn't planning on working with intents for now. I will work on the larger chunking changes in the future, should time allow for it.
This commit is contained in:
parent
5769511779
commit
13a3f760e6
@ -223,13 +223,13 @@ class Client:
|
|||||||
'ready': self._handle_ready
|
'ready': self._handle_ready
|
||||||
}
|
}
|
||||||
|
|
||||||
self._connection = ConnectionState(dispatch=self.dispatch, chunker=self._chunker, handlers=self._handlers,
|
self._connection = ConnectionState(dispatch=self.dispatch, handlers=self._handlers,
|
||||||
syncer=self._syncer, http=self.http, loop=self.loop, **options)
|
syncer=self._syncer, http=self.http, loop=self.loop, **options)
|
||||||
|
|
||||||
self._connection.shard_count = self.shard_count
|
self._connection.shard_count = self.shard_count
|
||||||
self._closed = False
|
self._closed = False
|
||||||
self._ready = asyncio.Event()
|
self._ready = asyncio.Event()
|
||||||
self._connection._get_websocket = lambda g: self.ws
|
self._connection._get_websocket = self._get_websocket
|
||||||
|
|
||||||
if VoiceClient.warn_nacl:
|
if VoiceClient.warn_nacl:
|
||||||
VoiceClient.warn_nacl = False
|
VoiceClient.warn_nacl = False
|
||||||
@ -237,26 +237,12 @@ class Client:
|
|||||||
|
|
||||||
# internals
|
# internals
|
||||||
|
|
||||||
|
def _get_websocket(self, guild_id=None, *, shard_id=None):
|
||||||
|
return self.ws
|
||||||
|
|
||||||
async def _syncer(self, guilds):
|
async def _syncer(self, guilds):
|
||||||
await self.ws.request_sync(guilds)
|
await self.ws.request_sync(guilds)
|
||||||
|
|
||||||
async def _chunker(self, guild):
|
|
||||||
try:
|
|
||||||
guild_id = guild.id
|
|
||||||
except AttributeError:
|
|
||||||
guild_id = [s.id for s in guild]
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
'op': 8,
|
|
||||||
'd': {
|
|
||||||
'guild_id': guild_id,
|
|
||||||
'query': '',
|
|
||||||
'limit': 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
await self.ws.send_as_json(payload)
|
|
||||||
|
|
||||||
def _handle_ready(self):
|
def _handle_ready(self):
|
||||||
self._ready.set()
|
self._ready.set()
|
||||||
|
|
||||||
|
@ -535,15 +535,19 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
}
|
}
|
||||||
await self.send_as_json(payload)
|
await self.send_as_json(payload)
|
||||||
|
|
||||||
async def request_chunks(self, guild_id, query, limit):
|
async def request_chunks(self, guild_id, query, limit, *, nonce=None):
|
||||||
payload = {
|
payload = {
|
||||||
'op': self.REQUEST_MEMBERS,
|
'op': self.REQUEST_MEMBERS,
|
||||||
'd': {
|
'd': {
|
||||||
'guild_id': str(guild_id),
|
'guild_id': guild_id,
|
||||||
'query': query,
|
'query': query,
|
||||||
'limit': limit
|
'limit': limit
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if nonce:
|
||||||
|
payload['d']['nonce'] = nonce
|
||||||
|
|
||||||
await self.send_as_json(payload)
|
await self.send_as_json(payload)
|
||||||
|
|
||||||
async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False):
|
async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False):
|
||||||
|
@ -126,38 +126,19 @@ class AutoShardedClient(Client):
|
|||||||
elif not isinstance(self.shard_ids, (list, tuple)):
|
elif not isinstance(self.shard_ids, (list, tuple)):
|
||||||
raise ClientException('shard_ids parameter must be a list or a tuple.')
|
raise ClientException('shard_ids parameter must be a list or a tuple.')
|
||||||
|
|
||||||
self._connection = AutoShardedConnectionState(dispatch=self.dispatch, chunker=self._chunker,
|
self._connection = AutoShardedConnectionState(dispatch=self.dispatch,
|
||||||
handlers=self._handlers, syncer=self._syncer,
|
handlers=self._handlers, syncer=self._syncer,
|
||||||
http=self.http, loop=self.loop, **kwargs)
|
http=self.http, loop=self.loop, **kwargs)
|
||||||
|
|
||||||
# instead of a single websocket, we have multiple
|
# instead of a single websocket, we have multiple
|
||||||
# the key is the shard_id
|
# the key is the shard_id
|
||||||
self.shards = {}
|
self.shards = {}
|
||||||
|
self._connection._get_websocket = self._get_websocket
|
||||||
|
|
||||||
def _get_websocket(guild_id):
|
def _get_websocket(self, guild_id=None, *, shard_id=None):
|
||||||
i = (guild_id >> 22) % self.shard_count
|
if shard_id is None:
|
||||||
return self.shards[i].ws
|
shard_id = (guild_id >> 22) % self.shard_count
|
||||||
|
return self.shards[shard_id].ws
|
||||||
self._connection._get_websocket = _get_websocket
|
|
||||||
|
|
||||||
async def _chunker(self, guild, *, shard_id=None):
|
|
||||||
try:
|
|
||||||
guild_id = guild.id
|
|
||||||
shard_id = shard_id or guild.shard_id
|
|
||||||
except AttributeError:
|
|
||||||
guild_id = [s.id for s in guild]
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
'op': 8,
|
|
||||||
'd': {
|
|
||||||
'guild_id': guild_id,
|
|
||||||
'query': '',
|
|
||||||
'limit': 0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ws = self.shards[shard_id].ws
|
|
||||||
await ws.send_as_json(payload)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def latency(self):
|
def latency(self):
|
||||||
|
@ -35,6 +35,9 @@ import weakref
|
|||||||
import inspect
|
import inspect
|
||||||
import gc
|
import gc
|
||||||
|
|
||||||
|
import os
|
||||||
|
import binascii
|
||||||
|
|
||||||
from .guild import Guild
|
from .guild import Guild
|
||||||
from .activity import BaseActivity
|
from .activity import BaseActivity
|
||||||
from .user import User, ClientUser
|
from .user import User, ClientUser
|
||||||
@ -62,7 +65,7 @@ log = logging.getLogger(__name__)
|
|||||||
ReadyState = namedtuple('ReadyState', ('launch', 'guilds'))
|
ReadyState = namedtuple('ReadyState', ('launch', 'guilds'))
|
||||||
|
|
||||||
class ConnectionState:
|
class ConnectionState:
|
||||||
def __init__(self, *, dispatch, chunker, handlers, syncer, http, loop, **options):
|
def __init__(self, *, dispatch, handlers, syncer, http, loop, **options):
|
||||||
self.loop = loop
|
self.loop = loop
|
||||||
self.http = http
|
self.http = http
|
||||||
self.max_messages = options.get('max_messages', 1000)
|
self.max_messages = options.get('max_messages', 1000)
|
||||||
@ -70,7 +73,6 @@ class ConnectionState:
|
|||||||
self.max_messages = 1000
|
self.max_messages = 1000
|
||||||
|
|
||||||
self.dispatch = dispatch
|
self.dispatch = dispatch
|
||||||
self.chunker = chunker
|
|
||||||
self.syncer = syncer
|
self.syncer = syncer
|
||||||
self.is_bot = None
|
self.is_bot = None
|
||||||
self.handlers = handlers
|
self.handlers = handlers
|
||||||
@ -132,6 +134,9 @@ class ConnectionState:
|
|||||||
# to reconnect loops which cause mass allocations and deallocations.
|
# to reconnect loops which cause mass allocations and deallocations.
|
||||||
gc.collect()
|
gc.collect()
|
||||||
|
|
||||||
|
def get_nonce(self):
|
||||||
|
return binascii.hexlify(os.urandom(16)).decode('ascii')
|
||||||
|
|
||||||
def process_listeners(self, listener_type, argument, result):
|
def process_listeners(self, listener_type, argument, result):
|
||||||
removed = []
|
removed = []
|
||||||
for i, listener in enumerate(self._listeners):
|
for i, listener in enumerate(self._listeners):
|
||||||
@ -298,6 +303,10 @@ class ConnectionState:
|
|||||||
|
|
||||||
return channel or Object(id=channel_id), guild
|
return channel or Object(id=channel_id), guild
|
||||||
|
|
||||||
|
async def chunker(self, guild_id, query='', limit=0, *, nonce=None):
|
||||||
|
ws = self._get_websocket(guild_id) # This is ignored upstream
|
||||||
|
await ws.request_chunks(guild_id, query=query, limit=limit, nonce=nonce)
|
||||||
|
|
||||||
async def request_offline_members(self, guilds):
|
async def request_offline_members(self, guilds):
|
||||||
# get all the chunks
|
# get all the chunks
|
||||||
chunks = []
|
chunks = []
|
||||||
@ -307,7 +316,7 @@ class ConnectionState:
|
|||||||
# we only want to request ~75 guilds per chunk request.
|
# we only want to request ~75 guilds per chunk request.
|
||||||
splits = [guilds[i:i + 75] for i in range(0, len(guilds), 75)]
|
splits = [guilds[i:i + 75] for i in range(0, len(guilds), 75)]
|
||||||
for split in splits:
|
for split in splits:
|
||||||
await self.chunker(split)
|
await self.chunker([g.id for g in split])
|
||||||
|
|
||||||
# wait for the chunks
|
# wait for the chunks
|
||||||
if chunks:
|
if chunks:
|
||||||
@ -329,10 +338,11 @@ class ConnectionState:
|
|||||||
# and they don't receive GUILD_MEMBER events which make computing
|
# and they don't receive GUILD_MEMBER events which make computing
|
||||||
# member_count impossible. The only way to fix it is by limiting
|
# member_count impossible. The only way to fix it is by limiting
|
||||||
# the limit parameter to 1 to 1000.
|
# the limit parameter to 1 to 1000.
|
||||||
future = self.receive_member_query(guild_id, query)
|
nonce = self.get_nonce()
|
||||||
|
future = self.receive_member_query(guild_id, nonce)
|
||||||
try:
|
try:
|
||||||
# start the query operation
|
# start the query operation
|
||||||
await ws.request_chunks(guild_id, query, limit)
|
await ws.request_chunks(guild_id, query, limit, nonce=nonce)
|
||||||
members = await asyncio.wait_for(future, timeout=5.0)
|
members = await asyncio.wait_for(future, timeout=5.0)
|
||||||
|
|
||||||
if cache:
|
if cache:
|
||||||
@ -894,8 +904,7 @@ class ConnectionState:
|
|||||||
guild._add_member(member)
|
guild._add_member(member)
|
||||||
|
|
||||||
self.process_listeners(ListenerType.chunk, guild, len(members))
|
self.process_listeners(ListenerType.chunk, guild, len(members))
|
||||||
names = [x.name.lower() for x in members]
|
self.process_listeners(ListenerType.query_members, (guild_id, data.get('nonce')), members)
|
||||||
self.process_listeners(ListenerType.query_members, (guild_id, names), members)
|
|
||||||
|
|
||||||
def parse_guild_integrations_update(self, data):
|
def parse_guild_integrations_update(self, data):
|
||||||
guild = self._get_guild(int(data['guild_id']))
|
guild = self._get_guild(int(data['guild_id']))
|
||||||
@ -1025,10 +1034,10 @@ class ConnectionState:
|
|||||||
self._listeners.append(listener)
|
self._listeners.append(listener)
|
||||||
return future
|
return future
|
||||||
|
|
||||||
def receive_member_query(self, guild_id, query):
|
def receive_member_query(self, guild_id, nonce):
|
||||||
def predicate(args, *, guild_id=guild_id, query=query.lower()):
|
def predicate(args, *, guild_id=guild_id, nonce=nonce):
|
||||||
request_guild_id, names = args
|
return args == (guild_id, nonce)
|
||||||
return request_guild_id == guild_id and all(n.startswith(query) for n in names)
|
|
||||||
future = self.loop.create_future()
|
future = self.loop.create_future()
|
||||||
listener = Listener(ListenerType.query_members, future, predicate)
|
listener = Listener(ListenerType.query_members, future, predicate)
|
||||||
self._listeners.append(listener)
|
self._listeners.append(listener)
|
||||||
@ -1040,6 +1049,10 @@ class AutoShardedConnectionState(ConnectionState):
|
|||||||
self._ready_task = None
|
self._ready_task = None
|
||||||
self.shard_ids = ()
|
self.shard_ids = ()
|
||||||
|
|
||||||
|
async def chunker(self, guild_id, query='', limit=0, *, shard_id, nonce=None):
|
||||||
|
ws = self._get_websocket(shard_id=shard_id)
|
||||||
|
await ws.request_chunks(guild_id, query=query, limit=limit, nonce=nonce)
|
||||||
|
|
||||||
async def request_offline_members(self, guilds, *, shard_id):
|
async def request_offline_members(self, guilds, *, shard_id):
|
||||||
# get all the chunks
|
# get all the chunks
|
||||||
chunks = []
|
chunks = []
|
||||||
@ -1049,7 +1062,7 @@ class AutoShardedConnectionState(ConnectionState):
|
|||||||
# we only want to request ~75 guilds per chunk request.
|
# we only want to request ~75 guilds per chunk request.
|
||||||
splits = [guilds[i:i + 75] for i in range(0, len(guilds), 75)]
|
splits = [guilds[i:i + 75] for i in range(0, len(guilds), 75)]
|
||||||
for split in splits:
|
for split in splits:
|
||||||
await self.chunker(split, shard_id=shard_id)
|
await self.chunker([g.id for g in split], shard_id=shard_id)
|
||||||
|
|
||||||
# wait for the chunks
|
# wait for the chunks
|
||||||
if chunks:
|
if chunks:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user