Make fetch_x methods cache when applicable

This commit is contained in:
Gnome 2021-10-02 19:34:56 +01:00
parent 3260ec6643
commit c9983d0248
10 changed files with 97 additions and 55 deletions

View File

@ -1434,7 +1434,7 @@ class Messageable:
components=components, components=components,
) )
ret = state.create_message(channel=channel, data=data) ret = state.store_message(channel=channel, data=data)
if view: if view:
state.store_view(view, ret.id) state.store_view(view, ret.id)
@ -1501,7 +1501,7 @@ class Messageable:
channel = await self._get_channel() channel = await self._get_channel()
data = await self._state.http.get_message(channel.id, id) data = await self._state.http.get_message(channel.id, id)
return self._state.create_message(channel=channel, data=data) return self._state.store_message(channel=channel, data=data)
async def pins(self) -> List[Message]: async def pins(self) -> List[Message]:
"""|coro| """|coro|
@ -1528,7 +1528,7 @@ class Messageable:
channel = await self._get_channel() channel = await self._get_channel()
state = self._state state = self._state
data = await state.http.pins_from(channel.id) data = await state.http.pins_from(channel.id)
return [state.create_message(channel=channel, data=m) for m in data] return [state.store_message(channel=channel, data=m) for m in data]
def history( def history(
self, self,

View File

@ -53,7 +53,7 @@ from .widget import Widget
from .guild import Guild from .guild import Guild
from .emoji import Emoji from .emoji import Emoji
from .channel import _threaded_channel_factory, PartialMessageable from .channel import _threaded_channel_factory, PartialMessageable
from .enums import ChannelType from .enums import ChannelType, StickerType
from .mentions import AllowedMentions from .mentions import AllowedMentions
from .errors import * from .errors import *
from .enums import Status, VoiceRegion from .enums import Status, VoiceRegion
@ -76,7 +76,8 @@ from .threads import Thread
from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory from .sticker import GuildSticker, StandardSticker, StickerPack, _sticker_factory
if TYPE_CHECKING: if TYPE_CHECKING:
from .abc import SnowflakeTime, PrivateChannel, GuildChannel, Snowflake from .abc import SnowflakeTime, PrivateChannel, GuildChannel as GuildChannelABC, Snowflake
from .guild import GuildChannel
from .channel import DMChannel from .channel import DMChannel
from .message import Message from .message import Message
from .member import Member from .member import Member
@ -780,7 +781,7 @@ class Client:
"""List[:class:`~discord.User`]: Returns a list of all the users the bot can see.""" """List[:class:`~discord.User`]: Returns a list of all the users the bot can see."""
return list(self._connection._users.values()) return list(self._connection._users.values())
def get_channel(self, id: int, /) -> Optional[Union[GuildChannel, Thread, PrivateChannel]]: def get_channel(self, id: int, /) -> Optional[Union[GuildChannelABC, Thread, PrivateChannel]]:
"""Returns a channel or thread with the given ID. """Returns a channel or thread with the given ID.
Parameters Parameters
@ -933,7 +934,7 @@ class Client:
""" """
return self._connection.get_sticker(id) return self._connection.get_sticker(id)
def get_all_channels(self) -> Generator[GuildChannel, None, None]: def get_all_channels(self) -> Generator[GuildChannelABC, None, None]:
"""A generator that retrieves every :class:`.abc.GuildChannel` the client can 'access'. """A generator that retrieves every :class:`.abc.GuildChannel` the client can 'access'.
This is equivalent to: :: This is equivalent to: ::
@ -1488,6 +1489,7 @@ class Client:
"""|coro| """|coro|
Retrieves the bot's application information. Retrieves the bot's application information.
This will fill up :attr:`application_id` and :attr:`application_flags`.
Raises Raises
------- -------
@ -1502,6 +1504,8 @@ class Client:
data = await self.http.application_info() data = await self.http.application_info()
if "rpc_origins" not in data: if "rpc_origins" not in data:
data["rpc_origins"] = None data["rpc_origins"] = None
self._connection.store_appinfo(data)
return AppInfo(self._connection, data) return AppInfo(self._connection, data)
async def fetch_user(self, user_id: int, /) -> User: async def fetch_user(self, user_id: int, /) -> User:
@ -1535,10 +1539,11 @@ class Client:
data = await self.http.get_user(user_id) data = await self.http.get_user(user_id)
return User(state=self._connection, data=data) return User(state=self._connection, data=data)
async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, PrivateChannel, Thread]: async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannelABC, PrivateChannel, Thread]:
"""|coro| """|coro|
Retrieves a :class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, or :class:`.Thread` with the specified ID. Retrieves a :class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, or :class:`.Thread` with the specified ID.
If found, will store the Channel in the internal cache, meaning :meth:``get_channel`` will succeed afterwards.
.. note:: .. note::
@ -1570,14 +1575,18 @@ class Client:
if ch_type in (ChannelType.group, ChannelType.private): if ch_type in (ChannelType.group, ChannelType.private):
# the factory will be a DMChannel or GroupChannel here # the factory will be a DMChannel or GroupChannel here
channel = factory(me=self.user, data=data, state=self._connection) # type: ignore channel: PrivateChannel = factory(me=self.user, data=data, state=self._connection) # type: ignore
self._connection._add_private_channel(channel) # type: ignore
else: else:
# the factory can't be a DMChannel or GroupChannel here # the factory can't be a DMChannel or GroupChannel here
guild_id = int(data["guild_id"]) # type: ignore guild_id = int(data["guild_id"]) # type: ignore
guild = self.get_guild(guild_id) or Object(id=guild_id) guild = self.get_guild(guild_id)
# GuildChannels expect a Guild, we may be passing an Object
channel = factory(guild=guild, state=self._connection, data=data) # type: ignore
if guild is None:
return factory(guild=Object(guild_id), state=self._connection, data=data) # type: ignore
channel: GuildChannel = factory(guild=guild, state=self._connection, data=data) # type: ignore
guild._add_channel(channel)
return channel return channel
async def fetch_webhook(self, webhook_id: int, /) -> Webhook: async def fetch_webhook(self, webhook_id: int, /) -> Webhook:
@ -1606,6 +1615,7 @@ class Client:
"""|coro| """|coro|
Retrieves a :class:`.Sticker` with the specified ID. Retrieves a :class:`.Sticker` with the specified ID.
If found, will store the sticker in the internal cache, meaning :meth:``get_sticker`` will succeed afterwards.
.. versionadded:: 2.0 .. versionadded:: 2.0
@ -1623,7 +1633,11 @@ class Client:
""" """
data = await self.http.get_sticker(sticker_id) data = await self.http.get_sticker(sticker_id)
cls, _ = _sticker_factory(data["type"]) # type: ignore cls, _ = _sticker_factory(data["type"]) # type: ignore
return cls(state=self._connection, data=data) # type: ignore
if data["type"] == StickerType.guild: # type: ignore
return self._connection.store_sticker(data) # type: ignore
else:
return cls(state=self._connection, data=data) # type: ignore
async def fetch_premium_sticker_packs(self) -> List[StickerPack]: async def fetch_premium_sticker_packs(self) -> List[StickerPack]:
"""|coro| """|coro|

