Merge pull request #1 from Rapptz/feature/intents

Feature/intents
This commit is contained in:
iDutchy 2020-09-13 05:44:24 +02:00 committed by GitHub
commit 4b612aeece
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 633 additions and 229 deletions

View File

@ -141,6 +141,10 @@ class Client:
Integer starting at ``0`` and less than :attr:`.shard_count`. Integer starting at ``0`` and less than :attr:`.shard_count`.
shard_count: Optional[:class:`int`] shard_count: Optional[:class:`int`]
The total number of shards. The total number of shards.
intents: :class:`Intents`
A list of intents that you want to enable for the session. This is a way of
disabling and enabling certain gateway events from triggering and being sent.
Currently, if no intents are passed then you will receive all data.
fetch_offline_members: :class:`bool` fetch_offline_members: :class:`bool`
Indicates if :func:`.on_ready` should be delayed to fetch all offline Indicates if :func:`.on_ready` should be delayed to fetch all offline
members from the guilds the client belongs to. If this is ``False``\, then members from the guilds the client belongs to. If this is ``False``\, then
@ -375,6 +379,7 @@ class Client:
print('Ignoring exception in {}'.format(event_method), file=sys.stderr) print('Ignoring exception in {}'.format(event_method), file=sys.stderr)
traceback.print_exc() traceback.print_exc()
@utils.deprecated('Guild.chunk')
async def request_offline_members(self, *guilds): async def request_offline_members(self, *guilds):
r"""|coro| r"""|coro|
@ -388,6 +393,10 @@ class Client:
in the guild is larger than 250. You can check if a guild is large in the guild is larger than 250. You can check if a guild is large
if :attr:`.Guild.large` is ``True``. if :attr:`.Guild.large` is ``True``.
.. warning::
This method is deprecated. Use :meth:`Guild.chunk` instead.
Parameters Parameters
----------- -----------
\*guilds: :class:`.Guild` \*guilds: :class:`.Guild`
@ -396,12 +405,13 @@ class Client:
Raises Raises
------- -------
:exc:`.InvalidArgument` :exc:`.InvalidArgument`
If any guild is unavailable or not large in the collection. If any guild is unavailable in the collection.
""" """
if any(not g.large or g.unavailable for g in guilds): if any(g.unavailable for g in guilds):
raise InvalidArgument('An unavailable or non-large guild was passed.') raise InvalidArgument('An unavailable guild was passed.')
await self._connection.request_offline_members(guilds) for guild in guilds:
await self._connection.chunk_guild(guild)
# hooks # hooks

View File

