From c9983d0248a7ff4c2f2489c090a37c8d35220b42 Mon Sep 17 00:00:00 2001 From: Gnome Date: Sat, 2 Oct 2021 19:34:56 +0100 Subject: [PATCH] Make fetch_x methods cache when applicable --- discord/abc.py | 6 +++--- discord/client.py | 34 ++++++++++++++++++++++--------- discord/ext/commands/core.py | 4 +++- discord/ext/commands/errors.py | 2 +- discord/flags.py | 11 ++++++++++ discord/guild.py | 16 +++++++++++---- discord/interactions.py | 2 +- discord/iterators.py | 37 +++++++++++++++++----------------- discord/message.py | 6 +++--- discord/state.py | 34 +++++++++++++++++++------------ 10 files changed, 97 insertions(+), 55 deletions(-) diff --git a/discord/abc.py b/discord/abc.py index 196043e3..5e7f85eb 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -1434,7 +1434,7 @@ class Messageable: components=components, ) - ret = state.create_message(channel=channel, data=data) + ret = state.store_message(channel=channel, data=data) if view: state.store_view(view, ret.id) @@ -1501,7 +1501,7 @@ class Messageable: channel = await self._get_channel() data = await self._state.http.get_message(channel.id, id) - return self._state.create_message(channel=channel, data=data) + return self._state.store_message(channel=channel, data=data) async def pins(self) -> List[Message]: """|coro| @@ -1528,7 +1528,7 @@ class Messageable: channel = await self._get_channel() state = self._state data = await state.http.pins_from(channel.id) - return [state.create_message(channel=channel, data=m) for m in data] + return [state.store_message(channel=channel, data=m) for m in data] def history( self, diff --git a/discord/client.py b/discord/client.py index 3d0216d8..640470f0 100644 --- a/discord/client.py +++ b/discord/client.py @@ -53,7 +53,7 @@ from .widget import Widget from .guild import Guild from .emoji import Emoji from .channel import _threaded_channel_factory, PartialMessageable -from .enums import ChannelType +from .enums import ChannelType, StickerType from .mentions import AllowedMentions from .errors import * from .enums import Status, VoiceRegion @@ -76,7 +76,8 @@ from .threads import Thread from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory if TYPE_CHECKING: - from .abc import SnowflakeTime, PrivateChannel, GuildChannel, Snowflake + from .abc import SnowflakeTime, PrivateChannel, GuildChannel as GuildChannelABC, Snowflake + from .guild import GuildChannel from .channel import DMChannel from .message import Message from .member import Member @@ -780,7 +781,7 @@ class Client: """List[:class:`~discord.User`]: Returns a list of all the users the bot can see.""" return list(self._connection._users.values()) - def get_channel(self, id: int, /) -> Optional[Union[GuildChannel, Thread, PrivateChannel]]: + def get_channel(self, id: int, /) -> Optional[Union[GuildChannelABC, Thread, PrivateChannel]]: """Returns a channel or thread with the given ID. Parameters @@ -933,7 +934,7 @@ class Client: """ return self._connection.get_sticker(id) - def get_all_channels(self) -> Generator[GuildChannel, None, None]: + def get_all_channels(self) -> Generator[GuildChannelABC, None, None]: """A generator that retrieves every :class:`.abc.GuildChannel` the client can 'access'. This is equivalent to: :: @@ -1488,6 +1489,7 @@ class Client: """|coro| Retrieves the bot's application information. + This will fill up :attr:`application_id` and :attr:`application_flags`. Raises ------- @@ -1502,6 +1504,8 @@ class Client: data = await self.http.application_info() if "rpc_origins" not in data: data["rpc_origins"] = None + + self._connection.store_appinfo(data) return AppInfo(self._connection, data) async def fetch_user(self, user_id: int, /) -> User: @@ -1535,10 +1539,11 @@ class Client: data = await self.http.get_user(user_id) return User(state=self._connection, data=data) - async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, PrivateChannel, Thread]: + async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannelABC, PrivateChannel, Thread]: """|coro| Retrieves a :class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, or :class:`.Thread` with the specified ID. + If found, will store the Channel in the internal cache, meaning :meth:``get_channel`` will succeed afterwards. .. note:: @@ -1570,14 +1575,18 @@ class Client: if ch_type in (ChannelType.group, ChannelType.private): # the factory will be a DMChannel or GroupChannel here - channel = factory(me=self.user, data=data, state=self._connection) # type: ignore + channel: PrivateChannel = factory(me=self.user, data=data, state=self._connection) # type: ignore + self._connection._add_private_channel(channel) # type: ignore else: # the factory can't be a DMChannel or GroupChannel here guild_id = int(data["guild_id"]) # type: ignore - guild = self.get_guild(guild_id) or Object(id=guild_id) - # GuildChannels expect a Guild, we may be passing an Object - channel = factory(guild=guild, state=self._connection, data=data) # type: ignore + guild = self.get_guild(guild_id) + if guild is None: + return factory(guild=Object(guild_id), state=self._connection, data=data) # type: ignore + + channel: GuildChannel = factory(guild=guild, state=self._connection, data=data) # type: ignore + guild._add_channel(channel) return channel async def fetch_webhook(self, webhook_id: int, /) -> Webhook: @@ -1606,6 +1615,7 @@ class Client: """|coro| Retrieves a :class:`.Sticker` with the specified ID. + If found, will store the sticker in the internal cache, meaning :meth:``get_sticker`` will succeed afterwards. .. versionadded:: 2.0 @@ -1623,7 +1633,11 @@ class Client: """ data = await self.http.get_sticker(sticker_id) cls, _ = _sticker_factory(data["type"]) # type: ignore - return cls(state=self._connection, data=data) # type: ignore + + if data["type"] == StickerType.guild: # type: ignore + return self._connection.store_sticker(data) # type: ignore + else: + return cls(state=self._connection, data=data) # type: ignore async def fetch_premium_sticker_packs(self) -> List[StickerPack]: """|coro| diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 178b252e..9650bd6f 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -1694,7 +1694,9 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]): "name": self.name, "type": int(not (nested - 1)) + 1, "description": self.short_doc or "no description", - "options": [cmd.to_application_command(nested=nested + 1) for cmd in sorted(self.commands, key=lambda x: x.name)], + "options": [ + cmd.to_application_command(nested=nested + 1) for cmd in sorted(self.commands, key=lambda x: x.name) + ], } diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index c8dde033..25162925 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -455,7 +455,7 @@ class BadInviteArgument(BadArgument): This inherits from :exc:`BadArgument` .. versionadded:: 1.5 - + Attributes ----------- argument: :class:`str` diff --git a/discord/flags.py b/discord/flags.py index 920c190f..e71ddf4e 100644 --- a/discord/flags.py +++ b/discord/flags.py @@ -951,6 +951,15 @@ class MemberCacheFlags(BaseFlags): """ return 2 + @flag_value + def fetched(self): + """:class:`bool`: Whether to cache members that are fetched via :meth:``Guild.fetch_member`` + or :meth:``Guild.fetch_members`` + + Members that leave the guild are no longer cached. + """ + return 4 + @classmethod def from_intents(cls: Type[MemberCacheFlags], intents: Intents) -> MemberCacheFlags: """A factory method that creates a :class:`MemberCacheFlags` based on @@ -968,6 +977,8 @@ class MemberCacheFlags(BaseFlags): """ self = cls.none() + self.fetched = True + if intents.members: self.joined = True if intents.voice_states: diff --git a/discord/guild.py b/discord/guild.py index a5dfafaf..a3a50a3d 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -428,7 +428,7 @@ class Guild(Hashable): self.mfa_level: MFALevel = guild.get("mfa_level") self.emojis: Tuple[Emoji, ...] = tuple(map(lambda d: state.store_emoji(self, d), guild.get("emojis", []))) self.stickers: Tuple[GuildSticker, ...] = tuple( - map(lambda d: state.store_sticker(self, d), guild.get("stickers", [])) + map(lambda d: state.store_sticker(d), guild.get("stickers", [])) ) self.features: List[GuildFeature] = guild.get("features", []) self._splash: Optional[str] = guild.get("splash") @@ -1594,6 +1594,7 @@ class Guild(Hashable): """|coro| Retrieves all :class:`abc.GuildChannel` that the guild has. + Will store the Channels in the internal cache, meaning :meth:``get_channel`` will succeed afterwards. .. note:: @@ -1616,11 +1617,12 @@ class Guild(Hashable): data = await self._state.http.get_all_guild_channels(self.id) def convert(d): - factory, ch_type = _guild_channel_factory(d["type"]) + factory, _ = _guild_channel_factory(d["type"]) if factory is None: raise InvalidData("Unknown channel type {type} for channel ID {id}.".format_map(d)) channel = factory(guild=self, state=self._state, data=d) + self._add_channel(channel) return channel return [convert(d) for d in data] @@ -1712,6 +1714,8 @@ class Guild(Hashable): """|coro| Retrieves a :class:`Member` from a guild ID, and a member ID. + If found, will store the Member in the internal cache, filling up :attr:`members` + and meaning :meth:``get_member`` will succeed afterwards. .. note:: @@ -1737,7 +1741,11 @@ class Guild(Hashable): The member from the member ID. """ data = await self._state.http.get_member(self.id, member_id) - return Member(data=data, state=self._state, guild=self) + member = Member(data=data, state=self._state, guild=self) + if self._state.member_cache_flags.fetched: + self._add_member(member) + + return member async def try_member(self, member_id: int, /) -> Optional[Member]: """|coro| @@ -2257,7 +2265,7 @@ class Guild(Hashable): payload["tags"] = emoji data = await self._state.http.create_guild_sticker(self.id, payload, file, reason) - return self._state.store_sticker(self, data) + return self._state.store_sticker(data) async def delete_sticker(self, sticker: Snowflake, *, reason: Optional[str] = None) -> None: """|coro| diff --git a/discord/interactions.py b/discord/interactions.py index 83a61a3b..804d9a75 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -137,7 +137,7 @@ class Interaction: self.message: Optional[Message] try: - self.message = Message(state=self._state, channel=self.channel, data=data["message"]) # type: ignore + self.message = self._state.store_message(channel=self.channel, data=data["message"]) # type: ignore except KeyError: self.message = None diff --git a/discord/iterators.py b/discord/iterators.py index f5a94ae1..f65925fd 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -42,22 +42,12 @@ __all__ = ( ) if TYPE_CHECKING: - from .types.audit_log import ( - AuditLog as AuditLogPayload, - ) - from .types.guild import ( - Guild as GuildPayload, - ) - from .types.message import ( - Message as MessagePayload, - ) - from .types.user import ( - PartialUser as PartialUserPayload, - ) - - from .types.threads import ( - Thread as ThreadPayload, - ) + from .types.member import MemberWithUser as MemberWithUserPayload + from .types.user import PartialUser as PartialUserPayload + from .types.audit_log import AuditLog as AuditLogPayload + from .types.message import Message as MessagePayload + from .types.threads import Thread as ThreadPayload + from .types.guild import Guild as GuildPayload from .member import Member from .user import User @@ -354,7 +344,7 @@ class HistoryIterator(_AsyncIterator["Message"]): channel = self.channel for element in data: - await self.messages.put(self.state.create_message(channel=channel, data=element)) + await self.messages.put(self.state.store_message(channel=channel, data=element)) async def _retrieve_messages(self, retrieve) -> List[Message]: """Retrieve messages and update next parameters.""" @@ -615,14 +605,18 @@ class MemberIterator(_AsyncIterator["Member"]): if isinstance(after, datetime.datetime): after = Object(id=time_snowflake(after, high=True)) - self.guild = guild self.limit = limit + self.guild: Guild = guild self.after = after or OLDEST_OBJECT self.state = self.guild._state self.get_members = self.state.http.get_members self.members = asyncio.Queue() + self.create_member = ( + self.create_member_cache if self.state.member_cache_flags.fetched else self.create_member_no_cache + ) + async def next(self) -> Member: if self.members.empty(): await self.fill_members() @@ -657,11 +651,16 @@ class MemberIterator(_AsyncIterator["Member"]): for element in reversed(data): await self.members.put(self.create_member(element)) - def create_member(self, data): + def create_member_no_cache(self, data: MemberWithUserPayload) -> Member: from .member import Member return Member(data=data, guild=self.guild, state=self.state) + def create_member_cache(self, data: MemberWithUserPayload) -> Member: + member = self.create_member_no_cache(data) + self.guild._add_member(member) + return member + class ArchivedThreadIterator(_AsyncIterator["Thread"]): def __init__( diff --git a/discord/message.py b/discord/message.py index d8757022..0ab4ed7f 100644 --- a/discord/message.py +++ b/discord/message.py @@ -1331,7 +1331,7 @@ class Message(Hashable): payload["components"] = [] data = await self._state.http.edit_message(self.channel.id, self.id, **payload) - message = Message(state=self._state, channel=self.channel, data=data) + message = self._state.store_message(channel=self.channel, data=data) if view and not view.is_finished(): self._state.store_view(view, self.id) @@ -1756,7 +1756,7 @@ class PartialMessage(Hashable): """ data = await self._state.http.get_message(self.channel.id, self.id) - return self._state.create_message(channel=self.channel, data=data) + return self._state.store_message(channel=self.channel, data=data) async def edit(self, **fields: Any) -> Optional[Message]: """|coro| @@ -1873,7 +1873,7 @@ class PartialMessage(Hashable): if fields: # data isn't unbound - msg = self._state.create_message(channel=self.channel, data=data) # type: ignore + msg = self._state.store_message(channel=self.channel, data=data) # type: ignore if view and not view.is_finished(): self._state.store_view(view, self.id) return msg diff --git a/discord/state.py b/discord/state.py index edf39263..5bab2e65 100644 --- a/discord/state.py +++ b/discord/state.py @@ -27,7 +27,6 @@ from __future__ import annotations import asyncio from collections import deque, OrderedDict import copy -import datetime import itertools import logging from typing import Dict, Optional, TYPE_CHECKING, Union, Callable, Any, List, TypeVar, Coroutine, Sequence, Tuple, Deque @@ -75,6 +74,7 @@ if TYPE_CHECKING: from .types.sticker import GuildSticker as GuildStickerPayload from .types.guild import Guild as GuildPayload from .types.message import Message as MessagePayload + from .types.appinfo import AppInfo as AppInfoPayload T = TypeVar("T") CS = TypeVar("CS", bound="ConnectionState") @@ -323,6 +323,13 @@ class ConnectionState: for vc in self.voice_clients: vc.main_ws = ws # type: ignore + def store_message(self, channel: MessageableChannel, data: MessagePayload) -> Message: + message = Message(state=self, channel=channel, data=data) + if self._messages is not None: + self._messages.append(message) + + return message + def store_user(self, data: UserPayload) -> User: user_id = int(data["id"]) try: @@ -353,11 +360,20 @@ class ConnectionState: self._emojis[emoji_id] = emoji = Emoji(guild=guild, state=self, data=data) return emoji - def store_sticker(self, guild: Guild, data: GuildStickerPayload) -> GuildSticker: + def store_sticker(self, data: GuildStickerPayload) -> GuildSticker: sticker_id = int(data["id"]) self._stickers[sticker_id] = sticker = GuildSticker(state=self, data=data) return sticker + def store_appinfo(self, data: AppInfoPayload): + self.application_id = utils._get_as_snowflake(data, "id") + + flags = data.get("flags") + if flags is not None: + self.application_flags = ApplicationFlags._from_value(flags) + + return data + def store_view(self, view: View, message_id: Optional[int] = None) -> None: self._view_store.add_view(view, message_id) @@ -563,9 +579,7 @@ class ConnectionState: except KeyError: pass else: - self.application_id = utils._get_as_snowflake(application, "id") - # flags will always be present here - self.application_flags = ApplicationFlags._from_value(application["flags"]) # type: ignore + self.store_appinfo(application) for guild_data in data["guilds"]: self._add_guild_from_data(guild_data) @@ -581,8 +595,7 @@ class ConnectionState: # channel would be the correct type here message = Message(channel=channel, data=data, state=self) # type: ignore self.dispatch("message", message) - if self._messages is not None: - self._messages.append(message) + self.store_message(channel, data) # we ensure that the channel is either a TextChannel or Thread if channel and channel.__class__ in (TextChannel, Thread): channel.last_message_id = message.id # type: ignore @@ -1032,7 +1045,7 @@ class ConnectionState: for emoji in before_stickers: self._stickers.pop(emoji.id, None) # guild won't be None here - guild.stickers = tuple(map(lambda d: self.store_sticker(guild, d), data["stickers"])) # type: ignore + guild.stickers = tuple(map(lambda d: self.store_sticker(d), data["stickers"])) # type: ignore self.dispatch("guild_stickers_update", guild, before_stickers, guild.stickers) def _get_create_guild(self, data): @@ -1403,11 +1416,6 @@ class ConnectionState: if channel is not None: return channel - def create_message( - self, *, channel: Union[TextChannel, Thread, DMChannel, GroupChannel, PartialMessageable], data: MessagePayload - ) -> Message: - return Message(state=self, channel=channel, data=data) - class AutoShardedConnectionState(ConnectionState): def __init__(self, *args: Any, **kwargs: Any) -> None: