mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-04-16 06:03:11 +00:00
Separately delay ready event for each shard
This commit is contained in:
parent
89eb86ecdc
commit
2dbf14bb72
@ -484,7 +484,6 @@ class Client:
|
||||
self.loop = loop
|
||||
self.http.loop = loop
|
||||
self._connection.loop = loop
|
||||
await self._connection.async_setup()
|
||||
|
||||
self._ready = asyncio.Event()
|
||||
|
||||
|
@ -546,8 +546,6 @@ class DiscordWebSocket:
|
||||
self._trace = trace = data.get('_trace', [])
|
||||
self.sequence = msg['s']
|
||||
self.session_id = data['session_id']
|
||||
# pass back shard ID to ready handler
|
||||
data['__shard_id__'] = self.shard_id
|
||||
_log.info(
|
||||
'Shard ID %s has connected to Gateway: %s (Session ID: %s).',
|
||||
self.shard_id,
|
||||
|
@ -423,8 +423,6 @@ class AutoShardedClient(Client):
|
||||
initial = shard_id == shard_ids[0]
|
||||
await self.launch_shard(gateway, shard_id, initial=initial)
|
||||
|
||||
self._connection.shards_launched.set()
|
||||
|
||||
async def _async_setup_hook(self) -> None:
|
||||
await super()._async_setup_hook()
|
||||
self.__queue = asyncio.PriorityQueue()
|
||||
|
163
discord/state.py
163
discord/state.py
@ -27,7 +27,6 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
from collections import deque, OrderedDict
|
||||
import copy
|
||||
import itertools
|
||||
import logging
|
||||
from typing import (
|
||||
Dict,
|
||||
@ -302,9 +301,6 @@ class ConnectionState:
|
||||
else:
|
||||
await coro(*args, **kwargs)
|
||||
|
||||
async def async_setup(self) -> None:
|
||||
pass
|
||||
|
||||
@property
|
||||
def self_id(self) -> Optional[int]:
|
||||
u = self.user
|
||||
@ -561,7 +557,7 @@ class ConnectionState:
|
||||
if self._ready_task is not None:
|
||||
self._ready_task.cancel()
|
||||
|
||||
self._ready_state = asyncio.Queue()
|
||||
self._ready_state: asyncio.Queue[Guild] = asyncio.Queue()
|
||||
self.clear(views=False)
|
||||
self.user = user = ClientUser(state=self, data=data['user'])
|
||||
self._users[user.id] = user # type: ignore
|
||||
@ -1111,6 +1107,15 @@ class ConnectionState:
|
||||
else:
|
||||
self.dispatch('guild_join', guild)
|
||||
|
||||
def _add_ready_state(self, guild: Guild) -> bool:
|
||||
try:
|
||||
# Notify the on_ready state, if any, that this guild is complete.
|
||||
self._ready_state.put_nowait(guild)
|
||||
except AttributeError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
def parse_guild_create(self, data: gw.GuildCreateEvent) -> None:
|
||||
unavailable = data.get('unavailable')
|
||||
if unavailable is True:
|
||||
@ -1119,14 +1124,8 @@ class ConnectionState:
|
||||
|
||||
guild = self._get_create_guild(data)
|
||||
|
||||
try:
|
||||
# Notify the on_ready state, if any, that this guild is complete.
|
||||
self._ready_state.put_nowait(guild)
|
||||
except AttributeError:
|
||||
pass
|
||||
else:
|
||||
# If we're waiting for the event, put the rest on hold
|
||||
return
|
||||
if self._add_ready_state(guild):
|
||||
return # We're waiting for the ready event, put the rest on hold
|
||||
|
||||
# check if it requires chunking
|
||||
if self._guild_needs_chunking(guild):
|
||||
@ -1510,8 +1509,12 @@ class ConnectionState:
|
||||
class AutoShardedConnectionState(ConnectionState):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
self.shard_ids: Union[List[int], range] = []
|
||||
|
||||
self._ready_tasks: Dict[int, asyncio.Task[None]] = {}
|
||||
self._ready_states: Dict[int, asyncio.Queue[Guild]] = {}
|
||||
|
||||
def _update_message_references(self) -> None:
|
||||
# self._messages won't be None when this is called
|
||||
for msg in self._messages: # type: ignore
|
||||
@ -1525,9 +1528,6 @@ class AutoShardedConnectionState(ConnectionState):
|
||||
# channel will either be a TextChannel, Thread or Object
|
||||
msg._rebind_cached_references(new_guild, channel) # type: ignore
|
||||
|
||||
async def async_setup(self) -> None:
|
||||
self.shards_launched: asyncio.Event = asyncio.Event()
|
||||
|
||||
async def chunker(
|
||||
self,
|
||||
guild_id: int,
|
||||
@ -1541,76 +1541,80 @@ class AutoShardedConnectionState(ConnectionState):
|
||||
ws = self._get_websocket(guild_id, shard_id=shard_id)
|
||||
await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce)
|
||||
|
||||
async def _delay_ready(self) -> None:
|
||||
await self.shards_launched.wait()
|
||||
processed = []
|
||||
max_concurrency = len(self.shard_ids) * 2
|
||||
current_bucket = []
|
||||
while True:
|
||||
# this snippet of code is basically waiting N seconds
|
||||
# until the last GUILD_CREATE was sent
|
||||
try:
|
||||
guild = await asyncio.wait_for(self._ready_state.get(), timeout=self.guild_ready_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
else:
|
||||
if self._guild_needs_chunking(guild):
|
||||
_log.debug('Guild ID %d requires chunking, will be done in the background.', guild.id)
|
||||
if len(current_bucket) >= max_concurrency:
|
||||
try:
|
||||
await utils.sane_wait_for(current_bucket, timeout=max_concurrency * 70.0)
|
||||
except asyncio.TimeoutError:
|
||||
fmt = 'Shard ID %s failed to wait for chunks from a sub-bucket with length %d'
|
||||
_log.warning(fmt, guild.shard_id, len(current_bucket))
|
||||
finally:
|
||||
current_bucket = []
|
||||
|
||||
# Chunk the guild in the background while we wait for GUILD_CREATE streaming
|
||||
future = asyncio.ensure_future(self.chunk_guild(guild))
|
||||
current_bucket.append(future)
|
||||
else:
|
||||
future = self.loop.create_future()
|
||||
future.set_result([])
|
||||
|
||||
processed.append((guild, future))
|
||||
|
||||
guilds = sorted(processed, key=lambda g: g[0].shard_id)
|
||||
for shard_id, info in itertools.groupby(guilds, key=lambda g: g[0].shard_id):
|
||||
children, futures = zip(*info)
|
||||
# 110 reqs/minute w/ 1 req/guild plus some buffer
|
||||
timeout = 61 * (len(children) / 110)
|
||||
try:
|
||||
await utils.sane_wait_for(futures, timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
_log.warning(
|
||||
'Shard ID %s failed to wait for chunks (timeout=%.2f) for %d guilds', shard_id, timeout, len(guilds)
|
||||
)
|
||||
for guild in children:
|
||||
if guild.unavailable is False:
|
||||
self.dispatch('guild_available', guild)
|
||||
else:
|
||||
self.dispatch('guild_join', guild)
|
||||
|
||||
self.dispatch('shard_ready', shard_id)
|
||||
|
||||
# remove the state
|
||||
def _add_ready_state(self, guild: Guild) -> bool:
|
||||
try:
|
||||
del self._ready_state
|
||||
except AttributeError:
|
||||
pass # already been deleted somehow
|
||||
# Notify the on_ready state, if any, that this guild is complete.
|
||||
self._ready_states[guild.shard_id].put_nowait(guild)
|
||||
except KeyError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
# regular users cannot shard so we won't worry about it here.
|
||||
async def _delay_ready(self) -> None:
|
||||
await asyncio.gather(*self._ready_tasks.values())
|
||||
|
||||
# clear the current task
|
||||
# clear the current tasks
|
||||
self._ready_task = None
|
||||
self._ready_tasks = {}
|
||||
|
||||
# dispatch the event
|
||||
self.call_handlers('ready')
|
||||
self.dispatch('ready')
|
||||
|
||||
async def _delay_shard_ready(self, shard_id: int) -> None:
|
||||
try:
|
||||
states = []
|
||||
while True:
|
||||
# this snippet of code is basically waiting N seconds
|
||||
# until the last GUILD_CREATE was sent
|
||||
try:
|
||||
guild = await asyncio.wait_for(self._ready_states[shard_id].get(), timeout=self.guild_ready_timeout)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
else:
|
||||
if self._guild_needs_chunking(guild):
|
||||
future = await self.chunk_guild(guild, wait=False)
|
||||
states.append((guild, future))
|
||||
else:
|
||||
if guild.unavailable is False:
|
||||
self.dispatch('guild_available', guild)
|
||||
else:
|
||||
self.dispatch('guild_join', guild)
|
||||
|
||||
for guild, future in states:
|
||||
try:
|
||||
await asyncio.wait_for(future, timeout=5.0)
|
||||
except asyncio.TimeoutError:
|
||||
_log.warning('Shard ID %s timed out waiting for chunks for guild_id %s.', guild.shard_id, guild.id)
|
||||
|
||||
if guild.unavailable is False:
|
||||
self.dispatch('guild_available', guild)
|
||||
else:
|
||||
self.dispatch('guild_join', guild)
|
||||
|
||||
# remove the state
|
||||
try:
|
||||
del self._ready_states[shard_id]
|
||||
except KeyError:
|
||||
pass # already been deleted somehow
|
||||
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
else:
|
||||
# dispatch the event
|
||||
self.dispatch('shard_ready', shard_id)
|
||||
|
||||
def parse_ready(self, data: gw.ReadyEvent) -> None:
|
||||
if not hasattr(self, '_ready_state'):
|
||||
self._ready_state = asyncio.Queue()
|
||||
if self._ready_task is not None:
|
||||
self._ready_task.cancel()
|
||||
|
||||
shard_id = data['shard'][0] # shard_id, num_shards
|
||||
|
||||
if shard_id in self._ready_tasks:
|
||||
self._ready_tasks[shard_id].cancel()
|
||||
|
||||
if shard_id not in self._ready_states:
|
||||
self._ready_states[shard_id] = asyncio.Queue()
|
||||
|
||||
self.user: Optional[ClientUser]
|
||||
self.user = user = ClientUser(state=self, data=data['user'])
|
||||
@ -1633,9 +1637,12 @@ class AutoShardedConnectionState(ConnectionState):
|
||||
self._update_message_references()
|
||||
|
||||
self.dispatch('connect')
|
||||
self.dispatch('shard_connect', data['__shard_id__']) # type: ignore # This is an internal discord.py key
|
||||
self.dispatch('shard_connect', shard_id)
|
||||
|
||||
if self._ready_task is None:
|
||||
self._ready_tasks[shard_id] = asyncio.create_task(self._delay_shard_ready(shard_id))
|
||||
|
||||
# The delay task for every shard has been started
|
||||
if len(self._ready_tasks) == len(self.shard_ids):
|
||||
self._ready_task = asyncio.create_task(self._delay_ready())
|
||||
|
||||
def parse_resumed(self, data: gw.ResumedEvent) -> None:
|
||||
|
@ -60,17 +60,12 @@ class GatewayBot(Gateway):
|
||||
session_start_limit: SessionStartLimit
|
||||
|
||||
|
||||
class ShardInfo(TypedDict):
|
||||
shard_id: int
|
||||
shard_count: int
|
||||
|
||||
|
||||
class ReadyEvent(TypedDict):
|
||||
v: int
|
||||
user: User
|
||||
guilds: List[UnavailableGuild]
|
||||
session_id: str
|
||||
shard: ShardInfo
|
||||
shard: List[int] # shard_id, num_shards
|
||||
application: GatewayAppInfo
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user