@ -29,7 +29,8 @@ from .enums import UserFlags
__all__ = ( __all__ = (
'SystemChannelFlags', 'SystemChannelFlags',
'MessageFlags', 'MessageFlags',
'PublicUserFlags' 'PublicUserFlags',
'Intents',
) )
class flag_value: class flag_value:
@ -327,3 +328,326 @@ class PublicUserFlags(BaseFlags):
def all(self): def all(self):
"""List[:class:`UserFlags`]: Returns all public flags the user has.""" """List[:class:`UserFlags`]: Returns all public flags the user has."""
return [public_flag for public_flag in UserFlags if self._has_flag(public_flag.value)] return [public_flag for public_flag in UserFlags if self._has_flag(public_flag.value)]
@fill_with_flags()
class Intents(BaseFlags):
r"""Wraps up a Discord gateway intent flag.
Similar to :class:`Permissions`\, the properties provided are two way.
You can set and retrieve individual bits using the properties as if they
were regular bools.
To construct an object you can pass keyword arguments denoting the flags
to enable or disable.
This is used to disable certain gateway features that are unnecessary to
run your bot. To make use of this, it is passed to the ``intents`` keyword
argument of :class:`Client`.
A default instance of this class has everything enabled except :attr:`presences`
and :attr:`members`.
.. versionadded:: 1.5
.. container:: operations
.. describe:: x == y
Checks if two flags are equal.
.. describe:: x != y
Checks if two flags are not equal.
.. describe:: hash(x)
Return the flag's hash.
.. describe:: iter(x)
Returns an iterator of ``(name, value)`` pairs. This allows it
to be, for example, constructed as a dict or a list of pairs.
Attributes
-----------
value: :class:`int`
The raw value. You should query flags via the properties
rather than using this raw value.
"""
__slots__ = ()
def __init__(self, **kwargs):
# Change the default value to everything being enabled
# except presences and members
bits = max(self.VALID_FLAGS.values()).bit_length()
self.value = (1 << bits) - 1
self.presences = False
self.members = False
for key, value in kwargs.items():
if key not in self.VALID_FLAGS:
raise TypeError('%r is not a valid flag name.' % key)
setattr(self, key, value)
@classmethod
def all(cls):
"""A factory method that creates a :class:`Intents` with everything enabled."""
bits = max(cls.VALID_FLAGS.values()).bit_length()
value = (1 << bits) - 1
self = cls.__new__(cls)
self.value = value
return self
@classmethod
def none(cls):
"""A factory method that creates a :class:`Intents` with everything disabled."""
self = cls.__new__(cls)
self.value = self.DEFAULT_VALUE
return self
@flag_value
def guilds(self):
""":class:`bool`: Whether guild related events are enabled.
This corresponds to the following events:
- :func:`on_guild_join`
- :func:`on_guild_remove`
- :func:`on_guild_available`
- :func:`on_guild_unavailable`
- :func:`on_guild_channel_update`
- :func:`on_guild_channel_create`
- :func:`on_guild_channel_delete`
- :func:`on_guild_channel_pins_update`
"""
return 1 << 0
@flag_value
def members(self):
""":class:`bool`: Whether guild member related events are enabled.
This corresponds to the following events:
- :func:`on_member_join`
- :func:`on_member_remove`
- :func:`on_member_update` (nickname, roles)
- :func:`on_user_update`
.. note::
Currently, this requires opting in explicitly via the dev portal as well.
Bots in over 100 guilds will need to apply to Discord for verification.
"""
return 1 << 1
@flag_value
def bans(self):
""":class:`bool`: Whether guild ban related events are enabled.
This corresponds to the following events:
- :func:`on_member_ban`
- :func:`on_member_unban`
"""
return 1 << 2
@flag_value
def emojis(self):
""":class:`bool`: Whether guild emoji related events are enabled.
This corresponds to the following events:
- :func:`on_guild_emojis_update`
"""
return 1 << 3
@flag_value
def integrations(self):
""":class:`bool`: Whether guild integration related events are enabled.
This corresponds to the following events:
- :func:`on_guild_integrations_update`
"""
return 1 << 4
@flag_value
def webhooks(self):
""":class:`bool`: Whether guild webhook related events are enabled.
This corresponds to the following events:
- :func:`on_webhooks_update`
"""
return 1 << 5
@flag_value
def invites(self):
""":class:`bool`: Whether guild invite related events are enabled.
This corresponds to the following events:
- :func:`on_invite_create`
- :func:`on_invite_delete`
"""
return 1 << 6
@flag_value
def voice_states(self):
""":class:`bool`: Whether guild voice state related events are enabled.
This corresponds to the following events:
- :func:`on_voice_state_update`
"""
return 1 << 7
@flag_value
def presences(self):
""":class:`bool`: Whether guild voice state related events are enabled.
This corresponds to the following events:
- :func:`on_member_update` (activities, status)
.. note::
Currently, this requires opting in explicitly via the dev portal as well.
Bots in over 100 guilds will need to apply to Discord for verification.
"""
return 1 << 8
@flag_value
def messages(self):
""":class:`bool`: Whether guild and direct message related events are enabled.
This is a shortcut to set or get both :attr:`guild_messages` and :attr:`dm_messages`.
This corresponds to the following events:
- :func:`on_message` (both guilds and DMs)
- :func:`on_message_update` (both guilds and DMs)
- :func:`on_message_delete` (both guilds and DMs)
- :func:`on_raw_message_delete` (both guilds and DMs)
- :func:`on_raw_message_update` (both guilds and DMs)
- :func:`on_private_channel_create`
"""
return (1 << 9) | (1 << 12)
@flag_value
def guild_messages(self):
""":class:`bool`: Whether guild message related events are enabled.
See also :attr:`dm_messages` for DMs or :attr:`messages` for both.
This corresponds to the following events:
- :func:`on_message` (only for guilds)
- :func:`on_message_update` (only for guilds)
- :func:`on_message_delete` (only for guilds)
- :func:`on_raw_message_delete` (only for guilds)
- :func:`on_raw_message_update` (only for guilds)
"""
return 1 << 9
@flag_value
def dm_messages(self):
""":class:`bool`: Whether direct message related events are enabled.
See also :attr:`guild_messages` for guilds or :attr:`messages` for both.
This corresponds to the following events:
- :func:`on_message` (only for DMs)
- :func:`on_message_update` (only for DMs)
- :func:`on_message_delete` (only for DMs)
- :func:`on_raw_message_delete` (only for DMs)
- :func:`on_raw_message_update` (only for DMs)
- :func:`on_private_channel_create`
"""
return 1 << 12
@flag_value
def reactions(self):
""":class:`bool`: Whether guild and direct message reaction related events are enabled.
This is a shortcut to set or get both :attr:`guild_reactions` and :attr:`dm_reactions`.
This corresponds to the following events:
- :func:`on_reaction_add` (both guilds and DMs)
- :func:`on_reaction_remove` (both guilds and DMs)
- :func:`on_reaction_clear` (both guilds and DMs)
- :func:`on_raw_reaction_add` (both guilds and DMs)
- :func:`on_raw_reaction_remove` (both guilds and DMs)
- :func:`on_raw_reaction_clear` (both guilds and DMs)
"""
return (1 << 10) | (1 << 13)
@flag_value
def guild_reactions(self):
""":class:`bool`: Whether guild message reaction related events are enabled.
See also :attr:`dm_reactions` for DMs or :attr:`reactions` for both.
This corresponds to the following events:
- :func:`on_reaction_add` (only for guilds)
- :func:`on_reaction_remove` (only for guilds)
- :func:`on_reaction_clear` (only for guilds)
- :func:`on_raw_reaction_add` (only for guilds)
- :func:`on_raw_reaction_remove` (only for guilds)
- :func:`on_raw_reaction_clear` (only for guilds)
"""
return 1 << 10
@flag_value
def dm_reactions(self):
""":class:`bool`: Whether direct message reaction related events are enabled.
See also :attr:`guild_reactions` for guilds or :attr:`reactions` for both.
This corresponds to the following events:
- :func:`on_reaction_add` (only for DMs)
- :func:`on_reaction_remove` (only for DMs)
- :func:`on_reaction_clear` (only for DMs)
- :func:`on_raw_reaction_add` (only for DMs)
- :func:`on_raw_reaction_remove` (only for DMs)
- :func:`on_raw_reaction_clear` (only for DMs)
"""
return 1 << 13
@flag_value
def typing(self):
""":class:`bool`: Whether guild and direct message typing related events are enabled.
This is a shortcut to set or get both :attr:`guild_typing` and :attr:`dm_typing`.
This corresponds to the following events:
- :func:`on_typing` (both guilds and DMs)
"""
return (1 << 11) | (1 << 14)
@flag_value
def guild_typing(self):
""":class:`bool`: Whether guild and direct message typing related events are enabled.
See also :attr:`dm_typing` for DMs or :attr:`typing` for both.
This corresponds to the following events:
- :func:`on_typing` (only for guilds)
"""
return 1 << 11
@flag_value
def dm_typing(self):
""":class:`bool`: Whether guild and direct message typing related events are enabled.
See also :attr:`guild_typing` for guilds or :attr:`typing` for both.
This corresponds to the following events:
- :func:`on_typing` (only for DMs)
"""
return 1 << 14

View File

