From 8cdc1f4ad9628196dcaf3f9d78ada4aad055f266 Mon Sep 17 00:00:00 2001 From: Arthur Jovart Date: Sun, 29 Aug 2021 00:26:26 +0200 Subject: [PATCH 01/64] Alias admin to administrators in permissions. This needs to be tested, but should be working. --- discord/permissions.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/discord/permissions.py b/discord/permissions.py index 9d40ca33..ff383431 100644 --- a/discord/permissions.py +++ b/discord/permissions.py @@ -299,6 +299,8 @@ class Permissions(BaseFlags): """ return 1 << 3 + admin = administrator + @flag_value def manage_channels(self) -> int: """:class:`bool`: Returns ``True`` if a user can edit, delete, or create channels in the guild. -- 2.47.2 From 10d8a03d7139798a748ba0aa3c64af11b052a417 Mon Sep 17 00:00:00 2001 From: privacy lulz <80280930+slyberries@users.noreply.github.com> Date: Wed, 1 Sep 2021 19:36:37 +0000 Subject: [PATCH 02/64] Add base -> colour support, and optimization in errors.py. --- discord/colour.py | 18 ++++++++++++++++++ discord/ext/commands/errors.py | 2 +- 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/discord/colour.py b/discord/colour.py index 927addc1..96f0f0ee 100644 --- a/discord/colour.py +++ b/discord/colour.py @@ -83,6 +83,24 @@ class Colour: raise TypeError(f'Expected int parameter, received {value.__class__.__name__} instead.') self.value: int = value + + @staticmethod + def from_hex(self, value: str, base:int = None): + """ + Initiate self.value from different base(hexidecimal, binary) + ======= + + value `str` : + + value in different base, e.g. white in hexidecimal, 0xffffff + + base `int` (optional) : + + base of your value, if you don't supply this, you have to add a prefix to your number, e.g. 0x or 0b + + """ + + return Colour(value = int(value, base=base)) def _get_byte(self, byte: int) -> int: return (self.value >> (8 * byte)) & 0xff diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index 93834385..9ea1e81a 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -110,7 +110,7 @@ class CommandError(DiscordException): from :class:`.Bot`\, :func:`.on_command_error`. """ def __init__(self, message: Optional[str] = None, *args: Any) -> None: - if message is not None: + if message: # replace 'if not none' with 'if message' # clean-up @everyone and @here mentions m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere') super().__init__(m, *args) -- 2.47.2 From ec5d350fbb4e30f3707c7a05922d67a5ee0ddab3 Mon Sep 17 00:00:00 2001 From: privacy lulz <80280930+slyberries@users.noreply.github.com> Date: Wed, 1 Sep 2021 19:39:45 +0000 Subject: [PATCH 03/64] Add optimizations in errors --- discord/errors.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discord/errors.py b/discord/errors.py index bc2398d5..9b33c260 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -143,7 +143,7 @@ class HTTPException(DiscordException): self.code = 0 fmt = '{0.status} {0.reason} (error code: {1})' - if len(self.text): + if self.text: fmt += ': {2}' super().__init__(fmt.format(self.response, self.code, self.text)) -- 2.47.2 From ee58f2d36c96fecd6128e2b70ec8171b8e914492 Mon Sep 17 00:00:00 2001 From: privacy lulz <80280930+slyberries@users.noreply.github.com> Date: Wed, 1 Sep 2021 19:47:18 +0000 Subject: [PATCH 04/64] Add different base support --- discord/embeds.py | 19 +++++++++++++++++-- 1 file changed, 17 insertions(+), 2 deletions(-) diff --git a/discord/embeds.py b/discord/embeds.py index 25f05aef..c118aadc 100644 --- a/discord/embeds.py +++ b/discord/embeds.py @@ -315,13 +315,28 @@ class Embed: return getattr(self, '_colour', EmptyEmbed) @colour.setter - def colour(self, value: Union[int, Colour, _EmptyEmbed]): # type: ignore + def colour(self, value: Union[int, Colour, _EmptyEmbed, str], base:int=None): # type: ignore + """ + Set colour + ============ + + value `Union[int, Colour, _EmptyEmbed, str]`: + + value of the colour. If you want to use different number systems such as hexidecimal or binary, use Colour.from_base or set it as a string in that system + + base `int` (optional): + + if value is a string without a prefix(0x, 0b), set base here + """ + if isinstance(value, (Colour, _EmptyEmbed)): self._colour = value elif isinstance(value, int): self._colour = Colour(value=value) + elif isinstance(value, str): + self._colour = int(value, base=base) else: - raise TypeError(f'Expected discord.Colour, int, or Embed.Empty but received {value.__class__.__name__} instead.') + raise TypeError(f'Expected discord.Colour, int, str, or Embed.Empty but received {value.__class__.__name__} instead.') color = colour -- 2.47.2 From 7ca4e650fb389405e1003ce20b9be1577ee99433 Mon Sep 17 00:00:00 2001 From: privacy lulz <80280930+slyberries@users.noreply.github.com> Date: Wed, 1 Sep 2021 20:17:52 +0000 Subject: [PATCH 05/64] Bare bones for JSON database --- discord/__init__.py | 1 - discord/ext/commands/cog.py | 4 +- discord/ext/json/__init__.py | 0 discord/ext/json/types.py | 95 ++++++++++++++++++++++++++++++++++++ 4 files changed, 97 insertions(+), 3 deletions(-) create mode 100644 discord/ext/json/__init__.py create mode 100644 discord/ext/json/types.py diff --git a/discord/__init__.py b/discord/__init__.py index 1e74cf91..818abc4a 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -68,7 +68,6 @@ class VersionInfo(NamedTuple): releaselevel: Literal["alpha", "beta", "candidate", "final"] serial: int - version_info: VersionInfo = VersionInfo(major=2, minor=0, micro=0, releaselevel='alpha', serial=0) logging.getLogger(__name__).addHandler(logging.NullHandler()) diff --git a/discord/ext/commands/cog.py b/discord/ext/commands/cog.py index 9931557d..3d3dc1c0 100644 --- a/discord/ext/commands/cog.py +++ b/discord/ext/commands/cog.py @@ -208,7 +208,7 @@ class Cog(metaclass=CogMeta): for command in self.__cog_commands__: setattr(self, command.callback.__name__, command) parent = command.parent - if parent is not None: + if parent: # Get the latest parent reference parent = lookup[parent.qualified_name] # type: ignore @@ -230,7 +230,7 @@ class Cog(metaclass=CogMeta): This does not include subcommands. """ - return [c for c in self.__cog_commands__ if c.parent is None] + return [c for c in self.__cog_commands__ if not c.parent] @property def qualified_name(self) -> str: diff --git a/discord/ext/json/__init__.py b/discord/ext/json/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/discord/ext/json/types.py b/discord/ext/json/types.py new file mode 100644 index 00000000..076afb1b --- /dev/null +++ b/discord/ext/json/types.py @@ -0,0 +1,95 @@ +import json + +from typing import Union, Any + +class Serializer: + def bind_file(self, file:str): + """Bind file to be serialized(loaded or dumped)""" + + self.file = file + + async def load(self, content:str): + raise NotImplementedError("Serializer has to be implemented!") + + async def dump(self): + raise NotImplementedError("Serializer hasn't been implemented") + +class File: + def __init__(self, path:str): + """ + Initialize a file. + + ======= + + path `str` : + + path to the file + """ + + self.path : str = path + + @staticmethod + def construct(path:str): + """ + Construct a File object statically + """ + + if not isinstance(path, str): + raise TypeError(f"Expected path to be type str, got {type(path)} instead.") + + return File(path) + + async def access(self, mode:str): + """ + Open object to access file. Don't use this in your code. + """ + + if not isinstance(mode, str): + raise TypeError(f"Expected mode to be type str, got {type(mode)} instead.") + + return open(self.path, mode) + + async def open_writer(self): + """Open plain writer for file""" + + return await self.access("w+") + + async def open_reader(self): + """Open plain reader for file""" + + return await self.access("r+") + + async def open_binary_writer(self): + """Open binary writer for file""" + + return await self.access("wb+") + + async def open_binary_reader(self): + """Open binary reader for file""" + + return await self.access("rb+") + + async def read_contents(self, binary=False) -> Union[str, bytes]: + """ + Dump out file contents + ====== + + binary `bool` (optional): + + Set to False normally. Controls whether to read file as binary or read as normal text. + """ + + reader = None # initialize + + if binary: + reader = await self.open_binary_reader() + else: + reader = await self.open_reader() + + contents = reader.read() + + reader.close() # safely close file + + return contents + + async def serialize(self, serializer:Serializer) -> Union[Any]: -- 2.47.2 From dba9a8abb9a27674c0bd2830152a12539bd3dea4 Mon Sep 17 00:00:00 2001 From: classerase <60030956+classerase@users.noreply.github.com> Date: Wed, 1 Sep 2021 23:30:15 +0100 Subject: [PATCH 06/64] Update README.rst (#48) --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index b8fdf99a..c9c2b6e6 100644 --- a/README.rst +++ b/README.rst @@ -59,7 +59,7 @@ To install the development version, do the following: .. code:: sh $ git clone https://github.com/iDevision/enhanced-discord.py - $ cd discord.py + $ cd enhanced-discord.py $ python3 -m pip install -U .[voice] -- 2.47.2 From c485e08ea012be73fc44b136080d651a54896452 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 2 Sep 2021 02:47:15 +0200 Subject: [PATCH 07/64] Add try_member to guild. (#14) * Add try_member to guild. This also fix an omission in the fetch_member docs. fetch_member raises NotFound if the given user isn't in the guild. * Optimize imports. --- discord/guild.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/discord/guild.py b/discord/guild.py index 4ed89821..c9132f5a 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -46,7 +46,7 @@ from . import utils, abc from .role import Role from .member import Member, VoiceState from .emoji import Emoji -from .errors import InvalidData +from .errors import InvalidData, NotFound from .permissions import PermissionOverwrite from .colour import Colour from .errors import InvalidArgument, ClientException @@ -1723,6 +1723,8 @@ class Guild(Hashable): You do not have access to the guild. HTTPException Fetching the member failed. + NotFound + A member with that ID does not exist. Returns -------- @@ -1732,6 +1734,34 @@ class Guild(Hashable): data = await self._state.http.get_member(self.id, member_id) return Member(data=data, state=self._state, guild=self) + async def try_member(self, member_id: int, /) -> Optional[Member]: + """|coro| + + Returns a member with the given ID. This uses the cache first, and if not found, it'll request using :meth:`fetch_member`. + + .. note:: + This method might result in an API call. + + Parameters + ----------- + member_id: :class:`int` + The ID to search for. + + Returns + -------- + Optional[:class:`Member`] + The member or ``None`` if not found. + """ + member = self.get_member(member_id) + + if member: + return member + else: + try: + return await self.fetch_member(member_id) + except NotFound: + return None + async def fetch_ban(self, user: Snowflake) -> BanEntry: """|coro| -- 2.47.2 From 630a842556f695e346b6f7af1cca0af745b9ba03 Mon Sep 17 00:00:00 2001 From: NightSlasher35 <49624805+NightSlasher35@users.noreply.github.com> Date: Thu, 2 Sep 2021 01:49:41 +0100 Subject: [PATCH 08/64] Update CONTRIBUTING.md correctly (#29) * Update CONTRIBUTING.md * Update CONTRIBUTING.md Co-authored-by: Tom <47765953+IAmTomahawkx@users.noreply.github.com> --- .github/CONTRIBUTING.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index 68f037c3..bb457a35 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -1,5 +1,7 @@ ## Contributing to discord.py +Credits to the `original lib` by Rapptz + First off, thanks for taking the time to contribute. It makes the library substantially better. :+1: The following is a set of guidelines for contributing to the repository. These are guidelines, not hard rules. @@ -8,9 +10,9 @@ The following is a set of guidelines for contributing to the repository. These a Generally speaking questions are better suited in our resources below. -- The official support server: https://discord.gg/r3sSKJJ +- The official support server: https://discord.gg/TvqYBrGXEm - The Discord API server under #python_discord-py: https://discord.gg/discord-api -- [The FAQ in the documentation](https://discordpy.readthedocs.io/en/latest/faq.html) +- [The FAQ in the documentation](https://enhanced-dpy.readthedocs.io/en/latest/faq.html) - [StackOverflow's `discord.py` tag](https://stackoverflow.com/questions/tagged/discord.py) Please try your best not to ask questions in our issue tracker. Most of them don't belong there unless they provide value to a larger audience. -- 2.47.2 From b75be64044706d87c4117e135b5882e114b139dd Mon Sep 17 00:00:00 2001 From: iDutchy <42503862+iDutchy@users.noreply.github.com> Date: Thu, 2 Sep 2021 03:05:49 +0200 Subject: [PATCH 09/64] Update permissions.py A better implementation :) --- discord/permissions.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/discord/permissions.py b/discord/permissions.py index ff383431..4b3d9830 100644 --- a/discord/permissions.py +++ b/discord/permissions.py @@ -299,7 +299,12 @@ class Permissions(BaseFlags): """ return 1 << 3 - admin = administrator + @make_permission_alias('administrator') + def admin(self) -> int: + """:class:`bool`: An alias for :attr:`administrator`. + .. versionadded:: 2.0 + """ + return 1 << 3 @flag_value def manage_channels(self) -> int: -- 2.47.2 From 08d012d6dd653f8c19fe339f4de6bf09aec6df19 Mon Sep 17 00:00:00 2001 From: privacy lulz <80280930+slyberries@users.noreply.github.com> Date: Thu, 2 Sep 2021 15:45:46 +0000 Subject: [PATCH 10/64] Add JSON database support --- discord/ext/json/database.py | 50 +++++++++++++++++++++++++++ discord/ext/json/types.py | 66 +++++++++++++++++++++++++++++------- 2 files changed, 104 insertions(+), 12 deletions(-) create mode 100644 discord/ext/json/database.py diff --git a/discord/ext/json/database.py b/discord/ext/json/database.py new file mode 100644 index 00000000..de2503f4 --- /dev/null +++ b/discord/ext/json/database.py @@ -0,0 +1,50 @@ +from .types import JSONFile, Entry +from typing import Union, Any + +class Database: + def __init__(self, file:JSONFile): + self.file = file; + self.database = {}; + + async def getData(self) -> dict: + """ + Get contents from a JSONFile + """ + + contents = await self.file.serialize("load") + + return contents + + async def dumpData(self, data:dict) -> None: + """ + Dump a dict into file + ===== + + data `dict` - + + The data to be dumped into the file. + """ + + await self.file.serialize("dump", contents=data) + + async def loadFile(self) -> None: + """ + Load JSON from file to self.database + """ + + self.database = await self.getData(); + + async def getEntry(self, name:str) -> Entry: + value : Union[Any] = self.database[name] + + return Entry(name, value) + + async def editEntry(self, name:str, value:Union[Any]): + self.database[name] = value; + + async def saveData(self): + """ + Save current database to file + """ + + await self.file.serialize("dump", contents=self.database) \ No newline at end of file diff --git a/discord/ext/json/types.py b/discord/ext/json/types.py index 076afb1b..de4dc0d7 100644 --- a/discord/ext/json/types.py +++ b/discord/ext/json/types.py @@ -2,17 +2,6 @@ import json from typing import Union, Any -class Serializer: - def bind_file(self, file:str): - """Bind file to be serialized(loaded or dumped)""" - - self.file = file - - async def load(self, content:str): - raise NotImplementedError("Serializer has to be implemented!") - - async def dump(self): - raise NotImplementedError("Serializer hasn't been implemented") class File: def __init__(self, path:str): @@ -92,4 +81,57 @@ class File: return contents - async def serialize(self, serializer:Serializer) -> Union[Any]: + async def serialize(self, serializer, *args, **kwargs) -> Union[Any]:# + return serializer(*args, **kwargs) + +class JSONFile(File): + load = json.load + dump = json.dump + + async def serialize(self, serializer:str, *args, **kwargs) -> Union[dict, int]: + """ + Serialize JSON data + ===== + + serializer `str`: + + dump or load data, set to "load" for loading file, "dump" to dump json to file + + contents `dict` of kwargs: + + what to dump. + """ + + if not serializer: + raise ValueError("Argument serializer should be either load or dump, got NoneType instead.") + + if serializer == "load": + reader = self.open_reader() + contents = JSONFile.load(reader) + reader.close() + + return contents + elif serializer == "dump": + writer = self.open_writer() + + JSONFile.dump(kwargs["contents"], writer) + + writer.close() + + return 0; + +class Entry: + """ + Class representing a JSON Entry + """ + + def __init__(self, name:str, value:Union[Any]): + self.name : str = name; + self.value : Union[Any] = value; + + async def browse(self, name:str): + """ + Browse entry and get a new Entry object if it finds something successfully + """ + + return Entry(name, self.value[name]) \ No newline at end of file -- 2.47.2 From 550a0ef19f34c95b35a38e9ad3b6e0bc8bba3c23 Mon Sep 17 00:00:00 2001 From: privacy lulz <80280930+slyberries@users.noreply.github.com> Date: Thu, 2 Sep 2021 15:54:37 +0000 Subject: [PATCH 11/64] Add color and embed utilities --- discord/colour.py | 2 +- discord/utils.py | 37 +++++++++++++++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 1 deletion(-) diff --git a/discord/colour.py b/discord/colour.py index 96f0f0ee..e0bcf700 100644 --- a/discord/colour.py +++ b/discord/colour.py @@ -85,7 +85,7 @@ class Colour: self.value: int = value @staticmethod - def from_hex(self, value: str, base:int = None): + def from_base(self, value: str, base:int = None): """ Initiate self.value from different base(hexidecimal, binary) ======= diff --git a/discord/utils.py b/discord/utils.py index 4360b77a..586896e7 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -60,6 +60,8 @@ import re import sys import types import warnings +from .embeds import Embed +from .colors import Colour from .errors import InvalidArgument @@ -168,6 +170,41 @@ class CachedSlotProperty(Generic[T, T_co]): setattr(instance, self.name, value) return value +def generate_embed(header:str, content:str, footer:str, color=None): # Courtesy of Dank HadocK, my friend who coded this for my rewrite. + """ + Easy way to form embeds + + This was made by Dank Had0cK, a valuable contributor on my fork. Thanks. + + ===== + + header `str`: + + Header of your embed + + content `str`: + + Content of your embed + + footer `str`: + + Footer of your embed + + color `str` optional: + + If it is None, color will default to 2F3136 + Hexstring of your color, uses color converter by Arkae to convert hex to int. + """ + + embed = Embed() + embed.title = header + embed.description = content + if color is None: + embed.color = Colour.from_base("0x2f3136") + else: + embed.color = Colour.from_base(color, 16) + embed.set_footer(text=footer) + return embed class classproperty(Generic[T_co]): def __init__(self, fget: Callable[[Any], T_co]) -> None: -- 2.47.2 From fc0188d7bcbd4fe93bd41331e405a2eb69583b42 Mon Sep 17 00:00:00 2001 From: Daud <78969148+Daudd@users.noreply.github.com> Date: Fri, 3 Sep 2021 02:17:19 +0700 Subject: [PATCH 12/64] Merge pull request #49 * Change README title to enhanced-discord.py --- README.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index c9c2b6e6..9f222e10 100644 --- a/README.rst +++ b/README.rst @@ -1,5 +1,5 @@ -discord.py -========== +enhanced-discord.py +=================== .. image:: https://discord.com/api/guilds/514232441498763279/embed.png :target: https://discord.gg/PYAfZzpsjG -- 2.47.2 From 5d1038457618576290dd23dbbe983293284f1441 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 2 Sep 2021 21:18:26 +0200 Subject: [PATCH 13/64] Merge pull request #27 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add author_permissions to the Context object as a shortcut to return … --- discord/ext/commands/context.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 38a24d1d..fa16c74a 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -32,6 +32,7 @@ import discord.abc import discord.utils from discord.message import Message +from discord import Permissions if TYPE_CHECKING: from typing_extensions import ParamSpec @@ -314,6 +315,13 @@ class Context(discord.abc.Messageable, Generic[BotT]): g = self.guild return g.voice_client if g else None + def author_permissions(self) -> Permissions: + """Returns the author permissions in the given channel. + + .. versionadded:: 2.0 + """ + return self.channel.permissions_for(self.author) + async def send_help(self, *args: Any) -> Any: """send_help(entity=) -- 2.47.2 From 13834d114726f7acc4d093d754568e4d5107c78f Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 2 Sep 2021 21:24:52 +0200 Subject: [PATCH 14/64] Merge pull request #7 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add try_user to get a user from cache or from the gateway. * Extract populate_owners into a new coroutine. * Add a try_owners coroutine to get a list of owners of the bot. * Fix coding-style. * Fix a bug where None would be returned in try_owners if the cache was… * Fix docstring * Add spacing in the code --- discord/client.py | 32 +++++++++++++++++++++ discord/ext/commands/bot.py | 57 +++++++++++++++++++++++++++++++++---- 2 files changed, 83 insertions(+), 6 deletions(-) diff --git a/discord/client.py b/discord/client.py index b4f1db17..746dd995 100644 --- a/discord/client.py +++ b/discord/client.py @@ -829,6 +829,38 @@ class Client: """ return self._connection.get_user(id) + async def try_user(self, id: int, /) -> Optional[User]: + """|coro| + Returns a user with the given ID. If not from cache, the user will be requested from the API. + + You do not have to share any guilds with the user to get this information from the API, + however many operations do require that you do. + + .. note:: + This method is an API call. If you have :attr:`discord.Intents.members` and member cache enabled, consider :meth:`get_user` instead. + + .. versionadded:: 2.0 + + Parameters + ----------- + id: :class:`int` + The ID to search for. + + Returns + -------- + Optional[:class:`~discord.User`] + The user or ``None`` if not found. + """ + maybe_user = self.get_user(id) + + if maybe_user is not None: + return maybe_user + + try: + return await self.fetch_user(id) + except NotFound: + return None + def get_emoji(self, id: int, /) -> Optional[Emoji]: """Returns an emoji with the given ID. diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index c089b87d..e169577f 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -344,14 +344,59 @@ class BotBase(GroupMixin): elif self.owner_ids: return user.id in self.owner_ids else: + # Populate the used fields, then retry the check. This is only done at-most once in the bot lifetime. + await self.populate_owners() + return await self.is_owner(user) - app = await self.application_info() # type: ignore - if app.team: - self.owner_ids = ids = {m.id for m in app.team.members} - return user.id in ids + async def try_owners(self) -> List[discord.User]: + """|coro| + + Returns a list of :class:`~discord.User` representing the owners of the bot. + It uses the :attr:`owner_id` and :attr:`owner_ids`, if set. + + .. versionadded:: 2.0 + The function also checks if the application is team-owned if + :attr:`owner_ids` is not set. + + Returns + -------- + List[:class:`~discord.User`] + List of owners of the bot. + """ + if self.owner_id: + owner = await self.try_user(self.owner_id) + + if owner: + return [owner] else: - self.owner_id = owner_id = app.owner.id - return user.id == owner_id + return [] + + elif self.owner_ids: + owners = [] + + for owner_id in self.owner_ids: + owner = await self.try_user(owner_id) + if owner: + owners.append(owner) + + return owners + else: + # We didn't have owners cached yet, cache them and retry. + await self.populate_owners() + return await self.try_owners() + + async def populate_owners(self): + """|coro| + + Populate the :attr:`owner_id` and :attr:`owner_ids` through the use of :meth:`~.Bot.application_info`. + + .. versionadded:: 2.0 + """ + app = await self.application_info() # type: ignore + if app.team: + self.owner_ids = {m.id for m in app.team.members} + else: + self.owner_id = app.owner.id def before_invoke(self, coro: CFT) -> CFT: """A decorator that registers a coroutine as a pre-invoke hook. -- 2.47.2 From 092fbca08f7bd2de8dda191c9a7007c9d3649c68 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 2 Sep 2021 21:28:03 +0200 Subject: [PATCH 15/64] Merge pull request #21 * [BREAKING] Make case_insensitive default to True on groups and commands --- discord/ext/commands/bot.py | 2 +- discord/ext/commands/core.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index e169577f..e03562b6 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -1120,7 +1120,7 @@ class Bot(BotBase, discord.Client): when passing an empty string, it should always be last as no prefix after it will be matched. case_insensitive: :class:`bool` - Whether the commands should be case insensitive. Defaults to ``False``. This + Whether the commands should be case insensitive. Defaults to ``True``. This attribute does not carry over to groups. You must set it to every group if you require group commands to be case insensitive as well. description: :class:`str` diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 35b7e840..f122e9ad 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -1135,10 +1135,10 @@ class GroupMixin(Generic[CogT]): A mapping of command name to :class:`.Command` objects. case_insensitive: :class:`bool` - Whether the commands should be case insensitive. Defaults to ``False``. + Whether the commands should be case insensitive. Defaults to ``True``. """ def __init__(self, *args: Any, **kwargs: Any) -> None: - case_insensitive = kwargs.get('case_insensitive', False) + case_insensitive = kwargs.get('case_insensitive', True) self.all_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {} self.case_insensitive: bool = case_insensitive super().__init__(*args, **kwargs) -- 2.47.2 From 42c0a8d8a5840c00185e367933e61e2565bf7305 Mon Sep 17 00:00:00 2001 From: chillymosh <86857777+chillymosh@users.noreply.github.com> Date: Thu, 2 Sep 2021 20:32:46 +0100 Subject: [PATCH 16/64] Merge pull request #12 * Clean up python * Clean up bot python * revert lists * revert commands.bot completely * extract raise_expected_coro further * add new lines * removed erroneous import * remove hashed line --- discord/activity.py | 10 ++++----- discord/ext/commands/bot.py | 16 ++++++-------- discord/ext/commands/context.py | 6 ++---- discord/ext/commands/converter.py | 21 +++++++++--------- discord/ext/commands/view.py | 15 ++++++------- discord/ext/tasks/__init__.py | 25 ++++++--------------- discord/player.py | 13 ++++++----- discord/ui/button.py | 16 +++++++------- discord/utils.py | 36 +++++++++++++++++-------------- 9 files changed, 74 insertions(+), 84 deletions(-) diff --git a/discord/activity.py b/discord/activity.py index 51205377..cba61f38 100644 --- a/discord/activity.py +++ b/discord/activity.py @@ -794,13 +794,13 @@ class CustomActivity(BaseActivity): return hash((self.name, str(self.emoji))) def __str__(self) -> str: - if self.emoji: - if self.name: - return f'{self.emoji} {self.name}' - return str(self.emoji) - else: + if not self.emoji: return str(self.name) + if self.name: + return f'{self.emoji} {self.name}' + return str(self.emoji) + def __repr__(self) -> str: return f'' diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index e03562b6..b3a1fb57 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -43,6 +43,7 @@ from .context import Context from . import errors from .help import HelpCommand, DefaultHelpCommand from .cog import Cog +from discord.utils import raise_expected_coro if TYPE_CHECKING: import importlib.machinery @@ -424,11 +425,9 @@ class BotBase(GroupMixin): TypeError The coroutine passed is not actually a coroutine. """ - if not asyncio.iscoroutinefunction(coro): - raise TypeError('The pre-invoke hook must be a coroutine.') - - self._before_invoke = coro - return coro + return raise_expected_coro( + coro, 'The pre-invoke hook must be a coroutine.' + ) def after_invoke(self, coro: CFT) -> CFT: r"""A decorator that registers a coroutine as a post-invoke hook. @@ -457,11 +456,10 @@ class BotBase(GroupMixin): TypeError The coroutine passed is not actually a coroutine. """ - if not asyncio.iscoroutinefunction(coro): - raise TypeError('The post-invoke hook must be a coroutine.') + return raise_expected_coro( + coro, 'The post-invoke hook must be a coroutine.' + ) - self._after_invoke = coro - return coro # listener registration diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index fa16c74a..158c84ea 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations import inspect @@ -61,10 +62,7 @@ T = TypeVar('T') BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]") CogT = TypeVar('CogT', bound="Cog") -if TYPE_CHECKING: - P = ParamSpec('P') -else: - P = TypeVar('P') +P = ParamSpec('P') if TYPE_CHECKING else TypeVar('P') class Context(discord.abc.Messageable, Generic[BotT]): diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index 5740a188..ce7037a4 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -353,15 +353,15 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): @staticmethod def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]: - if guild_id is not None: - guild = ctx.bot.get_guild(guild_id) - if guild is not None and channel_id is not None: - return guild._resolve_channel(channel_id) # type: ignore - else: - return None - else: + if guild_id is None: return ctx.bot.get_channel(channel_id) if channel_id else ctx.channel + guild = ctx.bot.get_guild(guild_id) + if guild is not None and channel_id is not None: + return guild._resolve_channel(channel_id) # type: ignore + else: + return None + async def convert(self, ctx: Context, argument: str) -> discord.PartialMessage: guild_id, message_id, channel_id = self._get_id_matches(ctx, argument) channel = self._resolve_channel(ctx, guild_id, channel_id) @@ -754,8 +754,8 @@ class GuildConverter(IDConverter[discord.Guild]): if result is None: result = discord.utils.get(ctx.bot.guilds, name=argument) - if result is None: - raise GuildNotFound(argument) + if result is None: + raise GuildNotFound(argument) return result @@ -939,8 +939,7 @@ class clean_content(Converter[str]): def repl(match: re.Match) -> str: type = match[1] id = int(match[2]) - transformed = transforms[type](id) - return transformed + return transforms[type](id) result = re.sub(r'<(@[!&]?|#)([0-9]{15,20})>', repl, argument) if self.escape_markdown: diff --git a/discord/ext/commands/view.py b/discord/ext/commands/view.py index a7dc7236..39cc35f7 100644 --- a/discord/ext/commands/view.py +++ b/discord/ext/commands/view.py @@ -82,9 +82,7 @@ class StringView: def skip_string(self, string): strlen = len(string) if self.buffer[self.index:self.index + strlen] == string: - self.previous = self.index - self.index += strlen - return True + return self._return_index(strlen, True) return False def read_rest(self): @@ -95,9 +93,7 @@ class StringView: def read(self, n): result = self.buffer[self.index:self.index + n] - self.previous = self.index - self.index += n - return result + return self._return_index(n, result) def get(self): try: @@ -105,9 +101,12 @@ class StringView: except IndexError: result = None + return self._return_index(1, result) + + def _return_index(self, arg0, arg1): self.previous = self.index - self.index += 1 - return result + self.index += arg0 + return arg1 def get_word(self): pos = 0 diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index 5b78f10e..9518390e 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -46,7 +46,9 @@ import traceback from collections.abc import Sequence from discord.backoff import ExponentialBackoff -from discord.utils import MISSING +from discord.utils import MISSING, raise_expected_coro + + __all__ = ( 'loop', @@ -488,11 +490,7 @@ class Loop(Generic[LF]): The function was not a coroutine. """ - if not inspect.iscoroutinefunction(coro): - raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') - - self._before_loop = coro - return coro + return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.') def after_loop(self, coro: FT) -> FT: """A decorator that register a coroutine to be called after the loop finished running. @@ -516,11 +514,7 @@ class Loop(Generic[LF]): The function was not a coroutine. """ - if not inspect.iscoroutinefunction(coro): - raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') - - self._after_loop = coro - return coro + return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.') def error(self, coro: ET) -> ET: """A decorator that registers a coroutine to be called if the task encounters an unhandled exception. @@ -542,11 +536,7 @@ class Loop(Generic[LF]): TypeError The function was not a coroutine. """ - if not inspect.iscoroutinefunction(coro): - raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') - - self._error = coro # type: ignore - return coro + return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.') def _get_next_sleep_time(self) -> datetime.datetime: if self._sleep is not MISSING: @@ -614,8 +604,7 @@ class Loop(Generic[LF]): ) ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc)) - ret = sorted(set(ret)) # de-dupe and sort times - return ret + return sorted(set(ret)) def change_interval( self, diff --git a/discord/player.py b/discord/player.py index 8098d3e3..79579c8d 100644 --- a/discord/player.py +++ b/discord/player.py @@ -21,6 +21,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from __future__ import annotations import threading @@ -63,10 +64,7 @@ __all__ = ( CREATE_NO_WINDOW: int -if sys.platform != 'win32': - CREATE_NO_WINDOW = 0 -else: - CREATE_NO_WINDOW = 0x08000000 +CREATE_NO_WINDOW = 0 if sys.platform != 'win32' else 0x08000000 class AudioSource: """Represents an audio stream. @@ -526,7 +524,12 @@ class FFmpegOpusAudio(FFmpegAudio): @staticmethod def _probe_codec_native(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]: - exe = executable[:2] + 'probe' if executable in ('ffmpeg', 'avconv') else executable + exe = ( + executable[:2] + 'probe' + if executable in {'ffmpeg', 'avconv'} + else executable + ) + args = [exe, '-v', 'quiet', '-print_format', 'json', '-show_streams', '-select_streams', 'a:0', source] output = subprocess.check_output(args, timeout=20) codec = bitrate = None diff --git a/discord/ui/button.py b/discord/ui/button.py index fedeac68..0b16e87e 100644 --- a/discord/ui/button.py +++ b/discord/ui/button.py @@ -185,16 +185,16 @@ class Button(Item[V]): @emoji.setter def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]): # type: ignore - if value is not None: - if isinstance(value, str): - self._underlying.emoji = PartialEmoji.from_str(value) - elif isinstance(value, _EmojiTag): - self._underlying.emoji = value._to_partial() - else: - raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__} instead') - else: + if value is None: self._underlying.emoji = None + elif isinstance(value, str): + self._underlying.emoji = PartialEmoji.from_str(value) + elif isinstance(value, _EmojiTag): + self._underlying.emoji = value._to_partial() + else: + raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__} instead') + @classmethod def from_component(cls: Type[B], button: ButtonComponent) -> B: return cls( diff --git a/discord/utils.py b/discord/utils.py index 4360b77a..cad99da6 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -499,14 +499,14 @@ else: def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float: reset_after: Optional[str] = request.headers.get('X-Ratelimit-Reset-After') - if use_clock or not reset_after: - utc = datetime.timezone.utc - now = datetime.datetime.now(utc) - reset = datetime.datetime.fromtimestamp(float(request.headers['X-Ratelimit-Reset']), utc) - return (reset - now).total_seconds() - else: + if not use_clock and reset_after: return float(reset_after) + utc = datetime.timezone.utc + now = datetime.datetime.now(utc) + reset = datetime.datetime.fromtimestamp(float(request.headers['X-Ratelimit-Reset']), utc) + return (reset - now).total_seconds() + async def maybe_coroutine(f, *args, **kwargs): value = f(*args, **kwargs) @@ -659,11 +659,10 @@ def resolve_invite(invite: Union[Invite, str]) -> str: if isinstance(invite, Invite): return invite.code - else: - rx = r'(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)' - m = re.match(rx, invite) - if m: - return m.group(1) + rx = r'(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)' + m = re.match(rx, invite) + if m: + return m.group(1) return invite @@ -687,11 +686,10 @@ def resolve_template(code: Union[Template, str]) -> str: if isinstance(code, Template): return code.code - else: - rx = r'(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)' - m = re.match(rx, code) - if m: - return m.group(1) + rx = r'(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)' + m = re.match(rx, code) + if m: + return m.group(1) return code @@ -1017,3 +1015,9 @@ def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None) if style is None: return f'' return f'' + + +def raise_expected_coro(coro, error: str)-> TypeError: + if not asyncio.iscoroutinefunction(coro): + raise TypeError(error) + return coro -- 2.47.2 From 0f6db99c597a629e739f057ec838827c6d51ef12 Mon Sep 17 00:00:00 2001 From: NightSlasher35 <49624805+NightSlasher35@users.noreply.github.com> Date: Thu, 2 Sep 2021 20:34:41 +0100 Subject: [PATCH 17/64] Merge pull request #22 * add nitro booster color * Update discord/colour.py --- discord/colour.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/discord/colour.py b/discord/colour.py index 927addc1..43ad6c6f 100644 --- a/discord/colour.py +++ b/discord/colour.py @@ -251,6 +251,13 @@ class Colour: def red(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0xe74c3c``.""" return cls(0xe74c3c) + + @classmethod + def nitro_booster(cls): + """A factory method that returns a :class:`Colour` with a value of ``0xf47fff``. + + .. versionadded:: 2.0""" + return cls(0xf47fff) @classmethod def dark_red(cls: Type[CT]) -> CT: -- 2.47.2 From f37be7961a05e76d0cda3fc76c6bc79bdcc3c52d Mon Sep 17 00:00:00 2001 From: Ahmad Ansori Palembani <46041660+null2264@users.noreply.github.com> Date: Fri, 3 Sep 2021 02:46:56 +0700 Subject: [PATCH 18/64] Merge pull request #41 * Fixed `TypeError` * Handles `EmptyEmbed` inside setter instead of set_ * Remove return and setter docstring --- discord/embeds.py | 38 +++++++++++++++++--------------------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/discord/embeds.py b/discord/embeds.py index 25f05aef..41f2be40 100644 --- a/discord/embeds.py +++ b/discord/embeds.py @@ -404,10 +404,13 @@ class Embed: return EmbedProxy(getattr(self, '_image', {})) # type: ignore @image.setter - def image(self: E, *, url: Any): - self._image = { - 'url': str(url), - } + def image(self: E, url: Any): + if url is EmptyEmbed: + del self._image + else: + self._image = { + 'url': str(url), + } @image.deleter def image(self: E): @@ -431,10 +434,7 @@ class Embed: The source URL for the image. Only HTTP(S) is supported. """ - if url is EmptyEmbed: - del self.image - else: - self.image = url + self.image = url return self @@ -454,15 +454,13 @@ class Embed: return EmbedProxy(getattr(self, '_thumbnail', {})) # type: ignore @thumbnail.setter - def thumbnail(self: E, *, url: Any): - """Sets the thumbnail for the embed content. - """ - - self._thumbnail = { - 'url': str(url), - } - - return + def thumbnail(self: E, url: Any): + if url is EmptyEmbed: + del self._thumbnail + else: + self._thumbnail = { + 'url': str(url), + } @thumbnail.deleter def thumbnail(self): @@ -485,10 +483,8 @@ class Embed: url: :class:`str` The source URL for the thumbnail. Only HTTP(S) is supported. """ - if url is EmptyEmbed: - del self.thumbnail - else: - self.thumbnail = url + + self.thumbnail = url return self -- 2.47.2 From 152b61aabb33bb13dae54a3f63e4e97aae81aa6e Mon Sep 17 00:00:00 2001 From: IAmTomahawkx Date: Thu, 2 Sep 2021 12:49:38 -0700 Subject: [PATCH 19/64] fix recursionerror caused by a Pull Request --- discord/embeds.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discord/embeds.py b/discord/embeds.py index 41f2be40..52d71ef4 100644 --- a/discord/embeds.py +++ b/discord/embeds.py @@ -465,7 +465,7 @@ class Embed: @thumbnail.deleter def thumbnail(self): try: - del self.thumbnail + del self._thumbnail except AttributeError: pass -- 2.47.2 From 4055bafaa5b35416f568254b89402d7bb218af01 Mon Sep 17 00:00:00 2001 From: Astrea Date: Thu, 2 Sep 2021 16:34:39 -0400 Subject: [PATCH 20/64] Merge pull request #47 * Added `on_raw_typing` event --- discord/raw_models.py | 38 +++++++++++++++++++++++++++++++++- discord/state.py | 41 ++++++++++++++++++++++--------------- discord/types/raw_models.py | 10 +++++++++ docs/api.rst | 19 +++++++++++++++++ 4 files changed, 91 insertions(+), 17 deletions(-) diff --git a/discord/raw_models.py b/discord/raw_models.py index cda754d1..3c9360ba 100644 --- a/discord/raw_models.py +++ b/discord/raw_models.py @@ -24,6 +24,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations +import datetime from typing import TYPE_CHECKING, Optional, Set, List if TYPE_CHECKING: @@ -34,7 +35,8 @@ if TYPE_CHECKING: MessageUpdateEvent, ReactionClearEvent, ReactionClearEmojiEvent, - IntegrationDeleteEvent + IntegrationDeleteEvent, + TypingEvent ) from .message import Message from .partial_emoji import PartialEmoji @@ -49,6 +51,7 @@ __all__ = ( 'RawReactionClearEvent', 'RawReactionClearEmojiEvent', 'RawIntegrationDeleteEvent', + 'RawTypingEvent' ) @@ -276,3 +279,36 @@ class RawIntegrationDeleteEvent(_RawReprMixin): self.application_id: Optional[int] = int(data['application_id']) except KeyError: self.application_id: Optional[int] = None + + +class RawTypingEvent(_RawReprMixin): + """Represents the payload for a :func:`on_raw_typing` event. + + .. versionadded:: 2.0 + + Attributes + ----------- + channel_id: :class:`int` + The channel ID where the typing originated from. + user_id: :class:`int` + The ID of the user that started typing. + when: :class:`datetime.datetime` + When the typing started as an aware datetime in UTC. + guild_id: Optional[:class:`int`] + The guild ID where the typing originated from, if applicable. + member: Optional[:class:`Member`] + The member who started typing. Only available if the member started typing in a guild. + """ + + __slots__ = ("channel_id", "user_id", "when", "guild_id", "member") + + def __init__(self, data: TypingEvent) -> None: + self.channel_id: int = int(data['channel_id']) + self.user_id: int = int(data['user_id']) + self.when: datetime.datetime = datetime.datetime.fromtimestamp(data.get('timestamp'), tz=datetime.timezone.utc) + self.member: Optional[Member] = None + + try: + self.guild_id: Optional[int] = int(data['guild_id']) + except KeyError: + self.guild_id: Optional[int] = None \ No newline at end of file diff --git a/discord/state.py b/discord/state.py index 0a9feac1..09777008 100644 --- a/discord/state.py +++ b/discord/state.py @@ -1327,28 +1327,37 @@ class ConnectionState: asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice server update handler')) def parse_typing_start(self, data) -> None: + raw = RawTypingEvent(data) + + member_data = data.get('member') + if member_data: + guild = self._get_guild(raw.guild_id) + if guild is not None: + raw.member = Member(data=member_data, guild=guild, state=self) + else: + raw.member = None + else: + raw.member = None + self.dispatch('raw_typing', raw) + channel, guild = self._get_guild_channel(data) if channel is not None: - member = None - user_id = utils._get_as_snowflake(data, 'user_id') - if isinstance(channel, DMChannel): - member = channel.recipient + user = raw.member or self._get_typing_user(channel, raw.user_id) - elif isinstance(channel, (Thread, TextChannel)) and guild is not None: - # user_id won't be None - member = guild.get_member(user_id) # type: ignore + if user is not None: + self.dispatch('typing', channel, user, raw.when) - if member is None: - member_data = data.get('member') - if member_data: - member = Member(data=member_data, state=self, guild=guild) + def _get_typing_user(self, channel: Optional[MessageableChannel], user_id: int) -> Optional[Union[User, Member]]: + if isinstance(channel, DMChannel): + return channel.recipient - elif isinstance(channel, GroupChannel): - member = utils.find(lambda x: x.id == user_id, channel.recipients) + elif isinstance(channel, (Thread, TextChannel)) and channel.guild is not None: + return channel.guild.get_member(user_id) # type: ignore - if member is not None: - timestamp = datetime.datetime.fromtimestamp(data.get('timestamp'), tz=datetime.timezone.utc) - self.dispatch('typing', channel, member, timestamp) + elif isinstance(channel, GroupChannel): + return utils.find(lambda x: x.id == user_id, channel.recipients) + + return self.get_user(user_id) def _get_reaction_user(self, channel: MessageableChannel, user_id: int) -> Optional[Union[User, Member]]: if isinstance(channel, TextChannel): diff --git a/discord/types/raw_models.py b/discord/types/raw_models.py index 3c45b299..2d779e51 100644 --- a/discord/types/raw_models.py +++ b/discord/types/raw_models.py @@ -85,3 +85,13 @@ class _IntegrationDeleteEventOptional(TypedDict, total=False): class IntegrationDeleteEvent(_IntegrationDeleteEventOptional): id: Snowflake guild_id: Snowflake + + +class _TypingEventOptional(TypedDict, total=False): + guild_id: Snowflake + member: Member + +class TypingEvent(_TypingEventOptional): + channel_id: Snowflake + user_id: Snowflake + timestamp: int \ No newline at end of file diff --git a/docs/api.rst b/docs/api.rst index 0a9ba5cc..5fc56af1 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -369,6 +369,17 @@ to handle it, which defaults to print a traceback and ignoring the exception. :param when: When the typing started as an aware datetime in UTC. :type when: :class:`datetime.datetime` +.. function:: on_raw_typing(payload) + + Called when someone begins typing a message. Unlike :func:`on_typing`, this is + called regardless if the user can be found or not. This most often happens + when a user types in DMs. + + This requires :attr:`Intents.typing` to be enabled. + + :param payload: The raw typing payload. + :type payload: :class:`RawTypingEvent` + .. function:: on_message(message) Called when a :class:`Message` is created and sent. @@ -3846,6 +3857,14 @@ GuildSticker .. autoclass:: GuildSticker() :members: +RawTypingEvent +~~~~~~~~~~~~~~~~~~~~~~~ + +.. attributetable:: RawTypingEvent + +.. autoclass:: RawTypingEvent() + :members: + RawMessageDeleteEvent ~~~~~~~~~~~~~~~~~~~~~~~ -- 2.47.2 From 47e42d164839a9bfff76091fe9604399875a8e01 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 2 Sep 2021 22:40:11 +0200 Subject: [PATCH 21/64] Merge pull request #42 * implement WelcomeScreen * copy over the kwargs issue. * readable variable names * modernise code * modernise pt2 * Update discord/welcome_screen.py * make pylance not cry from my onions * type http.py * remove extraneous import --- discord/__init__.py | 1 + discord/guild.py | 76 ++++++++++++++ discord/http.py | 15 +++ discord/welcome_screen.py | 216 ++++++++++++++++++++++++++++++++++++++ docs/api.rst | 16 +++ 5 files changed, 324 insertions(+) create mode 100644 discord/welcome_screen.py diff --git a/discord/__init__.py b/discord/__init__.py index 1e74cf91..288da41b 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -40,6 +40,7 @@ from .colour import * from .integrations import * from .invite import * from .template import * +from .welcome_screen import * from .widget import * from .object import * from .reaction import * diff --git a/discord/guild.py b/discord/guild.py index c9132f5a..cb53f44c 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -76,6 +76,7 @@ from .stage_instance import StageInstance from .threads import Thread, ThreadMember from .sticker import GuildSticker from .file import File +from .welcome_screen import WelcomeScreen, WelcomeChannel __all__ = ( @@ -2604,6 +2605,81 @@ class Guild(Hashable): return roles + async def welcome_screen(self) -> WelcomeScreen: + """|coro| + + Returns the guild's welcome screen. + + The guild must have ``COMMUNITY`` in :attr:`~Guild.features`. + + You must have the :attr:`~Permissions.manage_guild` permission to use + this as well. + + .. versionadded:: 2.0 + + Raises + ------- + Forbidden + You do not have the proper permissions to get this. + HTTPException + Retrieving the welcome screen failed. + + Returns + -------- + :class:`WelcomeScreen` + The welcome screen. + """ + data = await self._state.http.get_welcome_screen(self.id) + return WelcomeScreen(data=data, guild=self) + + @overload + async def edit_welcome_screen( + self, + *, + description: Optional[str] = ..., + welcome_channels: Optional[List[WelcomeChannel]] = ..., + enabled: Optional[bool] = ..., + ) -> WelcomeScreen: + ... + + @overload + async def edit_welcome_screen(self) -> None: + ... + + async def edit_welcome_screen(self, **kwargs): + """|coro| + + A shorthand method of :attr:`WelcomeScreen.edit` without needing + to fetch the welcome screen beforehand. + + The guild must have ``COMMUNITY`` in :attr:`~Guild.features`. + + You must have the :attr:`~Permissions.manage_guild` permission to use + this as well. + + .. versionadded:: 2.0 + + Returns + -------- + :class:`WelcomeScreen` + The edited welcome screen. + """ + try: + welcome_channels = kwargs['welcome_channels'] + except KeyError: + pass + else: + welcome_channels_serialised = [] + for wc in welcome_channels: + if not isinstance(wc, WelcomeChannel): + raise InvalidArgument('welcome_channels parameter must be a list of WelcomeChannel') + welcome_channels_serialised.append(wc.to_dict()) + kwargs['welcome_channels'] = welcome_channels_serialised + + if kwargs: + data = await self._state.http.edit_welcome_screen(self.id, kwargs) + return WelcomeScreen(data=data, guild=self) + async def kick(self, user: Snowflake, *, reason: Optional[str] = None) -> None: """|coro| diff --git a/discord/http.py b/discord/http.py index 7a4c2adc..4f86fc87 100644 --- a/discord/http.py +++ b/discord/http.py @@ -84,6 +84,7 @@ if TYPE_CHECKING: threads, voice, sticker, + welcome_screen, ) from .types.snowflake import Snowflake, SnowflakeList @@ -1116,6 +1117,20 @@ class HTTPClient: payload['icon'] = icon return self.request(Route('POST', '/guilds/templates/{code}', code=code), json=payload) + def get_welcome_screen(self, guild_id: Snowflake) -> Response[welcome_screen.WelcomeScreen]: + return self.request(Route('GET', '/guilds/{guild_id}/welcome-screen', guild_id=guild_id)) + + def edit_welcome_screen(self, guild_id: Snowflake, payload: Any) -> Response[welcome_screen.WelcomeScreen]: + valid_keys = ( + 'description', + 'welcome_channels', + 'enabled', + ) + payload = { + k: v for k, v in payload.items() if k in valid_keys + } + return self.request(Route('PATCH', '/guilds/{guild_id}/welcome-screen', guild_id=guild_id), json=payload) + def get_bans(self, guild_id: Snowflake) -> Response[List[guild.Ban]]: return self.request(Route('GET', '/guilds/{guild_id}/bans', guild_id=guild_id)) diff --git a/discord/welcome_screen.py b/discord/welcome_screen.py new file mode 100644 index 00000000..de8c3c27 --- /dev/null +++ b/discord/welcome_screen.py @@ -0,0 +1,216 @@ +# -*- coding: utf-8 -*- + +""" +The MIT License (MIT) + +Copyright (c) 2015-present Rapptz + +Permission is hereby granted, free of charge, to any person obtaining a +copy of this software and associated documentation files (the "Software"), +to deal in the Software without restriction, including without limitation +the rights to use, copy, modify, merge, publish, distribute, sublicense, +and/or sell copies of the Software, and to permit persons to whom the +Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS +OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER +DEALINGS IN THE SOFTWARE. +""" + +from __future__ import annotations + +from typing import Dict, List, Optional, TYPE_CHECKING, Union, overload +from .utils import _get_as_snowflake, get +from .errors import InvalidArgument +from .partial_emoji import _EmojiTag + +__all__ = ( + 'WelcomeChannel', + 'WelcomeScreen', +) + +if TYPE_CHECKING: + from .types.welcome_screen import ( + WelcomeScreen as WelcomeScreenPayload, + WelcomeScreenChannel as WelcomeScreenChannelPayload, + ) + from .abc import Snowflake + from .guild import Guild + from .partial_emoji import PartialEmoji + from .emoji import Emoji + + +class WelcomeChannel: + """Represents a :class:`WelcomeScreen` welcome channel. + + .. versionadded:: 2.0 + + Attributes + ----------- + channel: :class:`abc.Snowflake` + The guild channel that is being referenced. + description: :class:`str` + The description shown of the channel. + emoji: Optional[:class:`PartialEmoji`, :class:`Emoji`, :class:`str`] + The emoji used beside the channel description. + """ + + def __init__(self, *, channel: Snowflake, description: str, emoji: Union[PartialEmoji, Emoji, str] = None): + self.channel = channel + self.description = description + self.emoji = emoji + + def __repr__(self) -> str: + return f'' + + @classmethod + def _from_dict(cls, *, data: WelcomeScreenChannelPayload, guild: Guild) -> WelcomeChannel: + channel_id = _get_as_snowflake(data, 'channel_id') + channel = guild.get_channel(channel_id) + description = data['description'] + _emoji_id = _get_as_snowflake(data, 'emoji_id') + _emoji_name = data['emoji_name'] + + if _emoji_id: + # custom + emoji = get(guild.emojis, id=_emoji_id) + else: + # unicode or None + emoji = _emoji_name + + return cls(channel=channel, description=description, emoji=emoji) # type: ignore + + def to_dict(self) -> WelcomeScreenChannelPayload: + ret: WelcomeScreenChannelPayload = { + 'channel_id': self.channel.id, + 'description': self.description, + 'emoji_id': None, + 'emoji_name': None, + } + + if isinstance(self.emoji, _EmojiTag): + ret['emoji_id'] = self.emoji.id # type: ignore + ret['emoji_name'] = self.emoji.name # type: ignore + else: + # unicode or None + ret['emoji_name'] = self.emoji + + return ret + + +class WelcomeScreen: + """Represents a :class:`Guild` welcome screen. + + .. versionadded:: 2.0 + + Attributes + ----------- + description: :class:`str` + The description shown on the welcome screen. + welcome_channels: List[:class:`WelcomeChannel`] + The channels shown on the welcome screen. + """ + + def __init__(self, *, data: WelcomeScreenPayload, guild: Guild): + self._state = guild._state + self._guild = guild + self._store(data) + + def _store(self, data: WelcomeScreenPayload) -> None: + self.description = data['description'] + welcome_channels = data.get('welcome_channels', []) + self.welcome_channels = [WelcomeChannel._from_dict(data=wc, guild=self._guild) for wc in welcome_channels] + + def __repr__(self) -> str: + return f'' + + @property + def enabled(self) -> bool: + """:class:`bool`: Whether the welcome screen is displayed. + + This is equivalent to checking if ``WELCOME_SCREEN_ENABLED`` + is present in :attr:`Guild.features`. + """ + return 'WELCOME_SCREEN_ENABLED' in self._guild.features + + @overload + async def edit( + self, + *, + description: Optional[str] = ..., + welcome_channels: Optional[List[WelcomeChannel]] = ..., + enabled: Optional[bool] = ..., + ) -> None: + ... + + @overload + async def edit(self) -> None: + ... + + async def edit(self, **kwargs): + """|coro| + + Edit the welcome screen. + + You must have the :attr:`~Permissions.manage_guild` permission in the + guild to do this. + + Usage: :: + + rules_channel = guild.get_channel(12345678) + announcements_channel = guild.get_channel(87654321) + + custom_emoji = utils.get(guild.emojis, name='loudspeaker') + + await welcome_screen.edit( + description='This is a very cool community server!', + welcome_channels=[ + WelcomeChannel(channel=rules_channel, description='Read the rules!', emoji='👨‍🏫'), + WelcomeChannel(channel=announcements_channel, description='Watch out for announcements!', emoji=custom_emoji), + ] + ) + + .. note:: + + Welcome channels can only accept custom emojis if :attr:`~Guild.premium_tier` is level 2 or above. + + Parameters + ------------ + description: Optional[:class:`str`] + The template's description. + welcome_channels: Optional[List[:class:`WelcomeChannel`]] + The welcome channels, in their respective order. + enabled: Optional[:class:`bool`] + Whether the welcome screen should be displayed. + + Raises + ------- + HTTPException + Editing the welcome screen failed failed. + Forbidden + You don't have permissions to edit the welcome screen. + NotFound + This welcome screen does not exist. + """ + try: + welcome_channels = kwargs['welcome_channels'] + except KeyError: + pass + else: + welcome_channels_serialised = [] + for wc in welcome_channels: + if not isinstance(wc, WelcomeChannel): + raise InvalidArgument('welcome_channels parameter must be a list of WelcomeChannel') + welcome_channels_serialised.append(wc.to_dict()) + kwargs['welcome_channels'] = welcome_channels_serialised + + if kwargs: + data = await self._state.http.edit_welcome_screen(self._guild.id, kwargs) + self._store(data) diff --git a/docs/api.rst b/docs/api.rst index 5fc56af1..069dd1d4 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -3792,6 +3792,22 @@ Template .. autoclass:: Template() :members: +WelcomeScreen +~~~~~~~~~~~~~~~ + +.. attributetable:: WelcomeScreen + +.. autoclass:: WelcomeScreen() + :members: + +WelcomeChannel +~~~~~~~~~~~~~~~ + +.. attributetable:: WelcomeChannel + +.. autoclass:: WelcomeChannel() + :members: + WidgetChannel ~~~~~~~~~~~~~~~ -- 2.47.2 From 33470ff1960285b5e5d5e0b277ee76a1e71dc1e2 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 2 Sep 2021 22:41:26 +0200 Subject: [PATCH 22/64] Merge pull request #31 * Add bots and humans to TextChannel --- discord/channel.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/discord/channel.py b/discord/channel.py index dc3967c4..f467d1cc 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -228,6 +228,16 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): """List[:class:`Member`]: Returns all members that can see this channel.""" return [m for m in self.guild.members if self.permissions_for(m).read_messages] + @property + def bots(self) -> List[Member]: + """List[:class:`Member`]: Returns all bots that can see this channel.""" + return [m for m in self.guild.members if m.bot and self.permissions_for(m).read_messages] + + @property + def humans(self) -> List[Member]: + """List[:class:`Member`]: Returns all human members that can see this channel.""" + return [m for m in self.guild.members if not m.bot and self.permissions_for(m).read_messages] + @property def threads(self) -> List[Thread]: """List[:class:`Thread`]: Returns all the threads that you can see. -- 2.47.2 From 1032728311ef71f26d4fc8aa5e3ba942d9ae775a Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 2 Sep 2021 22:43:19 +0200 Subject: [PATCH 23/64] Merge pull request #32 * Add get/fetch_member to ThreadMember objects --- discord/threads.py | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/discord/threads.py b/discord/threads.py index 6f8ffb00..a06ef219 100644 --- a/discord/threads.py +++ b/discord/threads.py @@ -808,3 +808,39 @@ class ThreadMember(Hashable): def thread(self) -> Thread: """:class:`Thread`: The thread this member belongs to.""" return self.parent + + async def fetch_member(self) -> Member: + """|coro| + + Retrieves a :class:`Member` from the ThreadMember object. + + .. note:: + + This method is an API call. If you have :attr:`Intents.members` and member cache enabled, consider :meth:`get_member` instead. + + Raises + ------- + Forbidden + You do not have access to the guild. + HTTPException + Fetching the member failed. + + Returns + -------- + :class:`Member` + The member. + """ + + return await self.thread.guild.fetch_member(self.id) + + def get_member(self) -> Optional[Member]: + """ + Get the :class:`Member` from cache for the ThreadMember object. + + Returns + -------- + Optional[:class:`Member`] + The member or ``None`` if not found. + """ + + return await self.thread.guild.get_member(self.id) -- 2.47.2 From 3ffe1348956ebc1e2512439b532fdb6516c267c6 Mon Sep 17 00:00:00 2001 From: Arthur Date: Thu, 2 Sep 2021 22:50:19 +0200 Subject: [PATCH 24/64] Merge pull request #44 * Typehint gateway.py * Add relevant typehints to gateway.py to voice_client.py * Change EventListener to subclass NamedTuple * Add return type for DiscordWebSocket.wait_for * Correct deque typehint * Remove unnecessary typehints for literals * Use type aliases * Merge branch '2.0' into pr7422 --- discord/gateway.py | 257 ++++++++++++++++++++++++---------------- discord/voice_client.py | 3 + 2 files changed, 161 insertions(+), 99 deletions(-) diff --git a/discord/gateway.py b/discord/gateway.py index aa0c6ba0..fbbc3c5e 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -22,8 +22,12 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + +from typing import TYPE_CHECKING, TypedDict, Any, Optional, List, TypeVar, Type, Dict, Callable, Coroutine, NamedTuple, Deque + import asyncio -from collections import namedtuple, deque +from collections import deque import concurrent.futures import logging import struct @@ -38,9 +42,25 @@ import aiohttp from . import utils from .activity import BaseActivity from .enums import SpeakingState -from .errors import ConnectionClosed, InvalidArgument +from .errors import ConnectionClosed, InvalidArgument + +if TYPE_CHECKING: + from .client import Client + from .state import ConnectionState + from .voice_client import VoiceClient + + T = TypeVar('T') + DWS = TypeVar('DWS', bound='DiscordWebSocket') + DVWS = TypeVar('DVWS', bound='DiscordVoiceWebSocket') + + Coro = Callable[..., Coroutine[Any, Any, Any]] + Predicate = Callable[[Dict[str, Any]], bool] + DataCallable = Callable[[Dict[str, Any]], T] + Result = Optional[DataCallable[Any]] + + +_log: logging.Logger = logging.getLogger(__name__) -_log = logging.getLogger(__name__) __all__ = ( 'DiscordWebSocket', @@ -50,36 +70,49 @@ __all__ = ( 'ReconnectWebSocket', ) + +class Heartbeat(TypedDict): + op: int + d: int + + class ReconnectWebSocket(Exception): """Signals to safely reconnect the websocket.""" - def __init__(self, shard_id, *, resume=True): - self.shard_id = shard_id - self.resume = resume + def __init__(self, shard_id: Optional[int], *, resume: bool = True) -> None: + self.shard_id: Optional[int] = shard_id + self.resume: bool = resume self.op = 'RESUME' if resume else 'IDENTIFY' + class WebSocketClosure(Exception): """An exception to make up for the fact that aiohttp doesn't signal closure.""" pass -EventListener = namedtuple('EventListener', 'predicate event result future') + +class EventListener(NamedTuple): + predicate: Predicate + event: str + result: Result + future: asyncio.Future + class GatewayRatelimiter: - def __init__(self, count=110, per=60.0): + def __init__(self, count: int = 110, per: float = 60.0) -> None: # The default is 110 to give room for at least 10 heartbeats per minute - self.max = count - self.remaining = count - self.window = 0.0 - self.per = per - self.lock = asyncio.Lock() - self.shard_id = None + self.max: int = count + self.remaining: int = count + self.window: float = 0.0 + self.per: float = per + self.lock: asyncio.Lock = asyncio.Lock() + self.shard_id: Optional[int] = None - def is_ratelimited(self): + def is_ratelimited(self) -> bool: current = time.time() if current > self.window + self.per: return False return self.remaining == 0 - def get_delay(self): + def get_delay(self) -> float: current = time.time() if current > self.window + self.per: @@ -97,7 +130,7 @@ class GatewayRatelimiter: return 0.0 - async def block(self): + async def block(self) -> None: async with self.lock: delta = self.get_delay() if delta: @@ -106,27 +139,27 @@ class GatewayRatelimiter: class KeepAliveHandler(threading.Thread): - def __init__(self, *args, **kwargs): - ws = kwargs.pop('ws', None) + def __init__(self, *args: Any, **kwargs: Any) -> None: + ws = kwargs.pop('ws') interval = kwargs.pop('interval', None) shard_id = kwargs.pop('shard_id', None) threading.Thread.__init__(self, *args, **kwargs) - self.ws = ws - self._main_thread_id = ws.thread_id - self.interval = interval - self.daemon = True - self.shard_id = shard_id - self.msg = 'Keeping shard ID %s websocket alive with sequence %s.' - self.block_msg = 'Shard ID %s heartbeat blocked for more than %s seconds.' - self.behind_msg = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.' - self._stop_ev = threading.Event() - self._last_ack = time.perf_counter() - self._last_send = time.perf_counter() - self._last_recv = time.perf_counter() - self.latency = float('inf') - self.heartbeat_timeout = ws._max_heartbeat_timeout + self.ws: DiscordWebSocket = ws + self._main_thread_id: int = ws.thread_id + self.interval: Optional[float] = interval + self.daemon: bool = True + self.shard_id: Optional[int] = shard_id + self.msg: str = 'Keeping shard ID %s websocket alive with sequence %s.' + self.block_msg: str = 'Shard ID %s heartbeat blocked for more than %s seconds.' + self.behind_msg: str = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.' + self._stop_ev: threading.Event = threading.Event() + self._last_ack: float = time.perf_counter() + self._last_send: float = time.perf_counter() + self._last_recv: float = time.perf_counter() + self.latency: float = float('inf') + self.heartbeat_timeout: float = ws._max_heartbeat_timeout - def run(self): + def run(self) -> None: while not self._stop_ev.wait(self.interval): if self._last_recv + self.heartbeat_timeout < time.perf_counter(): _log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id) @@ -168,19 +201,20 @@ class KeepAliveHandler(threading.Thread): else: self._last_send = time.perf_counter() - def get_payload(self): + def get_payload(self) -> Heartbeat: return { 'op': self.ws.HEARTBEAT, - 'd': self.ws.sequence + # the websocket's sequence won't be None here + 'd': self.ws.sequence # type: ignore } - def stop(self): + def stop(self) -> None: self._stop_ev.set() - def tick(self): + def tick(self) -> None: self._last_recv = time.perf_counter() - def ack(self): + def ack(self) -> None: ack_time = time.perf_counter() self._last_ack = ack_time self.latency = ack_time - self._last_send @@ -188,30 +222,32 @@ class KeepAliveHandler(threading.Thread): _log.warning(self.behind_msg, self.shard_id, self.latency) class VoiceKeepAliveHandler(KeepAliveHandler): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - self.recent_ack_latencies = deque(maxlen=20) + self.recent_ack_latencies: Deque[float] = deque(maxlen=20) self.msg = 'Keeping shard ID %s voice websocket alive with timestamp %s.' self.block_msg = 'Shard ID %s voice heartbeat blocked for more than %s seconds' self.behind_msg = 'High socket latency, shard ID %s heartbeat is %.1fs behind' - def get_payload(self): + def get_payload(self) -> Heartbeat: return { 'op': self.ws.HEARTBEAT, 'd': int(time.time() * 1000) } - def ack(self): + def ack(self) -> None: ack_time = time.perf_counter() self._last_ack = ack_time self._last_recv = ack_time self.latency = ack_time - self._last_send self.recent_ack_latencies.append(self.latency) + class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse): async def close(self, *, code: int = 4000, message: bytes = b'') -> bool: return await super().close(code=code, message=message) + class DiscordWebSocket: """Implements a WebSocket for Discord's gateway v6. @@ -266,41 +302,53 @@ class DiscordWebSocket: HEARTBEAT_ACK = 11 GUILD_SYNC = 12 - def __init__(self, socket, *, loop): - self.socket = socket - self.loop = loop + def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.AbstractEventLoop) -> None: + self.socket: aiohttp.ClientWebSocketResponse = socket + self.loop: asyncio.AbstractEventLoop = loop # an empty dispatcher to prevent crashes self._dispatch = lambda *args: None # generic event listeners - self._dispatch_listeners = [] + self._dispatch_listeners: List[EventListener] = [] # the keep alive - self._keep_alive = None - self.thread_id = threading.get_ident() + self._keep_alive: Optional[KeepAliveHandler] = None + self.thread_id: int = threading.get_ident() # ws related stuff - self.session_id = None - self.sequence = None + self.session_id: Optional[str] = None + self.sequence: Optional[int] = None self._zlib = zlib.decompressobj() - self._buffer = bytearray() - self._close_code = None - self._rate_limiter = GatewayRatelimiter() + self._buffer: bytearray = bytearray() + self._close_code: Optional[int] = None + self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter() + + # attributes that get set in from_client + self.token: str = utils.MISSING + self._connection: ConnectionState = utils.MISSING + self._discord_parsers: Dict[str, DataCallable[None]] = utils.MISSING + self.gateway: str = utils.MISSING + self.call_hooks: Coro = utils.MISSING + self._initial_identify: bool = utils.MISSING + self.shard_id: Optional[int] = utils.MISSING + self.shard_count: Optional[int] = utils.MISSING + self.session_id: Optional[str] = utils.MISSING + self._max_heartbeat_timeout: float = utils.MISSING @property - def open(self): + def open(self) -> bool: return not self.socket.closed - def is_ratelimited(self): + def is_ratelimited(self) -> bool: return self._rate_limiter.is_ratelimited() - def debug_log_receive(self, data, /): + def debug_log_receive(self, data, /) -> None: self._dispatch('socket_raw_receive', data) - def log_receive(self, _, /): + def log_receive(self, _, /) -> None: pass @classmethod - async def from_client(cls, client, *, initial=False, gateway=None, shard_id=None, session=None, sequence=None, resume=False): + async def from_client(cls: Type[DWS], client: Client, *, initial: bool = False, gateway: Optional[str] = None, shard_id: Optional[int] = None, session: Optional[str] = None, sequence: Optional[int] = None, resume: bool = False) -> DWS: """Creates a main websocket for Discord from a :class:`Client`. This is for internal use only. @@ -310,7 +358,9 @@ class DiscordWebSocket: ws = cls(socket, loop=client.loop) # dynamically add attributes needed - ws.token = client.http.token + + # the token won't be None here + ws.token = client.http.token # type: ignore ws._connection = client._connection ws._discord_parsers = client._connection.parsers ws._dispatch = client.dispatch @@ -342,7 +392,7 @@ class DiscordWebSocket: await ws.resume() return ws - def wait_for(self, event, predicate, result=None): + def wait_for(self, event: str, predicate: Predicate, result: Result = None) -> asyncio.Future: """Waits for a DISPATCH'd event that meets the predicate. Parameters @@ -367,7 +417,7 @@ class DiscordWebSocket: self._dispatch_listeners.append(entry) return future - async def identify(self): + async def identify(self) -> None: """Sends the IDENTIFY packet.""" payload = { 'op': self.IDENTIFY, @@ -405,7 +455,7 @@ class DiscordWebSocket: await self.send_as_json(payload) _log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id) - async def resume(self): + async def resume(self) -> None: """Sends the RESUME packet.""" payload = { 'op': self.RESUME, @@ -419,7 +469,8 @@ class DiscordWebSocket: await self.send_as_json(payload) _log.info('Shard ID %s has sent the RESUME payload.', self.shard_id) - async def received_message(self, msg, /): + + async def received_message(self, msg, /) -> None: if type(msg) is bytes: self._buffer.extend(msg) @@ -537,16 +588,16 @@ class DiscordWebSocket: del self._dispatch_listeners[index] @property - def latency(self): + def latency(self) -> float: """:class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.""" heartbeat = self._keep_alive return float('inf') if heartbeat is None else heartbeat.latency - def _can_handle_close(self): + def _can_handle_close(self) -> bool: code = self._close_code or self.socket.close_code return code not in (1000, 4004, 4010, 4011, 4012, 4013, 4014) - async def poll_event(self): + async def poll_event(self) -> None: """Polls for a DISPATCH event and handles the general gateway loop. Raises @@ -584,23 +635,23 @@ class DiscordWebSocket: _log.info('Websocket closed with %s, cannot reconnect.', code) raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None - async def debug_send(self, data, /): + async def debug_send(self, data, /) -> None: await self._rate_limiter.block() self._dispatch('socket_raw_send', data) await self.socket.send_str(data) - async def send(self, data, /): + async def send(self, data, /) -> None: await self._rate_limiter.block() await self.socket.send_str(data) - async def send_as_json(self, data): + async def send_as_json(self, data) -> None: try: await self.send(utils._to_json(data)) except RuntimeError as exc: if not self._can_handle_close(): raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc - async def send_heartbeat(self, data): + async def send_heartbeat(self, data: Heartbeat) -> None: # This bypasses the rate limit handling code since it has a higher priority try: await self.socket.send_str(utils._to_json(data)) @@ -608,13 +659,13 @@ class DiscordWebSocket: if not self._can_handle_close(): raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc - async def change_presence(self, *, activity=None, status=None, since=0.0): + async def change_presence(self, *, activity: Optional[BaseActivity] = None, status: Optional[str] = None, since: float = 0.0) -> None: if activity is not None: if not isinstance(activity, BaseActivity): raise InvalidArgument('activity must derive from BaseActivity.') - activity = [activity.to_dict()] + activities = [activity.to_dict()] else: - activity = [] + activities = [] if status == 'idle': since = int(time.time() * 1000) @@ -622,7 +673,7 @@ class DiscordWebSocket: payload = { 'op': self.PRESENCE, 'd': { - 'activities': activity, + 'activities': activities, 'afk': False, 'since': since, 'status': status @@ -633,7 +684,7 @@ class DiscordWebSocket: _log.debug('Sending "%s" to change status', sent) await self.send(sent) - async def request_chunks(self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None): + async def request_chunks(self, guild_id: int, query: Optional[str] = None, *, limit: int, user_ids: Optional[List[int]] = None, presences: bool = False, nonce: Optional[int] = None) -> None: payload = { 'op': self.REQUEST_MEMBERS, 'd': { @@ -655,7 +706,7 @@ class DiscordWebSocket: await self.send_as_json(payload) - async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False): + async def voice_state(self, guild_id: int, channel_id: int, self_mute: bool = False, self_deaf: bool = False) -> None: payload = { 'op': self.VOICE_STATE, 'd': { @@ -669,7 +720,7 @@ class DiscordWebSocket: _log.debug('Updating our voice state to %s.', payload) await self.send_as_json(payload) - async def close(self, code=4000): + async def close(self, code: int = 4000) -> None: if self._keep_alive: self._keep_alive.stop() self._keep_alive = None @@ -721,25 +772,31 @@ class DiscordVoiceWebSocket: CLIENT_CONNECT = 12 CLIENT_DISCONNECT = 13 - def __init__(self, socket, loop, *, hook=None): - self.ws = socket - self.loop = loop - self._keep_alive = None - self._close_code = None - self.secret_key = None + def __init__(self, socket: aiohttp.ClientWebSocketResponse, loop: asyncio.AbstractEventLoop, *, hook: Optional[Coro] = None) -> None: + self.ws: aiohttp.ClientWebSocketResponse = socket + self.loop: asyncio.AbstractEventLoop = loop + self._keep_alive: VoiceKeepAliveHandler = utils.MISSING + self._close_code: Optional[int] = None + self.secret_key: Optional[List[int]] = None + self.gateway: str = utils.MISSING + self._connection: VoiceClient = utils.MISSING + self._max_heartbeat_timeout: float = utils.MISSING + self.thread_id: int = utils.MISSING if hook: - self._hook = hook + # we want to redeclare self._hook + self._hook = hook # type: ignore - async def _hook(self, *args): + async def _hook(self, *args: Any) -> Any: pass - async def send_as_json(self, data): + + async def send_as_json(self, data) -> None: _log.debug('Sending voice websocket frame: %s.', data) await self.ws.send_str(utils._to_json(data)) send_heartbeat = send_as_json - async def resume(self): + async def resume(self) -> None: state = self._connection payload = { 'op': self.RESUME, @@ -765,7 +822,7 @@ class DiscordVoiceWebSocket: await self.send_as_json(payload) @classmethod - async def from_client(cls, client, *, resume=False, hook=None): + async def from_client(cls: Type[DVWS], client: VoiceClient, *, resume: bool = False, hook: Optional[Coro] = None) -> DVWS: """Creates a voice websocket for the :class:`VoiceClient`.""" gateway = 'wss://' + client.endpoint + '/?v=4' http = client._state.http @@ -783,7 +840,7 @@ class DiscordVoiceWebSocket: return ws - async def select_protocol(self, ip, port, mode): + async def select_protocol(self, ip, port, mode) -> None: payload = { 'op': self.SELECT_PROTOCOL, 'd': { @@ -798,7 +855,7 @@ class DiscordVoiceWebSocket: await self.send_as_json(payload) - async def client_connect(self): + async def client_connect(self) -> None: payload = { 'op': self.CLIENT_CONNECT, 'd': { @@ -808,7 +865,7 @@ class DiscordVoiceWebSocket: await self.send_as_json(payload) - async def speak(self, state=SpeakingState.voice): + async def speak(self, state=SpeakingState.voice) -> None: payload = { 'op': self.SPEAKING, 'd': { @@ -819,7 +876,8 @@ class DiscordVoiceWebSocket: await self.send_as_json(payload) - async def received_message(self, msg): + + async def received_message(self, msg) -> None: _log.debug('Voice websocket frame received: %s', msg) op = msg['op'] data = msg.get('d') @@ -840,7 +898,7 @@ class DiscordVoiceWebSocket: await self._hook(self, msg) - async def initial_connection(self, data): + async def initial_connection(self, data) -> None: state = self._connection state.ssrc = data['ssrc'] state.voice_port = data['port'] @@ -871,13 +929,13 @@ class DiscordVoiceWebSocket: _log.info('selected the voice protocol for use (%s)', mode) @property - def latency(self): + def latency(self) -> float: """:class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds.""" heartbeat = self._keep_alive return float('inf') if heartbeat is None else heartbeat.latency @property - def average_latency(self): + def average_latency(self) -> float: """:class:`list`: Average of last 20 HEARTBEAT latencies.""" heartbeat = self._keep_alive if heartbeat is None or not heartbeat.recent_ack_latencies: @@ -885,13 +943,14 @@ class DiscordVoiceWebSocket: return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies) - async def load_secret_key(self, data): + + async def load_secret_key(self, data) -> None: _log.info('received secret key for voice connection') self.secret_key = self._connection.secret_key = data.get('secret_key') await self.speak() await self.speak(False) - async def poll_event(self): + async def poll_event(self) -> None: # This exception is handled up the chain msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0) if msg.type is aiohttp.WSMsgType.TEXT: @@ -903,7 +962,7 @@ class DiscordVoiceWebSocket: _log.debug('Received %s', msg) raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code) - async def close(self, code=1000): + async def close(self, code: int = 1000) -> None: if self._keep_alive is not None: self._keep_alive.stop() diff --git a/discord/voice_client.py b/discord/voice_client.py index d382a74d..123dd29b 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -255,6 +255,9 @@ class VoiceClient(VoiceProtocol): self.encoder: Encoder = MISSING self._lite_nonce: int = 0 self.ws: DiscordVoiceWebSocket = MISSING + self.ip: str = MISSING + self.port: Tuple[Any, ...] = MISSING + warn_nacl = not has_nacl supported_modes: Tuple[SupportedModes, ...] = ( -- 2.47.2 From 14b3188bb897003666ea641370043e23ea0682dd Mon Sep 17 00:00:00 2001 From: Moksej <58531286+TheMoksej@users.noreply.github.com> Date: Sun, 5 Sep 2021 13:58:10 +0200 Subject: [PATCH 25/64] remove unnecessary await --- discord/threads.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/discord/threads.py b/discord/threads.py index a06ef219..49bba831 100644 --- a/discord/threads.py +++ b/discord/threads.py @@ -843,4 +843,4 @@ class ThreadMember(Hashable): The member or ``None`` if not found. """ - return await self.thread.guild.get_member(self.id) + return self.thread.guild.get_member(self.id) -- 2.47.2 From 53a6b2cb4577ae7c4787bd9dabea72e5551ecc32 Mon Sep 17 00:00:00 2001 From: Gnome! <45660393+Gnome-py@users.noreply.github.com> Date: Sun, 5 Sep 2021 18:37:51 +0100 Subject: [PATCH 26/64] Revert "Merge pull request #12" (#56) This reverts commit 42c0a8d8a5840c00185e367933e61e2565bf7305. --- discord/activity.py | 10 ++++----- discord/ext/commands/bot.py | 16 ++++++++------ discord/ext/commands/context.py | 6 ++++-- discord/ext/commands/converter.py | 21 +++++++++--------- discord/ext/commands/view.py | 15 +++++++------ discord/ext/tasks/__init__.py | 25 +++++++++++++++------ discord/player.py | 13 +++++------ discord/ui/button.py | 16 +++++++------- discord/utils.py | 36 ++++++++++++++----------------- 9 files changed, 84 insertions(+), 74 deletions(-) diff --git a/discord/activity.py b/discord/activity.py index cba61f38..51205377 100644 --- a/discord/activity.py +++ b/discord/activity.py @@ -794,13 +794,13 @@ class CustomActivity(BaseActivity): return hash((self.name, str(self.emoji))) def __str__(self) -> str: - if not self.emoji: + if self.emoji: + if self.name: + return f'{self.emoji} {self.name}' + return str(self.emoji) + else: return str(self.name) - if self.name: - return f'{self.emoji} {self.name}' - return str(self.emoji) - def __repr__(self) -> str: return f'' diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index b3a1fb57..e03562b6 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -43,7 +43,6 @@ from .context import Context from . import errors from .help import HelpCommand, DefaultHelpCommand from .cog import Cog -from discord.utils import raise_expected_coro if TYPE_CHECKING: import importlib.machinery @@ -425,9 +424,11 @@ class BotBase(GroupMixin): TypeError The coroutine passed is not actually a coroutine. """ - return raise_expected_coro( - coro, 'The pre-invoke hook must be a coroutine.' - ) + if not asyncio.iscoroutinefunction(coro): + raise TypeError('The pre-invoke hook must be a coroutine.') + + self._before_invoke = coro + return coro def after_invoke(self, coro: CFT) -> CFT: r"""A decorator that registers a coroutine as a post-invoke hook. @@ -456,10 +457,11 @@ class BotBase(GroupMixin): TypeError The coroutine passed is not actually a coroutine. """ - return raise_expected_coro( - coro, 'The post-invoke hook must be a coroutine.' - ) + if not asyncio.iscoroutinefunction(coro): + raise TypeError('The post-invoke hook must be a coroutine.') + self._after_invoke = coro + return coro # listener registration diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 158c84ea..fa16c74a 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -21,7 +21,6 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - from __future__ import annotations import inspect @@ -62,7 +61,10 @@ T = TypeVar('T') BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]") CogT = TypeVar('CogT', bound="Cog") -P = ParamSpec('P') if TYPE_CHECKING else TypeVar('P') +if TYPE_CHECKING: + P = ParamSpec('P') +else: + P = TypeVar('P') class Context(discord.abc.Messageable, Generic[BotT]): diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index ce7037a4..5740a188 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -353,14 +353,14 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): @staticmethod def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]: - if guild_id is None: - return ctx.bot.get_channel(channel_id) if channel_id else ctx.channel - - guild = ctx.bot.get_guild(guild_id) - if guild is not None and channel_id is not None: - return guild._resolve_channel(channel_id) # type: ignore + if guild_id is not None: + guild = ctx.bot.get_guild(guild_id) + if guild is not None and channel_id is not None: + return guild._resolve_channel(channel_id) # type: ignore + else: + return None else: - return None + return ctx.bot.get_channel(channel_id) if channel_id else ctx.channel async def convert(self, ctx: Context, argument: str) -> discord.PartialMessage: guild_id, message_id, channel_id = self._get_id_matches(ctx, argument) @@ -754,8 +754,8 @@ class GuildConverter(IDConverter[discord.Guild]): if result is None: result = discord.utils.get(ctx.bot.guilds, name=argument) - if result is None: - raise GuildNotFound(argument) + if result is None: + raise GuildNotFound(argument) return result @@ -939,7 +939,8 @@ class clean_content(Converter[str]): def repl(match: re.Match) -> str: type = match[1] id = int(match[2]) - return transforms[type](id) + transformed = transforms[type](id) + return transformed result = re.sub(r'<(@[!&]?|#)([0-9]{15,20})>', repl, argument) if self.escape_markdown: diff --git a/discord/ext/commands/view.py b/discord/ext/commands/view.py index 39cc35f7..a7dc7236 100644 --- a/discord/ext/commands/view.py +++ b/discord/ext/commands/view.py @@ -82,7 +82,9 @@ class StringView: def skip_string(self, string): strlen = len(string) if self.buffer[self.index:self.index + strlen] == string: - return self._return_index(strlen, True) + self.previous = self.index + self.index += strlen + return True return False def read_rest(self): @@ -93,7 +95,9 @@ class StringView: def read(self, n): result = self.buffer[self.index:self.index + n] - return self._return_index(n, result) + self.previous = self.index + self.index += n + return result def get(self): try: @@ -101,12 +105,9 @@ class StringView: except IndexError: result = None - return self._return_index(1, result) - - def _return_index(self, arg0, arg1): self.previous = self.index - self.index += arg0 - return arg1 + self.index += 1 + return result def get_word(self): pos = 0 diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index 9518390e..5b78f10e 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -46,9 +46,7 @@ import traceback from collections.abc import Sequence from discord.backoff import ExponentialBackoff -from discord.utils import MISSING, raise_expected_coro - - +from discord.utils import MISSING __all__ = ( 'loop', @@ -490,7 +488,11 @@ class Loop(Generic[LF]): The function was not a coroutine. """ - return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.') + if not inspect.iscoroutinefunction(coro): + raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') + + self._before_loop = coro + return coro def after_loop(self, coro: FT) -> FT: """A decorator that register a coroutine to be called after the loop finished running. @@ -514,7 +516,11 @@ class Loop(Generic[LF]): The function was not a coroutine. """ - return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.') + if not inspect.iscoroutinefunction(coro): + raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') + + self._after_loop = coro + return coro def error(self, coro: ET) -> ET: """A decorator that registers a coroutine to be called if the task encounters an unhandled exception. @@ -536,7 +542,11 @@ class Loop(Generic[LF]): TypeError The function was not a coroutine. """ - return raise_expected_coro(coro, f'Expected coroutine function, received {coro.__class__.__name__!r}.') + if not inspect.iscoroutinefunction(coro): + raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') + + self._error = coro # type: ignore + return coro def _get_next_sleep_time(self) -> datetime.datetime: if self._sleep is not MISSING: @@ -604,7 +614,8 @@ class Loop(Generic[LF]): ) ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc)) - return sorted(set(ret)) + ret = sorted(set(ret)) # de-dupe and sort times + return ret def change_interval( self, diff --git a/discord/player.py b/discord/player.py index 79579c8d..8098d3e3 100644 --- a/discord/player.py +++ b/discord/player.py @@ -21,7 +21,6 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - from __future__ import annotations import threading @@ -64,7 +63,10 @@ __all__ = ( CREATE_NO_WINDOW: int -CREATE_NO_WINDOW = 0 if sys.platform != 'win32' else 0x08000000 +if sys.platform != 'win32': + CREATE_NO_WINDOW = 0 +else: + CREATE_NO_WINDOW = 0x08000000 class AudioSource: """Represents an audio stream. @@ -524,12 +526,7 @@ class FFmpegOpusAudio(FFmpegAudio): @staticmethod def _probe_codec_native(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]: - exe = ( - executable[:2] + 'probe' - if executable in {'ffmpeg', 'avconv'} - else executable - ) - + exe = executable[:2] + 'probe' if executable in ('ffmpeg', 'avconv') else executable args = [exe, '-v', 'quiet', '-print_format', 'json', '-show_streams', '-select_streams', 'a:0', source] output = subprocess.check_output(args, timeout=20) codec = bitrate = None diff --git a/discord/ui/button.py b/discord/ui/button.py index 0b16e87e..fedeac68 100644 --- a/discord/ui/button.py +++ b/discord/ui/button.py @@ -185,15 +185,15 @@ class Button(Item[V]): @emoji.setter def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]): # type: ignore - if value is None: - self._underlying.emoji = None - - elif isinstance(value, str): - self._underlying.emoji = PartialEmoji.from_str(value) - elif isinstance(value, _EmojiTag): - self._underlying.emoji = value._to_partial() + if value is not None: + if isinstance(value, str): + self._underlying.emoji = PartialEmoji.from_str(value) + elif isinstance(value, _EmojiTag): + self._underlying.emoji = value._to_partial() + else: + raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__} instead') else: - raise TypeError(f'expected str, Emoji, or PartialEmoji, received {value.__class__} instead') + self._underlying.emoji = None @classmethod def from_component(cls: Type[B], button: ButtonComponent) -> B: diff --git a/discord/utils.py b/discord/utils.py index cad99da6..4360b77a 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -499,14 +499,14 @@ else: def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float: reset_after: Optional[str] = request.headers.get('X-Ratelimit-Reset-After') - if not use_clock and reset_after: + if use_clock or not reset_after: + utc = datetime.timezone.utc + now = datetime.datetime.now(utc) + reset = datetime.datetime.fromtimestamp(float(request.headers['X-Ratelimit-Reset']), utc) + return (reset - now).total_seconds() + else: return float(reset_after) - utc = datetime.timezone.utc - now = datetime.datetime.now(utc) - reset = datetime.datetime.fromtimestamp(float(request.headers['X-Ratelimit-Reset']), utc) - return (reset - now).total_seconds() - async def maybe_coroutine(f, *args, **kwargs): value = f(*args, **kwargs) @@ -659,10 +659,11 @@ def resolve_invite(invite: Union[Invite, str]) -> str: if isinstance(invite, Invite): return invite.code - rx = r'(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)' - m = re.match(rx, invite) - if m: - return m.group(1) + else: + rx = r'(?:https?\:\/\/)?discord(?:\.gg|(?:app)?\.com\/invite)\/(.+)' + m = re.match(rx, invite) + if m: + return m.group(1) return invite @@ -686,10 +687,11 @@ def resolve_template(code: Union[Template, str]) -> str: if isinstance(code, Template): return code.code - rx = r'(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)' - m = re.match(rx, code) - if m: - return m.group(1) + else: + rx = r'(?:https?\:\/\/)?discord(?:\.new|(?:app)?\.com\/template)\/(.+)' + m = re.match(rx, code) + if m: + return m.group(1) return code @@ -1015,9 +1017,3 @@ def format_dt(dt: datetime.datetime, /, style: Optional[TimestampStyle] = None) if style is None: return f'' return f'' - - -def raise_expected_coro(coro, error: str)-> TypeError: - if not asyncio.iscoroutinefunction(coro): - raise TypeError(error) - return coro -- 2.47.2 From 1833e984cefa112eac4b218e12f4b7e3dd504590 Mon Sep 17 00:00:00 2001 From: IAmTomahawkx Date: Sun, 5 Sep 2021 14:32:51 -0700 Subject: [PATCH 27/64] add black workflow, change our code formats. closes #43 --- .github/CONTRIBUTING.md | 6 +++--- .github/PULL_REQUEST_TEMPLATE.md | 2 ++ .github/workflows/black.yml | 25 +++++++++++++++++++++++++ 3 files changed, 30 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/black.yml diff --git a/.github/CONTRIBUTING.md b/.github/CONTRIBUTING.md index bb457a35..2311afca 100644 --- a/.github/CONTRIBUTING.md +++ b/.github/CONTRIBUTING.md @@ -34,13 +34,13 @@ If the bug report is missing this information then it'll take us longer to fix t ## Submitting a Pull Request -Submitting a pull request is fairly simple, just make sure it focuses on a single aspect and doesn't manage to have scope creep and it's probably good to go. It would be incredibly lovely if the style is consistent to that found in the project. This project follows PEP-8 guidelines (mostly) with a column limit of 125. +Submitting a pull request is fairly simple, just make sure it focuses on a single aspect and doesn't manage to have scope creep, and it's probably good to go. It would be incredibly lovely if the style is consistent to that found in the project. This project follows the black code format, with a line length limit of `120` ### Git Commit Guidelines - Use present tense (e.g. "Add feature" not "Added feature") -- Limit all lines to 72 characters or less. -- Reference issues or pull requests outside of the first line. +- Limit all lines to 120 characters or fewer. +- Reference issues or pull requests outside the first line. - Please use the shorthand `#123` and not the full URL. - Commits regarding the commands extension must be prefixed with `[commands]` diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md index 55941f4e..212d0392 100644 --- a/.github/PULL_REQUEST_TEMPLATE.md +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -1,3 +1,5 @@ + + ## Summary diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml new file mode 100644 index 00000000..b95fb1dd --- /dev/null +++ b/.github/workflows/black.yml @@ -0,0 +1,25 @@ +name: Lint + +on: [push] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - name: checkout + uses: actions/checkout@v2 + + - name: Setup Python + uses: actions/setup-python@v1 + with: + python-version: 3.8 + + - name: install black + run: pip install black + + - name: run linter + uses: wearerequired/lint-action@v1 + with: + black: true + black_args: ". --line-length 120" + auto_fix: true \ No newline at end of file -- 2.47.2 From 7513c2138f4e99d677dd258e6fe9d3efcdb3fa86 Mon Sep 17 00:00:00 2001 From: Lint Action Date: Sun, 5 Sep 2021 21:34:20 +0000 Subject: [PATCH 28/64] Fix code style issues with Black --- discord/__init__.py | 14 +- discord/__main__.py | 159 +++-- discord/abc.py | 148 ++-- discord/activity.py | 247 +++---- discord/appinfo.py | 118 ++-- discord/asset.py | 83 ++- discord/audit_logs.py | 141 ++-- discord/backoff.py | 9 +- discord/channel.py | 272 ++++---- discord/client.py | 197 +++--- discord/colour.py | 64 +- discord/components.py | 148 ++-- discord/context_managers.py | 14 +- discord/embeds.py | 166 +++-- discord/emoji.py | 60 +- discord/enums.py | 171 ++--- discord/errors.py | 68 +- discord/ext/commands/_types.py | 6 +- discord/ext/commands/bot.py | 95 +-- discord/ext/commands/cog.py | 50 +- discord/ext/commands/context.py | 29 +- discord/ext/commands/converter.py | 196 +++--- discord/ext/commands/cooldowns.py | 56 +- discord/ext/commands/core.py | 357 ++++++---- discord/ext/commands/errors.py | 314 ++++++--- discord/ext/commands/flags.py | 78 ++- discord/ext/commands/help.py | 110 +-- discord/ext/commands/view.py | 22 +- discord/ext/tasks/__init__.py | 50 +- discord/file.py | 18 +- discord/flags.py | 36 +- discord/gateway.py | 452 ++++++------ discord/guild.py | 447 ++++++------ discord/http.py | 879 +++++++++++++----------- discord/integrations.py | 118 ++-- discord/interactions.py | 114 +-- discord/invite.py | 138 ++-- discord/iterators.py | 89 +-- discord/member.py | 192 +++--- discord/mentions.py | 28 +- discord/message.py | 381 +++++----- discord/mixins.py | 6 +- discord/object.py | 10 +- discord/oggparse.py | 30 +- discord/opus.py | 191 ++--- discord/partial_emoji.py | 45 +- discord/permissions.py | 34 +- discord/player.py | 155 +++-- discord/raw_models.py | 93 ++- discord/reaction.py | 22 +- discord/role.py | 80 +-- discord/shard.py | 44 +- discord/stage_instance.py | 46 +- discord/state.py | 560 +++++++-------- discord/sticker.py | 120 ++-- discord/team.py | 32 +- discord/template.py | 62 +- discord/threads.py | 132 ++-- discord/types/activity.py | 2 +- discord/types/appinfo.py | 5 + discord/types/audit_log.py | 105 +-- discord/types/embed.py | 13 +- discord/types/guild.py | 44 +- discord/types/integration.py | 2 +- discord/types/interactions.py | 4 +- discord/types/message.py | 2 +- discord/types/raw_models.py | 3 +- discord/types/team.py | 2 + discord/types/voice.py | 2 +- discord/ui/button.py | 46 +- discord/ui/item.py | 16 +- discord/ui/select.py | 47 +- discord/ui/view.py | 32 +- discord/user.py | 80 +-- discord/utils.py | 147 ++-- discord/voice_client.py | 98 +-- discord/webhook/async_.py | 288 ++++---- discord/webhook/sync.py | 132 ++-- discord/welcome_screen.py | 42 +- discord/widget.py | 75 +- docs/conf.py | 229 +++--- docs/extensions/attributetable.py | 149 ++-- docs/extensions/builder.py | 53 +- docs/extensions/details.py | 25 +- docs/extensions/exception_hierarchy.py | 13 +- docs/extensions/nitpick_file_ignorer.py | 9 +- docs/extensions/resourcelinks.py | 18 +- examples/background_task.py | 14 +- examples/background_task_asyncio.py | 11 +- examples/basic_bot.py | 39 +- examples/basic_voice.py | 63 +- examples/converters.py | 24 +- examples/custom_context.py | 12 +- examples/deleted.py | 16 +- examples/edits.py | 16 +- examples/guessing_game.py | 20 +- examples/new_member.py | 9 +- examples/reaction_roles.py | 12 +- examples/reply.py | 12 +- examples/secret.py | 40 +- examples/views/confirm.py | 25 +- examples/views/counter.py | 13 +- examples/views/dropdown.py | 27 +- examples/views/ephemeral.py | 23 +- examples/views/link.py | 14 +- examples/views/persistent.py | 21 +- examples/views/tic_tac_toe.py | 23 +- setup.py | 114 +-- 108 files changed, 5369 insertions(+), 4858 deletions(-) diff --git a/discord/__init__.py b/discord/__init__.py index 288da41b..f0625cf8 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -9,13 +9,13 @@ A basic wrapper for the Discord API. """ -__title__ = 'discord' -__author__ = 'Rapptz' -__license__ = 'MIT' -__copyright__ = 'Copyright 2015-present Rapptz' -__version__ = '2.0.0a' +__title__ = "discord" +__author__ = "Rapptz" +__license__ = "MIT" +__copyright__ = "Copyright 2015-present Rapptz" +__version__ = "2.0.0a" -__path__ = __import__('pkgutil').extend_path(__path__, __name__) +__path__ = __import__("pkgutil").extend_path(__path__, __name__) import logging from typing import NamedTuple, Literal @@ -70,6 +70,6 @@ class VersionInfo(NamedTuple): serial: int -version_info: VersionInfo = VersionInfo(major=2, minor=0, micro=0, releaselevel='alpha', serial=0) +version_info: VersionInfo = VersionInfo(major=2, minor=0, micro=0, releaselevel="alpha", serial=0) logging.getLogger(__name__).addHandler(logging.NullHandler()) diff --git a/discord/__main__.py b/discord/__main__.py index 513b0cb3..6e93c09b 100644 --- a/discord/__main__.py +++ b/discord/__main__.py @@ -31,26 +31,29 @@ import pkg_resources import aiohttp import platform + def show_version(): entries = [] - entries.append('- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(sys.version_info)) + entries.append("- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format(sys.version_info)) version_info = discord.version_info - entries.append('- discord.py v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(version_info)) - if version_info.releaselevel != 'final': - pkg = pkg_resources.get_distribution('discord.py') + entries.append("- discord.py v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format(version_info)) + if version_info.releaselevel != "final": + pkg = pkg_resources.get_distribution("discord.py") if pkg: - entries.append(f' - discord.py pkg_resources: v{pkg.version}') + entries.append(f" - discord.py pkg_resources: v{pkg.version}") - entries.append(f'- aiohttp v{aiohttp.__version__}') + entries.append(f"- aiohttp v{aiohttp.__version__}") uname = platform.uname() - entries.append('- system info: {0.system} {0.release} {0.version}'.format(uname)) - print('\n'.join(entries)) + entries.append("- system info: {0.system} {0.release} {0.version}".format(uname)) + print("\n".join(entries)) + def core(parser, args): if args.version: show_version() + _bot_template = """#!/usr/bin/env python3 from discord.ext import commands @@ -120,7 +123,7 @@ def setup(bot): bot.add_cog({name}(bot)) ''' -_cog_extras = ''' +_cog_extras = """ def cog_unload(self): # clean up logic goes here pass @@ -149,22 +152,22 @@ _cog_extras = ''' # called after a command is called here pass -''' +""" # certain file names and directory names are forbidden # see: https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247%28v=vs.85%29.aspx # although some of this doesn't apply to Linux, we might as well be consistent _base_table = { - '<': '-', - '>': '-', - ':': '-', - '"': '-', + "<": "-", + ">": "-", + ":": "-", + '"': "-", # '/': '-', these are fine # '\\': '-', - '|': '-', - '?': '-', - '*': '-', + "|": "-", + "?": "-", + "*": "-", } # NUL (0) and 1-31 are disallowed @@ -172,21 +175,45 @@ _base_table.update((chr(i), None) for i in range(32)) _translation_table = str.maketrans(_base_table) + def to_path(parser, name, *, replace_spaces=False): if isinstance(name, Path): return name - if sys.platform == 'win32': - forbidden = ('CON', 'PRN', 'AUX', 'NUL', 'COM1', 'COM2', 'COM3', 'COM4', 'COM5', 'COM6', 'COM7', \ - 'COM8', 'COM9', 'LPT1', 'LPT2', 'LPT3', 'LPT4', 'LPT5', 'LPT6', 'LPT7', 'LPT8', 'LPT9') + if sys.platform == "win32": + forbidden = ( + "CON", + "PRN", + "AUX", + "NUL", + "COM1", + "COM2", + "COM3", + "COM4", + "COM5", + "COM6", + "COM7", + "COM8", + "COM9", + "LPT1", + "LPT2", + "LPT3", + "LPT4", + "LPT5", + "LPT6", + "LPT7", + "LPT8", + "LPT9", + ) if len(name) <= 4 and name.upper() in forbidden: - parser.error('invalid directory name given, use a different one') + parser.error("invalid directory name given, use a different one") name = name.translate(_translation_table) if replace_spaces: - name = name.replace(' ', '-') + name = name.replace(" ", "-") return Path(name) + def newbot(parser, args): new_directory = to_path(parser, args.directory) / to_path(parser, args.name) @@ -195,106 +222,114 @@ def newbot(parser, args): try: new_directory.mkdir(exist_ok=True, parents=True) except OSError as exc: - parser.error(f'could not create our bot directory ({exc})') + parser.error(f"could not create our bot directory ({exc})") - cogs = new_directory / 'cogs' + cogs = new_directory / "cogs" try: cogs.mkdir(exist_ok=True) - init = cogs / '__init__.py' + init = cogs / "__init__.py" init.touch() except OSError as exc: - print(f'warning: could not create cogs directory ({exc})') + print(f"warning: could not create cogs directory ({exc})") try: - with open(str(new_directory / 'config.py'), 'w', encoding='utf-8') as fp: + with open(str(new_directory / "config.py"), "w", encoding="utf-8") as fp: fp.write('token = "place your token here"\ncogs = []\n') except OSError as exc: - parser.error(f'could not create config file ({exc})') + parser.error(f"could not create config file ({exc})") try: - with open(str(new_directory / 'bot.py'), 'w', encoding='utf-8') as fp: - base = 'Bot' if not args.sharded else 'AutoShardedBot' + with open(str(new_directory / "bot.py"), "w", encoding="utf-8") as fp: + base = "Bot" if not args.sharded else "AutoShardedBot" fp.write(_bot_template.format(base=base, prefix=args.prefix)) except OSError as exc: - parser.error(f'could not create bot file ({exc})') + parser.error(f"could not create bot file ({exc})") if not args.no_git: try: - with open(str(new_directory / '.gitignore'), 'w', encoding='utf-8') as fp: + with open(str(new_directory / ".gitignore"), "w", encoding="utf-8") as fp: fp.write(_gitignore_template) except OSError as exc: - print(f'warning: could not create .gitignore file ({exc})') + print(f"warning: could not create .gitignore file ({exc})") + + print("successfully made bot at", new_directory) - print('successfully made bot at', new_directory) def newcog(parser, args): cog_dir = to_path(parser, args.directory) try: cog_dir.mkdir(exist_ok=True) except OSError as exc: - print(f'warning: could not create cogs directory ({exc})') + print(f"warning: could not create cogs directory ({exc})") directory = cog_dir / to_path(parser, args.name) - directory = directory.with_suffix('.py') + directory = directory.with_suffix(".py") try: - with open(str(directory), 'w', encoding='utf-8') as fp: - attrs = '' - extra = _cog_extras if args.full else '' + with open(str(directory), "w", encoding="utf-8") as fp: + attrs = "" + extra = _cog_extras if args.full else "" if args.class_name: name = args.class_name else: name = str(directory.stem) - if '-' in name or '_' in name: - translation = str.maketrans('-_', ' ') - name = name.translate(translation).title().replace(' ', '') + if "-" in name or "_" in name: + translation = str.maketrans("-_", " ") + name = name.translate(translation).title().replace(" ", "") else: name = name.title() if args.display_name: attrs += f', name="{args.display_name}"' if args.hide_commands: - attrs += ', command_attrs=dict(hidden=True)' + attrs += ", command_attrs=dict(hidden=True)" fp.write(_cog_template.format(name=name, extra=extra, attrs=attrs)) except OSError as exc: - parser.error(f'could not create cog file ({exc})') + parser.error(f"could not create cog file ({exc})") else: - print('successfully made cog at', directory) + print("successfully made cog at", directory) + def add_newbot_args(subparser): - parser = subparser.add_parser('newbot', help='creates a command bot project quickly') + parser = subparser.add_parser("newbot", help="creates a command bot project quickly") parser.set_defaults(func=newbot) - parser.add_argument('name', help='the bot project name') - parser.add_argument('directory', help='the directory to place it in (default: .)', nargs='?', default=Path.cwd()) - parser.add_argument('--prefix', help='the bot prefix (default: $)', default='$', metavar='') - parser.add_argument('--sharded', help='whether to use AutoShardedBot', action='store_true') - parser.add_argument('--no-git', help='do not create a .gitignore file', action='store_true', dest='no_git') + parser.add_argument("name", help="the bot project name") + parser.add_argument("directory", help="the directory to place it in (default: .)", nargs="?", default=Path.cwd()) + parser.add_argument("--prefix", help="the bot prefix (default: $)", default="$", metavar="") + parser.add_argument("--sharded", help="whether to use AutoShardedBot", action="store_true") + parser.add_argument("--no-git", help="do not create a .gitignore file", action="store_true", dest="no_git") + def add_newcog_args(subparser): - parser = subparser.add_parser('newcog', help='creates a new cog template quickly') + parser = subparser.add_parser("newcog", help="creates a new cog template quickly") parser.set_defaults(func=newcog) - parser.add_argument('name', help='the cog name') - parser.add_argument('directory', help='the directory to place it in (default: cogs)', nargs='?', default=Path('cogs')) - parser.add_argument('--class-name', help='the class name of the cog (default: )', dest='class_name') - parser.add_argument('--display-name', help='the cog name (default: )') - parser.add_argument('--hide-commands', help='whether to hide all commands in the cog', action='store_true') - parser.add_argument('--full', help='add all special methods as well', action='store_true') + parser.add_argument("name", help="the cog name") + parser.add_argument( + "directory", help="the directory to place it in (default: cogs)", nargs="?", default=Path("cogs") + ) + parser.add_argument("--class-name", help="the class name of the cog (default: )", dest="class_name") + parser.add_argument("--display-name", help="the cog name (default: )") + parser.add_argument("--hide-commands", help="whether to hide all commands in the cog", action="store_true") + parser.add_argument("--full", help="add all special methods as well", action="store_true") + def parse_args(): - parser = argparse.ArgumentParser(prog='discord', description='Tools for helping with discord.py') - parser.add_argument('-v', '--version', action='store_true', help='shows the library version') + parser = argparse.ArgumentParser(prog="discord", description="Tools for helping with discord.py") + parser.add_argument("-v", "--version", action="store_true", help="shows the library version") parser.set_defaults(func=core) - subparser = parser.add_subparsers(dest='subcommand', title='subcommands') + subparser = parser.add_subparsers(dest="subcommand", title="subcommands") add_newbot_args(subparser) add_newcog_args(subparser) return parser, parser.parse_args() + def main(): parser, args = parse_args() args.func(parser, args) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/discord/abc.py b/discord/abc.py index fd2dc4bb..196043e3 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -56,15 +56,15 @@ from .sticker import GuildSticker, StickerItem from . import utils __all__ = ( - 'Snowflake', - 'User', - 'PrivateChannel', - 'GuildChannel', - 'Messageable', - 'Connectable', + "Snowflake", + "User", + "PrivateChannel", + "GuildChannel", + "Messageable", + "Connectable", ) -T = TypeVar('T', bound=VoiceProtocol) +T = TypeVar("T", bound=VoiceProtocol) if TYPE_CHECKING: from datetime import datetime @@ -98,7 +98,7 @@ MISSING = utils.MISSING class _Undefined: def __repr__(self) -> str: - return 'see-below' + return "see-below" _undefined: Any = _Undefined() @@ -189,23 +189,23 @@ class PrivateChannel(Snowflake, Protocol): class _Overwrites: - __slots__ = ('id', 'allow', 'deny', 'type') + __slots__ = ("id", "allow", "deny", "type") ROLE = 0 MEMBER = 1 def __init__(self, data: PermissionOverwritePayload): - self.id: int = int(data['id']) - self.allow: int = int(data.get('allow', 0)) - self.deny: int = int(data.get('deny', 0)) - self.type: OverwriteType = data['type'] + self.id: int = int(data["id"]) + self.allow: int = int(data.get("allow", 0)) + self.deny: int = int(data.get("deny", 0)) + self.type: OverwriteType = data["type"] def _asdict(self) -> PermissionOverwritePayload: return { - 'id': self.id, - 'allow': str(self.allow), - 'deny': str(self.deny), - 'type': self.type, + "id": self.id, + "allow": str(self.allow), + "deny": str(self.deny), + "type": self.type, } def is_role(self) -> bool: @@ -215,7 +215,7 @@ class _Overwrites: return self.type == 1 -GCH = TypeVar('GCH', bound='GuildChannel') +GCH = TypeVar("GCH", bound="GuildChannel") class GuildChannel: @@ -276,7 +276,7 @@ class GuildChannel: reason: Optional[str], ) -> None: if position < 0: - raise InvalidArgument('Channel position cannot be less than 0.') + raise InvalidArgument("Channel position cannot be less than 0.") http = self._state.http bucket = self._sorting_bucket @@ -297,7 +297,7 @@ class GuildChannel: payload = [] for index, c in enumerate(channels): - d: Dict[str, Any] = {'id': c.id, 'position': index} + d: Dict[str, Any] = {"id": c.id, "position": index} if parent_id is not _undefined and c.id == self.id: d.update(parent_id=parent_id, lock_permissions=lock_permissions) payload.append(d) @@ -306,81 +306,81 @@ class GuildChannel: async def _edit(self, options: Dict[str, Any], reason: Optional[str]) -> Optional[ChannelPayload]: try: - parent = options.pop('category') + parent = options.pop("category") except KeyError: parent_id = _undefined else: parent_id = parent and parent.id try: - options['rate_limit_per_user'] = options.pop('slowmode_delay') + options["rate_limit_per_user"] = options.pop("slowmode_delay") except KeyError: pass try: - rtc_region = options.pop('rtc_region') + rtc_region = options.pop("rtc_region") except KeyError: pass else: - options['rtc_region'] = None if rtc_region is None else str(rtc_region) + options["rtc_region"] = None if rtc_region is None else str(rtc_region) try: - video_quality_mode = options.pop('video_quality_mode') + video_quality_mode = options.pop("video_quality_mode") except KeyError: pass else: - options['video_quality_mode'] = int(video_quality_mode) + options["video_quality_mode"] = int(video_quality_mode) - lock_permissions = options.pop('sync_permissions', False) + lock_permissions = options.pop("sync_permissions", False) try: - position = options.pop('position') + position = options.pop("position") except KeyError: if parent_id is not _undefined: if lock_permissions: category = self.guild.get_channel(parent_id) if category: - options['permission_overwrites'] = [c._asdict() for c in category._overwrites] - options['parent_id'] = parent_id + options["permission_overwrites"] = [c._asdict() for c in category._overwrites] + options["parent_id"] = parent_id elif lock_permissions and self.category_id is not None: # if we're syncing permissions on a pre-existing channel category without changing it # we need to update the permissions to point to the pre-existing category category = self.guild.get_channel(self.category_id) if category: - options['permission_overwrites'] = [c._asdict() for c in category._overwrites] + options["permission_overwrites"] = [c._asdict() for c in category._overwrites] else: await self._move(position, parent_id=parent_id, lock_permissions=lock_permissions, reason=reason) - overwrites = options.get('overwrites', None) + overwrites = options.get("overwrites", None) if overwrites is not None: perms = [] for target, perm in overwrites.items(): if not isinstance(perm, PermissionOverwrite): - raise InvalidArgument(f'Expected PermissionOverwrite received {perm.__class__.__name__}') + raise InvalidArgument(f"Expected PermissionOverwrite received {perm.__class__.__name__}") allow, deny = perm.pair() payload = { - 'allow': allow.value, - 'deny': deny.value, - 'id': target.id, + "allow": allow.value, + "deny": deny.value, + "id": target.id, } if isinstance(target, Role): - payload['type'] = _Overwrites.ROLE + payload["type"] = _Overwrites.ROLE else: - payload['type'] = _Overwrites.MEMBER + payload["type"] = _Overwrites.MEMBER perms.append(payload) - options['permission_overwrites'] = perms + options["permission_overwrites"] = perms try: - ch_type = options['type'] + ch_type = options["type"] except KeyError: pass else: if not isinstance(ch_type, ChannelType): - raise InvalidArgument('type field must be of type ChannelType') - options['type'] = ch_type.value + raise InvalidArgument("type field must be of type ChannelType") + options["type"] = ch_type.value if options: return await self._state.http.edit_channel(self.id, reason=reason, **options) @@ -390,7 +390,7 @@ class GuildChannel: everyone_index = 0 everyone_id = self.guild.id - for index, overridden in enumerate(data.get('permission_overwrites', [])): + for index, overridden in enumerate(data.get("permission_overwrites", [])): overwrite = _Overwrites(overridden) self._overwrites.append(overwrite) @@ -429,7 +429,7 @@ class GuildChannel: @property def mention(self) -> str: """:class:`str`: The string that allows you to mention the channel.""" - return f'<#{self.id}>' + return f"<#{self.id}>" @property def created_at(self) -> datetime: @@ -779,18 +779,18 @@ class GuildChannel: elif isinstance(target, Role): perm_type = _Overwrites.ROLE else: - raise InvalidArgument('target parameter must be either Member or Role') + raise InvalidArgument("target parameter must be either Member or Role") if overwrite is _undefined: if len(permissions) == 0: - raise InvalidArgument('No overwrite provided.') + raise InvalidArgument("No overwrite provided.") try: overwrite = PermissionOverwrite(**permissions) except (ValueError, TypeError): - raise InvalidArgument('Invalid permissions given to keyword arguments.') + raise InvalidArgument("Invalid permissions given to keyword arguments.") else: if len(permissions) > 0: - raise InvalidArgument('Cannot mix overwrite and keyword arguments.') + raise InvalidArgument("Cannot mix overwrite and keyword arguments.") # TODO: wait for event @@ -800,7 +800,7 @@ class GuildChannel: (allow, deny) = overwrite.pair() await http.edit_channel_permissions(self.id, target.id, allow.value, deny.value, perm_type, reason=reason) else: - raise InvalidArgument('Invalid overwrite type provided.') + raise InvalidArgument("Invalid overwrite type provided.") async def _clone_impl( self: GCH, @@ -809,9 +809,9 @@ class GuildChannel: name: Optional[str] = None, reason: Optional[str] = None, ) -> GCH: - base_attrs['permission_overwrites'] = [x._asdict() for x in self._overwrites] - base_attrs['parent_id'] = self.category_id - base_attrs['name'] = name or self.name + base_attrs["permission_overwrites"] = [x._asdict() for x in self._overwrites] + base_attrs["parent_id"] = self.category_id + base_attrs["name"] = name or self.name guild_id = self.guild.id cls = self.__class__ data = await self._state.http.create_channel(guild_id, self.type.value, reason=reason, **base_attrs) @@ -964,14 +964,14 @@ class GuildChannel: if not kwargs: return - beginning, end = kwargs.get('beginning'), kwargs.get('end') - before, after = kwargs.get('before'), kwargs.get('after') - offset = kwargs.get('offset', 0) + beginning, end = kwargs.get("beginning"), kwargs.get("end") + before, after = kwargs.get("before"), kwargs.get("after") + offset = kwargs.get("offset", 0) if sum(bool(a) for a in (beginning, end, before, after)) > 1: - raise InvalidArgument('Only one of [before, after, end, beginning] can be used.') + raise InvalidArgument("Only one of [before, after, end, beginning] can be used.") bucket = self._sorting_bucket - parent_id = kwargs.get('category', MISSING) + parent_id = kwargs.get("category", MISSING) # fmt: off channels: List[GuildChannel] if parent_id not in (MISSING, None): @@ -1011,14 +1011,14 @@ class GuildChannel: index = next((i + 1 for i, c in enumerate(channels) if c.id == after.id), None) if index is None: - raise InvalidArgument('Could not resolve appropriate move position') + raise InvalidArgument("Could not resolve appropriate move position") channels.insert(max((index + offset), 0), self) payload = [] - lock_permissions = kwargs.get('sync_permissions', False) - reason = kwargs.get('reason') + lock_permissions = kwargs.get("sync_permissions", False) + reason = kwargs.get("reason") for index, channel in enumerate(channels): - d = {'id': channel.id, 'position': index} + d = {"id": channel.id, "position": index} if parent_id is not MISSING and channel.id == self.id: d.update(parent_id=parent_id, lock_permissions=lock_permissions) payload.append(d) @@ -1332,14 +1332,14 @@ class Messageable: content = str(content) if content is not None else None if embed is not None and embeds is not None: - raise InvalidArgument('cannot pass both embed and embeds parameter to send()') + raise InvalidArgument("cannot pass both embed and embeds parameter to send()") if embed is not None: embed = embed.to_dict() elif embeds is not None: if len(embeds) > 10: - raise InvalidArgument('embeds parameter must be a list of up to 10 elements') + raise InvalidArgument("embeds parameter must be a list of up to 10 elements") embeds = [embed.to_dict() for embed in embeds] if stickers is not None: @@ -1355,28 +1355,30 @@ class Messageable: if mention_author is not None: allowed_mentions = allowed_mentions or AllowedMentions().to_dict() - allowed_mentions['replied_user'] = bool(mention_author) + allowed_mentions["replied_user"] = bool(mention_author) if reference is not None: try: reference = reference.to_message_reference_dict() except AttributeError: - raise InvalidArgument('reference parameter must be Message, MessageReference, or PartialMessage') from None + raise InvalidArgument( + "reference parameter must be Message, MessageReference, or PartialMessage" + ) from None if view: - if not hasattr(view, '__discord_ui_view__'): - raise InvalidArgument(f'view parameter must be View not {view.__class__!r}') + if not hasattr(view, "__discord_ui_view__"): + raise InvalidArgument(f"view parameter must be View not {view.__class__!r}") components = view.to_components() else: components = None if file is not None and files is not None: - raise InvalidArgument('cannot pass both file and files parameter to send()') + raise InvalidArgument("cannot pass both file and files parameter to send()") if file is not None: if not isinstance(file, File): - raise InvalidArgument('file parameter must be File') + raise InvalidArgument("file parameter must be File") try: data = await state.http.send_files( @@ -1397,9 +1399,9 @@ class Messageable: elif files is not None: if len(files) > 10: - raise InvalidArgument('files parameter must be a list of up to 10 elements') + raise InvalidArgument("files parameter must be a list of up to 10 elements") elif not all(isinstance(file, File) for file in files): - raise InvalidArgument('files parameter must be a list of File') + raise InvalidArgument("files parameter must be a list of File") try: data = await state.http.send_files( @@ -1666,13 +1668,13 @@ class Connectable(Protocol): state = self._state if state._get_voice_client(key_id): - raise ClientException('Already connected to a voice channel.') + raise ClientException("Already connected to a voice channel.") client = state._get_client() voice = cls(client, self) if not isinstance(voice, VoiceProtocol): - raise TypeError('Type must meet VoiceProtocol abstract base class.') + raise TypeError("Type must meet VoiceProtocol abstract base class.") state._add_voice_client(key_id, voice) diff --git a/discord/activity.py b/discord/activity.py index 51205377..7294dbd9 100644 --- a/discord/activity.py +++ b/discord/activity.py @@ -34,12 +34,12 @@ from .partial_emoji import PartialEmoji from .utils import _get_as_snowflake __all__ = ( - 'BaseActivity', - 'Activity', - 'Streaming', - 'Game', - 'Spotify', - 'CustomActivity', + "BaseActivity", + "Activity", + "Streaming", + "Game", + "Spotify", + "CustomActivity", ) """If curious, this is the current schema for an activity. @@ -119,10 +119,10 @@ class BaseActivity: .. versionadded:: 1.3 """ - __slots__ = ('_created_at',) + __slots__ = ("_created_at",) def __init__(self, **kwargs): - self._created_at: Optional[float] = kwargs.pop('created_at', None) + self._created_at: Optional[float] = kwargs.pop("created_at", None) @property def created_at(self) -> Optional[datetime.datetime]: @@ -199,58 +199,58 @@ class Activity(BaseActivity): """ __slots__ = ( - 'state', - 'details', - '_created_at', - 'timestamps', - 'assets', - 'party', - 'flags', - 'sync_id', - 'session_id', - 'type', - 'name', - 'url', - 'application_id', - 'emoji', - 'buttons', + "state", + "details", + "_created_at", + "timestamps", + "assets", + "party", + "flags", + "sync_id", + "session_id", + "type", + "name", + "url", + "application_id", + "emoji", + "buttons", ) def __init__(self, **kwargs): super().__init__(**kwargs) - self.state: Optional[str] = kwargs.pop('state', None) - self.details: Optional[str] = kwargs.pop('details', None) - self.timestamps: ActivityTimestamps = kwargs.pop('timestamps', {}) - self.assets: ActivityAssets = kwargs.pop('assets', {}) - self.party: ActivityParty = kwargs.pop('party', {}) - self.application_id: Optional[int] = _get_as_snowflake(kwargs, 'application_id') - self.name: Optional[str] = kwargs.pop('name', None) - self.url: Optional[str] = kwargs.pop('url', None) - self.flags: int = kwargs.pop('flags', 0) - self.sync_id: Optional[str] = kwargs.pop('sync_id', None) - self.session_id: Optional[str] = kwargs.pop('session_id', None) - self.buttons: List[ActivityButton] = kwargs.pop('buttons', []) + self.state: Optional[str] = kwargs.pop("state", None) + self.details: Optional[str] = kwargs.pop("details", None) + self.timestamps: ActivityTimestamps = kwargs.pop("timestamps", {}) + self.assets: ActivityAssets = kwargs.pop("assets", {}) + self.party: ActivityParty = kwargs.pop("party", {}) + self.application_id: Optional[int] = _get_as_snowflake(kwargs, "application_id") + self.name: Optional[str] = kwargs.pop("name", None) + self.url: Optional[str] = kwargs.pop("url", None) + self.flags: int = kwargs.pop("flags", 0) + self.sync_id: Optional[str] = kwargs.pop("sync_id", None) + self.session_id: Optional[str] = kwargs.pop("session_id", None) + self.buttons: List[ActivityButton] = kwargs.pop("buttons", []) - activity_type = kwargs.pop('type', -1) + activity_type = kwargs.pop("type", -1) self.type: ActivityType = ( activity_type if isinstance(activity_type, ActivityType) else try_enum(ActivityType, activity_type) ) - emoji = kwargs.pop('emoji', None) + emoji = kwargs.pop("emoji", None) self.emoji: Optional[PartialEmoji] = PartialEmoji.from_dict(emoji) if emoji is not None else None def __repr__(self) -> str: attrs = ( - ('type', self.type), - ('name', self.name), - ('url', self.url), - ('details', self.details), - ('application_id', self.application_id), - ('session_id', self.session_id), - ('emoji', self.emoji), + ("type", self.type), + ("name", self.name), + ("url", self.url), + ("details", self.details), + ("application_id", self.application_id), + ("session_id", self.session_id), + ("emoji", self.emoji), ) - inner = ' '.join('%s=%r' % t for t in attrs) - return f'' + inner = " ".join("%s=%r" % t for t in attrs) + return f"" def to_dict(self) -> Dict[str, Any]: ret: Dict[str, Any] = {} @@ -263,16 +263,16 @@ class Activity(BaseActivity): continue ret[attr] = value - ret['type'] = int(self.type) + ret["type"] = int(self.type) if self.emoji: - ret['emoji'] = self.emoji.to_dict() + ret["emoji"] = self.emoji.to_dict() return ret @property def start(self) -> Optional[datetime.datetime]: """Optional[:class:`datetime.datetime`]: When the user started doing this activity in UTC, if applicable.""" try: - timestamp = self.timestamps['start'] / 1000 + timestamp = self.timestamps["start"] / 1000 except KeyError: return None else: @@ -282,7 +282,7 @@ class Activity(BaseActivity): def end(self) -> Optional[datetime.datetime]: """Optional[:class:`datetime.datetime`]: When the user will stop doing this activity in UTC, if applicable.""" try: - timestamp = self.timestamps['end'] / 1000 + timestamp = self.timestamps["end"] / 1000 except KeyError: return None else: @@ -295,11 +295,11 @@ class Activity(BaseActivity): return None try: - large_image = self.assets['large_image'] + large_image = self.assets["large_image"] except KeyError: return None else: - return Asset.BASE + f'/app-assets/{self.application_id}/{large_image}.png' + return Asset.BASE + f"/app-assets/{self.application_id}/{large_image}.png" @property def small_image_url(self) -> Optional[str]: @@ -308,21 +308,21 @@ class Activity(BaseActivity): return None try: - small_image = self.assets['small_image'] + small_image = self.assets["small_image"] except KeyError: return None else: - return Asset.BASE + f'/app-assets/{self.application_id}/{small_image}.png' + return Asset.BASE + f"/app-assets/{self.application_id}/{small_image}.png" @property def large_image_text(self) -> Optional[str]: """Optional[:class:`str`]: Returns the large image asset hover text of this activity if applicable.""" - return self.assets.get('large_text', None) + return self.assets.get("large_text", None) @property def small_image_text(self) -> Optional[str]: """Optional[:class:`str`]: Returns the small image asset hover text of this activity if applicable.""" - return self.assets.get('small_text', None) + return self.assets.get("small_text", None) class Game(BaseActivity): @@ -359,20 +359,20 @@ class Game(BaseActivity): The game's name. """ - __slots__ = ('name', '_end', '_start') + __slots__ = ("name", "_end", "_start") def __init__(self, name: str, **extra): super().__init__(**extra) self.name: str = name try: - timestamps: ActivityTimestamps = extra['timestamps'] + timestamps: ActivityTimestamps = extra["timestamps"] except KeyError: self._start = 0 self._end = 0 else: - self._start = timestamps.get('start', 0) - self._end = timestamps.get('end', 0) + self._start = timestamps.get("start", 0) + self._end = timestamps.get("end", 0) @property def type(self) -> ActivityType: @@ -400,15 +400,15 @@ class Game(BaseActivity): return str(self.name) def __repr__(self) -> str: - return f'' + return f"" def to_dict(self) -> Dict[str, Any]: timestamps: Dict[str, Any] = {} if self._start: - timestamps['start'] = self._start + timestamps["start"] = self._start if self._end: - timestamps['end'] = self._end + timestamps["end"] = self._end # fmt: off return { @@ -473,16 +473,16 @@ class Streaming(BaseActivity): A dictionary comprising of similar keys than those in :attr:`Activity.assets`. """ - __slots__ = ('platform', 'name', 'game', 'url', 'details', 'assets') + __slots__ = ("platform", "name", "game", "url", "details", "assets") def __init__(self, *, name: Optional[str], url: str, **extra: Any): super().__init__(**extra) self.platform: Optional[str] = name - self.name: Optional[str] = extra.pop('details', name) - self.game: Optional[str] = extra.pop('state', None) + self.name: Optional[str] = extra.pop("details", name) + self.game: Optional[str] = extra.pop("state", None) self.url: str = url - self.details: Optional[str] = extra.pop('details', self.name) # compatibility - self.assets: ActivityAssets = extra.pop('assets', {}) + self.details: Optional[str] = extra.pop("details", self.name) # compatibility + self.assets: ActivityAssets = extra.pop("assets", {}) @property def type(self) -> ActivityType: @@ -496,7 +496,7 @@ class Streaming(BaseActivity): return str(self.name) def __repr__(self) -> str: - return f'' + return f"" @property def twitch_name(self): @@ -507,11 +507,11 @@ class Streaming(BaseActivity): """ try: - name = self.assets['large_image'] + name = self.assets["large_image"] except KeyError: return None else: - return name[7:] if name[:7] == 'twitch:' else None + return name[7:] if name[:7] == "twitch:" else None def to_dict(self) -> Dict[str, Any]: # fmt: off @@ -523,7 +523,7 @@ class Streaming(BaseActivity): } # fmt: on if self.details: - ret['details'] = self.details + ret["details"] = self.details return ret def __eq__(self, other: Any) -> bool: @@ -559,17 +559,17 @@ class Spotify: Returns the string 'Spotify'. """ - __slots__ = ('_state', '_details', '_timestamps', '_assets', '_party', '_sync_id', '_session_id', '_created_at') + __slots__ = ("_state", "_details", "_timestamps", "_assets", "_party", "_sync_id", "_session_id", "_created_at") def __init__(self, **data): - self._state: str = data.pop('state', '') - self._details: str = data.pop('details', '') - self._timestamps: Dict[str, int] = data.pop('timestamps', {}) - self._assets: ActivityAssets = data.pop('assets', {}) - self._party: ActivityParty = data.pop('party', {}) - self._sync_id: str = data.pop('sync_id') - self._session_id: str = data.pop('session_id') - self._created_at: Optional[float] = data.pop('created_at', None) + self._state: str = data.pop("state", "") + self._details: str = data.pop("details", "") + self._timestamps: Dict[str, int] = data.pop("timestamps", {}) + self._assets: ActivityAssets = data.pop("assets", {}) + self._party: ActivityParty = data.pop("party", {}) + self._sync_id: str = data.pop("sync_id") + self._session_id: str = data.pop("session_id") + self._created_at: Optional[float] = data.pop("created_at", None) @property def type(self) -> ActivityType: @@ -604,21 +604,21 @@ class Spotify: def to_dict(self) -> Dict[str, Any]: return { - 'flags': 48, # SYNC | PLAY - 'name': 'Spotify', - 'assets': self._assets, - 'party': self._party, - 'sync_id': self._sync_id, - 'session_id': self._session_id, - 'timestamps': self._timestamps, - 'details': self._details, - 'state': self._state, + "flags": 48, # SYNC | PLAY + "name": "Spotify", + "assets": self._assets, + "party": self._party, + "sync_id": self._sync_id, + "session_id": self._session_id, + "timestamps": self._timestamps, + "details": self._details, + "state": self._state, } @property def name(self) -> str: """:class:`str`: The activity's name. This will always return "Spotify".""" - return 'Spotify' + return "Spotify" def __eq__(self, other: Any) -> bool: return ( @@ -635,10 +635,10 @@ class Spotify: return hash(self._session_id) def __str__(self) -> str: - return 'Spotify' + return "Spotify" def __repr__(self) -> str: - return f'' + return f"" @property def title(self) -> str: @@ -648,7 +648,7 @@ class Spotify: @property def artists(self) -> List[str]: """List[:class:`str`]: The artists of the song being played.""" - return self._state.split('; ') + return self._state.split("; ") @property def artist(self) -> str: @@ -662,16 +662,16 @@ class Spotify: @property def album(self) -> str: """:class:`str`: The album that the song being played belongs to.""" - return self._assets.get('large_text', '') + return self._assets.get("large_text", "") @property def album_cover_url(self) -> str: """:class:`str`: The album cover image URL from Spotify's CDN.""" - large_image = self._assets.get('large_image', '') - if large_image[:8] != 'spotify:': - return '' + large_image = self._assets.get("large_image", "") + if large_image[:8] != "spotify:": + return "" album_image_id = large_image[8:] - return 'https://i.scdn.co/image/' + album_image_id + return "https://i.scdn.co/image/" + album_image_id @property def track_id(self) -> str: @@ -684,17 +684,17 @@ class Spotify: .. versionadded:: 2.0 """ - return f'https://open.spotify.com/track/{self.track_id}' + return f"https://open.spotify.com/track/{self.track_id}" @property def start(self) -> datetime.datetime: """:class:`datetime.datetime`: When the user started playing this song in UTC.""" - return datetime.datetime.fromtimestamp(self._timestamps['start'] / 1000, tz=datetime.timezone.utc) + return datetime.datetime.fromtimestamp(self._timestamps["start"] / 1000, tz=datetime.timezone.utc) @property def end(self) -> datetime.datetime: """:class:`datetime.datetime`: When the user will stop playing this song in UTC.""" - return datetime.datetime.fromtimestamp(self._timestamps['end'] / 1000, tz=datetime.timezone.utc) + return datetime.datetime.fromtimestamp(self._timestamps["end"] / 1000, tz=datetime.timezone.utc) @property def duration(self) -> datetime.timedelta: @@ -704,7 +704,7 @@ class Spotify: @property def party_id(self) -> str: """:class:`str`: The party ID of the listening party.""" - return self._party.get('id', '') + return self._party.get("id", "") class CustomActivity(BaseActivity): @@ -738,13 +738,13 @@ class CustomActivity(BaseActivity): The emoji to pass to the activity, if any. """ - __slots__ = ('name', 'emoji', 'state') + __slots__ = ("name", "emoji", "state") def __init__(self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any): super().__init__(**extra) self.name: Optional[str] = name - self.state: Optional[str] = extra.pop('state', None) - if self.name == 'Custom Status': + self.state: Optional[str] = extra.pop("state", None) + if self.name == "Custom Status": self.name = self.state self.emoji: Optional[PartialEmoji] @@ -757,7 +757,7 @@ class CustomActivity(BaseActivity): elif isinstance(emoji, PartialEmoji): self.emoji = emoji else: - raise TypeError(f'Expected str, PartialEmoji, or None, received {type(emoji)!r} instead.') + raise TypeError(f"Expected str, PartialEmoji, or None, received {type(emoji)!r} instead.") @property def type(self) -> ActivityType: @@ -770,18 +770,18 @@ class CustomActivity(BaseActivity): def to_dict(self) -> Dict[str, Any]: if self.name == self.state: o = { - 'type': ActivityType.custom.value, - 'state': self.name, - 'name': 'Custom Status', + "type": ActivityType.custom.value, + "state": self.name, + "name": "Custom Status", } else: o = { - 'type': ActivityType.custom.value, - 'name': self.name, + "type": ActivityType.custom.value, + "name": self.name, } if self.emoji: - o['emoji'] = self.emoji.to_dict() + o["emoji"] = self.emoji.to_dict() return o def __eq__(self, other: Any) -> bool: @@ -796,47 +796,50 @@ class CustomActivity(BaseActivity): def __str__(self) -> str: if self.emoji: if self.name: - return f'{self.emoji} {self.name}' + return f"{self.emoji} {self.name}" return str(self.emoji) else: return str(self.name) def __repr__(self) -> str: - return f'' + return f"" ActivityTypes = Union[Activity, Game, CustomActivity, Streaming, Spotify] + @overload def create_activity(data: ActivityPayload) -> ActivityTypes: ... + @overload def create_activity(data: None) -> None: ... + def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]: if not data: return None - game_type = try_enum(ActivityType, data.get('type', -1)) + game_type = try_enum(ActivityType, data.get("type", -1)) if game_type is ActivityType.playing: - if 'application_id' in data or 'session_id' in data: + if "application_id" in data or "session_id" in data: return Activity(**data) return Game(**data) elif game_type is ActivityType.custom: try: - name = data.pop('name') + name = data.pop("name") except KeyError: return Activity(**data) else: # we removed the name key from data already - return CustomActivity(name=name, **data) # type: ignore + return CustomActivity(name=name, **data) # type: ignore elif game_type is ActivityType.streaming: - if 'url' in data: + if "url" in data: # the url won't be None here - return Streaming(**data) # type: ignore + return Streaming(**data) # type: ignore return Activity(**data) - elif game_type is ActivityType.listening and 'sync_id' in data and 'session_id' in data: + elif game_type is ActivityType.listening and "sync_id" in data and "session_id" in data: return Spotify(**data) return Activity(**data) diff --git a/discord/appinfo.py b/discord/appinfo.py index de1f7a73..2fd3c58b 100644 --- a/discord/appinfo.py +++ b/discord/appinfo.py @@ -40,8 +40,8 @@ if TYPE_CHECKING: from .state import ConnectionState __all__ = ( - 'AppInfo', - 'PartialAppInfo', + "AppInfo", + "PartialAppInfo", ) @@ -115,58 +115,58 @@ class AppInfo: """ __slots__ = ( - '_state', - 'description', - 'id', - 'name', - 'rpc_origins', - 'bot_public', - 'bot_require_code_grant', - 'owner', - '_icon', - 'summary', - 'verify_key', - 'team', - 'guild_id', - 'primary_sku_id', - 'slug', - '_cover_image', - 'terms_of_service_url', - 'privacy_policy_url', + "_state", + "description", + "id", + "name", + "rpc_origins", + "bot_public", + "bot_require_code_grant", + "owner", + "_icon", + "summary", + "verify_key", + "team", + "guild_id", + "primary_sku_id", + "slug", + "_cover_image", + "terms_of_service_url", + "privacy_policy_url", ) def __init__(self, state: ConnectionState, data: AppInfoPayload): from .team import Team self._state: ConnectionState = state - self.id: int = int(data['id']) - self.name: str = data['name'] - self.description: str = data['description'] - self._icon: Optional[str] = data['icon'] - self.rpc_origins: List[str] = data['rpc_origins'] - self.bot_public: bool = data['bot_public'] - self.bot_require_code_grant: bool = data['bot_require_code_grant'] - self.owner: User = state.create_user(data['owner']) + self.id: int = int(data["id"]) + self.name: str = data["name"] + self.description: str = data["description"] + self._icon: Optional[str] = data["icon"] + self.rpc_origins: List[str] = data["rpc_origins"] + self.bot_public: bool = data["bot_public"] + self.bot_require_code_grant: bool = data["bot_require_code_grant"] + self.owner: User = state.create_user(data["owner"]) - team: Optional[TeamPayload] = data.get('team') + team: Optional[TeamPayload] = data.get("team") self.team: Optional[Team] = Team(state, team) if team else None - self.summary: str = data['summary'] - self.verify_key: str = data['verify_key'] + self.summary: str = data["summary"] + self.verify_key: str = data["verify_key"] - self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id') + self.guild_id: Optional[int] = utils._get_as_snowflake(data, "guild_id") - self.primary_sku_id: Optional[int] = utils._get_as_snowflake(data, 'primary_sku_id') - self.slug: Optional[str] = data.get('slug') - self._cover_image: Optional[str] = data.get('cover_image') - self.terms_of_service_url: Optional[str] = data.get('terms_of_service_url') - self.privacy_policy_url: Optional[str] = data.get('privacy_policy_url') + self.primary_sku_id: Optional[int] = utils._get_as_snowflake(data, "primary_sku_id") + self.slug: Optional[str] = data.get("slug") + self._cover_image: Optional[str] = data.get("cover_image") + self.terms_of_service_url: Optional[str] = data.get("terms_of_service_url") + self.privacy_policy_url: Optional[str] = data.get("privacy_policy_url") def __repr__(self) -> str: return ( - f'<{self.__class__.__name__} id={self.id} name={self.name!r} ' - f'description={self.description!r} public={self.bot_public} ' - f'owner={self.owner!r}>' + f"<{self.__class__.__name__} id={self.id} name={self.name!r} " + f"description={self.description!r} public={self.bot_public} " + f"owner={self.owner!r}>" ) @property @@ -174,7 +174,7 @@ class AppInfo: """Optional[:class:`.Asset`]: Retrieves the application's icon asset, if any.""" if self._icon is None: return None - return Asset._from_icon(self._state, self.id, self._icon, path='app') + return Asset._from_icon(self._state, self.id, self._icon, path="app") @property def cover_image(self) -> Optional[Asset]: @@ -195,6 +195,7 @@ class AppInfo: """ return self._state._get_guild(self.guild_id) + class PartialAppInfo: """Represents a partial AppInfo given by :func:`~discord.abc.GuildChannel.create_invite` @@ -222,26 +223,37 @@ class PartialAppInfo: The application's privacy policy URL, if set. """ - __slots__ = ('_state', 'id', 'name', 'description', 'rpc_origins', 'summary', 'verify_key', 'terms_of_service_url', 'privacy_policy_url', '_icon') + __slots__ = ( + "_state", + "id", + "name", + "description", + "rpc_origins", + "summary", + "verify_key", + "terms_of_service_url", + "privacy_policy_url", + "_icon", + ) def __init__(self, *, state: ConnectionState, data: PartialAppInfoPayload): self._state: ConnectionState = state - self.id: int = int(data['id']) - self.name: str = data['name'] - self._icon: Optional[str] = data.get('icon') - self.description: str = data['description'] - self.rpc_origins: Optional[List[str]] = data.get('rpc_origins') - self.summary: str = data['summary'] - self.verify_key: str = data['verify_key'] - self.terms_of_service_url: Optional[str] = data.get('terms_of_service_url') - self.privacy_policy_url: Optional[str] = data.get('privacy_policy_url') + self.id: int = int(data["id"]) + self.name: str = data["name"] + self._icon: Optional[str] = data.get("icon") + self.description: str = data["description"] + self.rpc_origins: Optional[List[str]] = data.get("rpc_origins") + self.summary: str = data["summary"] + self.verify_key: str = data["verify_key"] + self.terms_of_service_url: Optional[str] = data.get("terms_of_service_url") + self.privacy_policy_url: Optional[str] = data.get("privacy_policy_url") def __repr__(self) -> str: - return f'<{self.__class__.__name__} id={self.id} name={self.name!r} description={self.description!r}>' + return f"<{self.__class__.__name__} id={self.id} name={self.name!r} description={self.description!r}>" @property def icon(self) -> Optional[Asset]: """Optional[:class:`.Asset`]: Retrieves the application's icon asset, if any.""" if self._icon is None: return None - return Asset._from_icon(self._state, self.id, self._icon, path='app') + return Asset._from_icon(self._state, self.id, self._icon, path="app") diff --git a/discord/asset.py b/discord/asset.py index 25c72648..7d622984 100644 --- a/discord/asset.py +++ b/discord/asset.py @@ -33,13 +33,11 @@ from . import utils import yarl -__all__ = ( - 'Asset', -) +__all__ = ("Asset",) if TYPE_CHECKING: - ValidStaticFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png'] - ValidAssetFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png', 'gif'] + ValidStaticFormatTypes = Literal["webp", "jpeg", "jpg", "png"] + ValidAssetFormatTypes = Literal["webp", "jpeg", "jpg", "png", "gif"] VALID_STATIC_FORMATS = frozenset({"jpeg", "jpg", "webp", "png"}) VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"} @@ -47,6 +45,7 @@ VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"} MISSING = utils.MISSING + class AssetMixin: url: str _state: Optional[Any] @@ -71,7 +70,7 @@ class AssetMixin: The content of the asset. """ if self._state is None: - raise DiscordException('Invalid state (no ConnectionState provided)') + raise DiscordException("Invalid state (no ConnectionState provided)") return await self._state.http.get_from_cdn(self.url) @@ -112,7 +111,7 @@ class AssetMixin: fp.seek(0) return written else: - with open(fp, 'wb') as f: + with open(fp, "wb") as f: return f.write(data) @@ -143,13 +142,13 @@ class Asset(AssetMixin): """ __slots__: Tuple[str, ...] = ( - '_state', - '_url', - '_animated', - '_key', + "_state", + "_url", + "_animated", + "_key", ) - BASE = 'https://cdn.discordapp.com' + BASE = "https://cdn.discordapp.com" def __init__(self, state, *, url: str, key: str, animated: bool = False): self._state = state @@ -161,26 +160,26 @@ class Asset(AssetMixin): def _from_default_avatar(cls, state, index: int) -> Asset: return cls( state, - url=f'{cls.BASE}/embed/avatars/{index}.png', + url=f"{cls.BASE}/embed/avatars/{index}.png", key=str(index), animated=False, ) @classmethod def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset: - animated = avatar.startswith('a_') - format = 'gif' if animated else 'png' + animated = avatar.startswith("a_") + format = "gif" if animated else "png" return cls( state, - url=f'{cls.BASE}/avatars/{user_id}/{avatar}.{format}?size=1024', + url=f"{cls.BASE}/avatars/{user_id}/{avatar}.{format}?size=1024", key=avatar, animated=animated, ) @classmethod def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) -> Asset: - animated = avatar.startswith('a_') - format = 'gif' if animated else 'png' + animated = avatar.startswith("a_") + format = "gif" if animated else "png" return cls( state, url=f"{cls.BASE}/guilds/{guild_id}/users/{member_id}/avatars/{avatar}.{format}?size=1024", @@ -192,7 +191,7 @@ class Asset(AssetMixin): def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset: return cls( state, - url=f'{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024', + url=f"{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024", key=icon_hash, animated=False, ) @@ -201,7 +200,7 @@ class Asset(AssetMixin): def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asset: return cls( state, - url=f'{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024', + url=f"{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024", key=cover_image_hash, animated=False, ) @@ -210,18 +209,18 @@ class Asset(AssetMixin): def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset: return cls( state, - url=f'{cls.BASE}/{path}/{guild_id}/{image}.png?size=1024', + url=f"{cls.BASE}/{path}/{guild_id}/{image}.png?size=1024", key=image, animated=False, ) @classmethod def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset: - animated = icon_hash.startswith('a_') - format = 'gif' if animated else 'png' + animated = icon_hash.startswith("a_") + format = "gif" if animated else "png" return cls( state, - url=f'{cls.BASE}/icons/{guild_id}/{icon_hash}.{format}?size=1024', + url=f"{cls.BASE}/icons/{guild_id}/{icon_hash}.{format}?size=1024", key=icon_hash, animated=animated, ) @@ -230,20 +229,20 @@ class Asset(AssetMixin): def _from_sticker_banner(cls, state, banner: int) -> Asset: return cls( state, - url=f'{cls.BASE}/app-assets/710982414301790216/store/{banner}.png', + url=f"{cls.BASE}/app-assets/710982414301790216/store/{banner}.png", key=str(banner), animated=False, ) @classmethod def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset: - animated = banner_hash.startswith('a_') - format = 'gif' if animated else 'png' + animated = banner_hash.startswith("a_") + format = "gif" if animated else "png" return cls( state, - url=f'{cls.BASE}/banners/{user_id}/{banner_hash}.{format}?size=512', + url=f"{cls.BASE}/banners/{user_id}/{banner_hash}.{format}?size=512", key=banner_hash, - animated=animated + animated=animated, ) def __str__(self) -> str: @@ -253,8 +252,8 @@ class Asset(AssetMixin): return len(self._url) def __repr__(self): - shorten = self._url.replace(self.BASE, '') - return f'' + shorten = self._url.replace(self.BASE, "") + return f"" def __eq__(self, other): return isinstance(other, Asset) and self._url == other._url @@ -312,21 +311,21 @@ class Asset(AssetMixin): if format is not MISSING: if self._animated: if format not in VALID_ASSET_FORMATS: - raise InvalidArgument(f'format must be one of {VALID_ASSET_FORMATS}') - url = url.with_path(f'{path}.{format}') + raise InvalidArgument(f"format must be one of {VALID_ASSET_FORMATS}") + url = url.with_path(f"{path}.{format}") elif static_format is MISSING: if format not in VALID_STATIC_FORMATS: - raise InvalidArgument(f'format must be one of {VALID_STATIC_FORMATS}') - url = url.with_path(f'{path}.{format}') + raise InvalidArgument(f"format must be one of {VALID_STATIC_FORMATS}") + url = url.with_path(f"{path}.{format}") if static_format is not MISSING and not self._animated: if static_format not in VALID_STATIC_FORMATS: - raise InvalidArgument(f'static_format must be one of {VALID_STATIC_FORMATS}') - url = url.with_path(f'{path}.{static_format}') + raise InvalidArgument(f"static_format must be one of {VALID_STATIC_FORMATS}") + url = url.with_path(f"{path}.{static_format}") if size is not MISSING: if not utils.valid_icon_size(size): - raise InvalidArgument('size must be a power of 2 between 16 and 4096') + raise InvalidArgument("size must be a power of 2 between 16 and 4096") url = url.with_query(size=size) else: url = url.with_query(url.raw_query_string) @@ -353,7 +352,7 @@ class Asset(AssetMixin): The new updated asset. """ if not utils.valid_icon_size(size): - raise InvalidArgument('size must be a power of 2 between 16 and 4096') + raise InvalidArgument("size must be a power of 2 between 16 and 4096") url = str(yarl.URL(self._url).with_query(size=size)) return Asset(state=self._state, url=url, key=self._key, animated=self._animated) @@ -379,14 +378,14 @@ class Asset(AssetMixin): if self._animated: if format not in VALID_ASSET_FORMATS: - raise InvalidArgument(f'format must be one of {VALID_ASSET_FORMATS}') + raise InvalidArgument(f"format must be one of {VALID_ASSET_FORMATS}") else: if format not in VALID_STATIC_FORMATS: - raise InvalidArgument(f'format must be one of {VALID_STATIC_FORMATS}') + raise InvalidArgument(f"format must be one of {VALID_STATIC_FORMATS}") url = yarl.URL(self._url) path, _ = os.path.splitext(url.path) - url = str(url.with_path(f'{path}.{format}').with_query(url.raw_query_string)) + url = str(url.with_path(f"{path}.{format}").with_query(url.raw_query_string)) return Asset(state=self._state, url=url, key=self._key, animated=self._animated) def with_static_format(self, format: ValidStaticFormatTypes, /) -> Asset: diff --git a/discord/audit_logs.py b/discord/audit_logs.py index b74bbfef..7d5babb5 100644 --- a/discord/audit_logs.py +++ b/discord/audit_logs.py @@ -35,9 +35,9 @@ from .object import Object from .permissions import PermissionOverwrite, Permissions __all__ = ( - 'AuditLogDiff', - 'AuditLogChanges', - 'AuditLogEntry', + "AuditLogDiff", + "AuditLogChanges", + "AuditLogEntry", ) @@ -85,6 +85,7 @@ def _transform_member_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Uni return None return entry._get_member(int(data)) + def _transform_guild_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Guild]: if data is None: return None @@ -96,16 +97,16 @@ def _transform_overwrites( ) -> List[Tuple[Object, PermissionOverwrite]]: overwrites = [] for elem in data: - allow = Permissions(int(elem['allow'])) - deny = Permissions(int(elem['deny'])) + allow = Permissions(int(elem["allow"])) + deny = Permissions(int(elem["deny"])) ow = PermissionOverwrite.from_pair(allow, deny) - ow_type = elem['type'] - ow_id = int(elem['id']) + ow_type = elem["type"] + ow_id = int(elem["id"]) target = None - if ow_type == '0': + if ow_type == "0": target = entry.guild.get_role(ow_id) - elif ow_type == '1': + elif ow_type == "1": target = entry._get_member(ow_id) if target is None: @@ -137,7 +138,7 @@ def _guild_hash_transformer(path: str) -> Callable[[AuditLogEntry, Optional[str] return _transform -T = TypeVar('T', bound=enums.Enum) +T = TypeVar("T", bound=enums.Enum) def _enum_transformer(enum: Type[T]) -> Callable[[AuditLogEntry, int], T]: @@ -146,12 +147,14 @@ def _enum_transformer(enum: Type[T]) -> Callable[[AuditLogEntry, int], T]: return _transform + def _transform_type(entry: AuditLogEntry, data: Union[int]) -> Union[enums.ChannelType, enums.StickerType]: - if entry.action.name.startswith('sticker_'): + if entry.action.name.startswith("sticker_"): return enums.try_enum(enums.StickerType, data) else: return enums.try_enum(enums.ChannelType, data) + class AuditLogDiff: def __len__(self) -> int: return len(self.__dict__) @@ -160,8 +163,8 @@ class AuditLogDiff: yield from self.__dict__.items() def __repr__(self) -> str: - values = ' '.join('%s=%r' % item for item in self.__dict__.items()) - return f'' + values = " ".join("%s=%r" % item for item in self.__dict__.items()) + return f"" if TYPE_CHECKING: @@ -217,14 +220,14 @@ class AuditLogChanges: self.after = AuditLogDiff() for elem in data: - attr = elem['key'] + attr = elem["key"] # special cases for role add/remove - if attr == '$add': - self._handle_role(self.before, self.after, entry, elem['new_value']) # type: ignore + if attr == "$add": + self._handle_role(self.before, self.after, entry, elem["new_value"]) # type: ignore continue - elif attr == '$remove': - self._handle_role(self.after, self.before, entry, elem['new_value']) # type: ignore + elif attr == "$remove": + self._handle_role(self.after, self.before, entry, elem["new_value"]) # type: ignore continue try: @@ -238,7 +241,7 @@ class AuditLogChanges: transformer: Optional[Transformer] try: - before = elem['old_value'] + before = elem["old_value"] except KeyError: before = None else: @@ -248,7 +251,7 @@ class AuditLogChanges: setattr(self.before, attr, before) try: - after = elem['new_value'] + after = elem["new_value"] except KeyError: after = None else: @@ -258,34 +261,36 @@ class AuditLogChanges: setattr(self.after, attr, after) # add an alias - if hasattr(self.after, 'colour'): + if hasattr(self.after, "colour"): self.after.color = self.after.colour self.before.color = self.before.colour - if hasattr(self.after, 'expire_behavior'): + if hasattr(self.after, "expire_behavior"): self.after.expire_behaviour = self.after.expire_behavior self.before.expire_behaviour = self.before.expire_behavior def __repr__(self) -> str: - return f'' + return f"" - def _handle_role(self, first: AuditLogDiff, second: AuditLogDiff, entry: AuditLogEntry, elem: List[RolePayload]) -> None: - if not hasattr(first, 'roles'): - setattr(first, 'roles', []) + def _handle_role( + self, first: AuditLogDiff, second: AuditLogDiff, entry: AuditLogEntry, elem: List[RolePayload] + ) -> None: + if not hasattr(first, "roles"): + setattr(first, "roles", []) data = [] g: Guild = entry.guild # type: ignore for e in elem: - role_id = int(e['id']) + role_id = int(e["id"]) role = g.get_role(role_id) if role is None: role = Object(id=role_id) - role.name = e['name'] # type: ignore + role.name = e["name"] # type: ignore data.append(role) - setattr(second, 'roles', data) + setattr(second, "roles", data) class _AuditLogProxyMemberPrune: @@ -365,56 +370,56 @@ class AuditLogEntry(Hashable): self._from_data(data) def _from_data(self, data: AuditLogEntryPayload) -> None: - self.action = enums.try_enum(enums.AuditLogAction, data['action_type']) - self.id = int(data['id']) + self.action = enums.try_enum(enums.AuditLogAction, data["action_type"]) + self.id = int(data["id"]) # this key is technically not usually present - self.reason = data.get('reason') - self.extra = data.get('options') + self.reason = data.get("reason") + self.extra = data.get("options") if isinstance(self.action, enums.AuditLogAction) and self.extra: if self.action is enums.AuditLogAction.member_prune: # member prune has two keys with useful information self.extra: _AuditLogProxyMemberPrune = type( - '_AuditLogProxy', (), {k: int(v) for k, v in self.extra.items()} + "_AuditLogProxy", (), {k: int(v) for k, v in self.extra.items()} )() elif self.action is enums.AuditLogAction.member_move or self.action is enums.AuditLogAction.message_delete: - channel_id = int(self.extra['channel_id']) + channel_id = int(self.extra["channel_id"]) elems = { - 'count': int(self.extra['count']), - 'channel': self.guild.get_channel(channel_id) or Object(id=channel_id), + "count": int(self.extra["count"]), + "channel": self.guild.get_channel(channel_id) or Object(id=channel_id), } - self.extra: _AuditLogProxyMemberMoveOrMessageDelete = type('_AuditLogProxy', (), elems)() + self.extra: _AuditLogProxyMemberMoveOrMessageDelete = type("_AuditLogProxy", (), elems)() elif self.action is enums.AuditLogAction.member_disconnect: # The member disconnect action has a dict with some information elems = { - 'count': int(self.extra['count']), + "count": int(self.extra["count"]), } - self.extra: _AuditLogProxyMemberDisconnect = type('_AuditLogProxy', (), elems)() - elif self.action.name.endswith('pin'): + self.extra: _AuditLogProxyMemberDisconnect = type("_AuditLogProxy", (), elems)() + elif self.action.name.endswith("pin"): # the pin actions have a dict with some information - channel_id = int(self.extra['channel_id']) + channel_id = int(self.extra["channel_id"]) elems = { - 'channel': self.guild.get_channel(channel_id) or Object(id=channel_id), - 'message_id': int(self.extra['message_id']), + "channel": self.guild.get_channel(channel_id) or Object(id=channel_id), + "message_id": int(self.extra["message_id"]), } - self.extra: _AuditLogProxyPinAction = type('_AuditLogProxy', (), elems)() - elif self.action.name.startswith('overwrite_'): + self.extra: _AuditLogProxyPinAction = type("_AuditLogProxy", (), elems)() + elif self.action.name.startswith("overwrite_"): # the overwrite_ actions have a dict with some information - instance_id = int(self.extra['id']) - the_type = self.extra.get('type') - if the_type == '1': + instance_id = int(self.extra["id"]) + the_type = self.extra.get("type") + if the_type == "1": self.extra = self._get_member(instance_id) - elif the_type == '0': + elif the_type == "0": role = self.guild.get_role(instance_id) if role is None: role = Object(id=instance_id) - role.name = self.extra.get('role_name') # type: ignore + role.name = self.extra.get("role_name") # type: ignore self.extra: Role = role - elif self.action.name.startswith('stage_instance'): - channel_id = int(self.extra['channel_id']) - elems = {'channel': self.guild.get_channel(channel_id) or Object(id=channel_id)} - self.extra: _AuditLogProxyStageInstanceAction = type('_AuditLogProxy', (), elems)() + elif self.action.name.startswith("stage_instance"): + channel_id = int(self.extra["channel_id"]) + elems = {"channel": self.guild.get_channel(channel_id) or Object(id=channel_id)} + self.extra: _AuditLogProxyStageInstanceAction = type("_AuditLogProxy", (), elems)() # fmt: off self.extra: Union[ @@ -433,16 +438,16 @@ class AuditLogEntry(Hashable): # where new_value and old_value are not guaranteed to be there depending # on the action type, so let's just fetch it for now and only turn it # into meaningful data when requested - self._changes = data.get('changes', []) + self._changes = data.get("changes", []) - self.user = self._get_member(utils._get_as_snowflake(data, 'user_id')) # type: ignore - self._target_id = utils._get_as_snowflake(data, 'target_id') + self.user = self._get_member(utils._get_as_snowflake(data, "user_id")) # type: ignore + self._target_id = utils._get_as_snowflake(data, "target_id") def _get_member(self, user_id: int) -> Union[Member, User, None]: return self.guild.get_member(user_id) or self._users.get(user_id) def __repr__(self) -> str: - return f'' + return f"" @utils.cached_property def created_at(self) -> datetime.datetime: @@ -450,9 +455,13 @@ class AuditLogEntry(Hashable): return utils.snowflake_time(self.id) @utils.cached_property - def target(self) -> Union[Guild, abc.GuildChannel, Member, User, Role, Invite, Emoji, StageInstance, GuildSticker, Thread, Object, None]: + def target( + self, + ) -> Union[ + Guild, abc.GuildChannel, Member, User, Role, Invite, Emoji, StageInstance, GuildSticker, Thread, Object, None + ]: try: - converter = getattr(self, '_convert_target_' + self.action.target_type) + converter = getattr(self, "_convert_target_" + self.action.target_type) except AttributeError: return Object(id=self._target_id) else: @@ -498,11 +507,11 @@ class AuditLogEntry(Hashable): changeset = self.before if self.action is enums.AuditLogAction.invite_delete else self.after fake_payload = { - 'max_age': changeset.max_age, - 'max_uses': changeset.max_uses, - 'code': changeset.code, - 'temporary': changeset.temporary, - 'uses': changeset.uses, + "max_age": changeset.max_age, + "max_uses": changeset.max_uses, + "code": changeset.code, + "temporary": changeset.temporary, + "uses": changeset.uses, } obj = Invite(state=self._state, data=fake_payload, guild=self.guild, channel=changeset.channel) # type: ignore diff --git a/discord/backoff.py b/discord/backoff.py index 903ecf76..3b2d9f77 100644 --- a/discord/backoff.py +++ b/discord/backoff.py @@ -29,11 +29,10 @@ import time import random from typing import Callable, Generic, Literal, TypeVar, overload, Union -T = TypeVar('T', bool, Literal[True], Literal[False]) +T = TypeVar("T", bool, Literal[True], Literal[False]) + +__all__ = ("ExponentialBackoff",) -__all__ = ( - 'ExponentialBackoff', -) class ExponentialBackoff(Generic[T]): """An implementation of the exponential backoff algorithm @@ -69,7 +68,7 @@ class ExponentialBackoff(Generic[T]): rand = random.Random() rand.seed() - self._randfunc: Callable[..., Union[int, float]] = rand.randrange if integral else rand.uniform # type: ignore + self._randfunc: Callable[..., Union[int, float]] = rand.randrange if integral else rand.uniform # type: ignore @overload def delay(self: ExponentialBackoff[Literal[False]]) -> float: diff --git a/discord/channel.py b/discord/channel.py index f467d1cc..b9bc5861 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -57,14 +57,14 @@ from .threads import Thread from .iterators import ArchivedThreadIterator __all__ = ( - 'TextChannel', - 'VoiceChannel', - 'StageChannel', - 'DMChannel', - 'CategoryChannel', - 'StoreChannel', - 'GroupChannel', - 'PartialMessageable', + "TextChannel", + "VoiceChannel", + "StageChannel", + "DMChannel", + "CategoryChannel", + "StoreChannel", + "GroupChannel", + "PartialMessageable", ) if TYPE_CHECKING: @@ -155,51 +155,51 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): """ __slots__ = ( - 'name', - 'id', - 'guild', - 'topic', - '_state', - 'nsfw', - 'category_id', - 'position', - 'slowmode_delay', - '_overwrites', - '_type', - 'last_message_id', - 'default_auto_archive_duration', + "name", + "id", + "guild", + "topic", + "_state", + "nsfw", + "category_id", + "position", + "slowmode_delay", + "_overwrites", + "_type", + "last_message_id", + "default_auto_archive_duration", ) def __init__(self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload): self._state: ConnectionState = state - self.id: int = int(data['id']) - self._type: int = data['type'] + self.id: int = int(data["id"]) + self._type: int = data["type"] self._update(guild, data) def __repr__(self) -> str: attrs = [ - ('id', self.id), - ('name', self.name), - ('position', self.position), - ('nsfw', self.nsfw), - ('news', self.is_news()), - ('category_id', self.category_id), + ("id", self.id), + ("name", self.name), + ("position", self.position), + ("nsfw", self.nsfw), + ("news", self.is_news()), + ("category_id", self.category_id), ] - joined = ' '.join('%s=%r' % t for t in attrs) - return f'<{self.__class__.__name__} {joined}>' + joined = " ".join("%s=%r" % t for t in attrs) + return f"<{self.__class__.__name__} {joined}>" def _update(self, guild: Guild, data: TextChannelPayload) -> None: self.guild: Guild = guild - self.name: str = data['name'] - self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id') - self.topic: Optional[str] = data.get('topic') - self.position: int = data['position'] - self.nsfw: bool = data.get('nsfw', False) + self.name: str = data["name"] + self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id") + self.topic: Optional[str] = data.get("topic") + self.position: int = data["position"] + self.nsfw: bool = data.get("nsfw", False) # Does this need coercion into `int`? No idea yet. - self.slowmode_delay: int = data.get('rate_limit_per_user', 0) - self.default_auto_archive_duration: ThreadArchiveDuration = data.get('default_auto_archive_duration', 1440) - self._type: int = data.get('type', self._type) - self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id') + self.slowmode_delay: int = data.get("rate_limit_per_user", 0) + self.default_auto_archive_duration: ThreadArchiveDuration = data.get("default_auto_archive_duration", 1440) + self._type: int = data.get("type", self._type) + self.last_message_id: Optional[int] = utils._get_as_snowflake(data, "last_message_id") self._fill_overwrites(data) async def _get_channel(self): @@ -371,7 +371,9 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): @utils.copy_doc(discord.abc.GuildChannel.clone) async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> TextChannel: return await self._clone_impl( - {'topic': self.topic, 'nsfw': self.nsfw, 'rate_limit_per_user': self.slowmode_delay}, name=name, reason=reason + {"topic": self.topic, "nsfw": self.nsfw, "rate_limit_per_user": self.slowmode_delay}, + name=name, + reason=reason, ) async def delete_messages(self, messages: Iterable[Snowflake]) -> None: @@ -418,7 +420,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): return if len(messages) > 100: - raise ClientException('Can only bulk delete messages up to 100 messages') + raise ClientException("Can only bulk delete messages up to 100 messages") message_ids: SnowflakeList = [m.id for m in messages] await self._state.http.delete_messages(self.id, message_ids) @@ -558,7 +560,9 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): data = await self._state.http.channel_webhooks(self.id) return [Webhook.from_state(d, state=self._state) for d in data] - async def create_webhook(self, *, name: str, avatar: Optional[bytes] = None, reason: Optional[str] = None) -> Webhook: + async def create_webhook( + self, *, name: str, avatar: Optional[bytes] = None, reason: Optional[str] = None + ) -> Webhook: """|coro| Creates a webhook for this channel. @@ -635,10 +639,10 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): """ if not self.is_news(): - raise ClientException('The channel must be a news channel.') + raise ClientException("The channel must be a news channel.") if not isinstance(destination, TextChannel): - raise InvalidArgument(f'Expected TextChannel received {destination.__class__.__name__}') + raise InvalidArgument(f"Expected TextChannel received {destination.__class__.__name__}") from .webhook import Webhook @@ -802,40 +806,40 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable): __slots__ = ( - 'name', - 'id', - 'guild', - 'bitrate', - 'user_limit', - '_state', - 'position', - '_overwrites', - 'category_id', - 'rtc_region', - 'video_quality_mode', + "name", + "id", + "guild", + "bitrate", + "user_limit", + "_state", + "position", + "_overwrites", + "category_id", + "rtc_region", + "video_quality_mode", ) def __init__(self, *, state: ConnectionState, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]): self._state: ConnectionState = state - self.id: int = int(data['id']) + self.id: int = int(data["id"]) self._update(guild, data) def _get_voice_client_key(self) -> Tuple[int, str]: - return self.guild.id, 'guild_id' + return self.guild.id, "guild_id" def _get_voice_state_pair(self) -> Tuple[int, int]: return self.guild.id, self.id def _update(self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]) -> None: self.guild = guild - self.name: str = data['name'] - rtc = data.get('rtc_region') + self.name: str = data["name"] + rtc = data.get("rtc_region") self.rtc_region: Optional[VoiceRegion] = try_enum(VoiceRegion, rtc) if rtc is not None else None - self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get('video_quality_mode', 1)) - self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id') - self.position: int = data['position'] - self.bitrate: int = data.get('bitrate') - self.user_limit: int = data.get('user_limit') + self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get("video_quality_mode", 1)) + self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id") + self.position: int = data["position"] + self.bitrate: int = data.get("bitrate") + self.user_limit: int = data.get("user_limit") self._fill_overwrites(data) @property @@ -943,17 +947,17 @@ class VoiceChannel(VocalGuildChannel): def __repr__(self) -> str: attrs = [ - ('id', self.id), - ('name', self.name), - ('rtc_region', self.rtc_region), - ('position', self.position), - ('bitrate', self.bitrate), - ('video_quality_mode', self.video_quality_mode), - ('user_limit', self.user_limit), - ('category_id', self.category_id), + ("id", self.id), + ("name", self.name), + ("rtc_region", self.rtc_region), + ("position", self.position), + ("bitrate", self.bitrate), + ("video_quality_mode", self.video_quality_mode), + ("user_limit", self.user_limit), + ("category_id", self.category_id), ] - joined = ' '.join('%s=%r' % t for t in attrs) - return f'<{self.__class__.__name__} {joined}>' + joined = " ".join("%s=%r" % t for t in attrs) + return f"<{self.__class__.__name__} {joined}>" @property def type(self) -> ChannelType: @@ -962,7 +966,9 @@ class VoiceChannel(VocalGuildChannel): @utils.copy_doc(discord.abc.GuildChannel.clone) async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> VoiceChannel: - return await self._clone_impl({'bitrate': self.bitrate, 'user_limit': self.user_limit}, name=name, reason=reason) + return await self._clone_impl( + {"bitrate": self.bitrate, "user_limit": self.user_limit}, name=name, reason=reason + ) @overload async def edit( @@ -1103,26 +1109,26 @@ class StageChannel(VocalGuildChannel): .. versionadded:: 2.0 """ - __slots__ = ('topic',) + __slots__ = ("topic",) def __repr__(self) -> str: attrs = [ - ('id', self.id), - ('name', self.name), - ('topic', self.topic), - ('rtc_region', self.rtc_region), - ('position', self.position), - ('bitrate', self.bitrate), - ('video_quality_mode', self.video_quality_mode), - ('user_limit', self.user_limit), - ('category_id', self.category_id), + ("id", self.id), + ("name", self.name), + ("topic", self.topic), + ("rtc_region", self.rtc_region), + ("position", self.position), + ("bitrate", self.bitrate), + ("video_quality_mode", self.video_quality_mode), + ("user_limit", self.user_limit), + ("category_id", self.category_id), ] - joined = ' '.join('%s=%r' % t for t in attrs) - return f'<{self.__class__.__name__} {joined}>' + joined = " ".join("%s=%r" % t for t in attrs) + return f"<{self.__class__.__name__} {joined}>" def _update(self, guild: Guild, data: StageChannelPayload) -> None: super()._update(guild, data) - self.topic = data.get('topic') + self.topic = data.get("topic") @property def requesting_to_speak(self) -> List[Member]: @@ -1211,13 +1217,13 @@ class StageChannel(VocalGuildChannel): The newly created stage instance. """ - payload: Dict[str, Any] = {'channel_id': self.id, 'topic': topic} + payload: Dict[str, Any] = {"channel_id": self.id, "topic": topic} if privacy_level is not MISSING: if not isinstance(privacy_level, StagePrivacyLevel): - raise InvalidArgument('privacy_level field must be of type PrivacyLevel') + raise InvalidArgument("privacy_level field must be of type PrivacyLevel") - payload['privacy_level'] = privacy_level.value + payload["privacy_level"] = privacy_level.value data = await self._state.http.create_stage_instance(**payload, reason=reason) return StageInstance(guild=self.guild, state=self._state, data=data) @@ -1371,22 +1377,22 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): To check if the channel or the guild of that channel are marked as NSFW, consider :meth:`is_nsfw` instead. """ - __slots__ = ('name', 'id', 'guild', 'nsfw', '_state', 'position', '_overwrites', 'category_id') + __slots__ = ("name", "id", "guild", "nsfw", "_state", "position", "_overwrites", "category_id") def __init__(self, *, state: ConnectionState, guild: Guild, data: CategoryChannelPayload): self._state: ConnectionState = state - self.id: int = int(data['id']) + self.id: int = int(data["id"]) self._update(guild, data) def __repr__(self) -> str: - return f'' + return f"" def _update(self, guild: Guild, data: CategoryChannelPayload) -> None: self.guild: Guild = guild - self.name: str = data['name'] - self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id') - self.nsfw: bool = data.get('nsfw', False) - self.position: int = data['position'] + self.name: str = data["name"] + self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id") + self.nsfw: bool = data.get("nsfw", False) + self.position: int = data["position"] self._fill_overwrites(data) @property @@ -1404,7 +1410,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): @utils.copy_doc(discord.abc.GuildChannel.clone) async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> CategoryChannel: - return await self._clone_impl({'nsfw': self.nsfw}, name=name, reason=reason) + return await self._clone_impl({"nsfw": self.nsfw}, name=name, reason=reason) @overload async def edit( @@ -1473,7 +1479,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): @utils.copy_doc(discord.abc.GuildChannel.move) async def move(self, **kwargs): - kwargs.pop('category', None) + kwargs.pop("category", None) await super().move(**kwargs) @property @@ -1600,30 +1606,30 @@ class StoreChannel(discord.abc.GuildChannel, Hashable): """ __slots__ = ( - 'name', - 'id', - 'guild', - '_state', - 'nsfw', - 'category_id', - 'position', - '_overwrites', + "name", + "id", + "guild", + "_state", + "nsfw", + "category_id", + "position", + "_overwrites", ) def __init__(self, *, state: ConnectionState, guild: Guild, data: StoreChannelPayload): self._state: ConnectionState = state - self.id: int = int(data['id']) + self.id: int = int(data["id"]) self._update(guild, data) def __repr__(self) -> str: - return f'' + return f"" def _update(self, guild: Guild, data: StoreChannelPayload) -> None: self.guild: Guild = guild - self.name: str = data['name'] - self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id') - self.position: int = data['position'] - self.nsfw: bool = data.get('nsfw', False) + self.name: str = data["name"] + self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id") + self.position: int = data["position"] + self.nsfw: bool = data.get("nsfw", False) self._fill_overwrites(data) @property @@ -1650,7 +1656,7 @@ class StoreChannel(discord.abc.GuildChannel, Hashable): @utils.copy_doc(discord.abc.GuildChannel.clone) async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> StoreChannel: - return await self._clone_impl({'nsfw': self.nsfw}, name=name, reason=reason) + return await self._clone_impl({"nsfw": self.nsfw}, name=name, reason=reason) @overload async def edit( @@ -1726,7 +1732,7 @@ class StoreChannel(discord.abc.GuildChannel, Hashable): return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore -DMC = TypeVar('DMC', bound='DMChannel') +DMC = TypeVar("DMC", bound="DMChannel") class DMChannel(discord.abc.Messageable, Hashable): @@ -1766,24 +1772,24 @@ class DMChannel(discord.abc.Messageable, Hashable): The direct message channel ID. """ - __slots__ = ('id', 'recipient', 'me', '_state') + __slots__ = ("id", "recipient", "me", "_state") def __init__(self, *, me: ClientUser, state: ConnectionState, data: DMChannelPayload): self._state: ConnectionState = state - self.recipient: Optional[User] = state.store_user(data['recipients'][0]) + self.recipient: Optional[User] = state.store_user(data["recipients"][0]) self.me: ClientUser = me - self.id: int = int(data['id']) + self.id: int = int(data["id"]) async def _get_channel(self): return self def __str__(self) -> str: if self.recipient: - return f'Direct Message with {self.recipient}' - return 'Direct Message with Unknown User' + return f"Direct Message with {self.recipient}" + return "Direct Message with Unknown User" def __repr__(self) -> str: - return f'' + return f"" @classmethod def _from_message(cls: Type[DMC], state: ConnectionState, channel_id: int) -> DMC: @@ -1902,19 +1908,19 @@ class GroupChannel(discord.abc.Messageable, Hashable): The group channel's name if provided. """ - __slots__ = ('id', 'recipients', 'owner_id', 'owner', '_icon', 'name', 'me', '_state') + __slots__ = ("id", "recipients", "owner_id", "owner", "_icon", "name", "me", "_state") def __init__(self, *, me: ClientUser, state: ConnectionState, data: GroupChannelPayload): self._state: ConnectionState = state - self.id: int = int(data['id']) + self.id: int = int(data["id"]) self.me: ClientUser = me self._update_group(data) def _update_group(self, data: GroupChannelPayload) -> None: - self.owner_id: Optional[int] = utils._get_as_snowflake(data, 'owner_id') - self._icon: Optional[str] = data.get('icon') - self.name: Optional[str] = data.get('name') - self.recipients: List[User] = [self._state.store_user(u) for u in data.get('recipients', [])] + self.owner_id: Optional[int] = utils._get_as_snowflake(data, "owner_id") + self._icon: Optional[str] = data.get("icon") + self.name: Optional[str] = data.get("name") + self.recipients: List[User] = [self._state.store_user(u) for u in data.get("recipients", [])] self.owner: Optional[BaseUser] if self.owner_id == self.me.id: @@ -1930,12 +1936,12 @@ class GroupChannel(discord.abc.Messageable, Hashable): return self.name if len(self.recipients) == 0: - return 'Unnamed' + return "Unnamed" - return ', '.join(map(lambda x: x.name, self.recipients)) + return ", ".join(map(lambda x: x.name, self.recipients)) def __repr__(self) -> str: - return f'' + return f"" @property def type(self) -> ChannelType: @@ -1947,7 +1953,7 @@ class GroupChannel(discord.abc.Messageable, Hashable): """Optional[:class:`Asset`]: Returns the channel's icon asset if available.""" if self._icon is None: return None - return Asset._from_icon(self._state, self.id, self._icon, path='channel') + return Asset._from_icon(self._state, self.id, self._icon, path="channel") @property def created_at(self) -> datetime.datetime: diff --git a/discord/client.py b/discord/client.py index 746dd995..40f82443 100644 --- a/discord/client.py +++ b/discord/client.py @@ -29,7 +29,20 @@ import logging import signal import sys import traceback -from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + Coroutine, + Dict, + Generator, + List, + Optional, + Sequence, + TYPE_CHECKING, + Tuple, + TypeVar, + Union, +) import aiohttp @@ -69,46 +82,49 @@ if TYPE_CHECKING: from .member import Member from .voice_client import VoiceProtocol -__all__ = ( - 'Client', -) +__all__ = ("Client",) -Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]]) +Coro = TypeVar("Coro", bound=Callable[..., Coroutine[Any, Any, Any]]) _log = logging.getLogger(__name__) + def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None: tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()} if not tasks: return - _log.info('Cleaning up after %d tasks.', len(tasks)) + _log.info("Cleaning up after %d tasks.", len(tasks)) for task in tasks: task.cancel() loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) - _log.info('All tasks finished cancelling.') + _log.info("All tasks finished cancelling.") for task in tasks: if task.cancelled(): continue if task.exception() is not None: - loop.call_exception_handler({ - 'message': 'Unhandled exception during Client.run shutdown.', - 'exception': task.exception(), - 'task': task - }) + loop.call_exception_handler( + { + "message": "Unhandled exception during Client.run shutdown.", + "exception": task.exception(), + "task": task, + } + ) + def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None: try: _cancel_tasks(loop) loop.run_until_complete(loop.shutdown_asyncgens()) finally: - _log.info('Closing the event loop.') + _log.info("Closing the event loop.") loop.close() + class Client: r"""Represents a client connection that connects to Discord. This class is used to interact with the Discord WebSocket and API. @@ -199,6 +215,7 @@ class Client: loop: :class:`asyncio.AbstractEventLoop` The event loop that the client uses for asynchronous operations. """ + def __init__( self, *, @@ -212,24 +229,22 @@ class Client: self.ws: DiscordWebSocket = None # type: ignore self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {} - self.shard_id: Optional[int] = options.get('shard_id') - self.shard_count: Optional[int] = options.get('shard_count') + self.shard_id: Optional[int] = options.get("shard_id") + self.shard_count: Optional[int] = options.get("shard_count") - connector: Optional[aiohttp.BaseConnector] = options.pop('connector', None) - proxy: Optional[str] = options.pop('proxy', None) - proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None) - unsync_clock: bool = options.pop('assume_unsync_clock', True) - self.http: HTTPClient = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop) + connector: Optional[aiohttp.BaseConnector] = options.pop("connector", None) + proxy: Optional[str] = options.pop("proxy", None) + proxy_auth: Optional[aiohttp.BasicAuth] = options.pop("proxy_auth", None) + unsync_clock: bool = options.pop("assume_unsync_clock", True) + self.http: HTTPClient = HTTPClient( + connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop + ) - self._handlers: Dict[str, Callable] = { - 'ready': self._handle_ready - } + self._handlers: Dict[str, Callable] = {"ready": self._handle_ready} - self._hooks: Dict[str, Callable] = { - 'before_identify': self._call_before_identify_hook - } + self._hooks: Dict[str, Callable] = {"before_identify": self._call_before_identify_hook} - self._enable_debug_events: bool = options.pop('enable_debug_events', False) + self._enable_debug_events: bool = options.pop("enable_debug_events", False) self._connection: ConnectionState = self._get_state(**options) self._connection.shard_count = self.shard_count self._closed: bool = False @@ -247,8 +262,14 @@ class Client: return self.ws def _get_state(self, **options: Any) -> ConnectionState: - return ConnectionState(dispatch=self.dispatch, handlers=self._handlers, - hooks=self._hooks, http=self.http, loop=self.loop, **options) + return ConnectionState( + dispatch=self.dispatch, + handlers=self._handlers, + hooks=self._hooks, + http=self.http, + loop=self.loop, + **options, + ) def _handle_ready(self) -> None: self._ready.set() @@ -260,7 +281,7 @@ class Client: This could be referred to as the Discord WebSocket protocol latency. """ ws = self.ws - return float('nan') if not ws else ws.latency + return float("nan") if not ws else ws.latency def is_ws_ratelimited(self) -> bool: """:class:`bool`: Whether the websocket is currently rate limited. @@ -331,7 +352,7 @@ class Client: If this is not passed via ``__init__`` then this is retrieved through the gateway when an event contains the data. Usually after :func:`~discord.on_connect` is called. - + .. versionadded:: 2.0 """ return self._connection.application_id @@ -348,7 +369,9 @@ class Client: """:class:`bool`: Specifies if the client's internal cache is ready for use.""" return self._ready.is_set() - async def _run_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> None: + async def _run_event( + self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any + ) -> None: try: await coro(*args, **kwargs) except asyncio.CancelledError: @@ -359,14 +382,16 @@ class Client: except asyncio.CancelledError: pass - def _schedule_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task: + def _schedule_event( + self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any + ) -> asyncio.Task: wrapped = self._run_event(coro, event_name, *args, **kwargs) # Schedules the task - return asyncio.create_task(wrapped, name=f'discord.py: {event_name}') + return asyncio.create_task(wrapped, name=f"discord.py: {event_name}") def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None: - _log.debug('Dispatching event %s', event) - method = 'on_' + event + _log.debug("Dispatching event %s", event) + method = "on_" + event listeners = self._listeners.get(event) if listeners: @@ -413,7 +438,7 @@ class Client: overridden to have a different implementation. Check :func:`~discord.on_error` for more details. """ - print(f'Ignoring exception in {event_method}', file=sys.stderr) + print(f"Ignoring exception in {event_method}", file=sys.stderr) traceback.print_exc() # hooks @@ -470,7 +495,7 @@ class Client: passing status code. """ - _log.info('logging in using static token') + _log.info("logging in using static token") data = await self.http.static_login(token.strip()) self._connection.user = ClientUser(state=self._connection, data=data) @@ -502,29 +527,31 @@ class Client: backoff = ExponentialBackoff() ws_params = { - 'initial': True, - 'shard_id': self.shard_id, + "initial": True, + "shard_id": self.shard_id, } while not self.is_closed(): try: coro = DiscordWebSocket.from_client(self, **ws_params) self.ws = await asyncio.wait_for(coro, timeout=60.0) - ws_params['initial'] = False + ws_params["initial"] = False while True: await self.ws.poll_event() except ReconnectWebSocket as e: - _log.info('Got a request to %s the websocket.', e.op) - self.dispatch('disconnect') + _log.info("Got a request to %s the websocket.", e.op) + self.dispatch("disconnect") ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id) continue - except (OSError, - HTTPException, - GatewayNotFound, - ConnectionClosed, - aiohttp.ClientError, - asyncio.TimeoutError) as exc: + except ( + OSError, + HTTPException, + GatewayNotFound, + ConnectionClosed, + aiohttp.ClientError, + asyncio.TimeoutError, + ) as exc: - self.dispatch('disconnect') + self.dispatch("disconnect") if not reconnect: await self.close() if isinstance(exc, ConnectionClosed) and exc.code == 1000: @@ -654,10 +681,10 @@ class Client: try: loop.run_forever() except KeyboardInterrupt: - _log.info('Received signal to terminate bot and event loop.') + _log.info("Received signal to terminate bot and event loop.") finally: future.remove_done_callback(stop_loop_on_completion) - _log.info('Cleaning up tasks.') + _log.info("Cleaning up tasks.") _cleanup_loop(loop) if not future.cancelled(): @@ -686,10 +713,10 @@ class Client: self._connection._activity = None elif isinstance(value, BaseActivity): # ConnectionState._activity is typehinted as ActivityPayload, we're passing Dict[str, Any] - self._connection._activity = value.to_dict() # type: ignore + self._connection._activity = value.to_dict() # type: ignore else: - raise TypeError('activity must derive from BaseActivity.') - + raise TypeError("activity must derive from BaseActivity.") + @property def status(self): """:class:`.Status`: @@ -704,11 +731,11 @@ class Client: @status.setter def status(self, value): if value is Status.offline: - self._connection._status = 'invisible' + self._connection._status = "invisible" elif isinstance(value, Status): self._connection._status = str(value) else: - raise TypeError('status must derive from Status.') + raise TypeError("status must derive from Status.") @property def allowed_mentions(self) -> Optional[AllowedMentions]: @@ -723,7 +750,7 @@ class Client: if value is None or isinstance(value, AllowedMentions): self._connection.allowed_mentions = value else: - raise TypeError(f'allowed_mentions must be AllowedMentions not {value.__class__!r}') + raise TypeError(f"allowed_mentions must be AllowedMentions not {value.__class__!r}") @property def intents(self) -> Intents: @@ -760,7 +787,7 @@ class Client: This is useful if you have a channel_id but don't want to do an API call to send messages to it. - + .. versionadded:: 2.0 Parameters @@ -1033,8 +1060,10 @@ class Client: future = self.loop.create_future() if check is None: + def _check(*args): return True + check = _check ev = event.lower() @@ -1072,10 +1101,10 @@ class Client: """ if not asyncio.iscoroutinefunction(coro): - raise TypeError('event registered must be a coroutine function') + raise TypeError("event registered must be a coroutine function") setattr(self, coro.__name__, coro) - _log.debug('%s has successfully been registered as an event', coro.__name__) + _log.debug("%s has successfully been registered as an event", coro.__name__) return coro async def change_presence( @@ -1114,10 +1143,10 @@ class Client: """ if status is None: - status_str = 'online' + status_str = "online" status = Status.online elif status is Status.offline: - status_str = 'invisible' + status_str = "invisible" status = Status.offline else: status_str = str(status) @@ -1139,11 +1168,7 @@ class Client: # Guild stuff def fetch_guilds( - self, - *, - limit: Optional[int] = 100, - before: SnowflakeTime = None, - after: SnowflakeTime = None + self, *, limit: Optional[int] = 100, before: SnowflakeTime = None, after: SnowflakeTime = None ) -> GuildIterator: """Retrieves an :class:`.AsyncIterator` that enables receiving your guilds. @@ -1223,7 +1248,7 @@ class Client: """ code = utils.resolve_template(code) data = await self.http.get_template(code) - return Template(data=data, state=self._connection) # type: ignore + return Template(data=data, state=self._connection) # type: ignore async def fetch_guild(self, guild_id: int, /) -> Guild: """|coro| @@ -1339,12 +1364,14 @@ class Client: The stage instance from the stage channel ID. """ data = await self.http.get_stage_instance(channel_id) - guild = self.get_guild(int(data['guild_id'])) + guild = self.get_guild(int(data["guild_id"])) return StageInstance(guild=guild, state=self._connection, data=data) # type: ignore # Invite management - async def fetch_invite(self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True) -> Invite: + async def fetch_invite( + self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True + ) -> Invite: """|coro| Gets an :class:`.Invite` from a discord.gg URL or ID. @@ -1460,8 +1487,8 @@ class Client: The bot's application information. """ data = await self.http.application_info() - if 'rpc_origins' not in data: - data['rpc_origins'] = None + if "rpc_origins" not in data: + data["rpc_origins"] = None return AppInfo(self._connection, data) async def fetch_user(self, user_id: int, /) -> User: @@ -1524,19 +1551,19 @@ class Client: """ data = await self.http.get_channel(channel_id) - factory, ch_type = _threaded_channel_factory(data['type']) + factory, ch_type = _threaded_channel_factory(data["type"]) if factory is None: - raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data)) + raise InvalidData("Unknown channel type {type} for channel ID {id}.".format_map(data)) if ch_type in (ChannelType.group, ChannelType.private): # the factory will be a DMChannel or GroupChannel here - channel = factory(me=self.user, data=data, state=self._connection) # type: ignore + channel = factory(me=self.user, data=data, state=self._connection) # type: ignore else: # the factory can't be a DMChannel or GroupChannel here - guild_id = int(data['guild_id']) # type: ignore + guild_id = int(data["guild_id"]) # type: ignore guild = self.get_guild(guild_id) or Object(id=guild_id) # GuildChannels expect a Guild, we may be passing an Object - channel = factory(guild=guild, state=self._connection, data=data) # type: ignore + channel = factory(guild=guild, state=self._connection, data=data) # type: ignore return channel @@ -1582,8 +1609,8 @@ class Client: The sticker you requested. """ data = await self.http.get_sticker(sticker_id) - cls, _ = _sticker_factory(data['type']) # type: ignore - return cls(state=self._connection, data=data) # type: ignore + cls, _ = _sticker_factory(data["type"]) # type: ignore + return cls(state=self._connection, data=data) # type: ignore async def fetch_premium_sticker_packs(self) -> List[StickerPack]: """|coro| @@ -1603,7 +1630,7 @@ class Client: All available premium sticker packs. """ data = await self.http.list_premium_sticker_packs() - return [StickerPack(state=self._connection, data=pack) for pack in data['sticker_packs']] + return [StickerPack(state=self._connection, data=pack) for pack in data["sticker_packs"]] async def create_dm(self, user: Snowflake) -> DMChannel: """|coro| @@ -1638,7 +1665,7 @@ class Client: This method should be used for when a view is comprised of components that last longer than the lifecycle of the program. - + .. versionadded:: 2.0 Parameters @@ -1660,17 +1687,17 @@ class Client: """ if not isinstance(view, View): - raise TypeError(f'expected an instance of View not {view.__class__!r}') + raise TypeError(f"expected an instance of View not {view.__class__!r}") if not view.is_persistent(): - raise ValueError('View is not persistent. Items need to have a custom_id set and View must have no timeout') + raise ValueError("View is not persistent. Items need to have a custom_id set and View must have no timeout") self._connection.store_view(view, message_id) @property def persistent_views(self) -> Sequence[View]: """Sequence[:class:`.View`]: A sequence of persistent views added to the client. - + .. versionadded:: 2.0 """ return self._connection.persistent_views diff --git a/discord/colour.py b/discord/colour.py index 43ad6c6f..95447ff0 100644 --- a/discord/colour.py +++ b/discord/colour.py @@ -35,11 +35,11 @@ from typing import ( ) __all__ = ( - 'Colour', - 'Color', + "Colour", + "Color", ) -CT = TypeVar('CT', bound='Colour') +CT = TypeVar("CT", bound="Colour") class Colour: @@ -76,16 +76,16 @@ class Colour: The raw integer colour value. """ - __slots__ = ('value',) + __slots__ = ("value",) def __init__(self, value: int): if not isinstance(value, int): - raise TypeError(f'Expected int parameter, received {value.__class__.__name__} instead.') + raise TypeError(f"Expected int parameter, received {value.__class__.__name__} instead.") self.value: int = value def _get_byte(self, byte: int) -> int: - return (self.value >> (8 * byte)) & 0xff + return (self.value >> (8 * byte)) & 0xFF def __eq__(self, other: Any) -> bool: return isinstance(other, Colour) and self.value == other.value @@ -94,13 +94,13 @@ class Colour: return not self.__eq__(other) def __str__(self) -> str: - return f'#{self.value:0>6x}' + return f"#{self.value:0>6x}" def __int__(self) -> int: return self.value def __repr__(self) -> str: - return f'' + return f"" def __hash__(self) -> int: return hash(self.value) @@ -164,12 +164,12 @@ class Colour: @classmethod def teal(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x1abc9c``.""" - return cls(0x1abc9c) + return cls(0x1ABC9C) @classmethod def dark_teal(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x11806a``.""" - return cls(0x11806a) + return cls(0x11806A) @classmethod def brand_green(cls: Type[CT]) -> CT: @@ -182,17 +182,17 @@ class Colour: @classmethod def green(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x2ecc71``.""" - return cls(0x2ecc71) + return cls(0x2ECC71) @classmethod def dark_green(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x1f8b4c``.""" - return cls(0x1f8b4c) + return cls(0x1F8B4C) @classmethod def blue(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x3498db``.""" - return cls(0x3498db) + return cls(0x3498DB) @classmethod def dark_blue(cls: Type[CT]) -> CT: @@ -202,42 +202,42 @@ class Colour: @classmethod def purple(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x9b59b6``.""" - return cls(0x9b59b6) + return cls(0x9B59B6) @classmethod def dark_purple(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x71368a``.""" - return cls(0x71368a) + return cls(0x71368A) @classmethod def magenta(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0xe91e63``.""" - return cls(0xe91e63) + return cls(0xE91E63) @classmethod def dark_magenta(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0xad1457``.""" - return cls(0xad1457) + return cls(0xAD1457) @classmethod def gold(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0xf1c40f``.""" - return cls(0xf1c40f) + return cls(0xF1C40F) @classmethod def dark_gold(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0xc27c0e``.""" - return cls(0xc27c0e) + return cls(0xC27C0E) @classmethod def orange(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0xe67e22``.""" - return cls(0xe67e22) + return cls(0xE67E22) @classmethod def dark_orange(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0xa84300``.""" - return cls(0xa84300) + return cls(0xA84300) @classmethod def brand_red(cls: Type[CT]) -> CT: @@ -250,52 +250,52 @@ class Colour: @classmethod def red(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0xe74c3c``.""" - return cls(0xe74c3c) - + return cls(0xE74C3C) + @classmethod def nitro_booster(cls): """A factory method that returns a :class:`Colour` with a value of ``0xf47fff``. .. versionadded:: 2.0""" - return cls(0xf47fff) + return cls(0xF47FFF) @classmethod def dark_red(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x992d22``.""" - return cls(0x992d22) + return cls(0x992D22) @classmethod def lighter_grey(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x95a5a6``.""" - return cls(0x95a5a6) + return cls(0x95A5A6) lighter_gray = lighter_grey @classmethod def dark_grey(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x607d8b``.""" - return cls(0x607d8b) + return cls(0x607D8B) dark_gray = dark_grey @classmethod def light_grey(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x979c9f``.""" - return cls(0x979c9f) + return cls(0x979C9F) light_gray = light_grey @classmethod def darker_grey(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x546e7a``.""" - return cls(0x546e7a) + return cls(0x546E7A) darker_gray = darker_grey @classmethod def og_blurple(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x7289da``.""" - return cls(0x7289da) + return cls(0x7289DA) @classmethod def blurple(cls: Type[CT]) -> CT: @@ -305,7 +305,7 @@ class Colour: @classmethod def greyple(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x99aab5``.""" - return cls(0x99aab5) + return cls(0x99AAB5) @classmethod def dark_theme(cls: Type[CT]) -> CT: @@ -331,7 +331,7 @@ class Colour: .. versionadded:: 2.0 """ return cls(0xFEE75C) - + @classmethod def dark_blurple(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x4E5D94``. diff --git a/discord/components.py b/discord/components.py index 74c7be3d..aee13620 100644 --- a/discord/components.py +++ b/discord/components.py @@ -41,14 +41,14 @@ if TYPE_CHECKING: __all__ = ( - 'Component', - 'ActionRow', - 'Button', - 'SelectMenu', - 'SelectOption', + "Component", + "ActionRow", + "Button", + "SelectMenu", + "SelectOption", ) -C = TypeVar('C', bound='Component') +C = TypeVar("C", bound="Component") class Component: @@ -70,14 +70,14 @@ class Component: The type of component. """ - __slots__: Tuple[str, ...] = ('type',) + __slots__: Tuple[str, ...] = ("type",) __repr_info__: ClassVar[Tuple[str, ...]] type: ComponentType def __repr__(self) -> str: - attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__repr_info__) - return f'<{self.__class__.__name__} {attrs}>' + attrs = " ".join(f"{key}={getattr(self, key)!r}" for key in self.__repr_info__) + return f"<{self.__class__.__name__} {attrs}>" @classmethod def _raw_construct(cls: Type[C], **kwargs) -> C: @@ -112,18 +112,18 @@ class ActionRow(Component): The children components that this holds, if any. """ - __slots__: Tuple[str, ...] = ('children',) + __slots__: Tuple[str, ...] = ("children",) __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ def __init__(self, data: ComponentPayload): - self.type: ComponentType = try_enum(ComponentType, data['type']) - self.children: List[Component] = [_component_factory(d) for d in data.get('components', [])] + self.type: ComponentType = try_enum(ComponentType, data["type"]) + self.children: List[Component] = [_component_factory(d) for d in data.get("components", [])] def to_dict(self) -> ActionRowPayload: return { - 'type': int(self.type), - 'components': [child.to_dict() for child in self.children], + "type": int(self.type), + "components": [child.to_dict() for child in self.children], } # type: ignore @@ -157,44 +157,44 @@ class Button(Component): """ __slots__: Tuple[str, ...] = ( - 'style', - 'custom_id', - 'url', - 'disabled', - 'label', - 'emoji', + "style", + "custom_id", + "url", + "disabled", + "label", + "emoji", ) __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ def __init__(self, data: ButtonComponentPayload): - self.type: ComponentType = try_enum(ComponentType, data['type']) - self.style: ButtonStyle = try_enum(ButtonStyle, data['style']) - self.custom_id: Optional[str] = data.get('custom_id') - self.url: Optional[str] = data.get('url') - self.disabled: bool = data.get('disabled', False) - self.label: Optional[str] = data.get('label') + self.type: ComponentType = try_enum(ComponentType, data["type"]) + self.style: ButtonStyle = try_enum(ButtonStyle, data["style"]) + self.custom_id: Optional[str] = data.get("custom_id") + self.url: Optional[str] = data.get("url") + self.disabled: bool = data.get("disabled", False) + self.label: Optional[str] = data.get("label") self.emoji: Optional[PartialEmoji] try: - self.emoji = PartialEmoji.from_dict(data['emoji']) + self.emoji = PartialEmoji.from_dict(data["emoji"]) except KeyError: self.emoji = None def to_dict(self) -> ButtonComponentPayload: payload = { - 'type': 2, - 'style': int(self.style), - 'label': self.label, - 'disabled': self.disabled, + "type": 2, + "style": int(self.style), + "label": self.label, + "disabled": self.disabled, } if self.custom_id: - payload['custom_id'] = self.custom_id + payload["custom_id"] = self.custom_id if self.url: - payload['url'] = self.url + payload["url"] = self.url if self.emoji: - payload['emoji'] = self.emoji.to_dict() + payload["emoji"] = self.emoji.to_dict() return payload # type: ignore @@ -231,37 +231,37 @@ class SelectMenu(Component): """ __slots__: Tuple[str, ...] = ( - 'custom_id', - 'placeholder', - 'min_values', - 'max_values', - 'options', - 'disabled', + "custom_id", + "placeholder", + "min_values", + "max_values", + "options", + "disabled", ) __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ def __init__(self, data: SelectMenuPayload): self.type = ComponentType.select - self.custom_id: str = data['custom_id'] - self.placeholder: Optional[str] = data.get('placeholder') - self.min_values: int = data.get('min_values', 1) - self.max_values: int = data.get('max_values', 1) - self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get('options', [])] - self.disabled: bool = data.get('disabled', False) + self.custom_id: str = data["custom_id"] + self.placeholder: Optional[str] = data.get("placeholder") + self.min_values: int = data.get("min_values", 1) + self.max_values: int = data.get("max_values", 1) + self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get("options", [])] + self.disabled: bool = data.get("disabled", False) def to_dict(self) -> SelectMenuPayload: payload: SelectMenuPayload = { - 'type': self.type.value, - 'custom_id': self.custom_id, - 'min_values': self.min_values, - 'max_values': self.max_values, - 'options': [op.to_dict() for op in self.options], - 'disabled': self.disabled, + "type": self.type.value, + "custom_id": self.custom_id, + "min_values": self.min_values, + "max_values": self.max_values, + "options": [op.to_dict() for op in self.options], + "disabled": self.disabled, } if self.placeholder: - payload['placeholder'] = self.placeholder + payload["placeholder"] = self.placeholder return payload @@ -292,11 +292,11 @@ class SelectOption: """ __slots__: Tuple[str, ...] = ( - 'label', - 'value', - 'description', - 'emoji', - 'default', + "label", + "value", + "description", + "emoji", + "default", ) def __init__( @@ -318,60 +318,60 @@ class SelectOption: elif isinstance(emoji, _EmojiTag): emoji = emoji._to_partial() else: - raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}') + raise TypeError(f"expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}") self.emoji = emoji self.default = default def __repr__(self) -> str: return ( - f'' + f"" ) def __str__(self) -> str: if self.emoji: - base = f'{self.emoji} {self.label}' + base = f"{self.emoji} {self.label}" else: base = self.label if self.description: - return f'{base}\n{self.description}' + return f"{base}\n{self.description}" return base @classmethod def from_dict(cls, data: SelectOptionPayload) -> SelectOption: try: - emoji = PartialEmoji.from_dict(data['emoji']) + emoji = PartialEmoji.from_dict(data["emoji"]) except KeyError: emoji = None return cls( - label=data['label'], - value=data['value'], - description=data.get('description'), + label=data["label"], + value=data["value"], + description=data.get("description"), emoji=emoji, - default=data.get('default', False), + default=data.get("default", False), ) def to_dict(self) -> SelectOptionPayload: payload: SelectOptionPayload = { - 'label': self.label, - 'value': self.value, - 'default': self.default, + "label": self.label, + "value": self.value, + "default": self.default, } if self.emoji: - payload['emoji'] = self.emoji.to_dict() # type: ignore + payload["emoji"] = self.emoji.to_dict() # type: ignore if self.description: - payload['description'] = self.description + payload["description"] = self.description return payload def _component_factory(data: ComponentPayload) -> Component: - component_type = data['type'] + component_type = data["type"] if component_type == 1: return ActionRow(data) elif component_type == 2: diff --git a/discord/context_managers.py b/discord/context_managers.py index a3ab0d19..226e5662 100644 --- a/discord/context_managers.py +++ b/discord/context_managers.py @@ -32,11 +32,10 @@ if TYPE_CHECKING: from types import TracebackType - TypingT = TypeVar('TypingT', bound='Typing') + TypingT = TypeVar("TypingT", bound="Typing") + +__all__ = ("Typing",) -__all__ = ( - 'Typing', -) def _typing_done_callback(fut: asyncio.Future) -> None: # just retrieve any exception and call it a day @@ -45,6 +44,7 @@ def _typing_done_callback(fut: asyncio.Future) -> None: except (asyncio.CancelledError, Exception): pass + class Typing: def __init__(self, messageable: Messageable) -> None: self.loop: asyncio.AbstractEventLoop = messageable._state.loop @@ -67,7 +67,8 @@ class Typing: self.task.add_done_callback(_typing_done_callback) return self - def __exit__(self, + def __exit__( + self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], @@ -79,7 +80,8 @@ class Typing: await channel._state.http.send_typing(channel.id) return self.__enter__() - async def __aexit__(self, + async def __aexit__( + self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], diff --git a/discord/embeds.py b/discord/embeds.py index 52d71ef4..4e1647f1 100644 --- a/discord/embeds.py +++ b/discord/embeds.py @@ -30,9 +30,7 @@ from typing import Any, Dict, Final, List, Mapping, Protocol, TYPE_CHECKING, Typ from . import utils from .colour import Colour -__all__ = ( - 'Embed', -) +__all__ = ("Embed",) class _EmptyEmbed: @@ -40,7 +38,7 @@ class _EmptyEmbed: return False def __repr__(self) -> str: - return 'Embed.Empty' + return "Embed.Empty" def __len__(self) -> int: return 0 @@ -57,51 +55,45 @@ class EmbedProxy: return len(self.__dict__) def __repr__(self) -> str: - inner = ', '.join((f'{k}={v!r}' for k, v in self.__dict__.items() if not k.startswith('_'))) - return f'EmbedProxy({inner})' + inner = ", ".join((f"{k}={v!r}" for k, v in self.__dict__.items() if not k.startswith("_"))) + return f"EmbedProxy({inner})" def __getattr__(self, attr: str) -> _EmptyEmbed: return EmptyEmbed -E = TypeVar('E', bound='Embed') +E = TypeVar("E", bound="Embed") if TYPE_CHECKING: from discord.types.embed import Embed as EmbedData, EmbedType - T = TypeVar('T') + T = TypeVar("T") MaybeEmpty = Union[T, _EmptyEmbed] - class _EmbedFooterProxy(Protocol): text: MaybeEmpty[str] icon_url: MaybeEmpty[str] - class _EmbedFieldProxy(Protocol): name: MaybeEmpty[str] value: MaybeEmpty[str] inline: bool - class _EmbedMediaProxy(Protocol): url: MaybeEmpty[str] proxy_url: MaybeEmpty[str] height: MaybeEmpty[int] width: MaybeEmpty[int] - class _EmbedVideoProxy(Protocol): url: MaybeEmpty[str] height: MaybeEmpty[int] width: MaybeEmpty[int] - class _EmbedProviderProxy(Protocol): name: MaybeEmpty[str] url: MaybeEmpty[str] - class _EmbedAuthorProxy(Protocol): name: MaybeEmpty[str] url: MaybeEmpty[str] @@ -163,33 +155,33 @@ class Embed: """ __slots__ = ( - 'title', - 'url', - 'type', - '_timestamp', - '_colour', - '_footer', - '_image', - '_thumbnail', - '_video', - '_provider', - '_author', - '_fields', - 'description', + "title", + "url", + "type", + "_timestamp", + "_colour", + "_footer", + "_image", + "_thumbnail", + "_video", + "_provider", + "_author", + "_fields", + "description", ) Empty: Final = EmptyEmbed def __init__( - self, - *, - colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed, - color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed, - title: MaybeEmpty[Any] = EmptyEmbed, - type: EmbedType = 'rich', - url: MaybeEmpty[Any] = EmptyEmbed, - description: MaybeEmpty[Any] = EmptyEmbed, - timestamp: datetime.datetime = None, + self, + *, + colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed, + color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed, + title: MaybeEmpty[Any] = EmptyEmbed, + type: EmbedType = "rich", + url: MaybeEmpty[Any] = EmptyEmbed, + description: MaybeEmpty[Any] = EmptyEmbed, + timestamp: datetime.datetime = None, ): self.colour = colour if colour is not EmptyEmbed else color @@ -231,10 +223,10 @@ class Embed: # fill in the basic fields - self.title = data.get('title', EmptyEmbed) - self.type = data.get('type', EmptyEmbed) - self.description = data.get('description', EmptyEmbed) - self.url = data.get('url', EmptyEmbed) + self.title = data.get("title", EmptyEmbed) + self.type = data.get("type", EmptyEmbed) + self.description = data.get("description", EmptyEmbed) + self.url = data.get("url", EmptyEmbed) if self.title is not EmptyEmbed: self.title = str(self.title) @@ -248,22 +240,22 @@ class Embed: # try to fill in the more rich fields try: - self._colour = Colour(value=data['color']) + self._colour = Colour(value=data["color"]) except KeyError: pass try: - self._timestamp = utils.parse_time(data['timestamp']) + self._timestamp = utils.parse_time(data["timestamp"]) except KeyError: pass - for attr in ('thumbnail', 'video', 'provider', 'author', 'fields', 'image', 'footer'): + for attr in ("thumbnail", "video", "provider", "author", "fields", "image", "footer"): try: value = data[attr] except KeyError: continue else: - setattr(self, '_' + attr, value) + setattr(self, "_" + attr, value) return self @@ -273,11 +265,11 @@ class Embed: def __len__(self) -> int: total = len(self.title) + len(self.description) - for field in getattr(self, '_fields', []): - total += len(field['name']) + len(field['value']) + for field in getattr(self, "_fields", []): + total += len(field["name"]) + len(field["value"]) try: - footer_text = self._footer['text'] + footer_text = self._footer["text"] except (AttributeError, KeyError): pass else: @@ -288,7 +280,7 @@ class Embed: except AttributeError: pass else: - total += len(author['name']) + total += len(author["name"]) return total @@ -312,7 +304,7 @@ class Embed: @property def colour(self) -> MaybeEmpty[Colour]: - return getattr(self, '_colour', EmptyEmbed) + return getattr(self, "_colour", EmptyEmbed) @colour.setter def colour(self, value: Union[int, Colour, _EmptyEmbed]): # type: ignore @@ -321,13 +313,15 @@ class Embed: elif isinstance(value, int): self._colour = Colour(value=value) else: - raise TypeError(f'Expected discord.Colour, int, or Embed.Empty but received {value.__class__.__name__} instead.') + raise TypeError( + f"Expected discord.Colour, int, or Embed.Empty but received {value.__class__.__name__} instead." + ) color = colour @property def timestamp(self) -> MaybeEmpty[datetime.datetime]: - return getattr(self, '_timestamp', EmptyEmbed) + return getattr(self, "_timestamp", EmptyEmbed) @timestamp.setter def timestamp(self, value: MaybeEmpty[datetime.datetime]): @@ -348,7 +342,7 @@ class Embed: If the attribute has no value then :attr:`Empty` is returned. """ - return EmbedProxy(getattr(self, '_footer', {})) # type: ignore + return EmbedProxy(getattr(self, "_footer", {})) # type: ignore def set_footer(self: E, *, text: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed) -> E: """Sets the footer for the embed content. @@ -366,10 +360,10 @@ class Embed: self._footer = {} if text is not EmptyEmbed: - self._footer['text'] = str(text) + self._footer["text"] = str(text) if icon_url is not EmptyEmbed: - self._footer['icon_url'] = str(icon_url) + self._footer["icon_url"] = str(icon_url) return self @@ -401,7 +395,7 @@ class Embed: If the attribute has no value then :attr:`Empty` is returned. """ - return EmbedProxy(getattr(self, '_image', {})) # type: ignore + return EmbedProxy(getattr(self, "_image", {})) # type: ignore @image.setter def image(self: E, url: Any): @@ -409,7 +403,7 @@ class Embed: del self._image else: self._image = { - 'url': str(url), + "url": str(url), } @image.deleter @@ -451,7 +445,7 @@ class Embed: If the attribute has no value then :attr:`Empty` is returned. """ - return EmbedProxy(getattr(self, '_thumbnail', {})) # type: ignore + return EmbedProxy(getattr(self, "_thumbnail", {})) # type: ignore @thumbnail.setter def thumbnail(self: E, url: Any): @@ -459,7 +453,7 @@ class Embed: del self._thumbnail else: self._thumbnail = { - 'url': str(url), + "url": str(url), } @thumbnail.deleter @@ -500,7 +494,7 @@ class Embed: If the attribute has no value then :attr:`Empty` is returned. """ - return EmbedProxy(getattr(self, '_video', {})) # type: ignore + return EmbedProxy(getattr(self, "_video", {})) # type: ignore @property def provider(self) -> _EmbedProviderProxy: @@ -510,7 +504,7 @@ class Embed: If the attribute has no value then :attr:`Empty` is returned. """ - return EmbedProxy(getattr(self, '_provider', {})) # type: ignore + return EmbedProxy(getattr(self, "_provider", {})) # type: ignore @property def author(self) -> _EmbedAuthorProxy: @@ -520,9 +514,11 @@ class Embed: If the attribute has no value then :attr:`Empty` is returned. """ - return EmbedProxy(getattr(self, '_author', {})) # type: ignore + return EmbedProxy(getattr(self, "_author", {})) # type: ignore - def set_author(self: E, *, name: Any, url: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed) -> E: + def set_author( + self: E, *, name: Any, url: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed + ) -> E: """Sets the author for the embed content. This function returns the class instance to allow for fluent-style @@ -539,14 +535,14 @@ class Embed: """ self._author = { - 'name': str(name), + "name": str(name), } if url is not EmptyEmbed: - self._author['url'] = str(url) + self._author["url"] = str(url) if icon_url is not EmptyEmbed: - self._author['icon_url'] = str(icon_url) + self._author["icon_url"] = str(icon_url) return self @@ -573,7 +569,7 @@ class Embed: If the attribute has no value then :attr:`Empty` is returned. """ - return [EmbedProxy(d) for d in getattr(self, '_fields', [])] # type: ignore + return [EmbedProxy(d) for d in getattr(self, "_fields", [])] # type: ignore def add_field(self: E, *, name: Any, value: Any, inline: bool = True) -> E: """Adds a field to the embed object. @@ -592,9 +588,9 @@ class Embed: """ field = { - 'inline': inline, - 'name': str(name), - 'value': str(value), + "inline": inline, + "name": str(name), + "value": str(value), } try: @@ -625,9 +621,9 @@ class Embed: """ field = { - 'inline': inline, - 'name': str(name), - 'value': str(value), + "inline": inline, + "name": str(name), + "value": str(value), } try: @@ -693,11 +689,11 @@ class Embed: try: field = self._fields[index] except (TypeError, IndexError, AttributeError): - raise IndexError('field index out of range') + raise IndexError("field index out of range") - field['name'] = str(name) - field['value'] = str(value) - field['inline'] = inline + field["name"] = str(name) + field["value"] = str(value) + field["inline"] = inline return self def to_dict(self) -> EmbedData: @@ -715,35 +711,35 @@ class Embed: # deal with basic convenience wrappers try: - colour = result.pop('colour') + colour = result.pop("colour") except KeyError: pass else: if colour: - result['color'] = colour.value + result["color"] = colour.value try: - timestamp = result.pop('timestamp') + timestamp = result.pop("timestamp") except KeyError: pass else: if timestamp: if timestamp.tzinfo: - result['timestamp'] = timestamp.astimezone(tz=datetime.timezone.utc).isoformat() + result["timestamp"] = timestamp.astimezone(tz=datetime.timezone.utc).isoformat() else: - result['timestamp'] = timestamp.replace(tzinfo=datetime.timezone.utc).isoformat() + result["timestamp"] = timestamp.replace(tzinfo=datetime.timezone.utc).isoformat() # add in the non raw attribute ones if self.type: - result['type'] = self.type + result["type"] = self.type if self.description: - result['description'] = self.description + result["description"] = self.description if self.url: - result['url'] = self.url + result["url"] = self.url if self.title: - result['title'] = self.title + result["title"] = self.title return result # type: ignore diff --git a/discord/emoji.py b/discord/emoji.py index 39fa0218..891afef1 100644 --- a/discord/emoji.py +++ b/discord/emoji.py @@ -30,9 +30,7 @@ from .utils import SnowflakeList, snowflake_time, MISSING from .partial_emoji import _EmojiTag, PartialEmoji from .user import User -__all__ = ( - 'Emoji', -) +__all__ = ("Emoji",) if TYPE_CHECKING: from .types.emoji import Emoji as EmojiPayload @@ -98,16 +96,16 @@ class Emoji(_EmojiTag, AssetMixin): """ __slots__: Tuple[str, ...] = ( - 'require_colons', - 'animated', - 'managed', - 'id', - 'name', - '_roles', - 'guild_id', - '_state', - 'user', - 'available', + "require_colons", + "animated", + "managed", + "id", + "name", + "_roles", + "guild_id", + "_state", + "user", + "available", ) def __init__(self, *, guild: Guild, state: ConnectionState, data: EmojiPayload): @@ -116,14 +114,14 @@ class Emoji(_EmojiTag, AssetMixin): self._from_data(data) def _from_data(self, emoji: EmojiPayload): - self.require_colons: bool = emoji.get('require_colons', False) - self.managed: bool = emoji.get('managed', False) - self.id: int = int(emoji['id']) # type: ignore - self.name: str = emoji['name'] # type: ignore - self.animated: bool = emoji.get('animated', False) - self.available: bool = emoji.get('available', True) - self._roles: SnowflakeList = SnowflakeList(map(int, emoji.get('roles', []))) - user = emoji.get('user') + self.require_colons: bool = emoji.get("require_colons", False) + self.managed: bool = emoji.get("managed", False) + self.id: int = int(emoji["id"]) # type: ignore + self.name: str = emoji["name"] # type: ignore + self.animated: bool = emoji.get("animated", False) + self.available: bool = emoji.get("available", True) + self._roles: SnowflakeList = SnowflakeList(map(int, emoji.get("roles", []))) + user = emoji.get("user") self.user: Optional[User] = User(state=self._state, data=user) if user else None def _to_partial(self) -> PartialEmoji: @@ -131,21 +129,21 @@ class Emoji(_EmojiTag, AssetMixin): def __iter__(self) -> Iterator[Tuple[str, Any]]: for attr in self.__slots__: - if attr[0] != '_': + if attr[0] != "_": value = getattr(self, attr, None) if value is not None: yield (attr, value) def __str__(self) -> str: if self.animated: - return f'' - return f'<:{self.name}:{self.id}>' + return f"" + return f"<:{self.name}:{self.id}>" def __int__(self) -> int: return self.id def __repr__(self) -> str: - return f'' + return f"" def __eq__(self, other: Any) -> bool: return isinstance(other, _EmojiTag) and self.id == other.id @@ -164,8 +162,8 @@ class Emoji(_EmojiTag, AssetMixin): @property def url(self) -> str: """:class:`str`: Returns the URL of the emoji.""" - fmt = 'gif' if self.animated else 'png' - return f'{Asset.BASE}/emojis/{self.id}.{fmt}' + fmt = "gif" if self.animated else "png" + return f"{Asset.BASE}/emojis/{self.id}.{fmt}" @property def roles(self) -> List[Role]: @@ -219,7 +217,9 @@ class Emoji(_EmojiTag, AssetMixin): await self._state.http.delete_custom_emoji(self.guild.id, self.id, reason=reason) - async def edit(self, *, name: str = MISSING, roles: List[Snowflake] = MISSING, reason: Optional[str] = None) -> Emoji: + async def edit( + self, *, name: str = MISSING, roles: List[Snowflake] = MISSING, reason: Optional[str] = None + ) -> Emoji: r"""|coro| Edits the custom emoji. @@ -254,9 +254,9 @@ class Emoji(_EmojiTag, AssetMixin): payload = {} if name is not MISSING: - payload['name'] = name + payload["name"] = name if roles is not MISSING: - payload['roles'] = [role.id for role in roles] + payload["roles"] = [role.id for role in roles] data = await self._state.http.edit_custom_emoji(self.guild.id, self.id, payload=payload, reason=reason) return Emoji(guild=self.guild, data=data, state=self._state) diff --git a/discord/enums.py b/discord/enums.py index af8ee2b0..023ab7d0 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -27,41 +27,41 @@ from collections import namedtuple from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Type, TypeVar __all__ = ( - 'Enum', - 'ChannelType', - 'MessageType', - 'VoiceRegion', - 'SpeakingState', - 'VerificationLevel', - 'ContentFilter', - 'Status', - 'DefaultAvatar', - 'AuditLogAction', - 'AuditLogActionCategory', - 'UserFlags', - 'ActivityType', - 'NotificationLevel', - 'TeamMembershipState', - 'WebhookType', - 'ExpireBehaviour', - 'ExpireBehavior', - 'StickerType', - 'StickerFormatType', - 'InviteTarget', - 'VideoQualityMode', - 'ComponentType', - 'ButtonStyle', - 'StagePrivacyLevel', - 'InteractionType', - 'InteractionResponseType', - 'NSFWLevel', + "Enum", + "ChannelType", + "MessageType", + "VoiceRegion", + "SpeakingState", + "VerificationLevel", + "ContentFilter", + "Status", + "DefaultAvatar", + "AuditLogAction", + "AuditLogActionCategory", + "UserFlags", + "ActivityType", + "NotificationLevel", + "TeamMembershipState", + "WebhookType", + "ExpireBehaviour", + "ExpireBehavior", + "StickerType", + "StickerFormatType", + "InviteTarget", + "VideoQualityMode", + "ComponentType", + "ButtonStyle", + "StagePrivacyLevel", + "InteractionType", + "InteractionResponseType", + "NSFWLevel", ) def _create_value_cls(name, comparable): - cls = namedtuple('_EnumValue_' + name, 'name value') - cls.__repr__ = lambda self: f'<{name}.{self.name}: {self.value!r}>' - cls.__str__ = lambda self: f'{name}.{self.name}' + cls = namedtuple("_EnumValue_" + name, "name value") + cls.__repr__ = lambda self: f"<{name}.{self.name}: {self.value!r}>" + cls.__str__ = lambda self: f"{name}.{self.name}" if comparable: cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value @@ -69,8 +69,9 @@ def _create_value_cls(name, comparable): cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value return cls + def _is_descriptor(obj): - return hasattr(obj, '__get__') or hasattr(obj, '__set__') or hasattr(obj, '__delete__') + return hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") class EnumMeta(type): @@ -88,7 +89,7 @@ class EnumMeta(type): value_cls = _create_value_cls(name, comparable) for key, value in list(attrs.items()): is_descriptor = _is_descriptor(value) - if key[0] == '_' and not is_descriptor: + if key[0] == "_" and not is_descriptor: continue # Special case classmethod to just pass through @@ -110,10 +111,10 @@ class EnumMeta(type): member_mapping[key] = new_value attrs[key] = new_value - attrs['_enum_value_map_'] = value_mapping - attrs['_enum_member_map_'] = member_mapping - attrs['_enum_member_names_'] = member_names - attrs['_enum_value_cls_'] = value_cls + attrs["_enum_value_map_"] = value_mapping + attrs["_enum_member_map_"] = member_mapping + attrs["_enum_member_names_"] = member_names + attrs["_enum_value_cls_"] = value_cls actual_cls = super().__new__(cls, name, bases, attrs) value_cls._actual_enum_cls_ = actual_cls # type: ignore return actual_cls @@ -128,7 +129,7 @@ class EnumMeta(type): return len(cls._enum_member_names_) def __repr__(cls): - return f'' + return f"" @property def __members__(cls): @@ -144,10 +145,10 @@ class EnumMeta(type): return cls._enum_member_map_[key] def __setattr__(cls, name, value): - raise TypeError('Enums are immutable.') + raise TypeError("Enums are immutable.") def __delattr__(cls, attr): - raise TypeError('Enums are immutable') + raise TypeError("Enums are immutable") def __instancecheck__(self, instance): # isinstance(x, Y) @@ -215,29 +216,29 @@ class MessageType(Enum): class VoiceRegion(Enum): - us_west = 'us-west' - us_east = 'us-east' - us_south = 'us-south' - us_central = 'us-central' - eu_west = 'eu-west' - eu_central = 'eu-central' - singapore = 'singapore' - london = 'london' - sydney = 'sydney' - amsterdam = 'amsterdam' - frankfurt = 'frankfurt' - brazil = 'brazil' - hongkong = 'hongkong' - russia = 'russia' - japan = 'japan' - southafrica = 'southafrica' - south_korea = 'south-korea' - india = 'india' - europe = 'europe' - dubai = 'dubai' - vip_us_east = 'vip-us-east' - vip_us_west = 'vip-us-west' - vip_amsterdam = 'vip-amsterdam' + us_west = "us-west" + us_east = "us-east" + us_south = "us-south" + us_central = "us-central" + eu_west = "eu-west" + eu_central = "eu-central" + singapore = "singapore" + london = "london" + sydney = "sydney" + amsterdam = "amsterdam" + frankfurt = "frankfurt" + brazil = "brazil" + hongkong = "hongkong" + russia = "russia" + japan = "japan" + southafrica = "southafrica" + south_korea = "south-korea" + india = "india" + europe = "europe" + dubai = "dubai" + vip_us_east = "vip-us-east" + vip_us_west = "vip-us-west" + vip_amsterdam = "vip-amsterdam" def __str__(self): return self.value @@ -277,12 +278,12 @@ class ContentFilter(Enum, comparable=True): class Status(Enum): - online = 'online' - offline = 'offline' - idle = 'idle' - dnd = 'dnd' - do_not_disturb = 'dnd' - invisible = 'invisible' + online = "online" + offline = "offline" + idle = "idle" + dnd = "dnd" + do_not_disturb = "dnd" + invisible = "invisible" def __str__(self): return self.value @@ -415,33 +416,33 @@ class AuditLogAction(Enum): def target_type(self) -> Optional[str]: v = self.value if v == -1: - return 'all' + return "all" elif v < 10: - return 'guild' + return "guild" elif v < 20: - return 'channel' + return "channel" elif v < 30: - return 'user' + return "user" elif v < 40: - return 'role' + return "role" elif v < 50: - return 'invite' + return "invite" elif v < 60: - return 'webhook' + return "webhook" elif v < 70: - return 'emoji' + return "emoji" elif v == 73: - return 'channel' + return "channel" elif v < 80: - return 'message' + return "message" elif v < 83: - return 'integration' + return "integration" elif v < 90: - return 'stage_instance' + return "stage_instance" elif v < 93: - return 'sticker' + return "sticker" elif v < 113: - return 'thread' + return "thread" class UserFlags(Enum): @@ -589,12 +590,12 @@ class NSFWLevel(Enum, comparable=True): age_restricted = 3 -T = TypeVar('T') +T = TypeVar("T") def create_unknown_value(cls: Type[T], val: Any) -> T: value_cls = cls._enum_value_cls_ # type: ignore - name = f'unknown_{val}' + name = f"unknown_{val}" return value_cls(name=name, value=val) diff --git a/discord/errors.py b/discord/errors.py index bc2398d5..ba571c30 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -38,20 +38,20 @@ if TYPE_CHECKING: from .interactions import Interaction __all__ = ( - 'DiscordException', - 'ClientException', - 'NoMoreItems', - 'GatewayNotFound', - 'HTTPException', - 'Forbidden', - 'NotFound', - 'DiscordServerError', - 'InvalidData', - 'InvalidArgument', - 'LoginFailure', - 'ConnectionClosed', - 'PrivilegedIntentsRequired', - 'InteractionResponded', + "DiscordException", + "ClientException", + "NoMoreItems", + "GatewayNotFound", + "HTTPException", + "Forbidden", + "NotFound", + "DiscordServerError", + "InvalidData", + "InvalidArgument", + "LoginFailure", + "ConnectionClosed", + "PrivilegedIntentsRequired", + "InteractionResponded", ) @@ -83,22 +83,22 @@ class GatewayNotFound(DiscordException): """An exception that is raised when the gateway for Discord could not be found""" def __init__(self): - message = 'The gateway to connect to discord was not found.' + message = "The gateway to connect to discord was not found." super().__init__(message) -def _flatten_error_dict(d: Dict[str, Any], key: str = '') -> Dict[str, str]: +def _flatten_error_dict(d: Dict[str, Any], key: str = "") -> Dict[str, str]: items: List[Tuple[str, str]] = [] for k, v in d.items(): - new_key = key + '.' + k if key else k + new_key = key + "." + k if key else k if isinstance(v, dict): try: - _errors: List[Dict[str, Any]] = v['_errors'] + _errors: List[Dict[str, Any]] = v["_errors"] except KeyError: items.extend(_flatten_error_dict(v, new_key).items()) else: - items.append((new_key, ' '.join(x.get('message', '') for x in _errors))) + items.append((new_key, " ".join(x.get("message", "") for x in _errors))) else: items.append((new_key, v)) @@ -129,22 +129,22 @@ class HTTPException(DiscordException): self.code: int self.text: str if isinstance(message, dict): - self.code = message.get('code', 0) - base = message.get('message', '') - errors = message.get('errors') + self.code = message.get("code", 0) + base = message.get("message", "") + errors = message.get("errors") if errors: errors = _flatten_error_dict(errors) - helpful = '\n'.join('In %s: %s' % t for t in errors.items()) - self.text = base + '\n' + helpful + helpful = "\n".join("In %s: %s" % t for t in errors.items()) + self.text = base + "\n" + helpful else: self.text = base else: - self.text = message or '' + self.text = message or "" self.code = 0 - fmt = '{0.status} {0.reason} (error code: {1})' + fmt = "{0.status} {0.reason} (error code: {1})" if len(self.text): - fmt += ': {2}' + fmt += ": {2}" super().__init__(fmt.format(self.response, self.code, self.text)) @@ -226,9 +226,9 @@ class ConnectionClosed(ClientException): # reconfigured to subclass ClientException for users self.code: int = code or socket.close_code or -1 # aiohttp doesn't seem to consistently provide close reason - self.reason: str = '' + self.reason: str = "" self.shard_id: Optional[int] = shard_id - super().__init__(f'Shard ID {self.shard_id} WebSocket closed with {self.code}') + super().__init__(f"Shard ID {self.shard_id} WebSocket closed with {self.code}") class PrivilegedIntentsRequired(ClientException): @@ -250,10 +250,10 @@ class PrivilegedIntentsRequired(ClientException): def __init__(self, shard_id: Optional[int]): self.shard_id: Optional[int] = shard_id msg = ( - 'Shard ID %s is requesting privileged intents that have not been explicitly enabled in the ' - 'developer portal. It is recommended to go to https://discord.com/developers/applications/ ' - 'and explicitly enable the privileged intents within your application\'s page. If this is not ' - 'possible, then consider disabling the privileged intents instead.' + "Shard ID %s is requesting privileged intents that have not been explicitly enabled in the " + "developer portal. It is recommended to go to https://discord.com/developers/applications/ " + "and explicitly enable the privileged intents within your application's page. If this is not " + "possible, then consider disabling the privileged intents instead." ) super().__init__(msg % shard_id) @@ -274,4 +274,4 @@ class InteractionResponded(ClientException): def __init__(self, interaction: Interaction): self.interaction: Interaction = interaction - super().__init__('This interaction has already been responded to before') + super().__init__("This interaction has already been responded to before") diff --git a/discord/ext/commands/_types.py b/discord/ext/commands/_types.py index 9b155987..1f904ef9 100644 --- a/discord/ext/commands/_types.py +++ b/discord/ext/commands/_types.py @@ -31,7 +31,7 @@ if TYPE_CHECKING: from .cog import Cog from .errors import CommandError -T = TypeVar('T') +T = TypeVar("T") Coro = Coroutine[Any, Any, T] MaybeCoro = Union[T, Coro[T]] @@ -39,7 +39,9 @@ CoroFunc = Callable[..., Coro[Any]] Check = Union[Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]] Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]] -Error = Union[Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]] +Error = Union[ + Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]] +] # This is merely a tag type to avoid circular import issues. diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index e03562b6..8a341f90 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -54,17 +54,18 @@ if TYPE_CHECKING: ) __all__ = ( - 'when_mentioned', - 'when_mentioned_or', - 'Bot', - 'AutoShardedBot', + "when_mentioned", + "when_mentioned_or", + "Bot", + "AutoShardedBot", ) MISSING: Any = discord.utils.MISSING -T = TypeVar('T') -CFT = TypeVar('CFT', bound='CoroFunc') -CXT = TypeVar('CXT', bound='Context') +T = TypeVar("T") +CFT = TypeVar("CFT", bound="CoroFunc") +CXT = TypeVar("CXT", bound="Context") + def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]: """A callable that implements a command prefix equivalent to being mentioned. @@ -72,7 +73,8 @@ def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]: These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. """ # bot.user will never be None when this is called - return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore + return [f"<@{bot.user.id}> ", f"<@!{bot.user.id}> "] # type: ignore + def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]: """A callable that implements when mentioned or other prefixes provided. @@ -103,6 +105,7 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M ---------- :func:`.when_mentioned` """ + def inner(bot, msg): r = list(prefixes) r = when_mentioned(bot, msg) + r @@ -110,15 +113,19 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M return inner + def _is_submodule(parent: str, child: str) -> bool: return parent == child or child.startswith(parent + ".") + class _DefaultRepr: def __repr__(self): - return '' + return "" + _default = _DefaultRepr() + class BotBase(GroupMixin): def __init__(self, command_prefix, help_command=_default, description=None, *, intents: discord.Intents, **options): super().__init__(**options, intents=intents) @@ -131,16 +138,16 @@ class BotBase(GroupMixin): self._before_invoke = None self._after_invoke = None self._help_command = None - self.description = inspect.cleandoc(description) if description else '' - self.owner_id = options.get('owner_id') - self.owner_ids = options.get('owner_ids', set()) - self.strip_after_prefix = options.get('strip_after_prefix', False) + self.description = inspect.cleandoc(description) if description else "" + self.owner_id = options.get("owner_id") + self.owner_ids = options.get("owner_ids", set()) + self.strip_after_prefix = options.get("strip_after_prefix", False) if self.owner_id and self.owner_ids: - raise TypeError('Both owner_id and owner_ids are set.') + raise TypeError("Both owner_id and owner_ids are set.") if self.owner_ids and not isinstance(self.owner_ids, collections.abc.Collection): - raise TypeError(f'owner_ids must be a collection not {self.owner_ids.__class__!r}') + raise TypeError(f"owner_ids must be a collection not {self.owner_ids.__class__!r}") if help_command is _default: self.help_command = DefaultHelpCommand() @@ -152,7 +159,7 @@ class BotBase(GroupMixin): def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None: # super() will resolve to Client super().dispatch(event_name, *args, **kwargs) # type: ignore - ev = 'on_' + event_name + ev = "on_" + event_name for event in self.extra_events.get(ev, []): self._schedule_event(event, ev, *args, **kwargs) # type: ignore @@ -182,7 +189,7 @@ class BotBase(GroupMixin): This only fires if you do not specify any listeners for command error. """ - if self.extra_events.get('on_command_error', None): + if self.extra_events.get("on_command_error", None): return command = context.command @@ -193,7 +200,7 @@ class BotBase(GroupMixin): if cog and cog.has_error_handler(): return - print(f'Ignoring exception in command {context.command}:', file=sys.stderr) + print(f"Ignoring exception in command {context.command}:", file=sys.stderr) traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr) # global check registration @@ -425,7 +432,7 @@ class BotBase(GroupMixin): The coroutine passed is not actually a coroutine. """ if not asyncio.iscoroutinefunction(coro): - raise TypeError('The pre-invoke hook must be a coroutine.') + raise TypeError("The pre-invoke hook must be a coroutine.") self._before_invoke = coro return coro @@ -458,7 +465,7 @@ class BotBase(GroupMixin): The coroutine passed is not actually a coroutine. """ if not asyncio.iscoroutinefunction(coro): - raise TypeError('The post-invoke hook must be a coroutine.') + raise TypeError("The post-invoke hook must be a coroutine.") self._after_invoke = coro return coro @@ -490,7 +497,7 @@ class BotBase(GroupMixin): name = func.__name__ if name is MISSING else name if not asyncio.iscoroutinefunction(func): - raise TypeError('Listeners must be coroutines') + raise TypeError("Listeners must be coroutines") if name in self.extra_events: self.extra_events[name].append(func) @@ -586,14 +593,14 @@ class BotBase(GroupMixin): """ if not isinstance(cog, Cog): - raise TypeError('cogs must derive from Cog') + raise TypeError("cogs must derive from Cog") cog_name = cog.__cog_name__ existing = self.__cogs.get(cog_name) if existing is not None: if not override: - raise discord.ClientException(f'Cog named {cog_name!r} already loaded') + raise discord.ClientException(f"Cog named {cog_name!r} already loaded") self.remove_cog(cog_name) cog = cog._inject(self) @@ -681,7 +688,7 @@ class BotBase(GroupMixin): def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None: try: - func = getattr(lib, 'teardown') + func = getattr(lib, "teardown") except AttributeError: pass else: @@ -708,7 +715,7 @@ class BotBase(GroupMixin): raise errors.ExtensionFailed(key, e) from e try: - setup = getattr(lib, 'setup') + setup = getattr(lib, "setup") except AttributeError: del sys.modules[key] raise errors.NoEntryPointError(key) @@ -858,11 +865,7 @@ class BotBase(GroupMixin): raise errors.ExtensionNotLoaded(name) # get the previous module states from sys modules - modules = { - name: module - for name, module in sys.modules.items() - if _is_submodule(lib.__name__, name) - } + modules = {name: module for name, module in sys.modules.items() if _is_submodule(lib.__name__, name)} try: # Unload and then load the module... @@ -895,7 +898,7 @@ class BotBase(GroupMixin): def help_command(self, value: Optional[HelpCommand]) -> None: if value is not None: if not isinstance(value, HelpCommand): - raise TypeError('help_command must be a subclass of HelpCommand') + raise TypeError("help_command must be a subclass of HelpCommand") if self._help_command is not None: self._help_command._remove_from_bot(self) self._help_command = value @@ -938,8 +941,10 @@ class BotBase(GroupMixin): if isinstance(ret, collections.abc.Iterable): raise - raise TypeError("command_prefix must be plain string, iterable of strings, or callable " - f"returning either of these, not {ret.__class__.__name__}") + raise TypeError( + "command_prefix must be plain string, iterable of strings, or callable " + f"returning either of these, not {ret.__class__.__name__}" + ) if not ret: raise ValueError("Iterable command_prefix must contain at least one prefix") @@ -999,14 +1004,18 @@ class BotBase(GroupMixin): except TypeError: if not isinstance(prefix, list): - raise TypeError("get_prefix must return either a string or a list of string, " - f"not {prefix.__class__.__name__}") + raise TypeError( + "get_prefix must return either a string or a list of string, " + f"not {prefix.__class__.__name__}" + ) # It's possible a bad command_prefix got us here. for value in prefix: if not isinstance(value, str): - raise TypeError("Iterable command_prefix or list returned from get_prefix must " - f"contain only strings, not {value.__class__.__name__}") + raise TypeError( + "Iterable command_prefix or list returned from get_prefix must " + f"contain only strings, not {value.__class__.__name__}" + ) # Getting here shouldn't happen raise @@ -1033,19 +1042,19 @@ class BotBase(GroupMixin): The invocation context to invoke. """ if ctx.command is not None: - self.dispatch('command', ctx) + self.dispatch("command", ctx) try: if await self.can_run(ctx, call_once=True): await ctx.command.invoke(ctx) else: - raise errors.CheckFailure('The global check once functions failed.') + raise errors.CheckFailure("The global check once functions failed.") except errors.CommandError as exc: await ctx.command.dispatch_error(ctx, exc) else: - self.dispatch('command_completion', ctx) + self.dispatch("command_completion", ctx) elif ctx.invoked_with: exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found') - self.dispatch('command_error', ctx, exc) + self.dispatch("command_error", ctx, exc) async def process_commands(self, message: Message) -> None: """|coro| @@ -1078,6 +1087,7 @@ class BotBase(GroupMixin): async def on_message(self, message): await self.process_commands(message) + class Bot(BotBase, discord.Client): """Represents a discord bot. @@ -1148,10 +1158,13 @@ class Bot(BotBase, discord.Client): .. versionadded:: 1.7 """ + pass + class AutoShardedBot(BotBase, discord.AutoShardedClient): """This is similar to :class:`.Bot` except that it is inherited from :class:`discord.AutoShardedClient` instead. """ + pass diff --git a/discord/ext/commands/cog.py b/discord/ext/commands/cog.py index 9931557d..38de6226 100644 --- a/discord/ext/commands/cog.py +++ b/discord/ext/commands/cog.py @@ -36,15 +36,16 @@ if TYPE_CHECKING: from .core import Command __all__ = ( - 'CogMeta', - 'Cog', + "CogMeta", + "Cog", ) -CogT = TypeVar('CogT', bound='Cog') -FuncT = TypeVar('FuncT', bound=Callable[..., Any]) +CogT = TypeVar("CogT", bound="Cog") +FuncT = TypeVar("FuncT", bound=Callable[..., Any]) MISSING: Any = discord.utils.MISSING + class CogMeta(type): """A metaclass for defining a cog. @@ -104,6 +105,7 @@ class CogMeta(type): async def bar(self, ctx): pass # hidden -> False """ + __cog_name__: str __cog_settings__: Dict[str, Any] __cog_commands__: List[Command] @@ -111,17 +113,17 @@ class CogMeta(type): def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta: name, bases, attrs = args - attrs['__cog_name__'] = kwargs.pop('name', name) - attrs['__cog_settings__'] = kwargs.pop('command_attrs', {}) + attrs["__cog_name__"] = kwargs.pop("name", name) + attrs["__cog_settings__"] = kwargs.pop("command_attrs", {}) - description = kwargs.pop('description', None) + description = kwargs.pop("description", None) if description is None: - description = inspect.cleandoc(attrs.get('__doc__', '')) - attrs['__cog_description__'] = description + description = inspect.cleandoc(attrs.get("__doc__", "")) + attrs["__cog_description__"] = description commands = {} listeners = {} - no_bot_cog = 'Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})' + no_bot_cog = "Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})" new_cls = super().__new__(cls, name, bases, attrs, **kwargs) for base in reversed(new_cls.__mro__): @@ -136,21 +138,21 @@ class CogMeta(type): value = value.__func__ if isinstance(value, _BaseCommand): if is_static_method: - raise TypeError(f'Command in method {base}.{elem!r} must not be staticmethod.') - if elem.startswith(('cog_', 'bot_')): + raise TypeError(f"Command in method {base}.{elem!r} must not be staticmethod.") + if elem.startswith(("cog_", "bot_")): raise TypeError(no_bot_cog.format(base, elem)) commands[elem] = value elif inspect.iscoroutinefunction(value): try: - getattr(value, '__cog_listener__') + getattr(value, "__cog_listener__") except AttributeError: continue else: - if elem.startswith(('cog_', 'bot_')): + if elem.startswith(("cog_", "bot_")): raise TypeError(no_bot_cog.format(base, elem)) listeners[elem] = value - new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__ + new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__ listeners_as_list = [] for listener in listeners.values(): @@ -169,10 +171,12 @@ class CogMeta(type): def qualified_name(cls) -> str: return cls.__cog_name__ + def _cog_special_method(func: FuncT) -> FuncT: func.__cog_special_method__ = None return func + class Cog(metaclass=CogMeta): """The base class that all cogs must inherit from. @@ -183,6 +187,7 @@ class Cog(metaclass=CogMeta): When inheriting from this class, the options shown in :class:`CogMeta` are equally valid here. """ + __cog_name__: ClassVar[str] __cog_settings__: ClassVar[Dict[str, Any]] __cog_commands__: ClassVar[List[Command]] @@ -199,10 +204,7 @@ class Cog(metaclass=CogMeta): # r.e type ignore, type-checker complains about overriding a ClassVar self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore - lookup = { - cmd.qualified_name: cmd - for cmd in self.__cog_commands__ - } + lookup = {cmd.qualified_name: cmd for cmd in self.__cog_commands__} # Update the Command instances dynamically as well for command in self.__cog_commands__: @@ -255,6 +257,7 @@ class Cog(metaclass=CogMeta): A command or group from the cog. """ from .core import GroupMixin + for command in self.__cog_commands__: if command.parent is None: yield command @@ -274,7 +277,7 @@ class Cog(metaclass=CogMeta): @classmethod def _get_overridden_method(cls, method: FuncT) -> Optional[FuncT]: """Return None if the method is not overridden. Otherwise returns the overridden method.""" - return getattr(method.__func__, '__cog_special_method__', method) + return getattr(method.__func__, "__cog_special_method__", method) @classmethod def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]: @@ -296,14 +299,14 @@ class Cog(metaclass=CogMeta): """ if name is not MISSING and not isinstance(name, str): - raise TypeError(f'Cog.listener expected str but received {name.__class__.__name__!r} instead.') + raise TypeError(f"Cog.listener expected str but received {name.__class__.__name__!r} instead.") def decorator(func: FuncT) -> FuncT: actual = func if isinstance(actual, staticmethod): actual = actual.__func__ if not inspect.iscoroutinefunction(actual): - raise TypeError('Listener function must be a coroutine function.') + raise TypeError("Listener function must be a coroutine function.") actual.__cog_listener__ = True to_assign = name or actual.__name__ try: @@ -315,6 +318,7 @@ class Cog(metaclass=CogMeta): # to pick it up but the metaclass unfurls the function and # thus the assignments need to be on the actual function return func + return decorator def has_error_handler(self) -> bool: @@ -322,7 +326,7 @@ class Cog(metaclass=CogMeta): .. versionadded:: 1.7 """ - return not hasattr(self.cog_command_error.__func__, '__cog_special_method__') + return not hasattr(self.cog_command_error.__func__, "__cog_special_method__") @_cog_special_method def cog_unload(self) -> None: diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index fa16c74a..a4135793 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -50,21 +50,19 @@ if TYPE_CHECKING: from .help import HelpCommand from .view import StringView -__all__ = ( - 'Context', -) +__all__ = ("Context",) MISSING: Any = discord.utils.MISSING -T = TypeVar('T') -BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]") -CogT = TypeVar('CogT', bound="Cog") +T = TypeVar("T") +BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]") +CogT = TypeVar("CogT", bound="Cog") if TYPE_CHECKING: - P = ParamSpec('P') + P = ParamSpec("P") else: - P = TypeVar('P') + P = TypeVar("P") class Context(discord.abc.Messageable, Generic[BotT]): @@ -123,7 +121,8 @@ class Context(discord.abc.Messageable, Generic[BotT]): or invoked. """ - def __init__(self, + def __init__( + self, *, message: Message, bot: BotT, @@ -220,7 +219,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): cmd = self.command view = self.view if cmd is None: - raise ValueError('This context is not valid.') + raise ValueError("This context is not valid.") # some state to revert to when we're done index, previous = view.index, view.previous @@ -231,10 +230,10 @@ class Context(discord.abc.Messageable, Generic[BotT]): if restart: to_call = cmd.root_parent or cmd - view.index = len(self.prefix or '') + view.index = len(self.prefix or "") view.previous = 0 self.invoked_parents = [] - self.invoked_with = view.get_word() # advance to get the root command + self.invoked_with = view.get_word() # advance to get the root command else: to_call = cmd @@ -264,7 +263,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): .. versionadded:: 2.0 """ if self.prefix is None: - return '' + return "" user = self.me # this breaks if the prefix mention is not the bot itself but I @@ -272,7 +271,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): # for this common use case rather than waste performance for the # odd one. pattern = re.compile(r"<@!?%s>" % user.id) - return pattern.sub("@%s" % user.display_name.replace('\\', r'\\'), self.prefix) + return pattern.sub("@%s" % user.display_name.replace("\\", r"\\"), self.prefix) @property def cog(self) -> Optional[Cog]: @@ -389,7 +388,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): await cmd.prepare_help_command(self, entity.qualified_name) try: - if hasattr(entity, '__cog_commands__'): + if hasattr(entity, "__cog_commands__"): injected = wrap_callback(cmd.send_cog_help) return await injected(entity) elif isinstance(entity, Group): diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index 5740a188..a4791b8f 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -52,32 +52,32 @@ if TYPE_CHECKING: __all__ = ( - 'Converter', - 'ObjectConverter', - 'MemberConverter', - 'UserConverter', - 'MessageConverter', - 'PartialMessageConverter', - 'TextChannelConverter', - 'InviteConverter', - 'GuildConverter', - 'RoleConverter', - 'GameConverter', - 'ColourConverter', - 'ColorConverter', - 'VoiceChannelConverter', - 'StageChannelConverter', - 'EmojiConverter', - 'PartialEmojiConverter', - 'CategoryChannelConverter', - 'IDConverter', - 'StoreChannelConverter', - 'ThreadConverter', - 'GuildChannelConverter', - 'GuildStickerConverter', - 'clean_content', - 'Greedy', - 'run_converters', + "Converter", + "ObjectConverter", + "MemberConverter", + "UserConverter", + "MessageConverter", + "PartialMessageConverter", + "TextChannelConverter", + "InviteConverter", + "GuildConverter", + "RoleConverter", + "GameConverter", + "ColourConverter", + "ColorConverter", + "VoiceChannelConverter", + "StageChannelConverter", + "EmojiConverter", + "PartialEmojiConverter", + "CategoryChannelConverter", + "IDConverter", + "StoreChannelConverter", + "ThreadConverter", + "GuildChannelConverter", + "GuildStickerConverter", + "clean_content", + "Greedy", + "run_converters", ) @@ -91,10 +91,10 @@ def _get_from_guilds(bot, getter, argument): _utils_get = discord.utils.get -T = TypeVar('T') -T_co = TypeVar('T_co', covariant=True) -CT = TypeVar('CT', bound=discord.abc.GuildChannel) -TT = TypeVar('TT', bound=discord.Thread) +T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) +CT = TypeVar("CT", bound=discord.abc.GuildChannel) +TT = TypeVar("TT", bound=discord.Thread) @runtime_checkable @@ -132,10 +132,10 @@ class Converter(Protocol[T_co]): :exc:`.BadArgument` The converter failed to convert the argument. """ - raise NotImplementedError('Derived classes need to implement this.') + raise NotImplementedError("Derived classes need to implement this.") -_ID_REGEX = re.compile(r'([0-9]{15,20})$') +_ID_REGEX = re.compile(r"([0-9]{15,20})$") class IDConverter(Converter[T_co]): @@ -158,7 +158,7 @@ class ObjectConverter(IDConverter[discord.Object]): """ async def convert(self, ctx: Context, argument: str) -> discord.Object: - match = self._get_id_match(argument) or re.match(r'<(?:@(?:!|&)?|#)([0-9]{15,20})>$', argument) + match = self._get_id_match(argument) or re.match(r"<(?:@(?:!|&)?|#)([0-9]{15,20})>$", argument) if match is None: raise ObjectNotFound(argument) @@ -192,8 +192,8 @@ class MemberConverter(IDConverter[discord.Member]): async def query_member_named(self, guild, argument): cache = guild._state.member_cache_flags.joined - if len(argument) > 5 and argument[-5] == '#': - username, _, discriminator = argument.rpartition('#') + if len(argument) > 5 and argument[-5] == "#": + username, _, discriminator = argument.rpartition("#") members = await guild.query_members(username, limit=100, cache=cache) return discord.utils.get(members, name=username, discriminator=discriminator) else: @@ -223,7 +223,7 @@ class MemberConverter(IDConverter[discord.Member]): async def convert(self, ctx: Context, argument: str) -> discord.Member: bot = ctx.bot - match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument) + match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument) guild = ctx.guild result = None user_id = None @@ -232,13 +232,13 @@ class MemberConverter(IDConverter[discord.Member]): if guild: result = guild.get_member_named(argument) else: - result = _get_from_guilds(bot, 'get_member_named', argument) + result = _get_from_guilds(bot, "get_member_named", argument) else: user_id = int(match.group(1)) if guild: result = guild.get_member(user_id) or _utils_get(ctx.message.mentions, id=user_id) else: - result = _get_from_guilds(bot, 'get_member', user_id) + result = _get_from_guilds(bot, "get_member", user_id) if result is None: if guild is None: @@ -276,7 +276,7 @@ class UserConverter(IDConverter[discord.User]): """ async def convert(self, ctx: Context, argument: str) -> discord.User: - match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument) + match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument) result = None state = ctx._state @@ -294,12 +294,12 @@ class UserConverter(IDConverter[discord.User]): arg = argument # Remove the '@' character if this is the first character from the argument - if arg[0] == '@': + if arg[0] == "@": # Remove first character arg = arg[1:] # check for discriminator if it exists, - if len(arg) > 5 and arg[-5] == '#': + if len(arg) > 5 and arg[-5] == "#": discrim = arg[-4:] name = arg[:-5] predicate = lambda u: u.name == name and u.discriminator == discrim @@ -330,22 +330,22 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): @staticmethod def _get_id_matches(ctx, argument): - id_regex = re.compile(r'(?:(?P[0-9]{15,20})-)?(?P[0-9]{15,20})$') + id_regex = re.compile(r"(?:(?P[0-9]{15,20})-)?(?P[0-9]{15,20})$") link_regex = re.compile( - r'https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/' - r'(?P[0-9]{15,20}|@me)' - r'/(?P[0-9]{15,20})/(?P[0-9]{15,20})/?$' + r"https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/" + r"(?P[0-9]{15,20}|@me)" + r"/(?P[0-9]{15,20})/(?P[0-9]{15,20})/?$" ) match = id_regex.match(argument) or link_regex.match(argument) if not match: raise MessageNotFound(argument) data = match.groupdict() - channel_id = discord.utils._get_as_snowflake(data, 'channel_id') - message_id = int(data['message_id']) - guild_id = data.get('guild_id') + channel_id = discord.utils._get_as_snowflake(data, "channel_id") + message_id = int(data["message_id"]) + guild_id = data.get("guild_id") if guild_id is None: guild_id = ctx.guild and ctx.guild.id - elif guild_id == '@me': + elif guild_id == "@me": guild_id = None else: guild_id = int(guild_id) @@ -417,13 +417,13 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.abc.GuildChannel: - return self._resolve_channel(ctx, argument, 'channels', discord.abc.GuildChannel) + return self._resolve_channel(ctx, argument, "channels", discord.abc.GuildChannel) @staticmethod def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT: bot = ctx.bot - match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument) + match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument) result = None guild = ctx.guild @@ -443,7 +443,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): if guild: result = guild.get_channel(channel_id) else: - result = _get_from_guilds(bot, 'get_channel', channel_id) + result = _get_from_guilds(bot, "get_channel", channel_id) if not isinstance(result, type): raise ChannelNotFound(argument) @@ -454,7 +454,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT: bot = ctx.bot - match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument) + match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument) result = None guild = ctx.guild @@ -491,7 +491,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.TextChannel: - return GuildChannelConverter._resolve_channel(ctx, argument, 'text_channels', discord.TextChannel) + return GuildChannelConverter._resolve_channel(ctx, argument, "text_channels", discord.TextChannel) class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): @@ -511,7 +511,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel: - return GuildChannelConverter._resolve_channel(ctx, argument, 'voice_channels', discord.VoiceChannel) + return GuildChannelConverter._resolve_channel(ctx, argument, "voice_channels", discord.VoiceChannel) class StageChannelConverter(IDConverter[discord.StageChannel]): @@ -530,7 +530,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.StageChannel: - return GuildChannelConverter._resolve_channel(ctx, argument, 'stage_channels', discord.StageChannel) + return GuildChannelConverter._resolve_channel(ctx, argument, "stage_channels", discord.StageChannel) class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): @@ -550,7 +550,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel: - return GuildChannelConverter._resolve_channel(ctx, argument, 'categories', discord.CategoryChannel) + return GuildChannelConverter._resolve_channel(ctx, argument, "categories", discord.CategoryChannel) class StoreChannelConverter(IDConverter[discord.StoreChannel]): @@ -569,7 +569,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel: - return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel) + return GuildChannelConverter._resolve_channel(ctx, argument, "channels", discord.StoreChannel) class ThreadConverter(IDConverter[discord.Thread]): @@ -587,7 +587,7 @@ class ThreadConverter(IDConverter[discord.Thread]): """ async def convert(self, ctx: Context, argument: str) -> discord.Thread: - return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread) + return GuildChannelConverter._resolve_thread(ctx, argument, "threads", discord.Thread) class ColourConverter(Converter[discord.Colour]): @@ -616,10 +616,10 @@ class ColourConverter(Converter[discord.Colour]): Added support for ``rgb`` function and 3-digit hex shortcuts """ - RGB_REGEX = re.compile(r'rgb\s*\((?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*\)') + RGB_REGEX = re.compile(r"rgb\s*\((?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*\)") def parse_hex_number(self, argument): - arg = ''.join(i * 2 for i in argument) if len(argument) == 3 else argument + arg = "".join(i * 2 for i in argument) if len(argument) == 3 else argument try: value = int(arg, base=16) if not (0 <= value <= 0xFFFFFF): @@ -630,7 +630,7 @@ class ColourConverter(Converter[discord.Colour]): return discord.Color(value=value) def parse_rgb_number(self, argument, number): - if number[-1] == '%': + if number[-1] == "%": value = int(number[:-1]) if not (0 <= value <= 100): raise BadColourArgument(argument) @@ -646,29 +646,29 @@ class ColourConverter(Converter[discord.Colour]): if match is None: raise BadColourArgument(argument) - red = self.parse_rgb_number(argument, match.group('r')) - green = self.parse_rgb_number(argument, match.group('g')) - blue = self.parse_rgb_number(argument, match.group('b')) + red = self.parse_rgb_number(argument, match.group("r")) + green = self.parse_rgb_number(argument, match.group("g")) + blue = self.parse_rgb_number(argument, match.group("b")) return discord.Color.from_rgb(red, green, blue) async def convert(self, ctx: Context, argument: str) -> discord.Colour: - if argument[0] == '#': + if argument[0] == "#": return self.parse_hex_number(argument[1:]) - if argument[0:2] == '0x': + if argument[0:2] == "0x": rest = argument[2:] # Legacy backwards compatible syntax - if rest.startswith('#'): + if rest.startswith("#"): return self.parse_hex_number(rest[1:]) return self.parse_hex_number(rest) arg = argument.lower() - if arg[0:3] == 'rgb': + if arg[0:3] == "rgb": return self.parse_rgb(arg) - arg = arg.replace(' ', '_') + arg = arg.replace(" ", "_") method = getattr(discord.Colour, arg, None) - if arg.startswith('from_') or method is None or not inspect.ismethod(method): + if arg.startswith("from_") or method is None or not inspect.ismethod(method): raise BadColourArgument(arg) return method() @@ -697,7 +697,7 @@ class RoleConverter(IDConverter[discord.Role]): if not guild: raise NoPrivateMessage() - match = self._get_id_match(argument) or re.match(r'<@&([0-9]{15,20})>$', argument) + match = self._get_id_match(argument) or re.match(r"<@&([0-9]{15,20})>$", argument) if match: result = guild.get_role(int(match.group(1))) else: @@ -776,7 +776,7 @@ class EmojiConverter(IDConverter[discord.Emoji]): """ async def convert(self, ctx: Context, argument: str) -> discord.Emoji: - match = self._get_id_match(argument) or re.match(r'$', argument) + match = self._get_id_match(argument) or re.match(r"$", argument) result = None bot = ctx.bot guild = ctx.guild @@ -810,7 +810,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]): """ async def convert(self, ctx: Context, argument: str) -> discord.PartialEmoji: - match = re.match(r'<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$', argument) + match = re.match(r"<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$", argument) if match: emoji_animated = bool(match.group(1)) @@ -903,37 +903,37 @@ class clean_content(Converter[str]): def resolve_member(id: int) -> str: m = _utils_get(msg.mentions, id=id) or ctx.guild.get_member(id) - return f'@{m.display_name if self.use_nicknames else m.name}' if m else '@deleted-user' + return f"@{m.display_name if self.use_nicknames else m.name}" if m else "@deleted-user" def resolve_role(id: int) -> str: r = _utils_get(msg.role_mentions, id=id) or ctx.guild.get_role(id) - return f'@{r.name}' if r else '@deleted-role' + return f"@{r.name}" if r else "@deleted-role" else: def resolve_member(id: int) -> str: m = _utils_get(msg.mentions, id=id) or ctx.bot.get_user(id) - return f'@{m.name}' if m else '@deleted-user' + return f"@{m.name}" if m else "@deleted-user" def resolve_role(id: int) -> str: - return '@deleted-role' + return "@deleted-role" if self.fix_channel_mentions and ctx.guild: def resolve_channel(id: int) -> str: c = ctx.guild.get_channel(id) - return f'#{c.name}' if c else '#deleted-channel' + return f"#{c.name}" if c else "#deleted-channel" else: def resolve_channel(id: int) -> str: - return f'<#{id}>' + return f"<#{id}>" transforms = { - '@': resolve_member, - '@!': resolve_member, - '#': resolve_channel, - '@&': resolve_role, + "@": resolve_member, + "@!": resolve_member, + "#": resolve_channel, + "@&": resolve_role, } def repl(match: re.Match) -> str: @@ -942,7 +942,7 @@ class clean_content(Converter[str]): transformed = transforms[type](id) return transformed - result = re.sub(r'<(@[!&]?|#)([0-9]{15,20})>', repl, argument) + result = re.sub(r"<(@[!&]?|#)([0-9]{15,20})>", repl, argument) if self.escape_markdown: result = discord.utils.escape_markdown(result) elif self.remove_markdown: @@ -974,42 +974,42 @@ class Greedy(List[T]): For more information, check :ref:`ext_commands_special_converters`. """ - __slots__ = ('converter',) + __slots__ = ("converter",) def __init__(self, *, converter: T): self.converter = converter def __repr__(self): - converter = getattr(self.converter, '__name__', repr(self.converter)) - return f'Greedy[{converter}]' + converter = getattr(self.converter, "__name__", repr(self.converter)) + return f"Greedy[{converter}]" def __class_getitem__(cls, params: Union[Tuple[T], T]) -> Greedy[T]: if not isinstance(params, tuple): params = (params,) if len(params) != 1: - raise TypeError('Greedy[...] only takes a single argument') + raise TypeError("Greedy[...] only takes a single argument") converter = params[0] - origin = getattr(converter, '__origin__', None) - args = getattr(converter, '__args__', ()) + origin = getattr(converter, "__origin__", None) + args = getattr(converter, "__args__", ()) if not (callable(converter) or isinstance(converter, Converter) or origin is not None): - raise TypeError('Greedy[...] expects a type or a Converter instance.') + raise TypeError("Greedy[...] expects a type or a Converter instance.") if converter in (str, type(None)) or origin is Greedy: - raise TypeError(f'Greedy[{converter.__name__}] is invalid.') + raise TypeError(f"Greedy[{converter.__name__}] is invalid.") if origin is Union and type(None) in args: - raise TypeError(f'Greedy[{converter!r}] is invalid.') + raise TypeError(f"Greedy[{converter!r}] is invalid.") return cls(converter=converter) def _convert_to_bool(argument: str) -> bool: lowered = argument.lower() - if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'): + if lowered in ("yes", "y", "true", "t", "1", "enable", "on"): return True - elif lowered in ('no', 'n', 'false', 'f', '0', 'disable', 'off'): + elif lowered in ("no", "n", "false", "f", "0", "disable", "off"): return False else: raise BadBoolArgument(lowered) @@ -1065,7 +1065,7 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp except AttributeError: pass else: - if module is not None and (module.startswith('discord.') and not module.endswith('converter')): + if module is not None and (module.startswith("discord.") and not module.endswith("converter")): converter = CONVERTER_MAPPING.get(converter, converter) try: @@ -1124,7 +1124,7 @@ async def run_converters(ctx: Context, converter, argument: str, param: inspect. Any The resulting conversion. """ - origin = getattr(converter, '__origin__', None) + origin = getattr(converter, "__origin__", None) if origin is Union: errors = [] diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index 2e008aed..3808a989 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -38,24 +38,25 @@ if TYPE_CHECKING: from ...message import Message __all__ = ( - 'BucketType', - 'Cooldown', - 'CooldownMapping', - 'DynamicCooldownMapping', - 'MaxConcurrency', + "BucketType", + "Cooldown", + "CooldownMapping", + "DynamicCooldownMapping", + "MaxConcurrency", ) -C = TypeVar('C', bound='CooldownMapping') -MC = TypeVar('MC', bound='MaxConcurrency') +C = TypeVar("C", bound="CooldownMapping") +MC = TypeVar("MC", bound="MaxConcurrency") + class BucketType(Enum): - default = 0 - user = 1 - guild = 2 - channel = 3 - member = 4 + default = 0 + user = 1 + guild = 2 + channel = 3 + member = 4 category = 5 - role = 6 + role = 6 def get_key(self, msg: Message) -> Any: if self is BucketType.user: @@ -90,7 +91,7 @@ class Cooldown: The length of the cooldown period in seconds. """ - __slots__ = ('rate', 'per', '_window', '_tokens', '_last') + __slots__ = ("rate", "per", "_window", "_tokens", "_last") def __init__(self, rate: float, per: float) -> None: self.rate: int = int(rate) @@ -190,7 +191,8 @@ class Cooldown: return Cooldown(self.rate, self.per) def __repr__(self) -> str: - return f'' + return f"" + class CooldownMapping: def __init__( @@ -199,7 +201,7 @@ class CooldownMapping: type: Callable[[Message], Any], ) -> None: if not callable(type): - raise TypeError('Cooldown type must be a BucketType or callable') + raise TypeError("Cooldown type must be a BucketType or callable") self._cache: Dict[Any, Cooldown] = {} self._cooldown: Optional[Cooldown] = original @@ -256,13 +258,9 @@ class CooldownMapping: bucket = self.get_bucket(message, current) return bucket.update_rate_limit(current) -class DynamicCooldownMapping(CooldownMapping): - def __init__( - self, - factory: Callable[[Message], Cooldown], - type: Callable[[Message], Any] - ) -> None: +class DynamicCooldownMapping(CooldownMapping): + def __init__(self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any]) -> None: super().__init__(None, type) self._factory: Callable[[Message], Cooldown] = factory @@ -278,6 +276,7 @@ class DynamicCooldownMapping(CooldownMapping): def create_bucket(self, message: Message) -> Cooldown: return self._factory(message) + class _Semaphore: """This class is a version of a semaphore. @@ -291,7 +290,7 @@ class _Semaphore: overkill for what is basically a counter. """ - __slots__ = ('value', 'loop', '_waiters') + __slots__ = ("value", "loop", "_waiters") def __init__(self, number: int) -> None: self.value: int = number @@ -299,7 +298,7 @@ class _Semaphore: self._waiters: Deque[asyncio.Future] = deque() def __repr__(self) -> str: - return f'<_Semaphore value={self.value} waiters={len(self._waiters)}>' + return f"<_Semaphore value={self.value} waiters={len(self._waiters)}>" def locked(self) -> bool: return self.value == 0 @@ -337,8 +336,9 @@ class _Semaphore: self.value += 1 self.wake_up() + class MaxConcurrency: - __slots__ = ('number', 'per', 'wait', '_mapping') + __slots__ = ("number", "per", "wait", "_mapping") def __init__(self, number: int, *, per: BucketType, wait: bool) -> None: self._mapping: Dict[Any, _Semaphore] = {} @@ -347,16 +347,16 @@ class MaxConcurrency: self.wait: bool = wait if number <= 0: - raise ValueError('max_concurrency \'number\' cannot be less than 1') + raise ValueError("max_concurrency 'number' cannot be less than 1") if not isinstance(per, BucketType): - raise TypeError(f'max_concurrency \'per\' must be of type BucketType not {type(per)!r}') + raise TypeError(f"max_concurrency 'per' must be of type BucketType not {type(per)!r}") def copy(self: MC) -> MC: return self.__class__(self.number, per=self.per, wait=self.wait) def __repr__(self) -> str: - return f'' + return f"" def get_key(self, message: Message) -> Any: return self.per.get_key(message) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index f122e9ad..836b799a 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -70,52 +70,53 @@ if TYPE_CHECKING: __all__ = ( - 'Command', - 'Group', - 'GroupMixin', - 'command', - 'group', - 'has_role', - 'has_permissions', - 'has_any_role', - 'check', - 'check_any', - 'before_invoke', - 'after_invoke', - 'bot_has_role', - 'bot_has_permissions', - 'bot_has_any_role', - 'cooldown', - 'dynamic_cooldown', - 'max_concurrency', - 'dm_only', - 'guild_only', - 'is_owner', - 'is_nsfw', - 'has_guild_permissions', - 'bot_has_guild_permissions' + "Command", + "Group", + "GroupMixin", + "command", + "group", + "has_role", + "has_permissions", + "has_any_role", + "check", + "check_any", + "before_invoke", + "after_invoke", + "bot_has_role", + "bot_has_permissions", + "bot_has_any_role", + "cooldown", + "dynamic_cooldown", + "max_concurrency", + "dm_only", + "guild_only", + "is_owner", + "is_nsfw", + "has_guild_permissions", + "bot_has_guild_permissions", ) MISSING: Any = discord.utils.MISSING -T = TypeVar('T') -CogT = TypeVar('CogT', bound='Cog') -CommandT = TypeVar('CommandT', bound='Command') -ContextT = TypeVar('ContextT', bound='Context') +T = TypeVar("T") +CogT = TypeVar("CogT", bound="Cog") +CommandT = TypeVar("CommandT", bound="Command") +ContextT = TypeVar("ContextT", bound="Context") # CHT = TypeVar('CHT', bound='Check') -GroupT = TypeVar('GroupT', bound='Group') -HookT = TypeVar('HookT', bound='Hook') -ErrorT = TypeVar('ErrorT', bound='Error') +GroupT = TypeVar("GroupT", bound="Group") +HookT = TypeVar("HookT", bound="Hook") +ErrorT = TypeVar("ErrorT", bound="Error") if TYPE_CHECKING: - P = ParamSpec('P') + P = ParamSpec("P") else: - P = TypeVar('P') + P = TypeVar("P") + def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: partial = functools.partial while True: - if hasattr(function, '__wrapped__'): + if hasattr(function, "__wrapped__"): function = function.__wrapped__ elif isinstance(function, partial): function = function.func @@ -139,7 +140,7 @@ def get_signature_parameters(function: Callable[..., Any], globalns: Dict[str, A annotation = eval_annotation(annotation, globalns, globalns, cache) if annotation is Greedy: - raise TypeError('Unparameterized Greedy[...] is disallowed in signature.') + raise TypeError("Unparameterized Greedy[...] is disallowed in signature.") params[name] = parameter.replace(annotation=annotation) @@ -158,8 +159,10 @@ def wrap_callback(coro): except Exception as exc: raise CommandInvokeError(exc) from exc return ret + return wrapped + def hooked_wrapped_callback(command, ctx, coro): @functools.wraps(coro) async def wrapped(*args, **kwargs): @@ -180,6 +183,7 @@ def hooked_wrapped_callback(command, ctx, coro): await command.call_after_hooks(ctx) return ret + return wrapped @@ -202,6 +206,7 @@ class _CaseInsensitiveDict(dict): def __setitem__(self, k, v): super().__setitem__(k.casefold(), v) + class Command(_BaseCommand, Generic[CogT, P, T]): r"""A class that implements the protocol for a bot text command. @@ -269,8 +274,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]): which calls converters. If ``False`` then cooldown processing is done first and then the converters are called second. Defaults to ``False``. extras: :class:`dict` - A dict of user provided extras to attach to the Command. - + A dict of user provided extras to attach to the Command. + .. note:: This object may be copied by the library. @@ -295,56 +300,60 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.__original_kwargs__ = kwargs.copy() return self - def __init__(self, func: Union[ + def __init__( + self, + func: Union[ Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]], - ], **kwargs: Any): + ], + **kwargs: Any, + ): if not asyncio.iscoroutinefunction(func): - raise TypeError('Callback must be a coroutine.') + raise TypeError("Callback must be a coroutine.") - name = kwargs.get('name') or func.__name__ + name = kwargs.get("name") or func.__name__ if not isinstance(name, str): - raise TypeError('Name of a command must be a string.') + raise TypeError("Name of a command must be a string.") self.name: str = name self.callback = func - self.enabled: bool = kwargs.get('enabled', True) + self.enabled: bool = kwargs.get("enabled", True) - help_doc = kwargs.get('help') + help_doc = kwargs.get("help") if help_doc is not None: help_doc = inspect.cleandoc(help_doc) else: help_doc = inspect.getdoc(func) if isinstance(help_doc, bytes): - help_doc = help_doc.decode('utf-8') + help_doc = help_doc.decode("utf-8") self.help: Optional[str] = help_doc - self.brief: Optional[str] = kwargs.get('brief') - self.usage: Optional[str] = kwargs.get('usage') - self.rest_is_raw: bool = kwargs.get('rest_is_raw', False) - self.aliases: Union[List[str], Tuple[str]] = kwargs.get('aliases', []) - self.extras: Dict[str, Any] = kwargs.get('extras', {}) + self.brief: Optional[str] = kwargs.get("brief") + self.usage: Optional[str] = kwargs.get("usage") + self.rest_is_raw: bool = kwargs.get("rest_is_raw", False) + self.aliases: Union[List[str], Tuple[str]] = kwargs.get("aliases", []) + self.extras: Dict[str, Any] = kwargs.get("extras", {}) if not isinstance(self.aliases, (list, tuple)): raise TypeError("Aliases of a command must be a list or a tuple of strings.") - self.description: str = inspect.cleandoc(kwargs.get('description', '')) - self.hidden: bool = kwargs.get('hidden', False) + self.description: str = inspect.cleandoc(kwargs.get("description", "")) + self.hidden: bool = kwargs.get("hidden", False) try: checks = func.__commands_checks__ checks.reverse() except AttributeError: - checks = kwargs.get('checks', []) + checks = kwargs.get("checks", []) self.checks: List[Check] = checks try: cooldown = func.__commands_cooldown__ except AttributeError: - cooldown = kwargs.get('cooldown') - + cooldown = kwargs.get("cooldown") + if cooldown is None: buckets = CooldownMapping(cooldown, BucketType.default) elif isinstance(cooldown, CooldownMapping): @@ -356,17 +365,17 @@ class Command(_BaseCommand, Generic[CogT, P, T]): try: max_concurrency = func.__commands_max_concurrency__ except AttributeError: - max_concurrency = kwargs.get('max_concurrency') + max_concurrency = kwargs.get("max_concurrency") self._max_concurrency: Optional[MaxConcurrency] = max_concurrency - self.require_var_positional: bool = kwargs.get('require_var_positional', False) - self.ignore_extra: bool = kwargs.get('ignore_extra', True) - self.cooldown_after_parsing: bool = kwargs.get('cooldown_after_parsing', False) + self.require_var_positional: bool = kwargs.get("require_var_positional", False) + self.ignore_extra: bool = kwargs.get("ignore_extra", True) + self.cooldown_after_parsing: bool = kwargs.get("cooldown_after_parsing", False) self.cog: Optional[CogT] = None # bandaid for the fact that sometimes parent can be the bot instance - parent = kwargs.get('parent') + parent = kwargs.get("parent") self.parent: Optional[GroupMixin] = parent if isinstance(parent, _BaseCommand) else None # type: ignore self._before_invoke: Optional[Hook] = None @@ -386,17 +395,19 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.after_invoke(after_invoke) @property - def callback(self) -> Union[ - Callable[Concatenate[CogT, Context, P], Coro[T]], - Callable[Concatenate[Context, P], Coro[T]], - ]: + def callback( + self, + ) -> Union[Callable[Concatenate[CogT, Context, P], Coro[T]], Callable[Concatenate[Context, P], Coro[T]],]: return self._callback @callback.setter - def callback(self, function: Union[ + def callback( + self, + function: Union[ Callable[Concatenate[CogT, Context, P], Coro[T]], Callable[Concatenate[Context, P], Coro[T]], - ]) -> None: + ], + ) -> None: self._callback = function unwrap = unwrap_function(function) self.module = unwrap.__module__ @@ -527,7 +538,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): wrapped = wrap_callback(local) await wrapped(ctx, error) finally: - ctx.bot.dispatch('command_error', ctx, error) + ctx.bot.dispatch("command_error", ctx, error) async def transform(self, ctx: Context, param: inspect.Parameter) -> Any: required = param.default is param.empty @@ -551,11 +562,11 @@ class Command(_BaseCommand, Generic[CogT, P, T]): if view.eof: if param.kind == param.VAR_POSITIONAL: - raise RuntimeError() # break the loop + raise RuntimeError() # break the loop if required: if self._is_typing_optional(param.annotation): return None - if hasattr(converter, '__commands_is_flag__') and converter._can_be_constructible(): + if hasattr(converter, "__commands_is_flag__") and converter._can_be_constructible(): return await converter._construct_default(ctx) raise MissingRequiredArgument(param) return param.default @@ -577,7 +588,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): # type-checker fails to narrow argument return await run_converters(ctx, converter, argument, param) # type: ignore - async def _transform_greedy_pos(self, ctx: Context, param: inspect.Parameter, required: bool, converter: Any) -> Any: + async def _transform_greedy_pos( + self, ctx: Context, param: inspect.Parameter, required: bool, converter: Any + ) -> Any: view = ctx.view result = [] while not view.eof: @@ -606,7 +619,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): value = await run_converters(ctx, converter, argument, param) # type: ignore except (CommandError, ArgumentParsingError): view.index = previous - raise RuntimeError() from None # break loop + raise RuntimeError() from None # break loop else: return value @@ -643,11 +656,11 @@ class Command(_BaseCommand, Generic[CogT, P, T]): entries = [] command = self # command.parent is type-hinted as GroupMixin some attributes are resolved via MRO - while command.parent is not None: # type: ignore - command = command.parent # type: ignore - entries.append(command.name) # type: ignore + while command.parent is not None: # type: ignore + command = command.parent # type: ignore + entries.append(command.name) # type: ignore - return ' '.join(reversed(entries)) + return " ".join(reversed(entries)) @property def parents(self) -> List[Group]: @@ -661,8 +674,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]): """ entries = [] command = self - while command.parent is not None: # type: ignore - command = command.parent # type: ignore + while command.parent is not None: # type: ignore + command = command.parent # type: ignore entries.append(command) return entries @@ -690,7 +703,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): parent = self.full_parent_name if parent: - return parent + ' ' + self.name + return parent + " " + self.name else: return self.name @@ -745,7 +758,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): break if not self.ignore_extra and not view.eof: - raise TooManyArguments('Too many arguments passed to ' + self.qualified_name) + raise TooManyArguments("Too many arguments passed to " + self.qualified_name) async def call_before_hooks(self, ctx: Context) -> None: # now that we're done preparing we can call the pre-command hooks @@ -753,7 +766,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): cog = self.cog if self._before_invoke is not None: # should be cog if @commands.before_invoke is used - instance = getattr(self._before_invoke, '__self__', cog) + instance = getattr(self._before_invoke, "__self__", cog) # __self__ only exists for methods, not functions # however, if @command.before_invoke is used, it will be a function if instance: @@ -775,7 +788,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): async def call_after_hooks(self, ctx: Context) -> None: cog = self.cog if self._after_invoke is not None: - instance = getattr(self._after_invoke, '__self__', cog) + instance = getattr(self._after_invoke, "__self__", cog) if instance: await self._after_invoke(instance, ctx) # type: ignore else: @@ -805,7 +818,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): ctx.command = self if not await self.can_run(ctx): - raise CheckFailure(f'The check functions for command {self.qualified_name} failed.') + raise CheckFailure(f"The check functions for command {self.qualified_name} failed.") if self._max_concurrency is not None: # For this application, context can be duck-typed as a Message @@ -929,7 +942,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): """ if not asyncio.iscoroutinefunction(coro): - raise TypeError('The error handler must be a coroutine.') + raise TypeError("The error handler must be a coroutine.") self.on_error: Error = coro return coro @@ -939,7 +952,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): .. versionadded:: 1.7 """ - return hasattr(self, 'on_error') + return hasattr(self, "on_error") def before_invoke(self, coro: HookT) -> HookT: """A decorator that registers a coroutine as a pre-invoke hook. @@ -963,7 +976,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): The coroutine passed is not actually a coroutine. """ if not asyncio.iscoroutinefunction(coro): - raise TypeError('The pre-invoke hook must be a coroutine.') + raise TypeError("The pre-invoke hook must be a coroutine.") self._before_invoke = coro return coro @@ -990,7 +1003,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): The coroutine passed is not actually a coroutine. """ if not asyncio.iscoroutinefunction(coro): - raise TypeError('The post-invoke hook must be a coroutine.') + raise TypeError("The post-invoke hook must be a coroutine.") self._after_invoke = coro return coro @@ -1011,11 +1024,11 @@ class Command(_BaseCommand, Generic[CogT, P, T]): if self.brief is not None: return self.brief if self.help is not None: - return self.help.split('\n', 1)[0] - return '' + return self.help.split("\n", 1)[0] + return "" def _is_typing_optional(self, annotation: Union[T, Optional[T]]) -> TypeGuard[Optional[T]]: - return getattr(annotation, '__origin__', None) is Union and type(None) in annotation.__args__ # type: ignore + return getattr(annotation, "__origin__", None) is Union and type(None) in annotation.__args__ # type: ignore @property def signature(self) -> str: @@ -1025,7 +1038,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): params = self.clean_params if not params: - return '' + return "" result = [] for name, param in params.items(): @@ -1035,41 +1048,40 @@ class Command(_BaseCommand, Generic[CogT, P, T]): # for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the # parameter signature is a literal list of it's values annotation = param.annotation.converter if greedy else param.annotation - origin = getattr(annotation, '__origin__', None) + origin = getattr(annotation, "__origin__", None) if not greedy and origin is Union: none_cls = type(None) union_args = annotation.__args__ optional = union_args[-1] is none_cls if len(union_args) == 2 and optional: annotation = union_args[0] - origin = getattr(annotation, '__origin__', None) + origin = getattr(annotation, "__origin__", None) if origin is Literal: - name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__) + name = "|".join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__) if param.default is not param.empty: # We don't want None or '' to trigger the [name=value] case and instead it should # do [name] since [name=None] or [name=] are not exactly useful for the user. should_print = param.default if isinstance(param.default, str) else param.default is not None if should_print: - result.append(f'[{name}={param.default}]' if not greedy else - f'[{name}={param.default}]...') + result.append(f"[{name}={param.default}]" if not greedy else f"[{name}={param.default}]...") continue else: - result.append(f'[{name}]') + result.append(f"[{name}]") elif param.kind == param.VAR_POSITIONAL: if self.require_var_positional: - result.append(f'<{name}...>') + result.append(f"<{name}...>") else: - result.append(f'[{name}...]') + result.append(f"[{name}...]") elif greedy: - result.append(f'[{name}]...') + result.append(f"[{name}]...") elif optional: - result.append(f'[{name}]') + result.append(f"[{name}]") else: - result.append(f'<{name}>') + result.append(f"<{name}>") - return ' '.join(result) + return " ".join(result) async def can_run(self, ctx: Context) -> bool: """|coro| @@ -1099,14 +1111,14 @@ class Command(_BaseCommand, Generic[CogT, P, T]): """ if not self.enabled: - raise DisabledCommand(f'{self.name} command is disabled') + raise DisabledCommand(f"{self.name} command is disabled") original = ctx.command ctx.command = self try: if not await ctx.bot.can_run(ctx): - raise CheckFailure(f'The global check functions for command {self.qualified_name} failed.') + raise CheckFailure(f"The global check functions for command {self.qualified_name} failed.") cog = self.cog if cog is not None: @@ -1125,6 +1137,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): finally: ctx.command = original + class GroupMixin(Generic[CogT]): """A mixin that implements common functionality for classes that behave similar to :class:`.Group` and are allowed to register commands. @@ -1137,8 +1150,9 @@ class GroupMixin(Generic[CogT]): case_insensitive: :class:`bool` Whether the commands should be case insensitive. Defaults to ``True``. """ + def __init__(self, *args: Any, **kwargs: Any) -> None: - case_insensitive = kwargs.get('case_insensitive', True) + case_insensitive = kwargs.get("case_insensitive", True) self.all_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {} self.case_insensitive: bool = case_insensitive super().__init__(*args, **kwargs) @@ -1177,7 +1191,7 @@ class GroupMixin(Generic[CogT]): """ if not isinstance(command, Command): - raise TypeError('The command passed must be a subclass of Command') + raise TypeError("The command passed must be a subclass of Command") if isinstance(self, Command): command.parent = self @@ -1267,7 +1281,7 @@ class GroupMixin(Generic[CogT]): """ # fast path, no space in name. - if ' ' not in name: + if " " not in name: return self.all_commands.get(name) names = name.split() @@ -1298,7 +1312,9 @@ class GroupMixin(Generic[CogT]): Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]], ] - ], Command[CogT, P, T]]: + ], + Command[CogT, P, T], + ]: ... @overload @@ -1326,8 +1342,9 @@ class GroupMixin(Generic[CogT]): Callable[..., :class:`Command`] A decorator that converts the provided method into a Command, adds it to the bot, then returns it. """ + def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> CommandT: - kwargs.setdefault('parent', self) + kwargs.setdefault("parent", self) result = command(name=name, cls=cls, *args, **kwargs)(func) self.add_command(result) return result @@ -1341,12 +1358,10 @@ class GroupMixin(Generic[CogT]): cls: Type[Group[CogT, P, T]] = ..., *args: Any, **kwargs: Any, - ) -> Callable[[ - Union[ - Callable[Concatenate[CogT, ContextT, P], Coro[T]], - Callable[Concatenate[ContextT, P], Coro[T]] - ] - ], Group[CogT, P, T]]: + ) -> Callable[ + [Union[Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]]]], + Group[CogT, P, T], + ]: ... @overload @@ -1374,14 +1389,16 @@ class GroupMixin(Generic[CogT]): Callable[..., :class:`Group`] A decorator that converts the provided method into a Group, adds it to the bot, then returns it. """ + def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> GroupT: - kwargs.setdefault('parent', self) + kwargs.setdefault("parent", self) result = group(name=name, cls=cls, *args, **kwargs)(func) self.add_command(result) return result return decorator + class Group(GroupMixin[CogT], Command[CogT, P, T]): """A class that implements a grouping protocol for commands to be executed as subcommands. @@ -1404,8 +1421,9 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]): Indicates if the group's commands should be case insensitive. Defaults to ``False``. """ + def __init__(self, *args: Any, **attrs: Any) -> None: - self.invoke_without_command: bool = attrs.pop('invoke_without_command', False) + self.invoke_without_command: bool = attrs.pop("invoke_without_command", False) super().__init__(*args, **attrs) def copy(self: GroupT) -> GroupT: @@ -1492,8 +1510,10 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]): view.previous = previous await super().reinvoke(ctx, call_hooks=call_hooks) + # Decorators + @overload def command( name: str = ..., @@ -1505,10 +1525,12 @@ def command( Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]], ] - ] -, Command[CogT, P, T]]: + ], + Command[CogT, P, T], +]: ... + @overload def command( name: str = ..., @@ -1520,22 +1542,23 @@ def command( Callable[Concatenate[CogT, ContextT, P], Coro[Any]], Callable[Concatenate[ContextT, P], Coro[Any]], ] - ] -, CommandT]: + ], + CommandT, +]: ... + def command( - name: str = MISSING, - cls: Type[CommandT] = MISSING, - **attrs: Any + name: str = MISSING, cls: Type[CommandT] = MISSING, **attrs: Any ) -> Callable[ [ Union[ Callable[Concatenate[ContextT, P], Coro[Any]], Callable[Concatenate[CogT, ContextT, P], Coro[T]], ] - ] -, Union[Command[CogT, P, T], CommandT]]: + ], + Union[Command[CogT, P, T], CommandT], +]: """A decorator that transforms a function into a :class:`.Command` or if called with :func:`.group`, :class:`.Group`. @@ -1568,16 +1591,19 @@ def command( if cls is MISSING: cls = Command # type: ignore - def decorator(func: Union[ + def decorator( + func: Union[ Callable[Concatenate[ContextT, P], Coro[Any]], Callable[Concatenate[CogT, ContextT, P], Coro[Any]], - ]) -> CommandT: + ] + ) -> CommandT: if isinstance(func, Command): - raise TypeError('Callback is already a command.') + raise TypeError("Callback is already a command.") return cls(func, name=name, **attrs) return decorator + @overload def group( name: str = ..., @@ -1589,10 +1615,12 @@ def group( Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]], ] - ] -, Group[CogT, P, T]]: + ], + Group[CogT, P, T], +]: ... + @overload def group( name: str = ..., @@ -1604,10 +1632,12 @@ def group( Callable[Concatenate[CogT, ContextT, P], Coro[Any]], Callable[Concatenate[ContextT, P], Coro[Any]], ] - ] -, GroupT]: + ], + GroupT, +]: ... + def group( name: str = MISSING, cls: Type[GroupT] = MISSING, @@ -1618,8 +1648,9 @@ def group( Callable[Concatenate[ContextT, P], Coro[Any]], Callable[Concatenate[CogT, ContextT, P], Coro[T]], ] - ] -, Union[Group[CogT, P, T], GroupT]]: + ], + Union[Group[CogT, P, T], GroupT], +]: """A decorator that transforms a function into a :class:`.Group`. This is similar to the :func:`.command` decorator but the ``cls`` @@ -1632,6 +1663,7 @@ def group( cls = Group # type: ignore return command(name=name, cls=cls, **attrs) # type: ignore + def check(predicate: Check) -> Callable[[T], T]: r"""A decorator that adds a check to the :class:`.Command` or its subclasses. These checks could be accessed via :attr:`.Command.checks`. @@ -1707,7 +1739,7 @@ def check(predicate: Check) -> Callable[[T], T]: if isinstance(func, Command): func.checks.append(predicate) else: - if not hasattr(func, '__commands_checks__'): + if not hasattr(func, "__commands_checks__"): func.__commands_checks__ = [] func.__commands_checks__.append(predicate) @@ -1717,13 +1749,16 @@ def check(predicate: Check) -> Callable[[T], T]: if inspect.iscoroutinefunction(predicate): decorator.predicate = predicate else: + @functools.wraps(predicate) async def wrapper(ctx): return predicate(ctx) # type: ignore + decorator.predicate = wrapper return decorator # type: ignore + def check_any(*checks: Check) -> Callable[[T], T]: r"""A :func:`check` that is added that checks if any of the checks passed will pass, i.e. using logical OR. @@ -1773,7 +1808,7 @@ def check_any(*checks: Check) -> Callable[[T], T]: try: pred = wrapped.predicate except AttributeError: - raise TypeError(f'{wrapped!r} must be wrapped by commands.check decorator') from None + raise TypeError(f"{wrapped!r} must be wrapped by commands.check decorator") from None else: unwrapped.append(pred) @@ -1792,6 +1827,7 @@ def check_any(*checks: Check) -> Callable[[T], T]: return check(predicate) + def has_role(item: Union[int, str]) -> Callable[[T], T]: """A :func:`.check` that is added that checks if the member invoking the command has the role specified via the name or ID specified. @@ -1834,6 +1870,7 @@ def has_role(item: Union[int, str]) -> Callable[[T], T]: return check(predicate) + def has_any_role(*items: Union[int, str]) -> Callable[[T], T]: r"""A :func:`.check` that is added that checks if the member invoking the command has **any** of the roles specified. This means that if they have @@ -1865,18 +1902,22 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]: async def cool(ctx): await ctx.send('You are cool indeed') """ + def predicate(ctx): if ctx.guild is None: raise NoPrivateMessage() # ctx.guild is None doesn't narrow ctx.author to Member getter = functools.partial(discord.utils.get, ctx.author.roles) # type: ignore - if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items): + if any( + getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items + ): return True raise MissingAnyRole(list(items)) return check(predicate) + def bot_has_role(item: int) -> Callable[[T], T]: """Similar to :func:`.has_role` except checks if the bot itself has the role. @@ -1903,8 +1944,10 @@ def bot_has_role(item: int) -> Callable[[T], T]: if role is None: raise BotMissingRole(item) return True + return check(predicate) + def bot_has_any_role(*items: int) -> Callable[[T], T]: """Similar to :func:`.has_any_role` except checks if the bot itself has any of the roles listed. @@ -1918,17 +1961,22 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]: Raise :exc:`.BotMissingAnyRole` or :exc:`.NoPrivateMessage` instead of generic checkfailure """ + def predicate(ctx): if ctx.guild is None: raise NoPrivateMessage() me = ctx.me getter = functools.partial(discord.utils.get, me.roles) - if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items): + if any( + getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items + ): return True raise BotMissingAnyRole(list(items)) + return check(predicate) + def has_permissions(**perms: bool) -> Callable[[T], T]: """A :func:`.check` that is added that checks if the member has all of the permissions necessary. @@ -1976,6 +2024,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]: return check(predicate) + def bot_has_permissions(**perms: bool) -> Callable[[T], T]: """Similar to :func:`.has_permissions` except checks if the bot itself has the permissions listed. @@ -2002,6 +2051,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]: return check(predicate) + def has_guild_permissions(**perms: bool) -> Callable[[T], T]: """Similar to :func:`.has_permissions`, but operates on guild wide permissions instead of the current channel permissions. @@ -2030,6 +2080,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]: return check(predicate) + def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]: """Similar to :func:`.has_guild_permissions`, but checks the bot members guild permissions. @@ -2055,6 +2106,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]: return check(predicate) + def dm_only() -> Callable[[T], T]: """A :func:`.check` that indicates this command must only be used in a DM context. Only private messages are allowed when @@ -2073,6 +2125,7 @@ def dm_only() -> Callable[[T], T]: return check(predicate) + def guild_only() -> Callable[[T], T]: """A :func:`.check` that indicates this command must only be used in a guild context only. Basically, no private messages are allowed when @@ -2089,6 +2142,7 @@ def guild_only() -> Callable[[T], T]: return check(predicate) + def is_owner() -> Callable[[T], T]: """A :func:`.check` that checks if the person invoking this command is the owner of the bot. @@ -2101,11 +2155,12 @@ def is_owner() -> Callable[[T], T]: async def predicate(ctx: Context) -> bool: if not await ctx.bot.is_owner(ctx.author): - raise NotOwner('You do not own this bot.') + raise NotOwner("You do not own this bot.") return True return check(predicate) + def is_nsfw() -> Callable[[T], T]: """A :func:`.check` that checks if the channel is a NSFW channel. @@ -2117,14 +2172,19 @@ def is_nsfw() -> Callable[[T], T]: Raise :exc:`.NSFWChannelRequired` instead of generic :exc:`.CheckFailure`. DM channels will also now pass this check. """ + def pred(ctx: Context) -> bool: ch = ctx.channel if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()): return True raise NSFWChannelRequired(ch) # type: ignore + return check(pred) -def cooldown(rate: int, per: float, type: Union[BucketType, Callable[[Message], Any]] = BucketType.default) -> Callable[[T], T]: + +def cooldown( + rate: int, per: float, type: Union[BucketType, Callable[[Message], Any]] = BucketType.default +) -> Callable[[T], T]: """A decorator that adds a cooldown to a :class:`.Command` A cooldown allows a command to only be used a specific amount @@ -2157,9 +2217,13 @@ def cooldown(rate: int, per: float, type: Union[BucketType, Callable[[Message], else: func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type) return func + return decorator # type: ignore -def dynamic_cooldown(cooldown: Union[BucketType, Callable[[Message], Any]], type: BucketType = BucketType.default) -> Callable[[T], T]: + +def dynamic_cooldown( + cooldown: Union[BucketType, Callable[[Message], Any]], type: BucketType = BucketType.default +) -> Callable[[T], T]: """A decorator that adds a dynamic cooldown to a :class:`.Command` This differs from :func:`.cooldown` in that it takes a function that @@ -2197,8 +2261,10 @@ def dynamic_cooldown(cooldown: Union[BucketType, Callable[[Message], Any]], type else: func.__commands_cooldown__ = DynamicCooldownMapping(cooldown, type) return func + return decorator # type: ignore + def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait: bool = False) -> Callable[[T], T]: """A decorator that adds a maximum concurrency to a :class:`.Command` or its subclasses. @@ -2230,8 +2296,10 @@ def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait: else: func.__commands_max_concurrency__ = value return func + return decorator # type: ignore + def before_invoke(coro) -> Callable[[T], T]: """A decorator that registers a coroutine as a pre-invoke hook. @@ -2270,14 +2338,17 @@ def before_invoke(coro) -> Callable[[T], T]: bot.add_cog(What()) """ + def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: if isinstance(func, Command): func.before_invoke(coro) else: func.__before_invoke__ = coro return func + return decorator # type: ignore + def after_invoke(coro) -> Callable[[T], T]: """A decorator that registers a coroutine as a post-invoke hook. @@ -2286,10 +2357,12 @@ def after_invoke(coro) -> Callable[[T], T]: .. versionadded:: 1.4 """ + def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: if isinstance(func, Command): func.after_invoke(coro) else: func.__after_invoke__ = coro return func + return decorator # type: ignore diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index 93834385..7f6f5cb4 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -41,65 +41,66 @@ if TYPE_CHECKING: __all__ = ( - 'CommandError', - 'MissingRequiredArgument', - 'BadArgument', - 'PrivateMessageOnly', - 'NoPrivateMessage', - 'CheckFailure', - 'CheckAnyFailure', - 'CommandNotFound', - 'DisabledCommand', - 'CommandInvokeError', - 'TooManyArguments', - 'UserInputError', - 'CommandOnCooldown', - 'MaxConcurrencyReached', - 'NotOwner', - 'MessageNotFound', - 'ObjectNotFound', - 'MemberNotFound', - 'GuildNotFound', - 'UserNotFound', - 'ChannelNotFound', - 'ThreadNotFound', - 'ChannelNotReadable', - 'BadColourArgument', - 'BadColorArgument', - 'RoleNotFound', - 'BadInviteArgument', - 'EmojiNotFound', - 'GuildStickerNotFound', - 'PartialEmojiConversionFailure', - 'BadBoolArgument', - 'MissingRole', - 'BotMissingRole', - 'MissingAnyRole', - 'BotMissingAnyRole', - 'MissingPermissions', - 'BotMissingPermissions', - 'NSFWChannelRequired', - 'ConversionError', - 'BadUnionArgument', - 'BadLiteralArgument', - 'ArgumentParsingError', - 'UnexpectedQuoteError', - 'InvalidEndOfQuotedStringError', - 'ExpectedClosingQuoteError', - 'ExtensionError', - 'ExtensionAlreadyLoaded', - 'ExtensionNotLoaded', - 'NoEntryPointError', - 'ExtensionFailed', - 'ExtensionNotFound', - 'CommandRegistrationError', - 'FlagError', - 'BadFlagArgument', - 'MissingFlagArgument', - 'TooManyFlags', - 'MissingRequiredFlag', + "CommandError", + "MissingRequiredArgument", + "BadArgument", + "PrivateMessageOnly", + "NoPrivateMessage", + "CheckFailure", + "CheckAnyFailure", + "CommandNotFound", + "DisabledCommand", + "CommandInvokeError", + "TooManyArguments", + "UserInputError", + "CommandOnCooldown", + "MaxConcurrencyReached", + "NotOwner", + "MessageNotFound", + "ObjectNotFound", + "MemberNotFound", + "GuildNotFound", + "UserNotFound", + "ChannelNotFound", + "ThreadNotFound", + "ChannelNotReadable", + "BadColourArgument", + "BadColorArgument", + "RoleNotFound", + "BadInviteArgument", + "EmojiNotFound", + "GuildStickerNotFound", + "PartialEmojiConversionFailure", + "BadBoolArgument", + "MissingRole", + "BotMissingRole", + "MissingAnyRole", + "BotMissingAnyRole", + "MissingPermissions", + "BotMissingPermissions", + "NSFWChannelRequired", + "ConversionError", + "BadUnionArgument", + "BadLiteralArgument", + "ArgumentParsingError", + "UnexpectedQuoteError", + "InvalidEndOfQuotedStringError", + "ExpectedClosingQuoteError", + "ExtensionError", + "ExtensionAlreadyLoaded", + "ExtensionNotLoaded", + "NoEntryPointError", + "ExtensionFailed", + "ExtensionNotFound", + "CommandRegistrationError", + "FlagError", + "BadFlagArgument", + "MissingFlagArgument", + "TooManyFlags", + "MissingRequiredFlag", ) + class CommandError(DiscordException): r"""The base exception type for all command related errors. @@ -109,14 +110,16 @@ class CommandError(DiscordException): in a special way as they are caught and passed into a special event from :class:`.Bot`\, :func:`.on_command_error`. """ + def __init__(self, message: Optional[str] = None, *args: Any) -> None: if message is not None: # clean-up @everyone and @here mentions - m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere') + m = message.replace("@everyone", "@\u200beveryone").replace("@here", "@\u200bhere") super().__init__(m, *args) else: super().__init__(*args) + class ConversionError(CommandError): """Exception raised when a Converter class raises non-CommandError. @@ -130,18 +133,22 @@ class ConversionError(CommandError): The original exception that was raised. You can also get this via the ``__cause__`` attribute. """ + def __init__(self, converter: Converter, original: Exception) -> None: self.converter: Converter = converter self.original: Exception = original + class UserInputError(CommandError): """The base exception type for errors that involve errors regarding user input. This inherits from :exc:`CommandError`. """ + pass + class CommandNotFound(CommandError): """Exception raised when a command is attempted to be invoked but no command under that name is found. @@ -151,8 +158,10 @@ class CommandNotFound(CommandError): This inherits from :exc:`CommandError`. """ + pass + class MissingRequiredArgument(UserInputError): """Exception raised when parsing a command and a parameter that is required is not encountered. @@ -164,9 +173,11 @@ class MissingRequiredArgument(UserInputError): param: :class:`inspect.Parameter` The argument that is missing. """ + def __init__(self, param: Parameter) -> None: self.param: Parameter = param - super().__init__(f'{param.name} is a required argument that is missing.') + super().__init__(f"{param.name} is a required argument that is missing.") + class TooManyArguments(UserInputError): """Exception raised when the command was passed too many arguments and its @@ -174,23 +185,29 @@ class TooManyArguments(UserInputError): This inherits from :exc:`UserInputError` """ + pass + class BadArgument(UserInputError): """Exception raised when a parsing or conversion failure is encountered on an argument to pass into a command. This inherits from :exc:`UserInputError` """ + pass + class CheckFailure(CommandError): """Exception raised when the predicates in :attr:`.Command.checks` have failed. This inherits from :exc:`CommandError` """ + pass + class CheckAnyFailure(CheckFailure): """Exception raised when all predicates in :func:`check_any` fail. @@ -209,7 +226,8 @@ class CheckAnyFailure(CheckFailure): def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]) -> None: self.checks: List[CheckFailure] = checks self.errors: List[Callable[[Context], bool]] = errors - super().__init__('You do not have permission to run this command.') + super().__init__("You do not have permission to run this command.") + class PrivateMessageOnly(CheckFailure): """Exception raised when an operation does not work outside of private @@ -217,8 +235,10 @@ class PrivateMessageOnly(CheckFailure): This inherits from :exc:`CheckFailure` """ + def __init__(self, message: Optional[str] = None) -> None: - super().__init__(message or 'This command can only be used in private messages.') + super().__init__(message or "This command can only be used in private messages.") + class NoPrivateMessage(CheckFailure): """Exception raised when an operation does not work in private message @@ -228,15 +248,18 @@ class NoPrivateMessage(CheckFailure): """ def __init__(self, message: Optional[str] = None) -> None: - super().__init__(message or 'This command cannot be used in private messages.') + super().__init__(message or "This command cannot be used in private messages.") + class NotOwner(CheckFailure): """Exception raised when the message author is not the owner of the bot. This inherits from :exc:`CheckFailure` """ + pass + class ObjectNotFound(BadArgument): """Exception raised when the argument provided did not match the format of an ID or a mention. @@ -250,9 +273,11 @@ class ObjectNotFound(BadArgument): argument: :class:`str` The argument supplied by the caller that was not matched """ + def __init__(self, argument: str) -> None: self.argument: str = argument - super().__init__(f'{argument!r} does not follow a valid ID or mention format.') + super().__init__(f"{argument!r} does not follow a valid ID or mention format.") + class MemberNotFound(BadArgument): """Exception raised when the member provided was not found in the bot's @@ -267,10 +292,12 @@ class MemberNotFound(BadArgument): argument: :class:`str` The member supplied by the caller that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Member "{argument}" not found.') + class GuildNotFound(BadArgument): """Exception raised when the guild provided was not found in the bot's cache. @@ -283,10 +310,12 @@ class GuildNotFound(BadArgument): argument: :class:`str` The guild supplied by the called that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Guild "{argument}" not found.') + class UserNotFound(BadArgument): """Exception raised when the user provided was not found in the bot's cache. @@ -300,10 +329,12 @@ class UserNotFound(BadArgument): argument: :class:`str` The user supplied by the caller that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'User "{argument}" not found.') + class MessageNotFound(BadArgument): """Exception raised when the message provided was not found in the channel. @@ -316,10 +347,12 @@ class MessageNotFound(BadArgument): argument: :class:`str` The message supplied by the caller that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Message "{argument}" not found.') + class ChannelNotReadable(BadArgument): """Exception raised when the bot does not have permission to read messages in the channel. @@ -333,10 +366,12 @@ class ChannelNotReadable(BadArgument): argument: Union[:class:`.abc.GuildChannel`, :class:`.Thread`] The channel supplied by the caller that was not readable """ + def __init__(self, argument: Union[GuildChannel, Thread]) -> None: self.argument: Union[GuildChannel, Thread] = argument super().__init__(f"Can't read messages in {argument.mention}.") + class ChannelNotFound(BadArgument): """Exception raised when the bot can not find the channel. @@ -349,10 +384,12 @@ class ChannelNotFound(BadArgument): argument: :class:`str` The channel supplied by the caller that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Channel "{argument}" not found.') + class ThreadNotFound(BadArgument): """Exception raised when the bot can not find the thread. @@ -365,10 +402,12 @@ class ThreadNotFound(BadArgument): argument: :class:`str` The thread supplied by the caller that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Thread "{argument}" not found.') + class BadColourArgument(BadArgument): """Exception raised when the colour is not valid. @@ -381,12 +420,15 @@ class BadColourArgument(BadArgument): argument: :class:`str` The colour supplied by the caller that was not valid """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Colour "{argument}" is invalid.') + BadColorArgument = BadColourArgument + class RoleNotFound(BadArgument): """Exception raised when the bot can not find the role. @@ -399,10 +441,12 @@ class RoleNotFound(BadArgument): argument: :class:`str` The role supplied by the caller that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Role "{argument}" not found.') + class BadInviteArgument(BadArgument): """Exception raised when the invite is invalid or expired. @@ -410,10 +454,12 @@ class BadInviteArgument(BadArgument): .. versionadded:: 1.5 """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Invite "{argument}" is invalid or expired.') + class EmojiNotFound(BadArgument): """Exception raised when the bot can not find the emoji. @@ -426,10 +472,12 @@ class EmojiNotFound(BadArgument): argument: :class:`str` The emoji supplied by the caller that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Emoji "{argument}" not found.') + class PartialEmojiConversionFailure(BadArgument): """Exception raised when the emoji provided does not match the correct format. @@ -443,10 +491,12 @@ class PartialEmojiConversionFailure(BadArgument): argument: :class:`str` The emoji supplied by the caller that did not match the regex """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Couldn\'t convert "{argument}" to PartialEmoji.') + class GuildStickerNotFound(BadArgument): """Exception raised when the bot can not find the sticker. @@ -459,10 +509,12 @@ class GuildStickerNotFound(BadArgument): argument: :class:`str` The sticker supplied by the caller that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Sticker "{argument}" not found.') + class BadBoolArgument(BadArgument): """Exception raised when a boolean argument was not convertable. @@ -475,17 +527,21 @@ class BadBoolArgument(BadArgument): argument: :class:`str` The boolean argument supplied by the caller that is not in the predefined list """ + def __init__(self, argument: str) -> None: self.argument: str = argument - super().__init__(f'{argument} is not a recognised boolean option') + super().__init__(f"{argument} is not a recognised boolean option") + class DisabledCommand(CommandError): """Exception raised when the command being invoked is disabled. This inherits from :exc:`CommandError` """ + pass + class CommandInvokeError(CommandError): """Exception raised when the command being invoked raised an exception. @@ -497,9 +553,11 @@ class CommandInvokeError(CommandError): The original exception that was raised. You can also get this via the ``__cause__`` attribute. """ + def __init__(self, e: Exception) -> None: self.original: Exception = e - super().__init__(f'Command raised an exception: {e.__class__.__name__}: {e}') + super().__init__(f"Command raised an exception: {e.__class__.__name__}: {e}") + class CommandOnCooldown(CommandError): """Exception raised when the command being invoked is on cooldown. @@ -516,11 +574,13 @@ class CommandOnCooldown(CommandError): retry_after: :class:`float` The amount of seconds to wait before you can retry again. """ + def __init__(self, cooldown: Cooldown, retry_after: float, type: BucketType) -> None: self.cooldown: Cooldown = cooldown self.retry_after: float = retry_after self.type: BucketType = type - super().__init__(f'You are on cooldown. Try again in {retry_after:.2f}s') + super().__init__(f"You are on cooldown. Try again in {retry_after:.2f}s") + class MaxConcurrencyReached(CommandError): """Exception raised when the command being invoked has reached its maximum concurrency. @@ -539,10 +599,11 @@ class MaxConcurrencyReached(CommandError): self.number: int = number self.per: BucketType = per name = per.name - suffix = 'per %s' % name if per.name != 'default' else 'globally' - plural = '%s times %s' if number > 1 else '%s time %s' + suffix = "per %s" % name if per.name != "default" else "globally" + plural = "%s times %s" if number > 1 else "%s time %s" fmt = plural % (number, suffix) - super().__init__(f'Too many people are using this command. It can only be used {fmt} concurrently.') + super().__init__(f"Too many people are using this command. It can only be used {fmt} concurrently.") + class MissingRole(CheckFailure): """Exception raised when the command invoker lacks a role to run a command. @@ -557,11 +618,13 @@ class MissingRole(CheckFailure): The required role that is missing. This is the parameter passed to :func:`~.commands.has_role`. """ + def __init__(self, missing_role: Snowflake) -> None: self.missing_role: Snowflake = missing_role - message = f'Role {missing_role!r} is required to run this command.' + message = f"Role {missing_role!r} is required to run this command." super().__init__(message) + class BotMissingRole(CheckFailure): """Exception raised when the bot's member lacks a role to run a command. @@ -575,11 +638,13 @@ class BotMissingRole(CheckFailure): The required role that is missing. This is the parameter passed to :func:`~.commands.has_role`. """ + def __init__(self, missing_role: Snowflake) -> None: self.missing_role: Snowflake = missing_role - message = f'Bot requires the role {missing_role!r} to run this command' + message = f"Bot requires the role {missing_role!r} to run this command" super().__init__(message) + class MissingAnyRole(CheckFailure): """Exception raised when the command invoker lacks any of the roles specified to run a command. @@ -594,15 +659,16 @@ class MissingAnyRole(CheckFailure): The roles that the invoker is missing. These are the parameters passed to :func:`~.commands.has_any_role`. """ + def __init__(self, missing_roles: SnowflakeList) -> None: self.missing_roles: SnowflakeList = missing_roles missing = [f"'{role}'" for role in missing_roles] if len(missing) > 2: - fmt = '{}, or {}'.format(", ".join(missing[:-1]), missing[-1]) + fmt = "{}, or {}".format(", ".join(missing[:-1]), missing[-1]) else: - fmt = ' or '.join(missing) + fmt = " or ".join(missing) message = f"You are missing at least one of the required roles: {fmt}" super().__init__(message) @@ -623,19 +689,21 @@ class BotMissingAnyRole(CheckFailure): These are the parameters passed to :func:`~.commands.has_any_role`. """ + def __init__(self, missing_roles: SnowflakeList) -> None: self.missing_roles: SnowflakeList = missing_roles missing = [f"'{role}'" for role in missing_roles] if len(missing) > 2: - fmt = '{}, or {}'.format(", ".join(missing[:-1]), missing[-1]) + fmt = "{}, or {}".format(", ".join(missing[:-1]), missing[-1]) else: - fmt = ' or '.join(missing) + fmt = " or ".join(missing) message = f"Bot is missing at least one of the required roles: {fmt}" super().__init__(message) + class NSFWChannelRequired(CheckFailure): """Exception raised when a channel does not have the required NSFW setting. @@ -648,10 +716,12 @@ class NSFWChannelRequired(CheckFailure): channel: Union[:class:`.abc.GuildChannel`, :class:`.Thread`] The channel that does not have NSFW enabled. """ + def __init__(self, channel: Union[GuildChannel, Thread]) -> None: self.channel: Union[GuildChannel, Thread] = channel super().__init__(f"Channel '{channel}' needs to be NSFW for this command to work.") + class MissingPermissions(CheckFailure): """Exception raised when the command invoker lacks permissions to run a command. @@ -663,18 +733,20 @@ class MissingPermissions(CheckFailure): missing_permissions: List[:class:`str`] The required permissions that are missing. """ + def __init__(self, missing_permissions: List[str], *args: Any) -> None: self.missing_permissions: List[str] = missing_permissions - missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions] + missing = [perm.replace("_", " ").replace("guild", "server").title() for perm in missing_permissions] if len(missing) > 2: - fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1]) + fmt = "{}, and {}".format(", ".join(missing[:-1]), missing[-1]) else: - fmt = ' and '.join(missing) - message = f'You are missing {fmt} permission(s) to run this command.' + fmt = " and ".join(missing) + message = f"You are missing {fmt} permission(s) to run this command." super().__init__(message, *args) + class BotMissingPermissions(CheckFailure): """Exception raised when the bot's member lacks permissions to run a command. @@ -686,18 +758,20 @@ class BotMissingPermissions(CheckFailure): missing_permissions: List[:class:`str`] The required permissions that are missing. """ + def __init__(self, missing_permissions: List[str], *args: Any) -> None: self.missing_permissions: List[str] = missing_permissions - missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions] + missing = [perm.replace("_", " ").replace("guild", "server").title() for perm in missing_permissions] if len(missing) > 2: - fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1]) + fmt = "{}, and {}".format(", ".join(missing[:-1]), missing[-1]) else: - fmt = ' and '.join(missing) - message = f'Bot requires {fmt} permission(s) to run this command.' + fmt = " and ".join(missing) + message = f"Bot requires {fmt} permission(s) to run this command." super().__init__(message, *args) + class BadUnionArgument(UserInputError): """Exception raised when a :data:`typing.Union` converter fails for all its associated types. @@ -713,6 +787,7 @@ class BadUnionArgument(UserInputError): errors: List[:class:`CommandError`] A list of errors that were caught from failing the conversion. """ + def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None: self.param: Parameter = param self.converters: Tuple[Type, ...] = converters @@ -722,18 +797,19 @@ class BadUnionArgument(UserInputError): try: return x.__name__ except AttributeError: - if hasattr(x, '__origin__'): + if hasattr(x, "__origin__"): return repr(x) return x.__class__.__name__ to_string = [_get_name(x) for x in converters] if len(to_string) > 2: - fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1]) + fmt = "{}, or {}".format(", ".join(to_string[:-1]), to_string[-1]) else: - fmt = ' or '.join(to_string) + fmt = " or ".join(to_string) super().__init__(f'Could not convert "{param.name}" into {fmt}.') + class BadLiteralArgument(UserInputError): """Exception raised when a :data:`typing.Literal` converter fails for all its associated values. @@ -751,6 +827,7 @@ class BadLiteralArgument(UserInputError): errors: List[:class:`CommandError`] A list of errors that were caught from failing the conversion. """ + def __init__(self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError]) -> None: self.param: Parameter = param self.literals: Tuple[Any, ...] = literals @@ -758,12 +835,13 @@ class BadLiteralArgument(UserInputError): to_string = [repr(l) for l in literals] if len(to_string) > 2: - fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1]) + fmt = "{}, or {}".format(", ".join(to_string[:-1]), to_string[-1]) else: - fmt = ' or '.join(to_string) + fmt = " or ".join(to_string) super().__init__(f'Could not convert "{param.name}" into the literal {fmt}.') + class ArgumentParsingError(UserInputError): """An exception raised when the parser fails to parse a user's input. @@ -772,8 +850,10 @@ class ArgumentParsingError(UserInputError): There are child classes that implement more granular parsing errors for i18n purposes. """ + pass + class UnexpectedQuoteError(ArgumentParsingError): """An exception raised when the parser encounters a quote mark inside a non-quoted string. @@ -784,9 +864,11 @@ class UnexpectedQuoteError(ArgumentParsingError): quote: :class:`str` The quote mark that was found inside the non-quoted string. """ + def __init__(self, quote: str) -> None: self.quote: str = quote - super().__init__(f'Unexpected quote mark, {quote!r}, in non-quoted string') + super().__init__(f"Unexpected quote mark, {quote!r}, in non-quoted string") + class InvalidEndOfQuotedStringError(ArgumentParsingError): """An exception raised when a space is expected after the closing quote in a string @@ -799,9 +881,11 @@ class InvalidEndOfQuotedStringError(ArgumentParsingError): char: :class:`str` The character found instead of the expected string. """ + def __init__(self, char: str) -> None: self.char: str = char - super().__init__(f'Expected space after closing quotation but received {char!r}') + super().__init__(f"Expected space after closing quotation but received {char!r}") + class ExpectedClosingQuoteError(ArgumentParsingError): """An exception raised when a quote character is expected but not found. @@ -816,7 +900,8 @@ class ExpectedClosingQuoteError(ArgumentParsingError): def __init__(self, close_quote: str) -> None: self.close_quote: str = close_quote - super().__init__(f'Expected closing {close_quote}.') + super().__init__(f"Expected closing {close_quote}.") + class ExtensionError(DiscordException): """Base exception for extension related errors. @@ -828,37 +913,45 @@ class ExtensionError(DiscordException): name: :class:`str` The extension that had an error. """ + def __init__(self, message: Optional[str] = None, *args: Any, name: str) -> None: self.name: str = name - message = message or f'Extension {name!r} had an error.' + message = message or f"Extension {name!r} had an error." # clean-up @everyone and @here mentions - m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere') + m = message.replace("@everyone", "@\u200beveryone").replace("@here", "@\u200bhere") super().__init__(m, *args) + class ExtensionAlreadyLoaded(ExtensionError): """An exception raised when an extension has already been loaded. This inherits from :exc:`ExtensionError` """ + def __init__(self, name: str) -> None: - super().__init__(f'Extension {name!r} is already loaded.', name=name) + super().__init__(f"Extension {name!r} is already loaded.", name=name) + class ExtensionNotLoaded(ExtensionError): """An exception raised when an extension was not loaded. This inherits from :exc:`ExtensionError` """ + def __init__(self, name: str) -> None: - super().__init__(f'Extension {name!r} has not been loaded.', name=name) + super().__init__(f"Extension {name!r} has not been loaded.", name=name) + class NoEntryPointError(ExtensionError): """An exception raised when an extension does not have a ``setup`` entry point function. This inherits from :exc:`ExtensionError` """ + def __init__(self, name: str) -> None: super().__init__(f"Extension {name!r} has no 'setup' function.", name=name) + class ExtensionFailed(ExtensionError): """An exception raised when an extension failed to load during execution of the module or ``setup`` entry point. @@ -872,11 +965,13 @@ class ExtensionFailed(ExtensionError): The original exception that was raised. You can also get this via the ``__cause__`` attribute. """ + def __init__(self, name: str, original: Exception) -> None: self.original: Exception = original - msg = f'Extension {name!r} raised an error: {original.__class__.__name__}: {original}' + msg = f"Extension {name!r} raised an error: {original.__class__.__name__}: {original}" super().__init__(msg, name=name) + class ExtensionNotFound(ExtensionError): """An exception raised when an extension is not found. @@ -890,10 +985,12 @@ class ExtensionNotFound(ExtensionError): name: :class:`str` The extension that had the error. """ + def __init__(self, name: str) -> None: - msg = f'Extension {name!r} could not be loaded.' + msg = f"Extension {name!r} could not be loaded." super().__init__(msg, name=name) + class CommandRegistrationError(ClientException): """An exception raised when the command can't be added because the name is already taken by a different command. @@ -909,11 +1006,13 @@ class CommandRegistrationError(ClientException): alias_conflict: :class:`bool` Whether the name that conflicts is an alias of the command we try to add. """ + def __init__(self, name: str, *, alias_conflict: bool = False) -> None: self.name: str = name self.alias_conflict: bool = alias_conflict - type_ = 'alias' if alias_conflict else 'command' - super().__init__(f'The {type_} {name} is already an existing command or alias.') + type_ = "alias" if alias_conflict else "command" + super().__init__(f"The {type_} {name} is already an existing command or alias.") + class FlagError(BadArgument): """The base exception type for all flag parsing related errors. @@ -922,8 +1021,10 @@ class FlagError(BadArgument): .. versionadded:: 2.0 """ + pass + class TooManyFlags(FlagError): """An exception raised when a flag has received too many values. @@ -938,10 +1039,12 @@ class TooManyFlags(FlagError): values: List[:class:`str`] The values that were passed. """ + def __init__(self, flag: Flag, values: List[str]) -> None: self.flag: Flag = flag self.values: List[str] = values - super().__init__(f'Too many flag values, expected {flag.max_args} but received {len(values)}.') + super().__init__(f"Too many flag values, expected {flag.max_args} but received {len(values)}.") + class BadFlagArgument(FlagError): """An exception raised when a flag failed to convert a value. @@ -955,6 +1058,7 @@ class BadFlagArgument(FlagError): flag: :class:`~discord.ext.commands.Flag` The flag that failed to convert. """ + def __init__(self, flag: Flag) -> None: self.flag: Flag = flag try: @@ -962,7 +1066,8 @@ class BadFlagArgument(FlagError): except AttributeError: name = flag.annotation.__class__.__name__ - super().__init__(f'Could not convert to {name!r} for flag {flag.name!r}') + super().__init__(f"Could not convert to {name!r} for flag {flag.name!r}") + class MissingRequiredFlag(FlagError): """An exception raised when a required flag was not given. @@ -976,9 +1081,11 @@ class MissingRequiredFlag(FlagError): flag: :class:`~discord.ext.commands.Flag` The required flag that was not found. """ + def __init__(self, flag: Flag) -> None: self.flag: Flag = flag - super().__init__(f'Flag {flag.name!r} is required and missing') + super().__init__(f"Flag {flag.name!r} is required and missing") + class MissingFlagArgument(FlagError): """An exception raised when a flag did not get a value. @@ -992,6 +1099,7 @@ class MissingFlagArgument(FlagError): flag: :class:`~discord.ext.commands.Flag` The flag that did not get a value. """ + def __init__(self, flag: Flag) -> None: self.flag: Flag = flag - super().__init__(f'Flag {flag.name!r} does not have an argument') + super().__init__(f"Flag {flag.name!r} does not have an argument") diff --git a/discord/ext/commands/flags.py b/discord/ext/commands/flags.py index b356af34..367127a4 100644 --- a/discord/ext/commands/flags.py +++ b/discord/ext/commands/flags.py @@ -59,9 +59,9 @@ import sys import re __all__ = ( - 'Flag', - 'flag', - 'FlagConverter', + "Flag", + "flag", + "FlagConverter", ) @@ -148,20 +148,20 @@ def flag( def validate_flag_name(name: str, forbidden: Set[str]): if not name: - raise ValueError('flag names should not be empty') + raise ValueError("flag names should not be empty") for ch in name: if ch.isspace(): - raise ValueError(f'flag name {name!r} cannot have spaces') - if ch == '\\': - raise ValueError(f'flag name {name!r} cannot have backslashes') + raise ValueError(f"flag name {name!r} cannot have spaces") + if ch == "\\": + raise ValueError(f"flag name {name!r} cannot have backslashes") if ch in forbidden: - raise ValueError(f'flag name {name!r} cannot have any of {forbidden!r} within them') + raise ValueError(f"flag name {name!r} cannot have any of {forbidden!r} within them") def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]: - annotations = namespace.get('__annotations__', {}) - case_insensitive = namespace['__commands_flag_case_insensitive__'] + annotations = namespace.get("__annotations__", {}) + case_insensitive = namespace["__commands_flag_case_insensitive__"] flags: Dict[str, Flag] = {} cache: Dict[str, Any] = {} names: Set[str] = set() @@ -178,7 +178,11 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache) - if flag.default is MISSING and hasattr(annotation, '__commands_is_flag__') and annotation._can_be_constructible(): + if ( + flag.default is MISSING + and hasattr(annotation, "__commands_is_flag__") + and annotation._can_be_constructible() + ): flag.default = annotation._construct_default if flag.aliases is MISSING: @@ -229,7 +233,7 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s if flag.max_args is MISSING: flag.max_args = 1 else: - raise TypeError(f'Unsupported typing annotation {annotation!r} for {flag.name!r} flag') + raise TypeError(f"Unsupported typing annotation {annotation!r} for {flag.name!r} flag") if flag.override is MISSING: flag.override = False @@ -237,7 +241,7 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s # Validate flag names are unique name = flag.name.casefold() if case_insensitive else flag.name if name in names: - raise TypeError(f'{flag.name!r} flag conflicts with previous flag or alias.') + raise TypeError(f"{flag.name!r} flag conflicts with previous flag or alias.") else: names.add(name) @@ -245,7 +249,7 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s # Validate alias is unique alias = alias.casefold() if case_insensitive else alias if alias in names: - raise TypeError(f'{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.') + raise TypeError(f"{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.") else: names.add(alias) @@ -274,10 +278,10 @@ class FlagsMeta(type): delimiter: str = MISSING, prefix: str = MISSING, ): - attrs['__commands_is_flag__'] = True + attrs["__commands_is_flag__"] = True try: - global_ns = sys.modules[attrs['__module__']].__dict__ + global_ns = sys.modules[attrs["__module__"]].__dict__ except KeyError: global_ns = {} @@ -296,26 +300,26 @@ class FlagsMeta(type): flags: Dict[str, Flag] = {} aliases: Dict[str, str] = {} for base in reversed(bases): - if base.__dict__.get('__commands_is_flag__', False): - flags.update(base.__dict__['__commands_flags__']) - aliases.update(base.__dict__['__commands_flag_aliases__']) + if base.__dict__.get("__commands_is_flag__", False): + flags.update(base.__dict__["__commands_flags__"]) + aliases.update(base.__dict__["__commands_flag_aliases__"]) if case_insensitive is MISSING: - attrs['__commands_flag_case_insensitive__'] = base.__dict__['__commands_flag_case_insensitive__'] + attrs["__commands_flag_case_insensitive__"] = base.__dict__["__commands_flag_case_insensitive__"] if delimiter is MISSING: - attrs['__commands_flag_delimiter__'] = base.__dict__['__commands_flag_delimiter__'] + attrs["__commands_flag_delimiter__"] = base.__dict__["__commands_flag_delimiter__"] if prefix is MISSING: - attrs['__commands_flag_prefix__'] = base.__dict__['__commands_flag_prefix__'] + attrs["__commands_flag_prefix__"] = base.__dict__["__commands_flag_prefix__"] if case_insensitive is not MISSING: - attrs['__commands_flag_case_insensitive__'] = case_insensitive + attrs["__commands_flag_case_insensitive__"] = case_insensitive if delimiter is not MISSING: - attrs['__commands_flag_delimiter__'] = delimiter + attrs["__commands_flag_delimiter__"] = delimiter if prefix is not MISSING: - attrs['__commands_flag_prefix__'] = prefix + attrs["__commands_flag_prefix__"] = prefix - case_insensitive = attrs.setdefault('__commands_flag_case_insensitive__', False) - delimiter = attrs.setdefault('__commands_flag_delimiter__', ':') - prefix = attrs.setdefault('__commands_flag_prefix__', '') + case_insensitive = attrs.setdefault("__commands_flag_case_insensitive__", False) + delimiter = attrs.setdefault("__commands_flag_delimiter__", ":") + prefix = attrs.setdefault("__commands_flag_prefix__", "") for flag_name, flag in get_flags(attrs, global_ns, local_ns).items(): flags[flag_name] = flag @@ -337,11 +341,11 @@ class FlagsMeta(type): keys.extend(re.escape(a) for a in aliases) keys = sorted(keys, key=lambda t: len(t), reverse=True) - joined = '|'.join(keys) - pattern = re.compile(f'(({re.escape(prefix)})(?P{joined}){re.escape(delimiter)})', regex_flags) - attrs['__commands_flag_regex__'] = pattern - attrs['__commands_flags__'] = flags - attrs['__commands_flag_aliases__'] = aliases + joined = "|".join(keys) + pattern = re.compile(f"(({re.escape(prefix)})(?P{joined}){re.escape(delimiter)})", regex_flags) + attrs["__commands_flag_regex__"] = pattern + attrs["__commands_flags__"] = flags + attrs["__commands_flag_aliases__"] = aliases return type.__new__(cls, name, bases, attrs) @@ -432,7 +436,7 @@ async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) - raise BadFlagArgument(flag) from e -F = TypeVar('F', bound='FlagConverter') +F = TypeVar("F", bound="FlagConverter") class FlagConverter(metaclass=FlagsMeta): @@ -493,8 +497,8 @@ class FlagConverter(metaclass=FlagsMeta): return self def __repr__(self) -> str: - pairs = ' '.join([f'{flag.attribute}={getattr(self, flag.attribute)!r}' for flag in self.get_flags().values()]) - return f'<{self.__class__.__name__} {pairs}>' + pairs = " ".join([f"{flag.attribute}={getattr(self, flag.attribute)!r}" for flag in self.get_flags().values()]) + return f"<{self.__class__.__name__} {pairs}>" @classmethod def parse_flags(cls, argument: str) -> Dict[str, List[str]]: @@ -507,7 +511,7 @@ class FlagConverter(metaclass=FlagsMeta): case_insensitive = cls.__commands_flag_case_insensitive__ for match in cls.__commands_flag_regex__.finditer(argument): begin, end = match.span(0) - key = match.group('flag') + key = match.group("flag") if case_insensitive: key = key.casefold() diff --git a/discord/ext/commands/help.py b/discord/ext/commands/help.py index afaacbfb..0630ea81 100644 --- a/discord/ext/commands/help.py +++ b/discord/ext/commands/help.py @@ -39,10 +39,10 @@ if TYPE_CHECKING: from .context import Context __all__ = ( - 'Paginator', - 'HelpCommand', - 'DefaultHelpCommand', - 'MinimalHelpCommand', + "Paginator", + "HelpCommand", + "DefaultHelpCommand", + "MinimalHelpCommand", ) # help -> shows info of bot on top/bottom and lists subcommands @@ -89,7 +89,7 @@ class Paginator: .. versionadded:: 1.7 """ - def __init__(self, prefix='```', suffix='```', max_size=2000, linesep='\n'): + def __init__(self, prefix="```", suffix="```", max_size=2000, linesep="\n"): self.prefix = prefix self.suffix = suffix self.max_size = max_size @@ -118,7 +118,7 @@ class Paginator: def _linesep_len(self): return len(self.linesep) - def add_line(self, line='', *, empty=False): + def add_line(self, line="", *, empty=False): """Adds a line to the current page. If the line exceeds the :attr:`max_size` then an exception @@ -138,7 +138,7 @@ class Paginator: """ max_page_size = self.max_size - self._prefix_len - self._suffix_len - 2 * self._linesep_len if len(line) > max_page_size: - raise RuntimeError(f'Line exceeds maximum page size {max_page_size}') + raise RuntimeError(f"Line exceeds maximum page size {max_page_size}") if self._count + len(line) + self._linesep_len > self.max_size - self._suffix_len: self.close_page() @@ -147,7 +147,7 @@ class Paginator: self._current_page.append(line) if empty: - self._current_page.append('') + self._current_page.append("") self._count += self._linesep_len def close_page(self): @@ -176,7 +176,7 @@ class Paginator: return self._pages def __repr__(self): - fmt = '' + fmt = "" return fmt.format(self) @@ -197,7 +197,7 @@ class _HelpCommandImpl(Command): self.callback = injected.command_callback on_error = injected.on_help_command_error - if not hasattr(on_error, '__help_command_not_overriden__'): + if not hasattr(on_error, "__help_command_not_overriden__"): if self.cog is not None: self.on_error = self._on_error_cog_implementation else: @@ -224,7 +224,7 @@ class _HelpCommandImpl(Command): try: del result[next(iter(result))] except StopIteration: - raise ValueError('Missing context parameter') from None + raise ValueError("Missing context parameter") from None else: return result @@ -296,13 +296,13 @@ class HelpCommand: """ MENTION_TRANSFORMS = { - '@everyone': '@\u200beveryone', - '@here': '@\u200bhere', - r'<@!?[0-9]{17,22}>': '@deleted-user', - r'<@&[0-9]{17,22}>': '@deleted-role', + "@everyone": "@\u200beveryone", + "@here": "@\u200bhere", + r"<@!?[0-9]{17,22}>": "@deleted-user", + r"<@&[0-9]{17,22}>": "@deleted-role", } - MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys())) + MENTION_PATTERN = re.compile("|".join(MENTION_TRANSFORMS.keys())) def __new__(cls, *args, **kwargs): # To prevent race conditions of a single instance while also allowing @@ -321,11 +321,11 @@ class HelpCommand: return self def __init__(self, **options): - self.show_hidden = options.pop('show_hidden', False) - self.verify_checks = options.pop('verify_checks', True) - self.command_attrs = attrs = options.pop('command_attrs', {}) - attrs.setdefault('name', 'help') - attrs.setdefault('help', 'Shows this message') + self.show_hidden = options.pop("show_hidden", False) + self.verify_checks = options.pop("verify_checks", True) + self.command_attrs = attrs = options.pop("command_attrs", {}) + attrs.setdefault("name", "help") + attrs.setdefault("help", "Shows this message") self.context: Context = discord.utils.MISSING self._command_impl = _HelpCommandImpl(self, **self.command_attrs) @@ -422,20 +422,20 @@ class HelpCommand: if not parent.signature or parent.invoke_without_command: entries.append(parent.name) else: - entries.append(parent.name + ' ' + parent.signature) + entries.append(parent.name + " " + parent.signature) parent = parent.parent - parent_sig = ' '.join(reversed(entries)) + parent_sig = " ".join(reversed(entries)) if len(command.aliases) > 0: - aliases = '|'.join(command.aliases) - fmt = f'[{command.name}|{aliases}]' + aliases = "|".join(command.aliases) + fmt = f"[{command.name}|{aliases}]" if parent_sig: - fmt = parent_sig + ' ' + fmt + fmt = parent_sig + " " + fmt alias = fmt else: - alias = command.name if not parent_sig else parent_sig + ' ' + command.name + alias = command.name if not parent_sig else parent_sig + " " + command.name - return f'{self.context.clean_prefix}{alias} {command.signature}' + return f"{self.context.clean_prefix}{alias} {command.signature}" def remove_mentions(self, string): """Removes mentions from the string to prevent abuse. @@ -449,7 +449,7 @@ class HelpCommand: """ def replace(obj, *, transforms=self.MENTION_TRANSFORMS): - return transforms.get(obj.group(0), '@invalid') + return transforms.get(obj.group(0), "@invalid") return self.MENTION_PATTERN.sub(replace, string) @@ -846,7 +846,7 @@ class HelpCommand: # Since we want to have detailed errors when someone # passes an invalid subcommand, we need to walk through # the command group chain ourselves. - keys = command.split(' ') + keys = command.split(" ") cmd = bot.all_commands.get(keys[0]) if cmd is None: string = await maybe_coro(self.command_not_found, self.remove_mentions(keys[0])) @@ -907,14 +907,14 @@ class DefaultHelpCommand(HelpCommand): """ def __init__(self, **options): - self.width = options.pop('width', 80) - self.indent = options.pop('indent', 2) - self.sort_commands = options.pop('sort_commands', True) - self.dm_help = options.pop('dm_help', False) - self.dm_help_threshold = options.pop('dm_help_threshold', 1000) - self.commands_heading = options.pop('commands_heading', "Commands:") - self.no_category = options.pop('no_category', 'No Category') - self.paginator = options.pop('paginator', None) + self.width = options.pop("width", 80) + self.indent = options.pop("indent", 2) + self.sort_commands = options.pop("sort_commands", True) + self.dm_help = options.pop("dm_help", False) + self.dm_help_threshold = options.pop("dm_help_threshold", 1000) + self.commands_heading = options.pop("commands_heading", "Commands:") + self.no_category = options.pop("no_category", "No Category") + self.paginator = options.pop("paginator", None) if self.paginator is None: self.paginator = Paginator() @@ -924,7 +924,7 @@ class DefaultHelpCommand(HelpCommand): def shorten_text(self, text): """:class:`str`: Shortens text to fit into the :attr:`width`.""" if len(text) > self.width: - return text[:self.width - 3].rstrip() + '...' + return text[: self.width - 3].rstrip() + "..." return text def get_ending_note(self): @@ -1021,11 +1021,11 @@ class DefaultHelpCommand(HelpCommand): # portion self.paginator.add_line(bot.description, empty=True) - no_category = f'\u200b{self.no_category}:' + no_category = f"\u200b{self.no_category}:" def get_category(command, *, no_category=no_category): cog = command.cog - return cog.qualified_name + ':' if cog is not None else no_category + return cog.qualified_name + ":" if cog is not None else no_category filtered = await self.filter_commands(bot.commands, sort=True, key=get_category) max_size = self.get_max_size(filtered) @@ -1110,13 +1110,13 @@ class MinimalHelpCommand(HelpCommand): """ def __init__(self, **options): - self.sort_commands = options.pop('sort_commands', True) - self.commands_heading = options.pop('commands_heading', "Commands") - self.dm_help = options.pop('dm_help', False) - self.dm_help_threshold = options.pop('dm_help_threshold', 1000) - self.aliases_heading = options.pop('aliases_heading', "Aliases:") - self.no_category = options.pop('no_category', 'No Category') - self.paginator = options.pop('paginator', None) + self.sort_commands = options.pop("sort_commands", True) + self.commands_heading = options.pop("commands_heading", "Commands") + self.dm_help = options.pop("dm_help", False) + self.dm_help_threshold = options.pop("dm_help_threshold", 1000) + self.aliases_heading = options.pop("aliases_heading", "Aliases:") + self.no_category = options.pop("no_category", "No Category") + self.paginator = options.pop("paginator", None) if self.paginator is None: self.paginator = Paginator(suffix=None, prefix=None) @@ -1149,7 +1149,7 @@ class MinimalHelpCommand(HelpCommand): ) def get_command_signature(self, command): - return f'{self.context.clean_prefix}{command.qualified_name} {command.signature}' + return f"{self.context.clean_prefix}{command.qualified_name} {command.signature}" def get_ending_note(self): """Return the help command's ending note. This is mainly useful to override for i18n purposes. @@ -1180,8 +1180,8 @@ class MinimalHelpCommand(HelpCommand): """ if commands: # U+2002 Middle Dot - joined = '\u2002'.join(c.name for c in commands) - self.paginator.add_line(f'__**{heading}**__') + joined = "\u2002".join(c.name for c in commands) + self.paginator.add_line(f"__**{heading}**__") self.paginator.add_line(joined) def add_subcommand_formatting(self, command): @@ -1197,7 +1197,7 @@ class MinimalHelpCommand(HelpCommand): command: :class:`Command` The command to show information of. """ - fmt = '{0}{1} \N{EN DASH} {2}' if command.short_doc else '{0}{1}' + fmt = "{0}{1} \N{EN DASH} {2}" if command.short_doc else "{0}{1}" self.paginator.add_line(fmt.format(self.context.clean_prefix, command.qualified_name, command.short_doc)) def add_aliases_formatting(self, aliases): @@ -1268,7 +1268,7 @@ class MinimalHelpCommand(HelpCommand): if note: self.paginator.add_line(note, empty=True) - no_category = f'\u200b{self.no_category}' + no_category = f"\u200b{self.no_category}" def get_category(command, *, no_category=no_category): cog = command.cog @@ -1302,7 +1302,7 @@ class MinimalHelpCommand(HelpCommand): filtered = await self.filter_commands(cog.get_commands(), sort=self.sort_commands) if filtered: - self.paginator.add_line(f'**{cog.qualified_name} {self.commands_heading}**') + self.paginator.add_line(f"**{cog.qualified_name} {self.commands_heading}**") for command in filtered: self.add_subcommand_formatting(command) @@ -1322,7 +1322,7 @@ class MinimalHelpCommand(HelpCommand): if note: self.paginator.add_line(note, empty=True) - self.paginator.add_line(f'**{self.commands_heading}**') + self.paginator.add_line(f"**{self.commands_heading}**") for command in filtered: self.add_subcommand_formatting(command) diff --git a/discord/ext/commands/view.py b/discord/ext/commands/view.py index a7dc7236..9c503ac4 100644 --- a/discord/ext/commands/view.py +++ b/discord/ext/commands/view.py @@ -46,6 +46,7 @@ _quotes = { } _all_quotes = set(_quotes.keys()) | set(_quotes.values()) + class StringView: def __init__(self, buffer): self.index = 0 @@ -81,20 +82,20 @@ class StringView: def skip_string(self, string): strlen = len(string) - if self.buffer[self.index:self.index + strlen] == string: + if self.buffer[self.index : self.index + strlen] == string: self.previous = self.index self.index += strlen return True return False def read_rest(self): - result = self.buffer[self.index:] + result = self.buffer[self.index :] self.previous = self.index self.index = self.end return result def read(self, n): - result = self.buffer[self.index:self.index + n] + result = self.buffer[self.index : self.index + n] self.previous = self.index self.index += n return result @@ -120,7 +121,7 @@ class StringView: except IndexError: break self.previous = self.index - result = self.buffer[self.index:self.index + pos] + result = self.buffer[self.index : self.index + pos] self.index += pos return result @@ -144,11 +145,11 @@ class StringView: if is_quoted: # unexpected EOF raise ExpectedClosingQuoteError(close_quote) - return ''.join(result) + return "".join(result) # currently we accept strings in the format of "hello world" # to embed a quote inside the string you must escape it: "a \"world\"" - if current == '\\': + if current == "\\": next_char = self.get() if not next_char: # string ends with \ and no character after it @@ -156,7 +157,7 @@ class StringView: # if we're quoted then we're expecting a closing quote raise ExpectedClosingQuoteError(close_quote) # if we aren't then we just let it through - return ''.join(result) + return "".join(result) if next_char in _escaped_quotes: # escaped quote @@ -179,14 +180,13 @@ class StringView: raise InvalidEndOfQuotedStringError(next_char) # we're quoted so it's okay - return ''.join(result) + return "".join(result) if current.isspace() and not is_quoted: # end of word found - return ''.join(result) + return "".join(result) result.append(current) - def __repr__(self): - return f'' + return f"" diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index 5b78f10e..2754f96c 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -48,19 +48,17 @@ from collections.abc import Sequence from discord.backoff import ExponentialBackoff from discord.utils import MISSING -__all__ = ( - 'loop', -) +__all__ = ("loop",) -T = TypeVar('T') +T = TypeVar("T") _func = Callable[..., Awaitable[Any]] -LF = TypeVar('LF', bound=_func) -FT = TypeVar('FT', bound=_func) -ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]]) +LF = TypeVar("LF", bound=_func) +FT = TypeVar("FT", bound=_func) +ET = TypeVar("ET", bound=Callable[[Any, BaseException], Awaitable[Any]]) class SleepHandle: - __slots__ = ('future', 'loop', 'handle') + __slots__ = ("future", "loop", "handle") def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None: self.loop = loop @@ -124,7 +122,7 @@ class Loop(Generic[LF]): self._stop_next_iteration = False if self.count is not None and self.count <= 0: - raise ValueError('count must be greater than 0 or None.') + raise ValueError("count must be greater than 0 or None.") self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time) self._last_iteration_failed = False @@ -132,10 +130,10 @@ class Loop(Generic[LF]): self._next_iteration = None if not inspect.iscoroutinefunction(self.coro): - raise TypeError(f'Expected coroutine function, not {type(self.coro).__name__!r}.') + raise TypeError(f"Expected coroutine function, not {type(self.coro).__name__!r}.") async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None: - coro = getattr(self, '_' + name) + coro = getattr(self, "_" + name) if coro is None: return @@ -150,7 +148,7 @@ class Loop(Generic[LF]): async def _loop(self, *args: Any, **kwargs: Any) -> None: backoff = ExponentialBackoff() - await self._call_loop_function('before_loop') + await self._call_loop_function("before_loop") self._last_iteration_failed = False if self._time is not MISSING: # the time index should be prepared every time the internal loop is started @@ -193,10 +191,10 @@ class Loop(Generic[LF]): raise except Exception as exc: self._has_failed = True - await self._call_loop_function('error', exc) + await self._call_loop_function("error", exc) raise exc finally: - await self._call_loop_function('after_loop') + await self._call_loop_function("after_loop") self._handle.cancel() self._is_being_cancelled = False self._current_loop = 0 @@ -323,7 +321,7 @@ class Loop(Generic[LF]): """ if self._task is not MISSING and not self._task.done(): - raise RuntimeError('Task is already launched and is not completed.') + raise RuntimeError("Task is already launched and is not completed.") if self._injected is not None: args = (self._injected, *args) @@ -410,9 +408,9 @@ class Loop(Generic[LF]): for exc in exceptions: if not inspect.isclass(exc): - raise TypeError(f'{exc!r} must be a class.') + raise TypeError(f"{exc!r} must be a class.") if not issubclass(exc, BaseException): - raise TypeError(f'{exc!r} must inherit from BaseException.') + raise TypeError(f"{exc!r} must inherit from BaseException.") self._valid_exception = (*self._valid_exception, *exceptions) @@ -466,7 +464,7 @@ class Loop(Generic[LF]): async def _error(self, *args: Any) -> None: exception: Exception = args[-1] - print(f'Unhandled exception in internal background task {self.coro.__name__!r}.', file=sys.stderr) + print(f"Unhandled exception in internal background task {self.coro.__name__!r}.", file=sys.stderr) traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr) def before_loop(self, coro: FT) -> FT: @@ -489,7 +487,7 @@ class Loop(Generic[LF]): """ if not inspect.iscoroutinefunction(coro): - raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') + raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.") self._before_loop = coro return coro @@ -517,7 +515,7 @@ class Loop(Generic[LF]): """ if not inspect.iscoroutinefunction(coro): - raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') + raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.") self._after_loop = coro return coro @@ -543,7 +541,7 @@ class Loop(Generic[LF]): The function was not a coroutine. """ if not inspect.iscoroutinefunction(coro): - raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') + raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.") self._error = coro # type: ignore return coro @@ -601,16 +599,16 @@ class Loop(Generic[LF]): return [inner] if not isinstance(time, Sequence): raise TypeError( - f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.' + f"Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead." ) if not time: - raise ValueError('time parameter must not be an empty sequence.') + raise ValueError("time parameter must not be an empty sequence.") ret: List[datetime.time] = [] for index, t in enumerate(time): if not isinstance(t, dt): raise TypeError( - f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.' + f"Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead." ) ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc)) @@ -663,7 +661,7 @@ class Loop(Generic[LF]): hours = hours or 0 sleep = seconds + (minutes * 60.0) + (hours * 3600.0) if sleep < 0: - raise ValueError('Total number of seconds cannot be less than zero.') + raise ValueError("Total number of seconds cannot be less than zero.") self._sleep = sleep self._seconds = float(seconds) @@ -672,7 +670,7 @@ class Loop(Generic[LF]): self._time: List[datetime.time] = MISSING else: if any((seconds, minutes, hours)): - raise TypeError('Cannot mix explicit time with relative time') + raise TypeError("Cannot mix explicit time with relative time") self._time = self._get_time_parameter(time) self._sleep = self._seconds = self._minutes = self._hours = MISSING diff --git a/discord/file.py b/discord/file.py index 5303e325..849c6e85 100644 --- a/discord/file.py +++ b/discord/file.py @@ -28,9 +28,7 @@ from typing import Optional, TYPE_CHECKING, Union import os import io -__all__ = ( - 'File', -) +__all__ = ("File",) class File: @@ -64,7 +62,7 @@ class File: Whether the attachment is a spoiler. """ - __slots__ = ('fp', 'filename', 'spoiler', '_original_pos', '_owner', '_closer') + __slots__ = ("fp", "filename", "spoiler", "_original_pos", "_owner", "_closer") if TYPE_CHECKING: fp: io.BufferedIOBase @@ -80,12 +78,12 @@ class File: ): if isinstance(fp, io.IOBase): if not (fp.seekable() and fp.readable()): - raise ValueError(f'File buffer {fp!r} must be seekable and readable') + raise ValueError(f"File buffer {fp!r} must be seekable and readable") self.fp = fp self._original_pos = fp.tell() self._owner = False else: - self.fp = open(fp, 'rb') + self.fp = open(fp, "rb") self._original_pos = 0 self._owner = True @@ -100,14 +98,14 @@ class File: if isinstance(fp, str): _, self.filename = os.path.split(fp) else: - self.filename = getattr(fp, 'name', None) + self.filename = getattr(fp, "name", None) else: self.filename = filename - if spoiler and self.filename is not None and not self.filename.startswith('SPOILER_'): - self.filename = 'SPOILER_' + self.filename + if spoiler and self.filename is not None and not self.filename.startswith("SPOILER_"): + self.filename = "SPOILER_" + self.filename - self.spoiler = spoiler or (self.filename is not None and self.filename.startswith('SPOILER_')) + self.spoiler = spoiler or (self.filename is not None and self.filename.startswith("SPOILER_")) def reset(self, *, seek: Union[int, bool] = True) -> None: # The `seek` parameter is needed because diff --git a/discord/flags.py b/discord/flags.py index 3c5956a4..920c190f 100644 --- a/discord/flags.py +++ b/discord/flags.py @@ -29,16 +29,16 @@ from typing import Any, Callable, ClassVar, Dict, Generic, Iterator, List, Optio from .enums import UserFlags __all__ = ( - 'SystemChannelFlags', - 'MessageFlags', - 'PublicUserFlags', - 'Intents', - 'MemberCacheFlags', - 'ApplicationFlags', + "SystemChannelFlags", + "MessageFlags", + "PublicUserFlags", + "Intents", + "MemberCacheFlags", + "ApplicationFlags", ) -FV = TypeVar('FV', bound='flag_value') -BF = TypeVar('BF', bound='BaseFlags') +FV = TypeVar("FV", bound="flag_value") +BF = TypeVar("BF", bound="BaseFlags") class flag_value: @@ -63,7 +63,7 @@ class flag_value: instance._set_flag(self.flag, value) def __repr__(self): - return f'' + return f"" class alias_flag_value(flag_value): @@ -98,13 +98,13 @@ class BaseFlags: value: int - __slots__ = ('value',) + __slots__ = ("value",) def __init__(self, **kwargs: bool): self.value = self.DEFAULT_VALUE for key, value in kwargs.items(): if key not in self.VALID_FLAGS: - raise TypeError(f'{key!r} is not a valid flag name.') + raise TypeError(f"{key!r} is not a valid flag name.") setattr(self, key, value) @classmethod @@ -123,7 +123,7 @@ class BaseFlags: return hash(self.value) def __repr__(self) -> str: - return f'<{self.__class__.__name__} value={self.value}>' + return f"<{self.__class__.__name__} value={self.value}>" def __iter__(self) -> Iterator[Tuple[str, bool]]: for name, value in self.__class__.__dict__.items(): @@ -142,7 +142,7 @@ class BaseFlags: elif toggle is False: self.value &= ~o else: - raise TypeError(f'Value to set for {self.__class__.__name__} must be a bool.') + raise TypeError(f"Value to set for {self.__class__.__name__} must be a bool.") @fill_with_flags(inverted=True) @@ -196,7 +196,7 @@ class SystemChannelFlags(BaseFlags): elif toggle is False: self.value |= o else: - raise TypeError('Value to set for SystemChannelFlags must be a bool.') + raise TypeError("Value to set for SystemChannelFlags must be a bool.") @flag_value def join_notifications(self): @@ -461,7 +461,7 @@ class Intents(BaseFlags): self.value = self.DEFAULT_VALUE for key, value in kwargs.items(): if key not in self.VALID_FLAGS: - raise TypeError(f'{key!r} is not a valid flag name.') + raise TypeError(f"{key!r} is not a valid flag name.") setattr(self, key, value) @classmethod @@ -907,7 +907,7 @@ class MemberCacheFlags(BaseFlags): self.value = (1 << bits) - 1 for key, value in kwargs.items(): if key not in self.VALID_FLAGS: - raise TypeError(f'{key!r} is not a valid flag name.') + raise TypeError(f"{key!r} is not a valid flag name.") setattr(self, key, value) @classmethod @@ -977,10 +977,10 @@ class MemberCacheFlags(BaseFlags): def _verify_intents(self, intents: Intents): if self.voice and not intents.voice_states: - raise ValueError('MemberCacheFlags.voice requires Intents.voice_states') + raise ValueError("MemberCacheFlags.voice requires Intents.voice_states") if self.joined and not intents.members: - raise ValueError('MemberCacheFlags.joined requires Intents.members') + raise ValueError("MemberCacheFlags.joined requires Intents.members") @property def _voice_only(self): diff --git a/discord/gateway.py b/discord/gateway.py index fbbc3c5e..5ef651f1 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -24,7 +24,20 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import TYPE_CHECKING, TypedDict, Any, Optional, List, TypeVar, Type, Dict, Callable, Coroutine, NamedTuple, Deque +from typing import ( + TYPE_CHECKING, + TypedDict, + Any, + Optional, + List, + TypeVar, + Type, + Dict, + Callable, + Coroutine, + NamedTuple, + Deque, +) import asyncio from collections import deque @@ -42,16 +55,16 @@ import aiohttp from . import utils from .activity import BaseActivity from .enums import SpeakingState -from .errors import ConnectionClosed, InvalidArgument +from .errors import ConnectionClosed, InvalidArgument if TYPE_CHECKING: from .client import Client from .state import ConnectionState from .voice_client import VoiceClient - T = TypeVar('T') - DWS = TypeVar('DWS', bound='DiscordWebSocket') - DVWS = TypeVar('DVWS', bound='DiscordVoiceWebSocket') + T = TypeVar("T") + DWS = TypeVar("DWS", bound="DiscordWebSocket") + DVWS = TypeVar("DVWS", bound="DiscordVoiceWebSocket") Coro = Callable[..., Coroutine[Any, Any, Any]] Predicate = Callable[[Dict[str, Any]], bool] @@ -63,11 +76,11 @@ _log: logging.Logger = logging.getLogger(__name__) __all__ = ( - 'DiscordWebSocket', - 'KeepAliveHandler', - 'VoiceKeepAliveHandler', - 'DiscordVoiceWebSocket', - 'ReconnectWebSocket', + "DiscordWebSocket", + "KeepAliveHandler", + "VoiceKeepAliveHandler", + "DiscordVoiceWebSocket", + "ReconnectWebSocket", ) @@ -78,14 +91,16 @@ class Heartbeat(TypedDict): class ReconnectWebSocket(Exception): """Signals to safely reconnect the websocket.""" + def __init__(self, shard_id: Optional[int], *, resume: bool = True) -> None: self.shard_id: Optional[int] = shard_id self.resume: bool = resume - self.op = 'RESUME' if resume else 'IDENTIFY' + self.op = "RESUME" if resume else "IDENTIFY" class WebSocketClosure(Exception): """An exception to make up for the fact that aiohttp doesn't signal closure.""" + pass @@ -134,48 +149,50 @@ class GatewayRatelimiter: async with self.lock: delta = self.get_delay() if delta: - _log.warning('WebSocket in shard ID %s is ratelimited, waiting %.2f seconds', self.shard_id, delta) + _log.warning("WebSocket in shard ID %s is ratelimited, waiting %.2f seconds", self.shard_id, delta) await asyncio.sleep(delta) class KeepAliveHandler(threading.Thread): def __init__(self, *args: Any, **kwargs: Any) -> None: - ws = kwargs.pop('ws') - interval = kwargs.pop('interval', None) - shard_id = kwargs.pop('shard_id', None) + ws = kwargs.pop("ws") + interval = kwargs.pop("interval", None) + shard_id = kwargs.pop("shard_id", None) threading.Thread.__init__(self, *args, **kwargs) self.ws: DiscordWebSocket = ws self._main_thread_id: int = ws.thread_id self.interval: Optional[float] = interval self.daemon: bool = True self.shard_id: Optional[int] = shard_id - self.msg: str = 'Keeping shard ID %s websocket alive with sequence %s.' - self.block_msg: str = 'Shard ID %s heartbeat blocked for more than %s seconds.' - self.behind_msg: str = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.' + self.msg: str = "Keeping shard ID %s websocket alive with sequence %s." + self.block_msg: str = "Shard ID %s heartbeat blocked for more than %s seconds." + self.behind_msg: str = "Can't keep up, shard ID %s websocket is %.1fs behind." self._stop_ev: threading.Event = threading.Event() self._last_ack: float = time.perf_counter() self._last_send: float = time.perf_counter() self._last_recv: float = time.perf_counter() - self.latency: float = float('inf') + self.latency: float = float("inf") self.heartbeat_timeout: float = ws._max_heartbeat_timeout def run(self) -> None: while not self._stop_ev.wait(self.interval): if self._last_recv + self.heartbeat_timeout < time.perf_counter(): - _log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id) + _log.warning( + "Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id + ) coro = self.ws.close(4000) f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop) try: f.result() except Exception: - _log.exception('An error occurred while stopping the gateway. Ignoring.') + _log.exception("An error occurred while stopping the gateway. Ignoring.") finally: self.stop() return data = self.get_payload() - _log.debug(self.msg, self.shard_id, data['d']) + _log.debug(self.msg, self.shard_id, data["d"]) coro = self.ws.send_heartbeat(data) f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop) try: @@ -192,8 +209,8 @@ class KeepAliveHandler(threading.Thread): except KeyError: msg = self.block_msg else: - stack = ''.join(traceback.format_stack(frame)) - msg = f'{self.block_msg}\nLoop thread traceback (most recent call last):\n{stack}' + stack = "".join(traceback.format_stack(frame)) + msg = f"{self.block_msg}\nLoop thread traceback (most recent call last):\n{stack}" _log.warning(msg, self.shard_id, total) except Exception: @@ -203,9 +220,9 @@ class KeepAliveHandler(threading.Thread): def get_payload(self) -> Heartbeat: return { - 'op': self.ws.HEARTBEAT, + "op": self.ws.HEARTBEAT, # the websocket's sequence won't be None here - 'd': self.ws.sequence # type: ignore + "d": self.ws.sequence, # type: ignore } def stop(self) -> None: @@ -221,19 +238,17 @@ class KeepAliveHandler(threading.Thread): if self.latency > 10: _log.warning(self.behind_msg, self.shard_id, self.latency) + class VoiceKeepAliveHandler(KeepAliveHandler): def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.recent_ack_latencies: Deque[float] = deque(maxlen=20) - self.msg = 'Keeping shard ID %s voice websocket alive with timestamp %s.' - self.block_msg = 'Shard ID %s voice heartbeat blocked for more than %s seconds' - self.behind_msg = 'High socket latency, shard ID %s heartbeat is %.1fs behind' + self.msg = "Keeping shard ID %s voice websocket alive with timestamp %s." + self.block_msg = "Shard ID %s voice heartbeat blocked for more than %s seconds" + self.behind_msg = "High socket latency, shard ID %s heartbeat is %.1fs behind" def get_payload(self) -> Heartbeat: - return { - 'op': self.ws.HEARTBEAT, - 'd': int(time.time() * 1000) - } + return {"op": self.ws.HEARTBEAT, "d": int(time.time() * 1000)} def ack(self) -> None: ack_time = time.perf_counter() @@ -244,7 +259,7 @@ class VoiceKeepAliveHandler(KeepAliveHandler): class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse): - async def close(self, *, code: int = 4000, message: bytes = b'') -> bool: + async def close(self, *, code: int = 4000, message: bytes = b"") -> bool: return await super().close(code=code, message=message) @@ -288,19 +303,19 @@ class DiscordWebSocket: The authentication token for discord. """ - DISPATCH = 0 - HEARTBEAT = 1 - IDENTIFY = 2 - PRESENCE = 3 - VOICE_STATE = 4 - VOICE_PING = 5 - RESUME = 6 - RECONNECT = 7 - REQUEST_MEMBERS = 8 + DISPATCH = 0 + HEARTBEAT = 1 + IDENTIFY = 2 + PRESENCE = 3 + VOICE_STATE = 4 + VOICE_PING = 5 + RESUME = 6 + RECONNECT = 7 + REQUEST_MEMBERS = 8 INVALIDATE_SESSION = 9 - HELLO = 10 - HEARTBEAT_ACK = 11 - GUILD_SYNC = 12 + HELLO = 10 + HEARTBEAT_ACK = 11 + GUILD_SYNC = 12 def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.AbstractEventLoop) -> None: self.socket: aiohttp.ClientWebSocketResponse = socket @@ -342,13 +357,23 @@ class DiscordWebSocket: return self._rate_limiter.is_ratelimited() def debug_log_receive(self, data, /) -> None: - self._dispatch('socket_raw_receive', data) + self._dispatch("socket_raw_receive", data) def log_receive(self, _, /) -> None: pass @classmethod - async def from_client(cls: Type[DWS], client: Client, *, initial: bool = False, gateway: Optional[str] = None, shard_id: Optional[int] = None, session: Optional[str] = None, sequence: Optional[int] = None, resume: bool = False) -> DWS: + async def from_client( + cls: Type[DWS], + client: Client, + *, + initial: bool = False, + gateway: Optional[str] = None, + shard_id: Optional[int] = None, + session: Optional[str] = None, + sequence: Optional[int] = None, + resume: bool = False, + ) -> DWS: """Creates a main websocket for Discord from a :class:`Client`. This is for internal use only. @@ -360,7 +385,7 @@ class DiscordWebSocket: # dynamically add attributes needed # the token won't be None here - ws.token = client.http.token # type: ignore + ws.token = client.http.token # type: ignore ws._connection = client._connection ws._discord_parsers = client._connection.parsers ws._dispatch = client.dispatch @@ -380,7 +405,7 @@ class DiscordWebSocket: client._connection._update_references(ws) - _log.debug('Created websocket connected to %s', gateway) + _log.debug("Created websocket connected to %s", gateway) # poll event for OP Hello await ws.poll_event() @@ -420,77 +445,64 @@ class DiscordWebSocket: async def identify(self) -> None: """Sends the IDENTIFY packet.""" payload = { - 'op': self.IDENTIFY, - 'd': { - 'token': self.token, - 'properties': { - '$os': sys.platform, - '$browser': 'discord.py', - '$device': 'discord.py', - '$referrer': '', - '$referring_domain': '' + "op": self.IDENTIFY, + "d": { + "token": self.token, + "properties": { + "$os": sys.platform, + "$browser": "discord.py", + "$device": "discord.py", + "$referrer": "", + "$referring_domain": "", }, - 'compress': True, - 'large_threshold': 250, - 'v': 3 - } + "compress": True, + "large_threshold": 250, + "v": 3, + }, } if self.shard_id is not None and self.shard_count is not None: - payload['d']['shard'] = [self.shard_id, self.shard_count] + payload["d"]["shard"] = [self.shard_id, self.shard_count] state = self._connection if state._activity is not None or state._status is not None: - payload['d']['presence'] = { - 'status': state._status, - 'game': state._activity, - 'since': 0, - 'afk': False - } + payload["d"]["presence"] = {"status": state._status, "game": state._activity, "since": 0, "afk": False} if state._intents is not None: - payload['d']['intents'] = state._intents.value + payload["d"]["intents"] = state._intents.value - await self.call_hooks('before_identify', self.shard_id, initial=self._initial_identify) + await self.call_hooks("before_identify", self.shard_id, initial=self._initial_identify) await self.send_as_json(payload) - _log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id) + _log.info("Shard ID %s has sent the IDENTIFY payload.", self.shard_id) async def resume(self) -> None: """Sends the RESUME packet.""" - payload = { - 'op': self.RESUME, - 'd': { - 'seq': self.sequence, - 'session_id': self.session_id, - 'token': self.token - } - } + payload = {"op": self.RESUME, "d": {"seq": self.sequence, "session_id": self.session_id, "token": self.token}} await self.send_as_json(payload) - _log.info('Shard ID %s has sent the RESUME payload.', self.shard_id) + _log.info("Shard ID %s has sent the RESUME payload.", self.shard_id) - - async def received_message(self, msg, /) -> None: + async def received_message(self, msg, /) -> None: if type(msg) is bytes: self._buffer.extend(msg) - if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff': + if len(msg) < 4 or msg[-4:] != b"\x00\x00\xff\xff": return msg = self._zlib.decompress(self._buffer) - msg = msg.decode('utf-8') + msg = msg.decode("utf-8") self._buffer = bytearray() self.log_receive(msg) msg = utils._from_json(msg) - _log.debug('For Shard ID %s: WebSocket Event: %s', self.shard_id, msg) - event = msg.get('t') + _log.debug("For Shard ID %s: WebSocket Event: %s", self.shard_id, msg) + event = msg.get("t") if event: - self._dispatch('socket_event_type', event) + self._dispatch("socket_event_type", event) - op = msg.get('op') - data = msg.get('d') - seq = msg.get('s') + op = msg.get("op") + data = msg.get("d") + seq = msg.get("s") if seq is not None: self.sequence = seq @@ -502,7 +514,7 @@ class DiscordWebSocket: # "reconnect" can only be handled by the Client # so we terminate our connection and raise an # internal exception signalling to reconnect. - _log.debug('Received RECONNECT opcode.') + _log.debug("Received RECONNECT opcode.") await self.close() raise ReconnectWebSocket(self.shard_id) @@ -518,7 +530,7 @@ class DiscordWebSocket: return if op == self.HELLO: - interval = data['heartbeat_interval'] / 1000.0 + interval = data["heartbeat_interval"] / 1000.0 self._keep_alive = KeepAliveHandler(ws=self, interval=interval, shard_id=self.shard_id) # send a heartbeat immediately await self.send_as_json(self._keep_alive.get_payload()) @@ -532,33 +544,41 @@ class DiscordWebSocket: self.sequence = None self.session_id = None - _log.info('Shard ID %s session has been invalidated.', self.shard_id) + _log.info("Shard ID %s session has been invalidated.", self.shard_id) await self.close(code=1000) raise ReconnectWebSocket(self.shard_id, resume=False) - _log.warning('Unknown OP code %s.', op) + _log.warning("Unknown OP code %s.", op) return - if event == 'READY': - self._trace = trace = data.get('_trace', []) - self.sequence = msg['s'] - self.session_id = data['session_id'] + if event == "READY": + self._trace = trace = data.get("_trace", []) + self.sequence = msg["s"] + self.session_id = data["session_id"] # pass back shard ID to ready handler - data['__shard_id__'] = self.shard_id - _log.info('Shard ID %s has connected to Gateway: %s (Session ID: %s).', - self.shard_id, ', '.join(trace), self.session_id) + data["__shard_id__"] = self.shard_id + _log.info( + "Shard ID %s has connected to Gateway: %s (Session ID: %s).", + self.shard_id, + ", ".join(trace), + self.session_id, + ) - elif event == 'RESUMED': - self._trace = trace = data.get('_trace', []) + elif event == "RESUMED": + self._trace = trace = data.get("_trace", []) # pass back the shard ID to the resumed handler - data['__shard_id__'] = self.shard_id - _log.info('Shard ID %s has successfully RESUMED session %s under trace %s.', - self.shard_id, self.session_id, ', '.join(trace)) + data["__shard_id__"] = self.shard_id + _log.info( + "Shard ID %s has successfully RESUMED session %s under trace %s.", + self.shard_id, + self.session_id, + ", ".join(trace), + ) try: func = self._discord_parsers[event] except KeyError: - _log.debug('Unknown event %s.', event) + _log.debug("Unknown event %s.", event) else: func(data) @@ -591,7 +611,7 @@ class DiscordWebSocket: def latency(self) -> float: """:class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.""" heartbeat = self._keep_alive - return float('inf') if heartbeat is None else heartbeat.latency + return float("inf") if heartbeat is None else heartbeat.latency def _can_handle_close(self) -> bool: code = self._close_code or self.socket.close_code @@ -612,10 +632,10 @@ class DiscordWebSocket: elif msg.type is aiohttp.WSMsgType.BINARY: await self.received_message(msg.data) elif msg.type is aiohttp.WSMsgType.ERROR: - _log.debug('Received %s', msg) + _log.debug("Received %s", msg) raise msg.data elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSE): - _log.debug('Received %s', msg) + _log.debug("Received %s", msg) raise WebSocketClosure except (asyncio.TimeoutError, WebSocketClosure) as e: # Ensure the keep alive handler is closed @@ -624,20 +644,20 @@ class DiscordWebSocket: self._keep_alive = None if isinstance(e, asyncio.TimeoutError): - _log.info('Timed out receiving packet. Attempting a reconnect.') + _log.info("Timed out receiving packet. Attempting a reconnect.") raise ReconnectWebSocket(self.shard_id) from None code = self._close_code or self.socket.close_code if self._can_handle_close(): - _log.info('Websocket closed with %s, attempting a reconnect.', code) + _log.info("Websocket closed with %s, attempting a reconnect.", code) raise ReconnectWebSocket(self.shard_id) from None else: - _log.info('Websocket closed with %s, cannot reconnect.', code) + _log.info("Websocket closed with %s, cannot reconnect.", code) raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None async def debug_send(self, data, /) -> None: await self._rate_limiter.block() - self._dispatch('socket_raw_send', data) + self._dispatch("socket_raw_send", data) await self.socket.send_str(data) async def send(self, data, /) -> None: @@ -659,65 +679,57 @@ class DiscordWebSocket: if not self._can_handle_close(): raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc - async def change_presence(self, *, activity: Optional[BaseActivity] = None, status: Optional[str] = None, since: float = 0.0) -> None: + async def change_presence( + self, *, activity: Optional[BaseActivity] = None, status: Optional[str] = None, since: float = 0.0 + ) -> None: if activity is not None: if not isinstance(activity, BaseActivity): - raise InvalidArgument('activity must derive from BaseActivity.') + raise InvalidArgument("activity must derive from BaseActivity.") activities = [activity.to_dict()] else: activities = [] - if status == 'idle': + if status == "idle": since = int(time.time() * 1000) - payload = { - 'op': self.PRESENCE, - 'd': { - 'activities': activities, - 'afk': False, - 'since': since, - 'status': status - } - } + payload = {"op": self.PRESENCE, "d": {"activities": activities, "afk": False, "since": since, "status": status}} sent = utils._to_json(payload) _log.debug('Sending "%s" to change status', sent) await self.send(sent) - async def request_chunks(self, guild_id: int, query: Optional[str] = None, *, limit: int, user_ids: Optional[List[int]] = None, presences: bool = False, nonce: Optional[int] = None) -> None: - payload = { - 'op': self.REQUEST_MEMBERS, - 'd': { - 'guild_id': guild_id, - 'presences': presences, - 'limit': limit - } - } + async def request_chunks( + self, + guild_id: int, + query: Optional[str] = None, + *, + limit: int, + user_ids: Optional[List[int]] = None, + presences: bool = False, + nonce: Optional[int] = None, + ) -> None: + payload = {"op": self.REQUEST_MEMBERS, "d": {"guild_id": guild_id, "presences": presences, "limit": limit}} if nonce: - payload['d']['nonce'] = nonce + payload["d"]["nonce"] = nonce if user_ids: - payload['d']['user_ids'] = user_ids + payload["d"]["user_ids"] = user_ids if query is not None: - payload['d']['query'] = query - + payload["d"]["query"] = query await self.send_as_json(payload) - async def voice_state(self, guild_id: int, channel_id: int, self_mute: bool = False, self_deaf: bool = False) -> None: + async def voice_state( + self, guild_id: int, channel_id: int, self_mute: bool = False, self_deaf: bool = False + ) -> None: payload = { - 'op': self.VOICE_STATE, - 'd': { - 'guild_id': guild_id, - 'channel_id': channel_id, - 'self_mute': self_mute, - 'self_deaf': self_deaf - } + "op": self.VOICE_STATE, + "d": {"guild_id": guild_id, "channel_id": channel_id, "self_mute": self_mute, "self_deaf": self_deaf}, } - _log.debug('Updating our voice state to %s.', payload) + _log.debug("Updating our voice state to %s.", payload) await self.send_as_json(payload) async def close(self, code: int = 4000) -> None: @@ -728,6 +740,7 @@ class DiscordWebSocket: self._close_code = code await self.socket.close(code=code) + class DiscordVoiceWebSocket: """Implements the websocket protocol for handling voice connections. @@ -759,20 +772,22 @@ class DiscordVoiceWebSocket: Receive only. Indicates a user has disconnected from voice. """ - IDENTIFY = 0 - SELECT_PROTOCOL = 1 - READY = 2 - HEARTBEAT = 3 + IDENTIFY = 0 + SELECT_PROTOCOL = 1 + READY = 2 + HEARTBEAT = 3 SESSION_DESCRIPTION = 4 - SPEAKING = 5 - HEARTBEAT_ACK = 6 - RESUME = 7 - HELLO = 8 - RESUMED = 9 - CLIENT_CONNECT = 12 - CLIENT_DISCONNECT = 13 + SPEAKING = 5 + HEARTBEAT_ACK = 6 + RESUME = 7 + HELLO = 8 + RESUMED = 9 + CLIENT_CONNECT = 12 + CLIENT_DISCONNECT = 13 - def __init__(self, socket: aiohttp.ClientWebSocketResponse, loop: asyncio.AbstractEventLoop, *, hook: Optional[Coro] = None) -> None: + def __init__( + self, socket: aiohttp.ClientWebSocketResponse, loop: asyncio.AbstractEventLoop, *, hook: Optional[Coro] = None + ) -> None: self.ws: aiohttp.ClientWebSocketResponse = socket self.loop: asyncio.AbstractEventLoop = loop self._keep_alive: VoiceKeepAliveHandler = utils.MISSING @@ -784,14 +799,13 @@ class DiscordVoiceWebSocket: self.thread_id: int = utils.MISSING if hook: # we want to redeclare self._hook - self._hook = hook # type: ignore + self._hook = hook # type: ignore async def _hook(self, *args: Any) -> Any: pass - async def send_as_json(self, data) -> None: - _log.debug('Sending voice websocket frame: %s.', data) + _log.debug("Sending voice websocket frame: %s.", data) await self.ws.send_str(utils._to_json(data)) send_heartbeat = send_as_json @@ -799,32 +813,30 @@ class DiscordVoiceWebSocket: async def resume(self) -> None: state = self._connection payload = { - 'op': self.RESUME, - 'd': { - 'token': state.token, - 'server_id': str(state.server_id), - 'session_id': state.session_id - } + "op": self.RESUME, + "d": {"token": state.token, "server_id": str(state.server_id), "session_id": state.session_id}, } await self.send_as_json(payload) async def identify(self): state = self._connection payload = { - 'op': self.IDENTIFY, - 'd': { - 'server_id': str(state.server_id), - 'user_id': str(state.user.id), - 'session_id': state.session_id, - 'token': state.token - } + "op": self.IDENTIFY, + "d": { + "server_id": str(state.server_id), + "user_id": str(state.user.id), + "session_id": state.session_id, + "token": state.token, + }, } await self.send_as_json(payload) @classmethod - async def from_client(cls: Type[DVWS], client: VoiceClient, *, resume: bool = False, hook: Optional[Coro] = None) -> DVWS: + async def from_client( + cls: Type[DVWS], client: VoiceClient, *, resume: bool = False, hook: Optional[Coro] = None + ) -> DVWS: """Creates a voice websocket for the :class:`VoiceClient`.""" - gateway = 'wss://' + client.endpoint + '/?v=4' + gateway = "wss://" + client.endpoint + "/?v=4" http = client._state.http socket = await http.ws_connect(gateway, compress=15) ws = cls(socket, loop=client.loop, hook=hook) @@ -842,57 +854,38 @@ class DiscordVoiceWebSocket: async def select_protocol(self, ip, port, mode) -> None: payload = { - 'op': self.SELECT_PROTOCOL, - 'd': { - 'protocol': 'udp', - 'data': { - 'address': ip, - 'port': port, - 'mode': mode - } - } + "op": self.SELECT_PROTOCOL, + "d": {"protocol": "udp", "data": {"address": ip, "port": port, "mode": mode}}, } await self.send_as_json(payload) async def client_connect(self) -> None: - payload = { - 'op': self.CLIENT_CONNECT, - 'd': { - 'audio_ssrc': self._connection.ssrc - } - } + payload = {"op": self.CLIENT_CONNECT, "d": {"audio_ssrc": self._connection.ssrc}} await self.send_as_json(payload) async def speak(self, state=SpeakingState.voice) -> None: - payload = { - 'op': self.SPEAKING, - 'd': { - 'speaking': int(state), - 'delay': 0 - } - } + payload = {"op": self.SPEAKING, "d": {"speaking": int(state), "delay": 0}} await self.send_as_json(payload) - - async def received_message(self, msg) -> None: - _log.debug('Voice websocket frame received: %s', msg) - op = msg['op'] - data = msg.get('d') + async def received_message(self, msg) -> None: + _log.debug("Voice websocket frame received: %s", msg) + op = msg["op"] + data = msg.get("d") if op == self.READY: await self.initial_connection(data) elif op == self.HEARTBEAT_ACK: self._keep_alive.ack() elif op == self.RESUMED: - _log.info('Voice RESUME succeeded.') + _log.info("Voice RESUME succeeded.") elif op == self.SESSION_DESCRIPTION: - self._connection.mode = data['mode'] + self._connection.mode = data["mode"] await self.load_secret_key(data) elif op == self.HELLO: - interval = data['heartbeat_interval'] / 1000.0 + interval = data["heartbeat_interval"] / 1000.0 self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0)) self._keep_alive.start() @@ -900,53 +893,52 @@ class DiscordVoiceWebSocket: async def initial_connection(self, data) -> None: state = self._connection - state.ssrc = data['ssrc'] - state.voice_port = data['port'] - state.endpoint_ip = data['ip'] + state.ssrc = data["ssrc"] + state.voice_port = data["port"] + state.endpoint_ip = data["ip"] packet = bytearray(70) - struct.pack_into('>H', packet, 0, 1) # 1 = Send - struct.pack_into('>H', packet, 2, 70) # 70 = Length - struct.pack_into('>I', packet, 4, state.ssrc) + struct.pack_into(">H", packet, 0, 1) # 1 = Send + struct.pack_into(">H", packet, 2, 70) # 70 = Length + struct.pack_into(">I", packet, 4, state.ssrc) state.socket.sendto(packet, (state.endpoint_ip, state.voice_port)) recv = await self.loop.sock_recv(state.socket, 70) - _log.debug('received packet in initial_connection: %s', recv) + _log.debug("received packet in initial_connection: %s", recv) # the ip is ascii starting at the 4th byte and ending at the first null ip_start = 4 ip_end = recv.index(0, ip_start) - state.ip = recv[ip_start:ip_end].decode('ascii') + state.ip = recv[ip_start:ip_end].decode("ascii") - state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0] - _log.debug('detected ip: %s port: %s', state.ip, state.port) + state.port = struct.unpack_from(">H", recv, len(recv) - 2)[0] + _log.debug("detected ip: %s port: %s", state.ip, state.port) # there *should* always be at least one supported mode (xsalsa20_poly1305) - modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes] - _log.debug('received supported encryption modes: %s', ", ".join(modes)) + modes = [mode for mode in data["modes"] if mode in self._connection.supported_modes] + _log.debug("received supported encryption modes: %s", ", ".join(modes)) mode = modes[0] await self.select_protocol(state.ip, state.port, mode) - _log.info('selected the voice protocol for use (%s)', mode) + _log.info("selected the voice protocol for use (%s)", mode) @property def latency(self) -> float: """:class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds.""" heartbeat = self._keep_alive - return float('inf') if heartbeat is None else heartbeat.latency + return float("inf") if heartbeat is None else heartbeat.latency @property def average_latency(self) -> float: """:class:`list`: Average of last 20 HEARTBEAT latencies.""" heartbeat = self._keep_alive if heartbeat is None or not heartbeat.recent_ack_latencies: - return float('inf') + return float("inf") return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies) - async def load_secret_key(self, data) -> None: - _log.info('received secret key for voice connection') - self.secret_key = self._connection.secret_key = data.get('secret_key') + _log.info("received secret key for voice connection") + self.secret_key = self._connection.secret_key = data.get("secret_key") await self.speak() await self.speak(False) @@ -956,10 +948,10 @@ class DiscordVoiceWebSocket: if msg.type is aiohttp.WSMsgType.TEXT: await self.received_message(utils._from_json(msg.data)) elif msg.type is aiohttp.WSMsgType.ERROR: - _log.debug('Received %s', msg) + _log.debug("Received %s", msg) raise ConnectionClosed(self.ws, shard_id=None) from msg.data elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING): - _log.debug('Received %s', msg) + _log.debug("Received %s", msg) raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code) async def close(self, code: int = 1000) -> None: diff --git a/discord/guild.py b/discord/guild.py index cb53f44c..4a599164 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -79,9 +79,7 @@ from .file import File from .welcome_screen import WelcomeScreen, WelcomeChannel -__all__ = ( - 'Guild', -) +__all__ = ("Guild",) MISSING = utils.MISSING @@ -240,45 +238,45 @@ class Guild(Hashable): """ __slots__ = ( - 'afk_timeout', - 'afk_channel', - 'name', - 'id', - 'unavailable', - 'region', - 'owner_id', - 'mfa_level', - 'emojis', - 'stickers', - 'features', - 'verification_level', - 'explicit_content_filter', - 'default_notifications', - 'description', - 'max_presences', - 'max_members', - 'max_video_channel_users', - 'premium_tier', - 'premium_subscription_count', - 'preferred_locale', - 'nsfw_level', - '_members', - '_channels', - '_icon', - '_banner', - '_state', - '_roles', - '_member_count', - '_large', - '_splash', - '_voice_states', - '_system_channel_id', - '_system_channel_flags', - '_discovery_splash', - '_rules_channel_id', - '_public_updates_channel_id', - '_stage_instances', - '_threads', + "afk_timeout", + "afk_channel", + "name", + "id", + "unavailable", + "region", + "owner_id", + "mfa_level", + "emojis", + "stickers", + "features", + "verification_level", + "explicit_content_filter", + "default_notifications", + "description", + "max_presences", + "max_members", + "max_video_channel_users", + "premium_tier", + "premium_subscription_count", + "preferred_locale", + "nsfw_level", + "_members", + "_channels", + "_icon", + "_banner", + "_state", + "_roles", + "_member_count", + "_large", + "_splash", + "_voice_states", + "_system_channel_id", + "_system_channel_flags", + "_discovery_splash", + "_rules_channel_id", + "_public_updates_channel_id", + "_stage_instances", + "_threads", ) _PREMIUM_GUILD_LIMITS: ClassVar[Dict[Optional[int], _GuildLimit]] = { @@ -338,21 +336,23 @@ class Guild(Hashable): return to_remove def __str__(self) -> str: - return self.name or '' + return self.name or "" def __repr__(self) -> str: attrs = ( - ('id', self.id), - ('name', self.name), - ('shard_id', self.shard_id), - ('chunked', self.chunked), - ('member_count', getattr(self, '_member_count', None)), + ("id", self.id), + ("name", self.name), + ("shard_id", self.shard_id), + ("chunked", self.chunked), + ("member_count", getattr(self, "_member_count", None)), ) - inner = ' '.join('%s=%r' % t for t in attrs) - return f'' + inner = " ".join("%s=%r" % t for t in attrs) + return f"" - def _update_voice_state(self, data: GuildVoiceState, channel_id: int) -> Tuple[Optional[Member], VoiceState, VoiceState]: - user_id = int(data['user_id']) + def _update_voice_state( + self, data: GuildVoiceState, channel_id: int + ) -> Tuple[Optional[Member], VoiceState, VoiceState]: + user_id = int(data["user_id"]) channel = self.get_channel(channel_id) try: # check if we should remove the voice state from cache @@ -372,7 +372,7 @@ class Guild(Hashable): member = self.get_member(user_id) if member is None: try: - member = Member(data=data['member'], state=self._state, guild=self) + member = Member(data=data["member"], state=self._state, guild=self) except KeyError: member = None @@ -404,57 +404,57 @@ class Guild(Hashable): def _from_data(self, guild: GuildPayload) -> None: # according to Stan, this is always available even if the guild is unavailable # I don't have this guarantee when someone updates the guild. - member_count = guild.get('member_count', None) + member_count = guild.get("member_count", None) if member_count is not None: self._member_count: int = member_count - self.name: str = guild.get('name') - self.region: VoiceRegion = try_enum(VoiceRegion, guild.get('region')) - self.verification_level: VerificationLevel = try_enum(VerificationLevel, guild.get('verification_level')) + self.name: str = guild.get("name") + self.region: VoiceRegion = try_enum(VoiceRegion, guild.get("region")) + self.verification_level: VerificationLevel = try_enum(VerificationLevel, guild.get("verification_level")) self.default_notifications: NotificationLevel = try_enum( - NotificationLevel, guild.get('default_message_notifications') + NotificationLevel, guild.get("default_message_notifications") ) - self.explicit_content_filter: ContentFilter = try_enum(ContentFilter, guild.get('explicit_content_filter', 0)) - self.afk_timeout: int = guild.get('afk_timeout') - self._icon: Optional[str] = guild.get('icon') - self._banner: Optional[str] = guild.get('banner') - self.unavailable: bool = guild.get('unavailable', False) - self.id: int = int(guild['id']) + self.explicit_content_filter: ContentFilter = try_enum(ContentFilter, guild.get("explicit_content_filter", 0)) + self.afk_timeout: int = guild.get("afk_timeout") + self._icon: Optional[str] = guild.get("icon") + self._banner: Optional[str] = guild.get("banner") + self.unavailable: bool = guild.get("unavailable", False) + self.id: int = int(guild["id"]) self._roles: Dict[int, Role] = {} state = self._state # speed up attribute access - for r in guild.get('roles', []): + for r in guild.get("roles", []): role = Role(guild=self, data=r, state=state) self._roles[role.id] = role - self.mfa_level: MFALevel = guild.get('mfa_level') - self.emojis: Tuple[Emoji, ...] = tuple(map(lambda d: state.store_emoji(self, d), guild.get('emojis', []))) + self.mfa_level: MFALevel = guild.get("mfa_level") + self.emojis: Tuple[Emoji, ...] = tuple(map(lambda d: state.store_emoji(self, d), guild.get("emojis", []))) self.stickers: Tuple[GuildSticker, ...] = tuple( - map(lambda d: state.store_sticker(self, d), guild.get('stickers', [])) + map(lambda d: state.store_sticker(self, d), guild.get("stickers", [])) ) - self.features: List[GuildFeature] = guild.get('features', []) - self._splash: Optional[str] = guild.get('splash') - self._system_channel_id: Optional[int] = utils._get_as_snowflake(guild, 'system_channel_id') - self.description: Optional[str] = guild.get('description') - self.max_presences: Optional[int] = guild.get('max_presences') - self.max_members: Optional[int] = guild.get('max_members') - self.max_video_channel_users: Optional[int] = guild.get('max_video_channel_users') - self.premium_tier: int = guild.get('premium_tier', 0) - self.premium_subscription_count: int = guild.get('premium_subscription_count') or 0 - self._system_channel_flags: int = guild.get('system_channel_flags', 0) - self.preferred_locale: Optional[str] = guild.get('preferred_locale') - self._discovery_splash: Optional[str] = guild.get('discovery_splash') - self._rules_channel_id: Optional[int] = utils._get_as_snowflake(guild, 'rules_channel_id') - self._public_updates_channel_id: Optional[int] = utils._get_as_snowflake(guild, 'public_updates_channel_id') - self.nsfw_level: NSFWLevel = try_enum(NSFWLevel, guild.get('nsfw_level', 0)) + self.features: List[GuildFeature] = guild.get("features", []) + self._splash: Optional[str] = guild.get("splash") + self._system_channel_id: Optional[int] = utils._get_as_snowflake(guild, "system_channel_id") + self.description: Optional[str] = guild.get("description") + self.max_presences: Optional[int] = guild.get("max_presences") + self.max_members: Optional[int] = guild.get("max_members") + self.max_video_channel_users: Optional[int] = guild.get("max_video_channel_users") + self.premium_tier: int = guild.get("premium_tier", 0) + self.premium_subscription_count: int = guild.get("premium_subscription_count") or 0 + self._system_channel_flags: int = guild.get("system_channel_flags", 0) + self.preferred_locale: Optional[str] = guild.get("preferred_locale") + self._discovery_splash: Optional[str] = guild.get("discovery_splash") + self._rules_channel_id: Optional[int] = utils._get_as_snowflake(guild, "rules_channel_id") + self._public_updates_channel_id: Optional[int] = utils._get_as_snowflake(guild, "public_updates_channel_id") + self.nsfw_level: NSFWLevel = try_enum(NSFWLevel, guild.get("nsfw_level", 0)) self._stage_instances: Dict[int, StageInstance] = {} - for s in guild.get('stage_instances', []): + for s in guild.get("stage_instances", []): stage_instance = StageInstance(guild=self, data=s, state=state) self._stage_instances[stage_instance.id] = stage_instance cache_joined = self._state.member_cache_flags.joined self_id = self._state.self_id - for mdata in guild.get('members', []): + for mdata in guild.get("members", []): member = Member(data=mdata, guild=self, state=state) if cache_joined or member.id == self_id: self._add_member(member) @@ -462,35 +462,35 @@ class Guild(Hashable): self._sync(guild) self._large: Optional[bool] = None if member_count is None else self._member_count >= 250 - self.owner_id: Optional[int] = utils._get_as_snowflake(guild, 'owner_id') - self.afk_channel: Optional[VocalGuildChannel] = self.get_channel(utils._get_as_snowflake(guild, 'afk_channel_id')) # type: ignore + self.owner_id: Optional[int] = utils._get_as_snowflake(guild, "owner_id") + self.afk_channel: Optional[VocalGuildChannel] = self.get_channel(utils._get_as_snowflake(guild, "afk_channel_id")) # type: ignore - for obj in guild.get('voice_states', []): - self._update_voice_state(obj, int(obj['channel_id'])) + for obj in guild.get("voice_states", []): + self._update_voice_state(obj, int(obj["channel_id"])) # TODO: refactor/remove? def _sync(self, data: GuildPayload) -> None: try: - self._large = data['large'] + self._large = data["large"] except KeyError: pass empty_tuple = tuple() - for presence in data.get('presences', []): - user_id = int(presence['user']['id']) + for presence in data.get("presences", []): + user_id = int(presence["user"]["id"]) member = self.get_member(user_id) if member is not None: member._presence_update(presence, empty_tuple) # type: ignore - if 'channels' in data: - channels = data['channels'] + if "channels" in data: + channels = data["channels"] for c in channels: - factory, ch_type = _guild_channel_factory(c['type']) + factory, ch_type = _guild_channel_factory(c["type"]) if factory: self._add_channel(factory(guild=self, data=c, state=self._state)) # type: ignore - if 'threads' in data: - threads = data['threads'] + if "threads" in data: + threads = data["threads"] for thread in threads: self._add_thread(Thread(guild=self, state=self._state, data=thread)) @@ -713,7 +713,7 @@ class Guild(Hashable): @property def emoji_limit(self) -> int: """:class:`int`: The maximum number of emoji slots this guild has.""" - more_emoji = 200 if 'MORE_EMOJI' in self.features else 50 + more_emoji = 200 if "MORE_EMOJI" in self.features else 50 return max(more_emoji, self._PREMIUM_GUILD_LIMITS[self.premium_tier].emoji) @property @@ -722,13 +722,13 @@ class Guild(Hashable): .. versionadded:: 2.0 """ - more_stickers = 60 if 'MORE_STICKERS' in self.features else 0 + more_stickers = 60 if "MORE_STICKERS" in self.features else 0 return max(more_stickers, self._PREMIUM_GUILD_LIMITS[self.premium_tier].stickers) @property def bitrate_limit(self) -> float: """:class:`float`: The maximum bitrate for voice channels this guild can have.""" - vip_guild = self._PREMIUM_GUILD_LIMITS[1].bitrate if 'VIP_REGIONS' in self.features else 96e3 + vip_guild = self._PREMIUM_GUILD_LIMITS[1].bitrate if "VIP_REGIONS" in self.features else 96e3 return max(vip_guild, self._PREMIUM_GUILD_LIMITS[self.premium_tier].bitrate) @property @@ -744,15 +744,15 @@ class Guild(Hashable): @property def humans(self) -> List[Member]: """List[:class:`Member`]: A list of human members that belong to this guild. - - .. versionadded:: 2.0 """ + + .. versionadded:: 2.0""" return [member for member in self.members if not member.bot] @property def bots(self) -> List[Member]: """List[:class:`Member`]: A list of bots that belong to this guild. - - .. versionadded:: 2.0 """ + + .. versionadded:: 2.0""" return [member for member in self.members if member.bot] def get_member(self, user_id: int, /) -> Optional[Member]: @@ -872,21 +872,21 @@ class Guild(Hashable): """Optional[:class:`Asset`]: Returns the guild's banner asset, if available.""" if self._banner is None: return None - return Asset._from_guild_image(self._state, self.id, self._banner, path='banners') + return Asset._from_guild_image(self._state, self.id, self._banner, path="banners") @property def splash(self) -> Optional[Asset]: """Optional[:class:`Asset`]: Returns the guild's invite splash asset, if available.""" if self._splash is None: return None - return Asset._from_guild_image(self._state, self.id, self._splash, path='splashes') + return Asset._from_guild_image(self._state, self.id, self._splash, path="splashes") @property def discovery_splash(self) -> Optional[Asset]: """Optional[:class:`Asset`]: Returns the guild's discovery splash asset, if available.""" if self._discovery_splash is None: return None - return Asset._from_guild_image(self._state, self.id, self._discovery_splash, path='discovery-splashes') + return Asset._from_guild_image(self._state, self.id, self._discovery_splash, path="discovery-splashes") @property def member_count(self) -> int: @@ -910,7 +910,7 @@ class Guild(Hashable): If this value returns ``False``, then you should request for offline members. """ - count = getattr(self, '_member_count', None) + count = getattr(self, "_member_count", None) if count is None: return False return count == len(self._members) @@ -957,7 +957,7 @@ class Guild(Hashable): result = None members = self.members - if len(name) > 5 and name[-5] == '#': + if len(name) > 5 and name[-5] == "#": # The 5 length is checking to see if #0000 is in the string, # as a#0000 has a length of 6, the minimum for a potential # discriminator lookup. @@ -985,20 +985,20 @@ class Guild(Hashable): if overwrites is MISSING: overwrites = {} elif not isinstance(overwrites, dict): - raise InvalidArgument('overwrites parameter expects a dict.') + raise InvalidArgument("overwrites parameter expects a dict.") perms = [] for target, perm in overwrites.items(): if not isinstance(perm, PermissionOverwrite): - raise InvalidArgument(f'Expected PermissionOverwrite received {perm.__class__.__name__}') + raise InvalidArgument(f"Expected PermissionOverwrite received {perm.__class__.__name__}") allow, deny = perm.pair() - payload = {'allow': allow.value, 'deny': deny.value, 'id': target.id} + payload = {"allow": allow.value, "deny": deny.value, "id": target.id} if isinstance(target, Role): - payload['type'] = abc._Overwrites.ROLE + payload["type"] = abc._Overwrites.ROLE else: - payload['type'] = abc._Overwrites.MEMBER + payload["type"] = abc._Overwrites.MEMBER perms.append(payload) @@ -1099,16 +1099,16 @@ class Guild(Hashable): options = {} if position is not MISSING: - options['position'] = position + options["position"] = position if topic is not MISSING: - options['topic'] = topic + options["topic"] = topic if slowmode_delay is not MISSING: - options['rate_limit_per_user'] = slowmode_delay + options["rate_limit_per_user"] = slowmode_delay if nsfw is not MISSING: - options['nsfw'] = nsfw + options["nsfw"] = nsfw data = await self._create_channel( name, overwrites=overwrites, channel_type=ChannelType.text, category=category, reason=reason, **options @@ -1183,19 +1183,19 @@ class Guild(Hashable): """ options = {} if position is not MISSING: - options['position'] = position + options["position"] = position if bitrate is not MISSING: - options['bitrate'] = bitrate + options["bitrate"] = bitrate if user_limit is not MISSING: - options['user_limit'] = user_limit + options["user_limit"] = user_limit if rtc_region is not MISSING: - options['rtc_region'] = None if rtc_region is None else str(rtc_region) + options["rtc_region"] = None if rtc_region is None else str(rtc_region) if video_quality_mode is not MISSING: - options['video_quality_mode'] = video_quality_mode.value + options["video_quality_mode"] = video_quality_mode.value data = await self._create_channel( name, overwrites=overwrites, channel_type=ChannelType.voice, category=category, reason=reason, **options @@ -1258,13 +1258,18 @@ class Guild(Hashable): """ options: Dict[str, Any] = { - 'topic': topic, + "topic": topic, } if position is not MISSING: - options['position'] = position + options["position"] = position data = await self._create_channel( - name, overwrites=overwrites, channel_type=ChannelType.stage_voice, category=category, reason=reason, **options + name, + overwrites=overwrites, + channel_type=ChannelType.stage_voice, + category=category, + reason=reason, + **options, ) channel = StageChannel(state=self._state, guild=self, data=data) @@ -1305,7 +1310,7 @@ class Guild(Hashable): """ options: Dict[str, Any] = {} if position is not MISSING: - options['position'] = position + options["position"] = position data = await self._create_channel( name, overwrites=overwrites, channel_type=ChannelType.category, reason=reason, **options @@ -1480,108 +1485,108 @@ class Guild(Hashable): fields: Dict[str, Any] = {} if name is not MISSING: - fields['name'] = name + fields["name"] = name if description is not MISSING: - fields['description'] = description + fields["description"] = description if preferred_locale is not MISSING: - fields['preferred_locale'] = preferred_locale + fields["preferred_locale"] = preferred_locale if afk_timeout is not MISSING: - fields['afk_timeout'] = afk_timeout + fields["afk_timeout"] = afk_timeout if icon is not MISSING: if icon is None: - fields['icon'] = icon + fields["icon"] = icon else: - fields['icon'] = utils._bytes_to_base64_data(icon) + fields["icon"] = utils._bytes_to_base64_data(icon) if banner is not MISSING: if banner is None: - fields['banner'] = banner + fields["banner"] = banner else: - fields['banner'] = utils._bytes_to_base64_data(banner) + fields["banner"] = utils._bytes_to_base64_data(banner) if splash is not MISSING: if splash is None: - fields['splash'] = splash + fields["splash"] = splash else: - fields['splash'] = utils._bytes_to_base64_data(splash) + fields["splash"] = utils._bytes_to_base64_data(splash) if discovery_splash is not MISSING: if discovery_splash is None: - fields['discovery_splash'] = discovery_splash + fields["discovery_splash"] = discovery_splash else: - fields['discovery_splash'] = utils._bytes_to_base64_data(discovery_splash) + fields["discovery_splash"] = utils._bytes_to_base64_data(discovery_splash) if default_notifications is not MISSING: if not isinstance(default_notifications, NotificationLevel): - raise InvalidArgument('default_notifications field must be of type NotificationLevel') - fields['default_message_notifications'] = default_notifications.value + raise InvalidArgument("default_notifications field must be of type NotificationLevel") + fields["default_message_notifications"] = default_notifications.value if afk_channel is not MISSING: if afk_channel is None: - fields['afk_channel_id'] = afk_channel + fields["afk_channel_id"] = afk_channel else: - fields['afk_channel_id'] = afk_channel.id + fields["afk_channel_id"] = afk_channel.id if system_channel is not MISSING: if system_channel is None: - fields['system_channel_id'] = system_channel + fields["system_channel_id"] = system_channel else: - fields['system_channel_id'] = system_channel.id + fields["system_channel_id"] = system_channel.id if rules_channel is not MISSING: if rules_channel is None: - fields['rules_channel_id'] = rules_channel + fields["rules_channel_id"] = rules_channel else: - fields['rules_channel_id'] = rules_channel.id + fields["rules_channel_id"] = rules_channel.id if public_updates_channel is not MISSING: if public_updates_channel is None: - fields['public_updates_channel_id'] = public_updates_channel + fields["public_updates_channel_id"] = public_updates_channel else: - fields['public_updates_channel_id'] = public_updates_channel.id + fields["public_updates_channel_id"] = public_updates_channel.id if owner is not MISSING: if self.owner_id != self._state.self_id: - raise InvalidArgument('To transfer ownership you must be the owner of the guild.') + raise InvalidArgument("To transfer ownership you must be the owner of the guild.") - fields['owner_id'] = owner.id + fields["owner_id"] = owner.id if region is not MISSING: - fields['region'] = str(region) + fields["region"] = str(region) if verification_level is not MISSING: if not isinstance(verification_level, VerificationLevel): - raise InvalidArgument('verification_level field must be of type VerificationLevel') + raise InvalidArgument("verification_level field must be of type VerificationLevel") - fields['verification_level'] = verification_level.value + fields["verification_level"] = verification_level.value if explicit_content_filter is not MISSING: if not isinstance(explicit_content_filter, ContentFilter): - raise InvalidArgument('explicit_content_filter field must be of type ContentFilter') + raise InvalidArgument("explicit_content_filter field must be of type ContentFilter") - fields['explicit_content_filter'] = explicit_content_filter.value + fields["explicit_content_filter"] = explicit_content_filter.value if system_channel_flags is not MISSING: if not isinstance(system_channel_flags, SystemChannelFlags): - raise InvalidArgument('system_channel_flags field must be of type SystemChannelFlags') + raise InvalidArgument("system_channel_flags field must be of type SystemChannelFlags") - fields['system_channel_flags'] = system_channel_flags.value + fields["system_channel_flags"] = system_channel_flags.value if community is not MISSING: features = [] if community: - if 'rules_channel_id' in fields and 'public_updates_channel_id' in fields: - features.append('COMMUNITY') + if "rules_channel_id" in fields and "public_updates_channel_id" in fields: + features.append("COMMUNITY") else: raise InvalidArgument( - 'community field requires both rules_channel and public_updates_channel fields to be provided' + "community field requires both rules_channel and public_updates_channel fields to be provided" ) - fields['features'] = features + fields["features"] = features data = await http.edit_guild(self.id, reason=reason, **fields) return Guild(data=data, state=self._state) @@ -1612,9 +1617,9 @@ class Guild(Hashable): data = await self._state.http.get_all_guild_channels(self.id) def convert(d): - factory, ch_type = _guild_channel_factory(d['type']) + factory, ch_type = _guild_channel_factory(d["type"]) if factory is None: - raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(d)) + raise InvalidData("Unknown channel type {type} for channel ID {id}.".format_map(d)) channel = factory(guild=self, state=self._state, data=d) return channel @@ -1641,10 +1646,10 @@ class Guild(Hashable): The active threads """ data = await self._state.http.get_active_threads(self.id) - threads = [Thread(guild=self, state=self._state, data=d) for d in data.get('threads', [])] + threads = [Thread(guild=self, state=self._state, data=d) for d in data.get("threads", [])] thread_lookup: Dict[int, Thread] = {thread.id: thread for thread in threads} - for member in data.get('members', []): - thread = thread_lookup.get(int(member['id'])) + for member in data.get("members", []): + thread = thread_lookup.get(int(member["id"])) if thread is not None: thread._add_member(ThreadMember(parent=thread, data=member)) @@ -1700,7 +1705,7 @@ class Guild(Hashable): """ if not self._state._intents.members: - raise ClientException('Intents.members must be enabled to use this.') + raise ClientException("Intents.members must be enabled to use this.") return MemberIterator(self, limit=limit, after=after) @@ -1791,7 +1796,7 @@ class Guild(Hashable): The :class:`BanEntry` object for the specified user. """ data: BanPayload = await self._state.http.get_ban(user.id, self.id) - return BanEntry(user=User(state=self._state, data=data['user']), reason=data['reason']) + return BanEntry(user=User(state=self._state, data=data["user"]), reason=data["reason"]) async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, Thread]: """|coro| @@ -1824,16 +1829,16 @@ class Guild(Hashable): """ data = await self._state.http.get_channel(channel_id) - factory, ch_type = _threaded_guild_channel_factory(data['type']) + factory, ch_type = _threaded_guild_channel_factory(data["type"]) if factory is None: - raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data)) + raise InvalidData("Unknown channel type {type} for channel ID {id}.".format_map(data)) if ch_type in (ChannelType.group, ChannelType.private): - raise InvalidData('Channel ID resolved to a private channel') + raise InvalidData("Channel ID resolved to a private channel") - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) if self.id != guild_id: - raise InvalidData('Guild ID resolved to a different guild') + raise InvalidData("Guild ID resolved to a different guild") channel: GuildChannel = factory(guild=self, state=self._state, data=data) # type: ignore return channel @@ -1860,7 +1865,7 @@ class Guild(Hashable): """ data: List[BanPayload] = await self._state.http.get_bans(self.id) - return [BanEntry(user=User(state=self._state, data=e['user']), reason=e['reason']) for e in data] + return [BanEntry(user=User(state=self._state, data=e["user"]), reason=e["reason"]) for e in data] async def prune_members( self, @@ -1920,7 +1925,7 @@ class Guild(Hashable): """ if not isinstance(days, int): - raise InvalidArgument(f'Expected int for ``days``, received {days.__class__.__name__} instead.') + raise InvalidArgument(f"Expected int for ``days``, received {days.__class__.__name__} instead.") if roles: role_ids = [str(role.id) for role in roles] @@ -1930,7 +1935,7 @@ class Guild(Hashable): data = await self._state.http.prune_members( self.id, days, compute_prune_count=compute_prune_count, roles=role_ids, reason=reason ) - return data['pruned'] + return data["pruned"] async def templates(self) -> List[Template]: """|coro| @@ -2012,7 +2017,7 @@ class Guild(Hashable): """ if not isinstance(days, int): - raise InvalidArgument(f'Expected int for ``days``, received {days.__class__.__name__} instead.') + raise InvalidArgument(f"Expected int for ``days``, received {days.__class__.__name__} instead.") if roles: role_ids = [str(role.id) for role in roles] @@ -2020,7 +2025,7 @@ class Guild(Hashable): role_ids = [] data = await self._state.http.estimate_pruned_members(self.id, days, role_ids) - return data['pruned'] + return data["pruned"] async def invites(self) -> List[Invite]: """|coro| @@ -2046,7 +2051,7 @@ class Guild(Hashable): data = await self._state.http.invites_from(self.id) result = [] for invite in data: - channel = self.get_channel(int(invite['channel']['id'])) + channel = self.get_channel(int(invite["channel"]["id"])) result.append(Invite(state=self._state, data=invite, guild=self, channel=channel)) return result @@ -2070,10 +2075,10 @@ class Guild(Hashable): """ from .template import Template - payload = {'name': name} + payload = {"name": name} if description: - payload['description'] = description + payload["description"] = description data = await self._state.http.create_template(self.id, payload) @@ -2130,9 +2135,9 @@ class Guild(Hashable): data = await self._state.http.get_all_integrations(self.id) def convert(d): - factory, _ = _integration_factory(d['type']) + factory, _ = _integration_factory(d["type"]) if factory is None: - raise InvalidData('Unknown integration type {type!r} for integration ID {id}'.format_map(d)) + raise InvalidData("Unknown integration type {type!r} for integration ID {id}".format_map(d)) return factory(guild=self, data=d) return [convert(d) for d in data] @@ -2237,20 +2242,20 @@ class Guild(Hashable): The created sticker. """ payload = { - 'name': name, + "name": name, } if description: - payload['description'] = description + payload["description"] = description try: emoji = unicodedata.name(emoji) except TypeError: pass else: - emoji = emoji.replace(' ', '_') + emoji = emoji.replace(" ", "_") - payload['tags'] = emoji + payload["tags"] = emoji data = await self._state.http.create_guild_sticker(self.id, payload, file, reason) return self._state.store_sticker(self, data) @@ -2517,24 +2522,24 @@ class Guild(Hashable): """ fields: Dict[str, Any] = {} if permissions is not MISSING: - fields['permissions'] = str(permissions.value) + fields["permissions"] = str(permissions.value) else: - fields['permissions'] = '0' + fields["permissions"] = "0" actual_colour = colour or color or Colour.default() if isinstance(actual_colour, int): - fields['color'] = actual_colour + fields["color"] = actual_colour else: - fields['color'] = actual_colour.value + fields["color"] = actual_colour.value if hoist is not MISSING: - fields['hoist'] = hoist + fields["hoist"] = hoist if mentionable is not MISSING: - fields['mentionable'] = mentionable + fields["mentionable"] = mentionable if name is not MISSING: - fields['name'] = name + fields["name"] = name data = await self._state.http.create_role(self.id, reason=reason, **fields) role = Role(guild=self, data=data, state=self._state) @@ -2587,12 +2592,12 @@ class Guild(Hashable): A list of all the roles in the guild. """ if not isinstance(positions, dict): - raise InvalidArgument('positions parameter expects a dict.') + raise InvalidArgument("positions parameter expects a dict.") role_positions: List[Dict[str, Any]] = [] for role, position in positions.items(): - payload = {'id': role.id, 'position': position} + payload = {"id": role.id, "position": position} role_positions.append(payload) @@ -2665,16 +2670,16 @@ class Guild(Hashable): The edited welcome screen. """ try: - welcome_channels = kwargs['welcome_channels'] + welcome_channels = kwargs["welcome_channels"] except KeyError: pass else: welcome_channels_serialised = [] for wc in welcome_channels: if not isinstance(wc, WelcomeChannel): - raise InvalidArgument('welcome_channels parameter must be a list of WelcomeChannel') + raise InvalidArgument("welcome_channels parameter must be a list of WelcomeChannel") welcome_channels_serialised.append(wc.to_dict()) - kwargs['welcome_channels'] = welcome_channels_serialised + kwargs["welcome_channels"] = welcome_channels_serialised if kwargs: data = await self._state.http.edit_welcome_screen(self.id, kwargs) @@ -2793,19 +2798,19 @@ class Guild(Hashable): # we start with { code: abc } payload = await self._state.http.get_vanity_code(self.id) - if not payload['code']: + if not payload["code"]: return None # get the vanity URL channel since default channels aren't # reliable or a thing anymore - data = await self._state.http.get_invite(payload['code']) + data = await self._state.http.get_invite(payload["code"]) - channel = self.get_channel(int(data['channel']['id'])) - payload['revoked'] = False - payload['temporary'] = False - payload['max_uses'] = 0 - payload['max_age'] = 0 - payload['uses'] = payload.get('uses', 0) + channel = self.get_channel(int(data["channel"]["id"])) + payload["revoked"] = False + payload["temporary"] = False + payload["max_uses"] = 0 + payload["max_age"] = 0 + payload["uses"] = payload.get("uses", 0) return Invite(state=self._state, data=payload, guild=self, channel=channel) # TODO: use MISSING when async iterators get refactored @@ -2882,7 +2887,13 @@ class Guild(Hashable): action = action.value return AuditLogIterator( - self, before=before, after=after, limit=limit, oldest_first=oldest_first, user_id=user_id, action_type=action + self, + before=before, + after=after, + limit=limit, + oldest_first=oldest_first, + user_id=user_id, + action_type=action, ) async def widget(self) -> Widget: @@ -2936,9 +2947,9 @@ class Guild(Hashable): """ payload = {} if channel is not MISSING: - payload['channel_id'] = None if channel is None else channel.id + payload["channel_id"] = None if channel is None else channel.id if enabled is not MISSING: - payload['enabled'] = enabled + payload["enabled"] = enabled await self._state.http.edit_widget(self.id, payload=payload) @@ -2964,7 +2975,7 @@ class Guild(Hashable): """ if not self._state._intents.members: - raise ClientException('Intents.members must be enabled to use this.') + raise ClientException("Intents.members must be enabled to use this.") if not self._state.is_guild_evicted(self): return await self._state.chunk_guild(self, cache=cache) @@ -3025,20 +3036,20 @@ class Guild(Hashable): """ if presences and not self._state._intents.presences: - raise ClientException('Intents.presences must be enabled to use this.') + raise ClientException("Intents.presences must be enabled to use this.") if query is None: - if query == '': - raise ValueError('Cannot pass empty query string.') + if query == "": + raise ValueError("Cannot pass empty query string.") if user_ids is None: - raise ValueError('Must pass either query or user_ids') + raise ValueError("Must pass either query or user_ids") if user_ids is not None and query is not None: - raise ValueError('Cannot pass both query and user_ids') + raise ValueError("Cannot pass both query and user_ids") if user_ids is not None and not user_ids: - raise ValueError('user_ids must contain at least 1 value') + raise ValueError("user_ids must contain at least 1 value") limit = min(100, limit or 5) return await self._state.query_members( diff --git a/discord/http.py b/discord/http.py index 4f86fc87..e9b67420 100644 --- a/discord/http.py +++ b/discord/http.py @@ -48,7 +48,15 @@ import weakref import aiohttp -from .errors import HTTPException, Forbidden, NotFound, LoginFailure, DiscordServerError, GatewayNotFound, InvalidArgument +from .errors import ( + HTTPException, + Forbidden, + NotFound, + LoginFailure, + DiscordServerError, + GatewayNotFound, + InvalidArgument, +) from .gateway import DiscordClientWebSocketResponse from . import __version__, utils from .utils import MISSING @@ -90,16 +98,16 @@ if TYPE_CHECKING: from types import TracebackType - T = TypeVar('T') - BE = TypeVar('BE', bound=BaseException) - MU = TypeVar('MU', bound='MaybeUnlock') + T = TypeVar("T") + BE = TypeVar("BE", bound=BaseException) + MU = TypeVar("MU", bound="MaybeUnlock") Response = Coroutine[Any, Any, T] async def json_or_text(response: aiohttp.ClientResponse) -> Union[Dict[str, Any], str]: - text = await response.text(encoding='utf-8') + text = await response.text(encoding="utf-8") try: - if response.headers['content-type'] == 'application/json': + if response.headers["content-type"] == "application/json": return utils._from_json(text) except KeyError: # Thanks Cloudflare @@ -109,7 +117,7 @@ async def json_or_text(response: aiohttp.ClientResponse) -> Union[Dict[str, Any] class Route: - BASE: ClassVar[str] = 'https://discord.com/api/v8' + BASE: ClassVar[str] = "https://discord.com/api/v8" def __init__(self, method: str, path: str, **parameters: Any) -> None: self.path: str = path @@ -120,15 +128,15 @@ class Route: self.url: str = url # major parameters: - self.channel_id: Optional[Snowflake] = parameters.get('channel_id') - self.guild_id: Optional[Snowflake] = parameters.get('guild_id') - self.webhook_id: Optional[Snowflake] = parameters.get('webhook_id') - self.webhook_token: Optional[str] = parameters.get('webhook_token') + self.channel_id: Optional[Snowflake] = parameters.get("channel_id") + self.guild_id: Optional[Snowflake] = parameters.get("guild_id") + self.webhook_id: Optional[Snowflake] = parameters.get("webhook_id") + self.webhook_token: Optional[str] = parameters.get("webhook_token") @property def bucket(self) -> str: # the bucket is just method + path w/ major parameters - return f'{self.channel_id}:{self.guild_id}:{self.path}' + return f"{self.channel_id}:{self.guild_id}:{self.path}" class MaybeUnlock: @@ -154,7 +162,7 @@ class MaybeUnlock: # For some reason, the Discord voice websocket expects this header to be # completely lowercase while aiohttp respects spec and does it as case-insensitive -aiohttp.hdrs.WEBSOCKET = 'websocket' # type: ignore +aiohttp.hdrs.WEBSOCKET = "websocket" # type: ignore class HTTPClient: @@ -181,7 +189,7 @@ class HTTPClient: self.proxy_auth: Optional[aiohttp.BasicAuth] = proxy_auth self.use_clock: bool = not unsync_clock - user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}' + user_agent = "DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}" self.user_agent: str = user_agent.format(__version__, sys.version_info, aiohttp.__version__) def recreate(self) -> None: @@ -192,15 +200,15 @@ class HTTPClient: async def ws_connect(self, url: str, *, compress: int = 0) -> Any: kwargs = { - 'proxy_auth': self.proxy_auth, - 'proxy': self.proxy, - 'max_msg_size': 0, - 'timeout': 30.0, - 'autoclose': False, - 'headers': { - 'User-Agent': self.user_agent, + "proxy_auth": self.proxy_auth, + "proxy": self.proxy, + "max_msg_size": 0, + "timeout": 30.0, + "autoclose": False, + "headers": { + "User-Agent": self.user_agent, }, - 'compress': compress, + "compress": compress, } return await self.__session.ws_connect(url, **kwargs) @@ -225,31 +233,31 @@ class HTTPClient: # header creation headers: Dict[str, str] = { - 'User-Agent': self.user_agent, + "User-Agent": self.user_agent, } if self.token is not None: - headers['Authorization'] = 'Bot ' + self.token + headers["Authorization"] = "Bot " + self.token # some checking if it's a JSON request - if 'json' in kwargs: - headers['Content-Type'] = 'application/json' - kwargs['data'] = utils._to_json(kwargs.pop('json')) + if "json" in kwargs: + headers["Content-Type"] = "application/json" + kwargs["data"] = utils._to_json(kwargs.pop("json")) try: - reason = kwargs.pop('reason') + reason = kwargs.pop("reason") except KeyError: pass else: if reason: - headers['X-Audit-Log-Reason'] = _uriquote(reason, safe='/ ') + headers["X-Audit-Log-Reason"] = _uriquote(reason, safe="/ ") - kwargs['headers'] = headers + kwargs["headers"] = headers # Proxy support if self.proxy is not None: - kwargs['proxy'] = self.proxy + kwargs["proxy"] = self.proxy if self.proxy_auth is not None: - kwargs['proxy_auth'] = self.proxy_auth + kwargs["proxy_auth"] = self.proxy_auth if not self._global_over.is_set(): # wait until the global lock is complete @@ -268,55 +276,55 @@ class HTTPClient: form_data = aiohttp.FormData() for params in form: form_data.add_field(**params) - kwargs['data'] = form_data + kwargs["data"] = form_data try: async with self.__session.request(method, url, **kwargs) as response: - _log.debug('%s %s with %s has returned %s', method, url, kwargs.get('data'), response.status) + _log.debug("%s %s with %s has returned %s", method, url, kwargs.get("data"), response.status) # even errors have text involved in them so this is safe to call data = await json_or_text(response) # check if we have rate limit header information - remaining = response.headers.get('X-Ratelimit-Remaining') - if remaining == '0' and response.status != 429: + remaining = response.headers.get("X-Ratelimit-Remaining") + if remaining == "0" and response.status != 429: # we've depleted our current bucket delta = utils._parse_ratelimit_header(response, use_clock=self.use_clock) - _log.debug('A rate limit bucket has been exhausted (bucket: %s, retry: %s).', bucket, delta) + _log.debug("A rate limit bucket has been exhausted (bucket: %s, retry: %s).", bucket, delta) maybe_lock.defer() self.loop.call_later(delta, lock.release) # the request was successful so just return the text/json if 300 > response.status >= 200: - _log.debug('%s %s has received %s', method, url, data) + _log.debug("%s %s has received %s", method, url, data) return data # we are being rate limited if response.status == 429: - if not response.headers.get('Via') or isinstance(data, str): + if not response.headers.get("Via") or isinstance(data, str): # Banned by Cloudflare more than likely. raise HTTPException(response, data) fmt = 'We are being rate limited. Retrying in %.2f seconds. Handled under the bucket "%s"' # sleep a bit - retry_after: float = data['retry_after'] + retry_after: float = data["retry_after"] _log.warning(fmt, retry_after, bucket) # check if it's a global rate limit - is_global = data.get('global', False) + is_global = data.get("global", False) if is_global: - _log.warning('Global rate limit has been hit. Retrying in %.2f seconds.', retry_after) + _log.warning("Global rate limit has been hit. Retrying in %.2f seconds.", retry_after) self._global_over.clear() await asyncio.sleep(retry_after) - _log.debug('Done sleeping for the rate limit. Retrying...') + _log.debug("Done sleeping for the rate limit. Retrying...") # release the global lock now that the # global rate limit has passed if is_global: self._global_over.set() - _log.debug('Global rate limit is now over.') + _log.debug("Global rate limit is now over.") continue @@ -350,18 +358,18 @@ class HTTPClient: raise HTTPException(response, data) - raise RuntimeError('Unreachable code in HTTP handling') + raise RuntimeError("Unreachable code in HTTP handling") async def get_from_cdn(self, url: str) -> bytes: async with self.__session.get(url) as resp: if resp.status == 200: return await resp.read() elif resp.status == 404: - raise NotFound(resp, 'asset not found') + raise NotFound(resp, "asset not found") elif resp.status == 403: - raise Forbidden(resp, 'cannot retrieve asset') + raise Forbidden(resp, "cannot retrieve asset") else: - raise HTTPException(resp, 'failed to get asset') + raise HTTPException(resp, "failed to get asset") # state management @@ -373,43 +381,45 @@ class HTTPClient: async def static_login(self, token: str) -> user.User: # Necessary to get aiohttp to stop complaining about session creation - self.__session = aiohttp.ClientSession(connector=self.connector, ws_response_class=DiscordClientWebSocketResponse) + self.__session = aiohttp.ClientSession( + connector=self.connector, ws_response_class=DiscordClientWebSocketResponse + ) old_token = self.token self.token = token try: - data = await self.request(Route('GET', '/users/@me')) + data = await self.request(Route("GET", "/users/@me")) except HTTPException as exc: self.token = old_token if exc.status == 401: - raise LoginFailure('Improper token has been passed.') from exc + raise LoginFailure("Improper token has been passed.") from exc raise return data def logout(self) -> Response[None]: - return self.request(Route('POST', '/auth/logout')) + return self.request(Route("POST", "/auth/logout")) # Group functionality def start_group(self, user_id: Snowflake, recipients: List[int]) -> Response[channel.GroupDMChannel]: payload = { - 'recipients': recipients, + "recipients": recipients, } - return self.request(Route('POST', '/users/{user_id}/channels', user_id=user_id), json=payload) + return self.request(Route("POST", "/users/{user_id}/channels", user_id=user_id), json=payload) def leave_group(self, channel_id) -> Response[None]: - return self.request(Route('DELETE', '/channels/{channel_id}', channel_id=channel_id)) + return self.request(Route("DELETE", "/channels/{channel_id}", channel_id=channel_id)) # Message management def start_private_message(self, user_id: Snowflake) -> Response[channel.DMChannel]: payload = { - 'recipient_id': user_id, + "recipient_id": user_id, } - return self.request(Route('POST', '/users/@me/channels'), json=payload) + return self.request(Route("POST", "/users/@me/channels"), json=payload) def send_message( self, @@ -425,40 +435,40 @@ class HTTPClient: stickers: Optional[List[sticker.StickerItem]] = None, components: Optional[List[components.Component]] = None, ) -> Response[message.Message]: - r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id) + r = Route("POST", "/channels/{channel_id}/messages", channel_id=channel_id) payload = {} if content: - payload['content'] = content + payload["content"] = content if tts: - payload['tts'] = True + payload["tts"] = True if embed: - payload['embeds'] = [embed] + payload["embeds"] = [embed] if embeds: - payload['embeds'] = embeds + payload["embeds"] = embeds if nonce: - payload['nonce'] = nonce + payload["nonce"] = nonce if allowed_mentions: - payload['allowed_mentions'] = allowed_mentions + payload["allowed_mentions"] = allowed_mentions if message_reference: - payload['message_reference'] = message_reference + payload["message_reference"] = message_reference if components: - payload['components'] = components + payload["components"] = components if stickers: - payload['sticker_ids'] = stickers + payload["sticker_ids"] = stickers return self.request(r, json=payload) def send_typing(self, channel_id: Snowflake) -> Response[None]: - return self.request(Route('POST', '/channels/{channel_id}/typing', channel_id=channel_id)) + return self.request(Route("POST", "/channels/{channel_id}/typing", channel_id=channel_id)) def send_multipart_helper( self, @@ -477,43 +487,43 @@ class HTTPClient: ) -> Response[message.Message]: form = [] - payload: Dict[str, Any] = {'tts': tts} + payload: Dict[str, Any] = {"tts": tts} if content: - payload['content'] = content + payload["content"] = content if embed: - payload['embeds'] = [embed] + payload["embeds"] = [embed] if embeds: - payload['embeds'] = embeds + payload["embeds"] = embeds if nonce: - payload['nonce'] = nonce + payload["nonce"] = nonce if allowed_mentions: - payload['allowed_mentions'] = allowed_mentions + payload["allowed_mentions"] = allowed_mentions if message_reference: - payload['message_reference'] = message_reference + payload["message_reference"] = message_reference if components: - payload['components'] = components + payload["components"] = components if stickers: - payload['sticker_ids'] = stickers + payload["sticker_ids"] = stickers - form.append({'name': 'payload_json', 'value': utils._to_json(payload)}) + form.append({"name": "payload_json", "value": utils._to_json(payload)}) if len(files) == 1: file = files[0] form.append( { - 'name': 'file', - 'value': file.fp, - 'filename': file.filename, - 'content_type': 'application/octet-stream', + "name": "file", + "value": file.fp, + "filename": file.filename, + "content_type": "application/octet-stream", } ) else: for index, file in enumerate(files): form.append( { - 'name': f'file{index}', - 'value': file.fp, - 'filename': file.filename, - 'content_type': 'application/octet-stream', + "name": f"file{index}", + "value": file.fp, + "filename": file.filename, + "content_type": "application/octet-stream", } ) @@ -534,7 +544,7 @@ class HTTPClient: stickers: Optional[List[sticker.StickerItem]] = None, components: Optional[List[components.Component]] = None, ) -> Response[message.Message]: - r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id) + r = Route("POST", "/channels/{channel_id}/messages", channel_id=channel_id) return self.send_multipart_helper( r, files=files, @@ -552,27 +562,29 @@ class HTTPClient: def delete_message( self, channel_id: Snowflake, message_id: Snowflake, *, reason: Optional[str] = None ) -> Response[None]: - r = Route('DELETE', '/channels/{channel_id}/messages/{message_id}', channel_id=channel_id, message_id=message_id) + r = Route( + "DELETE", "/channels/{channel_id}/messages/{message_id}", channel_id=channel_id, message_id=message_id + ) return self.request(r, reason=reason) def delete_messages( self, channel_id: Snowflake, message_ids: SnowflakeList, *, reason: Optional[str] = None ) -> Response[None]: - r = Route('POST', '/channels/{channel_id}/messages/bulk-delete', channel_id=channel_id) + r = Route("POST", "/channels/{channel_id}/messages/bulk-delete", channel_id=channel_id) payload = { - 'messages': message_ids, + "messages": message_ids, } return self.request(r, json=payload, reason=reason) def edit_message(self, channel_id: Snowflake, message_id: Snowflake, **fields: Any) -> Response[message.Message]: - r = Route('PATCH', '/channels/{channel_id}/messages/{message_id}', channel_id=channel_id, message_id=message_id) + r = Route("PATCH", "/channels/{channel_id}/messages/{message_id}", channel_id=channel_id, message_id=message_id) return self.request(r, json=fields) def add_reaction(self, channel_id: Snowflake, message_id: Snowflake, emoji: str) -> Response[None]: r = Route( - 'PUT', - '/channels/{channel_id}/messages/{message_id}/reactions/{emoji}/@me', + "PUT", + "/channels/{channel_id}/messages/{message_id}/reactions/{emoji}/@me", channel_id=channel_id, message_id=message_id, emoji=emoji, @@ -583,8 +595,8 @@ class HTTPClient: self, channel_id: Snowflake, message_id: Snowflake, emoji: str, member_id: Snowflake ) -> Response[None]: r = Route( - 'DELETE', - '/channels/{channel_id}/messages/{message_id}/reactions/{emoji}/{member_id}', + "DELETE", + "/channels/{channel_id}/messages/{message_id}/reactions/{emoji}/{member_id}", channel_id=channel_id, message_id=message_id, member_id=member_id, @@ -594,8 +606,8 @@ class HTTPClient: def remove_own_reaction(self, channel_id: Snowflake, message_id: Snowflake, emoji: str) -> Response[None]: r = Route( - 'DELETE', - '/channels/{channel_id}/messages/{message_id}/reactions/{emoji}/@me', + "DELETE", + "/channels/{channel_id}/messages/{message_id}/reactions/{emoji}/@me", channel_id=channel_id, message_id=message_id, emoji=emoji, @@ -611,24 +623,24 @@ class HTTPClient: after: Optional[Snowflake] = None, ) -> Response[List[user.User]]: r = Route( - 'GET', - '/channels/{channel_id}/messages/{message_id}/reactions/{emoji}', + "GET", + "/channels/{channel_id}/messages/{message_id}/reactions/{emoji}", channel_id=channel_id, message_id=message_id, emoji=emoji, ) params: Dict[str, Any] = { - 'limit': limit, + "limit": limit, } if after: - params['after'] = after + params["after"] = after return self.request(r, params=params) def clear_reactions(self, channel_id: Snowflake, message_id: Snowflake) -> Response[None]: r = Route( - 'DELETE', - '/channels/{channel_id}/messages/{message_id}/reactions', + "DELETE", + "/channels/{channel_id}/messages/{message_id}/reactions", channel_id=channel_id, message_id=message_id, ) @@ -637,8 +649,8 @@ class HTTPClient: def clear_single_reaction(self, channel_id: Snowflake, message_id: Snowflake, emoji: str) -> Response[None]: r = Route( - 'DELETE', - '/channels/{channel_id}/messages/{message_id}/reactions/{emoji}', + "DELETE", + "/channels/{channel_id}/messages/{message_id}/reactions/{emoji}", channel_id=channel_id, message_id=message_id, emoji=emoji, @@ -646,11 +658,11 @@ class HTTPClient: return self.request(r) def get_message(self, channel_id: Snowflake, message_id: Snowflake) -> Response[message.Message]: - r = Route('GET', '/channels/{channel_id}/messages/{message_id}', channel_id=channel_id, message_id=message_id) + r = Route("GET", "/channels/{channel_id}/messages/{message_id}", channel_id=channel_id, message_id=message_id) return self.request(r) def get_channel(self, channel_id: Snowflake) -> Response[channel.Channel]: - r = Route('GET', '/channels/{channel_id}', channel_id=channel_id) + r = Route("GET", "/channels/{channel_id}", channel_id=channel_id) return self.request(r) def logs_from( @@ -662,23 +674,23 @@ class HTTPClient: around: Optional[Snowflake] = None, ) -> Response[List[message.Message]]: params: Dict[str, Any] = { - 'limit': limit, + "limit": limit, } if before is not None: - params['before'] = before + params["before"] = before if after is not None: - params['after'] = after + params["after"] = after if around is not None: - params['around'] = around + params["around"] = around - return self.request(Route('GET', '/channels/{channel_id}/messages', channel_id=channel_id), params=params) + return self.request(Route("GET", "/channels/{channel_id}/messages", channel_id=channel_id), params=params) def publish_message(self, channel_id: Snowflake, message_id: Snowflake) -> Response[message.Message]: return self.request( Route( - 'POST', - '/channels/{channel_id}/messages/{message_id}/crosspost', + "POST", + "/channels/{channel_id}/messages/{message_id}/crosspost", channel_id=channel_id, message_id=message_id, ) @@ -686,32 +698,34 @@ class HTTPClient: def pin_message(self, channel_id: Snowflake, message_id: Snowflake, reason: Optional[str] = None) -> Response[None]: r = Route( - 'PUT', - '/channels/{channel_id}/pins/{message_id}', + "PUT", + "/channels/{channel_id}/pins/{message_id}", channel_id=channel_id, message_id=message_id, ) return self.request(r, reason=reason) - def unpin_message(self, channel_id: Snowflake, message_id: Snowflake, reason: Optional[str] = None) -> Response[None]: + def unpin_message( + self, channel_id: Snowflake, message_id: Snowflake, reason: Optional[str] = None + ) -> Response[None]: r = Route( - 'DELETE', - '/channels/{channel_id}/pins/{message_id}', + "DELETE", + "/channels/{channel_id}/pins/{message_id}", channel_id=channel_id, message_id=message_id, ) return self.request(r, reason=reason) def pins_from(self, channel_id: Snowflake) -> Response[List[message.Message]]: - return self.request(Route('GET', '/channels/{channel_id}/pins', channel_id=channel_id)) + return self.request(Route("GET", "/channels/{channel_id}/pins", channel_id=channel_id)) # Member management def kick(self, user_id: Snowflake, guild_id: Snowflake, reason: Optional[str] = None) -> Response[None]: - r = Route('DELETE', '/guilds/{guild_id}/members/{user_id}', guild_id=guild_id, user_id=user_id) + r = Route("DELETE", "/guilds/{guild_id}/members/{user_id}", guild_id=guild_id, user_id=user_id) if reason: # thanks aiohttp - r.url = f'{r.url}?reason={_uriquote(reason)}' + r.url = f"{r.url}?reason={_uriquote(reason)}" return self.request(r) @@ -722,15 +736,15 @@ class HTTPClient: delete_message_days: int = 1, reason: Optional[str] = None, ) -> Response[None]: - r = Route('PUT', '/guilds/{guild_id}/bans/{user_id}', guild_id=guild_id, user_id=user_id) + r = Route("PUT", "/guilds/{guild_id}/bans/{user_id}", guild_id=guild_id, user_id=user_id) params = { - 'delete_message_days': delete_message_days, + "delete_message_days": delete_message_days, } return self.request(r, params=params, reason=reason) def unban(self, user_id: Snowflake, guild_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]: - r = Route('DELETE', '/guilds/{guild_id}/bans/{user_id}', guild_id=guild_id, user_id=user_id) + r = Route("DELETE", "/guilds/{guild_id}/bans/{user_id}", guild_id=guild_id, user_id=user_id) return self.request(r, reason=reason) def guild_voice_state( @@ -742,18 +756,18 @@ class HTTPClient: deafen: Optional[bool] = None, reason: Optional[str] = None, ) -> Response[member.Member]: - r = Route('PATCH', '/guilds/{guild_id}/members/{user_id}', guild_id=guild_id, user_id=user_id) + r = Route("PATCH", "/guilds/{guild_id}/members/{user_id}", guild_id=guild_id, user_id=user_id) payload = {} if mute is not None: - payload['mute'] = mute + payload["mute"] = mute if deafen is not None: - payload['deaf'] = deafen + payload["deaf"] = deafen return self.request(r, json=payload, reason=reason) def edit_profile(self, payload: Dict[str, Any]) -> Response[user.User]: - return self.request(Route('PATCH', '/users/@me'), json=payload) + return self.request(Route("PATCH", "/users/@me"), json=payload) def change_my_nickname( self, @@ -762,9 +776,9 @@ class HTTPClient: *, reason: Optional[str] = None, ) -> Response[member.Nickname]: - r = Route('PATCH', '/guilds/{guild_id}/members/@me/nick', guild_id=guild_id) + r = Route("PATCH", "/guilds/{guild_id}/members/@me/nick", guild_id=guild_id) payload = { - 'nick': nickname, + "nick": nickname, } return self.request(r, json=payload, reason=reason) @@ -776,18 +790,18 @@ class HTTPClient: *, reason: Optional[str] = None, ) -> Response[member.Member]: - r = Route('PATCH', '/guilds/{guild_id}/members/{user_id}', guild_id=guild_id, user_id=user_id) + r = Route("PATCH", "/guilds/{guild_id}/members/{user_id}", guild_id=guild_id, user_id=user_id) payload = { - 'nick': nickname, + "nick": nickname, } return self.request(r, json=payload, reason=reason) def edit_my_voice_state(self, guild_id: Snowflake, payload: Dict[str, Any]) -> Response[None]: - r = Route('PATCH', '/guilds/{guild_id}/voice-states/@me', guild_id=guild_id) + r = Route("PATCH", "/guilds/{guild_id}/voice-states/@me", guild_id=guild_id) return self.request(r, json=payload) def edit_voice_state(self, guild_id: Snowflake, user_id: Snowflake, payload: Dict[str, Any]) -> Response[None]: - r = Route('PATCH', '/guilds/{guild_id}/voice-states/{user_id}', guild_id=guild_id, user_id=user_id) + r = Route("PATCH", "/guilds/{guild_id}/voice-states/{user_id}", guild_id=guild_id, user_id=user_id) return self.request(r, json=payload) def edit_member( @@ -798,7 +812,7 @@ class HTTPClient: reason: Optional[str] = None, **fields: Any, ) -> Response[member.MemberWithUser]: - r = Route('PATCH', '/guilds/{guild_id}/members/{user_id}', guild_id=guild_id, user_id=user_id) + r = Route("PATCH", "/guilds/{guild_id}/members/{user_id}", guild_id=guild_id, user_id=user_id) return self.request(r, json=fields, reason=reason) # Channel management @@ -810,25 +824,25 @@ class HTTPClient: reason: Optional[str] = None, **options: Any, ) -> Response[channel.Channel]: - r = Route('PATCH', '/channels/{channel_id}', channel_id=channel_id) + r = Route("PATCH", "/channels/{channel_id}", channel_id=channel_id) valid_keys = ( - 'name', - 'parent_id', - 'topic', - 'bitrate', - 'nsfw', - 'user_limit', - 'position', - 'permission_overwrites', - 'rate_limit_per_user', - 'type', - 'rtc_region', - 'video_quality_mode', - 'archived', - 'auto_archive_duration', - 'locked', - 'invitable', - 'default_auto_archive_duration', + "name", + "parent_id", + "topic", + "bitrate", + "nsfw", + "user_limit", + "position", + "permission_overwrites", + "rate_limit_per_user", + "type", + "rtc_region", + "video_quality_mode", + "archived", + "auto_archive_duration", + "locked", + "invitable", + "default_auto_archive_duration", ) payload = {k: v for k, v in options.items() if k in valid_keys} return self.request(r, reason=reason, json=payload) @@ -840,7 +854,7 @@ class HTTPClient: *, reason: Optional[str] = None, ) -> Response[None]: - r = Route('PATCH', '/guilds/{guild_id}/channels', guild_id=guild_id) + r = Route("PATCH", "/guilds/{guild_id}/channels", guild_id=guild_id) return self.request(r, json=data, reason=reason) def create_channel( @@ -852,26 +866,28 @@ class HTTPClient: **options: Any, ) -> Response[channel.GuildChannel]: payload = { - 'type': channel_type, + "type": channel_type, } valid_keys = ( - 'name', - 'parent_id', - 'topic', - 'bitrate', - 'nsfw', - 'user_limit', - 'position', - 'permission_overwrites', - 'rate_limit_per_user', - 'rtc_region', - 'video_quality_mode', - 'auto_archive_duration', + "name", + "parent_id", + "topic", + "bitrate", + "nsfw", + "user_limit", + "position", + "permission_overwrites", + "rate_limit_per_user", + "rtc_region", + "video_quality_mode", + "auto_archive_duration", ) payload.update({k: v for k, v in options.items() if k in valid_keys and v is not None}) - return self.request(Route('POST', '/guilds/{guild_id}/channels', guild_id=guild_id), json=payload, reason=reason) + return self.request( + Route("POST", "/guilds/{guild_id}/channels", guild_id=guild_id), json=payload, reason=reason + ) def delete_channel( self, @@ -879,7 +895,7 @@ class HTTPClient: *, reason: Optional[str] = None, ) -> Response[None]: - return self.request(Route('DELETE', '/channels/{channel_id}', channel_id=channel_id), reason=reason) + return self.request(Route("DELETE", "/channels/{channel_id}", channel_id=channel_id), reason=reason) # Thread management @@ -893,12 +909,12 @@ class HTTPClient: reason: Optional[str] = None, ) -> Response[threads.Thread]: payload = { - 'name': name, - 'auto_archive_duration': auto_archive_duration, + "name": name, + "auto_archive_duration": auto_archive_duration, } route = Route( - 'POST', '/channels/{channel_id}/messages/{message_id}/threads', channel_id=channel_id, message_id=message_id + "POST", "/channels/{channel_id}/messages/{message_id}/threads", channel_id=channel_id, message_id=message_id ) return self.request(route, json=payload, reason=reason) @@ -913,68 +929,70 @@ class HTTPClient: reason: Optional[str] = None, ) -> Response[threads.Thread]: payload = { - 'name': name, - 'auto_archive_duration': auto_archive_duration, - 'type': type, - 'invitable': invitable, + "name": name, + "auto_archive_duration": auto_archive_duration, + "type": type, + "invitable": invitable, } - route = Route('POST', '/channels/{channel_id}/threads', channel_id=channel_id) + route = Route("POST", "/channels/{channel_id}/threads", channel_id=channel_id) return self.request(route, json=payload, reason=reason) def join_thread(self, channel_id: Snowflake) -> Response[None]: - return self.request(Route('POST', '/channels/{channel_id}/thread-members/@me', channel_id=channel_id)) + return self.request(Route("POST", "/channels/{channel_id}/thread-members/@me", channel_id=channel_id)) def add_user_to_thread(self, channel_id: Snowflake, user_id: Snowflake) -> Response[None]: return self.request( - Route('PUT', '/channels/{channel_id}/thread-members/{user_id}', channel_id=channel_id, user_id=user_id) + Route("PUT", "/channels/{channel_id}/thread-members/{user_id}", channel_id=channel_id, user_id=user_id) ) def leave_thread(self, channel_id: Snowflake) -> Response[None]: - return self.request(Route('DELETE', '/channels/{channel_id}/thread-members/@me', channel_id=channel_id)) + return self.request(Route("DELETE", "/channels/{channel_id}/thread-members/@me", channel_id=channel_id)) def remove_user_from_thread(self, channel_id: Snowflake, user_id: Snowflake) -> Response[None]: - route = Route('DELETE', '/channels/{channel_id}/thread-members/{user_id}', channel_id=channel_id, user_id=user_id) + route = Route( + "DELETE", "/channels/{channel_id}/thread-members/{user_id}", channel_id=channel_id, user_id=user_id + ) return self.request(route) def get_public_archived_threads( self, channel_id: Snowflake, before: Optional[Snowflake] = None, limit: int = 50 ) -> Response[threads.ThreadPaginationPayload]: - route = Route('GET', '/channels/{channel_id}/threads/archived/public', channel_id=channel_id) + route = Route("GET", "/channels/{channel_id}/threads/archived/public", channel_id=channel_id) params = {} if before: - params['before'] = before - params['limit'] = limit + params["before"] = before + params["limit"] = limit return self.request(route, params=params) def get_private_archived_threads( self, channel_id: Snowflake, before: Optional[Snowflake] = None, limit: int = 50 ) -> Response[threads.ThreadPaginationPayload]: - route = Route('GET', '/channels/{channel_id}/threads/archived/private', channel_id=channel_id) + route = Route("GET", "/channels/{channel_id}/threads/archived/private", channel_id=channel_id) params = {} if before: - params['before'] = before - params['limit'] = limit + params["before"] = before + params["limit"] = limit return self.request(route, params=params) def get_joined_private_archived_threads( self, channel_id: Snowflake, before: Optional[Snowflake] = None, limit: int = 50 ) -> Response[threads.ThreadPaginationPayload]: - route = Route('GET', '/channels/{channel_id}/users/@me/threads/archived/private', channel_id=channel_id) + route = Route("GET", "/channels/{channel_id}/users/@me/threads/archived/private", channel_id=channel_id) params = {} if before: - params['before'] = before - params['limit'] = limit + params["before"] = before + params["limit"] = limit return self.request(route, params=params) def get_active_threads(self, guild_id: Snowflake) -> Response[threads.ThreadPaginationPayload]: - route = Route('GET', '/guilds/{guild_id}/threads/active', guild_id=guild_id) + route = Route("GET", "/guilds/{guild_id}/threads/active", guild_id=guild_id) return self.request(route) def get_thread_members(self, channel_id: Snowflake) -> Response[List[threads.ThreadMember]]: - route = Route('GET', '/channels/{channel_id}/thread-members', channel_id=channel_id) + route = Route("GET", "/channels/{channel_id}/thread-members", channel_id=channel_id) return self.request(route) # Webhook management @@ -988,22 +1006,22 @@ class HTTPClient: reason: Optional[str] = None, ) -> Response[webhook.Webhook]: payload: Dict[str, Any] = { - 'name': name, + "name": name, } if avatar is not None: - payload['avatar'] = avatar + payload["avatar"] = avatar - r = Route('POST', '/channels/{channel_id}/webhooks', channel_id=channel_id) + r = Route("POST", "/channels/{channel_id}/webhooks", channel_id=channel_id) return self.request(r, json=payload, reason=reason) def channel_webhooks(self, channel_id: Snowflake) -> Response[List[webhook.Webhook]]: - return self.request(Route('GET', '/channels/{channel_id}/webhooks', channel_id=channel_id)) + return self.request(Route("GET", "/channels/{channel_id}/webhooks", channel_id=channel_id)) def guild_webhooks(self, guild_id: Snowflake) -> Response[List[webhook.Webhook]]: - return self.request(Route('GET', '/guilds/{guild_id}/webhooks', guild_id=guild_id)) + return self.request(Route("GET", "/guilds/{guild_id}/webhooks", guild_id=guild_id)) def get_webhook(self, webhook_id: Snowflake) -> Response[webhook.Webhook]: - return self.request(Route('GET', '/webhooks/{webhook_id}', webhook_id=webhook_id)) + return self.request(Route("GET", "/webhooks/{webhook_id}", webhook_id=webhook_id)) def follow_webhook( self, @@ -1012,10 +1030,10 @@ class HTTPClient: reason: Optional[str] = None, ) -> Response[None]: payload = { - 'webhook_channel_id': str(webhook_channel_id), + "webhook_channel_id": str(webhook_channel_id), } return self.request( - Route('POST', '/channels/{channel_id}/followers', channel_id=channel_id), json=payload, reason=reason + Route("POST", "/channels/{channel_id}/followers", channel_id=channel_id), json=payload, reason=reason ) # Guild management @@ -1027,140 +1045,142 @@ class HTTPClient: after: Optional[Snowflake] = None, ) -> Response[List[guild.Guild]]: params: Dict[str, Any] = { - 'limit': limit, + "limit": limit, } if before: - params['before'] = before + params["before"] = before if after: - params['after'] = after + params["after"] = after - return self.request(Route('GET', '/users/@me/guilds'), params=params) + return self.request(Route("GET", "/users/@me/guilds"), params=params) def leave_guild(self, guild_id: Snowflake) -> Response[None]: - return self.request(Route('DELETE', '/users/@me/guilds/{guild_id}', guild_id=guild_id)) + return self.request(Route("DELETE", "/users/@me/guilds/{guild_id}", guild_id=guild_id)) def get_guild(self, guild_id: Snowflake) -> Response[guild.Guild]: - return self.request(Route('GET', '/guilds/{guild_id}', guild_id=guild_id)) + return self.request(Route("GET", "/guilds/{guild_id}", guild_id=guild_id)) def delete_guild(self, guild_id: Snowflake) -> Response[None]: - return self.request(Route('DELETE', '/guilds/{guild_id}', guild_id=guild_id)) + return self.request(Route("DELETE", "/guilds/{guild_id}", guild_id=guild_id)) def create_guild(self, name: str, region: str, icon: Optional[str]) -> Response[guild.Guild]: payload = { - 'name': name, - 'region': region, + "name": name, + "region": region, } if icon: - payload['icon'] = icon + payload["icon"] = icon - return self.request(Route('POST', '/guilds'), json=payload) + return self.request(Route("POST", "/guilds"), json=payload) def edit_guild(self, guild_id: Snowflake, *, reason: Optional[str] = None, **fields: Any) -> Response[guild.Guild]: valid_keys = ( - 'name', - 'region', - 'icon', - 'afk_timeout', - 'owner_id', - 'afk_channel_id', - 'splash', - 'discovery_splash', - 'features', - 'verification_level', - 'system_channel_id', - 'default_message_notifications', - 'description', - 'explicit_content_filter', - 'banner', - 'system_channel_flags', - 'rules_channel_id', - 'public_updates_channel_id', - 'preferred_locale', + "name", + "region", + "icon", + "afk_timeout", + "owner_id", + "afk_channel_id", + "splash", + "discovery_splash", + "features", + "verification_level", + "system_channel_id", + "default_message_notifications", + "description", + "explicit_content_filter", + "banner", + "system_channel_flags", + "rules_channel_id", + "public_updates_channel_id", + "preferred_locale", ) payload = {k: v for k, v in fields.items() if k in valid_keys} - return self.request(Route('PATCH', '/guilds/{guild_id}', guild_id=guild_id), json=payload, reason=reason) + return self.request(Route("PATCH", "/guilds/{guild_id}", guild_id=guild_id), json=payload, reason=reason) def get_template(self, code: str) -> Response[template.Template]: - return self.request(Route('GET', '/guilds/templates/{code}', code=code)) + return self.request(Route("GET", "/guilds/templates/{code}", code=code)) def guild_templates(self, guild_id: Snowflake) -> Response[List[template.Template]]: - return self.request(Route('GET', '/guilds/{guild_id}/templates', guild_id=guild_id)) + return self.request(Route("GET", "/guilds/{guild_id}/templates", guild_id=guild_id)) def create_template(self, guild_id: Snowflake, payload: template.CreateTemplate) -> Response[template.Template]: - return self.request(Route('POST', '/guilds/{guild_id}/templates', guild_id=guild_id), json=payload) + return self.request(Route("POST", "/guilds/{guild_id}/templates", guild_id=guild_id), json=payload) def sync_template(self, guild_id: Snowflake, code: str) -> Response[template.Template]: - return self.request(Route('PUT', '/guilds/{guild_id}/templates/{code}', guild_id=guild_id, code=code)) + return self.request(Route("PUT", "/guilds/{guild_id}/templates/{code}", guild_id=guild_id, code=code)) def edit_template(self, guild_id: Snowflake, code: str, payload) -> Response[template.Template]: valid_keys = ( - 'name', - 'description', + "name", + "description", ) payload = {k: v for k, v in payload.items() if k in valid_keys} return self.request( - Route('PATCH', '/guilds/{guild_id}/templates/{code}', guild_id=guild_id, code=code), json=payload + Route("PATCH", "/guilds/{guild_id}/templates/{code}", guild_id=guild_id, code=code), json=payload ) def delete_template(self, guild_id: Snowflake, code: str) -> Response[None]: - return self.request(Route('DELETE', '/guilds/{guild_id}/templates/{code}', guild_id=guild_id, code=code)) + return self.request(Route("DELETE", "/guilds/{guild_id}/templates/{code}", guild_id=guild_id, code=code)) def create_from_template(self, code: str, name: str, region: str, icon: Optional[str]) -> Response[guild.Guild]: payload = { - 'name': name, - 'region': region, + "name": name, + "region": region, } if icon: - payload['icon'] = icon - return self.request(Route('POST', '/guilds/templates/{code}', code=code), json=payload) + payload["icon"] = icon + return self.request(Route("POST", "/guilds/templates/{code}", code=code), json=payload) def get_welcome_screen(self, guild_id: Snowflake) -> Response[welcome_screen.WelcomeScreen]: - return self.request(Route('GET', '/guilds/{guild_id}/welcome-screen', guild_id=guild_id)) + return self.request(Route("GET", "/guilds/{guild_id}/welcome-screen", guild_id=guild_id)) def edit_welcome_screen(self, guild_id: Snowflake, payload: Any) -> Response[welcome_screen.WelcomeScreen]: valid_keys = ( - 'description', - 'welcome_channels', - 'enabled', + "description", + "welcome_channels", + "enabled", ) - payload = { - k: v for k, v in payload.items() if k in valid_keys - } - return self.request(Route('PATCH', '/guilds/{guild_id}/welcome-screen', guild_id=guild_id), json=payload) + payload = {k: v for k, v in payload.items() if k in valid_keys} + return self.request(Route("PATCH", "/guilds/{guild_id}/welcome-screen", guild_id=guild_id), json=payload) def get_bans(self, guild_id: Snowflake) -> Response[List[guild.Ban]]: - return self.request(Route('GET', '/guilds/{guild_id}/bans', guild_id=guild_id)) + return self.request(Route("GET", "/guilds/{guild_id}/bans", guild_id=guild_id)) def get_ban(self, user_id: Snowflake, guild_id: Snowflake) -> Response[guild.Ban]: - return self.request(Route('GET', '/guilds/{guild_id}/bans/{user_id}', guild_id=guild_id, user_id=user_id)) + return self.request(Route("GET", "/guilds/{guild_id}/bans/{user_id}", guild_id=guild_id, user_id=user_id)) def get_vanity_code(self, guild_id: Snowflake) -> Response[invite.VanityInvite]: - return self.request(Route('GET', '/guilds/{guild_id}/vanity-url', guild_id=guild_id)) + return self.request(Route("GET", "/guilds/{guild_id}/vanity-url", guild_id=guild_id)) def change_vanity_code(self, guild_id: Snowflake, code: str, *, reason: Optional[str] = None) -> Response[None]: - payload: Dict[str, Any] = {'code': code} - return self.request(Route('PATCH', '/guilds/{guild_id}/vanity-url', guild_id=guild_id), json=payload, reason=reason) + payload: Dict[str, Any] = {"code": code} + return self.request( + Route("PATCH", "/guilds/{guild_id}/vanity-url", guild_id=guild_id), json=payload, reason=reason + ) def get_all_guild_channels(self, guild_id: Snowflake) -> Response[List[guild.GuildChannel]]: - return self.request(Route('GET', '/guilds/{guild_id}/channels', guild_id=guild_id)) + return self.request(Route("GET", "/guilds/{guild_id}/channels", guild_id=guild_id)) def get_members( self, guild_id: Snowflake, limit: int, after: Optional[Snowflake] ) -> Response[List[member.MemberWithUser]]: params: Dict[str, Any] = { - 'limit': limit, + "limit": limit, } if after: - params['after'] = after + params["after"] = after - r = Route('GET', '/guilds/{guild_id}/members', guild_id=guild_id) + r = Route("GET", "/guilds/{guild_id}/members", guild_id=guild_id) return self.request(r, params=params) def get_member(self, guild_id: Snowflake, member_id: Snowflake) -> Response[member.MemberWithUser]: - return self.request(Route('GET', '/guilds/{guild_id}/members/{member_id}', guild_id=guild_id, member_id=member_id)) + return self.request( + Route("GET", "/guilds/{guild_id}/members/{member_id}", guild_id=guild_id, member_id=member_id) + ) def prune_members( self, @@ -1172,13 +1192,13 @@ class HTTPClient: reason: Optional[str] = None, ) -> Response[guild.GuildPrune]: payload: Dict[str, Any] = { - 'days': days, - 'compute_prune_count': 'true' if compute_prune_count else 'false', + "days": days, + "compute_prune_count": "true" if compute_prune_count else "false", } if roles: - payload['include_roles'] = ', '.join(roles) + payload["include_roles"] = ", ".join(roles) - return self.request(Route('POST', '/guilds/{guild_id}/prune', guild_id=guild_id), json=payload, reason=reason) + return self.request(Route("POST", "/guilds/{guild_id}/prune", guild_id=guild_id), json=payload, reason=reason) def estimate_pruned_members( self, @@ -1187,25 +1207,25 @@ class HTTPClient: roles: List[str], ) -> Response[guild.GuildPrune]: params: Dict[str, Any] = { - 'days': days, + "days": days, } if roles: - params['include_roles'] = ', '.join(roles) + params["include_roles"] = ", ".join(roles) - return self.request(Route('GET', '/guilds/{guild_id}/prune', guild_id=guild_id), params=params) + return self.request(Route("GET", "/guilds/{guild_id}/prune", guild_id=guild_id), params=params) def get_sticker(self, sticker_id: Snowflake) -> Response[sticker.Sticker]: - return self.request(Route('GET', '/stickers/{sticker_id}', sticker_id=sticker_id)) + return self.request(Route("GET", "/stickers/{sticker_id}", sticker_id=sticker_id)) def list_premium_sticker_packs(self) -> Response[sticker.ListPremiumStickerPacks]: - return self.request(Route('GET', '/sticker-packs')) + return self.request(Route("GET", "/sticker-packs")) def get_all_guild_stickers(self, guild_id: Snowflake) -> Response[List[sticker.GuildSticker]]: - return self.request(Route('GET', '/guilds/{guild_id}/stickers', guild_id=guild_id)) + return self.request(Route("GET", "/guilds/{guild_id}/stickers", guild_id=guild_id)) def get_guild_sticker(self, guild_id: Snowflake, sticker_id: Snowflake) -> Response[sticker.GuildSticker]: return self.request( - Route('GET', '/guilds/{guild_id}/stickers/{sticker_id}', guild_id=guild_id, sticker_id=sticker_id) + Route("GET", "/guilds/{guild_id}/stickers/{sticker_id}", guild_id=guild_id, sticker_id=sticker_id) ) def create_guild_sticker( @@ -1216,54 +1236,58 @@ class HTTPClient: try: mime_type = utils._get_mime_type_for_image(initial_bytes) except InvalidArgument: - if initial_bytes.startswith(b'{'): - mime_type = 'application/json' + if initial_bytes.startswith(b"{"): + mime_type = "application/json" else: - mime_type = 'application/octet-stream' + mime_type = "application/octet-stream" finally: file.reset() form: List[Dict[str, Any]] = [ { - 'name': 'file', - 'value': file.fp, - 'filename': file.filename, - 'content_type': mime_type, + "name": "file", + "value": file.fp, + "filename": file.filename, + "content_type": mime_type, } ] for k, v in payload.items(): form.append( { - 'name': k, - 'value': v, + "name": k, + "value": v, } ) return self.request( - Route('POST', '/guilds/{guild_id}/stickers', guild_id=guild_id), form=form, files=[file], reason=reason + Route("POST", "/guilds/{guild_id}/stickers", guild_id=guild_id), form=form, files=[file], reason=reason ) def modify_guild_sticker( - self, guild_id: Snowflake, sticker_id: Snowflake, payload: sticker.EditGuildSticker, reason: Optional[str], + self, + guild_id: Snowflake, + sticker_id: Snowflake, + payload: sticker.EditGuildSticker, + reason: Optional[str], ) -> Response[sticker.GuildSticker]: return self.request( - Route('PATCH', '/guilds/{guild_id}/stickers/{sticker_id}', guild_id=guild_id, sticker_id=sticker_id), + Route("PATCH", "/guilds/{guild_id}/stickers/{sticker_id}", guild_id=guild_id, sticker_id=sticker_id), json=payload, reason=reason, ) def delete_guild_sticker(self, guild_id: Snowflake, sticker_id: Snowflake, reason: Optional[str]) -> Response[None]: return self.request( - Route('DELETE', '/guilds/{guild_id}/stickers/{sticker_id}', guild_id=guild_id, sticker_id=sticker_id), + Route("DELETE", "/guilds/{guild_id}/stickers/{sticker_id}", guild_id=guild_id, sticker_id=sticker_id), reason=reason, ) def get_all_custom_emojis(self, guild_id: Snowflake) -> Response[List[emoji.Emoji]]: - return self.request(Route('GET', '/guilds/{guild_id}/emojis', guild_id=guild_id)) + return self.request(Route("GET", "/guilds/{guild_id}/emojis", guild_id=guild_id)) def get_custom_emoji(self, guild_id: Snowflake, emoji_id: Snowflake) -> Response[emoji.Emoji]: - return self.request(Route('GET', '/guilds/{guild_id}/emojis/{emoji_id}', guild_id=guild_id, emoji_id=emoji_id)) + return self.request(Route("GET", "/guilds/{guild_id}/emojis/{emoji_id}", guild_id=guild_id, emoji_id=emoji_id)) def create_custom_emoji( self, @@ -1275,12 +1299,12 @@ class HTTPClient: reason: Optional[str] = None, ) -> Response[emoji.Emoji]: payload = { - 'name': name, - 'image': image, - 'roles': roles or [], + "name": name, + "image": image, + "roles": roles or [], } - r = Route('POST', '/guilds/{guild_id}/emojis', guild_id=guild_id) + r = Route("POST", "/guilds/{guild_id}/emojis", guild_id=guild_id) return self.request(r, json=payload, reason=reason) def delete_custom_emoji( @@ -1290,7 +1314,7 @@ class HTTPClient: *, reason: Optional[str] = None, ) -> Response[None]: - r = Route('DELETE', '/guilds/{guild_id}/emojis/{emoji_id}', guild_id=guild_id, emoji_id=emoji_id) + r = Route("DELETE", "/guilds/{guild_id}/emojis/{emoji_id}", guild_id=guild_id, emoji_id=emoji_id) return self.request(r, reason=reason) def edit_custom_emoji( @@ -1301,33 +1325,39 @@ class HTTPClient: payload: Dict[str, Any], reason: Optional[str] = None, ) -> Response[emoji.Emoji]: - r = Route('PATCH', '/guilds/{guild_id}/emojis/{emoji_id}', guild_id=guild_id, emoji_id=emoji_id) + r = Route("PATCH", "/guilds/{guild_id}/emojis/{emoji_id}", guild_id=guild_id, emoji_id=emoji_id) return self.request(r, json=payload, reason=reason) def get_all_integrations(self, guild_id: Snowflake) -> Response[List[integration.Integration]]: - r = Route('GET', '/guilds/{guild_id}/integrations', guild_id=guild_id) + r = Route("GET", "/guilds/{guild_id}/integrations", guild_id=guild_id) return self.request(r) def create_integration(self, guild_id: Snowflake, type: integration.IntegrationType, id: int) -> Response[None]: payload = { - 'type': type, - 'id': id, + "type": type, + "id": id, } - r = Route('POST', '/guilds/{guild_id}/integrations', guild_id=guild_id) + r = Route("POST", "/guilds/{guild_id}/integrations", guild_id=guild_id) return self.request(r, json=payload) def edit_integration(self, guild_id: Snowflake, integration_id: Snowflake, **payload: Any) -> Response[None]: r = Route( - 'PATCH', '/guilds/{guild_id}/integrations/{integration_id}', guild_id=guild_id, integration_id=integration_id + "PATCH", + "/guilds/{guild_id}/integrations/{integration_id}", + guild_id=guild_id, + integration_id=integration_id, ) return self.request(r, json=payload) def sync_integration(self, guild_id: Snowflake, integration_id: Snowflake) -> Response[None]: r = Route( - 'POST', '/guilds/{guild_id}/integrations/{integration_id}/sync', guild_id=guild_id, integration_id=integration_id + "POST", + "/guilds/{guild_id}/integrations/{integration_id}/sync", + guild_id=guild_id, + integration_id=integration_id, ) return self.request(r) @@ -1336,7 +1366,10 @@ class HTTPClient: self, guild_id: Snowflake, integration_id: Snowflake, *, reason: Optional[str] = None ) -> Response[None]: r = Route( - 'DELETE', '/guilds/{guild_id}/integrations/{integration_id}', guild_id=guild_id, integration_id=integration_id + "DELETE", + "/guilds/{guild_id}/integrations/{integration_id}", + guild_id=guild_id, + integration_id=integration_id, ) return self.request(r, reason=reason) @@ -1350,24 +1383,24 @@ class HTTPClient: user_id: Optional[Snowflake] = None, action_type: Optional[AuditLogAction] = None, ) -> Response[audit_log.AuditLog]: - params: Dict[str, Any] = {'limit': limit} + params: Dict[str, Any] = {"limit": limit} if before: - params['before'] = before + params["before"] = before if after: - params['after'] = after + params["after"] = after if user_id: - params['user_id'] = user_id + params["user_id"] = user_id if action_type: - params['action_type'] = action_type + params["action_type"] = action_type - r = Route('GET', '/guilds/{guild_id}/audit-logs', guild_id=guild_id) + r = Route("GET", "/guilds/{guild_id}/audit-logs", guild_id=guild_id) return self.request(r, params=params) def get_widget(self, guild_id: Snowflake) -> Response[widget.Widget]: - return self.request(Route('GET', '/guilds/{guild_id}/widget.json', guild_id=guild_id)) + return self.request(Route("GET", "/guilds/{guild_id}/widget.json", guild_id=guild_id)) def edit_widget(self, guild_id: Snowflake, payload) -> Response[widget.WidgetSettings]: - return self.request(Route('PATCH', '/guilds/{guild_id}/widget', guild_id=guild_id), json=payload) + return self.request(Route("PATCH", "/guilds/{guild_id}/widget", guild_id=guild_id), json=payload) # Invite management @@ -1384,22 +1417,22 @@ class HTTPClient: target_user_id: Optional[Snowflake] = None, target_application_id: Optional[Snowflake] = None, ) -> Response[invite.Invite]: - r = Route('POST', '/channels/{channel_id}/invites', channel_id=channel_id) + r = Route("POST", "/channels/{channel_id}/invites", channel_id=channel_id) payload = { - 'max_age': max_age, - 'max_uses': max_uses, - 'temporary': temporary, - 'unique': unique, + "max_age": max_age, + "max_uses": max_uses, + "temporary": temporary, + "unique": unique, } if target_type: - payload['target_type'] = target_type + payload["target_type"] = target_type if target_user_id: - payload['target_user_id'] = target_user_id + payload["target_user_id"] = target_user_id if target_application_id: - payload['target_application_id'] = str(target_application_id) + payload["target_application_id"] = str(target_application_id) return self.request(r, reason=reason, json=payload) @@ -1407,35 +1440,35 @@ class HTTPClient: self, invite_id: str, *, with_counts: bool = True, with_expiration: bool = True ) -> Response[invite.Invite]: params = { - 'with_counts': int(with_counts), - 'with_expiration': int(with_expiration), + "with_counts": int(with_counts), + "with_expiration": int(with_expiration), } - return self.request(Route('GET', '/invites/{invite_id}', invite_id=invite_id), params=params) + return self.request(Route("GET", "/invites/{invite_id}", invite_id=invite_id), params=params) def invites_from(self, guild_id: Snowflake) -> Response[List[invite.Invite]]: - return self.request(Route('GET', '/guilds/{guild_id}/invites', guild_id=guild_id)) + return self.request(Route("GET", "/guilds/{guild_id}/invites", guild_id=guild_id)) def invites_from_channel(self, channel_id: Snowflake) -> Response[List[invite.Invite]]: - return self.request(Route('GET', '/channels/{channel_id}/invites', channel_id=channel_id)) + return self.request(Route("GET", "/channels/{channel_id}/invites", channel_id=channel_id)) def delete_invite(self, invite_id: str, *, reason: Optional[str] = None) -> Response[None]: - return self.request(Route('DELETE', '/invites/{invite_id}', invite_id=invite_id), reason=reason) + return self.request(Route("DELETE", "/invites/{invite_id}", invite_id=invite_id), reason=reason) # Role management def get_roles(self, guild_id: Snowflake) -> Response[List[role.Role]]: - return self.request(Route('GET', '/guilds/{guild_id}/roles', guild_id=guild_id)) + return self.request(Route("GET", "/guilds/{guild_id}/roles", guild_id=guild_id)) def edit_role( self, guild_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None, **fields: Any ) -> Response[role.Role]: - r = Route('PATCH', '/guilds/{guild_id}/roles/{role_id}', guild_id=guild_id, role_id=role_id) - valid_keys = ('name', 'permissions', 'color', 'hoist', 'mentionable') + r = Route("PATCH", "/guilds/{guild_id}/roles/{role_id}", guild_id=guild_id, role_id=role_id) + valid_keys = ("name", "permissions", "color", "hoist", "mentionable") payload = {k: v for k, v in fields.items() if k in valid_keys} return self.request(r, json=payload, reason=reason) def delete_role(self, guild_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]: - r = Route('DELETE', '/guilds/{guild_id}/roles/{role_id}', guild_id=guild_id, role_id=role_id) + r = Route("DELETE", "/guilds/{guild_id}/roles/{role_id}", guild_id=guild_id, role_id=role_id) return self.request(r, reason=reason) def replace_roles( @@ -1449,7 +1482,7 @@ class HTTPClient: return self.edit_member(guild_id=guild_id, user_id=user_id, roles=role_ids, reason=reason) def create_role(self, guild_id: Snowflake, *, reason: Optional[str] = None, **fields: Any) -> Response[role.Role]: - r = Route('POST', '/guilds/{guild_id}/roles', guild_id=guild_id) + r = Route("POST", "/guilds/{guild_id}/roles", guild_id=guild_id) return self.request(r, json=fields, reason=reason) def move_role_position( @@ -1459,15 +1492,15 @@ class HTTPClient: *, reason: Optional[str] = None, ) -> Response[List[role.Role]]: - r = Route('PATCH', '/guilds/{guild_id}/roles', guild_id=guild_id) + r = Route("PATCH", "/guilds/{guild_id}/roles", guild_id=guild_id) return self.request(r, json=positions, reason=reason) def add_role( self, guild_id: Snowflake, user_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None ) -> Response[None]: r = Route( - 'PUT', - '/guilds/{guild_id}/members/{user_id}/roles/{role_id}', + "PUT", + "/guilds/{guild_id}/members/{user_id}/roles/{role_id}", guild_id=guild_id, user_id=user_id, role_id=role_id, @@ -1478,8 +1511,8 @@ class HTTPClient: self, guild_id: Snowflake, user_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None ) -> Response[None]: r = Route( - 'DELETE', - '/guilds/{guild_id}/members/{user_id}/roles/{role_id}', + "DELETE", + "/guilds/{guild_id}/members/{user_id}/roles/{role_id}", guild_id=guild_id, user_id=user_id, role_id=role_id, @@ -1496,14 +1529,14 @@ class HTTPClient: *, reason: Optional[str] = None, ) -> Response[None]: - payload = {'id': target, 'allow': allow, 'deny': deny, 'type': type} - r = Route('PUT', '/channels/{channel_id}/permissions/{target}', channel_id=channel_id, target=target) + payload = {"id": target, "allow": allow, "deny": deny, "type": type} + r = Route("PUT", "/channels/{channel_id}/permissions/{target}", channel_id=channel_id, target=target) return self.request(r, json=payload, reason=reason) def delete_channel_permissions( self, channel_id: Snowflake, target: channel.OverwriteType, *, reason: Optional[str] = None ) -> Response[None]: - r = Route('DELETE', '/channels/{channel_id}/permissions/{target}', channel_id=channel_id, target=target) + r = Route("DELETE", "/channels/{channel_id}/permissions/{target}", channel_id=channel_id, target=target) return self.request(r, reason=reason) # Voice management @@ -1521,50 +1554,52 @@ class HTTPClient: # Stage instance management def get_stage_instance(self, channel_id: Snowflake) -> Response[channel.StageInstance]: - return self.request(Route('GET', '/stage-instances/{channel_id}', channel_id=channel_id)) + return self.request(Route("GET", "/stage-instances/{channel_id}", channel_id=channel_id)) def create_stage_instance(self, *, reason: Optional[str], **payload: Any) -> Response[channel.StageInstance]: valid_keys = ( - 'channel_id', - 'topic', - 'privacy_level', + "channel_id", + "topic", + "privacy_level", ) payload = {k: v for k, v in payload.items() if k in valid_keys} - return self.request(Route('POST', '/stage-instances'), json=payload, reason=reason) + return self.request(Route("POST", "/stage-instances"), json=payload, reason=reason) - def edit_stage_instance(self, channel_id: Snowflake, *, reason: Optional[str] = None, **payload: Any) -> Response[None]: + def edit_stage_instance( + self, channel_id: Snowflake, *, reason: Optional[str] = None, **payload: Any + ) -> Response[None]: valid_keys = ( - 'topic', - 'privacy_level', + "topic", + "privacy_level", ) payload = {k: v for k, v in payload.items() if k in valid_keys} return self.request( - Route('PATCH', '/stage-instances/{channel_id}', channel_id=channel_id), json=payload, reason=reason + Route("PATCH", "/stage-instances/{channel_id}", channel_id=channel_id), json=payload, reason=reason ) def delete_stage_instance(self, channel_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]: - return self.request(Route('DELETE', '/stage-instances/{channel_id}', channel_id=channel_id), reason=reason) + return self.request(Route("DELETE", "/stage-instances/{channel_id}", channel_id=channel_id), reason=reason) # Application commands (global) def get_global_commands(self, application_id: Snowflake) -> Response[List[interactions.ApplicationCommand]]: - return self.request(Route('GET', '/applications/{application_id}/commands', application_id=application_id)) + return self.request(Route("GET", "/applications/{application_id}/commands", application_id=application_id)) def get_global_command( self, application_id: Snowflake, command_id: Snowflake ) -> Response[interactions.ApplicationCommand]: r = Route( - 'GET', - '/applications/{application_id}/commands/{command_id}', + "GET", + "/applications/{application_id}/commands/{command_id}", application_id=application_id, command_id=command_id, ) return self.request(r) def upsert_global_command(self, application_id: Snowflake, payload) -> Response[interactions.ApplicationCommand]: - r = Route('POST', '/applications/{application_id}/commands', application_id=application_id) + r = Route("POST", "/applications/{application_id}/commands", application_id=application_id) return self.request(r, json=payload) def edit_global_command( @@ -1574,14 +1609,14 @@ class HTTPClient: payload: interactions.EditApplicationCommand, ) -> Response[interactions.ApplicationCommand]: valid_keys = ( - 'name', - 'description', - 'options', + "name", + "description", + "options", ) payload = {k: v for k, v in payload.items() if k in valid_keys} # type: ignore r = Route( - 'PATCH', - '/applications/{application_id}/commands/{command_id}', + "PATCH", + "/applications/{application_id}/commands/{command_id}", application_id=application_id, command_id=command_id, ) @@ -1589,8 +1624,8 @@ class HTTPClient: def delete_global_command(self, application_id: Snowflake, command_id: Snowflake) -> Response[None]: r = Route( - 'DELETE', - '/applications/{application_id}/commands/{command_id}', + "DELETE", + "/applications/{application_id}/commands/{command_id}", application_id=application_id, command_id=command_id, ) @@ -1599,7 +1634,7 @@ class HTTPClient: def bulk_upsert_global_commands( self, application_id: Snowflake, payload ) -> Response[List[interactions.ApplicationCommand]]: - r = Route('PUT', '/applications/{application_id}/commands', application_id=application_id) + r = Route("PUT", "/applications/{application_id}/commands", application_id=application_id) return self.request(r, json=payload) # Application commands (guild) @@ -1608,8 +1643,8 @@ class HTTPClient: self, application_id: Snowflake, guild_id: Snowflake ) -> Response[List[interactions.ApplicationCommand]]: r = Route( - 'GET', - '/applications/{application_id}/guilds/{guild_id}/commands', + "GET", + "/applications/{application_id}/guilds/{guild_id}/commands", application_id=application_id, guild_id=guild_id, ) @@ -1622,8 +1657,8 @@ class HTTPClient: command_id: Snowflake, ) -> Response[interactions.ApplicationCommand]: r = Route( - 'GET', - '/applications/{application_id}/guilds/{guild_id}/commands/{command_id}', + "GET", + "/applications/{application_id}/guilds/{guild_id}/commands/{command_id}", application_id=application_id, guild_id=guild_id, command_id=command_id, @@ -1637,8 +1672,8 @@ class HTTPClient: payload: interactions.EditApplicationCommand, ) -> Response[interactions.ApplicationCommand]: r = Route( - 'POST', - '/applications/{application_id}/guilds/{guild_id}/commands', + "POST", + "/applications/{application_id}/guilds/{guild_id}/commands", application_id=application_id, guild_id=guild_id, ) @@ -1652,14 +1687,14 @@ class HTTPClient: payload: interactions.EditApplicationCommand, ) -> Response[interactions.ApplicationCommand]: valid_keys = ( - 'name', - 'description', - 'options', + "name", + "description", + "options", ) payload = {k: v for k, v in payload.items() if k in valid_keys} # type: ignore r = Route( - 'PATCH', - '/applications/{application_id}/guilds/{guild_id}/commands/{command_id}', + "PATCH", + "/applications/{application_id}/guilds/{guild_id}/commands/{command_id}", application_id=application_id, guild_id=guild_id, command_id=command_id, @@ -1673,8 +1708,8 @@ class HTTPClient: command_id: Snowflake, ) -> Response[None]: r = Route( - 'DELETE', - '/applications/{application_id}/guilds/{guild_id}/commands/{command_id}', + "DELETE", + "/applications/{application_id}/guilds/{guild_id}/commands/{command_id}", application_id=application_id, guild_id=guild_id, command_id=command_id, @@ -1688,8 +1723,8 @@ class HTTPClient: payload: List[interactions.EditApplicationCommand], ) -> Response[List[interactions.ApplicationCommand]]: r = Route( - 'PUT', - '/applications/{application_id}/guilds/{guild_id}/commands', + "PUT", + "/applications/{application_id}/guilds/{guild_id}/commands", application_id=application_id, guild_id=guild_id, ) @@ -1708,26 +1743,26 @@ class HTTPClient: payload: Dict[str, Any] = {} if content: - payload['content'] = content + payload["content"] = content if embeds: - payload['embeds'] = embeds + payload["embeds"] = embeds if allowed_mentions: - payload['allowed_mentions'] = allowed_mentions + payload["allowed_mentions"] = allowed_mentions form: List[Dict[str, Any]] = [ { - 'name': 'payload_json', - 'value': utils._to_json(payload), + "name": "payload_json", + "value": utils._to_json(payload), } ] if file: form.append( { - 'name': 'file', - 'value': file.fp, - 'filename': file.filename, - 'content_type': 'application/octet-stream', + "name": "file", + "value": file.fp, + "filename": file.filename, + "content_type": "application/octet-stream", } ) @@ -1742,17 +1777,17 @@ class HTTPClient: data: Optional[interactions.InteractionApplicationCommandCallbackData] = None, ) -> Response[None]: r = Route( - 'POST', - '/interactions/{interaction_id}/{interaction_token}/callback', + "POST", + "/interactions/{interaction_id}/{interaction_token}/callback", interaction_id=interaction_id, interaction_token=token, ) payload: Dict[str, Any] = { - 'type': type, + "type": type, } if data is not None: - payload['data'] = data + payload["data"] = data return self.request(r, json=payload) @@ -1762,8 +1797,8 @@ class HTTPClient: token: str, ) -> Response[message.Message]: r = Route( - 'GET', - '/webhooks/{application_id}/{interaction_token}/messages/@original', + "GET", + "/webhooks/{application_id}/{interaction_token}/messages/@original", application_id=application_id, interaction_token=token, ) @@ -1779,17 +1814,19 @@ class HTTPClient: allowed_mentions: Optional[message.AllowedMentions] = None, ) -> Response[message.Message]: r = Route( - 'PATCH', - '/webhooks/{application_id}/{interaction_token}/messages/@original', + "PATCH", + "/webhooks/{application_id}/{interaction_token}/messages/@original", application_id=application_id, interaction_token=token, ) - return self._edit_webhook_helper(r, file=file, content=content, embeds=embeds, allowed_mentions=allowed_mentions) + return self._edit_webhook_helper( + r, file=file, content=content, embeds=embeds, allowed_mentions=allowed_mentions + ) def delete_original_interaction_response(self, application_id: Snowflake, token: str) -> Response[None]: r = Route( - 'DELETE', - '/webhooks/{application_id}/{interaction_token}/messages/@original', + "DELETE", + "/webhooks/{application_id}/{interaction_token}/messages/@original", application_id=application_id, interaction_token=token, ) @@ -1806,8 +1843,8 @@ class HTTPClient: allowed_mentions: Optional[message.AllowedMentions] = None, ) -> Response[message.Message]: r = Route( - 'POST', - '/webhooks/{application_id}/{interaction_token}', + "POST", + "/webhooks/{application_id}/{interaction_token}", application_id=application_id, interaction_token=token, ) @@ -1831,18 +1868,20 @@ class HTTPClient: allowed_mentions: Optional[message.AllowedMentions] = None, ) -> Response[message.Message]: r = Route( - 'PATCH', - '/webhooks/{application_id}/{interaction_token}/messages/{message_id}', + "PATCH", + "/webhooks/{application_id}/{interaction_token}/messages/{message_id}", application_id=application_id, interaction_token=token, message_id=message_id, ) - return self._edit_webhook_helper(r, file=file, content=content, embeds=embeds, allowed_mentions=allowed_mentions) + return self._edit_webhook_helper( + r, file=file, content=content, embeds=embeds, allowed_mentions=allowed_mentions + ) def delete_followup_message(self, application_id: Snowflake, token: str, message_id: Snowflake) -> Response[None]: r = Route( - 'DELETE', - '/webhooks/{application_id}/{interaction_token}/messages/{message_id}', + "DELETE", + "/webhooks/{application_id}/{interaction_token}/messages/{message_id}", application_id=application_id, interaction_token=token, message_id=message_id, @@ -1855,8 +1894,8 @@ class HTTPClient: guild_id: Snowflake, ) -> Response[List[interactions.GuildApplicationCommandPermissions]]: r = Route( - 'GET', - '/applications/{application_id}/guilds/{guild_id}/commands/permissions', + "GET", + "/applications/{application_id}/guilds/{guild_id}/commands/permissions", application_id=application_id, guild_id=guild_id, ) @@ -1869,8 +1908,8 @@ class HTTPClient: command_id: Snowflake, ) -> Response[interactions.GuildApplicationCommandPermissions]: r = Route( - 'GET', - '/applications/{application_id}/guilds/{guild_id}/commands/{command_id}/permissions', + "GET", + "/applications/{application_id}/guilds/{guild_id}/commands/{command_id}/permissions", application_id=application_id, guild_id=guild_id, command_id=command_id, @@ -1885,8 +1924,8 @@ class HTTPClient: payload: interactions.BaseGuildApplicationCommandPermissions, ) -> Response[None]: r = Route( - 'PUT', - '/applications/{application_id}/guilds/{guild_id}/commands/{command_id}/permissions', + "PUT", + "/applications/{application_id}/guilds/{guild_id}/commands/{command_id}/permissions", application_id=application_id, guild_id=guild_id, command_id=command_id, @@ -1900,8 +1939,8 @@ class HTTPClient: payload: List[interactions.PartialGuildApplicationCommandPermissions], ) -> Response[None]: r = Route( - 'PUT', - '/applications/{application_id}/guilds/{guild_id}/commands/permissions', + "PUT", + "/applications/{application_id}/guilds/{guild_id}/commands/permissions", application_id=application_id, guild_id=guild_id, ) @@ -1910,30 +1949,30 @@ class HTTPClient: # Misc def application_info(self) -> Response[appinfo.AppInfo]: - return self.request(Route('GET', '/oauth2/applications/@me')) + return self.request(Route("GET", "/oauth2/applications/@me")) - async def get_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> str: + async def get_gateway(self, *, encoding: str = "json", zlib: bool = True) -> str: try: - data = await self.request(Route('GET', '/gateway')) + data = await self.request(Route("GET", "/gateway")) except HTTPException as exc: raise GatewayNotFound() from exc if zlib: - value = '{0}?encoding={1}&v=9&compress=zlib-stream' + value = "{0}?encoding={1}&v=9&compress=zlib-stream" else: - value = '{0}?encoding={1}&v=9' - return value.format(data['url'], encoding) + value = "{0}?encoding={1}&v=9" + return value.format(data["url"], encoding) - async def get_bot_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> Tuple[int, str]: + async def get_bot_gateway(self, *, encoding: str = "json", zlib: bool = True) -> Tuple[int, str]: try: - data = await self.request(Route('GET', '/gateway/bot')) + data = await self.request(Route("GET", "/gateway/bot")) except HTTPException as exc: raise GatewayNotFound() from exc if zlib: - value = '{0}?encoding={1}&v=9&compress=zlib-stream' + value = "{0}?encoding={1}&v=9&compress=zlib-stream" else: - value = '{0}?encoding={1}&v=9' - return data['shards'], value.format(data['url'], encoding) + value = "{0}?encoding={1}&v=9" + return data["shards"], value.format(data["url"], encoding) def get_user(self, user_id: Snowflake) -> Response[user.User]: - return self.request(Route('GET', '/users/{user_id}', user_id=user_id)) + return self.request(Route("GET", "/users/{user_id}", user_id=user_id)) diff --git a/discord/integrations.py b/discord/integrations.py index 23d02930..3576500d 100644 --- a/discord/integrations.py +++ b/discord/integrations.py @@ -32,11 +32,11 @@ from .errors import InvalidArgument from .enums import try_enum, ExpireBehaviour __all__ = ( - 'IntegrationAccount', - 'IntegrationApplication', - 'Integration', - 'StreamIntegration', - 'BotIntegration', + "IntegrationAccount", + "IntegrationApplication", + "Integration", + "StreamIntegration", + "BotIntegration", ) if TYPE_CHECKING: @@ -65,14 +65,14 @@ class IntegrationAccount: The account name. """ - __slots__ = ('id', 'name') + __slots__ = ("id", "name") def __init__(self, data: IntegrationAccountPayload) -> None: - self.id: str = data['id'] - self.name: str = data['name'] + self.id: str = data["id"] + self.name: str = data["name"] def __repr__(self) -> str: - return f'' + return f"" class Integration: @@ -99,14 +99,14 @@ class Integration: """ __slots__ = ( - 'guild', - 'id', - '_state', - 'type', - 'name', - 'account', - 'user', - 'enabled', + "guild", + "id", + "_state", + "type", + "name", + "account", + "user", + "enabled", ) def __init__(self, *, data: IntegrationPayload, guild: Guild) -> None: @@ -118,14 +118,14 @@ class Integration: return f"<{self.__class__.__name__} id={self.id} name={self.name!r}>" def _from_data(self, data: IntegrationPayload) -> None: - self.id: int = int(data['id']) - self.type: IntegrationType = data['type'] - self.name: str = data['name'] - self.account: IntegrationAccount = IntegrationAccount(data['account']) + self.id: int = int(data["id"]) + self.type: IntegrationType = data["type"] + self.name: str = data["name"] + self.account: IntegrationAccount = IntegrationAccount(data["account"]) - user = data.get('user') + user = data.get("user") self.user = User(state=self._state, data=user) if user else None - self.enabled: bool = data['enabled'] + self.enabled: bool = data["enabled"] async def delete(self, *, reason: Optional[str] = None) -> None: """|coro| @@ -186,26 +186,26 @@ class StreamIntegration(Integration): """ __slots__ = ( - 'revoked', - 'expire_behaviour', - 'expire_grace_period', - 'synced_at', - '_role_id', - 'syncing', - 'enable_emoticons', - 'subscriber_count', + "revoked", + "expire_behaviour", + "expire_grace_period", + "synced_at", + "_role_id", + "syncing", + "enable_emoticons", + "subscriber_count", ) def _from_data(self, data: StreamIntegrationPayload) -> None: super()._from_data(data) - self.revoked: bool = data['revoked'] - self.expire_behaviour: ExpireBehaviour = try_enum(ExpireBehaviour, data['expire_behavior']) - self.expire_grace_period: int = data['expire_grace_period'] - self.synced_at: datetime.datetime = parse_time(data['synced_at']) - self._role_id: Optional[int] = _get_as_snowflake(data, 'role_id') - self.syncing: bool = data['syncing'] - self.enable_emoticons: bool = data['enable_emoticons'] - self.subscriber_count: int = data['subscriber_count'] + self.revoked: bool = data["revoked"] + self.expire_behaviour: ExpireBehaviour = try_enum(ExpireBehaviour, data["expire_behavior"]) + self.expire_grace_period: int = data["expire_grace_period"] + self.synced_at: datetime.datetime = parse_time(data["synced_at"]) + self._role_id: Optional[int] = _get_as_snowflake(data, "role_id") + self.syncing: bool = data["syncing"] + self.enable_emoticons: bool = data["enable_emoticons"] + self.subscriber_count: int = data["subscriber_count"] @property def expire_behavior(self) -> ExpireBehaviour: @@ -252,15 +252,15 @@ class StreamIntegration(Integration): payload: Dict[str, Any] = {} if expire_behaviour is not MISSING: if not isinstance(expire_behaviour, ExpireBehaviour): - raise InvalidArgument('expire_behaviour field must be of type ExpireBehaviour') + raise InvalidArgument("expire_behaviour field must be of type ExpireBehaviour") - payload['expire_behavior'] = expire_behaviour.value + payload["expire_behavior"] = expire_behaviour.value if expire_grace_period is not MISSING: - payload['expire_grace_period'] = expire_grace_period + payload["expire_grace_period"] = expire_grace_period if enable_emoticons is not MISSING: - payload['enable_emoticons'] = enable_emoticons + payload["enable_emoticons"] = enable_emoticons # This endpoint is undocumented. # Unsure if it returns the data or not as a result @@ -307,21 +307,21 @@ class IntegrationApplication: """ __slots__ = ( - 'id', - 'name', - 'icon', - 'description', - 'summary', - 'user', + "id", + "name", + "icon", + "description", + "summary", + "user", ) def __init__(self, *, data: IntegrationApplicationPayload, state): - self.id: int = int(data['id']) - self.name: str = data['name'] - self.icon: Optional[str] = data['icon'] - self.description: str = data['description'] - self.summary: str = data['summary'] - user = data.get('bot') + self.id: int = int(data["id"]) + self.name: str = data["name"] + self.icon: Optional[str] = data["icon"] + self.description: str = data["description"] + self.summary: str = data["summary"] + user = data.get("bot") self.user: Optional[User] = User(state=state, data=user) if user else None @@ -350,17 +350,17 @@ class BotIntegration(Integration): The application tied to this integration. """ - __slots__ = ('application',) + __slots__ = ("application",) def _from_data(self, data: BotIntegrationPayload) -> None: super()._from_data(data) - self.application = IntegrationApplication(data=data['application'], state=self._state) + self.application = IntegrationApplication(data=data["application"], state=self._state) def _integration_factory(value: str) -> Tuple[Type[Integration], str]: - if value == 'discord': + if value == "discord": return BotIntegration, value - elif value in ('twitch', 'youtube'): + elif value in ("twitch", "youtube"): return StreamIntegration, value else: return Integration, value diff --git a/discord/interactions.py b/discord/interactions.py index b89d49f5..0a9c7383 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -41,9 +41,9 @@ from .permissions import Permissions from .webhook.async_ import async_context, Webhook, handle_message_parameters __all__ = ( - 'Interaction', - 'InteractionMessage', - 'InteractionResponse', + "Interaction", + "InteractionMessage", + "InteractionResponse", ) if TYPE_CHECKING: @@ -100,23 +100,23 @@ class Interaction: """ __slots__: Tuple[str, ...] = ( - 'id', - 'type', - 'guild_id', - 'channel_id', - 'data', - 'application_id', - 'message', - 'user', - 'token', - 'version', - '_permissions', - '_state', - '_session', - '_original_message', - '_cs_response', - '_cs_followup', - '_cs_channel', + "id", + "type", + "guild_id", + "channel_id", + "data", + "application_id", + "message", + "user", + "token", + "version", + "_permissions", + "_state", + "_session", + "_original_message", + "_cs_response", + "_cs_followup", + "_cs_channel", ) def __init__(self, *, data: InteractionPayload, state: ConnectionState): @@ -126,18 +126,18 @@ class Interaction: self._from_data(data) def _from_data(self, data: InteractionPayload): - self.id: int = int(data['id']) - self.type: InteractionType = try_enum(InteractionType, data['type']) - self.data: Optional[InteractionData] = data.get('data') - self.token: str = data['token'] - self.version: int = data['version'] - self.channel_id: Optional[int] = utils._get_as_snowflake(data, 'channel_id') - self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id') - self.application_id: int = int(data['application_id']) + self.id: int = int(data["id"]) + self.type: InteractionType = try_enum(InteractionType, data["type"]) + self.data: Optional[InteractionData] = data.get("data") + self.token: str = data["token"] + self.version: int = data["version"] + self.channel_id: Optional[int] = utils._get_as_snowflake(data, "channel_id") + self.guild_id: Optional[int] = utils._get_as_snowflake(data, "guild_id") + self.application_id: int = int(data["application_id"]) self.message: Optional[Message] try: - self.message = Message(state=self._state, channel=self.channel, data=data['message']) # type: ignore + self.message = Message(state=self._state, channel=self.channel, data=data["message"]) # type: ignore except KeyError: self.message = None @@ -148,15 +148,15 @@ class Interaction: if self.guild_id: guild = self.guild or Object(id=self.guild_id) try: - member = data['member'] # type: ignore + member = data["member"] # type: ignore except KeyError: pass else: self.user = Member(state=self._state, guild=guild, data=member) # type: ignore - self._permissions = int(member.get('permissions', 0)) + self._permissions = int(member.get("permissions", 0)) else: try: - self.user = User(state=self._state, data=data['user']) + self.user = User(state=self._state, data=data["user"]) except KeyError: pass @@ -165,7 +165,7 @@ class Interaction: """Optional[:class:`Guild`]: The guild the interaction was sent from.""" return self._state and self._state._get_guild(self.guild_id) - @utils.cached_slot_property('_cs_channel') + @utils.cached_slot_property("_cs_channel") def channel(self) -> Optional[InteractionChannel]: """Optional[Union[:class:`abc.GuildChannel`, :class:`PartialMessageable`, :class:`Thread`]]: The channel the interaction was sent from. @@ -189,7 +189,7 @@ class Interaction: """ return Permissions(self._permissions) - @utils.cached_slot_property('_cs_response') + @utils.cached_slot_property("_cs_response") def response(self) -> InteractionResponse: """:class:`InteractionResponse`: Returns an object responsible for handling responding to the interaction. @@ -198,13 +198,13 @@ class Interaction: """ return InteractionResponse(self) - @utils.cached_slot_property('_cs_followup') + @utils.cached_slot_property("_cs_followup") def followup(self) -> Webhook: """:class:`Webhook`: Returns the follow up webhook for follow up interactions.""" payload = { - 'id': self.application_id, - 'type': 3, - 'token': self.token, + "id": self.application_id, + "type": 3, + "token": self.token, } return Webhook.from_state(data=payload, state=self._state) @@ -238,7 +238,7 @@ class Interaction: # TODO: fix later to not raise? channel = self.channel if channel is None: - raise ClientException('Channel for message could not be resolved') + raise ClientException("Channel for message could not be resolved") adapter = async_context.get() data = await adapter.get_original_interaction_response( @@ -369,8 +369,8 @@ class InteractionResponse: """ __slots__: Tuple[str, ...] = ( - '_responded', - '_parent', + "_responded", + "_parent", ) def __init__(self, parent: Interaction): @@ -416,7 +416,7 @@ class InteractionResponse: elif parent.type is InteractionType.application_command: defer_type = InteractionResponseType.deferred_channel_message.value if ephemeral: - data = {'flags': 64} + data = {"flags": 64} if defer_type: adapter = async_context.get() @@ -498,28 +498,28 @@ class InteractionResponse: raise InteractionResponded(self._parent) payload: Dict[str, Any] = { - 'tts': tts, + "tts": tts, } if embed is not MISSING and embeds is not MISSING: - raise TypeError('cannot mix embed and embeds keyword arguments') + raise TypeError("cannot mix embed and embeds keyword arguments") if embed is not MISSING: embeds = [embed] if embeds: if len(embeds) > 10: - raise ValueError('embeds cannot exceed maximum of 10 elements') - payload['embeds'] = [e.to_dict() for e in embeds] + raise ValueError("embeds cannot exceed maximum of 10 elements") + payload["embeds"] = [e.to_dict() for e in embeds] if content is not None: - payload['content'] = str(content) + payload["content"] = str(content) if ephemeral: - payload['flags'] = 64 + payload["flags"] = 64 if view is not MISSING: - payload['components'] = view.to_components() + payload["components"] = view.to_components() parent = self._parent adapter = async_context.get() @@ -591,12 +591,12 @@ class InteractionResponse: payload = {} if content is not MISSING: if content is None: - payload['content'] = None + payload["content"] = None else: - payload['content'] = str(content) + payload["content"] = str(content) if embed is not MISSING and embeds is not MISSING: - raise TypeError('cannot mix both embed and embeds keyword arguments') + raise TypeError("cannot mix both embed and embeds keyword arguments") if embed is not MISSING: if embed is None: @@ -605,17 +605,17 @@ class InteractionResponse: embeds = [embed] if embeds is not MISSING: - payload['embeds'] = [e.to_dict() for e in embeds] + payload["embeds"] = [e.to_dict() for e in embeds] if attachments is not MISSING: - payload['attachments'] = [a.to_dict() for a in attachments] + payload["attachments"] = [a.to_dict() for a in attachments] if view is not MISSING: state.prevent_view_updates_for(message_id) if view is None: - payload['components'] = [] + payload["components"] = [] else: - payload['components'] = view.to_components() + payload["components"] = view.to_components() adapter = async_context.get() await adapter.create_interaction_response( @@ -633,7 +633,7 @@ class InteractionResponse: class _InteractionMessageState: - __slots__ = ('_parent', '_interaction') + __slots__ = ("_parent", "_interaction") def __init__(self, interaction: Interaction, parent: ConnectionState): self._interaction: Interaction = interaction diff --git a/discord/invite.py b/discord/invite.py index 050d2b83..44d4fc9a 100644 --- a/discord/invite.py +++ b/discord/invite.py @@ -33,9 +33,9 @@ from .enums import ChannelType, VerificationLevel, InviteTarget, try_enum from .appinfo import PartialAppInfo __all__ = ( - 'PartialInviteChannel', - 'PartialInviteGuild', - 'Invite', + "PartialInviteChannel", + "PartialInviteGuild", + "Invite", ) if TYPE_CHECKING: @@ -52,8 +52,8 @@ if TYPE_CHECKING: from .abc import GuildChannel from .user import User - InviteGuildType = Union[Guild, 'PartialInviteGuild', Object] - InviteChannelType = Union[GuildChannel, 'PartialInviteChannel', Object] + InviteGuildType = Union[Guild, "PartialInviteGuild", Object] + InviteChannelType = Union[GuildChannel, "PartialInviteChannel", Object] import datetime @@ -92,23 +92,23 @@ class PartialInviteChannel: The partial channel's type. """ - __slots__ = ('id', 'name', 'type') + __slots__ = ("id", "name", "type") def __init__(self, data: InviteChannelPayload): - self.id: int = int(data['id']) - self.name: str = data['name'] - self.type: ChannelType = try_enum(ChannelType, data['type']) + self.id: int = int(data["id"]) + self.name: str = data["name"] + self.type: ChannelType = try_enum(ChannelType, data["type"]) def __str__(self) -> str: return self.name def __repr__(self) -> str: - return f'' + return f"" @property def mention(self) -> str: """:class:`str`: The string that allows you to mention the channel.""" - return f'<#{self.id}>' + return f"<#{self.id}>" @property def created_at(self) -> datetime.datetime: @@ -154,26 +154,26 @@ class PartialInviteGuild: The partial guild's description. """ - __slots__ = ('_state', 'features', '_icon', '_banner', 'id', 'name', '_splash', 'verification_level', 'description') + __slots__ = ("_state", "features", "_icon", "_banner", "id", "name", "_splash", "verification_level", "description") def __init__(self, state: ConnectionState, data: InviteGuildPayload, id: int): self._state: ConnectionState = state self.id: int = id - self.name: str = data['name'] - self.features: List[str] = data.get('features', []) - self._icon: Optional[str] = data.get('icon') - self._banner: Optional[str] = data.get('banner') - self._splash: Optional[str] = data.get('splash') - self.verification_level: VerificationLevel = try_enum(VerificationLevel, data.get('verification_level')) - self.description: Optional[str] = data.get('description') + self.name: str = data["name"] + self.features: List[str] = data.get("features", []) + self._icon: Optional[str] = data.get("icon") + self._banner: Optional[str] = data.get("banner") + self._splash: Optional[str] = data.get("splash") + self.verification_level: VerificationLevel = try_enum(VerificationLevel, data.get("verification_level")) + self.description: Optional[str] = data.get("description") def __str__(self) -> str: return self.name def __repr__(self) -> str: return ( - f'<{self.__class__.__name__} id={self.id} name={self.name!r} features={self.features} ' - f'description={self.description!r}>' + f"<{self.__class__.__name__} id={self.id} name={self.name!r} features={self.features} " + f"description={self.description!r}>" ) @property @@ -193,17 +193,17 @@ class PartialInviteGuild: """Optional[:class:`Asset`]: Returns the guild's banner asset, if available.""" if self._banner is None: return None - return Asset._from_guild_image(self._state, self.id, self._banner, path='banners') + return Asset._from_guild_image(self._state, self.id, self._banner, path="banners") @property def splash(self) -> Optional[Asset]: """Optional[:class:`Asset`]: Returns the guild's invite splash asset, if available.""" if self._splash is None: return None - return Asset._from_guild_image(self._state, self.id, self._splash, path='splashes') + return Asset._from_guild_image(self._state, self.id, self._splash, path="splashes") -I = TypeVar('I', bound='Invite') +I = TypeVar("I", bound="Invite") class Invite(Hashable): @@ -308,26 +308,26 @@ class Invite(Hashable): """ __slots__ = ( - 'max_age', - 'code', - 'guild', - 'revoked', - 'created_at', - 'uses', - 'temporary', - 'max_uses', - 'inviter', - 'channel', - 'target_user', - 'target_type', - '_state', - 'approximate_member_count', - 'approximate_presence_count', - 'target_application', - 'expires_at', + "max_age", + "code", + "guild", + "revoked", + "created_at", + "uses", + "temporary", + "max_uses", + "inviter", + "channel", + "target_user", + "target_type", + "_state", + "approximate_member_count", + "approximate_presence_count", + "target_application", + "expires_at", ) - BASE = 'https://discord.gg' + BASE = "https://discord.gg" def __init__( self, @@ -338,31 +338,33 @@ class Invite(Hashable): channel: Optional[Union[PartialInviteChannel, GuildChannel]] = None, ): self._state: ConnectionState = state - self.max_age: Optional[int] = data.get('max_age') - self.code: str = data['code'] - self.guild: Optional[InviteGuildType] = self._resolve_guild(data.get('guild'), guild) - self.revoked: Optional[bool] = data.get('revoked') - self.created_at: Optional[datetime.datetime] = parse_time(data.get('created_at')) - self.temporary: Optional[bool] = data.get('temporary') - self.uses: Optional[int] = data.get('uses') - self.max_uses: Optional[int] = data.get('max_uses') - self.approximate_presence_count: Optional[int] = data.get('approximate_presence_count') - self.approximate_member_count: Optional[int] = data.get('approximate_member_count') + self.max_age: Optional[int] = data.get("max_age") + self.code: str = data["code"] + self.guild: Optional[InviteGuildType] = self._resolve_guild(data.get("guild"), guild) + self.revoked: Optional[bool] = data.get("revoked") + self.created_at: Optional[datetime.datetime] = parse_time(data.get("created_at")) + self.temporary: Optional[bool] = data.get("temporary") + self.uses: Optional[int] = data.get("uses") + self.max_uses: Optional[int] = data.get("max_uses") + self.approximate_presence_count: Optional[int] = data.get("approximate_presence_count") + self.approximate_member_count: Optional[int] = data.get("approximate_member_count") - expires_at = data.get('expires_at', None) + expires_at = data.get("expires_at", None) self.expires_at: Optional[datetime.datetime] = parse_time(expires_at) if expires_at else None - inviter_data = data.get('inviter') + inviter_data = data.get("inviter") self.inviter: Optional[User] = None if inviter_data is None else self._state.create_user(inviter_data) - self.channel: Optional[InviteChannelType] = self._resolve_channel(data.get('channel'), channel) + self.channel: Optional[InviteChannelType] = self._resolve_channel(data.get("channel"), channel) - target_user_data = data.get('target_user') - self.target_user: Optional[User] = None if target_user_data is None else self._state.create_user(target_user_data) + target_user_data = data.get("target_user") + self.target_user: Optional[User] = ( + None if target_user_data is None else self._state.create_user(target_user_data) + ) self.target_type: InviteTarget = try_enum(InviteTarget, data.get("target_type", 0)) - application = data.get('target_application') + application = data.get("target_application") self.target_application: Optional[PartialAppInfo] = ( PartialAppInfo(data=application, state=state) if application else None ) @@ -371,12 +373,12 @@ class Invite(Hashable): def from_incomplete(cls: Type[I], *, state: ConnectionState, data: InvitePayload) -> I: guild: Optional[Union[Guild, PartialInviteGuild]] try: - guild_data = data['guild'] + guild_data = data["guild"] except KeyError: # If we're here, then this is a group DM guild = None else: - guild_id = int(guild_data['id']) + guild_id = int(guild_data["id"]) guild = state._get_guild(guild_id) if guild is None: # If it's not cached, then it has to be a partial guild @@ -384,7 +386,7 @@ class Invite(Hashable): # As far as I know, invites always need a channel # So this should never raise. - channel: Union[PartialInviteChannel, GuildChannel] = PartialInviteChannel(data['channel']) + channel: Union[PartialInviteChannel, GuildChannel] = PartialInviteChannel(data["channel"]) if guild is not None and not isinstance(guild, PartialInviteGuild): # Upgrade the partial data if applicable channel = guild.get_channel(channel.id) or channel @@ -393,9 +395,9 @@ class Invite(Hashable): @classmethod def from_gateway(cls: Type[I], *, state: ConnectionState, data: GatewayInvitePayload) -> I: - guild_id: Optional[int] = _get_as_snowflake(data, 'guild_id') + guild_id: Optional[int] = _get_as_snowflake(data, "guild_id") guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id) - channel_id = int(data['channel_id']) + channel_id = int(data["channel_id"]) if guild is not None: channel = guild.get_channel(channel_id) or Object(id=channel_id) # type: ignore else: @@ -415,7 +417,7 @@ class Invite(Hashable): if data is None: return None - guild_id = int(data['id']) + guild_id = int(data["id"]) return PartialInviteGuild(self._state, data, guild_id) def _resolve_channel( @@ -439,9 +441,9 @@ class Invite(Hashable): def __repr__(self) -> str: return ( - f'' + f"" ) def __hash__(self) -> int: @@ -455,7 +457,7 @@ class Invite(Hashable): @property def url(self) -> str: """:class:`str`: A property that retrieves the invite URL.""" - return self.BASE + '/' + self.code + return self.BASE + "/" + self.code async def delete(self, *, reason: Optional[str] = None): """|coro| diff --git a/discord/iterators.py b/discord/iterators.py index f725d527..f5a94ae1 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -34,11 +34,11 @@ from .object import Object from .audit_logs import AuditLogEntry __all__ = ( - 'ReactionIterator', - 'HistoryIterator', - 'AuditLogIterator', - 'GuildIterator', - 'MemberIterator', + "ReactionIterator", + "HistoryIterator", + "AuditLogIterator", + "GuildIterator", + "MemberIterator", ) if TYPE_CHECKING: @@ -67,8 +67,8 @@ if TYPE_CHECKING: from .threads import Thread from .abc import Snowflake -T = TypeVar('T') -OT = TypeVar('OT') +T = TypeVar("T") +OT = TypeVar("OT") _Func = Callable[[T], Union[OT, Awaitable[OT]]] OLDEST_OBJECT = Object(id=0) @@ -83,7 +83,7 @@ class _AsyncIterator(AsyncIterator[T]): def get(self, **attrs: Any) -> Awaitable[Optional[T]]: def predicate(elem: T): for attr, val in attrs.items(): - nested = attr.split('__') + nested = attr.split("__") obj = elem for attribute in nested: obj = getattr(obj, attribute) @@ -107,7 +107,7 @@ class _AsyncIterator(AsyncIterator[T]): def chunk(self, max_size: int) -> _ChunkedAsyncIterator[T]: if max_size <= 0: - raise ValueError('async iterator chunk sizes must be greater than 0.') + raise ValueError("async iterator chunk sizes must be greater than 0.") return _ChunkedAsyncIterator(self, max_size) def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]: @@ -182,7 +182,7 @@ class _FilteredAsyncIterator(_AsyncIterator[T]): return item -class ReactionIterator(_AsyncIterator[Union['User', 'Member']]): +class ReactionIterator(_AsyncIterator[Union["User", "Member"]]): def __init__(self, message, emoji, limit=100, after=None): self.message = message self.limit = limit @@ -218,14 +218,14 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]): if data: self.limit -= retrieve - self.after = Object(id=int(data[-1]['id'])) + self.after = Object(id=int(data[-1]["id"])) if self.guild is None or isinstance(self.guild, Object): for element in reversed(data): await self.users.put(User(state=self.state, data=element)) else: for element in reversed(data): - member_id = int(element['id']) + member_id = int(element["id"]) member = self.guild.get_member(member_id) if member is not None: await self.users.put(member) @@ -233,7 +233,7 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]): await self.users.put(User(state=self.state, data=element)) -class HistoryIterator(_AsyncIterator['Message']): +class HistoryIterator(_AsyncIterator["Message"]): """Iterator for receiving a channel's message history. The messages endpoint has two behaviours we care about here: @@ -295,7 +295,7 @@ class HistoryIterator(_AsyncIterator['Message']): if self.around: if self.limit is None: - raise ValueError('history does not support around with limit=None') + raise ValueError("history does not support around with limit=None") if self.limit > 101: raise ValueError("history max limit 101 when specifying around parameter") elif self.limit == 101: @@ -303,20 +303,20 @@ class HistoryIterator(_AsyncIterator['Message']): self._retrieve_messages = self._retrieve_messages_around_strategy # type: ignore if self.before and self.after: - self._filter = lambda m: self.after.id < int(m['id']) < self.before.id + self._filter = lambda m: self.after.id < int(m["id"]) < self.before.id elif self.before: - self._filter = lambda m: int(m['id']) < self.before.id + self._filter = lambda m: int(m["id"]) < self.before.id elif self.after: - self._filter = lambda m: self.after.id < int(m['id']) + self._filter = lambda m: self.after.id < int(m["id"]) else: if self.reverse: self._retrieve_messages = self._retrieve_messages_after_strategy # type: ignore if self.before: - self._filter = lambda m: int(m['id']) < self.before.id + self._filter = lambda m: int(m["id"]) < self.before.id else: self._retrieve_messages = self._retrieve_messages_before_strategy # type: ignore if self.after and self.after != OLDEST_OBJECT: - self._filter = lambda m: int(m['id']) > self.after.id + self._filter = lambda m: int(m["id"]) > self.after.id async def next(self) -> Message: if self.messages.empty(): @@ -337,7 +337,7 @@ class HistoryIterator(_AsyncIterator['Message']): return r > 0 async def fill_messages(self): - if not hasattr(self, 'channel'): + if not hasattr(self, "channel"): # do the required set up channel = await self.messageable._get_channel() self.channel = channel @@ -367,7 +367,7 @@ class HistoryIterator(_AsyncIterator['Message']): if len(data): if self.limit is not None: self.limit -= retrieve - self.before = Object(id=int(data[-1]['id'])) + self.before = Object(id=int(data[-1]["id"])) return data async def _retrieve_messages_after_strategy(self, retrieve): @@ -377,7 +377,7 @@ class HistoryIterator(_AsyncIterator['Message']): if len(data): if self.limit is not None: self.limit -= retrieve - self.after = Object(id=int(data[0]['id'])) + self.after = Object(id=int(data[0]["id"])) return data async def _retrieve_messages_around_strategy(self, retrieve): @@ -390,7 +390,7 @@ class HistoryIterator(_AsyncIterator['Message']): return [] -class AuditLogIterator(_AsyncIterator['AuditLogEntry']): +class AuditLogIterator(_AsyncIterator["AuditLogEntry"]): def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None): if isinstance(before, datetime.datetime): before = Object(id=time_snowflake(before, high=False)) @@ -420,11 +420,11 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']): if self.reverse: self._strategy = self._after_strategy if self.before: - self._filter = lambda m: int(m['id']) < self.before.id + self._filter = lambda m: int(m["id"]) < self.before.id else: self._strategy = self._before_strategy if self.after and self.after != OLDEST_OBJECT: - self._filter = lambda m: int(m['id']) > self.after.id + self._filter = lambda m: int(m["id"]) > self.after.id async def _before_strategy(self, retrieve): before = self.before.id if self.before else None @@ -432,24 +432,24 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']): self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, before=before ) - entries = data.get('audit_log_entries', []) + entries = data.get("audit_log_entries", []) if len(data) and entries: if self.limit is not None: self.limit -= retrieve - self.before = Object(id=int(entries[-1]['id'])) - return data.get('users', []), entries + self.before = Object(id=int(entries[-1]["id"])) + return data.get("users", []), entries async def _after_strategy(self, retrieve): after = self.after.id if self.after else None data: AuditLogPayload = await self.request( self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, after=after ) - entries = data.get('audit_log_entries', []) + entries = data.get("audit_log_entries", []) if len(data) and entries: if self.limit is not None: self.limit -= retrieve - self.after = Object(id=int(entries[0]['id'])) - return data.get('users', []), entries + self.after = Object(id=int(entries[0]["id"])) + return data.get("users", []), entries async def next(self) -> AuditLogEntry: if self.entries.empty(): @@ -488,13 +488,13 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']): for element in data: # TODO: remove this if statement later - if element['action_type'] is None: + if element["action_type"] is None: continue await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild)) -class GuildIterator(_AsyncIterator['Guild']): +class GuildIterator(_AsyncIterator["Guild"]): """Iterator for receiving the client's guilds. The guilds endpoint has the same two behaviours as described @@ -543,7 +543,7 @@ class GuildIterator(_AsyncIterator['Guild']): if self.before and self.after: self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore - self._filter = lambda m: int(m['id']) > self.after.id + self._filter = lambda m: int(m["id"]) > self.after.id elif self.after: self._retrieve_guilds = self._retrieve_guilds_after_strategy # type: ignore else: @@ -595,7 +595,7 @@ class GuildIterator(_AsyncIterator['Guild']): if len(data): if self.limit is not None: self.limit -= retrieve - self.before = Object(id=int(data[-1]['id'])) + self.before = Object(id=int(data[-1]["id"])) return data async def _retrieve_guilds_after_strategy(self, retrieve): @@ -605,11 +605,11 @@ class GuildIterator(_AsyncIterator['Guild']): if len(data): if self.limit is not None: self.limit -= retrieve - self.after = Object(id=int(data[0]['id'])) + self.after = Object(id=int(data[0]["id"])) return data -class MemberIterator(_AsyncIterator['Member']): +class MemberIterator(_AsyncIterator["Member"]): def __init__(self, guild, limit=1000, after=None): if isinstance(after, datetime.datetime): @@ -652,7 +652,7 @@ class MemberIterator(_AsyncIterator['Member']): if len(data) < 1000: self.limit = 0 # terminate loop - self.after = Object(id=int(data[-1]['user']['id'])) + self.after = Object(id=int(data[-1]["user"]["id"])) for element in reversed(data): await self.members.put(self.create_member(element)) @@ -663,7 +663,7 @@ class MemberIterator(_AsyncIterator['Member']): return Member(data=data, guild=self.guild, state=self.state) -class ArchivedThreadIterator(_AsyncIterator['Thread']): +class ArchivedThreadIterator(_AsyncIterator["Thread"]): def __init__( self, channel_id: int, @@ -681,7 +681,7 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']): self.http = guild._state.http if joined and not private: - raise ValueError('Cannot iterate over joined public archived threads') + raise ValueError("Cannot iterate over joined public archived threads") self.before: Optional[str] if before is None: @@ -721,11 +721,11 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']): @staticmethod def get_archive_timestamp(data: ThreadPayload) -> str: - return data['thread_metadata']['archive_timestamp'] + return data["thread_metadata"]["archive_timestamp"] @staticmethod def get_thread_id(data: ThreadPayload) -> str: - return data['id'] # type: ignore + return data["id"] # type: ignore async def fill_queue(self) -> None: if not self.has_more: @@ -735,11 +735,11 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']): data = await self.endpoint(self.channel_id, before=self.before, limit=limit) # This stuff is obviously WIP because 'members' is always empty - threads: List[ThreadPayload] = data.get('threads', []) + threads: List[ThreadPayload] = data.get("threads", []) for d in reversed(threads): self.queue.put_nowait(self.create_thread(d)) - self.has_more = data.get('has_more', False) + self.has_more = data.get("has_more", False) if self.limit is not None: self.limit -= len(threads) if self.limit <= 0: @@ -750,4 +750,5 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']): def create_thread(self, data: ThreadPayload) -> Thread: from .threads import Thread + return Thread(guild=self.guild, state=self.guild._state, data=data) diff --git a/discord/member.py b/discord/member.py index 49b7e3ab..4d49af39 100644 --- a/discord/member.py +++ b/discord/member.py @@ -44,8 +44,8 @@ from .colour import Colour from .object import Object __all__ = ( - 'VoiceState', - 'Member', + "VoiceState", + "Member", ) if TYPE_CHECKING: @@ -113,52 +113,54 @@ class VoiceState: """ __slots__ = ( - 'session_id', - 'deaf', - 'mute', - 'self_mute', - 'self_stream', - 'self_video', - 'self_deaf', - 'afk', - 'channel', - 'requested_to_speak_at', - 'suppress', + "session_id", + "deaf", + "mute", + "self_mute", + "self_stream", + "self_video", + "self_deaf", + "afk", + "channel", + "requested_to_speak_at", + "suppress", ) def __init__(self, *, data: VoiceStatePayload, channel: Optional[VocalGuildChannel] = None): - self.session_id: str = data.get('session_id') + self.session_id: str = data.get("session_id") self._update(data, channel) def _update(self, data: VoiceStatePayload, channel: Optional[VocalGuildChannel]): - self.self_mute: bool = data.get('self_mute', False) - self.self_deaf: bool = data.get('self_deaf', False) - self.self_stream: bool = data.get('self_stream', False) - self.self_video: bool = data.get('self_video', False) - self.afk: bool = data.get('suppress', False) - self.mute: bool = data.get('mute', False) - self.deaf: bool = data.get('deaf', False) - self.suppress: bool = data.get('suppress', False) - self.requested_to_speak_at: Optional[datetime.datetime] = utils.parse_time(data.get('request_to_speak_timestamp')) + self.self_mute: bool = data.get("self_mute", False) + self.self_deaf: bool = data.get("self_deaf", False) + self.self_stream: bool = data.get("self_stream", False) + self.self_video: bool = data.get("self_video", False) + self.afk: bool = data.get("suppress", False) + self.mute: bool = data.get("mute", False) + self.deaf: bool = data.get("deaf", False) + self.suppress: bool = data.get("suppress", False) + self.requested_to_speak_at: Optional[datetime.datetime] = utils.parse_time( + data.get("request_to_speak_timestamp") + ) self.channel: Optional[VocalGuildChannel] = channel def __repr__(self) -> str: attrs = [ - ('self_mute', self.self_mute), - ('self_deaf', self.self_deaf), - ('self_stream', self.self_stream), - ('suppress', self.suppress), - ('requested_to_speak_at', self.requested_to_speak_at), - ('channel', self.channel), + ("self_mute", self.self_mute), + ("self_deaf", self.self_deaf), + ("self_stream", self.self_stream), + ("suppress", self.suppress), + ("requested_to_speak_at", self.requested_to_speak_at), + ("channel", self.channel), ] - inner = ' '.join('%s=%r' % t for t in attrs) - return f'<{self.__class__.__name__} {inner}>' + inner = " ".join("%s=%r" % t for t in attrs) + return f"<{self.__class__.__name__} {inner}>" def flatten_user(cls): for attr, value in itertools.chain(BaseUser.__dict__.items(), User.__dict__.items()): # ignore private/special methods - if attr.startswith('_'): + if attr.startswith("_"): continue # don't override what we already have @@ -167,9 +169,9 @@ def flatten_user(cls): # if it's a slotted attribute or a property, redirect it # slotted members are implemented as member_descriptors in Type.__dict__ - if not hasattr(value, '__annotations__'): - getter = attrgetter('_user.' + attr) - setattr(cls, attr, property(getter, doc=f'Equivalent to :attr:`User.{attr}`')) + if not hasattr(value, "__annotations__"): + getter = attrgetter("_user." + attr) + setattr(cls, attr, property(getter, doc=f"Equivalent to :attr:`User.{attr}`")) else: # Technically, this can also use attrgetter # However I'm not sure how I feel about "functions" returning properties @@ -197,7 +199,7 @@ def flatten_user(cls): return cls -M = TypeVar('M', bound='Member') +M = TypeVar("M", bound="Member") @flatten_user @@ -258,17 +260,17 @@ class Member(discord.abc.Messageable, _UserTag): """ __slots__ = ( - '_roles', - 'joined_at', - 'premium_since', - 'activities', - 'guild', - 'pending', - 'nick', - '_client_status', - '_user', - '_state', - '_avatar', + "_roles", + "joined_at", + "premium_since", + "activities", + "guild", + "pending", + "nick", + "_client_status", + "_user", + "_state", + "_avatar", ) if TYPE_CHECKING: @@ -290,16 +292,16 @@ class Member(discord.abc.Messageable, _UserTag): def __init__(self, *, data: MemberWithUserPayload, guild: Guild, state: ConnectionState): self._state: ConnectionState = state - self._user: User = state.store_user(data['user']) + self._user: User = state.store_user(data["user"]) self.guild: Guild = guild - self.joined_at: Optional[datetime.datetime] = utils.parse_time(data.get('joined_at')) - self.premium_since: Optional[datetime.datetime] = utils.parse_time(data.get('premium_since')) - self._roles: utils.SnowflakeList = utils.SnowflakeList(map(int, data['roles'])) - self._client_status: Dict[Optional[str], str] = {None: 'offline'} + self.joined_at: Optional[datetime.datetime] = utils.parse_time(data.get("joined_at")) + self.premium_since: Optional[datetime.datetime] = utils.parse_time(data.get("premium_since")) + self._roles: utils.SnowflakeList = utils.SnowflakeList(map(int, data["roles"])) + self._client_status: Dict[Optional[str], str] = {None: "offline"} self.activities: Tuple[ActivityTypes, ...] = tuple() - self.nick: Optional[str] = data.get('nick', None) - self.pending: bool = data.get('pending', False) - self._avatar: Optional[str] = data.get('avatar') + self.nick: Optional[str] = data.get("nick", None) + self.pending: bool = data.get("pending", False) + self._avatar: Optional[str] = data.get("avatar") def __str__(self) -> str: return str(self._user) @@ -309,8 +311,8 @@ class Member(discord.abc.Messageable, _UserTag): def __repr__(self) -> str: return ( - f'' + f"" ) def __eq__(self, other: Any) -> bool: @@ -325,25 +327,27 @@ class Member(discord.abc.Messageable, _UserTag): @classmethod def _from_message(cls: Type[M], *, message: Message, data: MemberPayload) -> M: author = message.author - data['user'] = author._to_minimal_user_json() # type: ignore + data["user"] = author._to_minimal_user_json() # type: ignore return cls(data=data, guild=message.guild, state=message._state) # type: ignore def _update_from_message(self, data: MemberPayload) -> None: - self.joined_at = utils.parse_time(data.get('joined_at')) - self.premium_since = utils.parse_time(data.get('premium_since')) - self._roles = utils.SnowflakeList(map(int, data['roles'])) - self.nick = data.get('nick', None) - self.pending = data.get('pending', False) + self.joined_at = utils.parse_time(data.get("joined_at")) + self.premium_since = utils.parse_time(data.get("premium_since")) + self._roles = utils.SnowflakeList(map(int, data["roles"])) + self.nick = data.get("nick", None) + self.pending = data.get("pending", False) @classmethod - def _try_upgrade(cls: Type[M], *, data: UserWithMemberPayload, guild: Guild, state: ConnectionState) -> Union[User, M]: + def _try_upgrade( + cls: Type[M], *, data: UserWithMemberPayload, guild: Guild, state: ConnectionState + ) -> Union[User, M]: # A User object with a 'member' key try: - member_data = data.pop('member') + member_data = data.pop("member") except KeyError: return state.create_user(data) else: - member_data['user'] = data # type: ignore + member_data["user"] = data # type: ignore return cls(data=member_data, guild=guild, state=state) # type: ignore @classmethod @@ -374,25 +378,25 @@ class Member(discord.abc.Messageable, _UserTag): # the nickname change is optional, # if it isn't in the payload then it didn't change try: - self.nick = data['nick'] + self.nick = data["nick"] except KeyError: pass try: - self.pending = data['pending'] + self.pending = data["pending"] except KeyError: pass - self.premium_since = utils.parse_time(data.get('premium_since')) - self._roles = utils.SnowflakeList(map(int, data['roles'])) - self._avatar = data.get('avatar') + self.premium_since = utils.parse_time(data.get("premium_since")) + self._roles = utils.SnowflakeList(map(int, data["roles"])) + self._avatar = data.get("avatar") def _presence_update(self, data: PartialPresenceUpdate, user: UserPayload) -> Optional[Tuple[User, User]]: - self.activities = tuple(map(create_activity, data['activities'])) + self.activities = tuple(map(create_activity, data["activities"])) self._client_status = { - sys.intern(key): sys.intern(value) for key, value in data.get('client_status', {}).items() # type: ignore + sys.intern(key): sys.intern(value) for key, value in data.get("client_status", {}).items() # type: ignore } - self._client_status[None] = sys.intern(data['status']) + self._client_status[None] = sys.intern(data["status"]) if len(user) > 1: return self._update_inner_user(user) @@ -402,7 +406,7 @@ class Member(discord.abc.Messageable, _UserTag): u = self._user original = (u.name, u._avatar, u.discriminator, u._public_flags) # These keys seem to always be available - modified = (user['username'], user['avatar'], user['discriminator'], user.get('public_flags', 0)) + modified = (user["username"], user["avatar"], user["discriminator"], user.get("public_flags", 0)) if original != modified: to_return = User._copy(self._user) u.name, u._avatar, u.discriminator, u._public_flags = modified @@ -430,21 +434,21 @@ class Member(discord.abc.Messageable, _UserTag): @property def mobile_status(self) -> Status: """:class:`Status`: The member's status on a mobile device, if applicable.""" - return try_enum(Status, self._client_status.get('mobile', 'offline')) + return try_enum(Status, self._client_status.get("mobile", "offline")) @property def desktop_status(self) -> Status: """:class:`Status`: The member's status on the desktop client, if applicable.""" - return try_enum(Status, self._client_status.get('desktop', 'offline')) + return try_enum(Status, self._client_status.get("desktop", "offline")) @property def web_status(self) -> Status: """:class:`Status`: The member's status on the web client, if applicable.""" - return try_enum(Status, self._client_status.get('web', 'offline')) + return try_enum(Status, self._client_status.get("web", "offline")) def is_on_mobile(self) -> bool: """:class:`bool`: A helper function that determines if a member is active on a mobile device.""" - return 'mobile' in self._client_status + return "mobile" in self._client_status @property def colour(self) -> Colour: @@ -497,8 +501,8 @@ class Member(discord.abc.Messageable, _UserTag): def mention(self) -> str: """:class:`str`: Returns a string that allows you to mention the member.""" if self.nick: - return f'<@!{self._user.id}>' - return f'<@{self._user.id}>' + return f"<@!{self._user.id}>" + return f"<@{self._user.id}>" @property def display_name(self) -> str: @@ -720,39 +724,39 @@ class Member(discord.abc.Messageable, _UserTag): payload: Dict[str, Any] = {} if nick is not MISSING: - nick = nick or '' + nick = nick or "" if me: await http.change_my_nickname(guild_id, nick, reason=reason) else: - payload['nick'] = nick + payload["nick"] = nick if deafen is not MISSING: - payload['deaf'] = deafen + payload["deaf"] = deafen if mute is not MISSING: - payload['mute'] = mute + payload["mute"] = mute if suppress is not MISSING: voice_state_payload = { - 'channel_id': self.voice.channel.id, - 'suppress': suppress, + "channel_id": self.voice.channel.id, + "suppress": suppress, } if suppress or self.bot: - voice_state_payload['request_to_speak_timestamp'] = None + voice_state_payload["request_to_speak_timestamp"] = None if me: await http.edit_my_voice_state(guild_id, voice_state_payload) else: if not suppress: - voice_state_payload['request_to_speak_timestamp'] = datetime.datetime.utcnow().isoformat() + voice_state_payload["request_to_speak_timestamp"] = datetime.datetime.utcnow().isoformat() await http.edit_voice_state(guild_id, self.id, voice_state_payload) if voice_channel is not MISSING: - payload['channel_id'] = voice_channel and voice_channel.id + payload["channel_id"] = voice_channel and voice_channel.id if roles is not MISSING: - payload['roles'] = tuple(r.id for r in roles) + payload["roles"] = tuple(r.id for r in roles) if payload: data = await http.edit_member(guild_id, self.id, reason=reason, **payload) @@ -780,12 +784,12 @@ class Member(discord.abc.Messageable, _UserTag): The operation failed. """ payload = { - 'channel_id': self.voice.channel.id, - 'request_to_speak_timestamp': datetime.datetime.utcnow().isoformat(), + "channel_id": self.voice.channel.id, + "request_to_speak_timestamp": datetime.datetime.utcnow().isoformat(), } if self._state.self_id != self.id: - payload['suppress'] = False + payload["suppress"] = False await self._state.http.edit_voice_state(self.guild.id, self.id, payload) else: await self._state.http.edit_my_voice_state(self.guild.id, payload) diff --git a/discord/mentions.py b/discord/mentions.py index 0516decf..4aef5387 100644 --- a/discord/mentions.py +++ b/discord/mentions.py @@ -25,9 +25,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations from typing import Type, TypeVar, Union, List, TYPE_CHECKING, Any, Union -__all__ = ( - 'AllowedMentions', -) +__all__ = ("AllowedMentions",) if TYPE_CHECKING: from .types.message import AllowedMentions as AllowedMentionsPayload @@ -36,7 +34,7 @@ if TYPE_CHECKING: class _FakeBool: def __repr__(self): - return 'True' + return "True" def __eq__(self, other): return other is True @@ -47,7 +45,7 @@ class _FakeBool: default: Any = _FakeBool() -A = TypeVar('A', bound='AllowedMentions') +A = TypeVar("A", bound="AllowedMentions") class AllowedMentions: @@ -80,7 +78,7 @@ class AllowedMentions: .. versionadded:: 1.6 """ - __slots__ = ('everyone', 'users', 'roles', 'replied_user') + __slots__ = ("everyone", "users", "roles", "replied_user") def __init__( self, @@ -116,22 +114,22 @@ class AllowedMentions: data = {} if self.everyone: - parse.append('everyone') + parse.append("everyone") if self.users == True: - parse.append('users') + parse.append("users") elif self.users != False: - data['users'] = [x.id for x in self.users] + data["users"] = [x.id for x in self.users] if self.roles == True: - parse.append('roles') + parse.append("roles") elif self.roles != False: - data['roles'] = [x.id for x in self.roles] + data["roles"] = [x.id for x in self.roles] if self.replied_user: - data['replied_user'] = True + data["replied_user"] = True - data['parse'] = parse + data["parse"] = parse return data # type: ignore def merge(self, other: AllowedMentions) -> AllowedMentions: @@ -146,6 +144,6 @@ class AllowedMentions: def __repr__(self) -> str: return ( - f'{self.__class__.__name__}(everyone={self.everyone}, ' - f'users={self.users}, roles={self.roles}, replied_user={self.replied_user})' + f"{self.__class__.__name__}(everyone={self.everyone}, " + f"users={self.users}, roles={self.roles}, replied_user={self.replied_user})" ) diff --git a/discord/message.py b/discord/message.py index 49c5e718..30464bd8 100644 --- a/discord/message.py +++ b/discord/message.py @@ -29,7 +29,21 @@ import datetime import re import io from os import PathLike -from typing import Dict, TYPE_CHECKING, Union, List, Optional, Any, Callable, Tuple, ClassVar, Optional, overload, TypeVar, Type +from typing import ( + Dict, + TYPE_CHECKING, + Union, + List, + Optional, + Any, + Callable, + Tuple, + ClassVar, + Optional, + overload, + TypeVar, + Type, +) from . import utils from .reaction import Reaction @@ -76,15 +90,15 @@ if TYPE_CHECKING: from .role import Role from .ui.view import View - MR = TypeVar('MR', bound='MessageReference') + MR = TypeVar("MR", bound="MessageReference") EmojiInputType = Union[Emoji, PartialEmoji, str] __all__ = ( - 'Attachment', - 'Message', - 'PartialMessage', - 'MessageReference', - 'DeletedReferencedMessage', + "Attachment", + "Message", + "PartialMessage", + "MessageReference", + "DeletedReferencedMessage", ) @@ -93,15 +107,15 @@ def convert_emoji_reaction(emoji): emoji = emoji.emoji if isinstance(emoji, Emoji): - return f'{emoji.name}:{emoji.id}' + return f"{emoji.name}:{emoji.id}" if isinstance(emoji, PartialEmoji): return emoji._as_reaction() if isinstance(emoji, str): # Reactions can be in :name:id format, but not <:name:id>. # No existing emojis have <> in them, so this should be okay. - return emoji.strip('<>') + return emoji.strip("<>") - raise InvalidArgument(f'emoji argument must be str, Emoji, or Reaction not {emoji.__class__.__name__}.') + raise InvalidArgument(f"emoji argument must be str, Emoji, or Reaction not {emoji.__class__.__name__}.") class Attachment(Hashable): @@ -157,28 +171,28 @@ class Attachment(Hashable): .. versionadded:: 1.7 """ - __slots__ = ('id', 'size', 'height', 'width', 'filename', 'url', 'proxy_url', '_http', 'content_type') + __slots__ = ("id", "size", "height", "width", "filename", "url", "proxy_url", "_http", "content_type") def __init__(self, *, data: AttachmentPayload, state: ConnectionState): - self.id: int = int(data['id']) - self.size: int = data['size'] - self.height: Optional[int] = data.get('height') - self.width: Optional[int] = data.get('width') - self.filename: str = data['filename'] - self.url: str = data.get('url') - self.proxy_url: str = data.get('proxy_url') + self.id: int = int(data["id"]) + self.size: int = data["size"] + self.height: Optional[int] = data.get("height") + self.width: Optional[int] = data.get("width") + self.filename: str = data["filename"] + self.url: str = data.get("url") + self.proxy_url: str = data.get("proxy_url") self._http = state.http - self.content_type: Optional[str] = data.get('content_type') + self.content_type: Optional[str] = data.get("content_type") def is_spoiler(self) -> bool: """:class:`bool`: Whether this attachment contains a spoiler.""" - return self.filename.startswith('SPOILER_') + return self.filename.startswith("SPOILER_") def __repr__(self) -> str: - return f'' + return f"" def __str__(self) -> str: - return self.url or '' + return self.url or "" async def save( self, @@ -227,7 +241,7 @@ class Attachment(Hashable): fp.seek(0) return written else: - with open(fp, 'wb') as f: + with open(fp, "wb") as f: return f.write(data) async def read(self, *, use_cached: bool = False) -> bytes: @@ -309,19 +323,19 @@ class Attachment(Hashable): def to_dict(self) -> AttachmentPayload: result: AttachmentPayload = { - 'filename': self.filename, - 'id': self.id, - 'proxy_url': self.proxy_url, - 'size': self.size, - 'url': self.url, - 'spoiler': self.is_spoiler(), + "filename": self.filename, + "id": self.id, + "proxy_url": self.proxy_url, + "size": self.size, + "url": self.url, + "spoiler": self.is_spoiler(), } if self.height: - result['height'] = self.height + result["height"] = self.height if self.width: - result['width'] = self.width + result["width"] = self.width if self.content_type: - result['content_type'] = self.content_type + result["content_type"] = self.content_type return result @@ -335,7 +349,7 @@ class DeletedReferencedMessage: .. versionadded:: 1.6 """ - __slots__ = ('_parent',) + __slots__ = ("_parent",) def __init__(self, parent: MessageReference): self._parent: MessageReference = parent @@ -347,7 +361,7 @@ class DeletedReferencedMessage: def id(self) -> int: """:class:`int`: The message ID of the deleted referenced message.""" # the parent's message id won't be None here - return self._parent.message_id # type: ignore + return self._parent.message_id # type: ignore @property def channel_id(self) -> int: @@ -394,9 +408,11 @@ class MessageReference: .. versionadded:: 1.6 """ - __slots__ = ('message_id', 'channel_id', 'guild_id', 'fail_if_not_exists', 'resolved', '_state') + __slots__ = ("message_id", "channel_id", "guild_id", "fail_if_not_exists", "resolved", "_state") - def __init__(self, *, message_id: int, channel_id: int, guild_id: Optional[int] = None, fail_if_not_exists: bool = True): + def __init__( + self, *, message_id: int, channel_id: int, guild_id: Optional[int] = None, fail_if_not_exists: bool = True + ): self._state: Optional[ConnectionState] = None self.resolved: Optional[Union[Message, DeletedReferencedMessage]] = None self.message_id: Optional[int] = message_id @@ -407,10 +423,10 @@ class MessageReference: @classmethod def with_state(cls: Type[MR], state: ConnectionState, data: MessageReferencePayload) -> MR: self = cls.__new__(cls) - self.message_id = utils._get_as_snowflake(data, 'message_id') - self.channel_id = int(data.pop('channel_id')) - self.guild_id = utils._get_as_snowflake(data, 'guild_id') - self.fail_if_not_exists = data.get('fail_if_not_exists', True) + self.message_id = utils._get_as_snowflake(data, "message_id") + self.channel_id = int(data.pop("channel_id")) + self.guild_id = utils._get_as_snowflake(data, "guild_id") + self.fail_if_not_exists = data.get("fail_if_not_exists", True) self._state = state self.resolved = None return self @@ -439,7 +455,7 @@ class MessageReference: self = cls( message_id=message.id, channel_id=message.channel.id, - guild_id=getattr(message.guild, 'id', None), + guild_id=getattr(message.guild, "id", None), fail_if_not_exists=fail_if_not_exists, ) self._state = message._state @@ -456,36 +472,36 @@ class MessageReference: .. versionadded:: 1.7 """ - guild_id = self.guild_id if self.guild_id is not None else '@me' - return f'https://discord.com/channels/{guild_id}/{self.channel_id}/{self.message_id}' + guild_id = self.guild_id if self.guild_id is not None else "@me" + return f"https://discord.com/channels/{guild_id}/{self.channel_id}/{self.message_id}" def __repr__(self) -> str: - return f'' + return f"" def to_dict(self) -> MessageReferencePayload: - result: MessageReferencePayload = {'message_id': self.message_id} if self.message_id is not None else {} - result['channel_id'] = self.channel_id + result: MessageReferencePayload = {"message_id": self.message_id} if self.message_id is not None else {} + result["channel_id"] = self.channel_id if self.guild_id is not None: - result['guild_id'] = self.guild_id + result["guild_id"] = self.guild_id if self.fail_if_not_exists is not None: - result['fail_if_not_exists'] = self.fail_if_not_exists + result["fail_if_not_exists"] = self.fail_if_not_exists return result to_message_reference_dict = to_dict def flatten_handlers(cls): - prefix = len('_handle_') + prefix = len("_handle_") handlers = [ (key[prefix:], value) for key, value in cls.__dict__.items() - if key.startswith('_handle_') and key != '_handle_member' + if key.startswith("_handle_") and key != "_handle_member" ] # store _handle_member last - handlers.append(('member', cls._handle_member)) + handlers.append(("member", cls._handle_member)) cls._HANDLERS = handlers - cls._CACHED_SLOTS = [attr for attr in cls.__slots__ if attr.startswith('_cs_')] + cls._CACHED_SLOTS = [attr for attr in cls.__slots__ if attr.startswith("_cs_")] return cls @@ -615,36 +631,36 @@ class Message(Hashable): """ __slots__ = ( - '_state', - '_edited_timestamp', - '_cs_channel_mentions', - '_cs_raw_mentions', - '_cs_clean_content', - '_cs_raw_channel_mentions', - '_cs_raw_role_mentions', - '_cs_system_content', - 'tts', - 'content', - 'channel', - 'webhook_id', - 'mention_everyone', - 'embeds', - 'id', - 'mentions', - 'author', - 'attachments', - 'nonce', - 'pinned', - 'role_mentions', - 'type', - 'flags', - 'reactions', - 'reference', - 'application', - 'activity', - 'stickers', - 'components', - 'guild', + "_state", + "_edited_timestamp", + "_cs_channel_mentions", + "_cs_raw_mentions", + "_cs_clean_content", + "_cs_raw_channel_mentions", + "_cs_raw_role_mentions", + "_cs_system_content", + "tts", + "content", + "channel", + "webhook_id", + "mention_everyone", + "embeds", + "id", + "mentions", + "author", + "attachments", + "nonce", + "pinned", + "role_mentions", + "type", + "flags", + "reactions", + "reference", + "application", + "activity", + "stickers", + "components", + "guild", ) if TYPE_CHECKING: @@ -664,39 +680,39 @@ class Message(Hashable): data: MessagePayload, ): self._state: ConnectionState = state - self.id: int = int(data['id']) - self.webhook_id: Optional[int] = utils._get_as_snowflake(data, 'webhook_id') - self.reactions: List[Reaction] = [Reaction(message=self, data=d) for d in data.get('reactions', [])] - self.attachments: List[Attachment] = [Attachment(data=a, state=self._state) for a in data['attachments']] - self.embeds: List[Embed] = [Embed.from_dict(a) for a in data['embeds']] - self.application: Optional[MessageApplicationPayload] = data.get('application') - self.activity: Optional[MessageActivityPayload] = data.get('activity') + self.id: int = int(data["id"]) + self.webhook_id: Optional[int] = utils._get_as_snowflake(data, "webhook_id") + self.reactions: List[Reaction] = [Reaction(message=self, data=d) for d in data.get("reactions", [])] + self.attachments: List[Attachment] = [Attachment(data=a, state=self._state) for a in data["attachments"]] + self.embeds: List[Embed] = [Embed.from_dict(a) for a in data["embeds"]] + self.application: Optional[MessageApplicationPayload] = data.get("application") + self.activity: Optional[MessageActivityPayload] = data.get("activity") self.channel: MessageableChannel = channel - self._edited_timestamp: Optional[datetime.datetime] = utils.parse_time(data['edited_timestamp']) - self.type: MessageType = try_enum(MessageType, data['type']) - self.pinned: bool = data['pinned'] - self.flags: MessageFlags = MessageFlags._from_value(data.get('flags', 0)) - self.mention_everyone: bool = data['mention_everyone'] - self.tts: bool = data['tts'] - self.content: str = data['content'] - self.nonce: Optional[Union[int, str]] = data.get('nonce') - self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get('sticker_items', [])] - self.components: List[Component] = [_component_factory(d) for d in data.get('components', [])] + self._edited_timestamp: Optional[datetime.datetime] = utils.parse_time(data["edited_timestamp"]) + self.type: MessageType = try_enum(MessageType, data["type"]) + self.pinned: bool = data["pinned"] + self.flags: MessageFlags = MessageFlags._from_value(data.get("flags", 0)) + self.mention_everyone: bool = data["mention_everyone"] + self.tts: bool = data["tts"] + self.content: str = data["content"] + self.nonce: Optional[Union[int, str]] = data.get("nonce") + self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get("sticker_items", [])] + self.components: List[Component] = [_component_factory(d) for d in data.get("components", [])] try: # if the channel doesn't have a guild attribute, we handle that self.guild = channel.guild # type: ignore except AttributeError: - self.guild = state._get_guild(utils._get_as_snowflake(data, 'guild_id')) + self.guild = state._get_guild(utils._get_as_snowflake(data, "guild_id")) try: - ref = data['message_reference'] + ref = data["message_reference"] except KeyError: self.reference = None else: self.reference = ref = MessageReference.with_state(state, ref) try: - resolved = data['referenced_message'] + resolved = data["referenced_message"] except KeyError: pass else: @@ -712,18 +728,15 @@ class Message(Hashable): # the channel will be the correct type here ref.resolved = self.__class__(channel=chan, data=resolved, state=state) # type: ignore - for handler in ('author', 'member', 'mentions', 'mention_roles'): + for handler in ("author", "member", "mentions", "mention_roles"): try: - getattr(self, f'_handle_{handler}')(data[handler]) + getattr(self, f"_handle_{handler}")(data[handler]) except KeyError: continue def __repr__(self) -> str: name = self.__class__.__name__ - return ( - f'<{name} id={self.id} channel={self.channel!r} type={self.type!r} author={self.author!r} flags={self.flags!r}>' - ) - + return f"<{name} id={self.id} channel={self.channel!r} type={self.type!r} author={self.author!r} flags={self.flags!r}>" def __str__(self) -> Optional[str]: return self.content @@ -741,7 +754,7 @@ class Message(Hashable): def _add_reaction(self, data, emoji, user_id) -> Reaction: reaction = utils.find(lambda r: r.emoji == emoji, self.reactions) - is_me = data['me'] = user_id == self._state.self_id + is_me = data["me"] = user_id == self._state.self_id if reaction is None: reaction = Reaction(message=self, data=data, emoji=emoji) @@ -758,7 +771,7 @@ class Message(Hashable): if reaction is None: # already removed? - raise ValueError('Emoji already removed?') + raise ValueError("Emoji already removed?") # if reaction isn't in the list, we crash. This means discord # sent bad data, or we stored improperly @@ -872,7 +885,7 @@ class Message(Hashable): return for mention in filter(None, mentions): - id_search = int(mention['id']) + id_search = int(mention["id"]) member = guild.get_member(id_search) if member is not None: r.append(member) @@ -894,7 +907,7 @@ class Message(Hashable): self.guild = new_guild self.channel = new_channel - @utils.cached_slot_property('_cs_raw_mentions') + @utils.cached_slot_property("_cs_raw_mentions") def raw_mentions(self) -> List[int]: """List[:class:`int`]: A property that returns an array of user IDs matched with the syntax of ``<@user_id>`` in the message content. @@ -902,30 +915,30 @@ class Message(Hashable): This allows you to receive the user IDs of mentioned users even in a private message context. """ - return [int(x) for x in re.findall(r'<@!?([0-9]{15,20})>', self.content)] + return [int(x) for x in re.findall(r"<@!?([0-9]{15,20})>", self.content)] - @utils.cached_slot_property('_cs_raw_channel_mentions') + @utils.cached_slot_property("_cs_raw_channel_mentions") def raw_channel_mentions(self) -> List[int]: """List[:class:`int`]: A property that returns an array of channel IDs matched with the syntax of ``<#channel_id>`` in the message content. """ - return [int(x) for x in re.findall(r'<#([0-9]{15,20})>', self.content)] + return [int(x) for x in re.findall(r"<#([0-9]{15,20})>", self.content)] - @utils.cached_slot_property('_cs_raw_role_mentions') + @utils.cached_slot_property("_cs_raw_role_mentions") def raw_role_mentions(self) -> List[int]: """List[:class:`int`]: A property that returns an array of role IDs matched with the syntax of ``<@&role_id>`` in the message content. """ - return [int(x) for x in re.findall(r'<@&([0-9]{15,20})>', self.content)] + return [int(x) for x in re.findall(r"<@&([0-9]{15,20})>", self.content)] - @utils.cached_slot_property('_cs_channel_mentions') + @utils.cached_slot_property("_cs_channel_mentions") def channel_mentions(self) -> List[GuildChannel]: if self.guild is None: return [] it = filter(None, map(self.guild.get_channel, self.raw_channel_mentions)) return utils._unique(it) - @utils.cached_slot_property('_cs_clean_content') + @utils.cached_slot_property("_cs_clean_content") def clean_content(self) -> str: """:class:`str`: A property that returns the content in a "cleaned up" manner. This basically means that mentions are transformed @@ -972,9 +985,9 @@ class Message(Hashable): # fmt: on def repl(obj): - return transformations.get(re.escape(obj.group(0)), '') + return transformations.get(re.escape(obj.group(0)), "") - pattern = re.compile('|'.join(transformations.keys())) + pattern = re.compile("|".join(transformations.keys())) result = pattern.sub(repl, self.content) return escape_mentions(result) @@ -991,8 +1004,8 @@ class Message(Hashable): @property def jump_url(self) -> str: """:class:`str`: Returns a URL that allows the client to jump to this message.""" - guild_id = getattr(self.guild, 'id', '@me') - return f'https://discord.com/channels/{guild_id}/{self.channel.id}/{self.id}' + guild_id = getattr(self.guild, "id", "@me") + return f"https://discord.com/channels/{guild_id}/{self.channel.id}/{self.id}" def is_system(self) -> bool: """:class:`bool`: Whether the message is a system message. @@ -1009,7 +1022,7 @@ class Message(Hashable): MessageType.thread_starter_message, ) - @utils.cached_slot_property('_cs_system_content') + @utils.cached_slot_property("_cs_system_content") def system_content(self): r""":class:`str`: A property that returns the content that is rendered regardless of the :attr:`Message.type`. @@ -1024,24 +1037,24 @@ class Message(Hashable): if self.type is MessageType.recipient_add: if self.channel.type is ChannelType.group: - return f'{self.author.name} added {self.mentions[0].name} to the group.' + return f"{self.author.name} added {self.mentions[0].name} to the group." else: - return f'{self.author.name} added {self.mentions[0].name} to the thread.' + return f"{self.author.name} added {self.mentions[0].name} to the thread." if self.type is MessageType.recipient_remove: if self.channel.type is ChannelType.group: - return f'{self.author.name} removed {self.mentions[0].name} from the group.' + return f"{self.author.name} removed {self.mentions[0].name} from the group." else: - return f'{self.author.name} removed {self.mentions[0].name} from the thread.' + return f"{self.author.name} removed {self.mentions[0].name} from the thread." if self.type is MessageType.channel_name_change: - return f'{self.author.name} changed the channel name: **{self.content}**' + return f"{self.author.name} changed the channel name: **{self.content}**" if self.type is MessageType.channel_icon_change: - return f'{self.author.name} changed the channel icon.' + return f"{self.author.name} changed the channel icon." if self.type is MessageType.pins_add: - return f'{self.author.name} pinned a message to this channel.' + return f"{self.author.name} pinned a message to this channel." if self.type is MessageType.new_member: formats = [ @@ -1065,62 +1078,62 @@ class Message(Hashable): if self.type is MessageType.premium_guild_subscription: if not self.content: - return f'{self.author.name} just boosted the server!' + return f"{self.author.name} just boosted the server!" else: - return f'{self.author.name} just boosted the server **{self.content}** times!' + return f"{self.author.name} just boosted the server **{self.content}** times!" if self.type is MessageType.premium_guild_tier_1: if not self.content: - return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 1!**' + return f"{self.author.name} just boosted the server! {self.guild} has achieved **Level 1!**" else: - return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 1!**' + return f"{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 1!**" if self.type is MessageType.premium_guild_tier_2: if not self.content: - return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 2!**' + return f"{self.author.name} just boosted the server! {self.guild} has achieved **Level 2!**" else: - return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 2!**' + return f"{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 2!**" if self.type is MessageType.premium_guild_tier_3: if not self.content: - return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 3!**' + return f"{self.author.name} just boosted the server! {self.guild} has achieved **Level 3!**" else: - return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 3!**' + return f"{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 3!**" if self.type is MessageType.channel_follow_add: - return f'{self.author.name} has added {self.content} to this channel' + return f"{self.author.name} has added {self.content} to this channel" if self.type is MessageType.guild_stream: # the author will be a Member - return f'{self.author.name} is live! Now streaming {self.author.activity.name}' # type: ignore + return f"{self.author.name} is live! Now streaming {self.author.activity.name}" # type: ignore if self.type is MessageType.guild_discovery_disqualified: - return 'This server has been removed from Server Discovery because it no longer passes all the requirements. Check Server Settings for more details.' + return "This server has been removed from Server Discovery because it no longer passes all the requirements. Check Server Settings for more details." if self.type is MessageType.guild_discovery_requalified: - return 'This server is eligible for Server Discovery again and has been automatically relisted!' + return "This server is eligible for Server Discovery again and has been automatically relisted!" if self.type is MessageType.guild_discovery_grace_period_initial_warning: - return 'This server has failed Discovery activity requirements for 1 week. If this server fails for 4 weeks in a row, it will be automatically removed from Discovery.' + return "This server has failed Discovery activity requirements for 1 week. If this server fails for 4 weeks in a row, it will be automatically removed from Discovery." if self.type is MessageType.guild_discovery_grace_period_final_warning: - return 'This server has failed Discovery activity requirements for 3 weeks in a row. If this server fails for 1 more week, it will be removed from Discovery.' + return "This server has failed Discovery activity requirements for 3 weeks in a row. If this server fails for 1 more week, it will be removed from Discovery." if self.type is MessageType.thread_created: - return f'{self.author.name} started a thread: **{self.content}**. See all **threads**.' + return f"{self.author.name} started a thread: **{self.content}**. See all **threads**." if self.type is MessageType.reply: return self.content if self.type is MessageType.thread_starter_message: if self.reference is None or self.reference.resolved is None: - return 'Sorry, we couldn\'t load the first message in this thread' + return "Sorry, we couldn't load the first message in this thread" # the resolved message for the reference will be a Message return self.reference.resolved.content # type: ignore if self.type is MessageType.guild_invite_reminder: - return 'Wondering who to invite?\nStart by inviting anyone who can help you build the server!' + return "Wondering who to invite?\nStart by inviting anyone who can help you build the server!" async def delete(self, *, delay: Optional[float] = None, silent: bool = False) -> None: """|coro| @@ -1271,45 +1284,45 @@ class Message(Hashable): payload: Dict[str, Any] = {} if content is not MISSING: if content is not None: - payload['content'] = str(content) + payload["content"] = str(content) else: - payload['content'] = None + payload["content"] = None if embed is not MISSING and embeds is not MISSING: - raise InvalidArgument('cannot pass both embed and embeds parameter to edit()') + raise InvalidArgument("cannot pass both embed and embeds parameter to edit()") if embed is not MISSING: if embed is None: - payload['embeds'] = [] + payload["embeds"] = [] else: - payload['embeds'] = [embed.to_dict()] + payload["embeds"] = [embed.to_dict()] elif embeds is not MISSING: - payload['embeds'] = [e.to_dict() for e in embeds] + payload["embeds"] = [e.to_dict() for e in embeds] if suppress is not MISSING: flags = MessageFlags._from_value(self.flags.value) flags.suppress_embeds = suppress - payload['flags'] = flags.value + payload["flags"] = flags.value if allowed_mentions is MISSING: if self._state.allowed_mentions is not None and self.author.id == self._state.self_id: - payload['allowed_mentions'] = self._state.allowed_mentions.to_dict() + payload["allowed_mentions"] = self._state.allowed_mentions.to_dict() else: if allowed_mentions is not None: if self._state.allowed_mentions is not None: - payload['allowed_mentions'] = self._state.allowed_mentions.merge(allowed_mentions).to_dict() + payload["allowed_mentions"] = self._state.allowed_mentions.merge(allowed_mentions).to_dict() else: - payload['allowed_mentions'] = allowed_mentions.to_dict() + payload["allowed_mentions"] = allowed_mentions.to_dict() if attachments is not MISSING: - payload['attachments'] = [a.to_dict() for a in attachments] + payload["attachments"] = [a.to_dict() for a in attachments] if view is not MISSING: self._state.prevent_view_updates_for(self.id) if view: - payload['components'] = view.to_components() + payload["components"] = view.to_components() else: - payload['components'] = [] + payload["components"] = [] data = await self._state.http.edit_message(self.channel.id, self.id, **payload) message = Message(state=self._state, channel=self.channel, data=data) @@ -1551,9 +1564,11 @@ class Message(Hashable): The created thread. """ if self.guild is None: - raise InvalidArgument('This message does not have guild info attached.') + raise InvalidArgument("This message does not have guild info attached.") - default_auto_archive_duration: ThreadArchiveDuration = getattr(self.channel, 'default_auto_archive_duration', 1440) + default_auto_archive_duration: ThreadArchiveDuration = getattr( + self.channel, "default_auto_archive_duration", 1440 + ) data = await self._state.http.start_thread_with_message( self.channel.id, self.id, @@ -1611,12 +1626,12 @@ class Message(Hashable): def to_message_reference_dict(self) -> MessageReferencePayload: data: MessageReferencePayload = { - 'message_id': self.id, - 'channel_id': self.channel.id, + "message_id": self.id, + "channel_id": self.channel.id, } if self.guild is not None: - data['guild_id'] = self.guild.id + data["guild_id"] = self.guild.id return data @@ -1662,7 +1677,7 @@ class PartialMessage(Hashable): The message ID. """ - __slots__ = ('channel', 'id', '_cs_guild', '_state') + __slots__ = ("channel", "id", "_cs_guild", "_state") jump_url: str = Message.jump_url # type: ignore delete = Message.delete @@ -1686,7 +1701,7 @@ class PartialMessage(Hashable): ChannelType.public_thread, ChannelType.private_thread, ): - raise TypeError(f'Expected TextChannel, DMChannel or Thread not {type(channel)!r}') + raise TypeError(f"Expected TextChannel, DMChannel or Thread not {type(channel)!r}") self.channel: PartialMessageableChannel = channel self._state: ConnectionState = channel._state @@ -1702,17 +1717,17 @@ class PartialMessage(Hashable): pinned = property(None, lambda x, y: None) def __repr__(self) -> str: - return f'' + return f"" @property def created_at(self) -> datetime.datetime: """:class:`datetime.datetime`: The partial message's creation time in UTC.""" return utils.snowflake_time(self.id) - @utils.cached_slot_property('_cs_guild') + @utils.cached_slot_property("_cs_guild") def guild(self) -> Optional[Guild]: """Optional[:class:`Guild`]: The guild that the partial message belongs to, if applicable.""" - return getattr(self.channel, 'guild', None) + return getattr(self.channel, "guild", None) async def fetch(self) -> Message: """|coro| @@ -1794,34 +1809,34 @@ class PartialMessage(Hashable): """ try: - content = fields['content'] + content = fields["content"] except KeyError: pass else: if content is not None: - fields['content'] = str(content) + fields["content"] = str(content) try: - embed = fields['embed'] + embed = fields["embed"] except KeyError: pass else: if embed is not None: - fields['embed'] = embed.to_dict() + fields["embed"] = embed.to_dict() try: - suppress: bool = fields.pop('suppress') + suppress: bool = fields.pop("suppress") except KeyError: pass else: flags = MessageFlags._from_value(0) flags.suppress_embeds = suppress - fields['flags'] = flags.value + fields["flags"] = flags.value - delete_after = fields.pop('delete_after', None) + delete_after = fields.pop("delete_after", None) try: - allowed_mentions = fields.pop('allowed_mentions') + allowed_mentions = fields.pop("allowed_mentions") except KeyError: pass else: @@ -1830,19 +1845,19 @@ class PartialMessage(Hashable): allowed_mentions = self._state.allowed_mentions.merge(allowed_mentions).to_dict() else: allowed_mentions = allowed_mentions.to_dict() - fields['allowed_mentions'] = allowed_mentions + fields["allowed_mentions"] = allowed_mentions try: - view = fields.pop('view') + view = fields.pop("view") except KeyError: # To check for the view afterwards view = None else: self._state.prevent_view_updates_for(self.id) if view: - fields['components'] = view.to_components() + fields["components"] = view.to_components() else: - fields['components'] = [] + fields["components"] = [] if fields: data = await self._state.http.edit_message(self.channel.id, self.id, **fields) diff --git a/discord/mixins.py b/discord/mixins.py index fdacf863..0fd42641 100644 --- a/discord/mixins.py +++ b/discord/mixins.py @@ -23,10 +23,11 @@ DEALINGS IN THE SOFTWARE. """ __all__ = ( - 'EqualityComparable', - 'Hashable', + "EqualityComparable", + "Hashable", ) + class EqualityComparable: __slots__ = () @@ -40,6 +41,7 @@ class EqualityComparable: return other.id != self.id return True + class Hashable(EqualityComparable): __slots__ = () diff --git a/discord/object.py b/discord/object.py index 8061a8be..b63242d9 100644 --- a/discord/object.py +++ b/discord/object.py @@ -35,11 +35,11 @@ from typing import ( if TYPE_CHECKING: import datetime + SupportsIntCast = Union[SupportsInt, str, bytes, bytearray] -__all__ = ( - 'Object', -) +__all__ = ("Object",) + class Object(Hashable): """Represents a generic Discord object. @@ -83,12 +83,12 @@ class Object(Hashable): try: id = int(id) except ValueError: - raise TypeError(f'id parameter must be convertable to int not {id.__class__!r}') from None + raise TypeError(f"id parameter must be convertable to int not {id.__class__!r}") from None else: self.id = id def __repr__(self) -> str: - return f'' + return f"" @property def created_at(self) -> datetime.datetime: diff --git a/discord/oggparse.py b/discord/oggparse.py index e0347d2c..ab837838 100644 --- a/discord/oggparse.py +++ b/discord/oggparse.py @@ -31,20 +31,24 @@ from typing import TYPE_CHECKING, ClassVar, IO, Generator, Tuple, Optional from .errors import DiscordException __all__ = ( - 'OggError', - 'OggPage', - 'OggStream', + "OggError", + "OggPage", + "OggStream", ) + class OggError(DiscordException): """An exception that is thrown for Ogg stream parsing errors.""" + pass + # https://tools.ietf.org/html/rfc3533 # https://tools.ietf.org/html/rfc7845 + class OggPage: - _header: ClassVar[struct.Struct] = struct.Struct(' Generator[Tuple[bytes, bool], None, None]: packetlen = offset = 0 @@ -76,7 +79,7 @@ class OggPage: partial = True else: packetlen += seg - yield self.data[offset:offset+packetlen], True + yield self.data[offset : offset + packetlen], True offset += packetlen packetlen = 0 partial = False @@ -84,18 +87,19 @@ class OggPage: if partial: yield self.data[offset:], False + class OggStream: def __init__(self, stream: IO[bytes]) -> None: self.stream: IO[bytes] = stream def _next_page(self) -> Optional[OggPage]: head = self.stream.read(4) - if head == b'OggS': + if head == b"OggS": return OggPage(self.stream) elif not head: return None else: - raise OggError('invalid header magic') + raise OggError("invalid header magic") def _iter_pages(self) -> Generator[OggPage, None, None]: page = self._next_page() @@ -104,10 +108,10 @@ class OggStream: page = self._next_page() def iter_packets(self) -> Generator[bytes, None, None]: - partial = b'' + partial = b"" for page in self._iter_pages(): for data, complete in page.iter_packets(): partial += data if complete: yield partial - partial = b'' + partial = b"" diff --git a/discord/opus.py b/discord/opus.py index 97d437a3..bb2c4bff 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -38,9 +38,10 @@ import sys from .errors import DiscordException, InvalidArgument if TYPE_CHECKING: - T = TypeVar('T') - BAND_CTL = Literal['narrow', 'medium', 'wide', 'superwide', 'full'] - SIGNAL_CTL = Literal['auto', 'voice', 'music'] + T = TypeVar("T") + BAND_CTL = Literal["narrow", "medium", "wide", "superwide", "full"] + SIGNAL_CTL = Literal["auto", "voice", "music"] + class BandCtl(TypedDict): narrow: int @@ -49,81 +50,89 @@ class BandCtl(TypedDict): superwide: int full: int + class SignalCtl(TypedDict): auto: int voice: int music: int + __all__ = ( - 'Encoder', - 'OpusError', - 'OpusNotLoaded', + "Encoder", + "OpusError", + "OpusNotLoaded", ) _log = logging.getLogger(__name__) -c_int_ptr = ctypes.POINTER(ctypes.c_int) +c_int_ptr = ctypes.POINTER(ctypes.c_int) c_int16_ptr = ctypes.POINTER(ctypes.c_int16) c_float_ptr = ctypes.POINTER(ctypes.c_float) _lib = None + class EncoderStruct(ctypes.Structure): pass + class DecoderStruct(ctypes.Structure): pass + EncoderStructPtr = ctypes.POINTER(EncoderStruct) DecoderStructPtr = ctypes.POINTER(DecoderStruct) ## Some constants from opus_defines.h # Error codes -OK = 0 +OK = 0 BAD_ARG = -1 # Encoder CTLs -APPLICATION_AUDIO = 2049 -APPLICATION_VOIP = 2048 +APPLICATION_AUDIO = 2049 +APPLICATION_VOIP = 2048 APPLICATION_LOWDELAY = 2051 -CTL_SET_BITRATE = 4002 -CTL_SET_BANDWIDTH = 4008 -CTL_SET_FEC = 4012 -CTL_SET_PLP = 4014 -CTL_SET_SIGNAL = 4024 +CTL_SET_BITRATE = 4002 +CTL_SET_BANDWIDTH = 4008 +CTL_SET_FEC = 4012 +CTL_SET_PLP = 4014 +CTL_SET_SIGNAL = 4024 # Decoder CTLs -CTL_SET_GAIN = 4034 +CTL_SET_GAIN = 4034 CTL_LAST_PACKET_DURATION = 4039 band_ctl: BandCtl = { - 'narrow': 1101, - 'medium': 1102, - 'wide': 1103, - 'superwide': 1104, - 'full': 1105, + "narrow": 1101, + "medium": 1102, + "wide": 1103, + "superwide": 1104, + "full": 1105, } signal_ctl: SignalCtl = { - 'auto': -1000, - 'voice': 3001, - 'music': 3002, + "auto": -1000, + "voice": 3001, + "music": 3002, } + def _err_lt(result: int, func: Callable, args: List) -> int: if result < OK: - _log.info('error has happened in %s', func.__name__) + _log.info("error has happened in %s", func.__name__) raise OpusError(result) return result + def _err_ne(result: T, func: Callable, args: List) -> T: ret = args[-1]._obj if ret.value != OK: - _log.info('error has happened in %s', func.__name__) + _log.info("error has happened in %s", func.__name__) raise OpusError(ret.value) return result + # A list of exported functions. # The first argument is obviously the name. # The second one are the types of arguments it takes. @@ -131,54 +140,51 @@ def _err_ne(result: T, func: Callable, args: List) -> T: # The fourth is the error handler. exported_functions: List[Tuple[Any, ...]] = [ # Generic - ('opus_get_version_string', - None, ctypes.c_char_p, None), - ('opus_strerror', - [ctypes.c_int], ctypes.c_char_p, None), - + ("opus_get_version_string", None, ctypes.c_char_p, None), + ("opus_strerror", [ctypes.c_int], ctypes.c_char_p, None), # Encoder functions - ('opus_encoder_get_size', - [ctypes.c_int], ctypes.c_int, None), - ('opus_encoder_create', - [ctypes.c_int, ctypes.c_int, ctypes.c_int, c_int_ptr], EncoderStructPtr, _err_ne), - ('opus_encode', - [EncoderStructPtr, c_int16_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int32, _err_lt), - ('opus_encode_float', - [EncoderStructPtr, c_float_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int32, _err_lt), - ('opus_encoder_ctl', - None, ctypes.c_int32, _err_lt), - ('opus_encoder_destroy', - [EncoderStructPtr], None, None), - + ("opus_encoder_get_size", [ctypes.c_int], ctypes.c_int, None), + ("opus_encoder_create", [ctypes.c_int, ctypes.c_int, ctypes.c_int, c_int_ptr], EncoderStructPtr, _err_ne), + ( + "opus_encode", + [EncoderStructPtr, c_int16_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], + ctypes.c_int32, + _err_lt, + ), + ( + "opus_encode_float", + [EncoderStructPtr, c_float_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], + ctypes.c_int32, + _err_lt, + ), + ("opus_encoder_ctl", None, ctypes.c_int32, _err_lt), + ("opus_encoder_destroy", [EncoderStructPtr], None, None), # Decoder functions - ('opus_decoder_get_size', - [ctypes.c_int], ctypes.c_int, None), - ('opus_decoder_create', - [ctypes.c_int, ctypes.c_int, c_int_ptr], DecoderStructPtr, _err_ne), - ('opus_decode', + ("opus_decoder_get_size", [ctypes.c_int], ctypes.c_int, None), + ("opus_decoder_create", [ctypes.c_int, ctypes.c_int, c_int_ptr], DecoderStructPtr, _err_ne), + ( + "opus_decode", [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_int16_ptr, ctypes.c_int, ctypes.c_int], - ctypes.c_int, _err_lt), - ('opus_decode_float', + ctypes.c_int, + _err_lt, + ), + ( + "opus_decode_float", [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_float_ptr, ctypes.c_int, ctypes.c_int], - ctypes.c_int, _err_lt), - ('opus_decoder_ctl', - None, ctypes.c_int32, _err_lt), - ('opus_decoder_destroy', - [DecoderStructPtr], None, None), - ('opus_decoder_get_nb_samples', - [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int, _err_lt), - + ctypes.c_int, + _err_lt, + ), + ("opus_decoder_ctl", None, ctypes.c_int32, _err_lt), + ("opus_decoder_destroy", [DecoderStructPtr], None, None), + ("opus_decoder_get_nb_samples", [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int, _err_lt), # Packet functions - ('opus_packet_get_bandwidth', - [ctypes.c_char_p], ctypes.c_int, _err_lt), - ('opus_packet_get_nb_channels', - [ctypes.c_char_p], ctypes.c_int, _err_lt), - ('opus_packet_get_nb_frames', - [ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt), - ('opus_packet_get_samples_per_frame', - [ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt), + ("opus_packet_get_bandwidth", [ctypes.c_char_p], ctypes.c_int, _err_lt), + ("opus_packet_get_nb_channels", [ctypes.c_char_p], ctypes.c_int, _err_lt), + ("opus_packet_get_nb_frames", [ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt), + ("opus_packet_get_samples_per_frame", [ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt), ] + def libopus_loader(name: str) -> Any: # create the library... lib = ctypes.cdll.LoadLibrary(name) @@ -203,22 +209,24 @@ def libopus_loader(name: str) -> Any: return lib + def _load_default() -> bool: global _lib try: - if sys.platform == 'win32': + if sys.platform == "win32": _basedir = os.path.dirname(os.path.abspath(__file__)) - _bitness = struct.calcsize('P') * 8 - _target = 'x64' if _bitness > 32 else 'x86' - _filename = os.path.join(_basedir, 'bin', f'libopus-0.{_target}.dll') + _bitness = struct.calcsize("P") * 8 + _target = "x64" if _bitness > 32 else "x86" + _filename = os.path.join(_basedir, "bin", f"libopus-0.{_target}.dll") _lib = libopus_loader(_filename) else: - _lib = libopus_loader(ctypes.util.find_library('opus')) + _lib = libopus_loader(ctypes.util.find_library("opus")) except Exception: _lib = None return _lib is not None + def load_opus(name: str) -> None: """Loads the libopus shared library for use with voice. @@ -257,6 +265,7 @@ def load_opus(name: str) -> None: global _lib _lib = libopus_loader(name) + def is_loaded() -> bool: """Function to check if opus lib is successfully loaded either via the :func:`ctypes.util.find_library` call of :func:`load_opus`. @@ -271,6 +280,7 @@ def is_loaded() -> bool: global _lib return _lib is not None + class OpusError(DiscordException): """An exception that is thrown for libopus related errors. @@ -282,19 +292,22 @@ class OpusError(DiscordException): def __init__(self, code: int): self.code: int = code - msg = _lib.opus_strerror(self.code).decode('utf-8') + msg = _lib.opus_strerror(self.code).decode("utf-8") _log.info('"%s" has happened', msg) super().__init__(msg) + class OpusNotLoaded(DiscordException): """An exception that is thrown for when libopus is not loaded.""" + pass + class _OpusStruct: SAMPLING_RATE = 48000 CHANNELS = 2 FRAME_LENGTH = 20 # in milliseconds - SAMPLE_SIZE = struct.calcsize('h') * CHANNELS + SAMPLE_SIZE = struct.calcsize("h") * CHANNELS SAMPLES_PER_FRAME = int(SAMPLING_RATE / 1000 * FRAME_LENGTH) FRAME_SIZE = SAMPLES_PER_FRAME * SAMPLE_SIZE @@ -304,7 +317,8 @@ class _OpusStruct: if not is_loaded() and not _load_default(): raise OpusNotLoaded() - return _lib.opus_get_version_string().decode('utf-8') + return _lib.opus_get_version_string().decode("utf-8") + class Encoder(_OpusStruct): def __init__(self, application: int = APPLICATION_AUDIO): @@ -315,14 +329,14 @@ class Encoder(_OpusStruct): self.set_bitrate(128) self.set_fec(True) self.set_expected_packet_loss_percent(0.15) - self.set_bandwidth('full') - self.set_signal_type('auto') + self.set_bandwidth("full") + self.set_signal_type("auto") def __del__(self) -> None: - if hasattr(self, '_state'): + if hasattr(self, "_state"): _lib.opus_encoder_destroy(self._state) # This is a destructor, so it's okay to assign None - self._state = None # type: ignore + self._state = None # type: ignore def _create_state(self) -> EncoderStruct: ret = ctypes.c_int() @@ -352,18 +366,19 @@ class Encoder(_OpusStruct): _lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0) def set_expected_packet_loss_percent(self, percentage: float) -> None: - _lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore + _lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore def encode(self, pcm: bytes, frame_size: int) -> bytes: max_data_bytes = len(pcm) # bytes can be used to reference pointer - pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore + pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore data = (ctypes.c_char * max_data_bytes)() ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes) # array can be initialized with bytes but mypy doesn't know - return array.array('b', data[:ret]).tobytes() # type: ignore + return array.array("b", data[:ret]).tobytes() # type: ignore + class Decoder(_OpusStruct): def __init__(self): @@ -372,10 +387,10 @@ class Decoder(_OpusStruct): self._state: DecoderStruct = self._create_state() def __del__(self) -> None: - if hasattr(self, '_state'): + if hasattr(self, "_state"): _lib.opus_decoder_destroy(self._state) # This is a destructor, so it's okay to assign None - self._state = None # type: ignore + self._state = None # type: ignore def _create_state(self) -> DecoderStruct: ret = ctypes.c_int() @@ -411,12 +426,12 @@ class Decoder(_OpusStruct): def set_gain(self, dB: float) -> int: """Sets the decoder gain in dB, from -128 to 128.""" - dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8) + dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8) return self._set_gain(dB_Q8) def set_volume(self, mult: float) -> int: """Sets the output volume as a float percent, i.e. 0.5 for 50%, 1.75 for 175%, etc.""" - return self.set_gain(20 * math.log10(mult)) # amplitude ratio + return self.set_gain(20 * math.log10(mult)) # amplitude ratio def _get_last_packet_duration(self) -> int: """Gets the duration (in samples) of the last packet successfully decoded or concealed.""" @@ -428,7 +443,7 @@ class Decoder(_OpusStruct): @overload def decode(self, data: bytes, *, fec: bool) -> bytes: ... - + @overload def decode(self, data: Literal[None], *, fec: Literal[False]) -> bytes: ... @@ -451,4 +466,4 @@ class Decoder(_OpusStruct): ret = _lib.opus_decode(self._state, data, len(data) if data else 0, pcm_ptr, frame_size, fec) - return array.array('h', pcm[:ret * channel_count]).tobytes() + return array.array("h", pcm[: ret * channel_count]).tobytes() diff --git a/discord/partial_emoji.py b/discord/partial_emoji.py index e2c689e2..0973028d 100644 --- a/discord/partial_emoji.py +++ b/discord/partial_emoji.py @@ -31,15 +31,14 @@ from .asset import Asset, AssetMixin from .errors import InvalidArgument from . import utils -__all__ = ( - 'PartialEmoji', -) +__all__ = ("PartialEmoji",) if TYPE_CHECKING: from .state import ConnectionState from datetime import datetime from .types.message import PartialEmoji as PartialEmojiPayload + class _EmojiTag: __slots__ = () @@ -49,7 +48,7 @@ class _EmojiTag: raise NotImplementedError -PE = TypeVar('PE', bound='PartialEmoji') +PE = TypeVar("PE", bound="PartialEmoji") class PartialEmoji(_EmojiTag, AssetMixin): @@ -90,9 +89,9 @@ class PartialEmoji(_EmojiTag, AssetMixin): The ID of the custom emoji, if applicable. """ - __slots__ = ('animated', 'name', 'id', '_state') + __slots__ = ("animated", "name", "id", "_state") - _CUSTOM_EMOJI_RE = re.compile(r'a)?:?(?P[A-Za-z0-9\_]+):(?P[0-9]{13,20})>?') + _CUSTOM_EMOJI_RE = re.compile(r"a)?:?(?P[A-Za-z0-9\_]+):(?P[0-9]{13,20})>?") if TYPE_CHECKING: id: Optional[int] @@ -106,9 +105,9 @@ class PartialEmoji(_EmojiTag, AssetMixin): @classmethod def from_dict(cls: Type[PE], data: Union[PartialEmojiPayload, Dict[str, Any]]) -> PE: return cls( - animated=data.get('animated', False), - id=utils._get_as_snowflake(data, 'id'), - name=data.get('name') or '', + animated=data.get("animated", False), + id=utils._get_as_snowflake(data, "id"), + name=data.get("name") or "", ) @classmethod @@ -139,19 +138,19 @@ class PartialEmoji(_EmojiTag, AssetMixin): match = cls._CUSTOM_EMOJI_RE.match(value) if match is not None: groups = match.groupdict() - animated = bool(groups['animated']) - emoji_id = int(groups['id']) - name = groups['name'] + animated = bool(groups["animated"]) + emoji_id = int(groups["id"]) + name = groups["name"] return cls(name=name, animated=animated, id=emoji_id) return cls(name=value, id=None, animated=False) def to_dict(self) -> Dict[str, Any]: - o: Dict[str, Any] = {'name': self.name} + o: Dict[str, Any] = {"name": self.name} if self.id: - o['id'] = self.id + o["id"] = self.id if self.animated: - o['animated'] = self.animated + o["animated"] = self.animated return o def _to_partial(self) -> PartialEmoji: @@ -169,11 +168,11 @@ class PartialEmoji(_EmojiTag, AssetMixin): if self.id is None: return self.name if self.animated: - return f'' - return f'<:{self.name}:{self.id}>' + return f"" + return f"<:{self.name}:{self.id}>" def __repr__(self): - return f'<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>' + return f"<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>" def __eq__(self, other: Any) -> bool: if self.is_unicode_emoji(): @@ -200,7 +199,7 @@ class PartialEmoji(_EmojiTag, AssetMixin): def _as_reaction(self) -> str: if self.id is None: return self.name - return f'{self.name}:{self.id}' + return f"{self.name}:{self.id}" @property def created_at(self) -> Optional[datetime]: @@ -220,13 +219,13 @@ class PartialEmoji(_EmojiTag, AssetMixin): If this isn't a custom emoji then an empty string is returned """ if self.is_unicode_emoji(): - return '' + return "" - fmt = 'gif' if self.animated else 'png' - return f'{Asset.BASE}/emojis/{self.id}.{fmt}' + fmt = "gif" if self.animated else "png" + return f"{Asset.BASE}/emojis/{self.id}.{fmt}" async def read(self) -> bytes: if self.is_unicode_emoji(): - raise InvalidArgument('PartialEmoji is not a custom emoji') + raise InvalidArgument("PartialEmoji is not a custom emoji") return await super().read() diff --git a/discord/permissions.py b/discord/permissions.py index 4b3d9830..1bc27bc0 100644 --- a/discord/permissions.py +++ b/discord/permissions.py @@ -28,8 +28,8 @@ from typing import Callable, Any, ClassVar, Dict, Iterator, Set, TYPE_CHECKING, from .flags import BaseFlags, flag_value, fill_with_flags, alias_flag_value __all__ = ( - 'Permissions', - 'PermissionOverwrite', + "Permissions", + "PermissionOverwrite", ) # A permission alias works like a regular flag but is marked @@ -46,7 +46,9 @@ def make_permission_alias(alias: str) -> Callable[[Callable[[Any], int]], permis return decorator -P = TypeVar('P', bound='Permissions') + +P = TypeVar("P", bound="Permissions") + @fill_with_flags() class Permissions(BaseFlags): @@ -101,12 +103,12 @@ class Permissions(BaseFlags): def __init__(self, permissions: int = 0, **kwargs: bool): if not isinstance(permissions, int): - raise TypeError(f'Expected int parameter, received {permissions.__class__.__name__} instead.') + raise TypeError(f"Expected int parameter, received {permissions.__class__.__name__} instead.") self.value = permissions for key, value in kwargs.items(): if key not in self.VALID_FLAGS: - raise TypeError(f'{key!r} is not a valid permission name.') + raise TypeError(f"{key!r} is not a valid permission name.") setattr(self, key, value) def is_subset(self, other: Permissions) -> bool: @@ -299,7 +301,7 @@ class Permissions(BaseFlags): """ return 1 << 3 - @make_permission_alias('administrator') + @make_permission_alias("administrator") def admin(self) -> int: """:class:`bool`: An alias for :attr:`administrator`. .. versionadded:: 2.0 @@ -343,7 +345,7 @@ class Permissions(BaseFlags): """:class:`bool`: Returns ``True`` if a user can read messages from all or specific text channels.""" return 1 << 10 - @make_permission_alias('read_messages') + @make_permission_alias("read_messages") def view_channel(self) -> int: """:class:`bool`: An alias for :attr:`read_messages`. @@ -396,7 +398,7 @@ class Permissions(BaseFlags): """:class:`bool`: Returns ``True`` if a user can use emojis from other guilds.""" return 1 << 18 - @make_permission_alias('external_emojis') + @make_permission_alias("external_emojis") def use_external_emojis(self) -> int: """:class:`bool`: An alias for :attr:`external_emojis`. @@ -460,7 +462,7 @@ class Permissions(BaseFlags): """ return 1 << 28 - @make_permission_alias('manage_roles') + @make_permission_alias("manage_roles") def manage_permissions(self) -> int: """:class:`bool`: An alias for :attr:`manage_roles`. @@ -478,7 +480,7 @@ class Permissions(BaseFlags): """:class:`bool`: Returns ``True`` if a user can create, edit, or delete emojis.""" return 1 << 30 - @make_permission_alias('manage_emojis') + @make_permission_alias("manage_emojis") def manage_emojis_and_stickers(self) -> int: """:class:`bool`: An alias for :attr:`manage_emojis`. @@ -542,7 +544,7 @@ class Permissions(BaseFlags): """ return 1 << 37 - @make_permission_alias('external_stickers') + @make_permission_alias("external_stickers") def use_external_stickers(self) -> int: """:class:`bool`: An alias for :attr:`external_stickers`. @@ -558,7 +560,9 @@ class Permissions(BaseFlags): """ return 1 << 38 -PO = TypeVar('PO', bound='PermissionOverwrite') + +PO = TypeVar("PO", bound="PermissionOverwrite") + def _augment_from_permissions(cls): cls.VALID_NAMES = set(Permissions.VALID_FLAGS) @@ -621,7 +625,7 @@ class PermissionOverwrite: Set the value of permissions by their name. """ - __slots__ = ('_values',) + __slots__ = ("_values",) if TYPE_CHECKING: VALID_NAMES: ClassVar[Set[str]] @@ -677,7 +681,7 @@ class PermissionOverwrite: for key, value in kwargs.items(): if key not in self.VALID_NAMES: - raise ValueError(f'no permission called {key}.') + raise ValueError(f"no permission called {key}.") setattr(self, key, value) @@ -686,7 +690,7 @@ class PermissionOverwrite: def _set(self, key: str, value: Optional[bool]) -> None: if value not in (True, None, False): - raise TypeError(f'Expected bool or NoneType, received {value.__class__.__name__}') + raise TypeError(f"Expected bool or NoneType, received {value.__class__.__name__}") if value is None: self._values.pop(key, None) diff --git a/discord/player.py b/discord/player.py index 8098d3e3..36f51273 100644 --- a/discord/player.py +++ b/discord/player.py @@ -36,7 +36,7 @@ import sys import re import io -from typing import Any, Callable, Generic, IO, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union +from typing import Any, Callable, Generic, IO, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union from .errors import ClientException from .opus import Encoder as OpusEncoder @@ -47,27 +47,28 @@ if TYPE_CHECKING: from .voice_client import VoiceClient -AT = TypeVar('AT', bound='AudioSource') -FT = TypeVar('FT', bound='FFmpegOpusAudio') +AT = TypeVar("AT", bound="AudioSource") +FT = TypeVar("FT", bound="FFmpegOpusAudio") _log = logging.getLogger(__name__) __all__ = ( - 'AudioSource', - 'PCMAudio', - 'FFmpegAudio', - 'FFmpegPCMAudio', - 'FFmpegOpusAudio', - 'PCMVolumeTransformer', + "AudioSource", + "PCMAudio", + "FFmpegAudio", + "FFmpegPCMAudio", + "FFmpegOpusAudio", + "PCMVolumeTransformer", ) CREATE_NO_WINDOW: int -if sys.platform != 'win32': +if sys.platform != "win32": CREATE_NO_WINDOW = 0 else: CREATE_NO_WINDOW = 0x08000000 + class AudioSource: """Represents an audio stream. @@ -114,6 +115,7 @@ class AudioSource: def __del__(self) -> None: self.cleanup() + class PCMAudio(AudioSource): """Represents raw 16-bit 48KHz stereo PCM audio source. @@ -122,15 +124,17 @@ class PCMAudio(AudioSource): stream: :term:`py:file object` A file-like object that reads byte data representing raw PCM. """ + def __init__(self, stream: io.BufferedIOBase) -> None: self.stream: io.BufferedIOBase = stream def read(self) -> bytes: ret = self.stream.read(OpusEncoder.FRAME_SIZE) if len(ret) != OpusEncoder.FRAME_SIZE: - return b'' + return b"" return ret + class FFmpegAudio(AudioSource): """Represents an FFmpeg (or AVConv) based AudioSource. @@ -140,13 +144,15 @@ class FFmpegAudio(AudioSource): .. versionadded:: 1.3 """ - def __init__(self, source: Union[str, io.BufferedIOBase], *, executable: str = 'ffmpeg', args: Any, **subprocess_kwargs: Any): - piping = subprocess_kwargs.get('stdin') == subprocess.PIPE + def __init__( + self, source: Union[str, io.BufferedIOBase], *, executable: str = "ffmpeg", args: Any, **subprocess_kwargs: Any + ): + piping = subprocess_kwargs.get("stdin") == subprocess.PIPE if piping and isinstance(source, str): raise TypeError("parameter conflict: 'source' parameter cannot be a string when piping to stdin") args = [executable, *args] - kwargs = {'stdout': subprocess.PIPE} + kwargs = {"stdout": subprocess.PIPE} kwargs.update(subprocess_kwargs) self._process: subprocess.Popen = self._spawn_process(args, **kwargs) @@ -155,7 +161,7 @@ class FFmpegAudio(AudioSource): self._pipe_thread: Optional[threading.Thread] = None if piping: - n = f'popen-stdin-writer:{id(self):#x}' + n = f"popen-stdin-writer:{id(self):#x}" self._stdin = self._process.stdin self._pipe_thread = threading.Thread(target=self._pipe_writer, args=(source,), daemon=True, name=n) self._pipe_thread.start() @@ -165,10 +171,10 @@ class FFmpegAudio(AudioSource): try: process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs) except FileNotFoundError: - executable = args.partition(' ')[0] if isinstance(args, str) else args[0] - raise ClientException(executable + ' was not found.') from None + executable = args.partition(" ")[0] if isinstance(args, str) else args[0] + raise ClientException(executable + " was not found.") from None except subprocess.SubprocessError as exc: - raise ClientException(f'Popen failed: {exc.__class__.__name__}: {exc}') from exc + raise ClientException(f"Popen failed: {exc.__class__.__name__}: {exc}") from exc else: return process @@ -177,20 +183,19 @@ class FFmpegAudio(AudioSource): if proc is MISSING: return - _log.info('Preparing to terminate ffmpeg process %s.', proc.pid) + _log.info("Preparing to terminate ffmpeg process %s.", proc.pid) try: proc.kill() except Exception: - _log.exception('Ignoring error attempting to kill ffmpeg process %s', proc.pid) + _log.exception("Ignoring error attempting to kill ffmpeg process %s", proc.pid) if proc.poll() is None: - _log.info('ffmpeg process %s has not terminated. Waiting to terminate...', proc.pid) + _log.info("ffmpeg process %s has not terminated. Waiting to terminate...", proc.pid) proc.communicate() - _log.info('ffmpeg process %s should have terminated with a return code of %s.', proc.pid, proc.returncode) + _log.info("ffmpeg process %s should have terminated with a return code of %s.", proc.pid, proc.returncode) else: - _log.info('ffmpeg process %s successfully terminated with return code of %s.', proc.pid, proc.returncode) - + _log.info("ffmpeg process %s successfully terminated with return code of %s.", proc.pid, proc.returncode) def _pipe_writer(self, source: io.BufferedIOBase) -> None: while self._process: @@ -202,7 +207,7 @@ class FFmpegAudio(AudioSource): try: self._stdin.write(data) except Exception: - _log.debug('Write error for %s, this is probably not a problem', self, exc_info=True) + _log.debug("Write error for %s, this is probably not a problem", self, exc_info=True) # at this point the source data is either exhausted or the process is fubar self._process.terminate() return @@ -211,6 +216,7 @@ class FFmpegAudio(AudioSource): self._kill_process() self._process = self._stdout = self._stdin = MISSING + class FFmpegPCMAudio(FFmpegAudio): """An audio source from FFmpeg (or AVConv). @@ -250,38 +256,39 @@ class FFmpegPCMAudio(FFmpegAudio): self, source: Union[str, io.BufferedIOBase], *, - executable: str = 'ffmpeg', + executable: str = "ffmpeg", pipe: bool = False, stderr: Optional[IO[str]] = None, before_options: Optional[str] = None, - options: Optional[str] = None + options: Optional[str] = None, ) -> None: args = [] - subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr} + subprocess_kwargs = {"stdin": subprocess.PIPE if pipe else subprocess.DEVNULL, "stderr": stderr} if isinstance(before_options, str): args.extend(shlex.split(before_options)) - args.append('-i') - args.append('-' if pipe else source) - args.extend(('-f', 's16le', '-ar', '48000', '-ac', '2', '-loglevel', 'warning')) + args.append("-i") + args.append("-" if pipe else source) + args.extend(("-f", "s16le", "-ar", "48000", "-ac", "2", "-loglevel", "warning")) if isinstance(options, str): args.extend(shlex.split(options)) - args.append('pipe:1') + args.append("pipe:1") super().__init__(source, executable=executable, args=args, **subprocess_kwargs) def read(self) -> bytes: ret = self._stdout.read(OpusEncoder.FRAME_SIZE) if len(ret) != OpusEncoder.FRAME_SIZE: - return b'' + return b"" return ret def is_opus(self) -> bool: return False + class FFmpegOpusAudio(FFmpegAudio): """An audio source from FFmpeg (or AVConv). @@ -349,7 +356,7 @@ class FFmpegOpusAudio(FFmpegAudio): *, bitrate: int = 128, codec: Optional[str] = None, - executable: str = 'ffmpeg', + executable: str = "ffmpeg", pipe=False, stderr=None, before_options=None, @@ -357,28 +364,39 @@ class FFmpegOpusAudio(FFmpegAudio): ) -> None: args = [] - subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr} + subprocess_kwargs = {"stdin": subprocess.PIPE if pipe else subprocess.DEVNULL, "stderr": stderr} if isinstance(before_options, str): args.extend(shlex.split(before_options)) - args.append('-i') - args.append('-' if pipe else source) + args.append("-i") + args.append("-" if pipe else source) - codec = 'copy' if codec in ('opus', 'libopus') else 'libopus' + codec = "copy" if codec in ("opus", "libopus") else "libopus" - args.extend(('-map_metadata', '-1', - '-f', 'opus', - '-c:a', codec, - '-ar', '48000', - '-ac', '2', - '-b:a', f'{bitrate}k', - '-loglevel', 'warning')) + args.extend( + ( + "-map_metadata", + "-1", + "-f", + "opus", + "-c:a", + codec, + "-ar", + "48000", + "-ac", + "2", + "-b:a", + f"{bitrate}k", + "-loglevel", + "warning", + ) + ) if isinstance(options, str): args.extend(shlex.split(options)) - args.append('pipe:1') + args.append("pipe:1") super().__init__(source, executable=executable, args=args, **subprocess_kwargs) self._packet_iter = OggStream(self._stdout).iter_packets() @@ -446,7 +464,7 @@ class FFmpegOpusAudio(FFmpegAudio): An instance of this class. """ - executable = kwargs.get('executable') + executable = kwargs.get("executable") codec, bitrate = await cls.probe(source, method=method, executable=executable) return cls(source, bitrate=bitrate, codec=codec, **kwargs) # type: ignore @@ -484,12 +502,12 @@ class FFmpegOpusAudio(FFmpegAudio): A 2-tuple with the codec and bitrate of the input source. """ - method = method or 'native' - executable = executable or 'ffmpeg' + method = method or "native" + executable = executable or "ffmpeg" probefunc = fallback = None if isinstance(method, str): - probefunc = getattr(cls, '_probe_codec_' + method, None) + probefunc = getattr(cls, "_probe_codec_" + method, None) if probefunc is None: raise AttributeError(f"Invalid probe method {method!r}") @@ -500,8 +518,7 @@ class FFmpegOpusAudio(FFmpegAudio): probefunc = method fallback = cls._probe_codec_fallback else: - raise TypeError("Expected str or callable for parameter 'probe', " \ - f"not '{method.__class__.__name__}'") + raise TypeError("Expected str or callable for parameter 'probe', " f"not '{method.__class__.__name__}'") codec = bitrate = None loop = asyncio.get_event_loop() @@ -525,28 +542,28 @@ class FFmpegOpusAudio(FFmpegAudio): return codec, bitrate @staticmethod - def _probe_codec_native(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]: - exe = executable[:2] + 'probe' if executable in ('ffmpeg', 'avconv') else executable - args = [exe, '-v', 'quiet', '-print_format', 'json', '-show_streams', '-select_streams', 'a:0', source] + def _probe_codec_native(source, executable: str = "ffmpeg") -> Tuple[Optional[str], Optional[int]]: + exe = executable[:2] + "probe" if executable in ("ffmpeg", "avconv") else executable + args = [exe, "-v", "quiet", "-print_format", "json", "-show_streams", "-select_streams", "a:0", source] output = subprocess.check_output(args, timeout=20) codec = bitrate = None if output: data = json.loads(output) - streamdata = data['streams'][0] + streamdata = data["streams"][0] - codec = streamdata.get('codec_name') - bitrate = int(streamdata.get('bit_rate', 0)) - bitrate = max(round(bitrate/1000), 512) + codec = streamdata.get("codec_name") + bitrate = int(streamdata.get("bit_rate", 0)) + bitrate = max(round(bitrate / 1000), 512) return codec, bitrate @staticmethod - def _probe_codec_fallback(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]: - args = [executable, '-hide_banner', '-i', source] + def _probe_codec_fallback(source, executable: str = "ffmpeg") -> Tuple[Optional[str], Optional[int]]: + args = [executable, "-hide_banner", "-i", source] proc = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) out, _ = proc.communicate(timeout=20) - output = out.decode('utf8') + output = out.decode("utf8") codec = bitrate = None codec_match = re.search(r"Stream #0.*?Audio: (\w+)", output) @@ -560,11 +577,12 @@ class FFmpegOpusAudio(FFmpegAudio): return codec, bitrate def read(self) -> bytes: - return next(self._packet_iter, b'') + return next(self._packet_iter, b"") def is_opus(self) -> bool: return True + class PCMVolumeTransformer(AudioSource, Generic[AT]): """Transforms a previous :class:`AudioSource` to have volume controls. @@ -589,10 +607,10 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]): def __init__(self, original: AT, volume: float = 1.0): if not isinstance(original, AudioSource): - raise TypeError(f'expected AudioSource not {original.__class__.__name__}.') + raise TypeError(f"expected AudioSource not {original.__class__.__name__}.") if original.is_opus(): - raise ClientException('AudioSource must not be Opus encoded.') + raise ClientException("AudioSource must not be Opus encoded.") self.original: AT = original self.volume = volume @@ -613,6 +631,7 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]): ret = self.original.read() return audioop.mul(ret, 2, min(self._volume, 2.0)) + class AudioPlayer(threading.Thread): DELAY: float = OpusEncoder.FRAME_LENGTH / 1000.0 @@ -625,7 +644,7 @@ class AudioPlayer(threading.Thread): self._end: threading.Event = threading.Event() self._resumed: threading.Event = threading.Event() - self._resumed.set() # we are not paused + self._resumed.set() # we are not paused self._current_error: Optional[Exception] = None self._connected: threading.Event = client._connected self._lock: threading.Lock = threading.Lock() @@ -685,11 +704,11 @@ class AudioPlayer(threading.Thread): try: self.after(error) except Exception as exc: - _log.exception('Calling the after function failed.') + _log.exception("Calling the after function failed.") exc.__context__ = error traceback.print_exception(type(exc), exc, exc.__traceback__) elif error: - msg = f'Exception in voice thread {self.name}' + msg = f"Exception in voice thread {self.name}" _log.exception(msg, exc_info=error) print(msg, file=sys.stderr) traceback.print_exception(type(error), error, error.__traceback__) diff --git a/discord/raw_models.py b/discord/raw_models.py index 3c9360ba..929ece2a 100644 --- a/discord/raw_models.py +++ b/discord/raw_models.py @@ -36,7 +36,7 @@ if TYPE_CHECKING: ReactionClearEvent, ReactionClearEmojiEvent, IntegrationDeleteEvent, - TypingEvent + TypingEvent, ) from .message import Message from .partial_emoji import PartialEmoji @@ -44,21 +44,21 @@ if TYPE_CHECKING: __all__ = ( - 'RawMessageDeleteEvent', - 'RawBulkMessageDeleteEvent', - 'RawMessageUpdateEvent', - 'RawReactionActionEvent', - 'RawReactionClearEvent', - 'RawReactionClearEmojiEvent', - 'RawIntegrationDeleteEvent', - 'RawTypingEvent' + "RawMessageDeleteEvent", + "RawBulkMessageDeleteEvent", + "RawMessageUpdateEvent", + "RawReactionActionEvent", + "RawReactionClearEvent", + "RawReactionClearEmojiEvent", + "RawIntegrationDeleteEvent", + "RawTypingEvent", ) class _RawReprMixin: def __repr__(self) -> str: - value = ' '.join(f'{attr}={getattr(self, attr)!r}' for attr in self.__slots__) - return f'<{self.__class__.__name__} {value}>' + value = " ".join(f"{attr}={getattr(self, attr)!r}" for attr in self.__slots__) + return f"<{self.__class__.__name__} {value}>" class RawMessageDeleteEvent(_RawReprMixin): @@ -76,14 +76,14 @@ class RawMessageDeleteEvent(_RawReprMixin): The cached message, if found in the internal message cache. """ - __slots__ = ('message_id', 'channel_id', 'guild_id', 'cached_message') + __slots__ = ("message_id", "channel_id", "guild_id", "cached_message") def __init__(self, data: MessageDeleteEvent) -> None: - self.message_id: int = int(data['id']) - self.channel_id: int = int(data['channel_id']) + self.message_id: int = int(data["id"]) + self.channel_id: int = int(data["channel_id"]) self.cached_message: Optional[Message] = None try: - self.guild_id: Optional[int] = int(data['guild_id']) + self.guild_id: Optional[int] = int(data["guild_id"]) except KeyError: self.guild_id: Optional[int] = None @@ -103,15 +103,15 @@ class RawBulkMessageDeleteEvent(_RawReprMixin): The cached messages, if found in the internal message cache. """ - __slots__ = ('message_ids', 'channel_id', 'guild_id', 'cached_messages') + __slots__ = ("message_ids", "channel_id", "guild_id", "cached_messages") def __init__(self, data: BulkMessageDeleteEvent) -> None: - self.message_ids: Set[int] = {int(x) for x in data.get('ids', [])} - self.channel_id: int = int(data['channel_id']) + self.message_ids: Set[int] = {int(x) for x in data.get("ids", [])} + self.channel_id: int = int(data["channel_id"]) self.cached_messages: List[Message] = [] try: - self.guild_id: Optional[int] = int(data['guild_id']) + self.guild_id: Optional[int] = int(data["guild_id"]) except KeyError: self.guild_id: Optional[int] = None @@ -139,16 +139,16 @@ class RawMessageUpdateEvent(_RawReprMixin): it is modified by the data in :attr:`RawMessageUpdateEvent.data`. """ - __slots__ = ('message_id', 'channel_id', 'guild_id', 'data', 'cached_message') + __slots__ = ("message_id", "channel_id", "guild_id", "data", "cached_message") def __init__(self, data: MessageUpdateEvent) -> None: - self.message_id: int = int(data['id']) - self.channel_id: int = int(data['channel_id']) + self.message_id: int = int(data["id"]) + self.channel_id: int = int(data["channel_id"]) self.data: MessageUpdateEvent = data self.cached_message: Optional[Message] = None try: - self.guild_id: Optional[int] = int(data['guild_id']) + self.guild_id: Optional[int] = int(data["guild_id"]) except KeyError: self.guild_id: Optional[int] = None @@ -182,19 +182,18 @@ class RawReactionActionEvent(_RawReprMixin): .. versionadded:: 1.3 """ - __slots__ = ('message_id', 'user_id', 'channel_id', 'guild_id', 'emoji', - 'event_type', 'member') + __slots__ = ("message_id", "user_id", "channel_id", "guild_id", "emoji", "event_type", "member") def __init__(self, data: ReactionActionEvent, emoji: PartialEmoji, event_type: str) -> None: - self.message_id: int = int(data['message_id']) - self.channel_id: int = int(data['channel_id']) - self.user_id: int = int(data['user_id']) + self.message_id: int = int(data["message_id"]) + self.channel_id: int = int(data["channel_id"]) + self.user_id: int = int(data["user_id"]) self.emoji: PartialEmoji = emoji self.event_type: str = event_type self.member: Optional[Member] = None try: - self.guild_id: Optional[int] = int(data['guild_id']) + self.guild_id: Optional[int] = int(data["guild_id"]) except KeyError: self.guild_id: Optional[int] = None @@ -212,14 +211,14 @@ class RawReactionClearEvent(_RawReprMixin): The guild ID where the reactions got cleared. """ - __slots__ = ('message_id', 'channel_id', 'guild_id') + __slots__ = ("message_id", "channel_id", "guild_id") def __init__(self, data: ReactionClearEvent) -> None: - self.message_id: int = int(data['message_id']) - self.channel_id: int = int(data['channel_id']) + self.message_id: int = int(data["message_id"]) + self.channel_id: int = int(data["channel_id"]) try: - self.guild_id: Optional[int] = int(data['guild_id']) + self.guild_id: Optional[int] = int(data["guild_id"]) except KeyError: self.guild_id: Optional[int] = None @@ -241,15 +240,15 @@ class RawReactionClearEmojiEvent(_RawReprMixin): The custom or unicode emoji being removed. """ - __slots__ = ('message_id', 'channel_id', 'guild_id', 'emoji') + __slots__ = ("message_id", "channel_id", "guild_id", "emoji") def __init__(self, data: ReactionClearEmojiEvent, emoji: PartialEmoji) -> None: self.emoji: PartialEmoji = emoji - self.message_id: int = int(data['message_id']) - self.channel_id: int = int(data['channel_id']) + self.message_id: int = int(data["message_id"]) + self.channel_id: int = int(data["channel_id"]) try: - self.guild_id: Optional[int] = int(data['guild_id']) + self.guild_id: Optional[int] = int(data["guild_id"]) except KeyError: self.guild_id: Optional[int] = None @@ -269,14 +268,14 @@ class RawIntegrationDeleteEvent(_RawReprMixin): The guild ID where the integration got deleted. """ - __slots__ = ('integration_id', 'application_id', 'guild_id') + __slots__ = ("integration_id", "application_id", "guild_id") def __init__(self, data: IntegrationDeleteEvent) -> None: - self.integration_id: int = int(data['id']) - self.guild_id: int = int(data['guild_id']) + self.integration_id: int = int(data["id"]) + self.guild_id: int = int(data["guild_id"]) try: - self.application_id: Optional[int] = int(data['application_id']) + self.application_id: Optional[int] = int(data["application_id"]) except KeyError: self.application_id: Optional[int] = None @@ -303,12 +302,12 @@ class RawTypingEvent(_RawReprMixin): __slots__ = ("channel_id", "user_id", "when", "guild_id", "member") def __init__(self, data: TypingEvent) -> None: - self.channel_id: int = int(data['channel_id']) - self.user_id: int = int(data['user_id']) - self.when: datetime.datetime = datetime.datetime.fromtimestamp(data.get('timestamp'), tz=datetime.timezone.utc) + self.channel_id: int = int(data["channel_id"]) + self.user_id: int = int(data["user_id"]) + self.when: datetime.datetime = datetime.datetime.fromtimestamp(data.get("timestamp"), tz=datetime.timezone.utc) self.member: Optional[Member] = None - + try: - self.guild_id: Optional[int] = int(data['guild_id']) + self.guild_id: Optional[int] = int(data["guild_id"]) except KeyError: - self.guild_id: Optional[int] = None \ No newline at end of file + self.guild_id: Optional[int] = None diff --git a/discord/reaction.py b/discord/reaction.py index 04eee342..a7cc3179 100644 --- a/discord/reaction.py +++ b/discord/reaction.py @@ -27,9 +27,7 @@ from typing import Any, TYPE_CHECKING, Union, Optional from .iterators import ReactionIterator -__all__ = ( - 'Reaction', -) +__all__ = ("Reaction",) if TYPE_CHECKING: from .types.message import Reaction as ReactionPayload @@ -38,6 +36,7 @@ if TYPE_CHECKING: from .emoji import Emoji from .abc import Snowflake + class Reaction: """Represents a reaction to a message. @@ -75,13 +74,16 @@ class Reaction: message: :class:`Message` Message this reaction is for. """ - __slots__ = ('message', 'count', 'emoji', 'me') - def __init__(self, *, message: Message, data: ReactionPayload, emoji: Optional[Union[PartialEmoji, Emoji, str]] = None): + __slots__ = ("message", "count", "emoji", "me") + + def __init__( + self, *, message: Message, data: ReactionPayload, emoji: Optional[Union[PartialEmoji, Emoji, str]] = None + ): self.message: Message = message - self.emoji: Union[PartialEmoji, Emoji, str] = emoji or message._state.get_reaction_emoji(data['emoji']) - self.count: int = data.get('count', 1) - self.me: bool = data.get('me') + self.emoji: Union[PartialEmoji, Emoji, str] = emoji or message._state.get_reaction_emoji(data["emoji"]) + self.count: int = data.get("count", 1) + self.me: bool = data.get("me") # TODO: typeguard def is_custom_emoji(self) -> bool: @@ -103,7 +105,7 @@ class Reaction: return str(self.emoji) def __repr__(self) -> str: - return f'' + return f"" async def remove(self, user: Snowflake) -> None: """|coro| @@ -201,7 +203,7 @@ class Reaction: """ if not isinstance(self.emoji, str): - emoji = f'{self.emoji.name}:{self.emoji.id}' + emoji = f"{self.emoji.name}:{self.emoji.id}" else: emoji = self.emoji diff --git a/discord/role.py b/discord/role.py index b690cbe4..b408e206 100644 --- a/discord/role.py +++ b/discord/role.py @@ -32,8 +32,8 @@ from .mixins import Hashable from .utils import snowflake_time, _get_as_snowflake, MISSING __all__ = ( - 'RoleTags', - 'Role', + "RoleTags", + "Role", ) if TYPE_CHECKING: @@ -68,19 +68,19 @@ class RoleTags: """ __slots__ = ( - 'bot_id', - 'integration_id', - '_premium_subscriber', + "bot_id", + "integration_id", + "_premium_subscriber", ) def __init__(self, data: RoleTagPayload): - self.bot_id: Optional[int] = _get_as_snowflake(data, 'bot_id') - self.integration_id: Optional[int] = _get_as_snowflake(data, 'integration_id') + self.bot_id: Optional[int] = _get_as_snowflake(data, "bot_id") + self.integration_id: Optional[int] = _get_as_snowflake(data, "integration_id") # NOTE: The API returns "null" for this if it's valid, which corresponds to None. # This is different from other fields where "null" means "not there". # So in this case, a value of None is the same as True. # Which means we would need a different sentinel. - self._premium_subscriber: Optional[Any] = data.get('premium_subscriber', MISSING) + self._premium_subscriber: Optional[Any] = data.get("premium_subscriber", MISSING) def is_bot_managed(self) -> bool: """:class:`bool`: Whether the role is associated with a bot.""" @@ -96,12 +96,12 @@ class RoleTags: def __repr__(self) -> str: return ( - f'' + f"" ) -R = TypeVar('R', bound='Role') +R = TypeVar("R", bound="Role") class Role(Hashable): @@ -181,23 +181,23 @@ class Role(Hashable): """ __slots__ = ( - 'id', - 'name', - '_permissions', - '_colour', - 'position', - 'managed', - 'mentionable', - 'hoist', - 'guild', - 'tags', - '_state', + "id", + "name", + "_permissions", + "_colour", + "position", + "managed", + "mentionable", + "hoist", + "guild", + "tags", + "_state", ) def __init__(self, *, guild: Guild, state: ConnectionState, data: RolePayload): self.guild: Guild = guild self._state: ConnectionState = state - self.id: int = int(data['id']) + self.id: int = int(data["id"]) self._update(data) def __str__(self) -> str: @@ -207,14 +207,14 @@ class Role(Hashable): return self.id def __repr__(self) -> str: - return f'' + return f"" def __lt__(self: R, other: R) -> bool: if not isinstance(other, Role) or not isinstance(self, Role): return NotImplemented if self.guild != other.guild: - raise RuntimeError('cannot compare roles from two different guilds.') + raise RuntimeError("cannot compare roles from two different guilds.") # the @everyone role is always the lowest role in hierarchy guild_id = self.guild.id @@ -246,17 +246,17 @@ class Role(Hashable): return not r def _update(self, data: RolePayload): - self.name: str = data['name'] - self._permissions: int = int(data.get('permissions', 0)) - self.position: int = data.get('position', 0) - self._colour: int = data.get('color', 0) - self.hoist: bool = data.get('hoist', False) - self.managed: bool = data.get('managed', False) - self.mentionable: bool = data.get('mentionable', False) + self.name: str = data["name"] + self._permissions: int = int(data.get("permissions", 0)) + self.position: int = data.get("position", 0) + self._colour: int = data.get("color", 0) + self.hoist: bool = data.get("hoist", False) + self.managed: bool = data.get("managed", False) + self.mentionable: bool = data.get("mentionable", False) self.tags: Optional[RoleTags] try: - self.tags = RoleTags(data['tags']) + self.tags = RoleTags(data["tags"]) except KeyError: self.tags = None @@ -316,7 +316,7 @@ class Role(Hashable): @property def mention(self) -> str: """:class:`str`: Returns a string that allows you to mention a role.""" - return f'<@&{self.id}>' + return f"<@&{self.id}>" @property def members(self) -> List[Member]: @@ -420,21 +420,21 @@ class Role(Hashable): if colour is not MISSING: if isinstance(colour, int): - payload['color'] = colour + payload["color"] = colour else: - payload['color'] = colour.value + payload["color"] = colour.value if name is not MISSING: - payload['name'] = name + payload["name"] = name if permissions is not MISSING: - payload['permissions'] = permissions.value + payload["permissions"] = permissions.value if hoist is not MISSING: - payload['hoist'] = hoist + payload["hoist"] = hoist if mentionable is not MISSING: - payload['mentionable'] = mentionable + payload["mentionable"] = mentionable data = await self._state.http.edit_role(self.guild.id, self.id, reason=reason, **payload) return Role(guild=self.guild, data=data, state=self._state) diff --git a/discord/shard.py b/discord/shard.py index edbdebf4..aaed3b92 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -50,11 +50,11 @@ if TYPE_CHECKING: from .activity import BaseActivity from .enums import Status - EI = TypeVar('EI', bound='EventItem') + EI = TypeVar("EI", bound="EventItem") __all__ = ( - 'AutoShardedClient', - 'ShardInfo', + "AutoShardedClient", + "ShardInfo", ) _log = logging.getLogger(__name__) @@ -70,11 +70,11 @@ class EventType: class EventItem: - __slots__ = ('type', 'shard', 'error') + __slots__ = ("type", "shard", "error") - def __init__(self, etype: int, shard: Optional['Shard'], error: Optional[Exception]) -> None: + def __init__(self, etype: int, shard: Optional["Shard"], error: Optional[Exception]) -> None: self.type: int = etype - self.shard: Optional['Shard'] = shard + self.shard: Optional["Shard"] = shard self.error: Optional[Exception] = error def __lt__(self: EI, other: EI) -> bool: @@ -129,11 +129,11 @@ class Shard: async def disconnect(self) -> None: await self.close() - self._dispatch('shard_disconnect', self.id) + self._dispatch("shard_disconnect", self.id) async def _handle_disconnect(self, e: Exception) -> None: - self._dispatch('disconnect') - self._dispatch('shard_disconnect', self.id) + self._dispatch("disconnect") + self._dispatch("shard_disconnect", self.id) if not self._reconnect: self._queue_put(EventItem(EventType.close, self, e)) return @@ -156,7 +156,7 @@ class Shard: return retry = self._backoff.delay() - _log.error('Attempting a reconnect for shard ID %s in %.2fs', self.id, retry, exc_info=e) + _log.error("Attempting a reconnect for shard ID %s in %.2fs", self.id, retry, exc_info=e) await asyncio.sleep(retry) self._queue_put(EventItem(EventType.reconnect, self, e)) @@ -179,9 +179,9 @@ class Shard: async def reidentify(self, exc: ReconnectWebSocket) -> None: self._cancel_task() - self._dispatch('disconnect') - self._dispatch('shard_disconnect', self.id) - _log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id) + self._dispatch("disconnect") + self._dispatch("shard_disconnect", self.id) + _log.info("Got a request to %s the websocket at Shard ID %s.", exc.op, self.id) try: coro = DiscordWebSocket.from_client( self._client, @@ -231,7 +231,7 @@ class ShardInfo: The shard count for this cluster. If this is ``None`` then the bot has not started yet. """ - __slots__ = ('_parent', 'id', 'shard_count') + __slots__ = ("_parent", "id", "shard_count") def __init__(self, parent: Shard, shard_count: Optional[int]) -> None: self._parent: Shard = parent @@ -321,15 +321,15 @@ class AutoShardedClient(Client): _connection: AutoShardedConnectionState def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None: - kwargs.pop('shard_id', None) - self.shard_ids: Optional[List[int]] = kwargs.pop('shard_ids', None) + kwargs.pop("shard_id", None) + self.shard_ids: Optional[List[int]] = kwargs.pop("shard_ids", None) super().__init__(*args, loop=loop, **kwargs) if self.shard_ids is not None: if self.shard_count is None: - raise ClientException('When passing manual shard_ids, you must provide a shard_count.') + raise ClientException("When passing manual shard_ids, you must provide a shard_count.") elif not isinstance(self.shard_ids, (list, tuple)): - raise ClientException('shard_ids parameter must be a list or a tuple.') + raise ClientException("shard_ids parameter must be a list or a tuple.") # instead of a single websocket, we have multiple # the key is the shard_id @@ -363,7 +363,7 @@ class AutoShardedClient(Client): :attr:`latencies` property. Returns ``nan`` if there are no shards ready. """ if not self.__shards: - return float('nan') + return float("nan") return sum(latency for _, latency in self.latencies) / len(self.__shards) @property @@ -393,7 +393,7 @@ class AutoShardedClient(Client): coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id) ws = await asyncio.wait_for(coro, timeout=180.0) except Exception: - _log.exception('Failed to connect for shard_id: %s. Retrying...', shard_id) + _log.exception("Failed to connect for shard_id: %s. Retrying...", shard_id) await asyncio.sleep(5.0) return await self.launch_shard(gateway, shard_id) @@ -503,10 +503,10 @@ class AutoShardedClient(Client): """ if status is None: - status_value = 'online' + status_value = "online" status_enum = Status.online elif status is Status.offline: - status_value = 'invisible' + status_value = "invisible" status_enum = Status.offline else: status_enum = status diff --git a/discord/stage_instance.py b/discord/stage_instance.py index b538eec3..7df08289 100644 --- a/discord/stage_instance.py +++ b/discord/stage_instance.py @@ -31,9 +31,7 @@ from .mixins import Hashable from .errors import InvalidArgument from .enums import StagePrivacyLevel, try_enum -__all__ = ( - 'StageInstance', -) +__all__ = ("StageInstance",) if TYPE_CHECKING: from .types.channel import StageInstance as StageInstancePayload @@ -82,14 +80,14 @@ class StageInstance(Hashable): """ __slots__ = ( - '_state', - 'id', - 'guild', - 'channel_id', - 'topic', - 'privacy_level', - 'discoverable_disabled', - '_cs_channel', + "_state", + "id", + "guild", + "channel_id", + "topic", + "privacy_level", + "discoverable_disabled", + "_cs_channel", ) def __init__(self, *, state: ConnectionState, guild: Guild, data: StageInstancePayload) -> None: @@ -98,25 +96,27 @@ class StageInstance(Hashable): self._update(data) def _update(self, data: StageInstancePayload): - self.id: int = int(data['id']) - self.channel_id: int = int(data['channel_id']) - self.topic: str = data['topic'] - self.privacy_level: StagePrivacyLevel = try_enum(StagePrivacyLevel, data['privacy_level']) - self.discoverable_disabled: bool = data.get('discoverable_disabled', False) + self.id: int = int(data["id"]) + self.channel_id: int = int(data["channel_id"]) + self.topic: str = data["topic"] + self.privacy_level: StagePrivacyLevel = try_enum(StagePrivacyLevel, data["privacy_level"]) + self.discoverable_disabled: bool = data.get("discoverable_disabled", False) def __repr__(self) -> str: - return f'' + return f"" - @cached_slot_property('_cs_channel') + @cached_slot_property("_cs_channel") def channel(self) -> Optional[StageChannel]: """Optional[:class:`StageChannel`]: The channel that stage instance is running in.""" # the returned channel will always be a StageChannel or None - return self._state.get_channel(self.channel_id) # type: ignore + return self._state.get_channel(self.channel_id) # type: ignore def is_public(self) -> bool: return self.privacy_level is StagePrivacyLevel.public - async def edit(self, *, topic: str = MISSING, privacy_level: StagePrivacyLevel = MISSING, reason: Optional[str] = None) -> None: + async def edit( + self, *, topic: str = MISSING, privacy_level: StagePrivacyLevel = MISSING, reason: Optional[str] = None + ) -> None: """|coro| Edits the stage instance. @@ -146,13 +146,13 @@ class StageInstance(Hashable): payload = {} if topic is not MISSING: - payload['topic'] = topic + payload["topic"] = topic if privacy_level is not MISSING: if not isinstance(privacy_level, StagePrivacyLevel): - raise InvalidArgument('privacy_level field must be of type PrivacyLevel') + raise InvalidArgument("privacy_level field must be of type PrivacyLevel") - payload['privacy_level'] = privacy_level.value + payload["privacy_level"] = privacy_level.value if payload: await self._state.http.edit_stage_instance(self.channel_id, **payload, reason=reason) diff --git a/discord/state.py b/discord/state.py index 09777008..edf39263 100644 --- a/discord/state.py +++ b/discord/state.py @@ -76,8 +76,8 @@ if TYPE_CHECKING: from .types.guild import Guild as GuildPayload from .types.message import Message as MessagePayload - T = TypeVar('T') - CS = TypeVar('CS', bound='ConnectionState') + T = TypeVar("T") + CS = TypeVar("CS", bound="ConnectionState") Channel = Union[GuildChannel, VocalGuildChannel, PrivateChannel, PartialMessageable] @@ -136,7 +136,7 @@ async def logging_coroutine(coroutine: Coroutine[Any, Any, T], *, info: str) -> try: await coroutine except Exception: - _log.exception('Exception occurred during %s', info) + _log.exception("Exception occurred during %s", info) class ConnectionState: @@ -158,7 +158,7 @@ class ConnectionState: ) -> None: self.loop: asyncio.AbstractEventLoop = loop self.http: HTTPClient = http - self.max_messages: Optional[int] = options.get('max_messages', 1000) + self.max_messages: Optional[int] = options.get("max_messages", 1000) if self.max_messages is not None and self.max_messages <= 0: self.max_messages = 1000 @@ -167,52 +167,52 @@ class ConnectionState: self.hooks: Dict[str, Callable] = hooks self.shard_count: Optional[int] = None self._ready_task: Optional[asyncio.Task] = None - self.application_id: Optional[int] = utils._get_as_snowflake(options, 'application_id') - self.heartbeat_timeout: float = options.get('heartbeat_timeout', 60.0) - self.guild_ready_timeout: float = options.get('guild_ready_timeout', 2.0) + self.application_id: Optional[int] = utils._get_as_snowflake(options, "application_id") + self.heartbeat_timeout: float = options.get("heartbeat_timeout", 60.0) + self.guild_ready_timeout: float = options.get("guild_ready_timeout", 2.0) if self.guild_ready_timeout < 0: - raise ValueError('guild_ready_timeout cannot be negative') + raise ValueError("guild_ready_timeout cannot be negative") - allowed_mentions = options.get('allowed_mentions') + allowed_mentions = options.get("allowed_mentions") if allowed_mentions is not None and not isinstance(allowed_mentions, AllowedMentions): - raise TypeError('allowed_mentions parameter must be AllowedMentions') + raise TypeError("allowed_mentions parameter must be AllowedMentions") self.allowed_mentions: Optional[AllowedMentions] = allowed_mentions self._chunk_requests: Dict[Union[int, str], ChunkRequest] = {} - activity = options.get('activity', None) + activity = options.get("activity", None) if activity: if not isinstance(activity, BaseActivity): - raise TypeError('activity parameter must derive from BaseActivity.') + raise TypeError("activity parameter must derive from BaseActivity.") activity = activity.to_dict() - status = options.get('status', None) + status = options.get("status", None) if status: if status is Status.offline: - status = 'invisible' + status = "invisible" else: status = str(status) if not isinstance(intents, Intents): - raise TypeError(f'intents parameter must be Intent not {type(intents)!r}') + raise TypeError(f"intents parameter must be Intent not {type(intents)!r}") if not intents.guilds: - _log.warning('Guilds intent seems to be disabled. This may cause state related issues.') + _log.warning("Guilds intent seems to be disabled. This may cause state related issues.") - self._chunk_guilds: bool = options.get('chunk_guilds_at_startup', intents.members) + self._chunk_guilds: bool = options.get("chunk_guilds_at_startup", intents.members) # Ensure these two are set properly if not intents.members and self._chunk_guilds: - raise ValueError('Intents.members must be enabled to chunk guilds at startup.') + raise ValueError("Intents.members must be enabled to chunk guilds at startup.") - cache_flags = options.get('member_cache_flags', None) + cache_flags = options.get("member_cache_flags", None) if cache_flags is None: cache_flags = MemberCacheFlags.from_intents(intents) else: if not isinstance(cache_flags, MemberCacheFlags): - raise TypeError(f'member_cache_flags parameter must be MemberCacheFlags not {type(cache_flags)!r}') + raise TypeError(f"member_cache_flags parameter must be MemberCacheFlags not {type(cache_flags)!r}") cache_flags._verify_intents(intents) @@ -227,7 +227,7 @@ class ConnectionState: self.parsers = parsers = {} for attr, func in inspect.getmembers(self): - if attr.startswith('parse_'): + if attr.startswith("parse_"): parsers[attr[6:].upper()] = func self.clear() @@ -264,7 +264,9 @@ class ConnectionState: else: self._messages: Optional[Deque[Message]] = None - def process_chunk_requests(self, guild_id: int, nonce: Optional[str], members: List[Member], complete: bool) -> None: + def process_chunk_requests( + self, guild_id: int, nonce: Optional[str], members: List[Member], complete: bool + ) -> None: removed = [] for key, request in self._chunk_requests.items(): if request.guild_id == guild_id and request.nonce == nonce: @@ -322,12 +324,12 @@ class ConnectionState: vc.main_ws = ws # type: ignore def store_user(self, data: UserPayload) -> User: - user_id = int(data['id']) + user_id = int(data["id"]) try: return self._users[user_id] except KeyError: user = User(state=self, data=data) - if user.discriminator != '0000': + if user.discriminator != "0000": self._users[user_id] = user user._stored = True return user @@ -347,12 +349,12 @@ class ConnectionState: def store_emoji(self, guild: Guild, data: EmojiPayload) -> Emoji: # the id will be present here - emoji_id = int(data['id']) # type: ignore + emoji_id = int(data["id"]) # type: ignore self._emojis[emoji_id] = emoji = Emoji(guild=guild, state=self, data=data) return emoji def store_sticker(self, guild: Guild, data: GuildStickerPayload) -> GuildSticker: - sticker_id = int(data['id']) + sticker_id = int(data["id"]) self._stickers[sticker_id] = sticker = GuildSticker(state=self, data=data) return sticker @@ -460,9 +462,9 @@ class ConnectionState: return self._chunk_guilds and not guild.chunked and not (self._intents.presences and not guild.large) def _get_guild_channel(self, data: MessagePayload) -> Tuple[Union[Channel, Thread], Optional[Guild]]: - channel_id = int(data['channel_id']) + channel_id = int(data["channel_id"]) try: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) except KeyError: channel = DMChannel._from_message(self, channel_id) guild = None @@ -472,16 +474,18 @@ class ConnectionState: return channel or PartialMessageable(state=self, id=channel_id), guild async def chunker( - self, guild_id: int, query: str = '', limit: int = 0, presences: bool = False, *, nonce: Optional[str] = None + self, guild_id: int, query: str = "", limit: int = 0, presences: bool = False, *, nonce: Optional[str] = None ) -> None: ws = self._get_websocket(guild_id) # This is ignored upstream await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) - async def query_members(self, guild: Guild, query: str, limit: int, user_ids: List[int], cache: bool, presences: bool): + async def query_members( + self, guild: Guild, query: str, limit: int, user_ids: List[int], cache: bool, presences: bool + ): guild_id = guild.id ws = self._get_websocket(guild_id) if ws is None: - raise RuntimeError('Somehow do not have a websocket for this guild_id') + raise RuntimeError("Somehow do not have a websocket for this guild_id") request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache) self._chunk_requests[request.nonce] = request @@ -493,7 +497,9 @@ class ConnectionState: ) return await asyncio.wait_for(request.wait(), timeout=30.0) except asyncio.TimeoutError: - _log.warning('Timed out waiting for chunks with query %r and limit %d for guild_id %d', query, limit, guild_id) + _log.warning( + "Timed out waiting for chunks with query %r and limit %d for guild_id %d", query, limit, guild_id + ) raise async def _delay_ready(self) -> None: @@ -512,20 +518,20 @@ class ConnectionState: states.append((guild, future)) else: if guild.unavailable is False: - self.dispatch('guild_available', guild) + self.dispatch("guild_available", guild) else: - self.dispatch('guild_join', guild) + self.dispatch("guild_join", guild) for guild, future in states: try: await asyncio.wait_for(future, timeout=5.0) except asyncio.TimeoutError: - _log.warning('Shard ID %s timed out waiting for chunks for guild_id %s.', guild.shard_id, guild.id) + _log.warning("Shard ID %s timed out waiting for chunks for guild_id %s.", guild.shard_id, guild.id) if guild.unavailable is False: - self.dispatch('guild_available', guild) + self.dispatch("guild_available", guild) else: - self.dispatch('guild_join', guild) + self.dispatch("guild_join", guild) # remove the state try: @@ -537,8 +543,8 @@ class ConnectionState: pass else: # dispatch the event - self.call_handlers('ready') - self.dispatch('ready') + self.call_handlers("ready") + self.dispatch("ready") finally: self._ready_task = None @@ -548,33 +554,33 @@ class ConnectionState: self._ready_state = asyncio.Queue() self.clear(views=False) - self.user = ClientUser(state=self, data=data['user']) - self.store_user(data['user']) + self.user = ClientUser(state=self, data=data["user"]) + self.store_user(data["user"]) if self.application_id is None: try: - application = data['application'] + application = data["application"] except KeyError: pass else: - self.application_id = utils._get_as_snowflake(application, 'id') + self.application_id = utils._get_as_snowflake(application, "id") # flags will always be present here - self.application_flags = ApplicationFlags._from_value(application['flags']) # type: ignore + self.application_flags = ApplicationFlags._from_value(application["flags"]) # type: ignore - for guild_data in data['guilds']: + for guild_data in data["guilds"]: self._add_guild_from_data(guild_data) - self.dispatch('connect') + self.dispatch("connect") self._ready_task = asyncio.create_task(self._delay_ready()) def parse_resumed(self, data) -> None: - self.dispatch('resumed') + self.dispatch("resumed") def parse_message_create(self, data) -> None: channel, _ = self._get_guild_channel(data) # channel would be the correct type here message = Message(channel=channel, data=data, state=self) # type: ignore - self.dispatch('message', message) + self.dispatch("message", message) if self._messages is not None: self._messages.append(message) # we ensure that the channel is either a TextChannel or Thread @@ -585,9 +591,9 @@ class ConnectionState: raw = RawMessageDeleteEvent(data) found = self._get_message(raw.message_id) raw.cached_message = found - self.dispatch('raw_message_delete', raw) + self.dispatch("raw_message_delete", raw) if self._messages is not None and found is not None: - self.dispatch('message_delete', found) + self.dispatch("message_delete", found) self._messages.remove(found) def parse_message_delete_bulk(self, data) -> None: @@ -597,9 +603,9 @@ class ConnectionState: else: found_messages = [] raw.cached_messages = found_messages - self.dispatch('raw_bulk_message_delete', raw) + self.dispatch("raw_bulk_message_delete", raw) if found_messages: - self.dispatch('bulk_message_delete', found_messages) + self.dispatch("bulk_message_delete", found_messages) for msg in found_messages: # self._messages won't be None here self._messages.remove(msg) # type: ignore @@ -610,25 +616,25 @@ class ConnectionState: if message is not None: older_message = copy.copy(message) raw.cached_message = older_message - self.dispatch('raw_message_edit', raw) + self.dispatch("raw_message_edit", raw) message._update(data) # Coerce the `after` parameter to take the new updated Member # ref: #5999 older_message.author = message.author - self.dispatch('message_edit', older_message, message) + self.dispatch("message_edit", older_message, message) else: - self.dispatch('raw_message_edit', raw) + self.dispatch("raw_message_edit", raw) - if 'components' in data and self._view_store.is_message_tracked(raw.message_id): - self._view_store.update_from_message(raw.message_id, data['components']) + if "components" in data and self._view_store.is_message_tracked(raw.message_id): + self._view_store.update_from_message(raw.message_id, data["components"]) def parse_message_reaction_add(self, data) -> None: - emoji = data['emoji'] - emoji_id = utils._get_as_snowflake(emoji, 'id') - emoji = PartialEmoji.with_state(self, id=emoji_id, animated=emoji.get('animated', False), name=emoji['name']) - raw = RawReactionActionEvent(data, emoji, 'REACTION_ADD') + emoji = data["emoji"] + emoji_id = utils._get_as_snowflake(emoji, "id") + emoji = PartialEmoji.with_state(self, id=emoji_id, animated=emoji.get("animated", False), name=emoji["name"]) + raw = RawReactionActionEvent(data, emoji, "REACTION_ADD") - member_data = data.get('member') + member_data = data.get("member") if member_data: guild = self._get_guild(raw.guild_id) if guild is not None: @@ -637,7 +643,7 @@ class ConnectionState: raw.member = None else: raw.member = None - self.dispatch('raw_reaction_add', raw) + self.dispatch("raw_reaction_add", raw) # rich interface here message = self._get_message(raw.message_id) @@ -647,24 +653,24 @@ class ConnectionState: user = raw.member or self._get_reaction_user(message.channel, raw.user_id) if user: - self.dispatch('reaction_add', reaction, user) + self.dispatch("reaction_add", reaction, user) def parse_message_reaction_remove_all(self, data) -> None: raw = RawReactionClearEvent(data) - self.dispatch('raw_reaction_clear', raw) + self.dispatch("raw_reaction_clear", raw) message = self._get_message(raw.message_id) if message is not None: old_reactions = message.reactions.copy() message.reactions.clear() - self.dispatch('reaction_clear', message, old_reactions) + self.dispatch("reaction_clear", message, old_reactions) def parse_message_reaction_remove(self, data) -> None: - emoji = data['emoji'] - emoji_id = utils._get_as_snowflake(emoji, 'id') - emoji = PartialEmoji.with_state(self, id=emoji_id, name=emoji['name']) - raw = RawReactionActionEvent(data, emoji, 'REACTION_REMOVE') - self.dispatch('raw_reaction_remove', raw) + emoji = data["emoji"] + emoji_id = utils._get_as_snowflake(emoji, "id") + emoji = PartialEmoji.with_state(self, id=emoji_id, name=emoji["name"]) + raw = RawReactionActionEvent(data, emoji, "REACTION_REMOVE") + self.dispatch("raw_reaction_remove", raw) message = self._get_message(raw.message_id) if message is not None: @@ -676,14 +682,14 @@ class ConnectionState: else: user = self._get_reaction_user(message.channel, raw.user_id) if user: - self.dispatch('reaction_remove', reaction, user) + self.dispatch("reaction_remove", reaction, user) def parse_message_reaction_remove_emoji(self, data) -> None: - emoji = data['emoji'] - emoji_id = utils._get_as_snowflake(emoji, 'id') - emoji = PartialEmoji.with_state(self, id=emoji_id, name=emoji['name']) + emoji = data["emoji"] + emoji_id = utils._get_as_snowflake(emoji, "id") + emoji = PartialEmoji.with_state(self, id=emoji_id, name=emoji["name"]) raw = RawReactionClearEmojiEvent(data, emoji) - self.dispatch('raw_reaction_clear_emoji', raw) + self.dispatch("raw_reaction_clear_emoji", raw) message = self._get_message(raw.message_id) if message is not None: @@ -693,38 +699,38 @@ class ConnectionState: pass else: if reaction: - self.dispatch('reaction_clear_emoji', reaction) + self.dispatch("reaction_clear_emoji", reaction) def parse_interaction_create(self, data) -> None: interaction = Interaction(data=data, state=self) - if data['type'] == 3: # interaction component - custom_id = interaction.data['custom_id'] # type: ignore - component_type = interaction.data['component_type'] # type: ignore + if data["type"] == 3: # interaction component + custom_id = interaction.data["custom_id"] # type: ignore + component_type = interaction.data["component_type"] # type: ignore self._view_store.dispatch(component_type, custom_id, interaction) - self.dispatch('interaction', interaction) + self.dispatch("interaction", interaction) def parse_presence_update(self, data) -> None: - guild_id = utils._get_as_snowflake(data, 'guild_id') + guild_id = utils._get_as_snowflake(data, "guild_id") # guild_id won't be None here guild = self._get_guild(guild_id) if guild is None: - _log.debug('PRESENCE_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) + _log.debug("PRESENCE_UPDATE referencing an unknown guild ID: %s. Discarding.", guild_id) return - user = data['user'] - member_id = int(user['id']) + user = data["user"] + member_id = int(user["id"]) member = guild.get_member(member_id) if member is None: - _log.debug('PRESENCE_UPDATE referencing an unknown member ID: %s. Discarding', member_id) + _log.debug("PRESENCE_UPDATE referencing an unknown member ID: %s. Discarding", member_id) return old_member = Member._copy(member) user_update = member._presence_update(data=data, user=user) if user_update: - self.dispatch('user_update', user_update[0], user_update[1]) + self.dispatch("user_update", user_update[0], user_update[1]) - self.dispatch('presence_update', old_member, member) + self.dispatch("presence_update", old_member, member) def parse_user_update(self, data) -> None: # self.user is *always* cached when this is called @@ -736,66 +742,66 @@ class ConnectionState: def parse_invite_create(self, data) -> None: invite = Invite.from_gateway(state=self, data=data) - self.dispatch('invite_create', invite) + self.dispatch("invite_create", invite) def parse_invite_delete(self, data) -> None: invite = Invite.from_gateway(state=self, data=data) - self.dispatch('invite_delete', invite) + self.dispatch("invite_delete", invite) def parse_channel_delete(self, data) -> None: - guild = self._get_guild(utils._get_as_snowflake(data, 'guild_id')) - channel_id = int(data['id']) + guild = self._get_guild(utils._get_as_snowflake(data, "guild_id")) + channel_id = int(data["id"]) if guild is not None: channel = guild.get_channel(channel_id) if channel is not None: guild._remove_channel(channel) - self.dispatch('guild_channel_delete', channel) + self.dispatch("guild_channel_delete", channel) def parse_channel_update(self, data) -> None: - channel_type = try_enum(ChannelType, data.get('type')) - channel_id = int(data['id']) + channel_type = try_enum(ChannelType, data.get("type")) + channel_id = int(data["id"]) if channel_type is ChannelType.group: channel = self._get_private_channel(channel_id) old_channel = copy.copy(channel) # the channel is a GroupChannel channel._update_group(data) # type: ignore - self.dispatch('private_channel_update', old_channel, channel) + self.dispatch("private_channel_update", old_channel, channel) return - guild_id = utils._get_as_snowflake(data, 'guild_id') + guild_id = utils._get_as_snowflake(data, "guild_id") guild = self._get_guild(guild_id) if guild is not None: channel = guild.get_channel(channel_id) if channel is not None: old_channel = copy.copy(channel) channel._update(guild, data) - self.dispatch('guild_channel_update', old_channel, channel) + self.dispatch("guild_channel_update", old_channel, channel) else: - _log.debug('CHANNEL_UPDATE referencing an unknown channel ID: %s. Discarding.', channel_id) + _log.debug("CHANNEL_UPDATE referencing an unknown channel ID: %s. Discarding.", channel_id) else: - _log.debug('CHANNEL_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) + _log.debug("CHANNEL_UPDATE referencing an unknown guild ID: %s. Discarding.", guild_id) def parse_channel_create(self, data) -> None: - factory, ch_type = _channel_factory(data['type']) + factory, ch_type = _channel_factory(data["type"]) if factory is None: - _log.debug('CHANNEL_CREATE referencing an unknown channel type %s. Discarding.', data['type']) + _log.debug("CHANNEL_CREATE referencing an unknown channel type %s. Discarding.", data["type"]) return - guild_id = utils._get_as_snowflake(data, 'guild_id') + guild_id = utils._get_as_snowflake(data, "guild_id") guild = self._get_guild(guild_id) if guild is not None: # the factory can't be a DMChannel or GroupChannel here channel = factory(guild=guild, state=self, data=data) # type: ignore guild._add_channel(channel) # type: ignore - self.dispatch('guild_channel_create', channel) + self.dispatch("guild_channel_create", channel) else: - _log.debug('CHANNEL_CREATE referencing an unknown guild ID: %s. Discarding.', guild_id) + _log.debug("CHANNEL_CREATE referencing an unknown guild ID: %s. Discarding.", guild_id) return def parse_channel_pins_update(self, data) -> None: - channel_id = int(data['channel_id']) + channel_id = int(data["channel_id"]) try: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) except KeyError: guild = None channel = self._get_private_channel(channel_id) @@ -803,69 +809,69 @@ class ConnectionState: channel = guild and guild._resolve_channel(channel_id) if channel is None: - _log.debug('CHANNEL_PINS_UPDATE referencing an unknown channel ID: %s. Discarding.', channel_id) + _log.debug("CHANNEL_PINS_UPDATE referencing an unknown channel ID: %s. Discarding.", channel_id) return - last_pin = utils.parse_time(data['last_pin_timestamp']) if data['last_pin_timestamp'] else None + last_pin = utils.parse_time(data["last_pin_timestamp"]) if data["last_pin_timestamp"] else None if guild is None: - self.dispatch('private_channel_pins_update', channel, last_pin) + self.dispatch("private_channel_pins_update", channel, last_pin) else: - self.dispatch('guild_channel_pins_update', channel, last_pin) + self.dispatch("guild_channel_pins_update", channel, last_pin) def parse_thread_create(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: - _log.debug('THREAD_CREATE referencing an unknown guild ID: %s. Discarding', guild_id) + _log.debug("THREAD_CREATE referencing an unknown guild ID: %s. Discarding", guild_id) return thread = Thread(guild=guild, state=guild._state, data=data) has_thread = guild.get_thread(thread.id) guild._add_thread(thread) if not has_thread: - self.dispatch('thread_join', thread) + self.dispatch("thread_join", thread) def parse_thread_update(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild = self._get_guild(guild_id) if guild is None: - _log.debug('THREAD_UPDATE referencing an unknown guild ID: %s. Discarding', guild_id) + _log.debug("THREAD_UPDATE referencing an unknown guild ID: %s. Discarding", guild_id) return - thread_id = int(data['id']) + thread_id = int(data["id"]) thread = guild.get_thread(thread_id) if thread is not None: old = copy.copy(thread) thread._update(data) - self.dispatch('thread_update', old, thread) + self.dispatch("thread_update", old, thread) else: thread = Thread(guild=guild, state=guild._state, data=data) guild._add_thread(thread) - self.dispatch('thread_join', thread) + self.dispatch("thread_join", thread) def parse_thread_delete(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild = self._get_guild(guild_id) if guild is None: - _log.debug('THREAD_DELETE referencing an unknown guild ID: %s. Discarding', guild_id) + _log.debug("THREAD_DELETE referencing an unknown guild ID: %s. Discarding", guild_id) return - thread_id = int(data['id']) + thread_id = int(data["id"]) thread = guild.get_thread(thread_id) if thread is not None: guild._remove_thread(thread) # type: ignore - self.dispatch('thread_delete', thread) + self.dispatch("thread_delete", thread) def parse_thread_list_sync(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: - _log.debug('THREAD_LIST_SYNC referencing an unknown guild ID: %s. Discarding', guild_id) + _log.debug("THREAD_LIST_SYNC referencing an unknown guild ID: %s. Discarding", guild_id) return try: - channel_ids = set(data['channel_ids']) + channel_ids = set(data["channel_ids"]) except KeyError: # If not provided, then the entire guild is being synced # So all previous thread data should be overwritten @@ -874,12 +880,12 @@ class ConnectionState: else: previous_threads = guild._filter_threads(channel_ids) - threads = {d['id']: guild._store_thread(d) for d in data.get('threads', [])} + threads = {d["id"]: guild._store_thread(d) for d in data.get("threads", [])} - for member in data.get('members', []): + for member in data.get("members", []): try: # note: member['id'] is the thread_id - thread = threads[member['id']] + thread = threads[member["id"]] except KeyError: continue else: @@ -888,63 +894,63 @@ class ConnectionState: for thread in threads.values(): old = previous_threads.pop(thread.id, None) if old is None: - self.dispatch('thread_join', thread) + self.dispatch("thread_join", thread) for thread in previous_threads.values(): - self.dispatch('thread_remove', thread) + self.dispatch("thread_remove", thread) def parse_thread_member_update(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: - _log.debug('THREAD_MEMBER_UPDATE referencing an unknown guild ID: %s. Discarding', guild_id) + _log.debug("THREAD_MEMBER_UPDATE referencing an unknown guild ID: %s. Discarding", guild_id) return - thread_id = int(data['id']) + thread_id = int(data["id"]) thread: Optional[Thread] = guild.get_thread(thread_id) if thread is None: - _log.debug('THREAD_MEMBER_UPDATE referencing an unknown thread ID: %s. Discarding', thread_id) + _log.debug("THREAD_MEMBER_UPDATE referencing an unknown thread ID: %s. Discarding", thread_id) return member = ThreadMember(thread, data) thread.me = member def parse_thread_members_update(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: - _log.debug('THREAD_MEMBERS_UPDATE referencing an unknown guild ID: %s. Discarding', guild_id) + _log.debug("THREAD_MEMBERS_UPDATE referencing an unknown guild ID: %s. Discarding", guild_id) return - thread_id = int(data['id']) + thread_id = int(data["id"]) thread: Optional[Thread] = guild.get_thread(thread_id) if thread is None: - _log.debug('THREAD_MEMBERS_UPDATE referencing an unknown thread ID: %s. Discarding', thread_id) + _log.debug("THREAD_MEMBERS_UPDATE referencing an unknown thread ID: %s. Discarding", thread_id) return - added_members = [ThreadMember(thread, d) for d in data.get('added_members', [])] - removed_member_ids = [int(x) for x in data.get('removed_member_ids', [])] + added_members = [ThreadMember(thread, d) for d in data.get("added_members", [])] + removed_member_ids = [int(x) for x in data.get("removed_member_ids", [])] self_id = self.self_id for member in added_members: if member.id != self_id: thread._add_member(member) - self.dispatch('thread_member_join', member) + self.dispatch("thread_member_join", member) else: thread.me = member - self.dispatch('thread_join', thread) + self.dispatch("thread_join", thread) for member_id in removed_member_ids: if member_id != self_id: member = thread._pop_member(member_id) if member is not None: - self.dispatch('thread_member_remove', member) + self.dispatch("thread_member_remove", member) else: - self.dispatch('thread_remove', thread) + self.dispatch("thread_remove", thread) def parse_guild_member_add(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is None: - _log.debug('GUILD_MEMBER_ADD referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug("GUILD_MEMBER_ADD referencing an unknown guild ID: %s. Discarding.", data["guild_id"]) return member = Member(guild=guild, data=data, state=self) @@ -956,30 +962,30 @@ class ConnectionState: except AttributeError: pass - self.dispatch('member_join', member) + self.dispatch("member_join", member) def parse_guild_member_remove(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: try: guild._member_count -= 1 except AttributeError: pass - user_id = int(data['user']['id']) + user_id = int(data["user"]["id"]) member = guild.get_member(user_id) if member is not None: guild._remove_member(member) # type: ignore - self.dispatch('member_remove', member) + self.dispatch("member_remove", member) else: - _log.debug('GUILD_MEMBER_REMOVE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug("GUILD_MEMBER_REMOVE referencing an unknown guild ID: %s. Discarding.", data["guild_id"]) def parse_guild_member_update(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) - user = data['user'] - user_id = int(user['id']) + guild = self._get_guild(int(data["guild_id"])) + user = data["user"] + user_id = int(user["id"]) if guild is None: - _log.debug('GUILD_MEMBER_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug("GUILD_MEMBER_UPDATE referencing an unknown guild ID: %s. Discarding.", data["guild_id"]) return member = guild.get_member(user_id) @@ -988,9 +994,9 @@ class ConnectionState: member._update(data) user_update = member._update_inner_user(user) if user_update: - self.dispatch('user_update', user_update[0], user_update[1]) + self.dispatch("user_update", user_update[0], user_update[1]) - self.dispatch('member_update', old_member, member) + self.dispatch("member_update", old_member, member) else: if self.member_cache_flags.joined: member = Member(data=data, guild=guild, state=self) @@ -998,43 +1004,43 @@ class ConnectionState: # Force an update on the inner user if necessary user_update = member._update_inner_user(user) if user_update: - self.dispatch('user_update', user_update[0], user_update[1]) + self.dispatch("user_update", user_update[0], user_update[1]) guild._add_member(member) - _log.debug('GUILD_MEMBER_UPDATE referencing an unknown member ID: %s. Discarding.', user_id) + _log.debug("GUILD_MEMBER_UPDATE referencing an unknown member ID: %s. Discarding.", user_id) def parse_guild_emojis_update(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is None: - _log.debug('GUILD_EMOJIS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug("GUILD_EMOJIS_UPDATE referencing an unknown guild ID: %s. Discarding.", data["guild_id"]) return before_emojis = guild.emojis for emoji in before_emojis: self._emojis.pop(emoji.id, None) # guild won't be None here - guild.emojis = tuple(map(lambda d: self.store_emoji(guild, d), data['emojis'])) # type: ignore - self.dispatch('guild_emojis_update', guild, before_emojis, guild.emojis) + guild.emojis = tuple(map(lambda d: self.store_emoji(guild, d), data["emojis"])) # type: ignore + self.dispatch("guild_emojis_update", guild, before_emojis, guild.emojis) def parse_guild_stickers_update(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is None: - _log.debug('GUILD_STICKERS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug("GUILD_STICKERS_UPDATE referencing an unknown guild ID: %s. Discarding.", data["guild_id"]) return before_stickers = guild.stickers for emoji in before_stickers: self._stickers.pop(emoji.id, None) # guild won't be None here - guild.stickers = tuple(map(lambda d: self.store_sticker(guild, d), data['stickers'])) # type: ignore - self.dispatch('guild_stickers_update', guild, before_stickers, guild.stickers) + guild.stickers = tuple(map(lambda d: self.store_sticker(guild, d), data["stickers"])) # type: ignore + self.dispatch("guild_stickers_update", guild, before_stickers, guild.stickers) def _get_create_guild(self, data): - if data.get('unavailable') is False: + if data.get("unavailable") is False: # GUILD_CREATE with unavailable in the response # usually means that the guild has become available # and is therefore in the cache - guild = self._get_guild(int(data['id'])) + guild = self._get_guild(int(data["id"])) if guild is not None: guild.unavailable = False guild._from_data(data) @@ -1060,15 +1066,15 @@ class ConnectionState: try: await asyncio.wait_for(self.chunk_guild(guild), timeout=60.0) except asyncio.TimeoutError: - _log.info('Somehow timed out waiting for chunks.') + _log.info("Somehow timed out waiting for chunks.") if unavailable is False: - self.dispatch('guild_available', guild) + self.dispatch("guild_available", guild) else: - self.dispatch('guild_join', guild) + self.dispatch("guild_join", guild) def parse_guild_create(self, data) -> None: - unavailable = data.get('unavailable') + unavailable = data.get("unavailable") if unavailable is True: # joined a guild with unavailable == True so.. return @@ -1091,30 +1097,30 @@ class ConnectionState: # Dispatch available if newly available if unavailable is False: - self.dispatch('guild_available', guild) + self.dispatch("guild_available", guild) else: - self.dispatch('guild_join', guild) + self.dispatch("guild_join", guild) def parse_guild_update(self, data) -> None: - guild = self._get_guild(int(data['id'])) + guild = self._get_guild(int(data["id"])) if guild is not None: old_guild = copy.copy(guild) guild._from_data(data) - self.dispatch('guild_update', old_guild, guild) + self.dispatch("guild_update", old_guild, guild) else: - _log.debug('GUILD_UPDATE referencing an unknown guild ID: %s. Discarding.', data['id']) + _log.debug("GUILD_UPDATE referencing an unknown guild ID: %s. Discarding.", data["id"]) def parse_guild_delete(self, data) -> None: - guild = self._get_guild(int(data['id'])) + guild = self._get_guild(int(data["id"])) if guild is None: - _log.debug('GUILD_DELETE referencing an unknown guild ID: %s. Discarding.', data['id']) + _log.debug("GUILD_DELETE referencing an unknown guild ID: %s. Discarding.", data["id"]) return - if data.get('unavailable', False): + if data.get("unavailable", False): # GUILD_DELETE with unavailable being True means that the # guild that was available is now currently unavailable guild.unavailable = True - self.dispatch('guild_unavailable', guild) + self.dispatch("guild_unavailable", guild) return # do a cleanup of the messages cache @@ -1124,7 +1130,7 @@ class ConnectionState: ) self._remove_guild(guild) - self.dispatch('guild_remove', guild) + self.dispatch("guild_remove", guild) def parse_guild_ban_add(self, data) -> None: # we make the assumption that GUILD_BAN_ADD is done @@ -1132,174 +1138,174 @@ class ConnectionState: # hence we don't remove it from cache or do anything # strange with it, the main purpose of this event # is mainly to dispatch to another event worth listening to for logging - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: try: - user = User(data=data['user'], state=self) + user = User(data=data["user"], state=self) except KeyError: pass else: member = guild.get_member(user.id) or user - self.dispatch('member_ban', guild, member) + self.dispatch("member_ban", guild, member) def parse_guild_ban_remove(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) - if guild is not None and 'user' in data: - user = self.store_user(data['user']) - self.dispatch('member_unban', guild, user) + guild = self._get_guild(int(data["guild_id"])) + if guild is not None and "user" in data: + user = self.store_user(data["user"]) + self.dispatch("member_unban", guild, user) def parse_guild_role_create(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is None: - _log.debug('GUILD_ROLE_CREATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug("GUILD_ROLE_CREATE referencing an unknown guild ID: %s. Discarding.", data["guild_id"]) return - role_data = data['role'] + role_data = data["role"] role = Role(guild=guild, data=role_data, state=self) guild._add_role(role) - self.dispatch('guild_role_create', role) + self.dispatch("guild_role_create", role) def parse_guild_role_delete(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: - role_id = int(data['role_id']) + role_id = int(data["role_id"]) try: role = guild._remove_role(role_id) except KeyError: return else: - self.dispatch('guild_role_delete', role) + self.dispatch("guild_role_delete", role) else: - _log.debug('GUILD_ROLE_DELETE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug("GUILD_ROLE_DELETE referencing an unknown guild ID: %s. Discarding.", data["guild_id"]) def parse_guild_role_update(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: - role_data = data['role'] - role_id = int(role_data['id']) + role_data = data["role"] + role_id = int(role_data["id"]) role = guild.get_role(role_id) if role is not None: old_role = copy.copy(role) role._update(role_data) - self.dispatch('guild_role_update', old_role, role) + self.dispatch("guild_role_update", old_role, role) else: - _log.debug('GUILD_ROLE_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug("GUILD_ROLE_UPDATE referencing an unknown guild ID: %s. Discarding.", data["guild_id"]) def parse_guild_members_chunk(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild = self._get_guild(guild_id) - presences = data.get('presences', []) + presences = data.get("presences", []) # the guild won't be None here - members = [Member(guild=guild, data=member, state=self) for member in data.get('members', [])] # type: ignore - _log.debug('Processed a chunk for %s members in guild ID %s.', len(members), guild_id) + members = [Member(guild=guild, data=member, state=self) for member in data.get("members", [])] # type: ignore + _log.debug("Processed a chunk for %s members in guild ID %s.", len(members), guild_id) if presences: member_dict = {str(member.id): member for member in members} for presence in presences: - user = presence['user'] - member_id = user['id'] + user = presence["user"] + member_id = user["id"] member = member_dict.get(member_id) if member is not None: member._presence_update(presence, user) - complete = data.get('chunk_index', 0) + 1 == data.get('chunk_count') - self.process_chunk_requests(guild_id, data.get('nonce'), members, complete) + complete = data.get("chunk_index", 0) + 1 == data.get("chunk_count") + self.process_chunk_requests(guild_id, data.get("nonce"), members, complete) def parse_guild_integrations_update(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: - self.dispatch('guild_integrations_update', guild) + self.dispatch("guild_integrations_update", guild) else: - _log.debug('GUILD_INTEGRATIONS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug("GUILD_INTEGRATIONS_UPDATE referencing an unknown guild ID: %s. Discarding.", data["guild_id"]) def parse_integration_create(self, data) -> None: - guild_id = int(data.pop('guild_id')) + guild_id = int(data.pop("guild_id")) guild = self._get_guild(guild_id) if guild is not None: - cls, _ = _integration_factory(data['type']) + cls, _ = _integration_factory(data["type"]) integration = cls(data=data, guild=guild) - self.dispatch('integration_create', integration) + self.dispatch("integration_create", integration) else: - _log.debug('INTEGRATION_CREATE referencing an unknown guild ID: %s. Discarding.', guild_id) + _log.debug("INTEGRATION_CREATE referencing an unknown guild ID: %s. Discarding.", guild_id) def parse_integration_update(self, data) -> None: - guild_id = int(data.pop('guild_id')) + guild_id = int(data.pop("guild_id")) guild = self._get_guild(guild_id) if guild is not None: - cls, _ = _integration_factory(data['type']) + cls, _ = _integration_factory(data["type"]) integration = cls(data=data, guild=guild) - self.dispatch('integration_update', integration) + self.dispatch("integration_update", integration) else: - _log.debug('INTEGRATION_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) + _log.debug("INTEGRATION_UPDATE referencing an unknown guild ID: %s. Discarding.", guild_id) def parse_integration_delete(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild = self._get_guild(guild_id) if guild is not None: raw = RawIntegrationDeleteEvent(data) - self.dispatch('raw_integration_delete', raw) + self.dispatch("raw_integration_delete", raw) else: - _log.debug('INTEGRATION_DELETE referencing an unknown guild ID: %s. Discarding.', guild_id) + _log.debug("INTEGRATION_DELETE referencing an unknown guild ID: %s. Discarding.", guild_id) def parse_webhooks_update(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is None: - _log.debug('WEBHOOKS_UPDATE referencing an unknown guild ID: %s. Discarding', data['guild_id']) + _log.debug("WEBHOOKS_UPDATE referencing an unknown guild ID: %s. Discarding", data["guild_id"]) return - channel = guild.get_channel(int(data['channel_id'])) + channel = guild.get_channel(int(data["channel_id"])) if channel is not None: - self.dispatch('webhooks_update', channel) + self.dispatch("webhooks_update", channel) else: - _log.debug('WEBHOOKS_UPDATE referencing an unknown channel ID: %s. Discarding.', data['channel_id']) + _log.debug("WEBHOOKS_UPDATE referencing an unknown channel ID: %s. Discarding.", data["channel_id"]) def parse_stage_instance_create(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: stage_instance = StageInstance(guild=guild, state=self, data=data) guild._stage_instances[stage_instance.id] = stage_instance - self.dispatch('stage_instance_create', stage_instance) + self.dispatch("stage_instance_create", stage_instance) else: - _log.debug('STAGE_INSTANCE_CREATE referencing unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug("STAGE_INSTANCE_CREATE referencing unknown guild ID: %s. Discarding.", data["guild_id"]) def parse_stage_instance_update(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: - stage_instance = guild._stage_instances.get(int(data['id'])) + stage_instance = guild._stage_instances.get(int(data["id"])) if stage_instance is not None: old_stage_instance = copy.copy(stage_instance) stage_instance._update(data) - self.dispatch('stage_instance_update', old_stage_instance, stage_instance) + self.dispatch("stage_instance_update", old_stage_instance, stage_instance) else: - _log.debug('STAGE_INSTANCE_UPDATE referencing unknown stage instance ID: %s. Discarding.', data['id']) + _log.debug("STAGE_INSTANCE_UPDATE referencing unknown stage instance ID: %s. Discarding.", data["id"]) else: - _log.debug('STAGE_INSTANCE_UPDATE referencing unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug("STAGE_INSTANCE_UPDATE referencing unknown guild ID: %s. Discarding.", data["guild_id"]) def parse_stage_instance_delete(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: try: - stage_instance = guild._stage_instances.pop(int(data['id'])) + stage_instance = guild._stage_instances.pop(int(data["id"])) except KeyError: pass else: - self.dispatch('stage_instance_delete', stage_instance) + self.dispatch("stage_instance_delete", stage_instance) else: - _log.debug('STAGE_INSTANCE_DELETE referencing unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug("STAGE_INSTANCE_DELETE referencing unknown guild ID: %s. Discarding.", data["guild_id"]) def parse_voice_state_update(self, data) -> None: - guild = self._get_guild(utils._get_as_snowflake(data, 'guild_id')) - channel_id = utils._get_as_snowflake(data, 'channel_id') + guild = self._get_guild(utils._get_as_snowflake(data, "guild_id")) + channel_id = utils._get_as_snowflake(data, "channel_id") flags = self.member_cache_flags # self.user is *always* cached when this is called self_id = self.user.id # type: ignore if guild is not None: - if int(data['user_id']) == self_id: + if int(data["user_id"]) == self_id: voice = self._get_voice_client(guild.id) if voice is not None: coro = voice.on_voice_state_update(data) - asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice state update handler')) + asyncio.create_task(logging_coroutine(coro, info="Voice Protocol voice state update handler")) member, before, after = guild._update_voice_state(data, channel_id) # type: ignore if member is not None: @@ -1311,25 +1317,25 @@ class ConnectionState: elif channel_id is not None: guild._add_member(member) - self.dispatch('voice_state_update', member, before, after) + self.dispatch("voice_state_update", member, before, after) else: - _log.debug('VOICE_STATE_UPDATE referencing an unknown member ID: %s. Discarding.', data['user_id']) + _log.debug("VOICE_STATE_UPDATE referencing an unknown member ID: %s. Discarding.", data["user_id"]) def parse_voice_server_update(self, data) -> None: try: - key_id = int(data['guild_id']) + key_id = int(data["guild_id"]) except KeyError: - key_id = int(data['channel_id']) + key_id = int(data["channel_id"]) vc = self._get_voice_client(key_id) if vc is not None: coro = vc.on_voice_server_update(data) - asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice server update handler')) + asyncio.create_task(logging_coroutine(coro, info="Voice Protocol voice server update handler")) def parse_typing_start(self, data) -> None: raw = RawTypingEvent(data) - member_data = data.get('member') + member_data = data.get("member") if member_data: guild = self._get_guild(raw.guild_id) if guild is not None: @@ -1338,14 +1344,14 @@ class ConnectionState: raw.member = None else: raw.member = None - self.dispatch('raw_typing', raw) + self.dispatch("raw_typing", raw) channel, guild = self._get_guild_channel(data) if channel is not None: user = raw.member or self._get_typing_user(channel, raw.user_id) if user is not None: - self.dispatch('typing', channel, user, raw.when) + self.dispatch("typing", channel, user, raw.when) def _get_typing_user(self, channel: Optional[MessageableChannel], user_id: int) -> Optional[Union[User, Member]]: if isinstance(channel, DMChannel): @@ -1365,15 +1371,15 @@ class ConnectionState: return self.get_user(user_id) def get_reaction_emoji(self, data) -> Union[Emoji, PartialEmoji]: - emoji_id = utils._get_as_snowflake(data, 'id') + emoji_id = utils._get_as_snowflake(data, "id") if not emoji_id: - return data['name'] + return data["name"] try: return self._emojis[emoji_id] except KeyError: - return PartialEmoji.with_state(self, animated=data.get('animated', False), id=emoji_id, name=data['name']) + return PartialEmoji.with_state(self, animated=data.get("animated", False), id=emoji_id, name=data["name"]) def _upgrade_partial_emoji(self, emoji: PartialEmoji) -> Union[Emoji, PartialEmoji, str]: emoji_id = emoji.id @@ -1425,7 +1431,7 @@ class AutoShardedConnectionState(ConnectionState): async def chunker( self, guild_id: int, - query: str = '', + query: str = "", limit: int = 0, presences: bool = False, *, @@ -1449,12 +1455,12 @@ class AutoShardedConnectionState(ConnectionState): break else: if self._guild_needs_chunking(guild): - _log.debug('Guild ID %d requires chunking, will be done in the background.', guild.id) + _log.debug("Guild ID %d requires chunking, will be done in the background.", guild.id) if len(current_bucket) >= max_concurrency: try: await utils.sane_wait_for(current_bucket, timeout=max_concurrency * 70.0) except asyncio.TimeoutError: - fmt = 'Shard ID %s failed to wait for chunks from a sub-bucket with length %d' + fmt = "Shard ID %s failed to wait for chunks from a sub-bucket with length %d" _log.warning(fmt, guild.shard_id, len(current_bucket)) finally: current_bucket = [] @@ -1477,15 +1483,15 @@ class AutoShardedConnectionState(ConnectionState): await utils.sane_wait_for(futures, timeout=timeout) except asyncio.TimeoutError: _log.warning( - 'Shard ID %s failed to wait for chunks (timeout=%.2f) for %d guilds', shard_id, timeout, len(guilds) + "Shard ID %s failed to wait for chunks (timeout=%.2f) for %d guilds", shard_id, timeout, len(guilds) ) for guild in children: if guild.unavailable is False: - self.dispatch('guild_available', guild) + self.dispatch("guild_available", guild) else: - self.dispatch('guild_join', guild) + self.dispatch("guild_join", guild) - self.dispatch('shard_ready', shard_id) + self.dispatch("shard_ready", shard_id) # remove the state try: @@ -1499,38 +1505,38 @@ class AutoShardedConnectionState(ConnectionState): self._ready_task = None # dispatch the event - self.call_handlers('ready') - self.dispatch('ready') + self.call_handlers("ready") + self.dispatch("ready") def parse_ready(self, data) -> None: - if not hasattr(self, '_ready_state'): + if not hasattr(self, "_ready_state"): self._ready_state = asyncio.Queue() - self.user = user = ClientUser(state=self, data=data['user']) + self.user = user = ClientUser(state=self, data=data["user"]) # self._users is a list of Users, we're setting a ClientUser self._users[user.id] = user # type: ignore if self.application_id is None: try: - application = data['application'] + application = data["application"] except KeyError: pass else: - self.application_id = utils._get_as_snowflake(application, 'id') - self.application_flags = ApplicationFlags._from_value(application['flags']) + self.application_id = utils._get_as_snowflake(application, "id") + self.application_flags = ApplicationFlags._from_value(application["flags"]) - for guild_data in data['guilds']: + for guild_data in data["guilds"]: self._add_guild_from_data(guild_data) if self._messages: self._update_message_references() - self.dispatch('connect') - self.dispatch('shard_connect', data['__shard_id__']) + self.dispatch("connect") + self.dispatch("shard_connect", data["__shard_id__"]) if self._ready_task is None: self._ready_task = asyncio.create_task(self._delay_ready()) def parse_resumed(self, data) -> None: - self.dispatch('resumed') - self.dispatch('shard_resumed', data['__shard_id__']) + self.dispatch("resumed") + self.dispatch("shard_resumed", data["__shard_id__"]) diff --git a/discord/sticker.py b/discord/sticker.py index 8ec2ad01..a789fde1 100644 --- a/discord/sticker.py +++ b/discord/sticker.py @@ -33,11 +33,11 @@ from .errors import InvalidData from .enums import StickerType, StickerFormatType, try_enum __all__ = ( - 'StickerPack', - 'StickerItem', - 'Sticker', - 'StandardSticker', - 'GuildSticker', + "StickerPack", + "StickerItem", + "Sticker", + "StandardSticker", + "GuildSticker", ) if TYPE_CHECKING: @@ -102,15 +102,15 @@ class StickerPack(Hashable): """ __slots__ = ( - '_state', - 'id', - 'stickers', - 'name', - 'sku_id', - 'cover_sticker_id', - 'cover_sticker', - 'description', - '_banner', + "_state", + "id", + "stickers", + "name", + "sku_id", + "cover_sticker_id", + "cover_sticker", + "description", + "_banner", ) def __init__(self, *, state: ConnectionState, data: StickerPackPayload) -> None: @@ -118,15 +118,17 @@ class StickerPack(Hashable): self._from_data(data) def _from_data(self, data: StickerPackPayload) -> None: - self.id: int = int(data['id']) - stickers = data['stickers'] - self.stickers: List[StandardSticker] = [StandardSticker(state=self._state, data=sticker) for sticker in stickers] - self.name: str = data['name'] - self.sku_id: int = int(data['sku_id']) - self.cover_sticker_id: int = int(data['cover_sticker_id']) - self.cover_sticker: StandardSticker = get(self.stickers, id=self.cover_sticker_id) # type: ignore - self.description: str = data['description'] - self._banner: int = int(data['banner_asset_id']) + self.id: int = int(data["id"]) + stickers = data["stickers"] + self.stickers: List[StandardSticker] = [ + StandardSticker(state=self._state, data=sticker) for sticker in stickers + ] + self.name: str = data["name"] + self.sku_id: int = int(data["sku_id"]) + self.cover_sticker_id: int = int(data["cover_sticker_id"]) + self.cover_sticker: StandardSticker = get(self.stickers, id=self.cover_sticker_id) # type: ignore + self.description: str = data["description"] + self._banner: int = int(data["banner_asset_id"]) @property def banner(self) -> Asset: @@ -134,7 +136,7 @@ class StickerPack(Hashable): return Asset._from_sticker_banner(self._state, self._banner) def __repr__(self) -> str: - return f'' + return f"" def __str__(self) -> str: return self.name @@ -205,17 +207,17 @@ class StickerItem(_StickerTag): The URL for the sticker's image. """ - __slots__ = ('_state', 'name', 'id', 'format', 'url') + __slots__ = ("_state", "name", "id", "format", "url") def __init__(self, *, state: ConnectionState, data: StickerItemPayload): self._state: ConnectionState = state - self.name: str = data['name'] - self.id: int = int(data['id']) - self.format: StickerFormatType = try_enum(StickerFormatType, data['format_type']) - self.url: str = f'{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}' + self.name: str = data["name"] + self.id: int = int(data["id"]) + self.format: StickerFormatType = try_enum(StickerFormatType, data["format_type"]) + self.url: str = f"{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}" def __repr__(self) -> str: - return f'' + return f"" def __str__(self) -> str: return self.name @@ -236,7 +238,7 @@ class StickerItem(_StickerTag): The retrieved sticker. """ data: StickerPayload = await self._state.http.get_sticker(self.id) - cls, _ = _sticker_factory(data['type']) # type: ignore + cls, _ = _sticker_factory(data["type"]) # type: ignore return cls(state=self._state, data=data) @@ -275,21 +277,21 @@ class Sticker(_StickerTag): The URL for the sticker's image. """ - __slots__ = ('_state', 'id', 'name', 'description', 'format', 'url') + __slots__ = ("_state", "id", "name", "description", "format", "url") def __init__(self, *, state: ConnectionState, data: StickerPayload) -> None: self._state: ConnectionState = state self._from_data(data) def _from_data(self, data: StickerPayload) -> None: - self.id: int = int(data['id']) - self.name: str = data['name'] - self.description: str = data['description'] - self.format: StickerFormatType = try_enum(StickerFormatType, data['format_type']) - self.url: str = f'{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}' + self.id: int = int(data["id"]) + self.name: str = data["name"] + self.description: str = data["description"] + self.format: StickerFormatType = try_enum(StickerFormatType, data["format_type"]) + self.url: str = f"{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}" def __repr__(self) -> str: - return f'' + return f"" def __str__(self) -> str: return self.name @@ -337,21 +339,21 @@ class StandardSticker(Sticker): The sticker's sort order within its pack. """ - __slots__ = ('sort_value', 'pack_id', 'type', 'tags') + __slots__ = ("sort_value", "pack_id", "type", "tags") def _from_data(self, data: StandardStickerPayload) -> None: super()._from_data(data) - self.sort_value: int = data['sort_value'] - self.pack_id: int = int(data['pack_id']) + self.sort_value: int = data["sort_value"] + self.pack_id: int = int(data["pack_id"]) self.type: StickerType = StickerType.standard try: - self.tags: List[str] = [tag.strip() for tag in data['tags'].split(',')] + self.tags: List[str] = [tag.strip() for tag in data["tags"].split(",")] except KeyError: self.tags = [] def __repr__(self) -> str: - return f'' + return f"" async def pack(self) -> StickerPack: """|coro| @@ -371,12 +373,12 @@ class StandardSticker(Sticker): The retrieved sticker pack. """ data: ListPremiumStickerPacksPayload = await self._state.http.list_premium_sticker_packs() - packs = data['sticker_packs'] - pack = find(lambda d: int(d['id']) == self.pack_id, packs) + packs = data["sticker_packs"] + pack = find(lambda d: int(d["id"]) == self.pack_id, packs) if pack: return StickerPack(state=self._state, data=pack) - raise InvalidData(f'Could not find corresponding sticker pack for {self!r}') + raise InvalidData(f"Could not find corresponding sticker pack for {self!r}") class GuildSticker(Sticker): @@ -419,21 +421,21 @@ class GuildSticker(Sticker): The name of a unicode emoji that represents this sticker. """ - __slots__ = ('available', 'guild_id', 'user', 'emoji', 'type', '_cs_guild') + __slots__ = ("available", "guild_id", "user", "emoji", "type", "_cs_guild") def _from_data(self, data: GuildStickerPayload) -> None: super()._from_data(data) - self.available: bool = data['available'] - self.guild_id: int = int(data['guild_id']) - user = data.get('user') + self.available: bool = data["available"] + self.guild_id: int = int(data["guild_id"]) + user = data.get("user") self.user: Optional[User] = self._state.store_user(user) if user else None - self.emoji: str = data['tags'] + self.emoji: str = data["tags"] self.type: StickerType = StickerType.guild def __repr__(self) -> str: - return f'' + return f"" - @cached_slot_property('_cs_guild') + @cached_slot_property("_cs_guild") def guild(self) -> Optional[Guild]: """Optional[:class:`Guild`]: The guild that this sticker is from. Could be ``None`` if the bot is not in the guild. @@ -480,10 +482,10 @@ class GuildSticker(Sticker): payload: EditGuildSticker = {} if name is not MISSING: - payload['name'] = name + payload["name"] = name if description is not MISSING: - payload['description'] = description + payload["description"] = description if emoji is not MISSING: try: @@ -491,9 +493,9 @@ class GuildSticker(Sticker): except TypeError: pass else: - emoji = emoji.replace(' ', '_') + emoji = emoji.replace(" ", "_") - payload['tags'] = emoji + payload["tags"] = emoji data: GuildStickerPayload = await self._state.http.modify_guild_sticker(self.guild_id, self.id, payload, reason) return GuildSticker(state=self._state, data=data) @@ -521,7 +523,9 @@ class GuildSticker(Sticker): await self._state.http.delete_guild_sticker(self.guild_id, self.id, reason) -def _sticker_factory(sticker_type: Literal[1, 2]) -> Tuple[Type[Union[StandardSticker, GuildSticker, Sticker]], StickerType]: +def _sticker_factory( + sticker_type: Literal[1, 2] +) -> Tuple[Type[Union[StandardSticker, GuildSticker, Sticker]], StickerType]: value = try_enum(StickerType, sticker_type) if value == StickerType.standard: return StandardSticker, value diff --git a/discord/team.py b/discord/team.py index 538aaba1..62febc27 100644 --- a/discord/team.py +++ b/discord/team.py @@ -40,8 +40,8 @@ if TYPE_CHECKING: ) __all__ = ( - 'Team', - 'TeamMember', + "Team", + "TeamMember", ) @@ -62,26 +62,26 @@ class Team: .. versionadded:: 1.3 """ - __slots__ = ('_state', 'id', 'name', '_icon', 'owner_id', 'members') + __slots__ = ("_state", "id", "name", "_icon", "owner_id", "members") def __init__(self, state: ConnectionState, data: TeamPayload): self._state: ConnectionState = state - self.id: int = int(data['id']) - self.name: str = data['name'] - self._icon: Optional[str] = data['icon'] - self.owner_id: Optional[int] = utils._get_as_snowflake(data, 'owner_user_id') - self.members: List[TeamMember] = [TeamMember(self, self._state, member) for member in data['members']] + self.id: int = int(data["id"]) + self.name: str = data["name"] + self._icon: Optional[str] = data["icon"] + self.owner_id: Optional[int] = utils._get_as_snowflake(data, "owner_user_id") + self.members: List[TeamMember] = [TeamMember(self, self._state, member) for member in data["members"]] def __repr__(self) -> str: - return f'<{self.__class__.__name__} id={self.id} name={self.name}>' + return f"<{self.__class__.__name__} id={self.id} name={self.name}>" @property def icon(self) -> Optional[Asset]: """Optional[:class:`.Asset`]: Retrieves the team's icon asset, if any.""" if self._icon is None: return None - return Asset._from_icon(self._state, self.id, self._icon, path='team') + return Asset._from_icon(self._state, self.id, self._icon, path="team") @property def owner(self) -> Optional[TeamMember]: @@ -130,16 +130,16 @@ class TeamMember(BaseUser): The membership state of the member (e.g. invited or accepted) """ - __slots__ = ('team', 'membership_state', 'permissions') + __slots__ = ("team", "membership_state", "permissions") def __init__(self, team: Team, state: ConnectionState, data: TeamMemberPayload): self.team: Team = team - self.membership_state: TeamMembershipState = try_enum(TeamMembershipState, data['membership_state']) - self.permissions: List[str] = data['permissions'] - super().__init__(state=state, data=data['user']) + self.membership_state: TeamMembershipState = try_enum(TeamMembershipState, data["membership_state"]) + self.permissions: List[str] = data["permissions"] + super().__init__(state=state, data=data["user"]) def __repr__(self) -> str: return ( - f'<{self.__class__.__name__} id={self.id} name={self.name!r} ' - f'discriminator={self.discriminator!r} membership_state={self.membership_state!r}>' + f"<{self.__class__.__name__} id={self.id} name={self.name!r} " + f"discriminator={self.discriminator!r} membership_state={self.membership_state!r}>" ) diff --git a/discord/template.py b/discord/template.py index 30af3a4d..6df6c69e 100644 --- a/discord/template.py +++ b/discord/template.py @@ -29,9 +29,7 @@ from .utils import parse_time, _get_as_snowflake, _bytes_to_base64_data, MISSING from .enums import VoiceRegion from .guild import Guild -__all__ = ( - 'Template', -) +__all__ = ("Template",) if TYPE_CHECKING: import datetime @@ -44,7 +42,7 @@ class _FriendlyHttpAttributeErrorHelper: __slots__ = () def __getattr__(self, attr): - raise AttributeError('PartialTemplateState does not support http methods.') + raise AttributeError("PartialTemplateState does not support http methods.") class _PartialTemplateState: @@ -84,7 +82,7 @@ class _PartialTemplateState: return [] def __getattr__(self, attr): - raise AttributeError(f'PartialTemplateState does not support {attr!r}.') + raise AttributeError(f"PartialTemplateState does not support {attr!r}.") class Template: @@ -118,16 +116,16 @@ class Template: """ __slots__ = ( - 'code', - 'uses', - 'name', - 'description', - 'creator', - 'created_at', - 'updated_at', - 'source_guild', - 'is_dirty', - '_state', + "code", + "uses", + "name", + "description", + "creator", + "created_at", + "updated_at", + "source_guild", + "is_dirty", + "_state", ) def __init__(self, *, state: ConnectionState, data: TemplatePayload) -> None: @@ -135,35 +133,35 @@ class Template: self._store(data) def _store(self, data: TemplatePayload) -> None: - self.code: str = data['code'] - self.uses: int = data['usage_count'] - self.name: str = data['name'] - self.description: Optional[str] = data['description'] - creator_data = data.get('creator') + self.code: str = data["code"] + self.uses: int = data["usage_count"] + self.name: str = data["name"] + self.description: Optional[str] = data["description"] + creator_data = data.get("creator") self.creator: Optional[User] = None if creator_data is None else self._state.create_user(creator_data) - self.created_at: Optional[datetime.datetime] = parse_time(data.get('created_at')) - self.updated_at: Optional[datetime.datetime] = parse_time(data.get('updated_at')) + self.created_at: Optional[datetime.datetime] = parse_time(data.get("created_at")) + self.updated_at: Optional[datetime.datetime] = parse_time(data.get("updated_at")) - guild_id = int(data['source_guild_id']) + guild_id = int(data["source_guild_id"]) guild: Optional[Guild] = self._state._get_guild(guild_id) self.source_guild: Guild if guild is None: - source_serialised = data['serialized_source_guild'] - source_serialised['id'] = guild_id + source_serialised = data["serialized_source_guild"] + source_serialised["id"] = guild_id state = _PartialTemplateState(state=self._state) # Guild expects a ConnectionState, we're passing a _PartialTemplateState self.source_guild = Guild(data=source_serialised, state=state) # type: ignore else: self.source_guild = guild - self.is_dirty: Optional[bool] = data.get('is_dirty', None) + self.is_dirty: Optional[bool] = data.get("is_dirty", None) def __repr__(self) -> str: return ( - f'