diff --git a/discord/client.py b/discord/client.py index 59be489a..23809424 100644 --- a/discord/client.py +++ b/discord/client.py @@ -141,6 +141,10 @@ class Client: Integer starting at ``0`` and less than :attr:`.shard_count`. shard_count: Optional[:class:`int`] The total number of shards. + intents: :class:`Intents` + A list of intents that you want to enable for the session. This is a way of + disabling and enabling certain gateway events from triggering and being sent. + Currently, if no intents are passed then you will receive all data. fetch_offline_members: :class:`bool` Indicates if :func:`.on_ready` should be delayed to fetch all offline members from the guilds the client belongs to. If this is ``False``\, then @@ -375,6 +379,7 @@ class Client: print('Ignoring exception in {}'.format(event_method), file=sys.stderr) traceback.print_exc() + @utils.deprecated('Guild.chunk') async def request_offline_members(self, *guilds): r"""|coro| @@ -388,6 +393,10 @@ class Client: in the guild is larger than 250. You can check if a guild is large if :attr:`.Guild.large` is ``True``. + .. warning:: + + This method is deprecated. Use :meth:`Guild.chunk` instead. + Parameters ----------- \*guilds: :class:`.Guild` @@ -396,12 +405,13 @@ class Client: Raises ------- :exc:`.InvalidArgument` - If any guild is unavailable or not large in the collection. + If any guild is unavailable in the collection. """ - if any(not g.large or g.unavailable for g in guilds): - raise InvalidArgument('An unavailable or non-large guild was passed.') + if any(g.unavailable for g in guilds): + raise InvalidArgument('An unavailable guild was passed.') - await self._connection.request_offline_members(guilds) + for guild in guilds: + await self._connection.chunk_guild(guild) # hooks diff --git a/discord/flags.py b/discord/flags.py index 448bced2..bc2a52ed 100644 --- a/discord/flags.py +++ b/discord/flags.py @@ -29,7 +29,8 @@ from .enums import UserFlags __all__ = ( 'SystemChannelFlags', 'MessageFlags', - 'PublicUserFlags' + 'PublicUserFlags', + 'Intents', ) class flag_value: @@ -327,3 +328,326 @@ class PublicUserFlags(BaseFlags): def all(self): """List[:class:`UserFlags`]: Returns all public flags the user has.""" return [public_flag for public_flag in UserFlags if self._has_flag(public_flag.value)] + + +@fill_with_flags() +class Intents(BaseFlags): + r"""Wraps up a Discord gateway intent flag. + + Similar to :class:`Permissions`\, the properties provided are two way. + You can set and retrieve individual bits using the properties as if they + were regular bools. + + To construct an object you can pass keyword arguments denoting the flags + to enable or disable. + + This is used to disable certain gateway features that are unnecessary to + run your bot. To make use of this, it is passed to the ``intents`` keyword + argument of :class:`Client`. + + A default instance of this class has everything enabled except :attr:`presences` + and :attr:`members`. + + .. versionadded:: 1.5 + + .. container:: operations + + .. describe:: x == y + + Checks if two flags are equal. + .. describe:: x != y + + Checks if two flags are not equal. + .. describe:: hash(x) + + Return the flag's hash. + .. describe:: iter(x) + + Returns an iterator of ``(name, value)`` pairs. This allows it + to be, for example, constructed as a dict or a list of pairs. + + Attributes + ----------- + value: :class:`int` + The raw value. You should query flags via the properties + rather than using this raw value. + """ + + __slots__ = () + + def __init__(self, **kwargs): + # Change the default value to everything being enabled + # except presences and members + bits = max(self.VALID_FLAGS.values()).bit_length() + self.value = (1 << bits) - 1 + self.presences = False + self.members = False + for key, value in kwargs.items(): + if key not in self.VALID_FLAGS: + raise TypeError('%r is not a valid flag name.' % key) + setattr(self, key, value) + + @classmethod + def all(cls): + """A factory method that creates a :class:`Intents` with everything enabled.""" + bits = max(cls.VALID_FLAGS.values()).bit_length() + value = (1 << bits) - 1 + self = cls.__new__(cls) + self.value = value + return self + + @classmethod + def none(cls): + """A factory method that creates a :class:`Intents` with everything disabled.""" + self = cls.__new__(cls) + self.value = self.DEFAULT_VALUE + return self + + @flag_value + def guilds(self): + """:class:`bool`: Whether guild related events are enabled. + + This corresponds to the following events: + + - :func:`on_guild_join` + - :func:`on_guild_remove` + - :func:`on_guild_available` + - :func:`on_guild_unavailable` + - :func:`on_guild_channel_update` + - :func:`on_guild_channel_create` + - :func:`on_guild_channel_delete` + - :func:`on_guild_channel_pins_update` + """ + return 1 << 0 + + @flag_value + def members(self): + """:class:`bool`: Whether guild member related events are enabled. + + This corresponds to the following events: + + - :func:`on_member_join` + - :func:`on_member_remove` + - :func:`on_member_update` (nickname, roles) + - :func:`on_user_update` + + .. note:: + + Currently, this requires opting in explicitly via the dev portal as well. + Bots in over 100 guilds will need to apply to Discord for verification. + """ + return 1 << 1 + + @flag_value + def bans(self): + """:class:`bool`: Whether guild ban related events are enabled. + + This corresponds to the following events: + + - :func:`on_member_ban` + - :func:`on_member_unban` + """ + return 1 << 2 + + @flag_value + def emojis(self): + """:class:`bool`: Whether guild emoji related events are enabled. + + This corresponds to the following events: + + - :func:`on_guild_emojis_update` + """ + return 1 << 3 + + @flag_value + def integrations(self): + """:class:`bool`: Whether guild integration related events are enabled. + + This corresponds to the following events: + + - :func:`on_guild_integrations_update` + """ + return 1 << 4 + + @flag_value + def webhooks(self): + """:class:`bool`: Whether guild webhook related events are enabled. + + This corresponds to the following events: + + - :func:`on_webhooks_update` + """ + return 1 << 5 + + @flag_value + def invites(self): + """:class:`bool`: Whether guild invite related events are enabled. + + This corresponds to the following events: + + - :func:`on_invite_create` + - :func:`on_invite_delete` + """ + return 1 << 6 + + @flag_value + def voice_states(self): + """:class:`bool`: Whether guild voice state related events are enabled. + + This corresponds to the following events: + + - :func:`on_voice_state_update` + """ + return 1 << 7 + + @flag_value + def presences(self): + """:class:`bool`: Whether guild voice state related events are enabled. + + This corresponds to the following events: + + - :func:`on_member_update` (activities, status) + + .. note:: + + Currently, this requires opting in explicitly via the dev portal as well. + Bots in over 100 guilds will need to apply to Discord for verification. + """ + return 1 << 8 + + @flag_value + def messages(self): + """:class:`bool`: Whether guild and direct message related events are enabled. + + This is a shortcut to set or get both :attr:`guild_messages` and :attr:`dm_messages`. + + This corresponds to the following events: + + - :func:`on_message` (both guilds and DMs) + - :func:`on_message_update` (both guilds and DMs) + - :func:`on_message_delete` (both guilds and DMs) + - :func:`on_raw_message_delete` (both guilds and DMs) + - :func:`on_raw_message_update` (both guilds and DMs) + - :func:`on_private_channel_create` + """ + return (1 << 9) | (1 << 12) + + @flag_value + def guild_messages(self): + """:class:`bool`: Whether guild message related events are enabled. + + See also :attr:`dm_messages` for DMs or :attr:`messages` for both. + + This corresponds to the following events: + + - :func:`on_message` (only for guilds) + - :func:`on_message_update` (only for guilds) + - :func:`on_message_delete` (only for guilds) + - :func:`on_raw_message_delete` (only for guilds) + - :func:`on_raw_message_update` (only for guilds) + """ + return 1 << 9 + + @flag_value + def dm_messages(self): + """:class:`bool`: Whether direct message related events are enabled. + + See also :attr:`guild_messages` for guilds or :attr:`messages` for both. + + This corresponds to the following events: + + - :func:`on_message` (only for DMs) + - :func:`on_message_update` (only for DMs) + - :func:`on_message_delete` (only for DMs) + - :func:`on_raw_message_delete` (only for DMs) + - :func:`on_raw_message_update` (only for DMs) + - :func:`on_private_channel_create` + """ + return 1 << 12 + + @flag_value + def reactions(self): + """:class:`bool`: Whether guild and direct message reaction related events are enabled. + + This is a shortcut to set or get both :attr:`guild_reactions` and :attr:`dm_reactions`. + + This corresponds to the following events: + + - :func:`on_reaction_add` (both guilds and DMs) + - :func:`on_reaction_remove` (both guilds and DMs) + - :func:`on_reaction_clear` (both guilds and DMs) + - :func:`on_raw_reaction_add` (both guilds and DMs) + - :func:`on_raw_reaction_remove` (both guilds and DMs) + - :func:`on_raw_reaction_clear` (both guilds and DMs) + """ + return (1 << 10) | (1 << 13) + + @flag_value + def guild_reactions(self): + """:class:`bool`: Whether guild message reaction related events are enabled. + + See also :attr:`dm_reactions` for DMs or :attr:`reactions` for both. + + This corresponds to the following events: + + - :func:`on_reaction_add` (only for guilds) + - :func:`on_reaction_remove` (only for guilds) + - :func:`on_reaction_clear` (only for guilds) + - :func:`on_raw_reaction_add` (only for guilds) + - :func:`on_raw_reaction_remove` (only for guilds) + - :func:`on_raw_reaction_clear` (only for guilds) + """ + return 1 << 10 + + @flag_value + def dm_reactions(self): + """:class:`bool`: Whether direct message reaction related events are enabled. + + See also :attr:`guild_reactions` for guilds or :attr:`reactions` for both. + + This corresponds to the following events: + + - :func:`on_reaction_add` (only for DMs) + - :func:`on_reaction_remove` (only for DMs) + - :func:`on_reaction_clear` (only for DMs) + - :func:`on_raw_reaction_add` (only for DMs) + - :func:`on_raw_reaction_remove` (only for DMs) + - :func:`on_raw_reaction_clear` (only for DMs) + """ + return 1 << 13 + + @flag_value + def typing(self): + """:class:`bool`: Whether guild and direct message typing related events are enabled. + + This is a shortcut to set or get both :attr:`guild_typing` and :attr:`dm_typing`. + + This corresponds to the following events: + + - :func:`on_typing` (both guilds and DMs) + """ + return (1 << 11) | (1 << 14) + + @flag_value + def guild_typing(self): + """:class:`bool`: Whether guild and direct message typing related events are enabled. + + See also :attr:`dm_typing` for DMs or :attr:`typing` for both. + + This corresponds to the following events: + + - :func:`on_typing` (only for guilds) + """ + return 1 << 11 + + @flag_value + def dm_typing(self): + """:class:`bool`: Whether guild and direct message typing related events are enabled. + + See also :attr:`guild_typing` for guilds or :attr:`typing` for both. + + This corresponds to the following events: + + - :func:`on_typing` (only for DMs) + """ + return 1 << 14 diff --git a/discord/gateway.py b/discord/gateway.py index 81ff69b8..9db98301 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -66,6 +66,42 @@ class WebSocketClosure(Exception): EventListener = namedtuple('EventListener', 'predicate event result future') +class GatewayRatelimiter: + def __init__(self, count=110, per=60.0): + # The default is 110 to give room for at least 10 heartbeats per minute + self.max = count + self.remaining = count + self.window = 0.0 + self.per = per + self.lock = asyncio.Lock() + self.shard_id = None + + def get_delay(self): + current = time.time() + + if current > self.window + self.per: + self.remaining = self.max + + if self.remaining == self.max: + self.window = current + + if self.remaining == 0: + return self.per - (current - self.window) + + self.remaining -= 1 + if self.remaining == 0: + self.window = current + + return 0.0 + + async def block(self): + async with self.lock: + delta = self.get_delay() + if delta: + log.warning('WebSocket in shard ID %s is ratelimited, waiting %.2f seconds', self.shard_id, delta) + await asyncio.sleep(delta) + + class KeepAliveHandler(threading.Thread): def __init__(self, *args, **kwargs): ws = kwargs.pop('ws', None) @@ -83,12 +119,13 @@ class KeepAliveHandler(threading.Thread): self._stop_ev = threading.Event() self._last_ack = time.perf_counter() self._last_send = time.perf_counter() + self._last_recv = time.perf_counter() self.latency = float('inf') self.heartbeat_timeout = ws._max_heartbeat_timeout def run(self): while not self._stop_ev.wait(self.interval): - if self._last_ack + self.heartbeat_timeout < time.perf_counter(): + if self._last_recv + self.heartbeat_timeout < time.perf_counter(): log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id) coro = self.ws.close(4000) f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop) @@ -103,7 +140,7 @@ class KeepAliveHandler(threading.Thread): data = self.get_payload() log.debug(self.msg, self.shard_id, data['d']) - coro = self.ws.send_as_json(data) + coro = self.ws.send_heartbeat(data) f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop) try: # block until sending is complete @@ -137,6 +174,9 @@ class KeepAliveHandler(threading.Thread): def stop(self): self._stop_ev.set() + def tick(self): + self._last_recv = time.perf_counter() + def ack(self): ack_time = time.perf_counter() self._last_ack = ack_time @@ -161,6 +201,7 @@ class VoiceKeepAliveHandler(KeepAliveHandler): def ack(self): ack_time = time.perf_counter() self._last_ack = ack_time + self._last_recv = ack_time self.latency = ack_time - self._last_send self.recent_ack_latencies.append(self.latency) @@ -240,6 +281,7 @@ class DiscordWebSocket: self._zlib = zlib.decompressobj() self._buffer = bytearray() self._close_code = None + self._rate_limiter = GatewayRatelimiter() @property def open(self): @@ -264,6 +306,7 @@ class DiscordWebSocket: ws.call_hooks = client._connection.call_hooks ws._initial_identify = initial ws.shard_id = shard_id + ws._rate_limiter.shard_id = shard_id ws.shard_count = client._connection.shard_count ws.session_id = session ws.sequence = sequence @@ -343,6 +386,9 @@ class DiscordWebSocket: 'afk': False } + if state._intents is not None: + payload['d']['intents'] = state._intents.value + await self.call_hooks('before_identify', self.shard_id, initial=self._initial_identify) await self.send_as_json(payload) log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id) @@ -388,6 +434,9 @@ class DiscordWebSocket: if seq is not None: self.sequence = seq + if self._keep_alive: + self._keep_alive.tick() + if op != self.DISPATCH: if op == self.RECONNECT: # "reconnect" can only be handled by the Client @@ -488,7 +537,7 @@ class DiscordWebSocket: def _can_handle_close(self): code = self._close_code or self.socket.close_code - return code not in (1000, 4004, 4010, 4011) + return code not in (1000, 4004, 4010, 4011, 4012, 4013, 4014) async def poll_event(self): """Polls for a DISPATCH event and handles the general gateway loop. @@ -529,6 +578,7 @@ class DiscordWebSocket: raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None async def send(self, data): + await self._rate_limiter.block() self._dispatch('socket_raw_send', data) await self.socket.send_str(data) @@ -539,6 +589,14 @@ class DiscordWebSocket: if not self._can_handle_close(): raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc + async def send_heartbeat(self, data): + # This bypasses the rate limit handling code since it has a higher priority + try: + await self.socket.send_str(utils.to_json(data)) + except RuntimeError as exc: + if not self._can_handle_close(): + raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc + async def change_presence(self, *, activity=None, status=None, afk=False, since=0.0): if activity is not None: if not isinstance(activity, BaseActivity): @@ -666,6 +724,8 @@ class DiscordVoiceWebSocket: log.debug('Sending voice websocket frame: %s.', data) await self.ws.send_str(utils.to_json(data)) + send_heartbeat = send_as_json + async def resume(self): state = self._connection payload = { diff --git a/discord/guild.py b/discord/guild.py index 15731363..b2cdbb1d 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -305,9 +305,12 @@ class Guild(Hashable): self._rules_channel_id = utils._get_as_snowflake(guild, 'rules_channel_id') self._public_updates_channel_id = utils._get_as_snowflake(guild, 'public_updates_channel_id') + cache_members = self._state._cache_members + self_id = self._state.self_id for mdata in guild.get('members', []): member = Member(data=mdata, guild=self, state=state) - self._add_member(member) + if cache_members or member.id == self_id: + self._add_member(member) self._sync(guild) self._large = None if member_count is None else self._member_count >= 250 @@ -2047,6 +2050,32 @@ class Guild(Hashable): return Widget(state=self._state, data=data) + async def chunk(self, *, cache=True): + """|coro| + + Requests all members that belong to this guild. In order to use this, + :meth:`Intents.members` must be enabled. + + This is a websocket operation and can be slow. + + .. versionadded:: 1.5 + + Parameters + ----------- + cache: :class:`bool` + Whether to cache the members as well. + + Raises + ------- + ClientException + The members intent is not enabled. + """ + + if not self._state._intents.members: + raise ClientException('Intents.members must be enabled to use this.') + + return await self._state.chunk_guild(self, cache=cache) + async def query_members(self, query=None, *, limit=5, user_ids=None, cache=True): """|coro| @@ -2055,25 +2084,19 @@ class Guild(Hashable): This is a websocket operation and can be slow. - .. warning:: - - Most bots do not need to use this. It's mainly a helper - for bots who have disabled ``guild_subscriptions``. - .. versionadded:: 1.3 Parameters ----------- - query: :class:`str` - The string that the username's start with. An empty string - requests all members. + query: Optional[:class:`str`] + The string that the username's start with. limit: :class:`int` The maximum number of members to send back. This must be - a number between 1 and 1000. + a number between 5 and 100. cache: :class:`bool` Whether to cache the members internally. This makes operations such as :meth:`get_member` work for those that matched. - user_ids: List[:class:`int`] + user_ids: Optional[List[:class:`int`]] List of user IDs to search for. If the user ID is not in the guild then it won't be returned. .. versionadded:: 1.4 @@ -2083,19 +2106,26 @@ class Guild(Hashable): ------- asyncio.TimeoutError The query timed out waiting for the members. + ValueError + Invalid parameters were passed to the function Returns -------- List[:class:`Member`] The list of members that have matched the query. """ + + if query is None: + if query == '': + raise ValueError('Cannot pass empty query string.') + + if user_ids is None: + raise ValueError('Must pass either query or user_ids') + if user_ids is not None and query is not None: - raise TypeError('Cannot pass both query and user_ids') + raise ValueError('Cannot pass both query and user_ids') - if user_ids is None and query is None: - raise TypeError('Must pass either query or user_ids') - - limit = limit or 5 + limit = min(100, limit or 5) return await self._state.query_members(self, query=query, limit=limit, user_ids=user_ids, cache=cache) async def change_voice_state(self, *, channel, self_mute=False, self_deaf=False): diff --git a/discord/member.py b/discord/member.py index 1fb11e63..0b89d59f 100644 --- a/discord/member.py +++ b/discord/member.py @@ -266,17 +266,20 @@ class Member(discord.abc.Messageable, _BaseUser): self._client_status[None] = data['status'] if len(user) > 1: - u = self._user - original = (u.name, u.avatar, u.discriminator) - # These keys seem to always be available - modified = (user['username'], user['avatar'], user['discriminator']) - if original != modified: - to_return = User._copy(self._user) - u.name, u.avatar, u.discriminator = modified - # Signal to dispatch on_user_update - return to_return, u + return self._update_inner_user(user) return False + def _update_inner_user(self, user): + u = self._user + original = (u.name, u.avatar, u.discriminator) + # These keys seem to always be available + modified = (user['username'], user['avatar'], user['discriminator']) + if original != modified: + to_return = User._copy(self._user) + u.name, u.avatar, u.discriminator = modified + # Signal to dispatch on_user_update + return to_return, u + @property def status(self): """:class:`Status`: The member's overall status. If the value is unknown, then it will be a :class:`str` instead.""" @@ -433,7 +436,7 @@ class Member(discord.abc.Messageable, _BaseUser): guild = self.guild if len(self._roles) == 0: return guild.default_role - + return max(guild.get_role(rid) or guild.default_role for rid in self._roles) @property diff --git a/discord/shard.py b/discord/shard.py index f6320678..2ed7724d 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -333,6 +333,7 @@ class AutoShardedClient(Client): """Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object.""" return { shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items() } + @utils.deprecated('Guild.chunk') async def request_offline_members(self, *guilds): r"""|coro| @@ -346,6 +347,10 @@ class AutoShardedClient(Client): in the guild is larger than 250. You can check if a guild is large if :attr:`Guild.large` is ``True``. + .. warning:: + + This method is deprecated. Use :meth:`Guild.chunk` instead. + Parameters ----------- \*guilds: :class:`Guild` @@ -354,15 +359,15 @@ class AutoShardedClient(Client): Raises ------- InvalidArgument - If any guild is unavailable or not large in the collection. + If any guild is unavailable in the collection. """ - if any(not g.large or g.unavailable for g in guilds): + if any(g.unavailable for g in guilds): raise InvalidArgument('An unavailable or non-large guild was passed.') _guilds = sorted(guilds, key=lambda g: g.shard_id) for shard_id, sub_guilds in itertools.groupby(_guilds, key=lambda g: g.shard_id): - sub_guilds = list(sub_guilds) - await self._connection.request_offline_members(sub_guilds, shard_id=shard_id) + for guild in sub_guilds: + await self._connection.chunk_guild(guild) async def launch_shard(self, gateway, shard_id, *, initial=False): try: diff --git a/discord/state.py b/discord/state.py index f0e93d35..23200da3 100644 --- a/discord/state.py +++ b/discord/state.py @@ -25,7 +25,7 @@ DEALINGS IN THE SOFTWARE. """ import asyncio -from collections import deque, namedtuple, OrderedDict +from collections import deque, OrderedDict import copy import datetime import itertools @@ -49,19 +49,38 @@ from .channel import * from .raw_models import * from .member import Member from .role import Role -from .enums import ChannelType, try_enum, Status, Enum +from .enums import ChannelType, try_enum, Status from . import utils +from .flags import Intents from .embeds import Embed from .object import Object from .invite import Invite -class ListenerType(Enum): - chunk = 0 - query_members = 1 +class ChunkRequest: + def __init__(self, guild_id, future, resolver, *, cache=True): + self.guild_id = guild_id + self.resolver = resolver + self.cache = cache + self.nonce = os.urandom(16).hex() + self.future = future + self.buffer = [] # List[Member] + + def add_members(self, members): + self.buffer.extend(members) + if self.cache: + guild = self.resolver(self.guild_id) + if guild is None: + return + + for member in members: + existing = guild.get_member(member.id) + if existing is None or existing.joined_at is None: + guild._add_member(member) + + def done(self): + self.future.set_result(self.buffer) -Listener = namedtuple('Listener', ('type', 'future', 'predicate')) log = logging.getLogger(__name__) -ReadyState = namedtuple('ReadyState', ('launch', 'guilds')) class ConnectionState: def __init__(self, *, dispatch, handlers, hooks, syncer, http, loop, **options): @@ -93,7 +112,7 @@ class ConnectionState: self.allowed_mentions = allowed_mentions # Only disable cache if both fetch_offline and guild_subscriptions are off. self._cache_members = (self._fetch_offline or self.guild_subscriptions) - self._listeners = [] + self._chunk_requests = [] activity = options.get('activity', None) if activity: @@ -109,8 +128,17 @@ class ConnectionState: else: status = str(status) + intents = options.get('intents', None) + if intents is not None: + if not isinstance(intents, Intents): + raise TypeError('intents parameter must be Intent not %r' % type(intents)) + + if not intents.members and self._fetch_offline: + raise ValueError('Intents.members has be enabled to fetch offline members.') + self._activity = activity self._status = status + self._intents = intents self.parsers = parsers = {} for attr, func in inspect.getmembers(self): @@ -138,34 +166,22 @@ class ConnectionState: # to reconnect loops which cause mass allocations and deallocations. gc.collect() - def get_nonce(self): - return os.urandom(16).hex() - - def process_listeners(self, listener_type, argument, result): + def process_chunk_requests(self, guild_id, nonce, members, complete): removed = [] - for i, listener in enumerate(self._listeners): - if listener.type != listener_type: - continue - - future = listener.future + for i, request in enumerate(self._chunk_requests): + future = request.future if future.cancelled(): removed.append(i) continue - try: - passed = listener.predicate(argument) - except Exception as exc: - future.set_exception(exc) - removed.append(i) - else: - if passed: - future.set_result(result) + if request.guild_id == guild_id and request.nonce == nonce: + request.add_members(members) + if complete: + request.done() removed.append(i) - if listener.type == ListenerType.chunk: - break for index in reversed(removed): - del self._listeners[index] + del self._chunk_requests[index] def call_handlers(self, key, *args, **kwargs): try: @@ -299,9 +315,9 @@ class ConnectionState: self._add_guild(guild) return guild - def chunks_needed(self, guild): - for _ in range(math.ceil(guild._member_count / 1000)): - yield self.receive_chunk(guild.id) + def _guild_needs_chunking(self, guild): + # If presences are enabled then we get back the old guild.large behaviour + return self._fetch_offline and not guild.chunked and not (self._intents.presences and not guild.large) def _get_guild_channel(self, data): channel_id = int(data['channel_id']) @@ -319,78 +335,56 @@ class ConnectionState: ws = self._get_websocket(guild_id) # This is ignored upstream await ws.request_chunks(guild_id, query=query, limit=limit, nonce=nonce) - async def request_offline_members(self, guilds): - # get all the chunks - chunks = [] - for guild in guilds: - chunks.extend(self.chunks_needed(guild)) - - # we only want to request ~75 guilds per chunk request. - splits = [guilds[i:i + 75] for i in range(0, len(guilds), 75)] - for split in splits: - await self.chunker([g.id for g in split]) - - # wait for the chunks - if chunks: - try: - await utils.sane_wait_for(chunks, timeout=len(chunks) * 30.0) - except asyncio.TimeoutError: - log.warning('Somehow timed out waiting for chunks.') - else: - log.info('Finished requesting guild member chunks for %d guilds.', len(guilds)) - async def query_members(self, guild, query, limit, user_ids, cache): guild_id = guild.id ws = self._get_websocket(guild_id) if ws is None: raise RuntimeError('Somehow do not have a websocket for this guild_id') - # Limits over 1000 cannot be supported since - # the main use case for this is guild_subscriptions being disabled - # and they don't receive GUILD_MEMBER events which make computing - # member_count impossible. The only way to fix it is by limiting - # the limit parameter to 1 to 1000. - nonce = self.get_nonce() - future = self.receive_member_query(guild_id, nonce) + future = self.loop.create_future() + request = ChunkRequest(guild.id, future, self._get_guild, cache=cache) + self._chunk_requests.append(request) + try: # start the query operation - await ws.request_chunks(guild_id, query=query, limit=limit, user_ids=user_ids, nonce=nonce) - members = await asyncio.wait_for(future, timeout=5.0) - - if cache: - for member in members: - guild._add_member(member) - - return members + await ws.request_chunks(guild_id, query=query, limit=limit, user_ids=user_ids, nonce=request.nonce) + return await asyncio.wait_for(future, timeout=30.0) except asyncio.TimeoutError: log.warning('Timed out waiting for chunks with query %r and limit %d for guild_id %d', query, limit, guild_id) raise async def _delay_ready(self): try: - launch = self._ready_state.launch - # only real bots wait for GUILD_CREATE streaming if self.is_bot: + states = [] while True: # this snippet of code is basically waiting N seconds # until the last GUILD_CREATE was sent try: - await asyncio.wait_for(launch.wait(), timeout=self.guild_ready_timeout) + guild = await asyncio.wait_for(self._ready_state.get(), timeout=self.guild_ready_timeout) except asyncio.TimeoutError: break else: - launch.clear() + if self._guild_needs_chunking(guild): + future = await self.chunk_guild(guild, wait=False) + states.append((guild, future)) + else: + if guild.unavailable is False: + self.dispatch('guild_available', guild) + else: + self.dispatch('guild_join', guild) - guilds = next(zip(*self._ready_state.guilds), []) - if self._fetch_offline: - await self.request_offline_members(guilds) + for guild, future in states: + try: + await asyncio.wait_for(future, timeout=5.0) + except asyncio.TimeoutError: + log.warning('Shard ID %s timed out waiting for chunks for guild_id %s.', guild.shard_id, guild.id) - for guild, unavailable in self._ready_state.guilds: - if unavailable is False: - self.dispatch('guild_available', guild) - else: - self.dispatch('guild_join', guild) + if guild.unavailable is False: + self.dispatch('guild_available', guild) + else: + self.dispatch('guild_join', guild) # remove the state try: @@ -415,16 +409,13 @@ class ConnectionState: if self._ready_task is not None: self._ready_task.cancel() - self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[]) + self._ready_state = asyncio.Queue() self.clear() self.user = user = ClientUser(state=self, data=data['user']) self._users[user.id] = user - guilds = self._ready_state.guilds for guild_data in data['guilds']: - guild = self._add_guild_from_data(guild_data) - if (not self.is_bot and not guild.unavailable) or guild.large: - guilds.append((guild, guild.unavailable)) + self._add_guild_from_data(guild_data) for relationship in data.get('relationships', []): try: @@ -561,7 +552,7 @@ class ConnectionState: guild_id = utils._get_as_snowflake(data, 'guild_id') guild = self._get_guild(guild_id) if guild is None: - log.warning('PRESENCE_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) + log.debug('PRESENCE_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) return user = data['user'] @@ -629,14 +620,14 @@ class ConnectionState: channel._update(guild, data) self.dispatch('guild_channel_update', old_channel, channel) else: - log.warning('CHANNEL_UPDATE referencing an unknown channel ID: %s. Discarding.', channel_id) + log.debug('CHANNEL_UPDATE referencing an unknown channel ID: %s. Discarding.', channel_id) else: - log.warning('CHANNEL_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) + log.debug('CHANNEL_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) def parse_channel_create(self, data): factory, ch_type = _channel_factory(data['type']) if factory is None: - log.warning('CHANNEL_CREATE referencing an unknown channel type %s. Discarding.', data['type']) + log.debug('CHANNEL_CREATE referencing an unknown channel type %s. Discarding.', data['type']) return channel = None @@ -655,14 +646,14 @@ class ConnectionState: guild._add_channel(channel) self.dispatch('guild_channel_create', channel) else: - log.warning('CHANNEL_CREATE referencing an unknown guild ID: %s. Discarding.', guild_id) + log.debug('CHANNEL_CREATE referencing an unknown guild ID: %s. Discarding.', guild_id) return def parse_channel_pins_update(self, data): channel_id = int(data['channel_id']) channel = self.get_channel(channel_id) if channel is None: - log.warning('CHANNEL_PINS_UPDATE referencing an unknown channel ID: %s. Discarding.', channel_id) + log.debug('CHANNEL_PINS_UPDATE referencing an unknown channel ID: %s. Discarding.', channel_id) return last_pin = utils.parse_time(data['last_pin_timestamp']) if data['last_pin_timestamp'] else None @@ -696,7 +687,7 @@ class ConnectionState: def parse_guild_member_add(self, data): guild = self._get_guild(int(data['guild_id'])) if guild is None: - log.warning('GUILD_MEMBER_ADD referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + log.debug('GUILD_MEMBER_ADD referencing an unknown guild ID: %s. Discarding.', data['guild_id']) return member = Member(guild=guild, data=data, state=self) @@ -715,28 +706,32 @@ class ConnectionState: guild._remove_member(member) self.dispatch('member_remove', member) else: - log.warning('GUILD_MEMBER_REMOVE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + log.debug('GUILD_MEMBER_REMOVE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) def parse_guild_member_update(self, data): guild = self._get_guild(int(data['guild_id'])) user = data['user'] user_id = int(user['id']) if guild is None: - log.warning('GUILD_MEMBER_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + log.debug('GUILD_MEMBER_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) return member = guild.get_member(user_id) if member is not None: - old_member = copy.copy(member) + old_member = Member._copy(member) member._update(data) + user_update = member._update_inner_user(user) + if user_update: + self.dispatch('user_update', user_update[0], user_update[1]) + self.dispatch('member_update', old_member, member) else: - log.warning('GUILD_MEMBER_UPDATE referencing an unknown member ID: %s. Discarding.', user_id) + log.debug('GUILD_MEMBER_UPDATE referencing an unknown member ID: %s. Discarding.', user_id) def parse_guild_emojis_update(self, data): guild = self._get_guild(int(data['guild_id'])) if guild is None: - log.warning('GUILD_EMOJIS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + log.debug('GUILD_EMOJIS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) return before_emojis = guild.emojis @@ -758,14 +753,21 @@ class ConnectionState: return self._add_guild_from_data(data) + async def chunk_guild(self, guild, *, wait=True, cache=None): + cache = cache or self._cache_members + future = self.loop.create_future() + request = ChunkRequest(guild.id, future, self._get_guild, cache=cache) + self._chunk_requests.append(request) + await self.chunker(guild.id, nonce=request.nonce) + if wait: + return await request.future + return request.future + async def _chunk_and_dispatch(self, guild, unavailable): - chunks = list(self.chunks_needed(guild)) - await self.chunker(guild.id) - if chunks: - try: - await utils.sane_wait_for(chunks, timeout=len(chunks)) - except asyncio.TimeoutError: - log.info('Somehow timed out waiting for chunks.') + try: + await asyncio.wait_for(self.chunk_guild(guild), timeout=60.0) + except asyncio.TimeoutError: + log.info('Somehow timed out waiting for chunks.') if unavailable is False: self.dispatch('guild_available', guild) @@ -780,30 +782,19 @@ class ConnectionState: guild = self._get_create_guild(data) - # check if it requires chunking - if guild.large: - if unavailable is False: - # check if we're waiting for 'useful' READY - # and if we are, we don't want to dispatch any - # event such as guild_join or guild_available - # because we're still in the 'READY' phase. Or - # so we say. - try: - state = self._ready_state - state.launch.set() - state.guilds.append((guild, unavailable)) - except AttributeError: - # the _ready_state attribute is only there during - # processing of useful READY. - pass - else: - return + try: + # Notify the on_ready state, if any, that this guild is complete. + self._ready_state.put_nowait(guild) + except AttributeError: + pass + else: + # If we're waiting for the event, put the rest on hold + return - # since we're not waiting for 'useful' READY we'll just - # do the chunk request here if wanted - if self._fetch_offline: - asyncio.ensure_future(self._chunk_and_dispatch(guild, unavailable), loop=self.loop) - return + # check if it requires chunking + if self._guild_needs_chunking(guild): + asyncio.ensure_future(self._chunk_and_dispatch(guild, unavailable), loop=self.loop) + return # Dispatch available if newly available if unavailable is False: @@ -822,12 +813,12 @@ class ConnectionState: guild._from_data(data) self.dispatch('guild_update', old_guild, guild) else: - log.warning('GUILD_UPDATE referencing an unknown guild ID: %s. Discarding.', data['id']) + log.debug('GUILD_UPDATE referencing an unknown guild ID: %s. Discarding.', data['id']) def parse_guild_delete(self, data): guild = self._get_guild(int(data['id'])) if guild is None: - log.warning('GUILD_DELETE referencing an unknown guild ID: %s. Discarding.', data['id']) + log.debug('GUILD_DELETE referencing an unknown guild ID: %s. Discarding.', data['id']) return if data.get('unavailable', False) and guild is not None: @@ -870,7 +861,7 @@ class ConnectionState: def parse_guild_role_create(self, data): guild = self._get_guild(int(data['guild_id'])) if guild is None: - log.warning('GUILD_ROLE_CREATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + log.debug('GUILD_ROLE_CREATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) return role_data = data['role'] @@ -889,7 +880,7 @@ class ConnectionState: else: self.dispatch('guild_role_delete', role) else: - log.warning('GUILD_ROLE_DELETE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + log.debug('GUILD_ROLE_DELETE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) def parse_guild_role_update(self, data): guild = self._get_guild(int(data['guild_id'])) @@ -902,35 +893,29 @@ class ConnectionState: role._update(role_data) self.dispatch('guild_role_update', old_role, role) else: - log.warning('GUILD_ROLE_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + log.debug('GUILD_ROLE_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) def parse_guild_members_chunk(self, data): guild_id = int(data['guild_id']) guild = self._get_guild(guild_id) members = [Member(guild=guild, data=member, state=self) for member in data.get('members', [])] log.debug('Processed a chunk for %s members in guild ID %s.', len(members), guild_id) - if self._cache_members: - for member in members: - existing = guild.get_member(member.id) - if existing is None or existing.joined_at is None: - guild._add_member(member) - - self.process_listeners(ListenerType.chunk, guild, len(members)) - self.process_listeners(ListenerType.query_members, (guild_id, data.get('nonce')), members) + complete = data.get('chunk_index', 0) + 1 == data.get('chunk_count') + self.process_chunk_requests(guild_id, data.get('nonce'), members, complete) def parse_guild_integrations_update(self, data): guild = self._get_guild(int(data['guild_id'])) if guild is not None: self.dispatch('guild_integrations_update', guild) else: - log.warning('GUILD_INTEGRATIONS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + log.debug('GUILD_INTEGRATIONS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) def parse_webhooks_update(self, data): channel = self.get_channel(int(data['channel_id'])) if channel is not None: self.dispatch('webhooks_update', channel) else: - log.warning('WEBHOOKS_UPDATE referencing an unknown channel ID: %s. Discarding.', data['channel_id']) + log.debug('WEBHOOKS_UPDATE referencing an unknown channel ID: %s. Discarding.', data['channel_id']) def parse_voice_state_update(self, data): guild = self._get_guild(utils._get_as_snowflake(data, 'guild_id')) @@ -947,7 +932,7 @@ class ConnectionState: if member is not None: self.dispatch('voice_state_update', member, before, after) else: - log.warning('VOICE_STATE_UPDATE referencing an unknown member ID: %s. Discarding.', data['user_id']) + log.debug('VOICE_STATE_UPDATE referencing an unknown member ID: %s. Discarding.', data['user_id']) else: # in here we're either at private or group calls call = self._calls.get(channel_id) @@ -1040,21 +1025,6 @@ class ConnectionState: def create_message(self, *, channel, data): return Message(state=self, channel=channel, data=data) - def receive_chunk(self, guild_id): - future = self.loop.create_future() - listener = Listener(ListenerType.chunk, future, lambda s: s.id == guild_id) - self._listeners.append(listener) - return future - - def receive_member_query(self, guild_id, nonce): - def predicate(args, *, guild_id=guild_id, nonce=nonce): - return args == (guild_id, nonce) - - future = self.loop.create_future() - listener = Listener(ListenerType.query_members, future, predicate) - self._listeners.append(listener) - return future - class AutoShardedConnectionState(ConnectionState): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -1077,51 +1047,56 @@ class AutoShardedConnectionState(ConnectionState): ws = self._get_websocket(guild_id, shard_id=shard_id) await ws.request_chunks(guild_id, query=query, limit=limit, nonce=nonce) - async def request_offline_members(self, guilds, *, shard_id): - # get all the chunks - chunks = [] - for guild in guilds: - chunks.extend(self.chunks_needed(guild)) - - # we only want to request ~75 guilds per chunk request. - splits = [guilds[i:i + 75] for i in range(0, len(guilds), 75)] - for split in splits: - await self.chunker([g.id for g in split], shard_id=shard_id) - - # wait for the chunks - if chunks: - try: - await utils.sane_wait_for(chunks, timeout=len(chunks) * 30.0) - except asyncio.TimeoutError: - log.info('Somehow timed out waiting for chunks.') - else: - log.info('Finished requesting guild member chunks for %d guilds.', len(guilds)) - async def _delay_ready(self): await self.shards_launched.wait() - launch = self._ready_state.launch + processed = [] + max_concurrency = len(self.shard_ids) * 2 + current_bucket = [] while True: # this snippet of code is basically waiting N seconds # until the last GUILD_CREATE was sent try: - await asyncio.wait_for(launch.wait(), timeout=self.guild_ready_timeout) + guild = await asyncio.wait_for(self._ready_state.get(), timeout=self.guild_ready_timeout) except asyncio.TimeoutError: break else: - launch.clear() + if self._guild_needs_chunking(guild): + log.debug('Guild ID %d requires chunking, will be done in the background.', guild.id) + if len(current_bucket) >= max_concurrency: + try: + await utils.sane_wait_for(current_bucket, timeout=max_concurrency * 70.0) + except asyncio.TimeoutError: + fmt = 'Shard ID %s failed to wait for chunks from a sub-bucket with length %d' + log.warning(fmt, self.shard_id, len(current_bucket)) + finally: + current_bucket = [] - guilds = sorted(self._ready_state.guilds, key=lambda g: g[0].shard_id) + # Chunk the guild in the background while we wait for GUILD_CREATE streaming + future = asyncio.ensure_future(self.chunk_guild(guild)) + current_bucket.append(future) + else: + future = self.loop.create_future() + future.set_result([]) - for shard_id, sub_guilds_info in itertools.groupby(guilds, key=lambda g: g[0].shard_id): - sub_guilds, sub_available = zip(*sub_guilds_info) - if self._fetch_offline: - await self.request_offline_members(sub_guilds, shard_id=shard_id) + processed.append((guild, future)) - for guild, unavailable in zip(sub_guilds, sub_available): - if unavailable is False: + guilds = sorted(processed, key=lambda g: g[0].shard_id) + for shard_id, info in itertools.groupby(guilds, key=lambda g: g[0].shard_id): + children, futures = zip(*info) + # 110 reqs/minute w/ 1 req/guild plus some buffer + timeout = 61 * (len(children) / 110) + try: + await utils.sane_wait_for(futures, timeout=timeout) + except asyncio.TimeoutError: + log.warning('Shard ID %s failed to wait for chunks (timeout=%.2f) for %d guilds', self.shard_id, + timeout, + len(guilds)) + for guild in children: + if guild.unavailable is False: self.dispatch('guild_available', guild) else: self.dispatch('guild_join', guild) + self.dispatch('shard_ready', shard_id) # remove the state @@ -1141,16 +1116,13 @@ class AutoShardedConnectionState(ConnectionState): def parse_ready(self, data): if not hasattr(self, '_ready_state'): - self._ready_state = ReadyState(launch=asyncio.Event(), guilds=[]) + self._ready_state = asyncio.Queue() self.user = user = ClientUser(state=self, data=data['user']) self._users[user.id] = user - guilds = self._ready_state.guilds for guild_data in data['guilds']: - guild = self._add_guild_from_data(guild_data) - if guild.large: - guilds.append((guild, guild.unavailable)) + self._add_guild_from_data(guild_data) if self._messages: self._update_message_references()