@ -66,6 +66,42 @@ class WebSocketClosure(Exception):
EventListener = namedtuple('EventListener', 'predicate event result future') EventListener = namedtuple('EventListener', 'predicate event result future')
class GatewayRatelimiter:
def __init__(self, count=110, per=60.0):
# The default is 110 to give room for at least 10 heartbeats per minute
self.max = count
self.remaining = count
self.window = 0.0
self.per = per
self.lock = asyncio.Lock()
self.shard_id = None
def get_delay(self):
current = time.time()
if current > self.window + self.per:
self.remaining = self.max
if self.remaining == self.max:
self.window = current
if self.remaining == 0:
return self.per - (current - self.window)
self.remaining -= 1
if self.remaining == 0:
self.window = current
return 0.0
async def block(self):
async with self.lock:
delta = self.get_delay()
if delta:
log.warning('WebSocket in shard ID %s is ratelimited, waiting %.2f seconds', self.shard_id, delta)
await asyncio.sleep(delta)
class KeepAliveHandler(threading.Thread): class KeepAliveHandler(threading.Thread):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
ws = kwargs.pop('ws', None) ws = kwargs.pop('ws', None)
@ -83,12 +119,13 @@ class KeepAliveHandler(threading.Thread):
self._stop_ev = threading.Event() self._stop_ev = threading.Event()
self._last_ack = time.perf_counter() self._last_ack = time.perf_counter()
self._last_send = time.perf_counter() self._last_send = time.perf_counter()
self._last_recv = time.perf_counter()
self.latency = float('inf') self.latency = float('inf')
self.heartbeat_timeout = ws._max_heartbeat_timeout self.heartbeat_timeout = ws._max_heartbeat_timeout
def run(self): def run(self):
while not self._stop_ev.wait(self.interval): while not self._stop_ev.wait(self.interval):
if self._last_ack + self.heartbeat_timeout < time.perf_counter(): if self._last_recv + self.heartbeat_timeout < time.perf_counter():
log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id) log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id)
coro = self.ws.close(4000) coro = self.ws.close(4000)
f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop) f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop)
@ -103,7 +140,7 @@ class KeepAliveHandler(threading.Thread):
data = self.get_payload() data = self.get_payload()
log.debug(self.msg, self.shard_id, data['d']) log.debug(self.msg, self.shard_id, data['d'])
coro = self.ws.send_as_json(data) coro = self.ws.send_heartbeat(data)
f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop) f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop)
try: try:
# block until sending is complete # block until sending is complete
@ -137,6 +174,9 @@ class KeepAliveHandler(threading.Thread):
def stop(self): def stop(self):
self._stop_ev.set() self._stop_ev.set()
def tick(self):
self._last_recv = time.perf_counter()
def ack(self): def ack(self):
ack_time = time.perf_counter() ack_time = time.perf_counter()
self._last_ack = ack_time self._last_ack = ack_time
@ -161,6 +201,7 @@ class VoiceKeepAliveHandler(KeepAliveHandler):
def ack(self): def ack(self):
ack_time = time.perf_counter() ack_time = time.perf_counter()
self._last_ack = ack_time self._last_ack = ack_time
self._last_recv = ack_time
self.latency = ack_time - self._last_send self.latency = ack_time - self._last_send
self.recent_ack_latencies.append(self.latency) self.recent_ack_latencies.append(self.latency)
@ -240,6 +281,7 @@ class DiscordWebSocket:
self._zlib = zlib.decompressobj() self._zlib = zlib.decompressobj()
self._buffer = bytearray() self._buffer = bytearray()
self._close_code = None self._close_code = None
self._rate_limiter = GatewayRatelimiter()
@property @property
def open(self): def open(self):
@ -264,6 +306,7 @@ class DiscordWebSocket:
ws.call_hooks = client._connection.call_hooks ws.call_hooks = client._connection.call_hooks
ws._initial_identify = initial ws._initial_identify = initial
ws.shard_id = shard_id ws.shard_id = shard_id
ws._rate_limiter.shard_id = shard_id
ws.shard_count = client._connection.shard_count ws.shard_count = client._connection.shard_count
ws.session_id = session ws.session_id = session
ws.sequence = sequence ws.sequence = sequence
@ -343,6 +386,9 @@ class DiscordWebSocket:
'afk': False 'afk': False
} }
if state._intents is not None:
payload['d']['intents'] = state._intents.value
await self.call_hooks('before_identify', self.shard_id, initial=self._initial_identify) await self.call_hooks('before_identify', self.shard_id, initial=self._initial_identify)
await self.send_as_json(payload) await self.send_as_json(payload)
log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id) log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id)
@ -388,6 +434,9 @@ class DiscordWebSocket:
if seq is not None: if seq is not None:
self.sequence = seq self.sequence = seq
if self._keep_alive:
self._keep_alive.tick()
if op != self.DISPATCH: if op != self.DISPATCH:
if op == self.RECONNECT: if op == self.RECONNECT:
# "reconnect" can only be handled by the Client # "reconnect" can only be handled by the Client
@ -488,7 +537,7 @@ class DiscordWebSocket:
def _can_handle_close(self): def _can_handle_close(self):
code = self._close_code or self.socket.close_code code = self._close_code or self.socket.close_code
return code not in (1000, 4004, 4010, 4011) return code not in (1000, 4004, 4010, 4011, 4012, 4013, 4014)
async def poll_event(self): async def poll_event(self):
"""Polls for a DISPATCH event and handles the general gateway loop. """Polls for a DISPATCH event and handles the general gateway loop.
@ -529,6 +578,7 @@ class DiscordWebSocket:
raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None
async def send(self, data): async def send(self, data):
await self._rate_limiter.block()
self._dispatch('socket_raw_send', data) self._dispatch('socket_raw_send', data)
await self.socket.send_str(data) await self.socket.send_str(data)
@ -539,6 +589,14 @@ class DiscordWebSocket:
if not self._can_handle_close(): if not self._can_handle_close():
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
async def send_heartbeat(self, data):
# This bypasses the rate limit handling code since it has a higher priority
try:
await self.socket.send_str(utils.to_json(data))
except RuntimeError as exc:
if not self._can_handle_close():
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
async def change_presence(self, *, activity=None, status=None, afk=False, since=0.0): async def change_presence(self, *, activity=None, status=None, afk=False, since=0.0):
if activity is not None: if activity is not None:
if not isinstance(activity, BaseActivity): if not isinstance(activity, BaseActivity):
@ -666,6 +724,8 @@ class DiscordVoiceWebSocket:
log.debug('Sending voice websocket frame: %s.', data) log.debug('Sending voice websocket frame: %s.', data)
await self.ws.send_str(utils.to_json(data)) await self.ws.send_str(utils.to_json(data))
send_heartbeat = send_as_json
async def resume(self): async def resume(self):
state = self._connection state = self._connection
payload = { payload = {

View File

@ -305,9 +305,12 @@ class Guild(Hashable):
self._rules_channel_id = utils._get_as_snowflake(guild, 'rules_channel_id') self._rules_channel_id = utils._get_as_snowflake(guild, 'rules_channel_id')
self._public_updates_channel_id = utils._get_as_snowflake(guild, 'public_updates_channel_id') self._public_updates_channel_id = utils._get_as_snowflake(guild, 'public_updates_channel_id')
cache_members = self._state._cache_members
self_id = self._state.self_id
for mdata in guild.get('members', []): for mdata in guild.get('members', []):
member = Member(data=mdata, guild=self, state=state) member = Member(data=mdata, guild=self, state=state)
self._add_member(member) if cache_members or member.id == self_id:
self._add_member(member)
self._sync(guild) self._sync(guild)
self._large = None if member_count is None else self._member_count >= 250 self._large = None if member_count is None else self._member_count >= 250
@ -2047,6 +2050,32 @@ class Guild(Hashable):
return Widget(state=self._state, data=data) return Widget(state=self._state, data=data)
async def chunk(self, *, cache=True):
"""|coro|
Requests all members that belong to this guild. In order to use this,
:meth:`Intents.members` must be enabled.
This is a websocket operation and can be slow.
.. versionadded:: 1.5
Parameters
-----------
cache: :class:`bool`
Whether to cache the members as well.
Raises
-------
ClientException
The members intent is not enabled.
"""
if not self._state._intents.members:
raise ClientException('Intents.members must be enabled to use this.')
return await self._state.chunk_guild(self, cache=cache)
async def query_members(self, query=None, *, limit=5, user_ids=None, cache=True): async def query_members(self, query=None, *, limit=5, user_ids=None, cache=True):
"""|coro| """|coro|
@ -2055,25 +2084,19 @@ class Guild(Hashable):
This is a websocket operation and can be slow. This is a websocket operation and can be slow.
.. warning::
Most bots do not need to use this. It's mainly a helper
for bots who have disabled ``guild_subscriptions``.
.. versionadded:: 1.3 .. versionadded:: 1.3
Parameters Parameters
----------- -----------
query: :class:`str` query: Optional[:class:`str`]
The string that the username's start with. An empty string The string that the username's start with.
requests all members.
limit: :class:`int` limit: :class:`int`
The maximum number of members to send back. This must be The maximum number of members to send back. This must be
a number between 1 and 1000. a number between 5 and 100.
cache: :class:`bool` cache: :class:`bool`
Whether to cache the members internally. This makes operations Whether to cache the members internally. This makes operations
such as :meth:`get_member` work for those that matched. such as :meth:`get_member` work for those that matched.
user_ids: List[:class:`int`] user_ids: Optional[List[:class:`int`]]
List of user IDs to search for. If the user ID is not in the guild then it won't be returned. List of user IDs to search for. If the user ID is not in the guild then it won't be returned.
.. versionadded:: 1.4 .. versionadded:: 1.4
@ -2083,19 +2106,26 @@ class Guild(Hashable):
------- -------
asyncio.TimeoutError asyncio.TimeoutError
The query timed out waiting for the members. The query timed out waiting for the members.
ValueError
Invalid parameters were passed to the function
Returns Returns
-------- --------
List[:class:`Member`] List[:class:`Member`]
The list of members that have matched the query. The list of members that have matched the query.
""" """
if query is None:
if query == '':
raise ValueError('Cannot pass empty query string.')
if user_ids is None:
raise ValueError('Must pass either query or user_ids')
if user_ids is not None and query is not None: if user_ids is not None and query is not None:
raise TypeError('Cannot pass both query and user_ids') raise ValueError('Cannot pass both query and user_ids')
if user_ids is None and query is None: limit = min(100, limit or 5)
raise TypeError('Must pass either query or user_ids')
limit = limit or 5
return await self._state.query_members(self, query=query, limit=limit, user_ids=user_ids, cache=cache) return await self._state.query_members(self, query=query, limit=limit, user_ids=user_ids, cache=cache)
async def change_voice_state(self, *, channel, self_mute=False, self_deaf=False): async def change_voice_state(self, *, channel, self_mute=False, self_deaf=False):

View File

@ -266,17 +266,20 @@ class Member(discord.abc.Messageable, _BaseUser):
self._client_status[None] = data['status'] self._client_status[None] = data['status']
if len(user) > 1: if len(user) > 1:
u = self._user return self._update_inner_user(user)
original = (u.name, u.avatar, u.discriminator)
# These keys seem to always be available
modified = (user['username'], user['avatar'], user['discriminator'])
if original != modified:
to_return = User._copy(self._user)
u.name, u.avatar, u.discriminator = modified
# Signal to dispatch on_user_update
return to_return, u
return False return False
def _update_inner_user(self, user):
u = self._user
original = (u.name, u.avatar, u.discriminator)
# These keys seem to always be available
modified = (user['username'], user['avatar'], user['discriminator'])
if original != modified:
to_return = User._copy(self._user)
u.name, u.avatar, u.discriminator = modified
# Signal to dispatch on_user_update
return to_return, u
@property @property
def status(self): def status(self):
""":class:`Status`: The member's overall status. If the value is unknown, then it will be a :class:`str` instead.""" """:class:`Status`: The member's overall status. If the value is unknown, then it will be a :class:`str` instead."""

View File

@ -333,6 +333,7 @@ class AutoShardedClient(Client):
"""Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object.""" """Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object."""
return { shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items() } return { shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items() }
@utils.deprecated('Guild.chunk')
async def request_offline_members(self, *guilds): async def request_offline_members(self, *guilds):
r"""|coro| r"""|coro|
@ -346,6 +347,10 @@ class AutoShardedClient(Client):
in the guild is larger than 250. You can check if a guild is large in the guild is larger than 250. You can check if a guild is large
if :attr:`Guild.large` is ``True``. if :attr:`Guild.large` is ``True``.
.. warning::
This method is deprecated. Use :meth:`Guild.chunk` instead.
Parameters Parameters
----------- -----------
\*guilds: :class:`Guild` \*guilds: :class:`Guild`
@ -354,15 +359,15 @@ class AutoShardedClient(Client):
Raises Raises
------- -------
InvalidArgument InvalidArgument
If any guild is unavailable or not large in the collection. If any guild is unavailable in the collection.
""" """
if any(not g.large or g.unavailable for g in guilds): if any(g.unavailable for g in guilds):
raise InvalidArgument('An unavailable or non-large guild was passed.') raise InvalidArgument('An unavailable or non-large guild was passed.')
_guilds = sorted(guilds, key=lambda g: g.shard_id) _guilds = sorted(guilds, key=lambda g: g.shard_id)
for shard_id, sub_guilds in itertools.groupby(_guilds, key=lambda g: g.shard_id): for shard_id, sub_guilds in itertools.groupby(_guilds, key=lambda g: g.shard_id):
sub_guilds = list(sub_guilds) for guild in sub_guilds:
await self._connection.request_offline_members(sub_guilds, shard_id=shard_id) await self._connection.chunk_guild(guild)
async def launch_shard(self, gateway, shard_id, *, initial=False): async def launch_shard(self, gateway, shard_id, *, initial=False):
try: try:

View File

@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE.
""" """
import asyncio import asyncio
from collections import deque, namedtuple, OrderedDict from collections import deque, OrderedDict
import copy import copy
import datetime import datetime
import itertools import itertools
@ -49,19 +49,38 @@ from .channel import *
from .raw_models import * from .raw_models import *
from .member import Member from .member import Member
from .role import Role from .role import Role
from .enums import ChannelType, try_enum, Status, Enum from .enums import ChannelType, try_enum, Status
from . import utils from . import utils
from .flags import Intents
from .embeds import Embed from .embeds import Embed
from .object import Object from .object import Object
from .invite import Invite from .invite import Invite
class ListenerType(Enum): class ChunkRequest:
chunk = 0 def __init__(self, guild_id, future, resolver, *, cache=True):
query_members = 1 self.guild_id = guild_id
self.resolver = resolver
self.cache = cache
self.nonce = os.urandom(16).hex()
self.future = future
self.buffer = [] # List[Member]
def add_members(self, members):
self.buffer.extend(members)
if self.cache:
guild = self.resolver(self.guild_id)
if guild is None:
return
for member in members:
existing = guild.get_member(member.id)
if existing is None or existing.joined_at is None:
guild._add_member(member)
def done(self):
self.future.set_result(self.buffer)
Listener = namedtuple('Listener', ('type', 'future', 'predicate'))
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
ReadyState = namedtuple('ReadyState', ('launch', 'guilds'))
class ConnectionState: class ConnectionState:
def __init__(self, *, dispatch, handlers, hooks, syncer, http, loop, **options): def __init__(self, *, dispatch, handlers, hooks, syncer, http, loop, **options):
@ -93,7 +112,7 @@ class ConnectionState:
self.allowed_mentions = allowed_mentions self.allowed_mentions = allowed_mentions
# Only disable cache if both fetch_offline and guild_subscriptions are off. # Only disable cache if both fetch_offline and guild_subscriptions are off.
self._cache_members = (self._fetch_offline or self.guild_subscriptions) self._cache_members = (self._fetch_offline or self.guild_subscriptions)
self._listeners = [] self._chunk_requests = []
activity = options.get('activity', None) activity = options.get('activity', None)
if activity: if activity:
@ -109,8 +128,17 @@ class ConnectionState:
else: else:
status = str(status) status = str(status)
intents = options.get('intents', None)
if intents is not None:
if not isinstance(intents, Intents):
raise TypeError('intents parameter must be Intent not %r' % type(intents))
if not intents.members and self._fetch_offline:
raise ValueError('Intents.members has be enabled to fetch offline members.')
self._activity = activity self._activity = activity
self._status = status self._status = status
self._intents = intents
self.parsers = parsers = {} self.parsers = parsers = {}
for attr, func in inspect.getmembers(self): for attr, func in inspect.getmembers(self):
@ -138,34 +166,22 @@ class ConnectionState:
# to reconnect loops which cause mass allocations and deallocations. # to reconnect loops which cause mass allocations and deallocations.
gc.collect() gc.collect()
def get_nonce(self): def process_chunk_requests(self, guild_id, nonce, members, complete):
return os.urandom(16).hex()
def process_listeners(self, listener_type, argument, result):
removed = [] removed = []
for i, listener in enumerate(self._listeners): for i, request in enumerate(self._chunk_requests):
if listener.type != listener_type: future = request.future
continue
future = listener.future
if future.cancelled(): if future.cancelled():
removed.append(i) removed.append(i)
continue continue
try: if request.guild_id == guild_id and request.nonce == nonce:
passed = listener.predicate(argument) request.add_members(members)
except Exception as exc: if complete:
future.set_exception(exc) request.done()
removed.append(i)
else:
if passed:
future.set_result(result)
removed.append(i) removed.append(i)
if listener.type == ListenerType.chunk:
break
for index in reversed(removed): for index in reversed(removed):
del self._listeners[index] del self._chunk_requests[index]
def call_handlers(self, key, *args, **kwargs): def call_handlers(self, key, *args, **kwargs):
try: try:
@ -299,9 +315,9 @@ class ConnectionState:
self._add_guild(guild) self._add_guild(guild)
return guild return guild
def chunks_needed(self, guild): def _guild_needs_chunking(self, guild):
for _ in range(math.ceil(guild._member_count / 1000)): # If presences are enabled then we get back the old guild.large behaviour
yield self.receive_chunk(guild.id) return self._fetch_offline and not guild.chunked and not (self._intents.presences and not guild.large)
def _get_guild_channel(self, data): def _get_guild_channel(self, data):
channel_id = int(data['channel_id']) channel_id = int(data['channel_id'])
@ -319,78 +335,56 @@ class ConnectionState:
ws = self._get_websocket(guild_id) # This is ignored upstream ws = self._get_websocket(guild_id) # This is ignored upstream
await ws.request_chunks(guild_id, query=query, limit=limit, nonce=nonce) await ws.request_chunks(guild_id, query=query, limit=limit, nonce=nonce)
async def request_offline_members(self, guilds):
# get all the chunks
chunks = []
for guild in guilds:
chunks.extend(self.chunks_needed(guild))
# we only want to request ~75 guilds per chunk request.
splits = [guilds[i:i + 75] for i in range(0, len(guilds), 75)]
for split in splits:
await self.chunker([g.id for g in split])
# wait for the chunks
if chunks:
try:
await utils.sane_wait_for(chunks, timeout=len(chunks) * 30.0)
except asyncio.TimeoutError:
log.warning('Somehow timed out waiting for chunks.')
else:
log.info('Finished requesting guild member chunks for %d guilds.', len(guilds))
async def query_members(self, guild, query, limit, user_ids, cache): async def query_members(self, guild, query, limit, user_ids, cache):
guild_id = guild.id guild_id = guild.id
ws = self._get_websocket(guild_id) ws = self._get_websocket(guild_id)
if ws is None: if ws is None:
raise RuntimeError('Somehow do not have a websocket for this guild_id') raise RuntimeError('Somehow do not have a websocket for this guild_id')
# Limits over 1000 cannot be supported since future = self.loop.create_future()
# the main use case for this is guild_subscriptions being disabled request = ChunkRequest(guild.id, future, self._get_guild, cache=cache)
# and they don't receive GUILD_MEMBER events which make computing self._chunk_requests.append(request)
# member_count impossible. The only way to fix it is by limiting
# the limit parameter to 1 to 1000.
nonce = self.get_nonce()
future = self.receive_member_query(guild_id, nonce)
try: try:
# start the query operation # start the query operation
await ws.request_chunks(guild_id, query=query, limit=limit, user_ids=user_ids, nonce=nonce) await ws.request_chunks(guild_id, query=query, limit=limit, user_ids=user_ids, nonce=request.nonce)
members = await asyncio.wait_for(future, timeout=5.0) return await asyncio.wait_for(future, timeout=30.0)
if cache:
for member in members:
guild._add_member(member)
return members
except asyncio.TimeoutError: except asyncio.TimeoutError:
log.warning('Timed out waiting for chunks with query %r and limit %d for guild_id %d', query, limit, guild_id) log.warning('Timed out waiting for chunks with query %r and limit %d for guild_id %d', query, limit, guild_id)
raise raise
async def _delay_ready(self): async def _delay_ready(self):
try: try:
launch = self._ready_state.launch
# only real bots wait for GUILD_CREATE streaming # only real bots wait for GUILD_CREATE streaming
if self.is_bot: if self.is_bot:
states = []
while True: while True:
# this snippet of code is basically waiting N seconds # this snippet of code is basically waiting N seconds
# until the last GUILD_CREATE was sent # until the last GUILD_CREATE was sent
try: try:
await asyncio.wait_for(launch.wait(), timeout=self.guild_ready_timeout) guild = await asyncio.wait_for(self._ready_state.get(), timeout=self.guild_ready_timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
break break
else: else:
launch.clear() if self._guild_needs_chunking(guild):
future = await self.chunk_guild(guild, wait=False)
states.append((guild, future))
else:
if guild.unavailable is False:
self.dispatch('guild_available', guild)
else:
self.dispatch('guild_join', guild)
guilds = next(zip(*self._ready_state.guilds), []) for guild, future in states:
if self._fetch_offline: try:
await self.request_offline_members(guilds) await asyncio.wait_for(future, timeout=5.0)
except asyncio.TimeoutError:
log.warning('Shard ID %s timed out waiting for chunks for guild_id %s.', guild.shard_id, guild.id)
for guild, unavailable in self._ready_state.guilds: if guild.unavailable is False:
if unavailable is False: self.dispatch('guild_available', guild)
self.dispatch('guild_available', guild) else:
else: self.dispatch('guild_join', guild)
self.dispatch('guild_join', guild)
# remove the state # remove the state
try: try:
@ -415,16 +409,13 @@ class ConnectionState:
if self._ready_task is not None: if self._ready_task is not None:
self._ready_task.cancel() self._ready_task.cancel()
self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[]) self._ready_state = asyncio.Queue()
self.clear() self.clear()
self.user = user = ClientUser(state=self, data=data['user']) self.user = user = ClientUser(state=self, data=data['user'])
self._users[user.id] = user self._users[user.id] = user
guilds = self._ready_state.guilds
for guild_data in data['guilds']: for guild_data in data['guilds']:
guild = self._add_guild_from_data(guild_data) self._add_guild_from_data(guild_data)
if (not self.is_bot and not guild.unavailable) or guild.large:
guilds.append((guild, guild.unavailable))
for relationship in data.get('relationships', []): for relationship in data.get('relationships', []):
try: try:
@ -561,7 +552,7 @@ class ConnectionState:
guild_id = utils._get_as_snowflake(data, 'guild_id') guild_id = utils._get_as_snowflake(data, 'guild_id')
guild = self._get_guild(guild_id) guild = self._get_guild(guild_id)
if guild is None: if guild is None:
log.warning('PRESENCE_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) log.debug('PRESENCE_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id)
return return
user = data['user'] user = data['user']
@ -629,14 +620,14 @@ class ConnectionState:
channel._update(guild, data) channel._update(guild, data)
self.dispatch('guild_channel_update', old_channel, channel) self.dispatch('guild_channel_update', old_channel, channel)
else: else:
log.warning('CHANNEL_UPDATE referencing an unknown channel ID: %s. Discarding.', channel_id) log.debug('CHANNEL_UPDATE referencing an unknown channel ID: %s. Discarding.', channel_id)
else: else:
log.warning('CHANNEL_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) log.debug('CHANNEL_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id)
def parse_channel_create(self, data): def parse_channel_create(self, data):
factory, ch_type = _channel_factory(data['type']) factory, ch_type = _channel_factory(data['type'])
if factory is None: if factory is None:
log.warning('CHANNEL_CREATE referencing an unknown channel type %s. Discarding.', data['type']) log.debug('CHANNEL_CREATE referencing an unknown channel type %s. Discarding.', data['type'])
return return
channel = None channel = None
@ -655,14 +646,14 @@ class ConnectionState:
guild._add_channel(channel) guild._add_channel(channel)
self.dispatch('guild_channel_create', channel) self.dispatch('guild_channel_create', channel)
else: else:
log.warning('CHANNEL_CREATE referencing an unknown guild ID: %s. Discarding.', guild_id) log.debug('CHANNEL_CREATE referencing an unknown guild ID: %s. Discarding.', guild_id)
return return
def parse_channel_pins_update(self, data): def parse_channel_pins_update(self, data):
channel_id = int(data['channel_id']) channel_id = int(data['channel_id'])
channel = self.get_channel(channel_id) channel = self.get_channel(channel_id)
if channel is None: if channel is None:
log.warning('CHANNEL_PINS_UPDATE referencing an unknown channel ID: %s. Discarding.', channel_id) log.debug('CHANNEL_PINS_UPDATE referencing an unknown channel ID: %s. Discarding.', channel_id)
return return
last_pin = utils.parse_time(data['last_pin_timestamp']) if data['last_pin_timestamp'] else None last_pin = utils.parse_time(data['last_pin_timestamp']) if data['last_pin_timestamp'] else None
@ -696,7 +687,7 @@ class ConnectionState:
def parse_guild_member_add(self, data): def parse_guild_member_add(self, data):
guild = self._get_guild(int(data['guild_id'])) guild = self._get_guild(int(data['guild_id']))
if guild is None: if guild is None:
log.warning('GUILD_MEMBER_ADD referencing an unknown guild ID: %s. Discarding.', data['guild_id']) log.debug('GUILD_MEMBER_ADD referencing an unknown guild ID: %s. Discarding.', data['guild_id'])
return return
member = Member(guild=guild, data=data, state=self) member = Member(guild=guild, data=data, state=self)
@ -715,28 +706,32 @@ class ConnectionState:
guild._remove_member(member) guild._remove_member(member)
self.dispatch('member_remove', member) self.dispatch('member_remove', member)
else: else:
log.warning('GUILD_MEMBER_REMOVE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) log.debug('GUILD_MEMBER_REMOVE referencing an unknown guild ID: %s. Discarding.', data['guild_id'])
def parse_guild_member_update(self, data): def parse_guild_member_update(self, data):
guild = self._get_guild(int(data['guild_id'])) guild = self._get_guild(int(data['guild_id']))
user = data['user'] user = data['user']
user_id = int(user['id']) user_id = int(user['id'])
if guild is None: if guild is None:
log.warning('GUILD_MEMBER_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) log.debug('GUILD_MEMBER_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id'])
return return
member = guild.get_member(user_id) member = guild.get_member(user_id)
if member is not None: if member is not None:
old_member = copy.copy(member) old_member = Member._copy(member)
member._update(data) member._update(data)
user_update = member._update_inner_user(user)
if user_update:
self.dispatch('user_update', user_update[0], user_update[1])
self.dispatch('member_update', old_member, member) self.dispatch('member_update', old_member, member)
else: else:
log.warning('GUILD_MEMBER_UPDATE referencing an unknown member ID: %s. Discarding.', user_id) log.debug('GUILD_MEMBER_UPDATE referencing an unknown member ID: %s. Discarding.', user_id)
def parse_guild_emojis_update(self, data): def parse_guild_emojis_update(self, data):
guild = self._get_guild(int(data['guild_id'])) guild = self._get_guild(int(data['guild_id']))
if guild is None: if guild is None:
log.warning('GUILD_EMOJIS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) log.debug('GUILD_EMOJIS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id'])
return return
before_emojis = guild.emojis before_emojis = guild.emojis
@ -758,14 +753,21 @@ class ConnectionState:
return self._add_guild_from_data(data) return self._add_guild_from_data(data)
async def chunk_guild(self, guild, *, wait=True, cache=None):
cache = cache or self._cache_members
future = self.loop.create_future()
request = ChunkRequest(guild.id, future, self._get_guild, cache=cache)
self._chunk_requests.append(request)
await self.chunker(guild.id, nonce=request.nonce)
if wait:
return await request.future
return request.future
async def _chunk_and_dispatch(self, guild, unavailable): async def _chunk_and_dispatch(self, guild, unavailable):
chunks = list(self.chunks_needed(guild)) try:
await self.chunker(guild.id) await asyncio.wait_for(self.chunk_guild(guild), timeout=60.0)
if chunks: except asyncio.TimeoutError:
try: log.info('Somehow timed out waiting for chunks.')
await utils.sane_wait_for(chunks, timeout=len(chunks))
except asyncio.TimeoutError:
log.info('Somehow timed out waiting for chunks.')
if unavailable is False: if unavailable is False:
self.dispatch('guild_available', guild) self.dispatch('guild_available', guild)
@ -780,30 +782,19 @@ class ConnectionState:
guild = self._get_create_guild(data) guild = self._get_create_guild(data)
# check if it requires chunking try:
if guild.large: # Notify the on_ready state, if any, that this guild is complete.
if unavailable is False: self._ready_state.put_nowait(guild)
# check if we're waiting for 'useful' READY except AttributeError:
# and if we are, we don't want to dispatch any pass
# event such as guild_join or guild_available else:
# because we're still in the 'READY' phase. Or # If we're waiting for the event, put the rest on hold
# so we say. return
try:
state = self._ready_state
state.launch.set()
state.guilds.append((guild, unavailable))
except AttributeError:
# the _ready_state attribute is only there during
# processing of useful READY.
pass
else:
return
# since we're not waiting for 'useful' READY we'll just # check if it requires chunking
# do the chunk request here if wanted if self._guild_needs_chunking(guild):
if self._fetch_offline: asyncio.ensure_future(self._chunk_and_dispatch(guild, unavailable), loop=self.loop)
asyncio.ensure_future(self._chunk_and_dispatch(guild, unavailable), loop=self.loop) return
return
# Dispatch available if newly available # Dispatch available if newly available
if unavailable is False: if unavailable is False:
@ -822,12 +813,12 @@ class ConnectionState:
guild._from_data(data) guild._from_data(data)
self.dispatch('guild_update', old_guild, guild) self.dispatch('guild_update', old_guild, guild)
else: else:
log.warning('GUILD_UPDATE referencing an unknown guild ID: %s. Discarding.', data['id']) log.debug('GUILD_UPDATE referencing an unknown guild ID: %s. Discarding.', data['id'])
def parse_guild_delete(self, data): def parse_guild_delete(self, data):
guild = self._get_guild(int(data['id'])) guild = self._get_guild(int(data['id']))
if guild is None: if guild is None:
log.warning('GUILD_DELETE referencing an unknown guild ID: %s. Discarding.', data['id']) log.debug('GUILD_DELETE referencing an unknown guild ID: %s. Discarding.', data['id'])
return return
if data.get('unavailable', False) and guild is not None: if data.get('unavailable', False) and guild is not None:
@ -870,7 +861,7 @@ class ConnectionState:
def parse_guild_role_create(self, data): def parse_guild_role_create(self, data):
guild = self._get_guild(int(data['guild_id'])) guild = self._get_guild(int(data['guild_id']))
if guild is None: if guild is None:
log.warning('GUILD_ROLE_CREATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) log.debug('GUILD_ROLE_CREATE referencing an unknown guild ID: %s. Discarding.', data['guild_id'])
return return
role_data = data['role'] role_data = data['role']
@ -889,7 +880,7 @@ class ConnectionState:
else: else:
self.dispatch('guild_role_delete', role) self.dispatch('guild_role_delete', role)
else: else:
log.warning('GUILD_ROLE_DELETE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) log.debug('GUILD_ROLE_DELETE referencing an unknown guild ID: %s. Discarding.', data['guild_id'])
def parse_guild_role_update(self, data): def parse_guild_role_update(self, data):
guild = self._get_guild(int(data['guild_id'])) guild = self._get_guild(int(data['guild_id']))
@ -902,35 +893,29 @@ class ConnectionState:
role._update(role_data) role._update(role_data)
self.dispatch('guild_role_update', old_role, role) self.dispatch('guild_role_update', old_role, role)
else: else:
log.warning('GUILD_ROLE_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) log.debug('GUILD_ROLE_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id'])
def parse_guild_members_chunk(self, data): def parse_guild_members_chunk(self, data):
guild_id = int(data['guild_id']) guild_id = int(data['guild_id'])
guild = self._get_guild(guild_id) guild = self._get_guild(guild_id)
members = [Member(guild=guild, data=member, state=self) for member in data.get('members', [])] members = [Member(guild=guild, data=member, state=self) for member in data.get('members', [])]
log.debug('Processed a chunk for %s members in guild ID %s.', len(members), guild_id) log.debug('Processed a chunk for %s members in guild ID %s.', len(members), guild_id)
if self._cache_members: complete = data.get('chunk_index', 0) + 1 == data.get('chunk_count')
for member in members: self.process_chunk_requests(guild_id, data.get('nonce'), members, complete)
existing = guild.get_member(member.id)
if existing is None or existing.joined_at is None:
guild._add_member(member)
self.process_listeners(ListenerType.chunk, guild, len(members))
self.process_listeners(ListenerType.query_members, (guild_id, data.get('nonce')), members)
def parse_guild_integrations_update(self, data): def parse_guild_integrations_update(self, data):
guild = self._get_guild(int(data['guild_id'])) guild = self._get_guild(int(data['guild_id']))
if guild is not None: if guild is not None:
self.dispatch('guild_integrations_update', guild) self.dispatch('guild_integrations_update', guild)
else: else:
log.warning('GUILD_INTEGRATIONS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) log.debug('GUILD_INTEGRATIONS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id'])
def parse_webhooks_update(self, data): def parse_webhooks_update(self, data):
channel = self.get_channel(int(data['channel_id'])) channel = self.get_channel(int(data['channel_id']))
if channel is not None: if channel is not None:
self.dispatch('webhooks_update', channel) self.dispatch('webhooks_update', channel)
else: else:
log.warning('WEBHOOKS_UPDATE referencing an unknown channel ID: %s. Discarding.', data['channel_id']) log.debug('WEBHOOKS_UPDATE referencing an unknown channel ID: %s. Discarding.', data['channel_id'])
def parse_voice_state_update(self, data): def parse_voice_state_update(self, data):
guild = self._get_guild(utils._get_as_snowflake(data, 'guild_id')) guild = self._get_guild(utils._get_as_snowflake(data, 'guild_id'))
@ -947,7 +932,7 @@ class ConnectionState:
if member is not None: if member is not None:
self.dispatch('voice_state_update', member, before, after) self.dispatch('voice_state_update', member, before, after)
else: else:
log.warning('VOICE_STATE_UPDATE referencing an unknown member ID: %s. Discarding.', data['user_id']) log.debug('VOICE_STATE_UPDATE referencing an unknown member ID: %s. Discarding.', data['user_id'])
else: else:
# in here we're either at private or group calls # in here we're either at private or group calls
call = self._calls.get(channel_id) call = self._calls.get(channel_id)
@ -1040,21 +1025,6 @@ class ConnectionState:
def create_message(self, *, channel, data): def create_message(self, *, channel, data):
return Message(state=self, channel=channel, data=data) return Message(state=self, channel=channel, data=data)
def receive_chunk(self, guild_id):
future = self.loop.create_future()
listener = Listener(ListenerType.chunk, future, lambda s: s.id == guild_id)
self._listeners.append(listener)
return future
def receive_member_query(self, guild_id, nonce):
def predicate(args, *, guild_id=guild_id, nonce=nonce):
return args == (guild_id, nonce)
future = self.loop.create_future()
listener = Listener(ListenerType.query_members, future, predicate)
self._listeners.append(listener)
return future
class AutoShardedConnectionState(ConnectionState): class AutoShardedConnectionState(ConnectionState):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -1077,51 +1047,56 @@ class AutoShardedConnectionState(ConnectionState):
ws = self._get_websocket(guild_id, shard_id=shard_id) ws = self._get_websocket(guild_id, shard_id=shard_id)
await ws.request_chunks(guild_id, query=query, limit=limit, nonce=nonce) await ws.request_chunks(guild_id, query=query, limit=limit, nonce=nonce)
async def request_offline_members(self, guilds, *, shard_id):
# get all the chunks
chunks = []
for guild in guilds:
chunks.extend(self.chunks_needed(guild))
# we only want to request ~75 guilds per chunk request.
splits = [guilds[i:i + 75] for i in range(0, len(guilds), 75)]
for split in splits:
await self.chunker([g.id for g in split], shard_id=shard_id)
# wait for the chunks
if chunks:
try:
await utils.sane_wait_for(chunks, timeout=len(chunks) * 30.0)
except asyncio.TimeoutError:
log.info('Somehow timed out waiting for chunks.')
else:
log.info('Finished requesting guild member chunks for %d guilds.', len(guilds))
async def _delay_ready(self): async def _delay_ready(self):
await self.shards_launched.wait() await self.shards_launched.wait()
launch = self._ready_state.launch processed = []
max_concurrency = len(self.shard_ids) * 2
current_bucket = []
while True: while True:
# this snippet of code is basically waiting N seconds # this snippet of code is basically waiting N seconds
# until the last GUILD_CREATE was sent # until the last GUILD_CREATE was sent
try: try:
await asyncio.wait_for(launch.wait(), timeout=self.guild_ready_timeout) guild = await asyncio.wait_for(self._ready_state.get(), timeout=self.guild_ready_timeout)
except asyncio.TimeoutError: except asyncio.TimeoutError:
break break
else: else:
launch.clear() if self._guild_needs_chunking(guild):
log.debug('Guild ID %d requires chunking, will be done in the background.', guild.id)
if len(current_bucket) >= max_concurrency:
try:
await utils.sane_wait_for(current_bucket, timeout=max_concurrency * 70.0)
except asyncio.TimeoutError:
fmt = 'Shard ID %s failed to wait for chunks from a sub-bucket with length %d'
log.warning(fmt, self.shard_id, len(current_bucket))
finally:
current_bucket = []
guilds = sorted(self._ready_state.guilds, key=lambda g: g[0].shard_id) # Chunk the guild in the background while we wait for GUILD_CREATE streaming
future = asyncio.ensure_future(self.chunk_guild(guild))
current_bucket.append(future)
else:
future = self.loop.create_future()
future.set_result([])
for shard_id, sub_guilds_info in itertools.groupby(guilds, key=lambda g: g[0].shard_id): processed.append((guild, future))
sub_guilds, sub_available = zip(*sub_guilds_info)
if self._fetch_offline:
await self.request_offline_members(sub_guilds, shard_id=shard_id)
for guild, unavailable in zip(sub_guilds, sub_available): guilds = sorted(processed, key=lambda g: g[0].shard_id)
if unavailable is False: for shard_id, info in itertools.groupby(guilds, key=lambda g: g[0].shard_id):
children, futures = zip(*info)
# 110 reqs/minute w/ 1 req/guild plus some buffer
timeout = 61 * (len(children) / 110)
try:
await utils.sane_wait_for(futures, timeout=timeout)
except asyncio.TimeoutError:
log.warning('Shard ID %s failed to wait for chunks (timeout=%.2f) for %d guilds', self.shard_id,
timeout,
len(guilds))
for guild in children:
if guild.unavailable is False:
self.dispatch('guild_available', guild) self.dispatch('guild_available', guild)
else: else:
self.dispatch('guild_join', guild) self.dispatch('guild_join', guild)
self.dispatch('shard_ready', shard_id) self.dispatch('shard_ready', shard_id)
# remove the state # remove the state
@ -1141,16 +1116,13 @@ class AutoShardedConnectionState(ConnectionState):
def parse_ready(self, data): def parse_ready(self, data):
if not hasattr(self, '_ready_state'): if not hasattr(self, '_ready_state'):
self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[]) self._ready_state = asyncio.Queue()
self.user = user = ClientUser(state=self, data=data['user']) self.user = user = ClientUser(state=self, data=data['user'])
self._users[user.id] = user self._users[user.id] = user
guilds = self._ready_state.guilds
for guild_data in data['guilds']: for guild_data in data['guilds']:
guild = self._add_guild_from_data(guild_data) self._add_guild_from_data(guild_data)
if guild.large:
guilds.append((guild, guild.unavailable))
if self._messages: if self._messages:
self._update_message_references() self._update_message_references()