mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-04-18 23:15:48 +00:00
Separately delay ready event for each shard
This commit is contained in:
parent
89eb86ecdc
commit
2dbf14bb72
@ -484,7 +484,6 @@ class Client:
|
|||||||
self.loop = loop
|
self.loop = loop
|
||||||
self.http.loop = loop
|
self.http.loop = loop
|
||||||
self._connection.loop = loop
|
self._connection.loop = loop
|
||||||
await self._connection.async_setup()
|
|
||||||
|
|
||||||
self._ready = asyncio.Event()
|
self._ready = asyncio.Event()
|
||||||
|
|
||||||
|
@ -546,8 +546,6 @@ class DiscordWebSocket:
|
|||||||
self._trace = trace = data.get('_trace', [])
|
self._trace = trace = data.get('_trace', [])
|
||||||
self.sequence = msg['s']
|
self.sequence = msg['s']
|
||||||
self.session_id = data['session_id']
|
self.session_id = data['session_id']
|
||||||
# pass back shard ID to ready handler
|
|
||||||
data['__shard_id__'] = self.shard_id
|
|
||||||
_log.info(
|
_log.info(
|
||||||
'Shard ID %s has connected to Gateway: %s (Session ID: %s).',
|
'Shard ID %s has connected to Gateway: %s (Session ID: %s).',
|
||||||
self.shard_id,
|
self.shard_id,
|
||||||
|
@ -423,8 +423,6 @@ class AutoShardedClient(Client):
|
|||||||
initial = shard_id == shard_ids[0]
|
initial = shard_id == shard_ids[0]
|
||||||
await self.launch_shard(gateway, shard_id, initial=initial)
|
await self.launch_shard(gateway, shard_id, initial=initial)
|
||||||
|
|
||||||
self._connection.shards_launched.set()
|
|
||||||
|
|
||||||
async def _async_setup_hook(self) -> None:
|
async def _async_setup_hook(self) -> None:
|
||||||
await super()._async_setup_hook()
|
await super()._async_setup_hook()
|
||||||
self.__queue = asyncio.PriorityQueue()
|
self.__queue = asyncio.PriorityQueue()
|
||||||
|
163
discord/state.py
163
discord/state.py
@ -27,7 +27,6 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
from collections import deque, OrderedDict
|
from collections import deque, OrderedDict
|
||||||
import copy
|
import copy
|
||||||
import itertools
|
|
||||||
import logging
|
import logging
|
||||||
from typing import (
|
from typing import (
|
||||||
Dict,
|
Dict,
|
||||||
@ -302,9 +301,6 @@ class ConnectionState:
|
|||||||
else:
|
else:
|
||||||
await coro(*args, **kwargs)
|
await coro(*args, **kwargs)
|
||||||
|
|
||||||
async def async_setup(self) -> None:
|
|
||||||
pass
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def self_id(self) -> Optional[int]:
|
def self_id(self) -> Optional[int]:
|
||||||
u = self.user
|
u = self.user
|
||||||
@ -561,7 +557,7 @@ class ConnectionState:
|
|||||||
if self._ready_task is not None:
|
if self._ready_task is not None:
|
||||||
self._ready_task.cancel()
|
self._ready_task.cancel()
|
||||||
|
|
||||||
self._ready_state = asyncio.Queue()
|
self._ready_state: asyncio.Queue[Guild] = asyncio.Queue()
|
||||||
self.clear(views=False)
|
self.clear(views=False)
|
||||||
self.user = user = ClientUser(state=self, data=data['user'])
|
self.user = user = ClientUser(state=self, data=data['user'])
|
||||||
self._users[user.id] = user # type: ignore
|
self._users[user.id] = user # type: ignore
|
||||||
@ -1111,6 +1107,15 @@ class ConnectionState:
|
|||||||
else:
|
else:
|
||||||
self.dispatch('guild_join', guild)
|
self.dispatch('guild_join', guild)
|
||||||
|
|
||||||
|
def _add_ready_state(self, guild: Guild) -> bool:
|
||||||
|
try:
|
||||||
|
# Notify the on_ready state, if any, that this guild is complete.
|
||||||
|
self._ready_state.put_nowait(guild)
|
||||||
|
except AttributeError:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
def parse_guild_create(self, data: gw.GuildCreateEvent) -> None:
|
def parse_guild_create(self, data: gw.GuildCreateEvent) -> None:
|
||||||
unavailable = data.get('unavailable')
|
unavailable = data.get('unavailable')
|
||||||
if unavailable is True:
|
if unavailable is True:
|
||||||
@ -1119,14 +1124,8 @@ class ConnectionState:
|
|||||||
|
|
||||||
guild = self._get_create_guild(data)
|
guild = self._get_create_guild(data)
|
||||||
|
|
||||||
try:
|
if self._add_ready_state(guild):
|
||||||
# Notify the on_ready state, if any, that this guild is complete.
|
return # We're waiting for the ready event, put the rest on hold
|
||||||
self._ready_state.put_nowait(guild)
|
|
||||||
except AttributeError:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
# If we're waiting for the event, put the rest on hold
|
|
||||||
return
|
|
||||||
|
|
||||||
# check if it requires chunking
|
# check if it requires chunking
|
||||||
if self._guild_needs_chunking(guild):
|
if self._guild_needs_chunking(guild):
|
||||||
@ -1510,8 +1509,12 @@ class ConnectionState:
|
|||||||
class AutoShardedConnectionState(ConnectionState):
|
class AutoShardedConnectionState(ConnectionState):
|
||||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
|
|
||||||
self.shard_ids: Union[List[int], range] = []
|
self.shard_ids: Union[List[int], range] = []
|
||||||
|
|
||||||
|
self._ready_tasks: Dict[int, asyncio.Task[None]] = {}
|
||||||
|
self._ready_states: Dict[int, asyncio.Queue[Guild]] = {}
|
||||||
|
|
||||||
def _update_message_references(self) -> None:
|
def _update_message_references(self) -> None:
|
||||||
# self._messages won't be None when this is called
|
# self._messages won't be None when this is called
|
||||||
for msg in self._messages: # type: ignore
|
for msg in self._messages: # type: ignore
|
||||||
@ -1525,9 +1528,6 @@ class AutoShardedConnectionState(ConnectionState):
|
|||||||
# channel will either be a TextChannel, Thread or Object
|
# channel will either be a TextChannel, Thread or Object
|
||||||
msg._rebind_cached_references(new_guild, channel) # type: ignore
|
msg._rebind_cached_references(new_guild, channel) # type: ignore
|
||||||
|
|
||||||
async def async_setup(self) -> None:
|
|
||||||
self.shards_launched: asyncio.Event = asyncio.Event()
|
|
||||||
|
|
||||||
async def chunker(
|
async def chunker(
|
||||||
self,
|
self,
|
||||||
guild_id: int,
|
guild_id: int,
|
||||||
@ -1541,76 +1541,80 @@ class AutoShardedConnectionState(ConnectionState):
|
|||||||
ws = self._get_websocket(guild_id, shard_id=shard_id)
|
ws = self._get_websocket(guild_id, shard_id=shard_id)
|
||||||
await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce)
|
await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce)
|
||||||
|
|
||||||
async def _delay_ready(self) -> None:
|
def _add_ready_state(self, guild: Guild) -> bool:
|
||||||
await self.shards_launched.wait()
|
|
||||||
processed = []
|
|
||||||
max_concurrency = len(self.shard_ids) * 2
|
|
||||||
current_bucket = []
|
|
||||||
while True:
|
|
||||||
# this snippet of code is basically waiting N seconds
|
|
||||||
# until the last GUILD_CREATE was sent
|
|
||||||
try:
|
|
||||||
guild = await asyncio.wait_for(self._ready_state.get(), timeout=self.guild_ready_timeout)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
if self._guild_needs_chunking(guild):
|
|
||||||
_log.debug('Guild ID %d requires chunking, will be done in the background.', guild.id)
|
|
||||||
if len(current_bucket) >= max_concurrency:
|
|
||||||
try:
|
|
||||||
await utils.sane_wait_for(current_bucket, timeout=max_concurrency * 70.0)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
fmt = 'Shard ID %s failed to wait for chunks from a sub-bucket with length %d'
|
|
||||||
_log.warning(fmt, guild.shard_id, len(current_bucket))
|
|
||||||
finally:
|
|
||||||
current_bucket = []
|
|
||||||
|
|
||||||
# Chunk the guild in the background while we wait for GUILD_CREATE streaming
|
|
||||||
future = asyncio.ensure_future(self.chunk_guild(guild))
|
|
||||||
current_bucket.append(future)
|
|
||||||
else:
|
|
||||||
future = self.loop.create_future()
|
|
||||||
future.set_result([])
|
|
||||||
|
|
||||||
processed.append((guild, future))
|
|
||||||
|
|
||||||
guilds = sorted(processed, key=lambda g: g[0].shard_id)
|
|
||||||
for shard_id, info in itertools.groupby(guilds, key=lambda g: g[0].shard_id):
|
|
||||||
children, futures = zip(*info)
|
|
||||||
# 110 reqs/minute w/ 1 req/guild plus some buffer
|
|
||||||
timeout = 61 * (len(children) / 110)
|
|
||||||
try:
|
|
||||||
await utils.sane_wait_for(futures, timeout=timeout)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
_log.warning(
|
|
||||||
'Shard ID %s failed to wait for chunks (timeout=%.2f) for %d guilds', shard_id, timeout, len(guilds)
|
|
||||||
)
|
|
||||||
for guild in children:
|
|
||||||
if guild.unavailable is False:
|
|
||||||
self.dispatch('guild_available', guild)
|
|
||||||
else:
|
|
||||||
self.dispatch('guild_join', guild)
|
|
||||||
|
|
||||||
self.dispatch('shard_ready', shard_id)
|
|
||||||
|
|
||||||
# remove the state
|
|
||||||
try:
|
try:
|
||||||
del self._ready_state
|
# Notify the on_ready state, if any, that this guild is complete.
|
||||||
except AttributeError:
|
self._ready_states[guild.shard_id].put_nowait(guild)
|
||||||
pass # already been deleted somehow
|
except KeyError:
|
||||||
|
return False
|
||||||
|
else:
|
||||||
|
return True
|
||||||
|
|
||||||
# regular users cannot shard so we won't worry about it here.
|
async def _delay_ready(self) -> None:
|
||||||
|
await asyncio.gather(*self._ready_tasks.values())
|
||||||
|
|
||||||
# clear the current task
|
# clear the current tasks
|
||||||
self._ready_task = None
|
self._ready_task = None
|
||||||
|
self._ready_tasks = {}
|
||||||
|
|
||||||
# dispatch the event
|
# dispatch the event
|
||||||
self.call_handlers('ready')
|
self.call_handlers('ready')
|
||||||
self.dispatch('ready')
|
self.dispatch('ready')
|
||||||
|
|
||||||
|
async def _delay_shard_ready(self, shard_id: int) -> None:
|
||||||
|
try:
|
||||||
|
states = []
|
||||||
|
while True:
|
||||||
|
# this snippet of code is basically waiting N seconds
|
||||||
|
# until the last GUILD_CREATE was sent
|
||||||
|
try:
|
||||||
|
guild = await asyncio.wait_for(self._ready_states[shard_id].get(), timeout=self.guild_ready_timeout)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
if self._guild_needs_chunking(guild):
|
||||||
|
future = await self.chunk_guild(guild, wait=False)
|
||||||
|
states.append((guild, future))
|
||||||
|
else:
|
||||||
|
if guild.unavailable is False:
|
||||||
|
self.dispatch('guild_available', guild)
|
||||||
|
else:
|
||||||
|
self.dispatch('guild_join', guild)
|
||||||
|
|
||||||
|
for guild, future in states:
|
||||||
|
try:
|
||||||
|
await asyncio.wait_for(future, timeout=5.0)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
_log.warning('Shard ID %s timed out waiting for chunks for guild_id %s.', guild.shard_id, guild.id)
|
||||||
|
|
||||||
|
if guild.unavailable is False:
|
||||||
|
self.dispatch('guild_available', guild)
|
||||||
|
else:
|
||||||
|
self.dispatch('guild_join', guild)
|
||||||
|
|
||||||
|
# remove the state
|
||||||
|
try:
|
||||||
|
del self._ready_states[shard_id]
|
||||||
|
except KeyError:
|
||||||
|
pass # already been deleted somehow
|
||||||
|
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# dispatch the event
|
||||||
|
self.dispatch('shard_ready', shard_id)
|
||||||
|
|
||||||
def parse_ready(self, data: gw.ReadyEvent) -> None:
|
def parse_ready(self, data: gw.ReadyEvent) -> None:
|
||||||
if not hasattr(self, '_ready_state'):
|
if self._ready_task is not None:
|
||||||
self._ready_state = asyncio.Queue()
|
self._ready_task.cancel()
|
||||||
|
|
||||||
|
shard_id = data['shard'][0] # shard_id, num_shards
|
||||||
|
|
||||||
|
if shard_id in self._ready_tasks:
|
||||||
|
self._ready_tasks[shard_id].cancel()
|
||||||
|
|
||||||
|
if shard_id not in self._ready_states:
|
||||||
|
self._ready_states[shard_id] = asyncio.Queue()
|
||||||
|
|
||||||
self.user: Optional[ClientUser]
|
self.user: Optional[ClientUser]
|
||||||
self.user = user = ClientUser(state=self, data=data['user'])
|
self.user = user = ClientUser(state=self, data=data['user'])
|
||||||
@ -1633,9 +1637,12 @@ class AutoShardedConnectionState(ConnectionState):
|
|||||||
self._update_message_references()
|
self._update_message_references()
|
||||||
|
|
||||||
self.dispatch('connect')
|
self.dispatch('connect')
|
||||||
self.dispatch('shard_connect', data['__shard_id__']) # type: ignore # This is an internal discord.py key
|
self.dispatch('shard_connect', shard_id)
|
||||||
|
|
||||||
if self._ready_task is None:
|
self._ready_tasks[shard_id] = asyncio.create_task(self._delay_shard_ready(shard_id))
|
||||||
|
|
||||||
|
# The delay task for every shard has been started
|
||||||
|
if len(self._ready_tasks) == len(self.shard_ids):
|
||||||
self._ready_task = asyncio.create_task(self._delay_ready())
|
self._ready_task = asyncio.create_task(self._delay_ready())
|
||||||
|
|
||||||
def parse_resumed(self, data: gw.ResumedEvent) -> None:
|
def parse_resumed(self, data: gw.ResumedEvent) -> None:
|
||||||
|
@ -60,17 +60,12 @@ class GatewayBot(Gateway):
|
|||||||
session_start_limit: SessionStartLimit
|
session_start_limit: SessionStartLimit
|
||||||
|
|
||||||
|
|
||||||
class ShardInfo(TypedDict):
|
|
||||||
shard_id: int
|
|
||||||
shard_count: int
|
|
||||||
|
|
||||||
|
|
||||||
class ReadyEvent(TypedDict):
|
class ReadyEvent(TypedDict):
|
||||||
v: int
|
v: int
|
||||||
user: User
|
user: User
|
||||||
guilds: List[UnavailableGuild]
|
guilds: List[UnavailableGuild]
|
||||||
session_id: str
|
session_id: str
|
||||||
shard: ShardInfo
|
shard: List[int] # shard_id, num_shards
|
||||||
application: GatewayAppInfo
|
application: GatewayAppInfo
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user