View File

@ -1694,7 +1694,9 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]):
"name": self.name, "name": self.name,
"type": int(not (nested - 1)) + 1, "type": int(not (nested - 1)) + 1,
"description": self.short_doc or "no description", "description": self.short_doc or "no description",
"options": [cmd.to_application_command(nested=nested + 1) for cmd in sorted(self.commands, key=lambda x: x.name)], "options": [
cmd.to_application_command(nested=nested + 1) for cmd in sorted(self.commands, key=lambda x: x.name)
],
} }

View File

@ -455,7 +455,7 @@ class BadInviteArgument(BadArgument):
This inherits from :exc:`BadArgument` This inherits from :exc:`BadArgument`
.. versionadded:: 1.5 .. versionadded:: 1.5
Attributes Attributes
----------- -----------
argument: :class:`str` argument: :class:`str`

View File

@ -951,6 +951,15 @@ class MemberCacheFlags(BaseFlags):
""" """
return 2 return 2
@flag_value
def fetched(self):
""":class:`bool`: Whether to cache members that are fetched via :meth:``Guild.fetch_member``
or :meth:``Guild.fetch_members``
Members that leave the guild are no longer cached.
"""
return 4
@classmethod @classmethod
def from_intents(cls: Type[MemberCacheFlags], intents: Intents) -> MemberCacheFlags: def from_intents(cls: Type[MemberCacheFlags], intents: Intents) -> MemberCacheFlags:
"""A factory method that creates a :class:`MemberCacheFlags` based on """A factory method that creates a :class:`MemberCacheFlags` based on
@ -968,6 +977,8 @@ class MemberCacheFlags(BaseFlags):
""" """
self = cls.none() self = cls.none()
self.fetched = True
if intents.members: if intents.members:
self.joined = True self.joined = True
if intents.voice_states: if intents.voice_states:

View File

@ -428,7 +428,7 @@ class Guild(Hashable):
self.mfa_level: MFALevel = guild.get("mfa_level") self.mfa_level: MFALevel = guild.get("mfa_level")
self.emojis: Tuple[Emoji, ...] = tuple(map(lambda d: state.store_emoji(self, d), guild.get("emojis", []))) self.emojis: Tuple[Emoji, ...] = tuple(map(lambda d: state.store_emoji(self, d), guild.get("emojis", [])))
self.stickers: Tuple[GuildSticker, ...] = tuple( self.stickers: Tuple[GuildSticker, ...] = tuple(
map(lambda d: state.store_sticker(self, d), guild.get("stickers", [])) map(lambda d: state.store_sticker(d), guild.get("stickers", []))
) )
self.features: List[GuildFeature] = guild.get("features", []) self.features: List[GuildFeature] = guild.get("features", [])
self._splash: Optional[str] = guild.get("splash") self._splash: Optional[str] = guild.get("splash")
@ -1594,6 +1594,7 @@ class Guild(Hashable):
"""|coro| """|coro|
Retrieves all :class:`abc.GuildChannel` that the guild has. Retrieves all :class:`abc.GuildChannel` that the guild has.
Will store the Channels in the internal cache, meaning :meth:``get_channel`` will succeed afterwards.
.. note:: .. note::
@ -1616,11 +1617,12 @@ class Guild(Hashable):
data = await self._state.http.get_all_guild_channels(self.id) data = await self._state.http.get_all_guild_channels(self.id)
def convert(d): def convert(d):
factory, ch_type = _guild_channel_factory(d["type"]) factory, _ = _guild_channel_factory(d["type"])
if factory is None: if factory is None:
raise InvalidData("Unknown channel type {type} for channel ID {id}.".format_map(d)) raise InvalidData("Unknown channel type {type} for channel ID {id}.".format_map(d))
channel = factory(guild=self, state=self._state, data=d) channel = factory(guild=self, state=self._state, data=d)
self._add_channel(channel)
return channel return channel
return [convert(d) for d in data] return [convert(d) for d in data]
@ -1712,6 +1714,8 @@ class Guild(Hashable):
"""|coro| """|coro|
Retrieves a :class:`Member` from a guild ID, and a member ID. Retrieves a :class:`Member` from a guild ID, and a member ID.
If found, will store the Member in the internal cache, filling up :attr:`members`
and meaning :meth:``get_member`` will succeed afterwards.
.. note:: .. note::
@ -1737,7 +1741,11 @@ class Guild(Hashable):
The member from the member ID. The member from the member ID.
""" """
data = await self._state.http.get_member(self.id, member_id) data = await self._state.http.get_member(self.id, member_id)
return Member(data=data, state=self._state, guild=self) member = Member(data=data, state=self._state, guild=self)
if self._state.member_cache_flags.fetched:
self._add_member(member)
return member
async def try_member(self, member_id: int, /) -> Optional[Member]: async def try_member(self, member_id: int, /) -> Optional[Member]:
"""|coro| """|coro|
@ -2257,7 +2265,7 @@ class Guild(Hashable):
payload["tags"] = emoji payload["tags"] = emoji
data = await self._state.http.create_guild_sticker(self.id, payload, file, reason) data = await self._state.http.create_guild_sticker(self.id, payload, file, reason)
return self._state.store_sticker(self, data) return self._state.store_sticker(data)
async def delete_sticker(self, sticker: Snowflake, *, reason: Optional[str] = None) -> None: async def delete_sticker(self, sticker: Snowflake, *, reason: Optional[str] = None) -> None:
"""|coro| """|coro|

View File

@ -137,7 +137,7 @@ class Interaction:
self.message: Optional[Message] self.message: Optional[Message]
try: try:
self.message = Message(state=self._state, channel=self.channel, data=data["message"]) # type: ignore self.message = self._state.store_message(channel=self.channel, data=data["message"]) # type: ignore
except KeyError: except KeyError:
self.message = None self.message = None

View File

@ -42,22 +42,12 @@ __all__ = (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from .types.audit_log import ( from .types.member import MemberWithUser as MemberWithUserPayload
AuditLog as AuditLogPayload, from .types.user import PartialUser as PartialUserPayload
) from .types.audit_log import AuditLog as AuditLogPayload
from .types.guild import ( from .types.message import Message as MessagePayload
Guild as GuildPayload, from .types.threads import Thread as ThreadPayload
) from .types.guild import Guild as GuildPayload
from .types.message import (
Message as MessagePayload,
)
from .types.user import (
PartialUser as PartialUserPayload,
)
from .types.threads import (
Thread as ThreadPayload,
)
from .member import Member from .member import Member
from .user import User from .user import User
@ -354,7 +344,7 @@ class HistoryIterator(_AsyncIterator["Message"]):
channel = self.channel channel = self.channel
for element in data: for element in data:
await self.messages.put(self.state.create_message(channel=channel, data=element)) await self.messages.put(self.state.store_message(channel=channel, data=element))
async def _retrieve_messages(self, retrieve) -> List[Message]: async def _retrieve_messages(self, retrieve) -> List[Message]:
"""Retrieve messages and update next parameters.""" """Retrieve messages and update next parameters."""
@ -615,14 +605,18 @@ class MemberIterator(_AsyncIterator["Member"]):
if isinstance(after, datetime.datetime): if isinstance(after, datetime.datetime):
after = Object(id=time_snowflake(after, high=True)) after = Object(id=time_snowflake(after, high=True))
self.guild = guild
self.limit = limit self.limit = limit
self.guild: Guild = guild
self.after = after or OLDEST_OBJECT self.after = after or OLDEST_OBJECT
self.state = self.guild._state self.state = self.guild._state
self.get_members = self.state.http.get_members self.get_members = self.state.http.get_members
self.members = asyncio.Queue() self.members = asyncio.Queue()
self.create_member = (
self.create_member_cache if self.state.member_cache_flags.fetched else self.create_member_no_cache
)
async def next(self) -> Member: async def next(self) -> Member:
if self.members.empty(): if self.members.empty():
await self.fill_members() await self.fill_members()
@ -657,11 +651,16 @@ class MemberIterator(_AsyncIterator["Member"]):
for element in reversed(data): for element in reversed(data):
await self.members.put(self.create_member(element)) await self.members.put(self.create_member(element))
def create_member(self, data): def create_member_no_cache(self, data: MemberWithUserPayload) -> Member:
from .member import Member from .member import Member
return Member(data=data, guild=self.guild, state=self.state) return Member(data=data, guild=self.guild, state=self.state)
def create_member_cache(self, data: MemberWithUserPayload) -> Member:
member = self.create_member_no_cache(data)
self.guild._add_member(member)
return member
class ArchivedThreadIterator(_AsyncIterator["Thread"]): class ArchivedThreadIterator(_AsyncIterator["Thread"]):
def __init__( def __init__(

View File

@ -1331,7 +1331,7 @@ class Message(Hashable):
payload["components"] = [] payload["components"] = []
data = await self._state.http.edit_message(self.channel.id, self.id, **payload) data = await self._state.http.edit_message(self.channel.id, self.id, **payload)
message = Message(state=self._state, channel=self.channel, data=data) message = self._state.store_message(channel=self.channel, data=data)
if view and not view.is_finished(): if view and not view.is_finished():
self._state.store_view(view, self.id) self._state.store_view(view, self.id)
@ -1756,7 +1756,7 @@ class PartialMessage(Hashable):
""" """
data = await self._state.http.get_message(self.channel.id, self.id) data = await self._state.http.get_message(self.channel.id, self.id)
return self._state.create_message(channel=self.channel, data=data) return self._state.store_message(channel=self.channel, data=data)
async def edit(self, **fields: Any) -> Optional[Message]: async def edit(self, **fields: Any) -> Optional[Message]:
"""|coro| """|coro|
@ -1873,7 +1873,7 @@ class PartialMessage(Hashable):
if fields: if fields:
# data isn't unbound # data isn't unbound
msg = self._state.create_message(channel=self.channel, data=data) # type: ignore msg = self._state.store_message(channel=self.channel, data=data) # type: ignore
if view and not view.is_finished(): if view and not view.is_finished():
self._state.store_view(view, self.id) self._state.store_view(view, self.id)
return msg return msg

View File

@ -27,7 +27,6 @@ from __future__ import annotations
import asyncio import asyncio
from collections import deque, OrderedDict from collections import deque, OrderedDict
import copy import copy
import datetime
import itertools import itertools
import logging import logging
from typing import Dict, Optional, TYPE_CHECKING, Union, Callable, Any, List, TypeVar, Coroutine, Sequence, Tuple, Deque from typing import Dict, Optional, TYPE_CHECKING, Union, Callable, Any, List, TypeVar, Coroutine, Sequence, Tuple, Deque
@ -75,6 +74,7 @@ if TYPE_CHECKING:
from .types.sticker import GuildSticker as GuildStickerPayload from .types.sticker import GuildSticker as GuildStickerPayload
from .types.guild import Guild as GuildPayload from .types.guild import Guild as GuildPayload
from .types.message import Message as MessagePayload from .types.message import Message as MessagePayload
from .types.appinfo import AppInfo as AppInfoPayload
T = TypeVar("T") T = TypeVar("T")
CS = TypeVar("CS", bound="ConnectionState") CS = TypeVar("CS", bound="ConnectionState")
@ -323,6 +323,13 @@ class ConnectionState:
for vc in self.voice_clients: for vc in self.voice_clients:
vc.main_ws = ws # type: ignore vc.main_ws = ws # type: ignore
def store_message(self, channel: MessageableChannel, data: MessagePayload) -> Message:
message = Message(state=self, channel=channel, data=data)
if self._messages is not None:
self._messages.append(message)
return message
def store_user(self, data: UserPayload) -> User: def store_user(self, data: UserPayload) -> User:
user_id = int(data["id"]) user_id = int(data["id"])
try: try:
@ -353,11 +360,20 @@ class ConnectionState:
self._emojis[emoji_id] = emoji = Emoji(guild=guild, state=self, data=data) self._emojis[emoji_id] = emoji = Emoji(guild=guild, state=self, data=data)
return emoji return emoji
def store_sticker(self, guild: Guild, data: GuildStickerPayload) -> GuildSticker: def store_sticker(self, data: GuildStickerPayload) -> GuildSticker:
sticker_id = int(data["id"]) sticker_id = int(data["id"])
self._stickers[sticker_id] = sticker = GuildSticker(state=self, data=data) self._stickers[sticker_id] = sticker = GuildSticker(state=self, data=data)
return sticker return sticker
def store_appinfo(self, data: AppInfoPayload):
self.application_id = utils._get_as_snowflake(data, "id")
flags = data.get("flags")
if flags is not None:
self.application_flags = ApplicationFlags._from_value(flags)
return data
def store_view(self, view: View, message_id: Optional[int] = None) -> None: def store_view(self, view: View, message_id: Optional[int] = None) -> None:
self._view_store.add_view(view, message_id) self._view_store.add_view(view, message_id)
@ -563,9 +579,7 @@ class ConnectionState:
except KeyError: except KeyError:
pass pass
else: else:
self.application_id = utils._get_as_snowflake(application, "id") self.store_appinfo(application)
# flags will always be present here
self.application_flags = ApplicationFlags._from_value(application["flags"]) # type: ignore
for guild_data in data["guilds"]: for guild_data in data["guilds"]:
self._add_guild_from_data(guild_data) self._add_guild_from_data(guild_data)
@ -581,8 +595,7 @@ class ConnectionState:
# channel would be the correct type here # channel would be the correct type here
message = Message(channel=channel, data=data, state=self) # type: ignore message = Message(channel=channel, data=data, state=self) # type: ignore
self.dispatch("message", message) self.dispatch("message", message)
if self._messages is not None: self.store_message(channel, data)
self._messages.append(message)
# we ensure that the channel is either a TextChannel or Thread # we ensure that the channel is either a TextChannel or Thread
if channel and channel.__class__ in (TextChannel, Thread): if channel and channel.__class__ in (TextChannel, Thread):
channel.last_message_id = message.id # type: ignore channel.last_message_id = message.id # type: ignore
@ -1032,7 +1045,7 @@ class ConnectionState:
for emoji in before_stickers: for emoji in before_stickers:
self._stickers.pop(emoji.id, None) self._stickers.pop(emoji.id, None)
# guild won't be None here # guild won't be None here
guild.stickers = tuple(map(lambda d: self.store_sticker(guild, d), data["stickers"])) # type: ignore guild.stickers = tuple(map(lambda d: self.store_sticker(d), data["stickers"])) # type: ignore
self.dispatch("guild_stickers_update", guild, before_stickers, guild.stickers) self.dispatch("guild_stickers_update", guild, before_stickers, guild.stickers)
def _get_create_guild(self, data): def _get_create_guild(self, data):
@ -1403,11 +1416,6 @@ class ConnectionState:
if channel is not None: if channel is not None:
return channel return channel
def create_message(
self, *, channel: Union[TextChannel, Thread, DMChannel, GroupChannel, PartialMessageable], data: MessagePayload
) -> Message:
return Message(state=self, channel=channel, data=data)
class AutoShardedConnectionState(ConnectionState): class AutoShardedConnectionState(ConnectionState):
def __init__(self, *args: Any, **kwargs: Any) -> None: def __init__(self, *args: Any, **kwargs: Any) -> None: