diff --git a/discord/__init__.py b/discord/__init__.py index 1e74cf91..3102c721 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -9,13 +9,13 @@ A basic wrapper for the Discord API. """ -__title__ = 'discord' -__author__ = 'Rapptz' -__license__ = 'MIT' -__copyright__ = 'Copyright 2015-present Rapptz' -__version__ = '2.0.0a' +__title__ = "discord" +__author__ = "Rapptz" +__license__ = "MIT" +__copyright__ = "Copyright 2015-present Rapptz" +__version__ = "2.0.0a" -__path__ = __import__('pkgutil').extend_path(__path__, __name__) +__path__ = __import__("pkgutil").extend_path(__path__, __name__) import logging from typing import NamedTuple, Literal @@ -69,6 +69,8 @@ class VersionInfo(NamedTuple): serial: int -version_info: VersionInfo = VersionInfo(major=2, minor=0, micro=0, releaselevel='alpha', serial=0) +version_info: VersionInfo = VersionInfo( + major=2, minor=0, micro=0, releaselevel="alpha", serial=0 +) logging.getLogger(__name__).addHandler(logging.NullHandler()) diff --git a/discord/__main__.py b/discord/__main__.py index 513b0cb3..54d94c04 100644 --- a/discord/__main__.py +++ b/discord/__main__.py @@ -31,26 +31,37 @@ import pkg_resources import aiohttp import platform + def show_version(): entries = [] - entries.append('- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(sys.version_info)) + entries.append( + "- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format( + sys.version_info + ) + ) version_info = discord.version_info - entries.append('- discord.py v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}'.format(version_info)) - if version_info.releaselevel != 'final': - pkg = pkg_resources.get_distribution('discord.py') + entries.append( + "- discord.py v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format( + version_info + ) + ) + if version_info.releaselevel != "final": + pkg = pkg_resources.get_distribution("discord.py") if pkg: - entries.append(f' - discord.py pkg_resources: v{pkg.version}') + entries.append(f" - discord.py pkg_resources: v{pkg.version}") - entries.append(f'- aiohttp v{aiohttp.__version__}') + entries.append(f"- aiohttp v{aiohttp.__version__}") uname = platform.uname() - entries.append('- system info: {0.system} {0.release} {0.version}'.format(uname)) - print('\n'.join(entries)) + entries.append("- system info: {0.system} {0.release} {0.version}".format(uname)) + print("\n".join(entries)) + def core(parser, args): if args.version: show_version() + _bot_template = """#!/usr/bin/env python3 from discord.ext import commands @@ -120,7 +131,7 @@ def setup(bot): bot.add_cog({name}(bot)) ''' -_cog_extras = ''' +_cog_extras = """ def cog_unload(self): # clean up logic goes here pass @@ -149,22 +160,22 @@ _cog_extras = ''' # called after a command is called here pass -''' +""" # certain file names and directory names are forbidden # see: https://msdn.microsoft.com/en-us/library/windows/desktop/aa365247%28v=vs.85%29.aspx # although some of this doesn't apply to Linux, we might as well be consistent _base_table = { - '<': '-', - '>': '-', - ':': '-', - '"': '-', + "<": "-", + ">": "-", + ":": "-", + '"': "-", # '/': '-', these are fine # '\\': '-', - '|': '-', - '?': '-', - '*': '-', + "|": "-", + "?": "-", + "*": "-", } # NUL (0) and 1-31 are disallowed @@ -172,21 +183,45 @@ _base_table.update((chr(i), None) for i in range(32)) _translation_table = str.maketrans(_base_table) + def to_path(parser, name, *, replace_spaces=False): if isinstance(name, Path): return name - if sys.platform == 'win32': - forbidden = ('CON', 'PRN', 'AUX', 'NUL', 'COM1', 'COM2', 'COM3', 'COM4', 'COM5', 'COM6', 'COM7', \ - 'COM8', 'COM9', 'LPT1', 'LPT2', 'LPT3', 'LPT4', 'LPT5', 'LPT6', 'LPT7', 'LPT8', 'LPT9') + if sys.platform == "win32": + forbidden = ( + "CON", + "PRN", + "AUX", + "NUL", + "COM1", + "COM2", + "COM3", + "COM4", + "COM5", + "COM6", + "COM7", + "COM8", + "COM9", + "LPT1", + "LPT2", + "LPT3", + "LPT4", + "LPT5", + "LPT6", + "LPT7", + "LPT8", + "LPT9", + ) if len(name) <= 4 and name.upper() in forbidden: - parser.error('invalid directory name given, use a different one') + parser.error("invalid directory name given, use a different one") name = name.translate(_translation_table) if replace_spaces: - name = name.replace(' ', '-') + name = name.replace(" ", "-") return Path(name) + def newbot(parser, args): new_directory = to_path(parser, args.directory) / to_path(parser, args.name) @@ -195,106 +230,147 @@ def newbot(parser, args): try: new_directory.mkdir(exist_ok=True, parents=True) except OSError as exc: - parser.error(f'could not create our bot directory ({exc})') + parser.error(f"could not create our bot directory ({exc})") - cogs = new_directory / 'cogs' + cogs = new_directory / "cogs" try: cogs.mkdir(exist_ok=True) - init = cogs / '__init__.py' + init = cogs / "__init__.py" init.touch() except OSError as exc: - print(f'warning: could not create cogs directory ({exc})') + print(f"warning: could not create cogs directory ({exc})") try: - with open(str(new_directory / 'config.py'), 'w', encoding='utf-8') as fp: + with open(str(new_directory / "config.py"), "w", encoding="utf-8") as fp: fp.write('token = "place your token here"\ncogs = []\n') except OSError as exc: - parser.error(f'could not create config file ({exc})') + parser.error(f"could not create config file ({exc})") try: - with open(str(new_directory / 'bot.py'), 'w', encoding='utf-8') as fp: - base = 'Bot' if not args.sharded else 'AutoShardedBot' + with open(str(new_directory / "bot.py"), "w", encoding="utf-8") as fp: + base = "Bot" if not args.sharded else "AutoShardedBot" fp.write(_bot_template.format(base=base, prefix=args.prefix)) except OSError as exc: - parser.error(f'could not create bot file ({exc})') + parser.error(f"could not create bot file ({exc})") if not args.no_git: try: - with open(str(new_directory / '.gitignore'), 'w', encoding='utf-8') as fp: + with open(str(new_directory / ".gitignore"), "w", encoding="utf-8") as fp: fp.write(_gitignore_template) except OSError as exc: - print(f'warning: could not create .gitignore file ({exc})') + print(f"warning: could not create .gitignore file ({exc})") + + print("successfully made bot at", new_directory) - print('successfully made bot at', new_directory) def newcog(parser, args): cog_dir = to_path(parser, args.directory) try: cog_dir.mkdir(exist_ok=True) except OSError as exc: - print(f'warning: could not create cogs directory ({exc})') + print(f"warning: could not create cogs directory ({exc})") directory = cog_dir / to_path(parser, args.name) - directory = directory.with_suffix('.py') + directory = directory.with_suffix(".py") try: - with open(str(directory), 'w', encoding='utf-8') as fp: - attrs = '' - extra = _cog_extras if args.full else '' + with open(str(directory), "w", encoding="utf-8") as fp: + attrs = "" + extra = _cog_extras if args.full else "" if args.class_name: name = args.class_name else: name = str(directory.stem) - if '-' in name or '_' in name: - translation = str.maketrans('-_', ' ') - name = name.translate(translation).title().replace(' ', '') + if "-" in name or "_" in name: + translation = str.maketrans("-_", " ") + name = name.translate(translation).title().replace(" ", "") else: name = name.title() if args.display_name: attrs += f', name="{args.display_name}"' if args.hide_commands: - attrs += ', command_attrs=dict(hidden=True)' + attrs += ", command_attrs=dict(hidden=True)" fp.write(_cog_template.format(name=name, extra=extra, attrs=attrs)) except OSError as exc: - parser.error(f'could not create cog file ({exc})') + parser.error(f"could not create cog file ({exc})") else: - print('successfully made cog at', directory) + print("successfully made cog at", directory) + def add_newbot_args(subparser): - parser = subparser.add_parser('newbot', help='creates a command bot project quickly') + parser = subparser.add_parser( + "newbot", help="creates a command bot project quickly" + ) parser.set_defaults(func=newbot) - parser.add_argument('name', help='the bot project name') - parser.add_argument('directory', help='the directory to place it in (default: .)', nargs='?', default=Path.cwd()) - parser.add_argument('--prefix', help='the bot prefix (default: $)', default='$', metavar='') - parser.add_argument('--sharded', help='whether to use AutoShardedBot', action='store_true') - parser.add_argument('--no-git', help='do not create a .gitignore file', action='store_true', dest='no_git') + parser.add_argument("name", help="the bot project name") + parser.add_argument( + "directory", + help="the directory to place it in (default: .)", + nargs="?", + default=Path.cwd(), + ) + parser.add_argument( + "--prefix", help="the bot prefix (default: $)", default="$", metavar="" + ) + parser.add_argument( + "--sharded", help="whether to use AutoShardedBot", action="store_true" + ) + parser.add_argument( + "--no-git", + help="do not create a .gitignore file", + action="store_true", + dest="no_git", + ) + def add_newcog_args(subparser): - parser = subparser.add_parser('newcog', help='creates a new cog template quickly') + parser = subparser.add_parser("newcog", help="creates a new cog template quickly") parser.set_defaults(func=newcog) - parser.add_argument('name', help='the cog name') - parser.add_argument('directory', help='the directory to place it in (default: cogs)', nargs='?', default=Path('cogs')) - parser.add_argument('--class-name', help='the class name of the cog (default: )', dest='class_name') - parser.add_argument('--display-name', help='the cog name (default: )') - parser.add_argument('--hide-commands', help='whether to hide all commands in the cog', action='store_true') - parser.add_argument('--full', help='add all special methods as well', action='store_true') + parser.add_argument("name", help="the cog name") + parser.add_argument( + "directory", + help="the directory to place it in (default: cogs)", + nargs="?", + default=Path("cogs"), + ) + parser.add_argument( + "--class-name", + help="the class name of the cog (default: )", + dest="class_name", + ) + parser.add_argument("--display-name", help="the cog name (default: )") + parser.add_argument( + "--hide-commands", + help="whether to hide all commands in the cog", + action="store_true", + ) + parser.add_argument( + "--full", help="add all special methods as well", action="store_true" + ) + def parse_args(): - parser = argparse.ArgumentParser(prog='discord', description='Tools for helping with discord.py') - parser.add_argument('-v', '--version', action='store_true', help='shows the library version') + parser = argparse.ArgumentParser( + prog="discord", description="Tools for helping with discord.py" + ) + parser.add_argument( + "-v", "--version", action="store_true", help="shows the library version" + ) parser.set_defaults(func=core) - subparser = parser.add_subparsers(dest='subcommand', title='subcommands') + subparser = parser.add_subparsers(dest="subcommand", title="subcommands") add_newbot_args(subparser) add_newcog_args(subparser) return parser, parser.parse_args() + def main(): parser, args = parse_args() args.func(parser, args) -if __name__ == '__main__': + +if __name__ == "__main__": main() diff --git a/discord/abc.py b/discord/abc.py index fd2dc4bb..7707cea0 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -56,15 +56,15 @@ from .sticker import GuildSticker, StickerItem from . import utils __all__ = ( - 'Snowflake', - 'User', - 'PrivateChannel', - 'GuildChannel', - 'Messageable', - 'Connectable', + "Snowflake", + "User", + "PrivateChannel", + "GuildChannel", + "Messageable", + "Connectable", ) -T = TypeVar('T', bound=VoiceProtocol) +T = TypeVar("T", bound=VoiceProtocol) if TYPE_CHECKING: from datetime import datetime @@ -89,7 +89,9 @@ if TYPE_CHECKING: OverwriteType, ) - PartialMessageableChannel = Union[TextChannel, Thread, DMChannel, PartialMessageable] + PartialMessageableChannel = Union[ + TextChannel, Thread, DMChannel, PartialMessageable + ] MessageableChannel = Union[PartialMessageableChannel, GroupChannel] SnowflakeTime = Union["Snowflake", datetime] @@ -98,7 +100,7 @@ MISSING = utils.MISSING class _Undefined: def __repr__(self) -> str: - return 'see-below' + return "see-below" _undefined: Any = _Undefined() @@ -189,23 +191,23 @@ class PrivateChannel(Snowflake, Protocol): class _Overwrites: - __slots__ = ('id', 'allow', 'deny', 'type') + __slots__ = ("id", "allow", "deny", "type") ROLE = 0 MEMBER = 1 def __init__(self, data: PermissionOverwritePayload): - self.id: int = int(data['id']) - self.allow: int = int(data.get('allow', 0)) - self.deny: int = int(data.get('deny', 0)) - self.type: OverwriteType = data['type'] + self.id: int = int(data["id"]) + self.allow: int = int(data.get("allow", 0)) + self.deny: int = int(data.get("deny", 0)) + self.type: OverwriteType = data["type"] def _asdict(self) -> PermissionOverwritePayload: return { - 'id': self.id, - 'allow': str(self.allow), - 'deny': str(self.deny), - 'type': self.type, + "id": self.id, + "allow": str(self.allow), + "deny": str(self.deny), + "type": self.type, } def is_role(self) -> bool: @@ -215,7 +217,7 @@ class _Overwrites: return self.type == 1 -GCH = TypeVar('GCH', bound='GuildChannel') +GCH = TypeVar("GCH", bound="GuildChannel") class GuildChannel: @@ -254,7 +256,9 @@ class GuildChannel: if TYPE_CHECKING: - def __init__(self, *, state: ConnectionState, guild: Guild, data: Dict[str, Any]): + def __init__( + self, *, state: ConnectionState, guild: Guild, data: Dict[str, Any] + ): ... def __str__(self) -> str: @@ -276,11 +280,13 @@ class GuildChannel: reason: Optional[str], ) -> None: if position < 0: - raise InvalidArgument('Channel position cannot be less than 0.') + raise InvalidArgument("Channel position cannot be less than 0.") http = self._state.http bucket = self._sorting_bucket - channels: List[GuildChannel] = [c for c in self.guild.channels if c._sorting_bucket == bucket] + channels: List[GuildChannel] = [ + c for c in self.guild.channels if c._sorting_bucket == bucket + ] channels.sort(key=lambda c: c.position) @@ -291,106 +297,124 @@ class GuildChannel: # not there somehow lol return else: - index = next((i for i, c in enumerate(channels) if c.position >= position), len(channels)) + index = next( + (i for i, c in enumerate(channels) if c.position >= position), + len(channels), + ) # add ourselves at our designated position channels.insert(index, self) payload = [] for index, c in enumerate(channels): - d: Dict[str, Any] = {'id': c.id, 'position': index} + d: Dict[str, Any] = {"id": c.id, "position": index} if parent_id is not _undefined and c.id == self.id: d.update(parent_id=parent_id, lock_permissions=lock_permissions) payload.append(d) await http.bulk_channel_update(self.guild.id, payload, reason=reason) - async def _edit(self, options: Dict[str, Any], reason: Optional[str]) -> Optional[ChannelPayload]: + async def _edit( + self, options: Dict[str, Any], reason: Optional[str] + ) -> Optional[ChannelPayload]: try: - parent = options.pop('category') + parent = options.pop("category") except KeyError: parent_id = _undefined else: parent_id = parent and parent.id try: - options['rate_limit_per_user'] = options.pop('slowmode_delay') + options["rate_limit_per_user"] = options.pop("slowmode_delay") except KeyError: pass try: - rtc_region = options.pop('rtc_region') + rtc_region = options.pop("rtc_region") except KeyError: pass else: - options['rtc_region'] = None if rtc_region is None else str(rtc_region) + options["rtc_region"] = None if rtc_region is None else str(rtc_region) try: - video_quality_mode = options.pop('video_quality_mode') + video_quality_mode = options.pop("video_quality_mode") except KeyError: pass else: - options['video_quality_mode'] = int(video_quality_mode) + options["video_quality_mode"] = int(video_quality_mode) - lock_permissions = options.pop('sync_permissions', False) + lock_permissions = options.pop("sync_permissions", False) try: - position = options.pop('position') + position = options.pop("position") except KeyError: if parent_id is not _undefined: if lock_permissions: category = self.guild.get_channel(parent_id) if category: - options['permission_overwrites'] = [c._asdict() for c in category._overwrites] - options['parent_id'] = parent_id + options["permission_overwrites"] = [ + c._asdict() for c in category._overwrites + ] + options["parent_id"] = parent_id elif lock_permissions and self.category_id is not None: # if we're syncing permissions on a pre-existing channel category without changing it # we need to update the permissions to point to the pre-existing category category = self.guild.get_channel(self.category_id) if category: - options['permission_overwrites'] = [c._asdict() for c in category._overwrites] + options["permission_overwrites"] = [ + c._asdict() for c in category._overwrites + ] else: - await self._move(position, parent_id=parent_id, lock_permissions=lock_permissions, reason=reason) + await self._move( + position, + parent_id=parent_id, + lock_permissions=lock_permissions, + reason=reason, + ) - overwrites = options.get('overwrites', None) + overwrites = options.get("overwrites", None) if overwrites is not None: perms = [] for target, perm in overwrites.items(): if not isinstance(perm, PermissionOverwrite): - raise InvalidArgument(f'Expected PermissionOverwrite received {perm.__class__.__name__}') + raise InvalidArgument( + f"Expected PermissionOverwrite received {perm.__class__.__name__}" + ) allow, deny = perm.pair() payload = { - 'allow': allow.value, - 'deny': deny.value, - 'id': target.id, + "allow": allow.value, + "deny": deny.value, + "id": target.id, } if isinstance(target, Role): - payload['type'] = _Overwrites.ROLE + payload["type"] = _Overwrites.ROLE else: - payload['type'] = _Overwrites.MEMBER + payload["type"] = _Overwrites.MEMBER perms.append(payload) - options['permission_overwrites'] = perms + options["permission_overwrites"] = perms try: - ch_type = options['type'] + ch_type = options["type"] except KeyError: pass else: if not isinstance(ch_type, ChannelType): - raise InvalidArgument('type field must be of type ChannelType') - options['type'] = ch_type.value + raise InvalidArgument("type field must be of type ChannelType") + options["type"] = ch_type.value if options: - return await self._state.http.edit_channel(self.id, reason=reason, **options) + return await self._state.http.edit_channel( + self.id, reason=reason, **options + ) def _fill_overwrites(self, data: GuildChannelPayload) -> None: self._overwrites = [] everyone_index = 0 everyone_id = self.guild.id - for index, overridden in enumerate(data.get('permission_overwrites', [])): + for index, overridden in enumerate(data.get("permission_overwrites", [])): overwrite = _Overwrites(overridden) self._overwrites.append(overwrite) @@ -429,7 +453,7 @@ class GuildChannel: @property def mention(self) -> str: """:class:`str`: The string that allows you to mention the channel.""" - return f'<#{self.id}>' + return f"<#{self.id}>" @property def created_at(self) -> datetime: @@ -589,7 +613,9 @@ class GuildChannel: try: maybe_everyone = self._overwrites[0] if maybe_everyone.id == self.guild.id: - base.handle_overwrite(allow=maybe_everyone.allow, deny=maybe_everyone.deny) + base.handle_overwrite( + allow=maybe_everyone.allow, deny=maybe_everyone.deny + ) except IndexError: pass @@ -620,7 +646,9 @@ class GuildChannel: try: maybe_everyone = self._overwrites[0] if maybe_everyone.id == self.guild.id: - base.handle_overwrite(allow=maybe_everyone.allow, deny=maybe_everyone.deny) + base.handle_overwrite( + allow=maybe_everyone.allow, deny=maybe_everyone.deny + ) remaining_overwrites = self._overwrites[1:] else: remaining_overwrites = self._overwrites @@ -703,7 +731,9 @@ class GuildChannel: ) -> None: ... - async def set_permissions(self, target, *, overwrite=_undefined, reason=None, **permissions): + async def set_permissions( + self, target, *, overwrite=_undefined, reason=None, **permissions + ): r"""|coro| Sets the channel specific permission overwrites for a target in the @@ -779,18 +809,18 @@ class GuildChannel: elif isinstance(target, Role): perm_type = _Overwrites.ROLE else: - raise InvalidArgument('target parameter must be either Member or Role') + raise InvalidArgument("target parameter must be either Member or Role") if overwrite is _undefined: if len(permissions) == 0: - raise InvalidArgument('No overwrite provided.') + raise InvalidArgument("No overwrite provided.") try: overwrite = PermissionOverwrite(**permissions) except (ValueError, TypeError): - raise InvalidArgument('Invalid permissions given to keyword arguments.') + raise InvalidArgument("Invalid permissions given to keyword arguments.") else: if len(permissions) > 0: - raise InvalidArgument('Cannot mix overwrite and keyword arguments.') + raise InvalidArgument("Cannot mix overwrite and keyword arguments.") # TODO: wait for event @@ -798,9 +828,11 @@ class GuildChannel: await http.delete_channel_permissions(self.id, target.id, reason=reason) elif isinstance(overwrite, PermissionOverwrite): (allow, deny) = overwrite.pair() - await http.edit_channel_permissions(self.id, target.id, allow.value, deny.value, perm_type, reason=reason) + await http.edit_channel_permissions( + self.id, target.id, allow.value, deny.value, perm_type, reason=reason + ) else: - raise InvalidArgument('Invalid overwrite type provided.') + raise InvalidArgument("Invalid overwrite type provided.") async def _clone_impl( self: GCH, @@ -809,19 +841,23 @@ class GuildChannel: name: Optional[str] = None, reason: Optional[str] = None, ) -> GCH: - base_attrs['permission_overwrites'] = [x._asdict() for x in self._overwrites] - base_attrs['parent_id'] = self.category_id - base_attrs['name'] = name or self.name + base_attrs["permission_overwrites"] = [x._asdict() for x in self._overwrites] + base_attrs["parent_id"] = self.category_id + base_attrs["name"] = name or self.name guild_id = self.guild.id cls = self.__class__ - data = await self._state.http.create_channel(guild_id, self.type.value, reason=reason, **base_attrs) + data = await self._state.http.create_channel( + guild_id, self.type.value, reason=reason, **base_attrs + ) obj = cls(state=self._state, guild=self.guild, data=data) # temporarily add it to the cache self.guild._channels[obj.id] = obj # type: ignore return obj - async def clone(self: GCH, *, name: Optional[str] = None, reason: Optional[str] = None) -> GCH: + async def clone( + self: GCH, *, name: Optional[str] = None, reason: Optional[str] = None + ) -> GCH: """|coro| Clones this channel. This creates a channel with the same properties @@ -964,14 +1000,16 @@ class GuildChannel: if not kwargs: return - beginning, end = kwargs.get('beginning'), kwargs.get('end') - before, after = kwargs.get('before'), kwargs.get('after') - offset = kwargs.get('offset', 0) + beginning, end = kwargs.get("beginning"), kwargs.get("end") + before, after = kwargs.get("before"), kwargs.get("after") + offset = kwargs.get("offset", 0) if sum(bool(a) for a in (beginning, end, before, after)) > 1: - raise InvalidArgument('Only one of [before, after, end, beginning] can be used.') + raise InvalidArgument( + "Only one of [before, after, end, beginning] can be used." + ) bucket = self._sorting_bucket - parent_id = kwargs.get('category', MISSING) + parent_id = kwargs.get("category", MISSING) # fmt: off channels: List[GuildChannel] if parent_id not in (MISSING, None): @@ -1008,22 +1046,26 @@ class GuildChannel: elif before: index = next((i for i, c in enumerate(channels) if c.id == before.id), None) elif after: - index = next((i + 1 for i, c in enumerate(channels) if c.id == after.id), None) + index = next( + (i + 1 for i, c in enumerate(channels) if c.id == after.id), None + ) if index is None: - raise InvalidArgument('Could not resolve appropriate move position') + raise InvalidArgument("Could not resolve appropriate move position") channels.insert(max((index + offset), 0), self) payload = [] - lock_permissions = kwargs.get('sync_permissions', False) - reason = kwargs.get('reason') + lock_permissions = kwargs.get("sync_permissions", False) + reason = kwargs.get("reason") for index, channel in enumerate(channels): - d = {'id': channel.id, 'position': index} + d = {"id": channel.id, "position": index} if parent_id is not MISSING and channel.id == self.id: d.update(parent_id=parent_id, lock_permissions=lock_permissions) payload.append(d) - await self._state.http.bulk_channel_update(self.guild.id, payload, reason=reason) + await self._state.http.bulk_channel_update( + self.guild.id, payload, reason=reason + ) async def create_invite( self, @@ -1126,7 +1168,10 @@ class GuildChannel: state = self._state data = await state.http.invites_from_channel(self.id) guild = self.guild - return [Invite(state=state, data=invite, channel=self, guild=guild) for invite in data] + return [ + Invite(state=state, data=invite, channel=self, guild=guild) + for invite in data + ] class Messageable: @@ -1332,14 +1377,18 @@ class Messageable: content = str(content) if content is not None else None if embed is not None and embeds is not None: - raise InvalidArgument('cannot pass both embed and embeds parameter to send()') + raise InvalidArgument( + "cannot pass both embed and embeds parameter to send()" + ) if embed is not None: embed = embed.to_dict() elif embeds is not None: if len(embeds) > 10: - raise InvalidArgument('embeds parameter must be a list of up to 10 elements') + raise InvalidArgument( + "embeds parameter must be a list of up to 10 elements" + ) embeds = [embed.to_dict() for embed in embeds] if stickers is not None: @@ -1347,36 +1396,44 @@ class Messageable: if allowed_mentions is not None: if state.allowed_mentions is not None: - allowed_mentions = state.allowed_mentions.merge(allowed_mentions).to_dict() + allowed_mentions = state.allowed_mentions.merge( + allowed_mentions + ).to_dict() else: allowed_mentions = allowed_mentions.to_dict() else: - allowed_mentions = state.allowed_mentions and state.allowed_mentions.to_dict() + allowed_mentions = ( + state.allowed_mentions and state.allowed_mentions.to_dict() + ) if mention_author is not None: allowed_mentions = allowed_mentions or AllowedMentions().to_dict() - allowed_mentions['replied_user'] = bool(mention_author) + allowed_mentions["replied_user"] = bool(mention_author) if reference is not None: try: reference = reference.to_message_reference_dict() except AttributeError: - raise InvalidArgument('reference parameter must be Message, MessageReference, or PartialMessage') from None + raise InvalidArgument( + "reference parameter must be Message, MessageReference, or PartialMessage" + ) from None if view: - if not hasattr(view, '__discord_ui_view__'): - raise InvalidArgument(f'view parameter must be View not {view.__class__!r}') + if not hasattr(view, "__discord_ui_view__"): + raise InvalidArgument( + f"view parameter must be View not {view.__class__!r}" + ) components = view.to_components() else: components = None if file is not None and files is not None: - raise InvalidArgument('cannot pass both file and files parameter to send()') + raise InvalidArgument("cannot pass both file and files parameter to send()") if file is not None: if not isinstance(file, File): - raise InvalidArgument('file parameter must be File') + raise InvalidArgument("file parameter must be File") try: data = await state.http.send_files( @@ -1397,9 +1454,11 @@ class Messageable: elif files is not None: if len(files) > 10: - raise InvalidArgument('files parameter must be a list of up to 10 elements') + raise InvalidArgument( + "files parameter must be a list of up to 10 elements" + ) elif not all(isinstance(file, File) for file in files): - raise InvalidArgument('files parameter must be a list of File') + raise InvalidArgument("files parameter must be a list of File") try: data = await state.http.send_files( @@ -1594,7 +1653,14 @@ class Messageable: :class:`~discord.Message` The message with the message data parsed. """ - return HistoryIterator(self, limit=limit, before=before, after=after, around=around, oldest_first=oldest_first) + return HistoryIterator( + self, + limit=limit, + before=before, + after=after, + around=around, + oldest_first=oldest_first, + ) class Connectable(Protocol): @@ -1666,13 +1732,13 @@ class Connectable(Protocol): state = self._state if state._get_voice_client(key_id): - raise ClientException('Already connected to a voice channel.') + raise ClientException("Already connected to a voice channel.") client = state._get_client() voice = cls(client, self) if not isinstance(voice, VoiceProtocol): - raise TypeError('Type must meet VoiceProtocol abstract base class.') + raise TypeError("Type must meet VoiceProtocol abstract base class.") state._add_voice_client(key_id, voice) diff --git a/discord/activity.py b/discord/activity.py index 51205377..8daadca9 100644 --- a/discord/activity.py +++ b/discord/activity.py @@ -34,12 +34,12 @@ from .partial_emoji import PartialEmoji from .utils import _get_as_snowflake __all__ = ( - 'BaseActivity', - 'Activity', - 'Streaming', - 'Game', - 'Spotify', - 'CustomActivity', + "BaseActivity", + "Activity", + "Streaming", + "Game", + "Spotify", + "CustomActivity", ) """If curious, this is the current schema for an activity. @@ -119,10 +119,10 @@ class BaseActivity: .. versionadded:: 1.3 """ - __slots__ = ('_created_at',) + __slots__ = ("_created_at",) def __init__(self, **kwargs): - self._created_at: Optional[float] = kwargs.pop('created_at', None) + self._created_at: Optional[float] = kwargs.pop("created_at", None) @property def created_at(self) -> Optional[datetime.datetime]: @@ -131,7 +131,9 @@ class BaseActivity: .. versionadded:: 1.3 """ if self._created_at is not None: - return datetime.datetime.fromtimestamp(self._created_at / 1000, tz=datetime.timezone.utc) + return datetime.datetime.fromtimestamp( + self._created_at / 1000, tz=datetime.timezone.utc + ) def to_dict(self) -> ActivityPayload: raise NotImplementedError @@ -199,58 +201,62 @@ class Activity(BaseActivity): """ __slots__ = ( - 'state', - 'details', - '_created_at', - 'timestamps', - 'assets', - 'party', - 'flags', - 'sync_id', - 'session_id', - 'type', - 'name', - 'url', - 'application_id', - 'emoji', - 'buttons', + "state", + "details", + "_created_at", + "timestamps", + "assets", + "party", + "flags", + "sync_id", + "session_id", + "type", + "name", + "url", + "application_id", + "emoji", + "buttons", ) def __init__(self, **kwargs): super().__init__(**kwargs) - self.state: Optional[str] = kwargs.pop('state', None) - self.details: Optional[str] = kwargs.pop('details', None) - self.timestamps: ActivityTimestamps = kwargs.pop('timestamps', {}) - self.assets: ActivityAssets = kwargs.pop('assets', {}) - self.party: ActivityParty = kwargs.pop('party', {}) - self.application_id: Optional[int] = _get_as_snowflake(kwargs, 'application_id') - self.name: Optional[str] = kwargs.pop('name', None) - self.url: Optional[str] = kwargs.pop('url', None) - self.flags: int = kwargs.pop('flags', 0) - self.sync_id: Optional[str] = kwargs.pop('sync_id', None) - self.session_id: Optional[str] = kwargs.pop('session_id', None) - self.buttons: List[ActivityButton] = kwargs.pop('buttons', []) + self.state: Optional[str] = kwargs.pop("state", None) + self.details: Optional[str] = kwargs.pop("details", None) + self.timestamps: ActivityTimestamps = kwargs.pop("timestamps", {}) + self.assets: ActivityAssets = kwargs.pop("assets", {}) + self.party: ActivityParty = kwargs.pop("party", {}) + self.application_id: Optional[int] = _get_as_snowflake(kwargs, "application_id") + self.name: Optional[str] = kwargs.pop("name", None) + self.url: Optional[str] = kwargs.pop("url", None) + self.flags: int = kwargs.pop("flags", 0) + self.sync_id: Optional[str] = kwargs.pop("sync_id", None) + self.session_id: Optional[str] = kwargs.pop("session_id", None) + self.buttons: List[ActivityButton] = kwargs.pop("buttons", []) - activity_type = kwargs.pop('type', -1) + activity_type = kwargs.pop("type", -1) self.type: ActivityType = ( - activity_type if isinstance(activity_type, ActivityType) else try_enum(ActivityType, activity_type) + activity_type + if isinstance(activity_type, ActivityType) + else try_enum(ActivityType, activity_type) ) - emoji = kwargs.pop('emoji', None) - self.emoji: Optional[PartialEmoji] = PartialEmoji.from_dict(emoji) if emoji is not None else None + emoji = kwargs.pop("emoji", None) + self.emoji: Optional[PartialEmoji] = ( + PartialEmoji.from_dict(emoji) if emoji is not None else None + ) def __repr__(self) -> str: attrs = ( - ('type', self.type), - ('name', self.name), - ('url', self.url), - ('details', self.details), - ('application_id', self.application_id), - ('session_id', self.session_id), - ('emoji', self.emoji), + ("type", self.type), + ("name", self.name), + ("url", self.url), + ("details", self.details), + ("application_id", self.application_id), + ("session_id", self.session_id), + ("emoji", self.emoji), ) - inner = ' '.join('%s=%r' % t for t in attrs) - return f'' + inner = " ".join("%s=%r" % t for t in attrs) + return f"" def to_dict(self) -> Dict[str, Any]: ret: Dict[str, Any] = {} @@ -263,16 +269,16 @@ class Activity(BaseActivity): continue ret[attr] = value - ret['type'] = int(self.type) + ret["type"] = int(self.type) if self.emoji: - ret['emoji'] = self.emoji.to_dict() + ret["emoji"] = self.emoji.to_dict() return ret @property def start(self) -> Optional[datetime.datetime]: """Optional[:class:`datetime.datetime`]: When the user started doing this activity in UTC, if applicable.""" try: - timestamp = self.timestamps['start'] / 1000 + timestamp = self.timestamps["start"] / 1000 except KeyError: return None else: @@ -282,7 +288,7 @@ class Activity(BaseActivity): def end(self) -> Optional[datetime.datetime]: """Optional[:class:`datetime.datetime`]: When the user will stop doing this activity in UTC, if applicable.""" try: - timestamp = self.timestamps['end'] / 1000 + timestamp = self.timestamps["end"] / 1000 except KeyError: return None else: @@ -295,11 +301,11 @@ class Activity(BaseActivity): return None try: - large_image = self.assets['large_image'] + large_image = self.assets["large_image"] except KeyError: return None else: - return Asset.BASE + f'/app-assets/{self.application_id}/{large_image}.png' + return Asset.BASE + f"/app-assets/{self.application_id}/{large_image}.png" @property def small_image_url(self) -> Optional[str]: @@ -308,21 +314,21 @@ class Activity(BaseActivity): return None try: - small_image = self.assets['small_image'] + small_image = self.assets["small_image"] except KeyError: return None else: - return Asset.BASE + f'/app-assets/{self.application_id}/{small_image}.png' + return Asset.BASE + f"/app-assets/{self.application_id}/{small_image}.png" @property def large_image_text(self) -> Optional[str]: """Optional[:class:`str`]: Returns the large image asset hover text of this activity if applicable.""" - return self.assets.get('large_text', None) + return self.assets.get("large_text", None) @property def small_image_text(self) -> Optional[str]: """Optional[:class:`str`]: Returns the small image asset hover text of this activity if applicable.""" - return self.assets.get('small_text', None) + return self.assets.get("small_text", None) class Game(BaseActivity): @@ -359,20 +365,20 @@ class Game(BaseActivity): The game's name. """ - __slots__ = ('name', '_end', '_start') + __slots__ = ("name", "_end", "_start") def __init__(self, name: str, **extra): super().__init__(**extra) self.name: str = name try: - timestamps: ActivityTimestamps = extra['timestamps'] + timestamps: ActivityTimestamps = extra["timestamps"] except KeyError: self._start = 0 self._end = 0 else: - self._start = timestamps.get('start', 0) - self._end = timestamps.get('end', 0) + self._start = timestamps.get("start", 0) + self._end = timestamps.get("end", 0) @property def type(self) -> ActivityType: @@ -386,29 +392,33 @@ class Game(BaseActivity): def start(self) -> Optional[datetime.datetime]: """Optional[:class:`datetime.datetime`]: When the user started playing this game in UTC, if applicable.""" if self._start: - return datetime.datetime.fromtimestamp(self._start / 1000, tz=datetime.timezone.utc) + return datetime.datetime.fromtimestamp( + self._start / 1000, tz=datetime.timezone.utc + ) return None @property def end(self) -> Optional[datetime.datetime]: """Optional[:class:`datetime.datetime`]: When the user will stop playing this game in UTC, if applicable.""" if self._end: - return datetime.datetime.fromtimestamp(self._end / 1000, tz=datetime.timezone.utc) + return datetime.datetime.fromtimestamp( + self._end / 1000, tz=datetime.timezone.utc + ) return None def __str__(self) -> str: return str(self.name) def __repr__(self) -> str: - return f'' + return f"" def to_dict(self) -> Dict[str, Any]: timestamps: Dict[str, Any] = {} if self._start: - timestamps['start'] = self._start + timestamps["start"] = self._start if self._end: - timestamps['end'] = self._end + timestamps["end"] = self._end # fmt: off return { @@ -473,16 +483,16 @@ class Streaming(BaseActivity): A dictionary comprising of similar keys than those in :attr:`Activity.assets`. """ - __slots__ = ('platform', 'name', 'game', 'url', 'details', 'assets') + __slots__ = ("platform", "name", "game", "url", "details", "assets") def __init__(self, *, name: Optional[str], url: str, **extra: Any): super().__init__(**extra) self.platform: Optional[str] = name - self.name: Optional[str] = extra.pop('details', name) - self.game: Optional[str] = extra.pop('state', None) + self.name: Optional[str] = extra.pop("details", name) + self.game: Optional[str] = extra.pop("state", None) self.url: str = url - self.details: Optional[str] = extra.pop('details', self.name) # compatibility - self.assets: ActivityAssets = extra.pop('assets', {}) + self.details: Optional[str] = extra.pop("details", self.name) # compatibility + self.assets: ActivityAssets = extra.pop("assets", {}) @property def type(self) -> ActivityType: @@ -496,7 +506,7 @@ class Streaming(BaseActivity): return str(self.name) def __repr__(self) -> str: - return f'' + return f"" @property def twitch_name(self): @@ -507,11 +517,11 @@ class Streaming(BaseActivity): """ try: - name = self.assets['large_image'] + name = self.assets["large_image"] except KeyError: return None else: - return name[7:] if name[:7] == 'twitch:' else None + return name[7:] if name[:7] == "twitch:" else None def to_dict(self) -> Dict[str, Any]: # fmt: off @@ -523,11 +533,15 @@ class Streaming(BaseActivity): } # fmt: on if self.details: - ret['details'] = self.details + ret["details"] = self.details return ret def __eq__(self, other: Any) -> bool: - return isinstance(other, Streaming) and other.name == self.name and other.url == self.url + return ( + isinstance(other, Streaming) + and other.name == self.name + and other.url == self.url + ) def __ne__(self, other: Any) -> bool: return not self.__eq__(other) @@ -559,17 +573,26 @@ class Spotify: Returns the string 'Spotify'. """ - __slots__ = ('_state', '_details', '_timestamps', '_assets', '_party', '_sync_id', '_session_id', '_created_at') + __slots__ = ( + "_state", + "_details", + "_timestamps", + "_assets", + "_party", + "_sync_id", + "_session_id", + "_created_at", + ) def __init__(self, **data): - self._state: str = data.pop('state', '') - self._details: str = data.pop('details', '') - self._timestamps: Dict[str, int] = data.pop('timestamps', {}) - self._assets: ActivityAssets = data.pop('assets', {}) - self._party: ActivityParty = data.pop('party', {}) - self._sync_id: str = data.pop('sync_id') - self._session_id: str = data.pop('session_id') - self._created_at: Optional[float] = data.pop('created_at', None) + self._state: str = data.pop("state", "") + self._details: str = data.pop("details", "") + self._timestamps: Dict[str, int] = data.pop("timestamps", {}) + self._assets: ActivityAssets = data.pop("assets", {}) + self._party: ActivityParty = data.pop("party", {}) + self._sync_id: str = data.pop("sync_id") + self._session_id: str = data.pop("session_id") + self._created_at: Optional[float] = data.pop("created_at", None) @property def type(self) -> ActivityType: @@ -586,7 +609,9 @@ class Spotify: .. versionadded:: 1.3 """ if self._created_at is not None: - return datetime.datetime.fromtimestamp(self._created_at / 1000, tz=datetime.timezone.utc) + return datetime.datetime.fromtimestamp( + self._created_at / 1000, tz=datetime.timezone.utc + ) @property def colour(self) -> Colour: @@ -604,21 +629,21 @@ class Spotify: def to_dict(self) -> Dict[str, Any]: return { - 'flags': 48, # SYNC | PLAY - 'name': 'Spotify', - 'assets': self._assets, - 'party': self._party, - 'sync_id': self._sync_id, - 'session_id': self._session_id, - 'timestamps': self._timestamps, - 'details': self._details, - 'state': self._state, + "flags": 48, # SYNC | PLAY + "name": "Spotify", + "assets": self._assets, + "party": self._party, + "sync_id": self._sync_id, + "session_id": self._session_id, + "timestamps": self._timestamps, + "details": self._details, + "state": self._state, } @property def name(self) -> str: """:class:`str`: The activity's name. This will always return "Spotify".""" - return 'Spotify' + return "Spotify" def __eq__(self, other: Any) -> bool: return ( @@ -635,10 +660,10 @@ class Spotify: return hash(self._session_id) def __str__(self) -> str: - return 'Spotify' + return "Spotify" def __repr__(self) -> str: - return f'' + return f"" @property def title(self) -> str: @@ -648,7 +673,7 @@ class Spotify: @property def artists(self) -> List[str]: """List[:class:`str`]: The artists of the song being played.""" - return self._state.split('; ') + return self._state.split("; ") @property def artist(self) -> str: @@ -662,16 +687,16 @@ class Spotify: @property def album(self) -> str: """:class:`str`: The album that the song being played belongs to.""" - return self._assets.get('large_text', '') + return self._assets.get("large_text", "") @property def album_cover_url(self) -> str: """:class:`str`: The album cover image URL from Spotify's CDN.""" - large_image = self._assets.get('large_image', '') - if large_image[:8] != 'spotify:': - return '' + large_image = self._assets.get("large_image", "") + if large_image[:8] != "spotify:": + return "" album_image_id = large_image[8:] - return 'https://i.scdn.co/image/' + album_image_id + return "https://i.scdn.co/image/" + album_image_id @property def track_id(self) -> str: @@ -684,17 +709,21 @@ class Spotify: .. versionadded:: 2.0 """ - return f'https://open.spotify.com/track/{self.track_id}' + return f"https://open.spotify.com/track/{self.track_id}" @property def start(self) -> datetime.datetime: """:class:`datetime.datetime`: When the user started playing this song in UTC.""" - return datetime.datetime.fromtimestamp(self._timestamps['start'] / 1000, tz=datetime.timezone.utc) + return datetime.datetime.fromtimestamp( + self._timestamps["start"] / 1000, tz=datetime.timezone.utc + ) @property def end(self) -> datetime.datetime: """:class:`datetime.datetime`: When the user will stop playing this song in UTC.""" - return datetime.datetime.fromtimestamp(self._timestamps['end'] / 1000, tz=datetime.timezone.utc) + return datetime.datetime.fromtimestamp( + self._timestamps["end"] / 1000, tz=datetime.timezone.utc + ) @property def duration(self) -> datetime.timedelta: @@ -704,7 +733,7 @@ class Spotify: @property def party_id(self) -> str: """:class:`str`: The party ID of the listening party.""" - return self._party.get('id', '') + return self._party.get("id", "") class CustomActivity(BaseActivity): @@ -738,13 +767,15 @@ class CustomActivity(BaseActivity): The emoji to pass to the activity, if any. """ - __slots__ = ('name', 'emoji', 'state') + __slots__ = ("name", "emoji", "state") - def __init__(self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any): + def __init__( + self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any + ): super().__init__(**extra) self.name: Optional[str] = name - self.state: Optional[str] = extra.pop('state', None) - if self.name == 'Custom Status': + self.state: Optional[str] = extra.pop("state", None) + if self.name == "Custom Status": self.name = self.state self.emoji: Optional[PartialEmoji] @@ -757,7 +788,9 @@ class CustomActivity(BaseActivity): elif isinstance(emoji, PartialEmoji): self.emoji = emoji else: - raise TypeError(f'Expected str, PartialEmoji, or None, received {type(emoji)!r} instead.') + raise TypeError( + f"Expected str, PartialEmoji, or None, received {type(emoji)!r} instead." + ) @property def type(self) -> ActivityType: @@ -770,22 +803,26 @@ class CustomActivity(BaseActivity): def to_dict(self) -> Dict[str, Any]: if self.name == self.state: o = { - 'type': ActivityType.custom.value, - 'state': self.name, - 'name': 'Custom Status', + "type": ActivityType.custom.value, + "state": self.name, + "name": "Custom Status", } else: o = { - 'type': ActivityType.custom.value, - 'name': self.name, + "type": ActivityType.custom.value, + "name": self.name, } if self.emoji: - o['emoji'] = self.emoji.to_dict() + o["emoji"] = self.emoji.to_dict() return o def __eq__(self, other: Any) -> bool: - return isinstance(other, CustomActivity) and other.name == self.name and other.emoji == self.emoji + return ( + isinstance(other, CustomActivity) + and other.name == self.name + and other.emoji == self.emoji + ) def __ne__(self, other: Any) -> bool: return not self.__eq__(other) @@ -796,47 +833,54 @@ class CustomActivity(BaseActivity): def __str__(self) -> str: if self.emoji: if self.name: - return f'{self.emoji} {self.name}' + return f"{self.emoji} {self.name}" return str(self.emoji) else: return str(self.name) def __repr__(self) -> str: - return f'' + return f"" ActivityTypes = Union[Activity, Game, CustomActivity, Streaming, Spotify] + @overload def create_activity(data: ActivityPayload) -> ActivityTypes: ... + @overload def create_activity(data: None) -> None: ... + def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]: if not data: return None - game_type = try_enum(ActivityType, data.get('type', -1)) + game_type = try_enum(ActivityType, data.get("type", -1)) if game_type is ActivityType.playing: - if 'application_id' in data or 'session_id' in data: + if "application_id" in data or "session_id" in data: return Activity(**data) return Game(**data) elif game_type is ActivityType.custom: try: - name = data.pop('name') + name = data.pop("name") except KeyError: return Activity(**data) else: # we removed the name key from data already - return CustomActivity(name=name, **data) # type: ignore + return CustomActivity(name=name, **data) # type: ignore elif game_type is ActivityType.streaming: - if 'url' in data: + if "url" in data: # the url won't be None here - return Streaming(**data) # type: ignore + return Streaming(**data) # type: ignore return Activity(**data) - elif game_type is ActivityType.listening and 'sync_id' in data and 'session_id' in data: + elif ( + game_type is ActivityType.listening + and "sync_id" in data + and "session_id" in data + ): return Spotify(**data) return Activity(**data) diff --git a/discord/appinfo.py b/discord/appinfo.py index de1f7a73..b50d9328 100644 --- a/discord/appinfo.py +++ b/discord/appinfo.py @@ -40,8 +40,8 @@ if TYPE_CHECKING: from .state import ConnectionState __all__ = ( - 'AppInfo', - 'PartialAppInfo', + "AppInfo", + "PartialAppInfo", ) @@ -115,58 +115,60 @@ class AppInfo: """ __slots__ = ( - '_state', - 'description', - 'id', - 'name', - 'rpc_origins', - 'bot_public', - 'bot_require_code_grant', - 'owner', - '_icon', - 'summary', - 'verify_key', - 'team', - 'guild_id', - 'primary_sku_id', - 'slug', - '_cover_image', - 'terms_of_service_url', - 'privacy_policy_url', + "_state", + "description", + "id", + "name", + "rpc_origins", + "bot_public", + "bot_require_code_grant", + "owner", + "_icon", + "summary", + "verify_key", + "team", + "guild_id", + "primary_sku_id", + "slug", + "_cover_image", + "terms_of_service_url", + "privacy_policy_url", ) def __init__(self, state: ConnectionState, data: AppInfoPayload): from .team import Team self._state: ConnectionState = state - self.id: int = int(data['id']) - self.name: str = data['name'] - self.description: str = data['description'] - self._icon: Optional[str] = data['icon'] - self.rpc_origins: List[str] = data['rpc_origins'] - self.bot_public: bool = data['bot_public'] - self.bot_require_code_grant: bool = data['bot_require_code_grant'] - self.owner: User = state.create_user(data['owner']) + self.id: int = int(data["id"]) + self.name: str = data["name"] + self.description: str = data["description"] + self._icon: Optional[str] = data["icon"] + self.rpc_origins: List[str] = data["rpc_origins"] + self.bot_public: bool = data["bot_public"] + self.bot_require_code_grant: bool = data["bot_require_code_grant"] + self.owner: User = state.create_user(data["owner"]) - team: Optional[TeamPayload] = data.get('team') + team: Optional[TeamPayload] = data.get("team") self.team: Optional[Team] = Team(state, team) if team else None - self.summary: str = data['summary'] - self.verify_key: str = data['verify_key'] + self.summary: str = data["summary"] + self.verify_key: str = data["verify_key"] - self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id') + self.guild_id: Optional[int] = utils._get_as_snowflake(data, "guild_id") - self.primary_sku_id: Optional[int] = utils._get_as_snowflake(data, 'primary_sku_id') - self.slug: Optional[str] = data.get('slug') - self._cover_image: Optional[str] = data.get('cover_image') - self.terms_of_service_url: Optional[str] = data.get('terms_of_service_url') - self.privacy_policy_url: Optional[str] = data.get('privacy_policy_url') + self.primary_sku_id: Optional[int] = utils._get_as_snowflake( + data, "primary_sku_id" + ) + self.slug: Optional[str] = data.get("slug") + self._cover_image: Optional[str] = data.get("cover_image") + self.terms_of_service_url: Optional[str] = data.get("terms_of_service_url") + self.privacy_policy_url: Optional[str] = data.get("privacy_policy_url") def __repr__(self) -> str: return ( - f'<{self.__class__.__name__} id={self.id} name={self.name!r} ' - f'description={self.description!r} public={self.bot_public} ' - f'owner={self.owner!r}>' + f"<{self.__class__.__name__} id={self.id} name={self.name!r} " + f"description={self.description!r} public={self.bot_public} " + f"owner={self.owner!r}>" ) @property @@ -174,7 +176,7 @@ class AppInfo: """Optional[:class:`.Asset`]: Retrieves the application's icon asset, if any.""" if self._icon is None: return None - return Asset._from_icon(self._state, self.id, self._icon, path='app') + return Asset._from_icon(self._state, self.id, self._icon, path="app") @property def cover_image(self) -> Optional[Asset]: @@ -195,6 +197,7 @@ class AppInfo: """ return self._state._get_guild(self.guild_id) + class PartialAppInfo: """Represents a partial AppInfo given by :func:`~discord.abc.GuildChannel.create_invite` @@ -222,26 +225,37 @@ class PartialAppInfo: The application's privacy policy URL, if set. """ - __slots__ = ('_state', 'id', 'name', 'description', 'rpc_origins', 'summary', 'verify_key', 'terms_of_service_url', 'privacy_policy_url', '_icon') + __slots__ = ( + "_state", + "id", + "name", + "description", + "rpc_origins", + "summary", + "verify_key", + "terms_of_service_url", + "privacy_policy_url", + "_icon", + ) def __init__(self, *, state: ConnectionState, data: PartialAppInfoPayload): self._state: ConnectionState = state - self.id: int = int(data['id']) - self.name: str = data['name'] - self._icon: Optional[str] = data.get('icon') - self.description: str = data['description'] - self.rpc_origins: Optional[List[str]] = data.get('rpc_origins') - self.summary: str = data['summary'] - self.verify_key: str = data['verify_key'] - self.terms_of_service_url: Optional[str] = data.get('terms_of_service_url') - self.privacy_policy_url: Optional[str] = data.get('privacy_policy_url') + self.id: int = int(data["id"]) + self.name: str = data["name"] + self._icon: Optional[str] = data.get("icon") + self.description: str = data["description"] + self.rpc_origins: Optional[List[str]] = data.get("rpc_origins") + self.summary: str = data["summary"] + self.verify_key: str = data["verify_key"] + self.terms_of_service_url: Optional[str] = data.get("terms_of_service_url") + self.privacy_policy_url: Optional[str] = data.get("privacy_policy_url") def __repr__(self) -> str: - return f'<{self.__class__.__name__} id={self.id} name={self.name!r} description={self.description!r}>' + return f"<{self.__class__.__name__} id={self.id} name={self.name!r} description={self.description!r}>" @property def icon(self) -> Optional[Asset]: """Optional[:class:`.Asset`]: Retrieves the application's icon asset, if any.""" if self._icon is None: return None - return Asset._from_icon(self._state, self.id, self._icon, path='app') + return Asset._from_icon(self._state, self.id, self._icon, path="app") diff --git a/discord/asset.py b/discord/asset.py index 25c72648..716d888f 100644 --- a/discord/asset.py +++ b/discord/asset.py @@ -33,13 +33,11 @@ from . import utils import yarl -__all__ = ( - 'Asset', -) +__all__ = ("Asset",) if TYPE_CHECKING: - ValidStaticFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png'] - ValidAssetFormatTypes = Literal['webp', 'jpeg', 'jpg', 'png', 'gif'] + ValidStaticFormatTypes = Literal["webp", "jpeg", "jpg", "png"] + ValidAssetFormatTypes = Literal["webp", "jpeg", "jpg", "png", "gif"] VALID_STATIC_FORMATS = frozenset({"jpeg", "jpg", "webp", "png"}) VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"} @@ -47,6 +45,7 @@ VALID_ASSET_FORMATS = VALID_STATIC_FORMATS | {"gif"} MISSING = utils.MISSING + class AssetMixin: url: str _state: Optional[Any] @@ -71,11 +70,16 @@ class AssetMixin: The content of the asset. """ if self._state is None: - raise DiscordException('Invalid state (no ConnectionState provided)') + raise DiscordException("Invalid state (no ConnectionState provided)") return await self._state.http.get_from_cdn(self.url) - async def save(self, fp: Union[str, bytes, os.PathLike, io.BufferedIOBase], *, seek_begin: bool = True) -> int: + async def save( + self, + fp: Union[str, bytes, os.PathLike, io.BufferedIOBase], + *, + seek_begin: bool = True, + ) -> int: """|coro| Saves this asset into a file-like object. @@ -112,7 +116,7 @@ class AssetMixin: fp.seek(0) return written else: - with open(fp, 'wb') as f: + with open(fp, "wb") as f: return f.write(data) @@ -143,13 +147,13 @@ class Asset(AssetMixin): """ __slots__: Tuple[str, ...] = ( - '_state', - '_url', - '_animated', - '_key', + "_state", + "_url", + "_animated", + "_key", ) - BASE = 'https://cdn.discordapp.com' + BASE = "https://cdn.discordapp.com" def __init__(self, state, *, url: str, key: str, animated: bool = False): self._state = state @@ -161,26 +165,28 @@ class Asset(AssetMixin): def _from_default_avatar(cls, state, index: int) -> Asset: return cls( state, - url=f'{cls.BASE}/embed/avatars/{index}.png', + url=f"{cls.BASE}/embed/avatars/{index}.png", key=str(index), animated=False, ) @classmethod def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset: - animated = avatar.startswith('a_') - format = 'gif' if animated else 'png' + animated = avatar.startswith("a_") + format = "gif" if animated else "png" return cls( state, - url=f'{cls.BASE}/avatars/{user_id}/{avatar}.{format}?size=1024', + url=f"{cls.BASE}/avatars/{user_id}/{avatar}.{format}?size=1024", key=avatar, animated=animated, ) @classmethod - def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) -> Asset: - animated = avatar.startswith('a_') - format = 'gif' if animated else 'png' + def _from_guild_avatar( + cls, state, guild_id: int, member_id: int, avatar: str + ) -> Asset: + animated = avatar.startswith("a_") + format = "gif" if animated else "png" return cls( state, url=f"{cls.BASE}/guilds/{guild_id}/users/{member_id}/avatars/{avatar}.{format}?size=1024", @@ -192,7 +198,7 @@ class Asset(AssetMixin): def _from_icon(cls, state, object_id: int, icon_hash: str, path: str) -> Asset: return cls( state, - url=f'{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024', + url=f"{cls.BASE}/{path}-icons/{object_id}/{icon_hash}.png?size=1024", key=icon_hash, animated=False, ) @@ -201,7 +207,7 @@ class Asset(AssetMixin): def _from_cover_image(cls, state, object_id: int, cover_image_hash: str) -> Asset: return cls( state, - url=f'{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024', + url=f"{cls.BASE}/app-assets/{object_id}/store/{cover_image_hash}.png?size=1024", key=cover_image_hash, animated=False, ) @@ -210,18 +216,18 @@ class Asset(AssetMixin): def _from_guild_image(cls, state, guild_id: int, image: str, path: str) -> Asset: return cls( state, - url=f'{cls.BASE}/{path}/{guild_id}/{image}.png?size=1024', + url=f"{cls.BASE}/{path}/{guild_id}/{image}.png?size=1024", key=image, animated=False, ) @classmethod def _from_guild_icon(cls, state, guild_id: int, icon_hash: str) -> Asset: - animated = icon_hash.startswith('a_') - format = 'gif' if animated else 'png' + animated = icon_hash.startswith("a_") + format = "gif" if animated else "png" return cls( state, - url=f'{cls.BASE}/icons/{guild_id}/{icon_hash}.{format}?size=1024', + url=f"{cls.BASE}/icons/{guild_id}/{icon_hash}.{format}?size=1024", key=icon_hash, animated=animated, ) @@ -230,20 +236,20 @@ class Asset(AssetMixin): def _from_sticker_banner(cls, state, banner: int) -> Asset: return cls( state, - url=f'{cls.BASE}/app-assets/710982414301790216/store/{banner}.png', + url=f"{cls.BASE}/app-assets/710982414301790216/store/{banner}.png", key=str(banner), animated=False, ) @classmethod def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset: - animated = banner_hash.startswith('a_') - format = 'gif' if animated else 'png' + animated = banner_hash.startswith("a_") + format = "gif" if animated else "png" return cls( state, - url=f'{cls.BASE}/banners/{user_id}/{banner_hash}.{format}?size=512', + url=f"{cls.BASE}/banners/{user_id}/{banner_hash}.{format}?size=512", key=banner_hash, - animated=animated + animated=animated, ) def __str__(self) -> str: @@ -253,8 +259,8 @@ class Asset(AssetMixin): return len(self._url) def __repr__(self): - shorten = self._url.replace(self.BASE, '') - return f'' + shorten = self._url.replace(self.BASE, "") + return f"" def __eq__(self, other): return isinstance(other, Asset) and self._url == other._url @@ -312,21 +318,27 @@ class Asset(AssetMixin): if format is not MISSING: if self._animated: if format not in VALID_ASSET_FORMATS: - raise InvalidArgument(f'format must be one of {VALID_ASSET_FORMATS}') - url = url.with_path(f'{path}.{format}') + raise InvalidArgument( + f"format must be one of {VALID_ASSET_FORMATS}" + ) + url = url.with_path(f"{path}.{format}") elif static_format is MISSING: if format not in VALID_STATIC_FORMATS: - raise InvalidArgument(f'format must be one of {VALID_STATIC_FORMATS}') - url = url.with_path(f'{path}.{format}') + raise InvalidArgument( + f"format must be one of {VALID_STATIC_FORMATS}" + ) + url = url.with_path(f"{path}.{format}") if static_format is not MISSING and not self._animated: if static_format not in VALID_STATIC_FORMATS: - raise InvalidArgument(f'static_format must be one of {VALID_STATIC_FORMATS}') - url = url.with_path(f'{path}.{static_format}') + raise InvalidArgument( + f"static_format must be one of {VALID_STATIC_FORMATS}" + ) + url = url.with_path(f"{path}.{static_format}") if size is not MISSING: if not utils.valid_icon_size(size): - raise InvalidArgument('size must be a power of 2 between 16 and 4096') + raise InvalidArgument("size must be a power of 2 between 16 and 4096") url = url.with_query(size=size) else: url = url.with_query(url.raw_query_string) @@ -353,7 +365,7 @@ class Asset(AssetMixin): The new updated asset. """ if not utils.valid_icon_size(size): - raise InvalidArgument('size must be a power of 2 between 16 and 4096') + raise InvalidArgument("size must be a power of 2 between 16 and 4096") url = str(yarl.URL(self._url).with_query(size=size)) return Asset(state=self._state, url=url, key=self._key, animated=self._animated) @@ -379,14 +391,14 @@ class Asset(AssetMixin): if self._animated: if format not in VALID_ASSET_FORMATS: - raise InvalidArgument(f'format must be one of {VALID_ASSET_FORMATS}') + raise InvalidArgument(f"format must be one of {VALID_ASSET_FORMATS}") else: if format not in VALID_STATIC_FORMATS: - raise InvalidArgument(f'format must be one of {VALID_STATIC_FORMATS}') + raise InvalidArgument(f"format must be one of {VALID_STATIC_FORMATS}") url = yarl.URL(self._url) path, _ = os.path.splitext(url.path) - url = str(url.with_path(f'{path}.{format}').with_query(url.raw_query_string)) + url = str(url.with_path(f"{path}.{format}").with_query(url.raw_query_string)) return Asset(state=self._state, url=url, key=self._key, animated=self._animated) def with_static_format(self, format: ValidStaticFormatTypes, /) -> Asset: diff --git a/discord/audit_logs.py b/discord/audit_logs.py index b74bbfef..7f4eb32c 100644 --- a/discord/audit_logs.py +++ b/discord/audit_logs.py @@ -24,7 +24,20 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, Generator, List, Optional, Tuple, Type, TypeVar, Union +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ClassVar, + Dict, + Generator, + List, + Optional, + Tuple, + Type, + TypeVar, + Union, +) from . import enums, utils from .asset import Asset @@ -35,9 +48,9 @@ from .object import Object from .permissions import PermissionOverwrite, Permissions __all__ = ( - 'AuditLogDiff', - 'AuditLogChanges', - 'AuditLogEntry', + "AuditLogDiff", + "AuditLogChanges", + "AuditLogEntry", ) @@ -74,18 +87,25 @@ def _transform_snowflake(entry: AuditLogEntry, data: Snowflake) -> int: return int(data) -def _transform_channel(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Union[abc.GuildChannel, Object]]: +def _transform_channel( + entry: AuditLogEntry, data: Optional[Snowflake] +) -> Optional[Union[abc.GuildChannel, Object]]: if data is None: return None return entry.guild.get_channel(int(data)) or Object(id=data) -def _transform_member_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Union[Member, User, None]: +def _transform_member_id( + entry: AuditLogEntry, data: Optional[Snowflake] +) -> Union[Member, User, None]: if data is None: return None return entry._get_member(int(data)) -def _transform_guild_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Guild]: + +def _transform_guild_id( + entry: AuditLogEntry, data: Optional[Snowflake] +) -> Optional[Guild]: if data is None: return None return entry._state._get_guild(data) @@ -96,16 +116,16 @@ def _transform_overwrites( ) -> List[Tuple[Object, PermissionOverwrite]]: overwrites = [] for elem in data: - allow = Permissions(int(elem['allow'])) - deny = Permissions(int(elem['deny'])) + allow = Permissions(int(elem["allow"])) + deny = Permissions(int(elem["deny"])) ow = PermissionOverwrite.from_pair(allow, deny) - ow_type = elem['type'] - ow_id = int(elem['id']) + ow_type = elem["type"] + ow_id = int(elem["id"]) target = None - if ow_type == '0': + if ow_type == "0": target = entry.guild.get_role(ow_id) - elif ow_type == '1': + elif ow_type == "1": target = entry._get_member(ow_id) if target is None: @@ -128,7 +148,9 @@ def _transform_avatar(entry: AuditLogEntry, data: Optional[str]) -> Optional[Ass return Asset._from_avatar(entry._state, entry._target_id, data) # type: ignore -def _guild_hash_transformer(path: str) -> Callable[[AuditLogEntry, Optional[str]], Optional[Asset]]: +def _guild_hash_transformer( + path: str, +) -> Callable[[AuditLogEntry, Optional[str]], Optional[Asset]]: def _transform(entry: AuditLogEntry, data: Optional[str]) -> Optional[Asset]: if data is None: return None @@ -137,7 +159,7 @@ def _guild_hash_transformer(path: str) -> Callable[[AuditLogEntry, Optional[str] return _transform -T = TypeVar('T', bound=enums.Enum) +T = TypeVar("T", bound=enums.Enum) def _enum_transformer(enum: Type[T]) -> Callable[[AuditLogEntry, int], T]: @@ -146,12 +168,16 @@ def _enum_transformer(enum: Type[T]) -> Callable[[AuditLogEntry, int], T]: return _transform -def _transform_type(entry: AuditLogEntry, data: Union[int]) -> Union[enums.ChannelType, enums.StickerType]: - if entry.action.name.startswith('sticker_'): + +def _transform_type( + entry: AuditLogEntry, data: Union[int] +) -> Union[enums.ChannelType, enums.StickerType]: + if entry.action.name.startswith("sticker_"): return enums.try_enum(enums.StickerType, data) else: return enums.try_enum(enums.ChannelType, data) + class AuditLogDiff: def __len__(self) -> int: return len(self.__dict__) @@ -160,8 +186,8 @@ class AuditLogDiff: yield from self.__dict__.items() def __repr__(self) -> str: - values = ' '.join('%s=%r' % item for item in self.__dict__.items()) - return f'' + values = " ".join("%s=%r" % item for item in self.__dict__.items()) + return f"" if TYPE_CHECKING: @@ -217,14 +243,14 @@ class AuditLogChanges: self.after = AuditLogDiff() for elem in data: - attr = elem['key'] + attr = elem["key"] # special cases for role add/remove - if attr == '$add': - self._handle_role(self.before, self.after, entry, elem['new_value']) # type: ignore + if attr == "$add": + self._handle_role(self.before, self.after, entry, elem["new_value"]) # type: ignore continue - elif attr == '$remove': - self._handle_role(self.after, self.before, entry, elem['new_value']) # type: ignore + elif attr == "$remove": + self._handle_role(self.after, self.before, entry, elem["new_value"]) # type: ignore continue try: @@ -238,7 +264,7 @@ class AuditLogChanges: transformer: Optional[Transformer] try: - before = elem['old_value'] + before = elem["old_value"] except KeyError: before = None else: @@ -248,7 +274,7 @@ class AuditLogChanges: setattr(self.before, attr, before) try: - after = elem['new_value'] + after = elem["new_value"] except KeyError: after = None else: @@ -258,34 +284,40 @@ class AuditLogChanges: setattr(self.after, attr, after) # add an alias - if hasattr(self.after, 'colour'): + if hasattr(self.after, "colour"): self.after.color = self.after.colour self.before.color = self.before.colour - if hasattr(self.after, 'expire_behavior'): + if hasattr(self.after, "expire_behavior"): self.after.expire_behaviour = self.after.expire_behavior self.before.expire_behaviour = self.before.expire_behavior def __repr__(self) -> str: - return f'' + return f"" - def _handle_role(self, first: AuditLogDiff, second: AuditLogDiff, entry: AuditLogEntry, elem: List[RolePayload]) -> None: - if not hasattr(first, 'roles'): - setattr(first, 'roles', []) + def _handle_role( + self, + first: AuditLogDiff, + second: AuditLogDiff, + entry: AuditLogEntry, + elem: List[RolePayload], + ) -> None: + if not hasattr(first, "roles"): + setattr(first, "roles", []) data = [] g: Guild = entry.guild # type: ignore for e in elem: - role_id = int(e['id']) + role_id = int(e["id"]) role = g.get_role(role_id) if role is None: role = Object(id=role_id) - role.name = e['name'] # type: ignore + role.name = e["name"] # type: ignore data.append(role) - setattr(second, 'roles', data) + setattr(second, "roles", data) class _AuditLogProxyMemberPrune: @@ -358,63 +390,81 @@ class AuditLogEntry(Hashable): which actions have this field filled out. """ - def __init__(self, *, users: Dict[int, User], data: AuditLogEntryPayload, guild: Guild): + def __init__( + self, *, users: Dict[int, User], data: AuditLogEntryPayload, guild: Guild + ): self._state = guild._state self.guild = guild self._users = users self._from_data(data) def _from_data(self, data: AuditLogEntryPayload) -> None: - self.action = enums.try_enum(enums.AuditLogAction, data['action_type']) - self.id = int(data['id']) + self.action = enums.try_enum(enums.AuditLogAction, data["action_type"]) + self.id = int(data["id"]) # this key is technically not usually present - self.reason = data.get('reason') - self.extra = data.get('options') + self.reason = data.get("reason") + self.extra = data.get("options") if isinstance(self.action, enums.AuditLogAction) and self.extra: if self.action is enums.AuditLogAction.member_prune: # member prune has two keys with useful information self.extra: _AuditLogProxyMemberPrune = type( - '_AuditLogProxy', (), {k: int(v) for k, v in self.extra.items()} + "_AuditLogProxy", (), {k: int(v) for k, v in self.extra.items()} )() - elif self.action is enums.AuditLogAction.member_move or self.action is enums.AuditLogAction.message_delete: - channel_id = int(self.extra['channel_id']) + elif ( + self.action is enums.AuditLogAction.member_move + or self.action is enums.AuditLogAction.message_delete + ): + channel_id = int(self.extra["channel_id"]) elems = { - 'count': int(self.extra['count']), - 'channel': self.guild.get_channel(channel_id) or Object(id=channel_id), + "count": int(self.extra["count"]), + "channel": self.guild.get_channel(channel_id) + or Object(id=channel_id), } - self.extra: _AuditLogProxyMemberMoveOrMessageDelete = type('_AuditLogProxy', (), elems)() + self.extra: _AuditLogProxyMemberMoveOrMessageDelete = type( + "_AuditLogProxy", (), elems + )() elif self.action is enums.AuditLogAction.member_disconnect: # The member disconnect action has a dict with some information elems = { - 'count': int(self.extra['count']), + "count": int(self.extra["count"]), } - self.extra: _AuditLogProxyMemberDisconnect = type('_AuditLogProxy', (), elems)() - elif self.action.name.endswith('pin'): + self.extra: _AuditLogProxyMemberDisconnect = type( + "_AuditLogProxy", (), elems + )() + elif self.action.name.endswith("pin"): # the pin actions have a dict with some information - channel_id = int(self.extra['channel_id']) + channel_id = int(self.extra["channel_id"]) elems = { - 'channel': self.guild.get_channel(channel_id) or Object(id=channel_id), - 'message_id': int(self.extra['message_id']), + "channel": self.guild.get_channel(channel_id) + or Object(id=channel_id), + "message_id": int(self.extra["message_id"]), } - self.extra: _AuditLogProxyPinAction = type('_AuditLogProxy', (), elems)() - elif self.action.name.startswith('overwrite_'): + self.extra: _AuditLogProxyPinAction = type( + "_AuditLogProxy", (), elems + )() + elif self.action.name.startswith("overwrite_"): # the overwrite_ actions have a dict with some information - instance_id = int(self.extra['id']) - the_type = self.extra.get('type') - if the_type == '1': + instance_id = int(self.extra["id"]) + the_type = self.extra.get("type") + if the_type == "1": self.extra = self._get_member(instance_id) - elif the_type == '0': + elif the_type == "0": role = self.guild.get_role(instance_id) if role is None: role = Object(id=instance_id) - role.name = self.extra.get('role_name') # type: ignore + role.name = self.extra.get("role_name") # type: ignore self.extra: Role = role - elif self.action.name.startswith('stage_instance'): - channel_id = int(self.extra['channel_id']) - elems = {'channel': self.guild.get_channel(channel_id) or Object(id=channel_id)} - self.extra: _AuditLogProxyStageInstanceAction = type('_AuditLogProxy', (), elems)() + elif self.action.name.startswith("stage_instance"): + channel_id = int(self.extra["channel_id"]) + elems = { + "channel": self.guild.get_channel(channel_id) + or Object(id=channel_id) + } + self.extra: _AuditLogProxyStageInstanceAction = type( + "_AuditLogProxy", (), elems + )() # fmt: off self.extra: Union[ @@ -433,16 +483,16 @@ class AuditLogEntry(Hashable): # where new_value and old_value are not guaranteed to be there depending # on the action type, so let's just fetch it for now and only turn it # into meaningful data when requested - self._changes = data.get('changes', []) + self._changes = data.get("changes", []) - self.user = self._get_member(utils._get_as_snowflake(data, 'user_id')) # type: ignore - self._target_id = utils._get_as_snowflake(data, 'target_id') + self.user = self._get_member(utils._get_as_snowflake(data, "user_id")) # type: ignore + self._target_id = utils._get_as_snowflake(data, "target_id") def _get_member(self, user_id: int) -> Union[Member, User, None]: return self.guild.get_member(user_id) or self._users.get(user_id) def __repr__(self) -> str: - return f'' + return f"" @utils.cached_property def created_at(self) -> datetime.datetime: @@ -450,9 +500,24 @@ class AuditLogEntry(Hashable): return utils.snowflake_time(self.id) @utils.cached_property - def target(self) -> Union[Guild, abc.GuildChannel, Member, User, Role, Invite, Emoji, StageInstance, GuildSticker, Thread, Object, None]: + def target( + self, + ) -> Union[ + Guild, + abc.GuildChannel, + Member, + User, + Role, + Invite, + Emoji, + StageInstance, + GuildSticker, + Thread, + Object, + None, + ]: try: - converter = getattr(self, '_convert_target_' + self.action.target_type) + converter = getattr(self, "_convert_target_" + self.action.target_type) except AttributeError: return Object(id=self._target_id) else: @@ -483,7 +548,9 @@ class AuditLogEntry(Hashable): def _convert_target_guild(self, target_id: int) -> Guild: return self.guild - def _convert_target_channel(self, target_id: int) -> Union[abc.GuildChannel, Object]: + def _convert_target_channel( + self, target_id: int + ) -> Union[abc.GuildChannel, Object]: return self.guild.get_channel(target_id) or Object(id=target_id) def _convert_target_user(self, target_id: int) -> Union[Member, User, None]: @@ -495,14 +562,18 @@ class AuditLogEntry(Hashable): def _convert_target_invite(self, target_id: int) -> Invite: # invites have target_id set to null # so figure out which change has the full invite data - changeset = self.before if self.action is enums.AuditLogAction.invite_delete else self.after + changeset = ( + self.before + if self.action is enums.AuditLogAction.invite_delete + else self.after + ) fake_payload = { - 'max_age': changeset.max_age, - 'max_uses': changeset.max_uses, - 'code': changeset.code, - 'temporary': changeset.temporary, - 'uses': changeset.uses, + "max_age": changeset.max_age, + "max_uses": changeset.max_uses, + "code": changeset.code, + "temporary": changeset.temporary, + "uses": changeset.uses, } obj = Invite(state=self._state, data=fake_payload, guild=self.guild, channel=changeset.channel) # type: ignore @@ -518,7 +589,9 @@ class AuditLogEntry(Hashable): def _convert_target_message(self, target_id: int) -> Union[Member, User, None]: return self._get_member(target_id) - def _convert_target_stage_instance(self, target_id: int) -> Union[StageInstance, Object]: + def _convert_target_stage_instance( + self, target_id: int + ) -> Union[StageInstance, Object]: return self.guild.get_stage_instance(target_id) or Object(id=target_id) def _convert_target_sticker(self, target_id: int) -> Union[GuildSticker, Object]: diff --git a/discord/backoff.py b/discord/backoff.py index 903ecf76..3b2d9f77 100644 --- a/discord/backoff.py +++ b/discord/backoff.py @@ -29,11 +29,10 @@ import time import random from typing import Callable, Generic, Literal, TypeVar, overload, Union -T = TypeVar('T', bool, Literal[True], Literal[False]) +T = TypeVar("T", bool, Literal[True], Literal[False]) + +__all__ = ("ExponentialBackoff",) -__all__ = ( - 'ExponentialBackoff', -) class ExponentialBackoff(Generic[T]): """An implementation of the exponential backoff algorithm @@ -69,7 +68,7 @@ class ExponentialBackoff(Generic[T]): rand = random.Random() rand.seed() - self._randfunc: Callable[..., Union[int, float]] = rand.randrange if integral else rand.uniform # type: ignore + self._randfunc: Callable[..., Union[int, float]] = rand.randrange if integral else rand.uniform # type: ignore @overload def delay(self: ExponentialBackoff[Literal[False]]) -> float: diff --git a/discord/channel.py b/discord/channel.py index dc3967c4..0eb63340 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -45,7 +45,13 @@ import datetime import discord.abc from .permissions import PermissionOverwrite, Permissions -from .enums import ChannelType, StagePrivacyLevel, try_enum, VoiceRegion, VideoQualityMode +from .enums import ( + ChannelType, + StagePrivacyLevel, + try_enum, + VoiceRegion, + VideoQualityMode, +) from .mixins import Hashable from .object import Object from . import utils @@ -57,14 +63,14 @@ from .threads import Thread from .iterators import ArchivedThreadIterator __all__ = ( - 'TextChannel', - 'VoiceChannel', - 'StageChannel', - 'DMChannel', - 'CategoryChannel', - 'StoreChannel', - 'GroupChannel', - 'PartialMessageable', + "TextChannel", + "VoiceChannel", + "StageChannel", + "DMChannel", + "CategoryChannel", + "StoreChannel", + "GroupChannel", + "PartialMessageable", ) if TYPE_CHECKING: @@ -155,51 +161,57 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): """ __slots__ = ( - 'name', - 'id', - 'guild', - 'topic', - '_state', - 'nsfw', - 'category_id', - 'position', - 'slowmode_delay', - '_overwrites', - '_type', - 'last_message_id', - 'default_auto_archive_duration', + "name", + "id", + "guild", + "topic", + "_state", + "nsfw", + "category_id", + "position", + "slowmode_delay", + "_overwrites", + "_type", + "last_message_id", + "default_auto_archive_duration", ) - def __init__(self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload): + def __init__( + self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload + ): self._state: ConnectionState = state - self.id: int = int(data['id']) - self._type: int = data['type'] + self.id: int = int(data["id"]) + self._type: int = data["type"] self._update(guild, data) def __repr__(self) -> str: attrs = [ - ('id', self.id), - ('name', self.name), - ('position', self.position), - ('nsfw', self.nsfw), - ('news', self.is_news()), - ('category_id', self.category_id), + ("id", self.id), + ("name", self.name), + ("position", self.position), + ("nsfw", self.nsfw), + ("news", self.is_news()), + ("category_id", self.category_id), ] - joined = ' '.join('%s=%r' % t for t in attrs) - return f'<{self.__class__.__name__} {joined}>' + joined = " ".join("%s=%r" % t for t in attrs) + return f"<{self.__class__.__name__} {joined}>" def _update(self, guild: Guild, data: TextChannelPayload) -> None: self.guild: Guild = guild - self.name: str = data['name'] - self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id') - self.topic: Optional[str] = data.get('topic') - self.position: int = data['position'] - self.nsfw: bool = data.get('nsfw', False) + self.name: str = data["name"] + self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id") + self.topic: Optional[str] = data.get("topic") + self.position: int = data["position"] + self.nsfw: bool = data.get("nsfw", False) # Does this need coercion into `int`? No idea yet. - self.slowmode_delay: int = data.get('rate_limit_per_user', 0) - self.default_auto_archive_duration: ThreadArchiveDuration = data.get('default_auto_archive_duration', 1440) - self._type: int = data.get('type', self._type) - self.last_message_id: Optional[int] = utils._get_as_snowflake(data, 'last_message_id') + self.slowmode_delay: int = data.get("rate_limit_per_user", 0) + self.default_auto_archive_duration: ThreadArchiveDuration = data.get( + "default_auto_archive_duration", 1440 + ) + self._type: int = data.get("type", self._type) + self.last_message_id: Optional[int] = utils._get_as_snowflake( + data, "last_message_id" + ) self._fill_overwrites(data) async def _get_channel(self): @@ -234,7 +246,11 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): .. versionadded:: 2.0 """ - return [thread for thread in self.guild._threads.values() if thread.parent_id == self.id] + return [ + thread + for thread in self.guild._threads.values() + if thread.parent_id == self.id + ] def is_nsfw(self) -> bool: """:class:`bool`: Checks if the channel is NSFW.""" @@ -263,7 +279,11 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): Optional[:class:`Message`] The last message in this channel or ``None`` if not found. """ - return self._state._get_message(self.last_message_id) if self.last_message_id else None + return ( + self._state._get_message(self.last_message_id) + if self.last_message_id + else None + ) @overload async def edit( @@ -359,9 +379,17 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> TextChannel: + async def clone( + self, *, name: Optional[str] = None, reason: Optional[str] = None + ) -> TextChannel: return await self._clone_impl( - {'topic': self.topic, 'nsfw': self.nsfw, 'rate_limit_per_user': self.slowmode_delay}, name=name, reason=reason + { + "topic": self.topic, + "nsfw": self.nsfw, + "rate_limit_per_user": self.slowmode_delay, + }, + name=name, + reason=reason, ) async def delete_messages(self, messages: Iterable[Snowflake]) -> None: @@ -408,7 +436,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): return if len(messages) > 100: - raise ClientException('Can only bulk delete messages up to 100 messages') + raise ClientException("Can only bulk delete messages up to 100 messages") message_ids: SnowflakeList = [m.id for m in messages] await self._state.http.delete_messages(self.id, message_ids) @@ -483,11 +511,19 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): if check is MISSING: check = lambda m: True - iterator = self.history(limit=limit, before=before, after=after, oldest_first=oldest_first, around=around) + iterator = self.history( + limit=limit, + before=before, + after=after, + oldest_first=oldest_first, + around=around, + ) ret: List[Message] = [] count = 0 - minimum_time = int((time.time() - 14 * 24 * 60 * 60) * 1000.0 - 1420070400000) << 22 + minimum_time = ( + int((time.time() - 14 * 24 * 60 * 60) * 1000.0 - 1420070400000) << 22 + ) strategy = self.delete_messages if bulk else _single_delete_strategy async for message in iterator: @@ -548,7 +584,9 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): data = await self._state.http.channel_webhooks(self.id) return [Webhook.from_state(d, state=self._state) for d in data] - async def create_webhook(self, *, name: str, avatar: Optional[bytes] = None, reason: Optional[str] = None) -> Webhook: + async def create_webhook( + self, *, name: str, avatar: Optional[bytes] = None, reason: Optional[str] = None + ) -> Webhook: """|coro| Creates a webhook for this channel. @@ -586,10 +624,14 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): if avatar is not None: avatar = utils._bytes_to_base64_data(avatar) # type: ignore - data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason) + data = await self._state.http.create_webhook( + self.id, name=str(name), avatar=avatar, reason=reason + ) return Webhook.from_state(data, state=self._state) - async def follow(self, *, destination: TextChannel, reason: Optional[str] = None) -> Webhook: + async def follow( + self, *, destination: TextChannel, reason: Optional[str] = None + ) -> Webhook: """ Follows a channel using a webhook. @@ -625,14 +667,18 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): """ if not self.is_news(): - raise ClientException('The channel must be a news channel.') + raise ClientException("The channel must be a news channel.") if not isinstance(destination, TextChannel): - raise InvalidArgument(f'Expected TextChannel received {destination.__class__.__name__}') + raise InvalidArgument( + f"Expected TextChannel received {destination.__class__.__name__}" + ) from .webhook import Webhook - data = await self._state.http.follow_webhook(self.id, webhook_channel_id=destination.id, reason=reason) + data = await self._state.http.follow_webhook( + self.id, webhook_channel_id=destination.id, reason=reason + ) return Webhook._as_follower(data, channel=destination, user=self._state.user) def get_partial_message(self, message_id: int, /) -> PartialMessage: @@ -731,7 +777,8 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): data = await self._state.http.start_thread_without_message( self.id, name=name, - auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration, + auto_archive_duration=auto_archive_duration + or self.default_auto_archive_duration, type=type.value, reason=reason, ) @@ -740,7 +787,8 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): self.id, message.id, name=name, - auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration, + auto_archive_duration=auto_archive_duration + or self.default_auto_archive_duration, reason=reason, ) @@ -787,45 +835,64 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): :class:`Thread` The archived threads. """ - return ArchivedThreadIterator(self.id, self.guild, limit=limit, joined=joined, private=private, before=before) + return ArchivedThreadIterator( + self.id, + self.guild, + limit=limit, + joined=joined, + private=private, + before=before, + ) class VocalGuildChannel(discord.abc.Connectable, discord.abc.GuildChannel, Hashable): __slots__ = ( - 'name', - 'id', - 'guild', - 'bitrate', - 'user_limit', - '_state', - 'position', - '_overwrites', - 'category_id', - 'rtc_region', - 'video_quality_mode', + "name", + "id", + "guild", + "bitrate", + "user_limit", + "_state", + "position", + "_overwrites", + "category_id", + "rtc_region", + "video_quality_mode", ) - def __init__(self, *, state: ConnectionState, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]): + def __init__( + self, + *, + state: ConnectionState, + guild: Guild, + data: Union[VoiceChannelPayload, StageChannelPayload], + ): self._state: ConnectionState = state - self.id: int = int(data['id']) + self.id: int = int(data["id"]) self._update(guild, data) def _get_voice_client_key(self) -> Tuple[int, str]: - return self.guild.id, 'guild_id' + return self.guild.id, "guild_id" def _get_voice_state_pair(self) -> Tuple[int, int]: return self.guild.id, self.id - def _update(self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]) -> None: + def _update( + self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload] + ) -> None: self.guild = guild - self.name: str = data['name'] - rtc = data.get('rtc_region') - self.rtc_region: Optional[VoiceRegion] = try_enum(VoiceRegion, rtc) if rtc is not None else None - self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get('video_quality_mode', 1)) - self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id') - self.position: int = data['position'] - self.bitrate: int = data.get('bitrate') - self.user_limit: int = data.get('user_limit') + self.name: str = data["name"] + rtc = data.get("rtc_region") + self.rtc_region: Optional[VoiceRegion] = ( + try_enum(VoiceRegion, rtc) if rtc is not None else None + ) + self.video_quality_mode: VideoQualityMode = try_enum( + VideoQualityMode, data.get("video_quality_mode", 1) + ) + self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id") + self.position: int = data["position"] + self.bitrate: int = data.get("bitrate") + self.user_limit: int = data.get("user_limit") self._fill_overwrites(data) @property @@ -933,17 +1000,17 @@ class VoiceChannel(VocalGuildChannel): def __repr__(self) -> str: attrs = [ - ('id', self.id), - ('name', self.name), - ('rtc_region', self.rtc_region), - ('position', self.position), - ('bitrate', self.bitrate), - ('video_quality_mode', self.video_quality_mode), - ('user_limit', self.user_limit), - ('category_id', self.category_id), + ("id", self.id), + ("name", self.name), + ("rtc_region", self.rtc_region), + ("position", self.position), + ("bitrate", self.bitrate), + ("video_quality_mode", self.video_quality_mode), + ("user_limit", self.user_limit), + ("category_id", self.category_id), ] - joined = ' '.join('%s=%r' % t for t in attrs) - return f'<{self.__class__.__name__} {joined}>' + joined = " ".join("%s=%r" % t for t in attrs) + return f"<{self.__class__.__name__} {joined}>" @property def type(self) -> ChannelType: @@ -951,8 +1018,14 @@ class VoiceChannel(VocalGuildChannel): return ChannelType.voice @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> VoiceChannel: - return await self._clone_impl({'bitrate': self.bitrate, 'user_limit': self.user_limit}, name=name, reason=reason) + async def clone( + self, *, name: Optional[str] = None, reason: Optional[str] = None + ) -> VoiceChannel: + return await self._clone_impl( + {"bitrate": self.bitrate, "user_limit": self.user_limit}, + name=name, + reason=reason, + ) @overload async def edit( @@ -1093,31 +1166,35 @@ class StageChannel(VocalGuildChannel): .. versionadded:: 2.0 """ - __slots__ = ('topic',) + __slots__ = ("topic",) def __repr__(self) -> str: attrs = [ - ('id', self.id), - ('name', self.name), - ('topic', self.topic), - ('rtc_region', self.rtc_region), - ('position', self.position), - ('bitrate', self.bitrate), - ('video_quality_mode', self.video_quality_mode), - ('user_limit', self.user_limit), - ('category_id', self.category_id), + ("id", self.id), + ("name", self.name), + ("topic", self.topic), + ("rtc_region", self.rtc_region), + ("position", self.position), + ("bitrate", self.bitrate), + ("video_quality_mode", self.video_quality_mode), + ("user_limit", self.user_limit), + ("category_id", self.category_id), ] - joined = ' '.join('%s=%r' % t for t in attrs) - return f'<{self.__class__.__name__} {joined}>' + joined = " ".join("%s=%r" % t for t in attrs) + return f"<{self.__class__.__name__} {joined}>" def _update(self, guild: Guild, data: StageChannelPayload) -> None: super()._update(guild, data) - self.topic = data.get('topic') + self.topic = data.get("topic") @property def requesting_to_speak(self) -> List[Member]: """List[:class:`Member`]: A list of members who are requesting to speak in the stage channel.""" - return [member for member in self.members if member.voice and member.voice.requested_to_speak_at is not None] + return [ + member + for member in self.members + if member.voice and member.voice.requested_to_speak_at is not None + ] @property def speakers(self) -> List[Member]: @@ -1128,7 +1205,9 @@ class StageChannel(VocalGuildChannel): return [ member for member in self.members - if member.voice and not member.voice.suppress and member.voice.requested_to_speak_at is None + if member.voice + and not member.voice.suppress + and member.voice.requested_to_speak_at is None ] @property @@ -1137,7 +1216,9 @@ class StageChannel(VocalGuildChannel): .. versionadded:: 2.0 """ - return [member for member in self.members if member.voice and member.voice.suppress] + return [ + member for member in self.members if member.voice and member.voice.suppress + ] @property def moderators(self) -> List[Member]: @@ -1146,7 +1227,11 @@ class StageChannel(VocalGuildChannel): .. versionadded:: 2.0 """ required_permissions = Permissions.stage_moderator() - return [member for member in self.members if self.permissions_for(member) >= required_permissions] + return [ + member + for member in self.members + if self.permissions_for(member) >= required_permissions + ] @property def type(self) -> ChannelType: @@ -1154,7 +1239,9 @@ class StageChannel(VocalGuildChannel): return ChannelType.stage_voice @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> StageChannel: + async def clone( + self, *, name: Optional[str] = None, reason: Optional[str] = None + ) -> StageChannel: return await self._clone_impl({}, name=name, reason=reason) @property @@ -1166,7 +1253,11 @@ class StageChannel(VocalGuildChannel): return utils.get(self.guild.stage_instances, channel_id=self.id) async def create_instance( - self, *, topic: str, privacy_level: StagePrivacyLevel = MISSING, reason: Optional[str] = None + self, + *, + topic: str, + privacy_level: StagePrivacyLevel = MISSING, + reason: Optional[str] = None, ) -> StageInstance: """|coro| @@ -1201,13 +1292,15 @@ class StageChannel(VocalGuildChannel): The newly created stage instance. """ - payload: Dict[str, Any] = {'channel_id': self.id, 'topic': topic} + payload: Dict[str, Any] = {"channel_id": self.id, "topic": topic} if privacy_level is not MISSING: if not isinstance(privacy_level, StagePrivacyLevel): - raise InvalidArgument('privacy_level field must be of type PrivacyLevel') + raise InvalidArgument( + "privacy_level field must be of type PrivacyLevel" + ) - payload['privacy_level'] = privacy_level.value + payload["privacy_level"] = privacy_level.value data = await self._state.http.create_stage_instance(**payload, reason=reason) return StageInstance(guild=self.guild, state=self._state, data=data) @@ -1361,22 +1454,33 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): To check if the channel or the guild of that channel are marked as NSFW, consider :meth:`is_nsfw` instead. """ - __slots__ = ('name', 'id', 'guild', 'nsfw', '_state', 'position', '_overwrites', 'category_id') + __slots__ = ( + "name", + "id", + "guild", + "nsfw", + "_state", + "position", + "_overwrites", + "category_id", + ) - def __init__(self, *, state: ConnectionState, guild: Guild, data: CategoryChannelPayload): + def __init__( + self, *, state: ConnectionState, guild: Guild, data: CategoryChannelPayload + ): self._state: ConnectionState = state - self.id: int = int(data['id']) + self.id: int = int(data["id"]) self._update(guild, data) def __repr__(self) -> str: - return f'' + return f"" def _update(self, guild: Guild, data: CategoryChannelPayload) -> None: self.guild: Guild = guild - self.name: str = data['name'] - self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id') - self.nsfw: bool = data.get('nsfw', False) - self.position: int = data['position'] + self.name: str = data["name"] + self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id") + self.nsfw: bool = data.get("nsfw", False) + self.position: int = data["position"] self._fill_overwrites(data) @property @@ -1393,8 +1497,10 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): return self.nsfw @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> CategoryChannel: - return await self._clone_impl({'nsfw': self.nsfw}, name=name, reason=reason) + async def clone( + self, *, name: Optional[str] = None, reason: Optional[str] = None + ) -> CategoryChannel: + return await self._clone_impl({"nsfw": self.nsfw}, name=name, reason=reason) @overload async def edit( @@ -1463,7 +1569,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): @utils.copy_doc(discord.abc.GuildChannel.move) async def move(self, **kwargs): - kwargs.pop('category', None) + kwargs.pop("category", None) await super().move(**kwargs) @property @@ -1483,14 +1589,22 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): @property def text_channels(self) -> List[TextChannel]: """List[:class:`TextChannel`]: Returns the text channels that are under this category.""" - ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, TextChannel)] + ret = [ + c + for c in self.guild.channels + if c.category_id == self.id and isinstance(c, TextChannel) + ] ret.sort(key=lambda c: (c.position, c.id)) return ret @property def voice_channels(self) -> List[VoiceChannel]: """List[:class:`VoiceChannel`]: Returns the voice channels that are under this category.""" - ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, VoiceChannel)] + ret = [ + c + for c in self.guild.channels + if c.category_id == self.id and isinstance(c, VoiceChannel) + ] ret.sort(key=lambda c: (c.position, c.id)) return ret @@ -1500,7 +1614,11 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): .. versionadded:: 1.7 """ - ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, StageChannel)] + ret = [ + c + for c in self.guild.channels + if c.category_id == self.id and isinstance(c, StageChannel) + ] ret.sort(key=lambda c: (c.position, c.id)) return ret @@ -1590,30 +1708,32 @@ class StoreChannel(discord.abc.GuildChannel, Hashable): """ __slots__ = ( - 'name', - 'id', - 'guild', - '_state', - 'nsfw', - 'category_id', - 'position', - '_overwrites', + "name", + "id", + "guild", + "_state", + "nsfw", + "category_id", + "position", + "_overwrites", ) - def __init__(self, *, state: ConnectionState, guild: Guild, data: StoreChannelPayload): + def __init__( + self, *, state: ConnectionState, guild: Guild, data: StoreChannelPayload + ): self._state: ConnectionState = state - self.id: int = int(data['id']) + self.id: int = int(data["id"]) self._update(guild, data) def __repr__(self) -> str: - return f'' + return f"" def _update(self, guild: Guild, data: StoreChannelPayload) -> None: self.guild: Guild = guild - self.name: str = data['name'] - self.category_id: Optional[int] = utils._get_as_snowflake(data, 'parent_id') - self.position: int = data['position'] - self.nsfw: bool = data.get('nsfw', False) + self.name: str = data["name"] + self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id") + self.position: int = data["position"] + self.nsfw: bool = data.get("nsfw", False) self._fill_overwrites(data) @property @@ -1639,8 +1759,10 @@ class StoreChannel(discord.abc.GuildChannel, Hashable): return self.nsfw @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> StoreChannel: - return await self._clone_impl({'nsfw': self.nsfw}, name=name, reason=reason) + async def clone( + self, *, name: Optional[str] = None, reason: Optional[str] = None + ) -> StoreChannel: + return await self._clone_impl({"nsfw": self.nsfw}, name=name, reason=reason) @overload async def edit( @@ -1716,7 +1838,7 @@ class StoreChannel(discord.abc.GuildChannel, Hashable): return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore -DMC = TypeVar('DMC', bound='DMChannel') +DMC = TypeVar("DMC", bound="DMChannel") class DMChannel(discord.abc.Messageable, Hashable): @@ -1756,24 +1878,26 @@ class DMChannel(discord.abc.Messageable, Hashable): The direct message channel ID. """ - __slots__ = ('id', 'recipient', 'me', '_state') + __slots__ = ("id", "recipient", "me", "_state") - def __init__(self, *, me: ClientUser, state: ConnectionState, data: DMChannelPayload): + def __init__( + self, *, me: ClientUser, state: ConnectionState, data: DMChannelPayload + ): self._state: ConnectionState = state - self.recipient: Optional[User] = state.store_user(data['recipients'][0]) + self.recipient: Optional[User] = state.store_user(data["recipients"][0]) self.me: ClientUser = me - self.id: int = int(data['id']) + self.id: int = int(data["id"]) async def _get_channel(self): return self def __str__(self) -> str: if self.recipient: - return f'Direct Message with {self.recipient}' - return 'Direct Message with Unknown User' + return f"Direct Message with {self.recipient}" + return "Direct Message with Unknown User" def __repr__(self) -> str: - return f'' + return f"" @classmethod def _from_message(cls: Type[DMC], state: ConnectionState, channel_id: int) -> DMC: @@ -1892,19 +2016,32 @@ class GroupChannel(discord.abc.Messageable, Hashable): The group channel's name if provided. """ - __slots__ = ('id', 'recipients', 'owner_id', 'owner', '_icon', 'name', 'me', '_state') + __slots__ = ( + "id", + "recipients", + "owner_id", + "owner", + "_icon", + "name", + "me", + "_state", + ) - def __init__(self, *, me: ClientUser, state: ConnectionState, data: GroupChannelPayload): + def __init__( + self, *, me: ClientUser, state: ConnectionState, data: GroupChannelPayload + ): self._state: ConnectionState = state - self.id: int = int(data['id']) + self.id: int = int(data["id"]) self.me: ClientUser = me self._update_group(data) def _update_group(self, data: GroupChannelPayload) -> None: - self.owner_id: Optional[int] = utils._get_as_snowflake(data, 'owner_id') - self._icon: Optional[str] = data.get('icon') - self.name: Optional[str] = data.get('name') - self.recipients: List[User] = [self._state.store_user(u) for u in data.get('recipients', [])] + self.owner_id: Optional[int] = utils._get_as_snowflake(data, "owner_id") + self._icon: Optional[str] = data.get("icon") + self.name: Optional[str] = data.get("name") + self.recipients: List[User] = [ + self._state.store_user(u) for u in data.get("recipients", []) + ] self.owner: Optional[BaseUser] if self.owner_id == self.me.id: @@ -1920,12 +2057,12 @@ class GroupChannel(discord.abc.Messageable, Hashable): return self.name if len(self.recipients) == 0: - return 'Unnamed' + return "Unnamed" - return ', '.join(map(lambda x: x.name, self.recipients)) + return ", ".join(map(lambda x: x.name, self.recipients)) def __repr__(self) -> str: - return f'' + return f"" @property def type(self) -> ChannelType: @@ -1937,7 +2074,7 @@ class GroupChannel(discord.abc.Messageable, Hashable): """Optional[:class:`Asset`]: Returns the channel's icon asset if available.""" if self._icon is None: return None - return Asset._from_icon(self._state, self.id, self._icon, path='channel') + return Asset._from_icon(self._state, self.id, self._icon, path="channel") @property def created_at(self) -> datetime.datetime: @@ -2032,7 +2169,9 @@ class PartialMessageable(discord.abc.Messageable, Hashable): The channel type associated with this partial messageable, if given. """ - def __init__(self, state: ConnectionState, id: int, type: Optional[ChannelType] = None): + def __init__( + self, state: ConnectionState, id: int, type: Optional[ChannelType] = None + ): self._state: ConnectionState = state self._channel: Object = Object(id=id) self.id: int = id @@ -2093,13 +2232,21 @@ def _channel_factory(channel_type: int): def _threaded_channel_factory(channel_type: int): cls, value = _channel_factory(channel_type) - if value in (ChannelType.private_thread, ChannelType.public_thread, ChannelType.news_thread): + if value in ( + ChannelType.private_thread, + ChannelType.public_thread, + ChannelType.news_thread, + ): return Thread, value return cls, value def _threaded_guild_channel_factory(channel_type: int): cls, value = _guild_channel_factory(channel_type) - if value in (ChannelType.private_thread, ChannelType.public_thread, ChannelType.news_thread): + if value in ( + ChannelType.private_thread, + ChannelType.public_thread, + ChannelType.news_thread, + ): return Thread, value return cls, value diff --git a/discord/client.py b/discord/client.py index b4f1db17..494c5322 100644 --- a/discord/client.py +++ b/discord/client.py @@ -29,7 +29,20 @@ import logging import signal import sys import traceback -from typing import Any, Callable, Coroutine, Dict, Generator, List, Optional, Sequence, TYPE_CHECKING, Tuple, TypeVar, Union +from typing import ( + Any, + Callable, + Coroutine, + Dict, + Generator, + List, + Optional, + Sequence, + TYPE_CHECKING, + Tuple, + TypeVar, + Union, +) import aiohttp @@ -69,46 +82,49 @@ if TYPE_CHECKING: from .member import Member from .voice_client import VoiceProtocol -__all__ = ( - 'Client', -) +__all__ = ("Client",) -Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]]) +Coro = TypeVar("Coro", bound=Callable[..., Coroutine[Any, Any, Any]]) _log = logging.getLogger(__name__) + def _cancel_tasks(loop: asyncio.AbstractEventLoop) -> None: tasks = {t for t in asyncio.all_tasks(loop=loop) if not t.done()} if not tasks: return - _log.info('Cleaning up after %d tasks.', len(tasks)) + _log.info("Cleaning up after %d tasks.", len(tasks)) for task in tasks: task.cancel() loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) - _log.info('All tasks finished cancelling.') + _log.info("All tasks finished cancelling.") for task in tasks: if task.cancelled(): continue if task.exception() is not None: - loop.call_exception_handler({ - 'message': 'Unhandled exception during Client.run shutdown.', - 'exception': task.exception(), - 'task': task - }) + loop.call_exception_handler( + { + "message": "Unhandled exception during Client.run shutdown.", + "exception": task.exception(), + "task": task, + } + ) + def _cleanup_loop(loop: asyncio.AbstractEventLoop) -> None: try: _cancel_tasks(loop) loop.run_until_complete(loop.shutdown_asyncgens()) finally: - _log.info('Closing the event loop.') + _log.info("Closing the event loop.") loop.close() + class Client: r"""Represents a client connection that connects to Discord. This class is used to interact with the Discord WebSocket and API. @@ -199,6 +215,7 @@ class Client: loop: :class:`asyncio.AbstractEventLoop` The event loop that the client uses for asynchronous operations. """ + def __init__( self, *, @@ -210,26 +227,34 @@ class Client: # self.ws is set in the connect method self.ws: DiscordWebSocket = None # type: ignore - self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop - self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {} - self.shard_id: Optional[int] = options.get('shard_id') - self.shard_count: Optional[int] = options.get('shard_count') + self.loop: asyncio.AbstractEventLoop = ( + asyncio.get_event_loop() if loop is None else loop + ) + self._listeners: Dict[ + str, List[Tuple[asyncio.Future, Callable[..., bool]]] + ] = {} + self.shard_id: Optional[int] = options.get("shard_id") + self.shard_count: Optional[int] = options.get("shard_count") - connector: Optional[aiohttp.BaseConnector] = options.pop('connector', None) - proxy: Optional[str] = options.pop('proxy', None) - proxy_auth: Optional[aiohttp.BasicAuth] = options.pop('proxy_auth', None) - unsync_clock: bool = options.pop('assume_unsync_clock', True) - self.http: HTTPClient = HTTPClient(connector, proxy=proxy, proxy_auth=proxy_auth, unsync_clock=unsync_clock, loop=self.loop) + connector: Optional[aiohttp.BaseConnector] = options.pop("connector", None) + proxy: Optional[str] = options.pop("proxy", None) + proxy_auth: Optional[aiohttp.BasicAuth] = options.pop("proxy_auth", None) + unsync_clock: bool = options.pop("assume_unsync_clock", True) + self.http: HTTPClient = HTTPClient( + connector, + proxy=proxy, + proxy_auth=proxy_auth, + unsync_clock=unsync_clock, + loop=self.loop, + ) - self._handlers: Dict[str, Callable] = { - 'ready': self._handle_ready - } + self._handlers: Dict[str, Callable] = {"ready": self._handle_ready} self._hooks: Dict[str, Callable] = { - 'before_identify': self._call_before_identify_hook + "before_identify": self._call_before_identify_hook } - self._enable_debug_events: bool = options.pop('enable_debug_events', False) + self._enable_debug_events: bool = options.pop("enable_debug_events", False) self._connection: ConnectionState = self._get_state(**options) self._connection.shard_count = self.shard_count self._closed: bool = False @@ -243,12 +268,20 @@ class Client: # internals - def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket: + def _get_websocket( + self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None + ) -> DiscordWebSocket: return self.ws def _get_state(self, **options: Any) -> ConnectionState: - return ConnectionState(dispatch=self.dispatch, handlers=self._handlers, - hooks=self._hooks, http=self.http, loop=self.loop, **options) + return ConnectionState( + dispatch=self.dispatch, + handlers=self._handlers, + hooks=self._hooks, + http=self.http, + loop=self.loop, + **options, + ) def _handle_ready(self) -> None: self._ready.set() @@ -260,7 +293,7 @@ class Client: This could be referred to as the Discord WebSocket protocol latency. """ ws = self.ws - return float('nan') if not ws else ws.latency + return float("nan") if not ws else ws.latency def is_ws_ratelimited(self) -> bool: """:class:`bool`: Whether the websocket is currently rate limited. @@ -331,7 +364,7 @@ class Client: If this is not passed via ``__init__`` then this is retrieved through the gateway when an event contains the data. Usually after :func:`~discord.on_connect` is called. - + .. versionadded:: 2.0 """ return self._connection.application_id @@ -348,7 +381,13 @@ class Client: """:class:`bool`: Specifies if the client's internal cache is ready for use.""" return self._ready.is_set() - async def _run_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> None: + async def _run_event( + self, + coro: Callable[..., Coroutine[Any, Any, Any]], + event_name: str, + *args: Any, + **kwargs: Any, + ) -> None: try: await coro(*args, **kwargs) except asyncio.CancelledError: @@ -359,14 +398,20 @@ class Client: except asyncio.CancelledError: pass - def _schedule_event(self, coro: Callable[..., Coroutine[Any, Any, Any]], event_name: str, *args: Any, **kwargs: Any) -> asyncio.Task: + def _schedule_event( + self, + coro: Callable[..., Coroutine[Any, Any, Any]], + event_name: str, + *args: Any, + **kwargs: Any, + ) -> asyncio.Task: wrapped = self._run_event(coro, event_name, *args, **kwargs) # Schedules the task - return asyncio.create_task(wrapped, name=f'discord.py: {event_name}') + return asyncio.create_task(wrapped, name=f"discord.py: {event_name}") def dispatch(self, event: str, *args: Any, **kwargs: Any) -> None: - _log.debug('Dispatching event %s', event) - method = 'on_' + event + _log.debug("Dispatching event %s", event) + method = "on_" + event listeners = self._listeners.get(event) if listeners: @@ -413,18 +458,22 @@ class Client: overridden to have a different implementation. Check :func:`~discord.on_error` for more details. """ - print(f'Ignoring exception in {event_method}', file=sys.stderr) + print(f"Ignoring exception in {event_method}", file=sys.stderr) traceback.print_exc() # hooks - async def _call_before_identify_hook(self, shard_id: Optional[int], *, initial: bool = False) -> None: + async def _call_before_identify_hook( + self, shard_id: Optional[int], *, initial: bool = False + ) -> None: # This hook is an internal hook that actually calls the public one. # It allows the library to have its own hook without stepping on the # toes of those who need to override their own hook. await self.before_identify_hook(shard_id, initial=initial) - async def before_identify_hook(self, shard_id: Optional[int], *, initial: bool = False) -> None: + async def before_identify_hook( + self, shard_id: Optional[int], *, initial: bool = False + ) -> None: """|coro| A hook that is called before IDENTIFYing a session. This is useful @@ -470,7 +519,7 @@ class Client: passing status code. """ - _log.info('logging in using static token') + _log.info("logging in using static token") data = await self.http.static_login(token.strip()) self._connection.user = ClientUser(state=self._connection, data=data) @@ -502,29 +551,35 @@ class Client: backoff = ExponentialBackoff() ws_params = { - 'initial': True, - 'shard_id': self.shard_id, + "initial": True, + "shard_id": self.shard_id, } while not self.is_closed(): try: coro = DiscordWebSocket.from_client(self, **ws_params) self.ws = await asyncio.wait_for(coro, timeout=60.0) - ws_params['initial'] = False + ws_params["initial"] = False while True: await self.ws.poll_event() except ReconnectWebSocket as e: - _log.info('Got a request to %s the websocket.', e.op) - self.dispatch('disconnect') - ws_params.update(sequence=self.ws.sequence, resume=e.resume, session=self.ws.session_id) + _log.info("Got a request to %s the websocket.", e.op) + self.dispatch("disconnect") + ws_params.update( + sequence=self.ws.sequence, + resume=e.resume, + session=self.ws.session_id, + ) continue - except (OSError, - HTTPException, - GatewayNotFound, - ConnectionClosed, - aiohttp.ClientError, - asyncio.TimeoutError) as exc: + except ( + OSError, + HTTPException, + GatewayNotFound, + ConnectionClosed, + aiohttp.ClientError, + asyncio.TimeoutError, + ) as exc: - self.dispatch('disconnect') + self.dispatch("disconnect") if not reconnect: await self.close() if isinstance(exc, ConnectionClosed) and exc.code == 1000: @@ -537,7 +592,12 @@ class Client: # If we get connection reset by peer then try to RESUME if isinstance(exc, OSError) and exc.errno in (54, 10054): - ws_params.update(sequence=self.ws.sequence, initial=False, resume=True, session=self.ws.session_id) + ws_params.update( + sequence=self.ws.sequence, + initial=False, + resume=True, + session=self.ws.session_id, + ) continue # We should only get this when an unhandled close code happens, @@ -557,7 +617,9 @@ class Client: # Always try to RESUME the connection # If the connection is not RESUME-able then the gateway will invalidate the session. # This is apparently what the official Discord client does. - ws_params.update(sequence=self.ws.sequence, resume=True, session=self.ws.session_id) + ws_params.update( + sequence=self.ws.sequence, resume=True, session=self.ws.session_id + ) async def close(self) -> None: """|coro| @@ -654,10 +716,10 @@ class Client: try: loop.run_forever() except KeyboardInterrupt: - _log.info('Received signal to terminate bot and event loop.') + _log.info("Received signal to terminate bot and event loop.") finally: future.remove_done_callback(stop_loop_on_completion) - _log.info('Cleaning up tasks.') + _log.info("Cleaning up tasks.") _cleanup_loop(loop) if not future.cancelled(): @@ -686,10 +748,10 @@ class Client: self._connection._activity = None elif isinstance(value, BaseActivity): # ConnectionState._activity is typehinted as ActivityPayload, we're passing Dict[str, Any] - self._connection._activity = value.to_dict() # type: ignore + self._connection._activity = value.to_dict() # type: ignore else: - raise TypeError('activity must derive from BaseActivity.') - + raise TypeError("activity must derive from BaseActivity.") + @property def status(self): """:class:`.Status`: @@ -704,11 +766,11 @@ class Client: @status.setter def status(self, value): if value is Status.offline: - self._connection._status = 'invisible' + self._connection._status = "invisible" elif isinstance(value, Status): self._connection._status = str(value) else: - raise TypeError('status must derive from Status.') + raise TypeError("status must derive from Status.") @property def allowed_mentions(self) -> Optional[AllowedMentions]: @@ -723,7 +785,9 @@ class Client: if value is None or isinstance(value, AllowedMentions): self._connection.allowed_mentions = value else: - raise TypeError(f'allowed_mentions must be AllowedMentions not {value.__class__!r}') + raise TypeError( + f"allowed_mentions must be AllowedMentions not {value.__class__!r}" + ) @property def intents(self) -> Intents: @@ -740,7 +804,9 @@ class Client: """List[:class:`~discord.User`]: Returns a list of all the users the bot can see.""" return list(self._connection._users.values()) - def get_channel(self, id: int, /) -> Optional[Union[GuildChannel, Thread, PrivateChannel]]: + def get_channel( + self, id: int, / + ) -> Optional[Union[GuildChannel, Thread, PrivateChannel]]: """Returns a channel or thread with the given ID. Parameters @@ -755,12 +821,14 @@ class Client: """ return self._connection.get_channel(id) - def get_partial_messageable(self, id: int, *, type: Optional[ChannelType] = None) -> PartialMessageable: + def get_partial_messageable( + self, id: int, *, type: Optional[ChannelType] = None + ) -> PartialMessageable: """Returns a partial messageable with the given channel ID. This is useful if you have a channel_id but don't want to do an API call to send messages to it. - + .. versionadded:: 2.0 Parameters @@ -1001,8 +1069,10 @@ class Client: future = self.loop.create_future() if check is None: + def _check(*args): return True + check = _check ev = event.lower() @@ -1040,10 +1110,10 @@ class Client: """ if not asyncio.iscoroutinefunction(coro): - raise TypeError('event registered must be a coroutine function') + raise TypeError("event registered must be a coroutine function") setattr(self, coro.__name__, coro) - _log.debug('%s has successfully been registered as an event', coro.__name__) + _log.debug("%s has successfully been registered as an event", coro.__name__) return coro async def change_presence( @@ -1082,10 +1152,10 @@ class Client: """ if status is None: - status_str = 'online' + status_str = "online" status = Status.online elif status is Status.offline: - status_str = 'invisible' + status_str = "invisible" status = Status.offline else: status_str = str(status) @@ -1111,7 +1181,7 @@ class Client: *, limit: Optional[int] = 100, before: SnowflakeTime = None, - after: SnowflakeTime = None + after: SnowflakeTime = None, ) -> GuildIterator: """Retrieves an :class:`.AsyncIterator` that enables receiving your guilds. @@ -1191,7 +1261,7 @@ class Client: """ code = utils.resolve_template(code) data = await self.http.get_template(code) - return Template(data=data, state=self._connection) # type: ignore + return Template(data=data, state=self._connection) # type: ignore async def fetch_guild(self, guild_id: int, /) -> Guild: """|coro| @@ -1277,7 +1347,9 @@ class Client: region_value = str(region) if code: - data = await self.http.create_from_template(code, name, region_value, icon_base64) + data = await self.http.create_from_template( + code, name, region_value, icon_base64 + ) else: data = await self.http.create_guild(name, region_value, icon_base64) return Guild(data=data, state=self._connection) @@ -1307,12 +1379,18 @@ class Client: The stage instance from the stage channel ID. """ data = await self.http.get_stage_instance(channel_id) - guild = self.get_guild(int(data['guild_id'])) + guild = self.get_guild(int(data["guild_id"])) return StageInstance(guild=guild, state=self._connection, data=data) # type: ignore # Invite management - async def fetch_invite(self, url: Union[Invite, str], *, with_counts: bool = True, with_expiration: bool = True) -> Invite: + async def fetch_invite( + self, + url: Union[Invite, str], + *, + with_counts: bool = True, + with_expiration: bool = True, + ) -> Invite: """|coro| Gets an :class:`.Invite` from a discord.gg URL or ID. @@ -1351,7 +1429,9 @@ class Client: """ invite_id = utils.resolve_invite(url) - data = await self.http.get_invite(invite_id, with_counts=with_counts, with_expiration=with_expiration) + data = await self.http.get_invite( + invite_id, with_counts=with_counts, with_expiration=with_expiration + ) return Invite.from_incomplete(state=self._connection, data=data) async def delete_invite(self, invite: Union[Invite, str]) -> None: @@ -1428,8 +1508,8 @@ class Client: The bot's application information. """ data = await self.http.application_info() - if 'rpc_origins' not in data: - data['rpc_origins'] = None + if "rpc_origins" not in data: + data["rpc_origins"] = None return AppInfo(self._connection, data) async def fetch_user(self, user_id: int, /) -> User: @@ -1463,7 +1543,9 @@ class Client: data = await self.http.get_user(user_id) return User(state=self._connection, data=data) - async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, PrivateChannel, Thread]: + async def fetch_channel( + self, channel_id: int, / + ) -> Union[GuildChannel, PrivateChannel, Thread]: """|coro| Retrieves a :class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, or :class:`.Thread` with the specified ID. @@ -1492,19 +1574,21 @@ class Client: """ data = await self.http.get_channel(channel_id) - factory, ch_type = _threaded_channel_factory(data['type']) + factory, ch_type = _threaded_channel_factory(data["type"]) if factory is None: - raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data)) + raise InvalidData( + "Unknown channel type {type} for channel ID {id}.".format_map(data) + ) if ch_type in (ChannelType.group, ChannelType.private): # the factory will be a DMChannel or GroupChannel here - channel = factory(me=self.user, data=data, state=self._connection) # type: ignore + channel = factory(me=self.user, data=data, state=self._connection) # type: ignore else: # the factory can't be a DMChannel or GroupChannel here - guild_id = int(data['guild_id']) # type: ignore + guild_id = int(data["guild_id"]) # type: ignore guild = self.get_guild(guild_id) or Object(id=guild_id) # GuildChannels expect a Guild, we may be passing an Object - channel = factory(guild=guild, state=self._connection, data=data) # type: ignore + channel = factory(guild=guild, state=self._connection, data=data) # type: ignore return channel @@ -1530,7 +1614,9 @@ class Client: data = await self.http.get_webhook(webhook_id) return Webhook.from_state(data, state=self._connection) - async def fetch_sticker(self, sticker_id: int, /) -> Union[StandardSticker, GuildSticker]: + async def fetch_sticker( + self, sticker_id: int, / + ) -> Union[StandardSticker, GuildSticker]: """|coro| Retrieves a :class:`.Sticker` with the specified ID. @@ -1550,8 +1636,8 @@ class Client: The sticker you requested. """ data = await self.http.get_sticker(sticker_id) - cls, _ = _sticker_factory(data['type']) # type: ignore - return cls(state=self._connection, data=data) # type: ignore + cls, _ = _sticker_factory(data["type"]) # type: ignore + return cls(state=self._connection, data=data) # type: ignore async def fetch_premium_sticker_packs(self) -> List[StickerPack]: """|coro| @@ -1571,7 +1657,10 @@ class Client: All available premium sticker packs. """ data = await self.http.list_premium_sticker_packs() - return [StickerPack(state=self._connection, data=pack) for pack in data['sticker_packs']] + return [ + StickerPack(state=self._connection, data=pack) + for pack in data["sticker_packs"] + ] async def create_dm(self, user: Snowflake) -> DMChannel: """|coro| @@ -1606,7 +1695,7 @@ class Client: This method should be used for when a view is comprised of components that last longer than the lifecycle of the program. - + .. versionadded:: 2.0 Parameters @@ -1628,17 +1717,19 @@ class Client: """ if not isinstance(view, View): - raise TypeError(f'expected an instance of View not {view.__class__!r}') + raise TypeError(f"expected an instance of View not {view.__class__!r}") if not view.is_persistent(): - raise ValueError('View is not persistent. Items need to have a custom_id set and View must have no timeout') + raise ValueError( + "View is not persistent. Items need to have a custom_id set and View must have no timeout" + ) self._connection.store_view(view, message_id) @property def persistent_views(self) -> Sequence[View]: """Sequence[:class:`.View`]: A sequence of persistent views added to the client. - + .. versionadded:: 2.0 """ return self._connection.persistent_views diff --git a/discord/colour.py b/discord/colour.py index 927addc1..e1caa8db 100644 --- a/discord/colour.py +++ b/discord/colour.py @@ -35,11 +35,11 @@ from typing import ( ) __all__ = ( - 'Colour', - 'Color', + "Colour", + "Color", ) -CT = TypeVar('CT', bound='Colour') +CT = TypeVar("CT", bound="Colour") class Colour: @@ -76,16 +76,18 @@ class Colour: The raw integer colour value. """ - __slots__ = ('value',) + __slots__ = ("value",) def __init__(self, value: int): if not isinstance(value, int): - raise TypeError(f'Expected int parameter, received {value.__class__.__name__} instead.') + raise TypeError( + f"Expected int parameter, received {value.__class__.__name__} instead." + ) self.value: int = value def _get_byte(self, byte: int) -> int: - return (self.value >> (8 * byte)) & 0xff + return (self.value >> (8 * byte)) & 0xFF def __eq__(self, other: Any) -> bool: return isinstance(other, Colour) and self.value == other.value @@ -94,13 +96,13 @@ class Colour: return not self.__eq__(other) def __str__(self) -> str: - return f'#{self.value:0>6x}' + return f"#{self.value:0>6x}" def __int__(self) -> int: return self.value def __repr__(self) -> str: - return f'' + return f"" def __hash__(self) -> int: return hash(self.value) @@ -141,7 +143,11 @@ class Colour: return cls(0) @classmethod - def random(cls: Type[CT], *, seed: Optional[Union[int, str, float, bytes, bytearray]] = None) -> CT: + def random( + cls: Type[CT], + *, + seed: Optional[Union[int, str, float, bytes, bytearray]] = None, + ) -> CT: """A factory method that returns a :class:`Colour` with a random hue. .. note:: @@ -164,12 +170,12 @@ class Colour: @classmethod def teal(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x1abc9c``.""" - return cls(0x1abc9c) + return cls(0x1ABC9C) @classmethod def dark_teal(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x11806a``.""" - return cls(0x11806a) + return cls(0x11806A) @classmethod def brand_green(cls: Type[CT]) -> CT: @@ -182,17 +188,17 @@ class Colour: @classmethod def green(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x2ecc71``.""" - return cls(0x2ecc71) + return cls(0x2ECC71) @classmethod def dark_green(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x1f8b4c``.""" - return cls(0x1f8b4c) + return cls(0x1F8B4C) @classmethod def blue(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x3498db``.""" - return cls(0x3498db) + return cls(0x3498DB) @classmethod def dark_blue(cls: Type[CT]) -> CT: @@ -202,42 +208,42 @@ class Colour: @classmethod def purple(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x9b59b6``.""" - return cls(0x9b59b6) + return cls(0x9B59B6) @classmethod def dark_purple(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x71368a``.""" - return cls(0x71368a) + return cls(0x71368A) @classmethod def magenta(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0xe91e63``.""" - return cls(0xe91e63) + return cls(0xE91E63) @classmethod def dark_magenta(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0xad1457``.""" - return cls(0xad1457) + return cls(0xAD1457) @classmethod def gold(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0xf1c40f``.""" - return cls(0xf1c40f) + return cls(0xF1C40F) @classmethod def dark_gold(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0xc27c0e``.""" - return cls(0xc27c0e) + return cls(0xC27C0E) @classmethod def orange(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0xe67e22``.""" - return cls(0xe67e22) + return cls(0xE67E22) @classmethod def dark_orange(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0xa84300``.""" - return cls(0xa84300) + return cls(0xA84300) @classmethod def brand_red(cls: Type[CT]) -> CT: @@ -250,45 +256,45 @@ class Colour: @classmethod def red(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0xe74c3c``.""" - return cls(0xe74c3c) + return cls(0xE74C3C) @classmethod def dark_red(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x992d22``.""" - return cls(0x992d22) + return cls(0x992D22) @classmethod def lighter_grey(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x95a5a6``.""" - return cls(0x95a5a6) + return cls(0x95A5A6) lighter_gray = lighter_grey @classmethod def dark_grey(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x607d8b``.""" - return cls(0x607d8b) + return cls(0x607D8B) dark_gray = dark_grey @classmethod def light_grey(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x979c9f``.""" - return cls(0x979c9f) + return cls(0x979C9F) light_gray = light_grey @classmethod def darker_grey(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x546e7a``.""" - return cls(0x546e7a) + return cls(0x546E7A) darker_gray = darker_grey @classmethod def og_blurple(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x7289da``.""" - return cls(0x7289da) + return cls(0x7289DA) @classmethod def blurple(cls: Type[CT]) -> CT: @@ -298,7 +304,7 @@ class Colour: @classmethod def greyple(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x99aab5``.""" - return cls(0x99aab5) + return cls(0x99AAB5) @classmethod def dark_theme(cls: Type[CT]) -> CT: @@ -324,7 +330,7 @@ class Colour: .. versionadded:: 2.0 """ return cls(0xFEE75C) - + @classmethod def dark_blurple(cls: Type[CT]) -> CT: """A factory method that returns a :class:`Colour` with a value of ``0x4E5D94``. diff --git a/discord/components.py b/discord/components.py index 74c7be3d..6ef1a4f9 100644 --- a/discord/components.py +++ b/discord/components.py @@ -24,7 +24,18 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union +from typing import ( + Any, + ClassVar, + Dict, + List, + Optional, + TYPE_CHECKING, + Tuple, + Type, + TypeVar, + Union, +) from .enums import try_enum, ComponentType, ButtonStyle from .utils import get_slots, MISSING from .partial_emoji import PartialEmoji, _EmojiTag @@ -41,14 +52,14 @@ if TYPE_CHECKING: __all__ = ( - 'Component', - 'ActionRow', - 'Button', - 'SelectMenu', - 'SelectOption', + "Component", + "ActionRow", + "Button", + "SelectMenu", + "SelectOption", ) -C = TypeVar('C', bound='Component') +C = TypeVar("C", bound="Component") class Component: @@ -70,14 +81,14 @@ class Component: The type of component. """ - __slots__: Tuple[str, ...] = ('type',) + __slots__: Tuple[str, ...] = ("type",) __repr_info__: ClassVar[Tuple[str, ...]] type: ComponentType def __repr__(self) -> str: - attrs = ' '.join(f'{key}={getattr(self, key)!r}' for key in self.__repr_info__) - return f'<{self.__class__.__name__} {attrs}>' + attrs = " ".join(f"{key}={getattr(self, key)!r}" for key in self.__repr_info__) + return f"<{self.__class__.__name__} {attrs}>" @classmethod def _raw_construct(cls: Type[C], **kwargs) -> C: @@ -112,18 +123,20 @@ class ActionRow(Component): The children components that this holds, if any. """ - __slots__: Tuple[str, ...] = ('children',) + __slots__: Tuple[str, ...] = ("children",) __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ def __init__(self, data: ComponentPayload): - self.type: ComponentType = try_enum(ComponentType, data['type']) - self.children: List[Component] = [_component_factory(d) for d in data.get('components', [])] + self.type: ComponentType = try_enum(ComponentType, data["type"]) + self.children: List[Component] = [ + _component_factory(d) for d in data.get("components", []) + ] def to_dict(self) -> ActionRowPayload: return { - 'type': int(self.type), - 'components': [child.to_dict() for child in self.children], + "type": int(self.type), + "components": [child.to_dict() for child in self.children], } # type: ignore @@ -157,44 +170,44 @@ class Button(Component): """ __slots__: Tuple[str, ...] = ( - 'style', - 'custom_id', - 'url', - 'disabled', - 'label', - 'emoji', + "style", + "custom_id", + "url", + "disabled", + "label", + "emoji", ) __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ def __init__(self, data: ButtonComponentPayload): - self.type: ComponentType = try_enum(ComponentType, data['type']) - self.style: ButtonStyle = try_enum(ButtonStyle, data['style']) - self.custom_id: Optional[str] = data.get('custom_id') - self.url: Optional[str] = data.get('url') - self.disabled: bool = data.get('disabled', False) - self.label: Optional[str] = data.get('label') + self.type: ComponentType = try_enum(ComponentType, data["type"]) + self.style: ButtonStyle = try_enum(ButtonStyle, data["style"]) + self.custom_id: Optional[str] = data.get("custom_id") + self.url: Optional[str] = data.get("url") + self.disabled: bool = data.get("disabled", False) + self.label: Optional[str] = data.get("label") self.emoji: Optional[PartialEmoji] try: - self.emoji = PartialEmoji.from_dict(data['emoji']) + self.emoji = PartialEmoji.from_dict(data["emoji"]) except KeyError: self.emoji = None def to_dict(self) -> ButtonComponentPayload: payload = { - 'type': 2, - 'style': int(self.style), - 'label': self.label, - 'disabled': self.disabled, + "type": 2, + "style": int(self.style), + "label": self.label, + "disabled": self.disabled, } if self.custom_id: - payload['custom_id'] = self.custom_id + payload["custom_id"] = self.custom_id if self.url: - payload['url'] = self.url + payload["url"] = self.url if self.emoji: - payload['emoji'] = self.emoji.to_dict() + payload["emoji"] = self.emoji.to_dict() return payload # type: ignore @@ -231,37 +244,39 @@ class SelectMenu(Component): """ __slots__: Tuple[str, ...] = ( - 'custom_id', - 'placeholder', - 'min_values', - 'max_values', - 'options', - 'disabled', + "custom_id", + "placeholder", + "min_values", + "max_values", + "options", + "disabled", ) __repr_info__: ClassVar[Tuple[str, ...]] = __slots__ def __init__(self, data: SelectMenuPayload): self.type = ComponentType.select - self.custom_id: str = data['custom_id'] - self.placeholder: Optional[str] = data.get('placeholder') - self.min_values: int = data.get('min_values', 1) - self.max_values: int = data.get('max_values', 1) - self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get('options', [])] - self.disabled: bool = data.get('disabled', False) + self.custom_id: str = data["custom_id"] + self.placeholder: Optional[str] = data.get("placeholder") + self.min_values: int = data.get("min_values", 1) + self.max_values: int = data.get("max_values", 1) + self.options: List[SelectOption] = [ + SelectOption.from_dict(option) for option in data.get("options", []) + ] + self.disabled: bool = data.get("disabled", False) def to_dict(self) -> SelectMenuPayload: payload: SelectMenuPayload = { - 'type': self.type.value, - 'custom_id': self.custom_id, - 'min_values': self.min_values, - 'max_values': self.max_values, - 'options': [op.to_dict() for op in self.options], - 'disabled': self.disabled, + "type": self.type.value, + "custom_id": self.custom_id, + "min_values": self.min_values, + "max_values": self.max_values, + "options": [op.to_dict() for op in self.options], + "disabled": self.disabled, } if self.placeholder: - payload['placeholder'] = self.placeholder + payload["placeholder"] = self.placeholder return payload @@ -292,11 +307,11 @@ class SelectOption: """ __slots__: Tuple[str, ...] = ( - 'label', - 'value', - 'description', - 'emoji', - 'default', + "label", + "value", + "description", + "emoji", + "default", ) def __init__( @@ -318,60 +333,62 @@ class SelectOption: elif isinstance(emoji, _EmojiTag): emoji = emoji._to_partial() else: - raise TypeError(f'expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}') + raise TypeError( + f"expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}" + ) self.emoji = emoji self.default = default def __repr__(self) -> str: return ( - f'' + f"" ) def __str__(self) -> str: if self.emoji: - base = f'{self.emoji} {self.label}' + base = f"{self.emoji} {self.label}" else: base = self.label if self.description: - return f'{base}\n{self.description}' + return f"{base}\n{self.description}" return base @classmethod def from_dict(cls, data: SelectOptionPayload) -> SelectOption: try: - emoji = PartialEmoji.from_dict(data['emoji']) + emoji = PartialEmoji.from_dict(data["emoji"]) except KeyError: emoji = None return cls( - label=data['label'], - value=data['value'], - description=data.get('description'), + label=data["label"], + value=data["value"], + description=data.get("description"), emoji=emoji, - default=data.get('default', False), + default=data.get("default", False), ) def to_dict(self) -> SelectOptionPayload: payload: SelectOptionPayload = { - 'label': self.label, - 'value': self.value, - 'default': self.default, + "label": self.label, + "value": self.value, + "default": self.default, } if self.emoji: - payload['emoji'] = self.emoji.to_dict() # type: ignore + payload["emoji"] = self.emoji.to_dict() # type: ignore if self.description: - payload['description'] = self.description + payload["description"] = self.description return payload def _component_factory(data: ComponentPayload) -> Component: - component_type = data['type'] + component_type = data["type"] if component_type == 1: return ActionRow(data) elif component_type == 2: diff --git a/discord/context_managers.py b/discord/context_managers.py index a3ab0d19..226e5662 100644 --- a/discord/context_managers.py +++ b/discord/context_managers.py @@ -32,11 +32,10 @@ if TYPE_CHECKING: from types import TracebackType - TypingT = TypeVar('TypingT', bound='Typing') + TypingT = TypeVar("TypingT", bound="Typing") + +__all__ = ("Typing",) -__all__ = ( - 'Typing', -) def _typing_done_callback(fut: asyncio.Future) -> None: # just retrieve any exception and call it a day @@ -45,6 +44,7 @@ def _typing_done_callback(fut: asyncio.Future) -> None: except (asyncio.CancelledError, Exception): pass + class Typing: def __init__(self, messageable: Messageable) -> None: self.loop: asyncio.AbstractEventLoop = messageable._state.loop @@ -67,7 +67,8 @@ class Typing: self.task.add_done_callback(_typing_done_callback) return self - def __exit__(self, + def __exit__( + self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], @@ -79,7 +80,8 @@ class Typing: await channel._state.http.send_typing(channel.id) return self.__enter__() - async def __aexit__(self, + async def __aexit__( + self, exc_type: Optional[Type[BaseException]], exc_value: Optional[BaseException], traceback: Optional[TracebackType], diff --git a/discord/embeds.py b/discord/embeds.py index 25f05aef..8589a055 100644 --- a/discord/embeds.py +++ b/discord/embeds.py @@ -25,14 +25,23 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations import datetime -from typing import Any, Dict, Final, List, Mapping, Protocol, TYPE_CHECKING, Type, TypeVar, Union +from typing import ( + Any, + Dict, + Final, + List, + Mapping, + Protocol, + TYPE_CHECKING, + Type, + TypeVar, + Union, +) from . import utils from .colour import Colour -__all__ = ( - 'Embed', -) +__all__ = ("Embed",) class _EmptyEmbed: @@ -40,7 +49,7 @@ class _EmptyEmbed: return False def __repr__(self) -> str: - return 'Embed.Empty' + return "Embed.Empty" def __len__(self) -> int: return 0 @@ -57,51 +66,47 @@ class EmbedProxy: return len(self.__dict__) def __repr__(self) -> str: - inner = ', '.join((f'{k}={v!r}' for k, v in self.__dict__.items() if not k.startswith('_'))) - return f'EmbedProxy({inner})' + inner = ", ".join( + (f"{k}={v!r}" for k, v in self.__dict__.items() if not k.startswith("_")) + ) + return f"EmbedProxy({inner})" def __getattr__(self, attr: str) -> _EmptyEmbed: return EmptyEmbed -E = TypeVar('E', bound='Embed') +E = TypeVar("E", bound="Embed") if TYPE_CHECKING: from discord.types.embed import Embed as EmbedData, EmbedType - T = TypeVar('T') + T = TypeVar("T") MaybeEmpty = Union[T, _EmptyEmbed] - class _EmbedFooterProxy(Protocol): text: MaybeEmpty[str] icon_url: MaybeEmpty[str] - class _EmbedFieldProxy(Protocol): name: MaybeEmpty[str] value: MaybeEmpty[str] inline: bool - class _EmbedMediaProxy(Protocol): url: MaybeEmpty[str] proxy_url: MaybeEmpty[str] height: MaybeEmpty[int] width: MaybeEmpty[int] - class _EmbedVideoProxy(Protocol): url: MaybeEmpty[str] height: MaybeEmpty[int] width: MaybeEmpty[int] - class _EmbedProviderProxy(Protocol): name: MaybeEmpty[str] url: MaybeEmpty[str] - class _EmbedAuthorProxy(Protocol): name: MaybeEmpty[str] url: MaybeEmpty[str] @@ -163,33 +168,33 @@ class Embed: """ __slots__ = ( - 'title', - 'url', - 'type', - '_timestamp', - '_colour', - '_footer', - '_image', - '_thumbnail', - '_video', - '_provider', - '_author', - '_fields', - 'description', + "title", + "url", + "type", + "_timestamp", + "_colour", + "_footer", + "_image", + "_thumbnail", + "_video", + "_provider", + "_author", + "_fields", + "description", ) Empty: Final = EmptyEmbed def __init__( - self, - *, - colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed, - color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed, - title: MaybeEmpty[Any] = EmptyEmbed, - type: EmbedType = 'rich', - url: MaybeEmpty[Any] = EmptyEmbed, - description: MaybeEmpty[Any] = EmptyEmbed, - timestamp: datetime.datetime = None, + self, + *, + colour: Union[int, Colour, _EmptyEmbed] = EmptyEmbed, + color: Union[int, Colour, _EmptyEmbed] = EmptyEmbed, + title: MaybeEmpty[Any] = EmptyEmbed, + type: EmbedType = "rich", + url: MaybeEmpty[Any] = EmptyEmbed, + description: MaybeEmpty[Any] = EmptyEmbed, + timestamp: datetime.datetime = None, ): self.colour = colour if colour is not EmptyEmbed else color @@ -231,10 +236,10 @@ class Embed: # fill in the basic fields - self.title = data.get('title', EmptyEmbed) - self.type = data.get('type', EmptyEmbed) - self.description = data.get('description', EmptyEmbed) - self.url = data.get('url', EmptyEmbed) + self.title = data.get("title", EmptyEmbed) + self.type = data.get("type", EmptyEmbed) + self.description = data.get("description", EmptyEmbed) + self.url = data.get("url", EmptyEmbed) if self.title is not EmptyEmbed: self.title = str(self.title) @@ -248,22 +253,30 @@ class Embed: # try to fill in the more rich fields try: - self._colour = Colour(value=data['color']) + self._colour = Colour(value=data["color"]) except KeyError: pass try: - self._timestamp = utils.parse_time(data['timestamp']) + self._timestamp = utils.parse_time(data["timestamp"]) except KeyError: pass - for attr in ('thumbnail', 'video', 'provider', 'author', 'fields', 'image', 'footer'): + for attr in ( + "thumbnail", + "video", + "provider", + "author", + "fields", + "image", + "footer", + ): try: value = data[attr] except KeyError: continue else: - setattr(self, '_' + attr, value) + setattr(self, "_" + attr, value) return self @@ -273,11 +286,11 @@ class Embed: def __len__(self) -> int: total = len(self.title) + len(self.description) - for field in getattr(self, '_fields', []): - total += len(field['name']) + len(field['value']) + for field in getattr(self, "_fields", []): + total += len(field["name"]) + len(field["value"]) try: - footer_text = self._footer['text'] + footer_text = self._footer["text"] except (AttributeError, KeyError): pass else: @@ -288,7 +301,7 @@ class Embed: except AttributeError: pass else: - total += len(author['name']) + total += len(author["name"]) return total @@ -312,7 +325,7 @@ class Embed: @property def colour(self) -> MaybeEmpty[Colour]: - return getattr(self, '_colour', EmptyEmbed) + return getattr(self, "_colour", EmptyEmbed) @colour.setter def colour(self, value: Union[int, Colour, _EmptyEmbed]): # type: ignore @@ -321,13 +334,15 @@ class Embed: elif isinstance(value, int): self._colour = Colour(value=value) else: - raise TypeError(f'Expected discord.Colour, int, or Embed.Empty but received {value.__class__.__name__} instead.') + raise TypeError( + f"Expected discord.Colour, int, or Embed.Empty but received {value.__class__.__name__} instead." + ) color = colour @property def timestamp(self) -> MaybeEmpty[datetime.datetime]: - return getattr(self, '_timestamp', EmptyEmbed) + return getattr(self, "_timestamp", EmptyEmbed) @timestamp.setter def timestamp(self, value: MaybeEmpty[datetime.datetime]): @@ -338,7 +353,9 @@ class Embed: elif isinstance(value, _EmptyEmbed): self._timestamp = value else: - raise TypeError(f"Expected datetime.datetime or Embed.Empty received {value.__class__.__name__} instead") + raise TypeError( + f"Expected datetime.datetime or Embed.Empty received {value.__class__.__name__} instead" + ) @property def footer(self) -> _EmbedFooterProxy: @@ -348,9 +365,14 @@ class Embed: If the attribute has no value then :attr:`Empty` is returned. """ - return EmbedProxy(getattr(self, '_footer', {})) # type: ignore + return EmbedProxy(getattr(self, "_footer", {})) # type: ignore - def set_footer(self: E, *, text: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed) -> E: + def set_footer( + self: E, + *, + text: MaybeEmpty[Any] = EmptyEmbed, + icon_url: MaybeEmpty[Any] = EmptyEmbed, + ) -> E: """Sets the footer for the embed content. This function returns the class instance to allow for fluent-style @@ -366,10 +388,10 @@ class Embed: self._footer = {} if text is not EmptyEmbed: - self._footer['text'] = str(text) + self._footer["text"] = str(text) if icon_url is not EmptyEmbed: - self._footer['icon_url'] = str(icon_url) + self._footer["icon_url"] = str(icon_url) return self @@ -401,12 +423,12 @@ class Embed: If the attribute has no value then :attr:`Empty` is returned. """ - return EmbedProxy(getattr(self, '_image', {})) # type: ignore + return EmbedProxy(getattr(self, "_image", {})) # type: ignore @image.setter def image(self: E, *, url: Any): self._image = { - 'url': str(url), + "url": str(url), } @image.deleter @@ -451,15 +473,14 @@ class Embed: If the attribute has no value then :attr:`Empty` is returned. """ - return EmbedProxy(getattr(self, '_thumbnail', {})) # type: ignore + return EmbedProxy(getattr(self, "_thumbnail", {})) # type: ignore @thumbnail.setter def thumbnail(self: E, *, url: Any): - """Sets the thumbnail for the embed content. - """ + """Sets the thumbnail for the embed content.""" self._thumbnail = { - 'url': str(url), + "url": str(url), } return @@ -504,7 +525,7 @@ class Embed: If the attribute has no value then :attr:`Empty` is returned. """ - return EmbedProxy(getattr(self, '_video', {})) # type: ignore + return EmbedProxy(getattr(self, "_video", {})) # type: ignore @property def provider(self) -> _EmbedProviderProxy: @@ -514,7 +535,7 @@ class Embed: If the attribute has no value then :attr:`Empty` is returned. """ - return EmbedProxy(getattr(self, '_provider', {})) # type: ignore + return EmbedProxy(getattr(self, "_provider", {})) # type: ignore @property def author(self) -> _EmbedAuthorProxy: @@ -524,9 +545,15 @@ class Embed: If the attribute has no value then :attr:`Empty` is returned. """ - return EmbedProxy(getattr(self, '_author', {})) # type: ignore + return EmbedProxy(getattr(self, "_author", {})) # type: ignore - def set_author(self: E, *, name: Any, url: MaybeEmpty[Any] = EmptyEmbed, icon_url: MaybeEmpty[Any] = EmptyEmbed) -> E: + def set_author( + self: E, + *, + name: Any, + url: MaybeEmpty[Any] = EmptyEmbed, + icon_url: MaybeEmpty[Any] = EmptyEmbed, + ) -> E: """Sets the author for the embed content. This function returns the class instance to allow for fluent-style @@ -543,14 +570,14 @@ class Embed: """ self._author = { - 'name': str(name), + "name": str(name), } if url is not EmptyEmbed: - self._author['url'] = str(url) + self._author["url"] = str(url) if icon_url is not EmptyEmbed: - self._author['icon_url'] = str(icon_url) + self._author["icon_url"] = str(icon_url) return self @@ -577,7 +604,7 @@ class Embed: If the attribute has no value then :attr:`Empty` is returned. """ - return [EmbedProxy(d) for d in getattr(self, '_fields', [])] # type: ignore + return [EmbedProxy(d) for d in getattr(self, "_fields", [])] # type: ignore def add_field(self: E, *, name: Any, value: Any, inline: bool = True) -> E: """Adds a field to the embed object. @@ -596,9 +623,9 @@ class Embed: """ field = { - 'inline': inline, - 'name': str(name), - 'value': str(value), + "inline": inline, + "name": str(name), + "value": str(value), } try: @@ -608,7 +635,9 @@ class Embed: return self - def insert_field_at(self: E, index: int, *, name: Any, value: Any, inline: bool = True) -> E: + def insert_field_at( + self: E, index: int, *, name: Any, value: Any, inline: bool = True + ) -> E: """Inserts a field before a specified index to the embed. This function returns the class instance to allow for fluent-style @@ -629,9 +658,9 @@ class Embed: """ field = { - 'inline': inline, - 'name': str(name), - 'value': str(value), + "inline": inline, + "name": str(name), + "value": str(value), } try: @@ -669,7 +698,9 @@ class Embed: except (AttributeError, IndexError): pass - def set_field_at(self: E, index: int, *, name: Any, value: Any, inline: bool = True) -> E: + def set_field_at( + self: E, index: int, *, name: Any, value: Any, inline: bool = True + ) -> E: """Modifies a field to the embed object. The index must point to a valid pre-existing field. @@ -697,11 +728,11 @@ class Embed: try: field = self._fields[index] except (TypeError, IndexError, AttributeError): - raise IndexError('field index out of range') + raise IndexError("field index out of range") - field['name'] = str(name) - field['value'] = str(value) - field['inline'] = inline + field["name"] = str(name) + field["value"] = str(value) + field["inline"] = inline return self def to_dict(self) -> EmbedData: @@ -719,35 +750,39 @@ class Embed: # deal with basic convenience wrappers try: - colour = result.pop('colour') + colour = result.pop("colour") except KeyError: pass else: if colour: - result['color'] = colour.value + result["color"] = colour.value try: - timestamp = result.pop('timestamp') + timestamp = result.pop("timestamp") except KeyError: pass else: if timestamp: if timestamp.tzinfo: - result['timestamp'] = timestamp.astimezone(tz=datetime.timezone.utc).isoformat() + result["timestamp"] = timestamp.astimezone( + tz=datetime.timezone.utc + ).isoformat() else: - result['timestamp'] = timestamp.replace(tzinfo=datetime.timezone.utc).isoformat() + result["timestamp"] = timestamp.replace( + tzinfo=datetime.timezone.utc + ).isoformat() # add in the non raw attribute ones if self.type: - result['type'] = self.type + result["type"] = self.type if self.description: - result['description'] = self.description + result["description"] = self.description if self.url: - result['url'] = self.url + result["url"] = self.url if self.title: - result['title'] = self.title + result["title"] = self.title return result # type: ignore diff --git a/discord/emoji.py b/discord/emoji.py index 39fa0218..a2f8a3ee 100644 --- a/discord/emoji.py +++ b/discord/emoji.py @@ -30,9 +30,7 @@ from .utils import SnowflakeList, snowflake_time, MISSING from .partial_emoji import _EmojiTag, PartialEmoji from .user import User -__all__ = ( - 'Emoji', -) +__all__ = ("Emoji",) if TYPE_CHECKING: from .types.emoji import Emoji as EmojiPayload @@ -98,16 +96,16 @@ class Emoji(_EmojiTag, AssetMixin): """ __slots__: Tuple[str, ...] = ( - 'require_colons', - 'animated', - 'managed', - 'id', - 'name', - '_roles', - 'guild_id', - '_state', - 'user', - 'available', + "require_colons", + "animated", + "managed", + "id", + "name", + "_roles", + "guild_id", + "_state", + "user", + "available", ) def __init__(self, *, guild: Guild, state: ConnectionState, data: EmojiPayload): @@ -116,14 +114,14 @@ class Emoji(_EmojiTag, AssetMixin): self._from_data(data) def _from_data(self, emoji: EmojiPayload): - self.require_colons: bool = emoji.get('require_colons', False) - self.managed: bool = emoji.get('managed', False) - self.id: int = int(emoji['id']) # type: ignore - self.name: str = emoji['name'] # type: ignore - self.animated: bool = emoji.get('animated', False) - self.available: bool = emoji.get('available', True) - self._roles: SnowflakeList = SnowflakeList(map(int, emoji.get('roles', []))) - user = emoji.get('user') + self.require_colons: bool = emoji.get("require_colons", False) + self.managed: bool = emoji.get("managed", False) + self.id: int = int(emoji["id"]) # type: ignore + self.name: str = emoji["name"] # type: ignore + self.animated: bool = emoji.get("animated", False) + self.available: bool = emoji.get("available", True) + self._roles: SnowflakeList = SnowflakeList(map(int, emoji.get("roles", []))) + user = emoji.get("user") self.user: Optional[User] = User(state=self._state, data=user) if user else None def _to_partial(self) -> PartialEmoji: @@ -131,21 +129,21 @@ class Emoji(_EmojiTag, AssetMixin): def __iter__(self) -> Iterator[Tuple[str, Any]]: for attr in self.__slots__: - if attr[0] != '_': + if attr[0] != "_": value = getattr(self, attr, None) if value is not None: yield (attr, value) def __str__(self) -> str: if self.animated: - return f'' - return f'<:{self.name}:{self.id}>' + return f"" + return f"<:{self.name}:{self.id}>" def __int__(self) -> int: return self.id def __repr__(self) -> str: - return f'' + return f"" def __eq__(self, other: Any) -> bool: return isinstance(other, _EmojiTag) and self.id == other.id @@ -164,8 +162,8 @@ class Emoji(_EmojiTag, AssetMixin): @property def url(self) -> str: """:class:`str`: Returns the URL of the emoji.""" - fmt = 'gif' if self.animated else 'png' - return f'{Asset.BASE}/emojis/{self.id}.{fmt}' + fmt = "gif" if self.animated else "png" + return f"{Asset.BASE}/emojis/{self.id}.{fmt}" @property def roles(self) -> List[Role]: @@ -217,9 +215,17 @@ class Emoji(_EmojiTag, AssetMixin): An error occurred deleting the emoji. """ - await self._state.http.delete_custom_emoji(self.guild.id, self.id, reason=reason) + await self._state.http.delete_custom_emoji( + self.guild.id, self.id, reason=reason + ) - async def edit(self, *, name: str = MISSING, roles: List[Snowflake] = MISSING, reason: Optional[str] = None) -> Emoji: + async def edit( + self, + *, + name: str = MISSING, + roles: List[Snowflake] = MISSING, + reason: Optional[str] = None, + ) -> Emoji: r"""|coro| Edits the custom emoji. @@ -254,9 +260,11 @@ class Emoji(_EmojiTag, AssetMixin): payload = {} if name is not MISSING: - payload['name'] = name + payload["name"] = name if roles is not MISSING: - payload['roles'] = [role.id for role in roles] + payload["roles"] = [role.id for role in roles] - data = await self._state.http.edit_custom_emoji(self.guild.id, self.id, payload=payload, reason=reason) + data = await self._state.http.edit_custom_emoji( + self.guild.id, self.id, payload=payload, reason=reason + ) return Emoji(guild=self.guild, data=data, state=self._state) diff --git a/discord/enums.py b/discord/enums.py index af8ee2b0..7eaa6283 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -27,50 +27,65 @@ from collections import namedtuple from typing import Any, ClassVar, Dict, List, Optional, TYPE_CHECKING, Type, TypeVar __all__ = ( - 'Enum', - 'ChannelType', - 'MessageType', - 'VoiceRegion', - 'SpeakingState', - 'VerificationLevel', - 'ContentFilter', - 'Status', - 'DefaultAvatar', - 'AuditLogAction', - 'AuditLogActionCategory', - 'UserFlags', - 'ActivityType', - 'NotificationLevel', - 'TeamMembershipState', - 'WebhookType', - 'ExpireBehaviour', - 'ExpireBehavior', - 'StickerType', - 'StickerFormatType', - 'InviteTarget', - 'VideoQualityMode', - 'ComponentType', - 'ButtonStyle', - 'StagePrivacyLevel', - 'InteractionType', - 'InteractionResponseType', - 'NSFWLevel', + "Enum", + "ChannelType", + "MessageType", + "VoiceRegion", + "SpeakingState", + "VerificationLevel", + "ContentFilter", + "Status", + "DefaultAvatar", + "AuditLogAction", + "AuditLogActionCategory", + "UserFlags", + "ActivityType", + "NotificationLevel", + "TeamMembershipState", + "WebhookType", + "ExpireBehaviour", + "ExpireBehavior", + "StickerType", + "StickerFormatType", + "InviteTarget", + "VideoQualityMode", + "ComponentType", + "ButtonStyle", + "StagePrivacyLevel", + "InteractionType", + "InteractionResponseType", + "NSFWLevel", ) def _create_value_cls(name, comparable): - cls = namedtuple('_EnumValue_' + name, 'name value') - cls.__repr__ = lambda self: f'<{name}.{self.name}: {self.value!r}>' - cls.__str__ = lambda self: f'{name}.{self.name}' + cls = namedtuple("_EnumValue_" + name, "name value") + cls.__repr__ = lambda self: f"<{name}.{self.name}: {self.value!r}>" + cls.__str__ = lambda self: f"{name}.{self.name}" if comparable: - cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value - cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value - cls.__lt__ = lambda self, other: isinstance(other, self.__class__) and self.value < other.value - cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value + cls.__le__ = ( + lambda self, other: isinstance(other, self.__class__) + and self.value <= other.value + ) + cls.__ge__ = ( + lambda self, other: isinstance(other, self.__class__) + and self.value >= other.value + ) + cls.__lt__ = ( + lambda self, other: isinstance(other, self.__class__) + and self.value < other.value + ) + cls.__gt__ = ( + lambda self, other: isinstance(other, self.__class__) + and self.value > other.value + ) return cls + def _is_descriptor(obj): - return hasattr(obj, '__get__') or hasattr(obj, '__set__') or hasattr(obj, '__delete__') + return ( + hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") + ) class EnumMeta(type): @@ -88,7 +103,7 @@ class EnumMeta(type): value_cls = _create_value_cls(name, comparable) for key, value in list(attrs.items()): is_descriptor = _is_descriptor(value) - if key[0] == '_' and not is_descriptor: + if key[0] == "_" and not is_descriptor: continue # Special case classmethod to just pass through @@ -110,10 +125,10 @@ class EnumMeta(type): member_mapping[key] = new_value attrs[key] = new_value - attrs['_enum_value_map_'] = value_mapping - attrs['_enum_member_map_'] = member_mapping - attrs['_enum_member_names_'] = member_names - attrs['_enum_value_cls_'] = value_cls + attrs["_enum_value_map_"] = value_mapping + attrs["_enum_member_map_"] = member_mapping + attrs["_enum_member_names_"] = member_names + attrs["_enum_value_cls_"] = value_cls actual_cls = super().__new__(cls, name, bases, attrs) value_cls._actual_enum_cls_ = actual_cls # type: ignore return actual_cls @@ -122,13 +137,15 @@ class EnumMeta(type): return (cls._enum_member_map_[name] for name in cls._enum_member_names_) def __reversed__(cls): - return (cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_)) + return ( + cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_) + ) def __len__(cls): return len(cls._enum_member_names_) def __repr__(cls): - return f'' + return f"" @property def __members__(cls): @@ -144,10 +161,10 @@ class EnumMeta(type): return cls._enum_member_map_[key] def __setattr__(cls, name, value): - raise TypeError('Enums are immutable.') + raise TypeError("Enums are immutable.") def __delattr__(cls, attr): - raise TypeError('Enums are immutable') + raise TypeError("Enums are immutable") def __instancecheck__(self, instance): # isinstance(x, Y) @@ -215,29 +232,29 @@ class MessageType(Enum): class VoiceRegion(Enum): - us_west = 'us-west' - us_east = 'us-east' - us_south = 'us-south' - us_central = 'us-central' - eu_west = 'eu-west' - eu_central = 'eu-central' - singapore = 'singapore' - london = 'london' - sydney = 'sydney' - amsterdam = 'amsterdam' - frankfurt = 'frankfurt' - brazil = 'brazil' - hongkong = 'hongkong' - russia = 'russia' - japan = 'japan' - southafrica = 'southafrica' - south_korea = 'south-korea' - india = 'india' - europe = 'europe' - dubai = 'dubai' - vip_us_east = 'vip-us-east' - vip_us_west = 'vip-us-west' - vip_amsterdam = 'vip-amsterdam' + us_west = "us-west" + us_east = "us-east" + us_south = "us-south" + us_central = "us-central" + eu_west = "eu-west" + eu_central = "eu-central" + singapore = "singapore" + london = "london" + sydney = "sydney" + amsterdam = "amsterdam" + frankfurt = "frankfurt" + brazil = "brazil" + hongkong = "hongkong" + russia = "russia" + japan = "japan" + southafrica = "southafrica" + south_korea = "south-korea" + india = "india" + europe = "europe" + dubai = "dubai" + vip_us_east = "vip-us-east" + vip_us_west = "vip-us-west" + vip_amsterdam = "vip-amsterdam" def __str__(self): return self.value @@ -277,12 +294,12 @@ class ContentFilter(Enum, comparable=True): class Status(Enum): - online = 'online' - offline = 'offline' - idle = 'idle' - dnd = 'dnd' - do_not_disturb = 'dnd' - invisible = 'invisible' + online = "online" + offline = "offline" + idle = "idle" + dnd = "dnd" + do_not_disturb = "dnd" + invisible = "invisible" def __str__(self): return self.value @@ -415,33 +432,33 @@ class AuditLogAction(Enum): def target_type(self) -> Optional[str]: v = self.value if v == -1: - return 'all' + return "all" elif v < 10: - return 'guild' + return "guild" elif v < 20: - return 'channel' + return "channel" elif v < 30: - return 'user' + return "user" elif v < 40: - return 'role' + return "role" elif v < 50: - return 'invite' + return "invite" elif v < 60: - return 'webhook' + return "webhook" elif v < 70: - return 'emoji' + return "emoji" elif v == 73: - return 'channel' + return "channel" elif v < 80: - return 'message' + return "message" elif v < 83: - return 'integration' + return "integration" elif v < 90: - return 'stage_instance' + return "stage_instance" elif v < 93: - return 'sticker' + return "sticker" elif v < 113: - return 'thread' + return "thread" class UserFlags(Enum): @@ -589,12 +606,12 @@ class NSFWLevel(Enum, comparable=True): age_restricted = 3 -T = TypeVar('T') +T = TypeVar("T") def create_unknown_value(cls: Type[T], val: Any) -> T: value_cls = cls._enum_value_cls_ # type: ignore - name = f'unknown_{val}' + name = f"unknown_{val}" return value_cls(name=name, value=val) diff --git a/discord/errors.py b/discord/errors.py index bc2398d5..4dec03e4 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -38,20 +38,20 @@ if TYPE_CHECKING: from .interactions import Interaction __all__ = ( - 'DiscordException', - 'ClientException', - 'NoMoreItems', - 'GatewayNotFound', - 'HTTPException', - 'Forbidden', - 'NotFound', - 'DiscordServerError', - 'InvalidData', - 'InvalidArgument', - 'LoginFailure', - 'ConnectionClosed', - 'PrivilegedIntentsRequired', - 'InteractionResponded', + "DiscordException", + "ClientException", + "NoMoreItems", + "GatewayNotFound", + "HTTPException", + "Forbidden", + "NotFound", + "DiscordServerError", + "InvalidData", + "InvalidArgument", + "LoginFailure", + "ConnectionClosed", + "PrivilegedIntentsRequired", + "InteractionResponded", ) @@ -83,22 +83,22 @@ class GatewayNotFound(DiscordException): """An exception that is raised when the gateway for Discord could not be found""" def __init__(self): - message = 'The gateway to connect to discord was not found.' + message = "The gateway to connect to discord was not found." super().__init__(message) -def _flatten_error_dict(d: Dict[str, Any], key: str = '') -> Dict[str, str]: +def _flatten_error_dict(d: Dict[str, Any], key: str = "") -> Dict[str, str]: items: List[Tuple[str, str]] = [] for k, v in d.items(): - new_key = key + '.' + k if key else k + new_key = key + "." + k if key else k if isinstance(v, dict): try: - _errors: List[Dict[str, Any]] = v['_errors'] + _errors: List[Dict[str, Any]] = v["_errors"] except KeyError: items.extend(_flatten_error_dict(v, new_key).items()) else: - items.append((new_key, ' '.join(x.get('message', '') for x in _errors))) + items.append((new_key, " ".join(x.get("message", "") for x in _errors))) else: items.append((new_key, v)) @@ -123,28 +123,30 @@ class HTTPException(DiscordException): The Discord specific error code for the failure. """ - def __init__(self, response: _ResponseType, message: Optional[Union[str, Dict[str, Any]]]): + def __init__( + self, response: _ResponseType, message: Optional[Union[str, Dict[str, Any]]] + ): self.response: _ResponseType = response self.status: int = response.status # type: ignore self.code: int self.text: str if isinstance(message, dict): - self.code = message.get('code', 0) - base = message.get('message', '') - errors = message.get('errors') + self.code = message.get("code", 0) + base = message.get("message", "") + errors = message.get("errors") if errors: errors = _flatten_error_dict(errors) - helpful = '\n'.join('In %s: %s' % t for t in errors.items()) - self.text = base + '\n' + helpful + helpful = "\n".join("In %s: %s" % t for t in errors.items()) + self.text = base + "\n" + helpful else: self.text = base else: - self.text = message or '' + self.text = message or "" self.code = 0 - fmt = '{0.status} {0.reason} (error code: {1})' + fmt = "{0.status} {0.reason} (error code: {1})" if len(self.text): - fmt += ': {2}' + fmt += ": {2}" super().__init__(fmt.format(self.response, self.code, self.text)) @@ -221,14 +223,20 @@ class ConnectionClosed(ClientException): The shard ID that got closed if applicable. """ - def __init__(self, socket: ClientWebSocketResponse, *, shard_id: Optional[int], code: Optional[int] = None): + def __init__( + self, + socket: ClientWebSocketResponse, + *, + shard_id: Optional[int], + code: Optional[int] = None, + ): # This exception is just the same exception except # reconfigured to subclass ClientException for users self.code: int = code or socket.close_code or -1 # aiohttp doesn't seem to consistently provide close reason - self.reason: str = '' + self.reason: str = "" self.shard_id: Optional[int] = shard_id - super().__init__(f'Shard ID {self.shard_id} WebSocket closed with {self.code}') + super().__init__(f"Shard ID {self.shard_id} WebSocket closed with {self.code}") class PrivilegedIntentsRequired(ClientException): @@ -250,10 +258,10 @@ class PrivilegedIntentsRequired(ClientException): def __init__(self, shard_id: Optional[int]): self.shard_id: Optional[int] = shard_id msg = ( - 'Shard ID %s is requesting privileged intents that have not been explicitly enabled in the ' - 'developer portal. It is recommended to go to https://discord.com/developers/applications/ ' - 'and explicitly enable the privileged intents within your application\'s page. If this is not ' - 'possible, then consider disabling the privileged intents instead.' + "Shard ID %s is requesting privileged intents that have not been explicitly enabled in the " + "developer portal. It is recommended to go to https://discord.com/developers/applications/ " + "and explicitly enable the privileged intents within your application's page. If this is not " + "possible, then consider disabling the privileged intents instead." ) super().__init__(msg % shard_id) @@ -274,4 +282,4 @@ class InteractionResponded(ClientException): def __init__(self, interaction: Interaction): self.interaction: Interaction = interaction - super().__init__('This interaction has already been responded to before') + super().__init__("This interaction has already been responded to before") diff --git a/discord/ext/commands/_types.py b/discord/ext/commands/_types.py index 9b155987..42a5cc6f 100644 --- a/discord/ext/commands/_types.py +++ b/discord/ext/commands/_types.py @@ -31,15 +31,23 @@ if TYPE_CHECKING: from .cog import Cog from .errors import CommandError -T = TypeVar('T') +T = TypeVar("T") Coro = Coroutine[Any, Any, T] MaybeCoro = Union[T, Coro[T]] CoroFunc = Callable[..., Coro[Any]] -Check = Union[Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]]] -Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]] -Error = Union[Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]]] +Check = Union[ + Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], + Callable[["Context[Any]"], MaybeCoro[bool]], +] +Hook = Union[ + Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]] +] +Error = Union[ + Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], + Callable[["Context[Any]", "CommandError"], Coro[Any]], +] # This is merely a tag type to avoid circular import issues. diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index c089b87d..b24a99c5 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -33,7 +33,18 @@ import importlib.util import sys import traceback import types -from typing import Any, Callable, Mapping, List, Dict, TYPE_CHECKING, Optional, TypeVar, Type, Union +from typing import ( + Any, + Callable, + Mapping, + List, + Dict, + TYPE_CHECKING, + Optional, + TypeVar, + Type, + Union, +) import discord @@ -54,17 +65,18 @@ if TYPE_CHECKING: ) __all__ = ( - 'when_mentioned', - 'when_mentioned_or', - 'Bot', - 'AutoShardedBot', + "when_mentioned", + "when_mentioned_or", + "Bot", + "AutoShardedBot", ) MISSING: Any = discord.utils.MISSING -T = TypeVar('T') -CFT = TypeVar('CFT', bound='CoroFunc') -CXT = TypeVar('CXT', bound='Context') +T = TypeVar("T") +CFT = TypeVar("CFT", bound="CoroFunc") +CXT = TypeVar("CXT", bound="Context") + def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]: """A callable that implements a command prefix equivalent to being mentioned. @@ -72,9 +84,12 @@ def when_mentioned(bot: Union[Bot, AutoShardedBot], msg: Message) -> List[str]: These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. """ # bot.user will never be None when this is called - return [f'<@{bot.user.id}> ', f'<@!{bot.user.id}> '] # type: ignore + return [f"<@{bot.user.id}> ", f"<@!{bot.user.id}> "] # type: ignore -def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]: + +def when_mentioned_or( + *prefixes: str, +) -> Callable[[Union[Bot, AutoShardedBot], Message], List[str]]: """A callable that implements when mentioned or other prefixes provided. These are meant to be passed into the :attr:`.Bot.command_prefix` attribute. @@ -103,6 +118,7 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M ---------- :func:`.when_mentioned` """ + def inner(bot, msg): r = list(prefixes) r = when_mentioned(bot, msg) + r @@ -110,17 +126,29 @@ def when_mentioned_or(*prefixes: str) -> Callable[[Union[Bot, AutoShardedBot], M return inner + def _is_submodule(parent: str, child: str) -> bool: return parent == child or child.startswith(parent + ".") + class _DefaultRepr: def __repr__(self): - return '' + return "" + _default = _DefaultRepr() + class BotBase(GroupMixin): - def __init__(self, command_prefix, help_command=_default, description=None, *, intents: discord.Intents, **options): + def __init__( + self, + command_prefix, + help_command=_default, + description=None, + *, + intents: discord.Intents, + **options, + ): super().__init__(**options, intents=intents) self.command_prefix = command_prefix self.extra_events: Dict[str, List[CoroFunc]] = {} @@ -131,16 +159,20 @@ class BotBase(GroupMixin): self._before_invoke = None self._after_invoke = None self._help_command = None - self.description = inspect.cleandoc(description) if description else '' - self.owner_id = options.get('owner_id') - self.owner_ids = options.get('owner_ids', set()) - self.strip_after_prefix = options.get('strip_after_prefix', False) + self.description = inspect.cleandoc(description) if description else "" + self.owner_id = options.get("owner_id") + self.owner_ids = options.get("owner_ids", set()) + self.strip_after_prefix = options.get("strip_after_prefix", False) if self.owner_id and self.owner_ids: - raise TypeError('Both owner_id and owner_ids are set.') + raise TypeError("Both owner_id and owner_ids are set.") - if self.owner_ids and not isinstance(self.owner_ids, collections.abc.Collection): - raise TypeError(f'owner_ids must be a collection not {self.owner_ids.__class__!r}') + if self.owner_ids and not isinstance( + self.owner_ids, collections.abc.Collection + ): + raise TypeError( + f"owner_ids must be a collection not {self.owner_ids.__class__!r}" + ) if help_command is _default: self.help_command = DefaultHelpCommand() @@ -152,7 +184,7 @@ class BotBase(GroupMixin): def dispatch(self, event_name: str, *args: Any, **kwargs: Any) -> None: # super() will resolve to Client super().dispatch(event_name, *args, **kwargs) # type: ignore - ev = 'on_' + event_name + ev = "on_" + event_name for event in self.extra_events.get(ev, []): self._schedule_event(event, ev, *args, **kwargs) # type: ignore @@ -172,7 +204,9 @@ class BotBase(GroupMixin): await super().close() # type: ignore - async def on_command_error(self, context: Context, exception: errors.CommandError) -> None: + async def on_command_error( + self, context: Context, exception: errors.CommandError + ) -> None: """|coro| The default command error handler provided by the bot. @@ -182,7 +216,7 @@ class BotBase(GroupMixin): This only fires if you do not specify any listeners for command error. """ - if self.extra_events.get('on_command_error', None): + if self.extra_events.get("on_command_error", None): return command = context.command @@ -193,8 +227,10 @@ class BotBase(GroupMixin): if cog and cog.has_error_handler(): return - print(f'Ignoring exception in command {context.command}:', file=sys.stderr) - traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr) + print(f"Ignoring exception in command {context.command}:", file=sys.stderr) + traceback.print_exception( + type(exception), exception, exception.__traceback__, file=sys.stderr + ) # global check registration @@ -380,7 +416,7 @@ class BotBase(GroupMixin): The coroutine passed is not actually a coroutine. """ if not asyncio.iscoroutinefunction(coro): - raise TypeError('The pre-invoke hook must be a coroutine.') + raise TypeError("The pre-invoke hook must be a coroutine.") self._before_invoke = coro return coro @@ -413,7 +449,7 @@ class BotBase(GroupMixin): The coroutine passed is not actually a coroutine. """ if not asyncio.iscoroutinefunction(coro): - raise TypeError('The post-invoke hook must be a coroutine.') + raise TypeError("The post-invoke hook must be a coroutine.") self._after_invoke = coro return coro @@ -445,7 +481,7 @@ class BotBase(GroupMixin): name = func.__name__ if name is MISSING else name if not asyncio.iscoroutinefunction(func): - raise TypeError('Listeners must be coroutines') + raise TypeError("Listeners must be coroutines") if name in self.extra_events: self.extra_events[name].append(func) @@ -541,14 +577,14 @@ class BotBase(GroupMixin): """ if not isinstance(cog, Cog): - raise TypeError('cogs must derive from Cog') + raise TypeError("cogs must derive from Cog") cog_name = cog.__cog_name__ existing = self.__cogs.get(cog_name) if existing is not None: if not override: - raise discord.ClientException(f'Cog named {cog_name!r} already loaded') + raise discord.ClientException(f"Cog named {cog_name!r} already loaded") self.remove_cog(cog_name) cog = cog._inject(self) @@ -628,7 +664,9 @@ class BotBase(GroupMixin): for event_list in self.extra_events.copy().values(): remove = [] for index, event in enumerate(event_list): - if event.__module__ is not None and _is_submodule(name, event.__module__): + if event.__module__ is not None and _is_submodule( + name, event.__module__ + ): remove.append(index) for index in reversed(remove): @@ -636,7 +674,7 @@ class BotBase(GroupMixin): def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None: try: - func = getattr(lib, 'teardown') + func = getattr(lib, "teardown") except AttributeError: pass else: @@ -652,7 +690,9 @@ class BotBase(GroupMixin): if _is_submodule(name, module): del sys.modules[module] - def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None: + def _load_from_module_spec( + self, spec: importlib.machinery.ModuleSpec, key: str + ) -> None: # precondition: key not in self.__extensions lib = importlib.util.module_from_spec(spec) sys.modules[key] = lib @@ -663,7 +703,7 @@ class BotBase(GroupMixin): raise errors.ExtensionFailed(key, e) from e try: - setup = getattr(lib, 'setup') + setup = getattr(lib, "setup") except AttributeError: del sys.modules[key] raise errors.NoEntryPointError(key) @@ -850,7 +890,7 @@ class BotBase(GroupMixin): def help_command(self, value: Optional[HelpCommand]) -> None: if value is not None: if not isinstance(value, HelpCommand): - raise TypeError('help_command must be a subclass of HelpCommand') + raise TypeError("help_command must be a subclass of HelpCommand") if self._help_command is not None: self._help_command._remove_from_bot(self) self._help_command = value @@ -893,11 +933,15 @@ class BotBase(GroupMixin): if isinstance(ret, collections.abc.Iterable): raise - raise TypeError("command_prefix must be plain string, iterable of strings, or callable " - f"returning either of these, not {ret.__class__.__name__}") + raise TypeError( + "command_prefix must be plain string, iterable of strings, or callable " + f"returning either of these, not {ret.__class__.__name__}" + ) if not ret: - raise ValueError("Iterable command_prefix must contain at least one prefix") + raise ValueError( + "Iterable command_prefix must contain at least one prefix" + ) return ret @@ -954,14 +998,18 @@ class BotBase(GroupMixin): except TypeError: if not isinstance(prefix, list): - raise TypeError("get_prefix must return either a string or a list of string, " - f"not {prefix.__class__.__name__}") + raise TypeError( + "get_prefix must return either a string or a list of string, " + f"not {prefix.__class__.__name__}" + ) # It's possible a bad command_prefix got us here. for value in prefix: if not isinstance(value, str): - raise TypeError("Iterable command_prefix or list returned from get_prefix must " - f"contain only strings, not {value.__class__.__name__}") + raise TypeError( + "Iterable command_prefix or list returned from get_prefix must " + f"contain only strings, not {value.__class__.__name__}" + ) # Getting here shouldn't happen raise @@ -988,19 +1036,19 @@ class BotBase(GroupMixin): The invocation context to invoke. """ if ctx.command is not None: - self.dispatch('command', ctx) + self.dispatch("command", ctx) try: if await self.can_run(ctx, call_once=True): await ctx.command.invoke(ctx) else: - raise errors.CheckFailure('The global check once functions failed.') + raise errors.CheckFailure("The global check once functions failed.") except errors.CommandError as exc: await ctx.command.dispatch_error(ctx, exc) else: - self.dispatch('command_completion', ctx) + self.dispatch("command_completion", ctx) elif ctx.invoked_with: exc = errors.CommandNotFound(f'Command "{ctx.invoked_with}" is not found') - self.dispatch('command_error', ctx, exc) + self.dispatch("command_error", ctx, exc) async def process_commands(self, message: Message) -> None: """|coro| @@ -1033,6 +1081,7 @@ class BotBase(GroupMixin): async def on_message(self, message): await self.process_commands(message) + class Bot(BotBase, discord.Client): """Represents a discord bot. @@ -1103,10 +1152,13 @@ class Bot(BotBase, discord.Client): .. versionadded:: 1.7 """ + pass + class AutoShardedBot(BotBase, discord.AutoShardedClient): """This is similar to :class:`.Bot` except that it is inherited from :class:`discord.AutoShardedClient` instead. """ + pass diff --git a/discord/ext/commands/cog.py b/discord/ext/commands/cog.py index 9931557d..455b8bcc 100644 --- a/discord/ext/commands/cog.py +++ b/discord/ext/commands/cog.py @@ -26,7 +26,19 @@ from __future__ import annotations import inspect import discord.utils -from typing import Any, Callable, ClassVar, Dict, Generator, List, Optional, TYPE_CHECKING, Tuple, TypeVar, Type +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Generator, + List, + Optional, + TYPE_CHECKING, + Tuple, + TypeVar, + Type, +) from ._types import _BaseCommand @@ -36,15 +48,16 @@ if TYPE_CHECKING: from .core import Command __all__ = ( - 'CogMeta', - 'Cog', + "CogMeta", + "Cog", ) -CogT = TypeVar('CogT', bound='Cog') -FuncT = TypeVar('FuncT', bound=Callable[..., Any]) +CogT = TypeVar("CogT", bound="Cog") +FuncT = TypeVar("FuncT", bound=Callable[..., Any]) MISSING: Any = discord.utils.MISSING + class CogMeta(type): """A metaclass for defining a cog. @@ -104,6 +117,7 @@ class CogMeta(type): async def bar(self, ctx): pass # hidden -> False """ + __cog_name__: str __cog_settings__: Dict[str, Any] __cog_commands__: List[Command] @@ -111,17 +125,17 @@ class CogMeta(type): def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta: name, bases, attrs = args - attrs['__cog_name__'] = kwargs.pop('name', name) - attrs['__cog_settings__'] = kwargs.pop('command_attrs', {}) + attrs["__cog_name__"] = kwargs.pop("name", name) + attrs["__cog_settings__"] = kwargs.pop("command_attrs", {}) - description = kwargs.pop('description', None) + description = kwargs.pop("description", None) if description is None: - description = inspect.cleandoc(attrs.get('__doc__', '')) - attrs['__cog_description__'] = description + description = inspect.cleandoc(attrs.get("__doc__", "")) + attrs["__cog_description__"] = description commands = {} listeners = {} - no_bot_cog = 'Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})' + no_bot_cog = "Commands or listeners must not start with cog_ or bot_ (in method {0.__name__}.{1})" new_cls = super().__new__(cls, name, bases, attrs, **kwargs) for base in reversed(new_cls.__mro__): @@ -136,21 +150,25 @@ class CogMeta(type): value = value.__func__ if isinstance(value, _BaseCommand): if is_static_method: - raise TypeError(f'Command in method {base}.{elem!r} must not be staticmethod.') - if elem.startswith(('cog_', 'bot_')): + raise TypeError( + f"Command in method {base}.{elem!r} must not be staticmethod." + ) + if elem.startswith(("cog_", "bot_")): raise TypeError(no_bot_cog.format(base, elem)) commands[elem] = value elif inspect.iscoroutinefunction(value): try: - getattr(value, '__cog_listener__') + getattr(value, "__cog_listener__") except AttributeError: continue else: - if elem.startswith(('cog_', 'bot_')): + if elem.startswith(("cog_", "bot_")): raise TypeError(no_bot_cog.format(base, elem)) listeners[elem] = value - new_cls.__cog_commands__ = list(commands.values()) # this will be copied in Cog.__new__ + new_cls.__cog_commands__ = list( + commands.values() + ) # this will be copied in Cog.__new__ listeners_as_list = [] for listener in listeners.values(): @@ -169,10 +187,12 @@ class CogMeta(type): def qualified_name(cls) -> str: return cls.__cog_name__ + def _cog_special_method(func: FuncT) -> FuncT: func.__cog_special_method__ = None return func + class Cog(metaclass=CogMeta): """The base class that all cogs must inherit from. @@ -183,6 +203,7 @@ class Cog(metaclass=CogMeta): When inheriting from this class, the options shown in :class:`CogMeta` are equally valid here. """ + __cog_name__: ClassVar[str] __cog_settings__: ClassVar[Dict[str, Any]] __cog_commands__: ClassVar[List[Command]] @@ -199,10 +220,7 @@ class Cog(metaclass=CogMeta): # r.e type ignore, type-checker complains about overriding a ClassVar self.__cog_commands__ = tuple(c._update_copy(cmd_attrs) for c in cls.__cog_commands__) # type: ignore - lookup = { - cmd.qualified_name: cmd - for cmd in self.__cog_commands__ - } + lookup = {cmd.qualified_name: cmd for cmd in self.__cog_commands__} # Update the Command instances dynamically as well for command in self.__cog_commands__: @@ -255,6 +273,7 @@ class Cog(metaclass=CogMeta): A command or group from the cog. """ from .core import GroupMixin + for command in self.__cog_commands__: if command.parent is None: yield command @@ -269,12 +288,15 @@ class Cog(metaclass=CogMeta): List[Tuple[:class:`str`, :ref:`coroutine `]] The listeners defined in this cog. """ - return [(name, getattr(self, method_name)) for name, method_name in self.__cog_listeners__] + return [ + (name, getattr(self, method_name)) + for name, method_name in self.__cog_listeners__ + ] @classmethod def _get_overridden_method(cls, method: FuncT) -> Optional[FuncT]: """Return None if the method is not overridden. Otherwise returns the overridden method.""" - return getattr(method.__func__, '__cog_special_method__', method) + return getattr(method.__func__, "__cog_special_method__", method) @classmethod def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]: @@ -296,14 +318,16 @@ class Cog(metaclass=CogMeta): """ if name is not MISSING and not isinstance(name, str): - raise TypeError(f'Cog.listener expected str but received {name.__class__.__name__!r} instead.') + raise TypeError( + f"Cog.listener expected str but received {name.__class__.__name__!r} instead." + ) def decorator(func: FuncT) -> FuncT: actual = func if isinstance(actual, staticmethod): actual = actual.__func__ if not inspect.iscoroutinefunction(actual): - raise TypeError('Listener function must be a coroutine function.') + raise TypeError("Listener function must be a coroutine function.") actual.__cog_listener__ = True to_assign = name or actual.__name__ try: @@ -315,6 +339,7 @@ class Cog(metaclass=CogMeta): # to pick it up but the metaclass unfurls the function and # thus the assignments need to be on the actual function return func + return decorator def has_error_handler(self) -> bool: @@ -322,7 +347,7 @@ class Cog(metaclass=CogMeta): .. versionadded:: 1.7 """ - return not hasattr(self.cog_command_error.__func__, '__cog_special_method__') + return not hasattr(self.cog_command_error.__func__, "__cog_special_method__") @_cog_special_method def cog_unload(self) -> None: diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index 38a24d1d..2b282308 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -49,21 +49,19 @@ if TYPE_CHECKING: from .help import HelpCommand from .view import StringView -__all__ = ( - 'Context', -) +__all__ = ("Context",) MISSING: Any = discord.utils.MISSING -T = TypeVar('T') -BotT = TypeVar('BotT', bound="Union[Bot, AutoShardedBot]") -CogT = TypeVar('CogT', bound="Cog") +T = TypeVar("T") +BotT = TypeVar("BotT", bound="Union[Bot, AutoShardedBot]") +CogT = TypeVar("CogT", bound="Cog") if TYPE_CHECKING: - P = ParamSpec('P') + P = ParamSpec("P") else: - P = TypeVar('P') + P = TypeVar("P") class Context(discord.abc.Messageable, Generic[BotT]): @@ -122,7 +120,8 @@ class Context(discord.abc.Messageable, Generic[BotT]): or invoked. """ - def __init__(self, + def __init__( + self, *, message: Message, bot: BotT, @@ -153,7 +152,9 @@ class Context(discord.abc.Messageable, Generic[BotT]): self.current_parameter: Optional[inspect.Parameter] = current_parameter self._state: ConnectionState = self.message._state - async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T: + async def invoke( + self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs + ) -> T: r"""|coro| Calls a command with the arguments given. @@ -219,7 +220,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): cmd = self.command view = self.view if cmd is None: - raise ValueError('This context is not valid.') + raise ValueError("This context is not valid.") # some state to revert to when we're done index, previous = view.index, view.previous @@ -230,10 +231,10 @@ class Context(discord.abc.Messageable, Generic[BotT]): if restart: to_call = cmd.root_parent or cmd - view.index = len(self.prefix or '') + view.index = len(self.prefix or "") view.previous = 0 self.invoked_parents = [] - self.invoked_with = view.get_word() # advance to get the root command + self.invoked_with = view.get_word() # advance to get the root command else: to_call = cmd @@ -263,7 +264,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): .. versionadded:: 2.0 """ if self.prefix is None: - return '' + return "" user = self.me # this breaks if the prefix mention is not the bot itself but I @@ -271,7 +272,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): # for this common use case rather than waste performance for the # odd one. pattern = re.compile(r"<@!?%s>" % user.id) - return pattern.sub("@%s" % user.display_name.replace('\\', r'\\'), self.prefix) + return pattern.sub("@%s" % user.display_name.replace("\\", r"\\"), self.prefix) @property def cog(self) -> Optional[Cog]: @@ -381,7 +382,7 @@ class Context(discord.abc.Messageable, Generic[BotT]): await cmd.prepare_help_command(self, entity.qualified_name) try: - if hasattr(entity, '__cog_commands__'): + if hasattr(entity, "__cog_commands__"): injected = wrap_callback(cmd.send_cog_help) return await injected(entity) elif isinstance(entity, Group): diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index 5740a188..c79f7b83 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -52,32 +52,32 @@ if TYPE_CHECKING: __all__ = ( - 'Converter', - 'ObjectConverter', - 'MemberConverter', - 'UserConverter', - 'MessageConverter', - 'PartialMessageConverter', - 'TextChannelConverter', - 'InviteConverter', - 'GuildConverter', - 'RoleConverter', - 'GameConverter', - 'ColourConverter', - 'ColorConverter', - 'VoiceChannelConverter', - 'StageChannelConverter', - 'EmojiConverter', - 'PartialEmojiConverter', - 'CategoryChannelConverter', - 'IDConverter', - 'StoreChannelConverter', - 'ThreadConverter', - 'GuildChannelConverter', - 'GuildStickerConverter', - 'clean_content', - 'Greedy', - 'run_converters', + "Converter", + "ObjectConverter", + "MemberConverter", + "UserConverter", + "MessageConverter", + "PartialMessageConverter", + "TextChannelConverter", + "InviteConverter", + "GuildConverter", + "RoleConverter", + "GameConverter", + "ColourConverter", + "ColorConverter", + "VoiceChannelConverter", + "StageChannelConverter", + "EmojiConverter", + "PartialEmojiConverter", + "CategoryChannelConverter", + "IDConverter", + "StoreChannelConverter", + "ThreadConverter", + "GuildChannelConverter", + "GuildStickerConverter", + "clean_content", + "Greedy", + "run_converters", ) @@ -91,10 +91,10 @@ def _get_from_guilds(bot, getter, argument): _utils_get = discord.utils.get -T = TypeVar('T') -T_co = TypeVar('T_co', covariant=True) -CT = TypeVar('CT', bound=discord.abc.GuildChannel) -TT = TypeVar('TT', bound=discord.Thread) +T = TypeVar("T") +T_co = TypeVar("T_co", covariant=True) +CT = TypeVar("CT", bound=discord.abc.GuildChannel) +TT = TypeVar("TT", bound=discord.Thread) @runtime_checkable @@ -132,10 +132,10 @@ class Converter(Protocol[T_co]): :exc:`.BadArgument` The converter failed to convert the argument. """ - raise NotImplementedError('Derived classes need to implement this.') + raise NotImplementedError("Derived classes need to implement this.") -_ID_REGEX = re.compile(r'([0-9]{15,20})$') +_ID_REGEX = re.compile(r"([0-9]{15,20})$") class IDConverter(Converter[T_co]): @@ -158,7 +158,9 @@ class ObjectConverter(IDConverter[discord.Object]): """ async def convert(self, ctx: Context, argument: str) -> discord.Object: - match = self._get_id_match(argument) or re.match(r'<(?:@(?:!|&)?|#)([0-9]{15,20})>$', argument) + match = self._get_id_match(argument) or re.match( + r"<(?:@(?:!|&)?|#)([0-9]{15,20})>$", argument + ) if match is None: raise ObjectNotFound(argument) @@ -192,13 +194,17 @@ class MemberConverter(IDConverter[discord.Member]): async def query_member_named(self, guild, argument): cache = guild._state.member_cache_flags.joined - if len(argument) > 5 and argument[-5] == '#': - username, _, discriminator = argument.rpartition('#') + if len(argument) > 5 and argument[-5] == "#": + username, _, discriminator = argument.rpartition("#") members = await guild.query_members(username, limit=100, cache=cache) - return discord.utils.get(members, name=username, discriminator=discriminator) + return discord.utils.get( + members, name=username, discriminator=discriminator + ) else: members = await guild.query_members(argument, limit=100, cache=cache) - return discord.utils.find(lambda m: m.name == argument or m.nick == argument, members) + return discord.utils.find( + lambda m: m.name == argument or m.nick == argument, members + ) async def query_member_by_id(self, bot, guild, user_id): ws = bot._get_websocket(shard_id=guild.shard_id) @@ -223,7 +229,9 @@ class MemberConverter(IDConverter[discord.Member]): async def convert(self, ctx: Context, argument: str) -> discord.Member: bot = ctx.bot - match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument) + match = self._get_id_match(argument) or re.match( + r"<@!?([0-9]{15,20})>$", argument + ) guild = ctx.guild result = None user_id = None @@ -232,13 +240,15 @@ class MemberConverter(IDConverter[discord.Member]): if guild: result = guild.get_member_named(argument) else: - result = _get_from_guilds(bot, 'get_member_named', argument) + result = _get_from_guilds(bot, "get_member_named", argument) else: user_id = int(match.group(1)) if guild: - result = guild.get_member(user_id) or _utils_get(ctx.message.mentions, id=user_id) + result = guild.get_member(user_id) or _utils_get( + ctx.message.mentions, id=user_id + ) else: - result = _get_from_guilds(bot, 'get_member', user_id) + result = _get_from_guilds(bot, "get_member", user_id) if result is None: if guild is None: @@ -276,13 +286,17 @@ class UserConverter(IDConverter[discord.User]): """ async def convert(self, ctx: Context, argument: str) -> discord.User: - match = self._get_id_match(argument) or re.match(r'<@!?([0-9]{15,20})>$', argument) + match = self._get_id_match(argument) or re.match( + r"<@!?([0-9]{15,20})>$", argument + ) result = None state = ctx._state if match is not None: user_id = int(match.group(1)) - result = ctx.bot.get_user(user_id) or _utils_get(ctx.message.mentions, id=user_id) + result = ctx.bot.get_user(user_id) or _utils_get( + ctx.message.mentions, id=user_id + ) if result is None: try: result = await ctx.bot.fetch_user(user_id) @@ -294,12 +308,12 @@ class UserConverter(IDConverter[discord.User]): arg = argument # Remove the '@' character if this is the first character from the argument - if arg[0] == '@': + if arg[0] == "@": # Remove first character arg = arg[1:] # check for discriminator if it exists, - if len(arg) > 5 and arg[-5] == '#': + if len(arg) > 5 and arg[-5] == "#": discrim = arg[-4:] name = arg[:-5] predicate = lambda u: u.name == name and u.discriminator == discrim @@ -330,29 +344,33 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): @staticmethod def _get_id_matches(ctx, argument): - id_regex = re.compile(r'(?:(?P[0-9]{15,20})-)?(?P[0-9]{15,20})$') + id_regex = re.compile( + r"(?:(?P[0-9]{15,20})-)?(?P[0-9]{15,20})$" + ) link_regex = re.compile( - r'https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/' - r'(?P[0-9]{15,20}|@me)' - r'/(?P[0-9]{15,20})/(?P[0-9]{15,20})/?$' + r"https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/" + r"(?P[0-9]{15,20}|@me)" + r"/(?P[0-9]{15,20})/(?P[0-9]{15,20})/?$" ) match = id_regex.match(argument) or link_regex.match(argument) if not match: raise MessageNotFound(argument) data = match.groupdict() - channel_id = discord.utils._get_as_snowflake(data, 'channel_id') - message_id = int(data['message_id']) - guild_id = data.get('guild_id') + channel_id = discord.utils._get_as_snowflake(data, "channel_id") + message_id = int(data["message_id"]) + guild_id = data.get("guild_id") if guild_id is None: guild_id = ctx.guild and ctx.guild.id - elif guild_id == '@me': + elif guild_id == "@me": guild_id = None else: guild_id = int(guild_id) return guild_id, message_id, channel_id @staticmethod - def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]: + def _resolve_channel( + ctx, guild_id, channel_id + ) -> Optional[PartialMessageableChannel]: if guild_id is not None: guild = ctx.bot.get_guild(guild_id) if guild is not None and channel_id is not None: @@ -386,7 +404,9 @@ class MessageConverter(IDConverter[discord.Message]): """ async def convert(self, ctx: Context, argument: str) -> discord.Message: - guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument) + guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches( + ctx, argument + ) message = ctx.bot._connection._get_message(message_id) if message: return message @@ -417,13 +437,19 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.abc.GuildChannel: - return self._resolve_channel(ctx, argument, 'channels', discord.abc.GuildChannel) + return self._resolve_channel( + ctx, argument, "channels", discord.abc.GuildChannel + ) @staticmethod - def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT: + def _resolve_channel( + ctx: Context, argument: str, attribute: str, type: Type[CT] + ) -> CT: bot = ctx.bot - match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument) + match = IDConverter._get_id_match(argument) or re.match( + r"<#([0-9]{15,20})>$", argument + ) result = None guild = ctx.guild @@ -443,7 +469,7 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): if guild: result = guild.get_channel(channel_id) else: - result = _get_from_guilds(bot, 'get_channel', channel_id) + result = _get_from_guilds(bot, "get_channel", channel_id) if not isinstance(result, type): raise ChannelNotFound(argument) @@ -451,10 +477,14 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): return result @staticmethod - def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT: + def _resolve_thread( + ctx: Context, argument: str, attribute: str, type: Type[TT] + ) -> TT: bot = ctx.bot - match = IDConverter._get_id_match(argument) or re.match(r'<#([0-9]{15,20})>$', argument) + match = IDConverter._get_id_match(argument) or re.match( + r"<#([0-9]{15,20})>$", argument + ) result = None guild = ctx.guild @@ -491,7 +521,9 @@ class TextChannelConverter(IDConverter[discord.TextChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.TextChannel: - return GuildChannelConverter._resolve_channel(ctx, argument, 'text_channels', discord.TextChannel) + return GuildChannelConverter._resolve_channel( + ctx, argument, "text_channels", discord.TextChannel + ) class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): @@ -511,7 +543,9 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel: - return GuildChannelConverter._resolve_channel(ctx, argument, 'voice_channels', discord.VoiceChannel) + return GuildChannelConverter._resolve_channel( + ctx, argument, "voice_channels", discord.VoiceChannel + ) class StageChannelConverter(IDConverter[discord.StageChannel]): @@ -530,7 +564,9 @@ class StageChannelConverter(IDConverter[discord.StageChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.StageChannel: - return GuildChannelConverter._resolve_channel(ctx, argument, 'stage_channels', discord.StageChannel) + return GuildChannelConverter._resolve_channel( + ctx, argument, "stage_channels", discord.StageChannel + ) class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): @@ -550,7 +586,9 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel: - return GuildChannelConverter._resolve_channel(ctx, argument, 'categories', discord.CategoryChannel) + return GuildChannelConverter._resolve_channel( + ctx, argument, "categories", discord.CategoryChannel + ) class StoreChannelConverter(IDConverter[discord.StoreChannel]): @@ -569,7 +607,9 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel: - return GuildChannelConverter._resolve_channel(ctx, argument, 'channels', discord.StoreChannel) + return GuildChannelConverter._resolve_channel( + ctx, argument, "channels", discord.StoreChannel + ) class ThreadConverter(IDConverter[discord.Thread]): @@ -587,7 +627,9 @@ class ThreadConverter(IDConverter[discord.Thread]): """ async def convert(self, ctx: Context, argument: str) -> discord.Thread: - return GuildChannelConverter._resolve_thread(ctx, argument, 'threads', discord.Thread) + return GuildChannelConverter._resolve_thread( + ctx, argument, "threads", discord.Thread + ) class ColourConverter(Converter[discord.Colour]): @@ -616,10 +658,12 @@ class ColourConverter(Converter[discord.Colour]): Added support for ``rgb`` function and 3-digit hex shortcuts """ - RGB_REGEX = re.compile(r'rgb\s*\((?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*\)') + RGB_REGEX = re.compile( + r"rgb\s*\((?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*\)" + ) def parse_hex_number(self, argument): - arg = ''.join(i * 2 for i in argument) if len(argument) == 3 else argument + arg = "".join(i * 2 for i in argument) if len(argument) == 3 else argument try: value = int(arg, base=16) if not (0 <= value <= 0xFFFFFF): @@ -630,7 +674,7 @@ class ColourConverter(Converter[discord.Colour]): return discord.Color(value=value) def parse_rgb_number(self, argument, number): - if number[-1] == '%': + if number[-1] == "%": value = int(number[:-1]) if not (0 <= value <= 100): raise BadColourArgument(argument) @@ -646,29 +690,29 @@ class ColourConverter(Converter[discord.Colour]): if match is None: raise BadColourArgument(argument) - red = self.parse_rgb_number(argument, match.group('r')) - green = self.parse_rgb_number(argument, match.group('g')) - blue = self.parse_rgb_number(argument, match.group('b')) + red = self.parse_rgb_number(argument, match.group("r")) + green = self.parse_rgb_number(argument, match.group("g")) + blue = self.parse_rgb_number(argument, match.group("b")) return discord.Color.from_rgb(red, green, blue) async def convert(self, ctx: Context, argument: str) -> discord.Colour: - if argument[0] == '#': + if argument[0] == "#": return self.parse_hex_number(argument[1:]) - if argument[0:2] == '0x': + if argument[0:2] == "0x": rest = argument[2:] # Legacy backwards compatible syntax - if rest.startswith('#'): + if rest.startswith("#"): return self.parse_hex_number(rest[1:]) return self.parse_hex_number(rest) arg = argument.lower() - if arg[0:3] == 'rgb': + if arg[0:3] == "rgb": return self.parse_rgb(arg) - arg = arg.replace(' ', '_') + arg = arg.replace(" ", "_") method = getattr(discord.Colour, arg, None) - if arg.startswith('from_') or method is None or not inspect.ismethod(method): + if arg.startswith("from_") or method is None or not inspect.ismethod(method): raise BadColourArgument(arg) return method() @@ -697,7 +741,9 @@ class RoleConverter(IDConverter[discord.Role]): if not guild: raise NoPrivateMessage() - match = self._get_id_match(argument) or re.match(r'<@&([0-9]{15,20})>$', argument) + match = self._get_id_match(argument) or re.match( + r"<@&([0-9]{15,20})>$", argument + ) if match: result = guild.get_role(int(match.group(1))) else: @@ -776,7 +822,9 @@ class EmojiConverter(IDConverter[discord.Emoji]): """ async def convert(self, ctx: Context, argument: str) -> discord.Emoji: - match = self._get_id_match(argument) or re.match(r'$', argument) + match = self._get_id_match(argument) or re.match( + r"$", argument + ) result = None bot = ctx.bot guild = ctx.guild @@ -810,7 +858,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]): """ async def convert(self, ctx: Context, argument: str) -> discord.PartialEmoji: - match = re.match(r'<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$', argument) + match = re.match(r"<(a?):([a-zA-Z0-9\_]{1,32}):([0-9]{15,20})>$", argument) if match: emoji_animated = bool(match.group(1)) @@ -818,7 +866,10 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]): emoji_id = int(match.group(3)) return discord.PartialEmoji.with_state( - ctx.bot._connection, animated=emoji_animated, name=emoji_name, id=emoji_id + ctx.bot._connection, + animated=emoji_animated, + name=emoji_name, + id=emoji_id, ) raise PartialEmojiConversionFailure(argument) @@ -903,37 +954,41 @@ class clean_content(Converter[str]): def resolve_member(id: int) -> str: m = _utils_get(msg.mentions, id=id) or ctx.guild.get_member(id) - return f'@{m.display_name if self.use_nicknames else m.name}' if m else '@deleted-user' + return ( + f"@{m.display_name if self.use_nicknames else m.name}" + if m + else "@deleted-user" + ) def resolve_role(id: int) -> str: r = _utils_get(msg.role_mentions, id=id) or ctx.guild.get_role(id) - return f'@{r.name}' if r else '@deleted-role' + return f"@{r.name}" if r else "@deleted-role" else: def resolve_member(id: int) -> str: m = _utils_get(msg.mentions, id=id) or ctx.bot.get_user(id) - return f'@{m.name}' if m else '@deleted-user' + return f"@{m.name}" if m else "@deleted-user" def resolve_role(id: int) -> str: - return '@deleted-role' + return "@deleted-role" if self.fix_channel_mentions and ctx.guild: def resolve_channel(id: int) -> str: c = ctx.guild.get_channel(id) - return f'#{c.name}' if c else '#deleted-channel' + return f"#{c.name}" if c else "#deleted-channel" else: def resolve_channel(id: int) -> str: - return f'<#{id}>' + return f"<#{id}>" transforms = { - '@': resolve_member, - '@!': resolve_member, - '#': resolve_channel, - '@&': resolve_role, + "@": resolve_member, + "@!": resolve_member, + "#": resolve_channel, + "@&": resolve_role, } def repl(match: re.Match) -> str: @@ -942,7 +997,7 @@ class clean_content(Converter[str]): transformed = transforms[type](id) return transformed - result = re.sub(r'<(@[!&]?|#)([0-9]{15,20})>', repl, argument) + result = re.sub(r"<(@[!&]?|#)([0-9]{15,20})>", repl, argument) if self.escape_markdown: result = discord.utils.escape_markdown(result) elif self.remove_markdown: @@ -974,42 +1029,46 @@ class Greedy(List[T]): For more information, check :ref:`ext_commands_special_converters`. """ - __slots__ = ('converter',) + __slots__ = ("converter",) def __init__(self, *, converter: T): self.converter = converter def __repr__(self): - converter = getattr(self.converter, '__name__', repr(self.converter)) - return f'Greedy[{converter}]' + converter = getattr(self.converter, "__name__", repr(self.converter)) + return f"Greedy[{converter}]" def __class_getitem__(cls, params: Union[Tuple[T], T]) -> Greedy[T]: if not isinstance(params, tuple): params = (params,) if len(params) != 1: - raise TypeError('Greedy[...] only takes a single argument') + raise TypeError("Greedy[...] only takes a single argument") converter = params[0] - origin = getattr(converter, '__origin__', None) - args = getattr(converter, '__args__', ()) + origin = getattr(converter, "__origin__", None) + args = getattr(converter, "__args__", ()) - if not (callable(converter) or isinstance(converter, Converter) or origin is not None): - raise TypeError('Greedy[...] expects a type or a Converter instance.') + if not ( + callable(converter) + or isinstance(converter, Converter) + or origin is not None + ): + raise TypeError("Greedy[...] expects a type or a Converter instance.") if converter in (str, type(None)) or origin is Greedy: - raise TypeError(f'Greedy[{converter.__name__}] is invalid.') + raise TypeError(f"Greedy[{converter.__name__}] is invalid.") if origin is Union and type(None) in args: - raise TypeError(f'Greedy[{converter!r}] is invalid.') + raise TypeError(f"Greedy[{converter!r}] is invalid.") return cls(converter=converter) def _convert_to_bool(argument: str) -> bool: lowered = argument.lower() - if lowered in ('yes', 'y', 'true', 't', '1', 'enable', 'on'): + if lowered in ("yes", "y", "true", "t", "1", "enable", "on"): return True - elif lowered in ('no', 'n', 'false', 'f', '0', 'disable', 'off'): + elif lowered in ("no", "n", "false", "f", "0", "disable", "off"): return False else: raise BadBoolArgument(lowered) @@ -1056,7 +1115,9 @@ CONVERTER_MAPPING: Dict[Type[Any], Any] = { } -async def _actual_conversion(ctx: Context, converter, argument: str, param: inspect.Parameter): +async def _actual_conversion( + ctx: Context, converter, argument: str, param: inspect.Parameter +): if converter is bool: return _convert_to_bool(argument) @@ -1065,7 +1126,9 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp except AttributeError: pass else: - if module is not None and (module.startswith('discord.') and not module.endswith('converter')): + if module is not None and ( + module.startswith("discord.") and not module.endswith("converter") + ): converter = CONVERTER_MAPPING.get(converter, converter) try: @@ -1091,10 +1154,14 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp except AttributeError: name = converter.__class__.__name__ - raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc + raise BadArgument( + f'Converting to "{name}" failed for parameter "{param.name}".' + ) from exc -async def run_converters(ctx: Context, converter, argument: str, param: inspect.Parameter): +async def run_converters( + ctx: Context, converter, argument: str, param: inspect.Parameter +): """|coro| Runs converters for a given converter, argument, and parameter. @@ -1124,7 +1191,7 @@ async def run_converters(ctx: Context, converter, argument: str, param: inspect. Any The resulting conversion. """ - origin = getattr(converter, '__origin__', None) + origin = getattr(converter, "__origin__", None) if origin is Union: errors = [] diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index 2e008aed..844b2214 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -38,24 +38,25 @@ if TYPE_CHECKING: from ...message import Message __all__ = ( - 'BucketType', - 'Cooldown', - 'CooldownMapping', - 'DynamicCooldownMapping', - 'MaxConcurrency', + "BucketType", + "Cooldown", + "CooldownMapping", + "DynamicCooldownMapping", + "MaxConcurrency", ) -C = TypeVar('C', bound='CooldownMapping') -MC = TypeVar('MC', bound='MaxConcurrency') +C = TypeVar("C", bound="CooldownMapping") +MC = TypeVar("MC", bound="MaxConcurrency") + class BucketType(Enum): - default = 0 - user = 1 - guild = 2 - channel = 3 - member = 4 + default = 0 + user = 1 + guild = 2 + channel = 3 + member = 4 category = 5 - role = 6 + role = 6 def get_key(self, msg: Message) -> Any: if self is BucketType.user: @@ -90,7 +91,7 @@ class Cooldown: The length of the cooldown period in seconds. """ - __slots__ = ('rate', 'per', '_window', '_tokens', '_last') + __slots__ = ("rate", "per", "_window", "_tokens", "_last") def __init__(self, rate: float, per: float) -> None: self.rate: int = int(rate) @@ -190,7 +191,8 @@ class Cooldown: return Cooldown(self.rate, self.per) def __repr__(self) -> str: - return f'' + return f"" + class CooldownMapping: def __init__( @@ -199,7 +201,7 @@ class CooldownMapping: type: Callable[[Message], Any], ) -> None: if not callable(type): - raise TypeError('Cooldown type must be a BucketType or callable') + raise TypeError("Cooldown type must be a BucketType or callable") self._cache: Dict[Any, Cooldown] = {} self._cooldown: Optional[Cooldown] = original @@ -252,16 +254,16 @@ class CooldownMapping: return bucket - def update_rate_limit(self, message: Message, current: Optional[float] = None) -> Optional[float]: + def update_rate_limit( + self, message: Message, current: Optional[float] = None + ) -> Optional[float]: bucket = self.get_bucket(message, current) return bucket.update_rate_limit(current) -class DynamicCooldownMapping(CooldownMapping): +class DynamicCooldownMapping(CooldownMapping): def __init__( - self, - factory: Callable[[Message], Cooldown], - type: Callable[[Message], Any] + self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any] ) -> None: super().__init__(None, type) self._factory: Callable[[Message], Cooldown] = factory @@ -278,6 +280,7 @@ class DynamicCooldownMapping(CooldownMapping): def create_bucket(self, message: Message) -> Cooldown: return self._factory(message) + class _Semaphore: """This class is a version of a semaphore. @@ -291,7 +294,7 @@ class _Semaphore: overkill for what is basically a counter. """ - __slots__ = ('value', 'loop', '_waiters') + __slots__ = ("value", "loop", "_waiters") def __init__(self, number: int) -> None: self.value: int = number @@ -299,7 +302,7 @@ class _Semaphore: self._waiters: Deque[asyncio.Future] = deque() def __repr__(self) -> str: - return f'<_Semaphore value={self.value} waiters={len(self._waiters)}>' + return f"<_Semaphore value={self.value} waiters={len(self._waiters)}>" def locked(self) -> bool: return self.value == 0 @@ -337,8 +340,9 @@ class _Semaphore: self.value += 1 self.wake_up() + class MaxConcurrency: - __slots__ = ('number', 'per', 'wait', '_mapping') + __slots__ = ("number", "per", "wait", "_mapping") def __init__(self, number: int, *, per: BucketType, wait: bool) -> None: self._mapping: Dict[Any, _Semaphore] = {} @@ -347,16 +351,20 @@ class MaxConcurrency: self.wait: bool = wait if number <= 0: - raise ValueError('max_concurrency \'number\' cannot be less than 1') + raise ValueError("max_concurrency 'number' cannot be less than 1") if not isinstance(per, BucketType): - raise TypeError(f'max_concurrency \'per\' must be of type BucketType not {type(per)!r}') + raise TypeError( + f"max_concurrency 'per' must be of type BucketType not {type(per)!r}" + ) def copy(self: MC) -> MC: return self.__class__(self.number, per=self.per, wait=self.wait) def __repr__(self) -> str: - return f'' + return ( + f"" + ) def get_key(self, message: Message) -> Any: return self.per.get_key(message) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 35b7e840..2c10fd0d 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -48,7 +48,13 @@ import datetime import discord from .errors import * -from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency, DynamicCooldownMapping +from .cooldowns import ( + Cooldown, + BucketType, + CooldownMapping, + MaxConcurrency, + DynamicCooldownMapping, +) from .converter import run_converters, get_converter, Greedy from ._types import _BaseCommand from .cog import Cog @@ -70,52 +76,53 @@ if TYPE_CHECKING: __all__ = ( - 'Command', - 'Group', - 'GroupMixin', - 'command', - 'group', - 'has_role', - 'has_permissions', - 'has_any_role', - 'check', - 'check_any', - 'before_invoke', - 'after_invoke', - 'bot_has_role', - 'bot_has_permissions', - 'bot_has_any_role', - 'cooldown', - 'dynamic_cooldown', - 'max_concurrency', - 'dm_only', - 'guild_only', - 'is_owner', - 'is_nsfw', - 'has_guild_permissions', - 'bot_has_guild_permissions' + "Command", + "Group", + "GroupMixin", + "command", + "group", + "has_role", + "has_permissions", + "has_any_role", + "check", + "check_any", + "before_invoke", + "after_invoke", + "bot_has_role", + "bot_has_permissions", + "bot_has_any_role", + "cooldown", + "dynamic_cooldown", + "max_concurrency", + "dm_only", + "guild_only", + "is_owner", + "is_nsfw", + "has_guild_permissions", + "bot_has_guild_permissions", ) MISSING: Any = discord.utils.MISSING -T = TypeVar('T') -CogT = TypeVar('CogT', bound='Cog') -CommandT = TypeVar('CommandT', bound='Command') -ContextT = TypeVar('ContextT', bound='Context') +T = TypeVar("T") +CogT = TypeVar("CogT", bound="Cog") +CommandT = TypeVar("CommandT", bound="Command") +ContextT = TypeVar("ContextT", bound="Context") # CHT = TypeVar('CHT', bound='Check') -GroupT = TypeVar('GroupT', bound='Group') -HookT = TypeVar('HookT', bound='Hook') -ErrorT = TypeVar('ErrorT', bound='Error') +GroupT = TypeVar("GroupT", bound="Group") +HookT = TypeVar("HookT", bound="Hook") +ErrorT = TypeVar("ErrorT", bound="Error") if TYPE_CHECKING: - P = ParamSpec('P') + P = ParamSpec("P") else: - P = TypeVar('P') + P = TypeVar("P") + def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: partial = functools.partial while True: - if hasattr(function, '__wrapped__'): + if hasattr(function, "__wrapped__"): function = function.__wrapped__ elif isinstance(function, partial): function = function.func @@ -123,7 +130,9 @@ def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: return function -def get_signature_parameters(function: Callable[..., Any], globalns: Dict[str, Any]) -> Dict[str, inspect.Parameter]: +def get_signature_parameters( + function: Callable[..., Any], globalns: Dict[str, Any] +) -> Dict[str, inspect.Parameter]: signature = inspect.signature(function) params = {} cache: Dict[str, Any] = {} @@ -139,7 +148,7 @@ def get_signature_parameters(function: Callable[..., Any], globalns: Dict[str, A annotation = eval_annotation(annotation, globalns, globalns, cache) if annotation is Greedy: - raise TypeError('Unparameterized Greedy[...] is disallowed in signature.') + raise TypeError("Unparameterized Greedy[...] is disallowed in signature.") params[name] = parameter.replace(annotation=annotation) @@ -158,8 +167,10 @@ def wrap_callback(coro): except Exception as exc: raise CommandInvokeError(exc) from exc return ret + return wrapped + def hooked_wrapped_callback(command, ctx, coro): @functools.wraps(coro) async def wrapped(*args, **kwargs): @@ -180,6 +191,7 @@ def hooked_wrapped_callback(command, ctx, coro): await command.call_after_hooks(ctx) return ret + return wrapped @@ -202,6 +214,7 @@ class _CaseInsensitiveDict(dict): def __setitem__(self, k, v): super().__setitem__(k.casefold(), v) + class Command(_BaseCommand, Generic[CogT, P, T]): r"""A class that implements the protocol for a bot text command. @@ -269,8 +282,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]): which calls converters. If ``False`` then cooldown processing is done first and then the converters are called second. Defaults to ``False``. extras: :class:`dict` - A dict of user provided extras to attach to the Command. - + A dict of user provided extras to attach to the Command. + .. note:: This object may be copied by the library. @@ -295,78 +308,86 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.__original_kwargs__ = kwargs.copy() return self - def __init__(self, func: Union[ + def __init__( + self, + func: Union[ Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]], - ], **kwargs: Any): + ], + **kwargs: Any, + ): if not asyncio.iscoroutinefunction(func): - raise TypeError('Callback must be a coroutine.') + raise TypeError("Callback must be a coroutine.") - name = kwargs.get('name') or func.__name__ + name = kwargs.get("name") or func.__name__ if not isinstance(name, str): - raise TypeError('Name of a command must be a string.') + raise TypeError("Name of a command must be a string.") self.name: str = name self.callback = func - self.enabled: bool = kwargs.get('enabled', True) + self.enabled: bool = kwargs.get("enabled", True) - help_doc = kwargs.get('help') + help_doc = kwargs.get("help") if help_doc is not None: help_doc = inspect.cleandoc(help_doc) else: help_doc = inspect.getdoc(func) if isinstance(help_doc, bytes): - help_doc = help_doc.decode('utf-8') + help_doc = help_doc.decode("utf-8") self.help: Optional[str] = help_doc - self.brief: Optional[str] = kwargs.get('brief') - self.usage: Optional[str] = kwargs.get('usage') - self.rest_is_raw: bool = kwargs.get('rest_is_raw', False) - self.aliases: Union[List[str], Tuple[str]] = kwargs.get('aliases', []) - self.extras: Dict[str, Any] = kwargs.get('extras', {}) + self.brief: Optional[str] = kwargs.get("brief") + self.usage: Optional[str] = kwargs.get("usage") + self.rest_is_raw: bool = kwargs.get("rest_is_raw", False) + self.aliases: Union[List[str], Tuple[str]] = kwargs.get("aliases", []) + self.extras: Dict[str, Any] = kwargs.get("extras", {}) if not isinstance(self.aliases, (list, tuple)): - raise TypeError("Aliases of a command must be a list or a tuple of strings.") + raise TypeError( + "Aliases of a command must be a list or a tuple of strings." + ) - self.description: str = inspect.cleandoc(kwargs.get('description', '')) - self.hidden: bool = kwargs.get('hidden', False) + self.description: str = inspect.cleandoc(kwargs.get("description", "")) + self.hidden: bool = kwargs.get("hidden", False) try: checks = func.__commands_checks__ checks.reverse() except AttributeError: - checks = kwargs.get('checks', []) + checks = kwargs.get("checks", []) self.checks: List[Check] = checks try: cooldown = func.__commands_cooldown__ except AttributeError: - cooldown = kwargs.get('cooldown') - + cooldown = kwargs.get("cooldown") + if cooldown is None: buckets = CooldownMapping(cooldown, BucketType.default) elif isinstance(cooldown, CooldownMapping): buckets = cooldown else: - raise TypeError("Cooldown must be a an instance of CooldownMapping or None.") + raise TypeError( + "Cooldown must be a an instance of CooldownMapping or None." + ) self._buckets: CooldownMapping = buckets try: max_concurrency = func.__commands_max_concurrency__ except AttributeError: - max_concurrency = kwargs.get('max_concurrency') + max_concurrency = kwargs.get("max_concurrency") self._max_concurrency: Optional[MaxConcurrency] = max_concurrency - self.require_var_positional: bool = kwargs.get('require_var_positional', False) - self.ignore_extra: bool = kwargs.get('ignore_extra', True) - self.cooldown_after_parsing: bool = kwargs.get('cooldown_after_parsing', False) + self.require_var_positional: bool = kwargs.get("require_var_positional", False) + self.ignore_extra: bool = kwargs.get("ignore_extra", True) + self.cooldown_after_parsing: bool = kwargs.get("cooldown_after_parsing", False) self.cog: Optional[CogT] = None # bandaid for the fact that sometimes parent can be the bot instance - parent = kwargs.get('parent') + parent = kwargs.get("parent") self.parent: Optional[GroupMixin] = parent if isinstance(parent, _BaseCommand) else None # type: ignore self._before_invoke: Optional[Hook] = None @@ -386,17 +407,22 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.after_invoke(after_invoke) @property - def callback(self) -> Union[ - Callable[Concatenate[CogT, Context, P], Coro[T]], - Callable[Concatenate[Context, P], Coro[T]], - ]: + def callback( + self, + ) -> Union[ + Callable[Concatenate[CogT, Context, P], Coro[T]], + Callable[Concatenate[Context, P], Coro[T]], + ]: return self._callback @callback.setter - def callback(self, function: Union[ + def callback( + self, + function: Union[ Callable[Concatenate[CogT, Context, P], Coro[T]], Callable[Concatenate[Context, P], Coro[T]], - ]) -> None: + ], + ) -> None: self._callback = function unwrap = unwrap_function(function) self.module = unwrap.__module__ @@ -527,12 +553,14 @@ class Command(_BaseCommand, Generic[CogT, P, T]): wrapped = wrap_callback(local) await wrapped(ctx, error) finally: - ctx.bot.dispatch('command_error', ctx, error) + ctx.bot.dispatch("command_error", ctx, error) async def transform(self, ctx: Context, param: inspect.Parameter) -> Any: required = param.default is param.empty converter = get_converter(param) - consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw + consume_rest_is_special = ( + param.kind == param.KEYWORD_ONLY and not self.rest_is_raw + ) view = ctx.view view.skip_ws() @@ -540,9 +568,13 @@ class Command(_BaseCommand, Generic[CogT, P, T]): # it undos the view ready for the next parameter to use instead if isinstance(converter, Greedy): if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY): - return await self._transform_greedy_pos(ctx, param, required, converter.converter) + return await self._transform_greedy_pos( + ctx, param, required, converter.converter + ) elif param.kind == param.VAR_POSITIONAL: - return await self._transform_greedy_var_pos(ctx, param, converter.converter) + return await self._transform_greedy_var_pos( + ctx, param, converter.converter + ) else: # if we're here, then it's a KEYWORD_ONLY param type # since this is mostly useless, we'll helpfully transform Greedy[X] @@ -551,11 +583,14 @@ class Command(_BaseCommand, Generic[CogT, P, T]): if view.eof: if param.kind == param.VAR_POSITIONAL: - raise RuntimeError() # break the loop + raise RuntimeError() # break the loop if required: if self._is_typing_optional(param.annotation): return None - if hasattr(converter, '__commands_is_flag__') and converter._can_be_constructible(): + if ( + hasattr(converter, "__commands_is_flag__") + and converter._can_be_constructible() + ): return await converter._construct_default(ctx) raise MissingRequiredArgument(param) return param.default @@ -577,7 +612,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): # type-checker fails to narrow argument return await run_converters(ctx, converter, argument, param) # type: ignore - async def _transform_greedy_pos(self, ctx: Context, param: inspect.Parameter, required: bool, converter: Any) -> Any: + async def _transform_greedy_pos( + self, ctx: Context, param: inspect.Parameter, required: bool, converter: Any + ) -> Any: view = ctx.view result = [] while not view.eof: @@ -598,7 +635,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): return param.default return result - async def _transform_greedy_var_pos(self, ctx: Context, param: inspect.Parameter, converter: Any) -> Any: + async def _transform_greedy_var_pos( + self, ctx: Context, param: inspect.Parameter, converter: Any + ) -> Any: view = ctx.view previous = view.index try: @@ -606,7 +645,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): value = await run_converters(ctx, converter, argument, param) # type: ignore except (CommandError, ArgumentParsingError): view.index = previous - raise RuntimeError() from None # break loop + raise RuntimeError() from None # break loop else: return value @@ -643,11 +682,11 @@ class Command(_BaseCommand, Generic[CogT, P, T]): entries = [] command = self # command.parent is type-hinted as GroupMixin some attributes are resolved via MRO - while command.parent is not None: # type: ignore - command = command.parent # type: ignore - entries.append(command.name) # type: ignore + while command.parent is not None: # type: ignore + command = command.parent # type: ignore + entries.append(command.name) # type: ignore - return ' '.join(reversed(entries)) + return " ".join(reversed(entries)) @property def parents(self) -> List[Group]: @@ -661,8 +700,8 @@ class Command(_BaseCommand, Generic[CogT, P, T]): """ entries = [] command = self - while command.parent is not None: # type: ignore - command = command.parent # type: ignore + while command.parent is not None: # type: ignore + command = command.parent # type: ignore entries.append(command) return entries @@ -690,7 +729,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): parent = self.full_parent_name if parent: - return parent + ' ' + self.name + return parent + " " + self.name else: return self.name @@ -712,13 +751,17 @@ class Command(_BaseCommand, Generic[CogT, P, T]): try: next(iterator) except StopIteration: - raise discord.ClientException(f'Callback for {self.name} command is missing "self" parameter.') + raise discord.ClientException( + f'Callback for {self.name} command is missing "self" parameter.' + ) # next we have the 'ctx' as the next parameter try: next(iterator) except StopIteration: - raise discord.ClientException(f'Callback for {self.name} command is missing "ctx" parameter.') + raise discord.ClientException( + f'Callback for {self.name} command is missing "ctx" parameter.' + ) for name, param in iterator: ctx.current_parameter = param @@ -745,7 +788,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): break if not self.ignore_extra and not view.eof: - raise TooManyArguments('Too many arguments passed to ' + self.qualified_name) + raise TooManyArguments( + "Too many arguments passed to " + self.qualified_name + ) async def call_before_hooks(self, ctx: Context) -> None: # now that we're done preparing we can call the pre-command hooks @@ -753,7 +798,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): cog = self.cog if self._before_invoke is not None: # should be cog if @commands.before_invoke is used - instance = getattr(self._before_invoke, '__self__', cog) + instance = getattr(self._before_invoke, "__self__", cog) # __self__ only exists for methods, not functions # however, if @command.before_invoke is used, it will be a function if instance: @@ -775,7 +820,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): async def call_after_hooks(self, ctx: Context) -> None: cog = self.cog if self._after_invoke is not None: - instance = getattr(self._after_invoke, '__self__', cog) + instance = getattr(self._after_invoke, "__self__", cog) if instance: await self._after_invoke(instance, ctx) # type: ignore else: @@ -805,7 +850,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): ctx.command = self if not await self.can_run(ctx): - raise CheckFailure(f'The check functions for command {self.qualified_name} failed.') + raise CheckFailure( + f"The check functions for command {self.qualified_name} failed." + ) if self._max_concurrency is not None: # For this application, context can be duck-typed as a Message @@ -929,7 +976,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): """ if not asyncio.iscoroutinefunction(coro): - raise TypeError('The error handler must be a coroutine.') + raise TypeError("The error handler must be a coroutine.") self.on_error: Error = coro return coro @@ -939,7 +986,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): .. versionadded:: 1.7 """ - return hasattr(self, 'on_error') + return hasattr(self, "on_error") def before_invoke(self, coro: HookT) -> HookT: """A decorator that registers a coroutine as a pre-invoke hook. @@ -963,7 +1010,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): The coroutine passed is not actually a coroutine. """ if not asyncio.iscoroutinefunction(coro): - raise TypeError('The pre-invoke hook must be a coroutine.') + raise TypeError("The pre-invoke hook must be a coroutine.") self._before_invoke = coro return coro @@ -990,7 +1037,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): The coroutine passed is not actually a coroutine. """ if not asyncio.iscoroutinefunction(coro): - raise TypeError('The post-invoke hook must be a coroutine.') + raise TypeError("The post-invoke hook must be a coroutine.") self._after_invoke = coro return coro @@ -1011,11 +1058,13 @@ class Command(_BaseCommand, Generic[CogT, P, T]): if self.brief is not None: return self.brief if self.help is not None: - return self.help.split('\n', 1)[0] - return '' + return self.help.split("\n", 1)[0] + return "" - def _is_typing_optional(self, annotation: Union[T, Optional[T]]) -> TypeGuard[Optional[T]]: - return getattr(annotation, '__origin__', None) is Union and type(None) in annotation.__args__ # type: ignore + def _is_typing_optional( + self, annotation: Union[T, Optional[T]] + ) -> TypeGuard[Optional[T]]: + return getattr(annotation, "__origin__", None) is Union and type(None) in annotation.__args__ # type: ignore @property def signature(self) -> str: @@ -1025,7 +1074,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): params = self.clean_params if not params: - return '' + return "" result = [] for name, param in params.items(): @@ -1035,41 +1084,51 @@ class Command(_BaseCommand, Generic[CogT, P, T]): # for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the # parameter signature is a literal list of it's values annotation = param.annotation.converter if greedy else param.annotation - origin = getattr(annotation, '__origin__', None) + origin = getattr(annotation, "__origin__", None) if not greedy and origin is Union: none_cls = type(None) union_args = annotation.__args__ optional = union_args[-1] is none_cls if len(union_args) == 2 and optional: annotation = union_args[0] - origin = getattr(annotation, '__origin__', None) + origin = getattr(annotation, "__origin__", None) if origin is Literal: - name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__) + name = "|".join( + f'"{v}"' if isinstance(v, str) else str(v) + for v in annotation.__args__ + ) if param.default is not param.empty: # We don't want None or '' to trigger the [name=value] case and instead it should # do [name] since [name=None] or [name=] are not exactly useful for the user. - should_print = param.default if isinstance(param.default, str) else param.default is not None + should_print = ( + param.default + if isinstance(param.default, str) + else param.default is not None + ) if should_print: - result.append(f'[{name}={param.default}]' if not greedy else - f'[{name}={param.default}]...') + result.append( + f"[{name}={param.default}]" + if not greedy + else f"[{name}={param.default}]..." + ) continue else: - result.append(f'[{name}]') + result.append(f"[{name}]") elif param.kind == param.VAR_POSITIONAL: if self.require_var_positional: - result.append(f'<{name}...>') + result.append(f"<{name}...>") else: - result.append(f'[{name}...]') + result.append(f"[{name}...]") elif greedy: - result.append(f'[{name}]...') + result.append(f"[{name}]...") elif optional: - result.append(f'[{name}]') + result.append(f"[{name}]") else: - result.append(f'<{name}>') + result.append(f"<{name}>") - return ' '.join(result) + return " ".join(result) async def can_run(self, ctx: Context) -> bool: """|coro| @@ -1099,14 +1158,16 @@ class Command(_BaseCommand, Generic[CogT, P, T]): """ if not self.enabled: - raise DisabledCommand(f'{self.name} command is disabled') + raise DisabledCommand(f"{self.name} command is disabled") original = ctx.command ctx.command = self try: if not await ctx.bot.can_run(ctx): - raise CheckFailure(f'The global check functions for command {self.qualified_name} failed.') + raise CheckFailure( + f"The global check functions for command {self.qualified_name} failed." + ) cog = self.cog if cog is not None: @@ -1125,6 +1186,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): finally: ctx.command = original + class GroupMixin(Generic[CogT]): """A mixin that implements common functionality for classes that behave similar to :class:`.Group` and are allowed to register commands. @@ -1137,9 +1199,12 @@ class GroupMixin(Generic[CogT]): case_insensitive: :class:`bool` Whether the commands should be case insensitive. Defaults to ``False``. """ + def __init__(self, *args: Any, **kwargs: Any) -> None: - case_insensitive = kwargs.get('case_insensitive', False) - self.all_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {} + case_insensitive = kwargs.get("case_insensitive", False) + self.all_commands: Dict[str, Command[CogT, Any, Any]] = ( + _CaseInsensitiveDict() if case_insensitive else {} + ) self.case_insensitive: bool = case_insensitive super().__init__(*args, **kwargs) @@ -1177,7 +1242,7 @@ class GroupMixin(Generic[CogT]): """ if not isinstance(command, Command): - raise TypeError('The command passed must be a subclass of Command') + raise TypeError("The command passed must be a subclass of Command") if isinstance(self, Command): command.parent = self @@ -1267,7 +1332,7 @@ class GroupMixin(Generic[CogT]): """ # fast path, no space in name. - if ' ' not in name: + if " " not in name: return self.all_commands.get(name) names = name.split() @@ -1298,7 +1363,9 @@ class GroupMixin(Generic[CogT]): Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]], ] - ], Command[CogT, P, T]]: + ], + Command[CogT, P, T], + ]: ... @overload @@ -1326,8 +1393,9 @@ class GroupMixin(Generic[CogT]): Callable[..., :class:`Command`] A decorator that converts the provided method into a Command, adds it to the bot, then returns it. """ + def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> CommandT: - kwargs.setdefault('parent', self) + kwargs.setdefault("parent", self) result = command(name=name, cls=cls, *args, **kwargs)(func) self.add_command(result) return result @@ -1341,12 +1409,15 @@ class GroupMixin(Generic[CogT]): cls: Type[Group[CogT, P, T]] = ..., *args: Any, **kwargs: Any, - ) -> Callable[[ + ) -> Callable[ + [ Union[ Callable[Concatenate[CogT, ContextT, P], Coro[T]], - Callable[Concatenate[ContextT, P], Coro[T]] + Callable[Concatenate[ContextT, P], Coro[T]], ] - ], Group[CogT, P, T]]: + ], + Group[CogT, P, T], + ]: ... @overload @@ -1374,14 +1445,16 @@ class GroupMixin(Generic[CogT]): Callable[..., :class:`Group`] A decorator that converts the provided method into a Group, adds it to the bot, then returns it. """ + def decorator(func: Callable[Concatenate[ContextT, P], Coro[Any]]) -> GroupT: - kwargs.setdefault('parent', self) + kwargs.setdefault("parent", self) result = group(name=name, cls=cls, *args, **kwargs)(func) self.add_command(result) return result return decorator + class Group(GroupMixin[CogT], Command[CogT, P, T]): """A class that implements a grouping protocol for commands to be executed as subcommands. @@ -1404,8 +1477,9 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]): Indicates if the group's commands should be case insensitive. Defaults to ``False``. """ + def __init__(self, *args: Any, **attrs: Any) -> None: - self.invoke_without_command: bool = attrs.pop('invoke_without_command', False) + self.invoke_without_command: bool = attrs.pop("invoke_without_command", False) super().__init__(*args, **attrs) def copy(self: GroupT) -> GroupT: @@ -1492,8 +1566,10 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]): view.previous = previous await super().reinvoke(ctx, call_hooks=call_hooks) + # Decorators + @overload def command( name: str = ..., @@ -1505,10 +1581,12 @@ def command( Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]], ] - ] -, Command[CogT, P, T]]: + ], + Command[CogT, P, T], +]: ... + @overload def command( name: str = ..., @@ -1520,22 +1598,23 @@ def command( Callable[Concatenate[CogT, ContextT, P], Coro[Any]], Callable[Concatenate[ContextT, P], Coro[Any]], ] - ] -, CommandT]: + ], + CommandT, +]: ... + def command( - name: str = MISSING, - cls: Type[CommandT] = MISSING, - **attrs: Any + name: str = MISSING, cls: Type[CommandT] = MISSING, **attrs: Any ) -> Callable[ [ Union[ Callable[Concatenate[ContextT, P], Coro[Any]], Callable[Concatenate[CogT, ContextT, P], Coro[T]], ] - ] -, Union[Command[CogT, P, T], CommandT]]: + ], + Union[Command[CogT, P, T], CommandT], +]: """A decorator that transforms a function into a :class:`.Command` or if called with :func:`.group`, :class:`.Group`. @@ -1568,16 +1647,19 @@ def command( if cls is MISSING: cls = Command # type: ignore - def decorator(func: Union[ + def decorator( + func: Union[ Callable[Concatenate[ContextT, P], Coro[Any]], Callable[Concatenate[CogT, ContextT, P], Coro[Any]], - ]) -> CommandT: + ] + ) -> CommandT: if isinstance(func, Command): - raise TypeError('Callback is already a command.') + raise TypeError("Callback is already a command.") return cls(func, name=name, **attrs) return decorator + @overload def group( name: str = ..., @@ -1589,10 +1671,12 @@ def group( Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]], ] - ] -, Group[CogT, P, T]]: + ], + Group[CogT, P, T], +]: ... + @overload def group( name: str = ..., @@ -1604,10 +1688,12 @@ def group( Callable[Concatenate[CogT, ContextT, P], Coro[Any]], Callable[Concatenate[ContextT, P], Coro[Any]], ] - ] -, GroupT]: + ], + GroupT, +]: ... + def group( name: str = MISSING, cls: Type[GroupT] = MISSING, @@ -1618,8 +1704,9 @@ def group( Callable[Concatenate[ContextT, P], Coro[Any]], Callable[Concatenate[CogT, ContextT, P], Coro[T]], ] - ] -, Union[Group[CogT, P, T], GroupT]]: + ], + Union[Group[CogT, P, T], GroupT], +]: """A decorator that transforms a function into a :class:`.Group`. This is similar to the :func:`.command` decorator but the ``cls`` @@ -1632,6 +1719,7 @@ def group( cls = Group # type: ignore return command(name=name, cls=cls, **attrs) # type: ignore + def check(predicate: Check) -> Callable[[T], T]: r"""A decorator that adds a check to the :class:`.Command` or its subclasses. These checks could be accessed via :attr:`.Command.checks`. @@ -1707,7 +1795,7 @@ def check(predicate: Check) -> Callable[[T], T]: if isinstance(func, Command): func.checks.append(predicate) else: - if not hasattr(func, '__commands_checks__'): + if not hasattr(func, "__commands_checks__"): func.__commands_checks__ = [] func.__commands_checks__.append(predicate) @@ -1717,13 +1805,16 @@ def check(predicate: Check) -> Callable[[T], T]: if inspect.iscoroutinefunction(predicate): decorator.predicate = predicate else: + @functools.wraps(predicate) async def wrapper(ctx): return predicate(ctx) # type: ignore + decorator.predicate = wrapper return decorator # type: ignore + def check_any(*checks: Check) -> Callable[[T], T]: r"""A :func:`check` that is added that checks if any of the checks passed will pass, i.e. using logical OR. @@ -1773,7 +1864,9 @@ def check_any(*checks: Check) -> Callable[[T], T]: try: pred = wrapped.predicate except AttributeError: - raise TypeError(f'{wrapped!r} must be wrapped by commands.check decorator') from None + raise TypeError( + f"{wrapped!r} must be wrapped by commands.check decorator" + ) from None else: unwrapped.append(pred) @@ -1792,6 +1885,7 @@ def check_any(*checks: Check) -> Callable[[T], T]: return check(predicate) + def has_role(item: Union[int, str]) -> Callable[[T], T]: """A :func:`.check` that is added that checks if the member invoking the command has the role specified via the name or ID specified. @@ -1834,6 +1928,7 @@ def has_role(item: Union[int, str]) -> Callable[[T], T]: return check(predicate) + def has_any_role(*items: Union[int, str]) -> Callable[[T], T]: r"""A :func:`.check` that is added that checks if the member invoking the command has **any** of the roles specified. This means that if they have @@ -1865,18 +1960,25 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]: async def cool(ctx): await ctx.send('You are cool indeed') """ + def predicate(ctx): if ctx.guild is None: raise NoPrivateMessage() # ctx.guild is None doesn't narrow ctx.author to Member getter = functools.partial(discord.utils.get, ctx.author.roles) # type: ignore - if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items): + if any( + getter(id=item) is not None + if isinstance(item, int) + else getter(name=item) is not None + for item in items + ): return True raise MissingAnyRole(list(items)) return check(predicate) + def bot_has_role(item: int) -> Callable[[T], T]: """Similar to :func:`.has_role` except checks if the bot itself has the role. @@ -1903,8 +2005,10 @@ def bot_has_role(item: int) -> Callable[[T], T]: if role is None: raise BotMissingRole(item) return True + return check(predicate) + def bot_has_any_role(*items: int) -> Callable[[T], T]: """Similar to :func:`.has_any_role` except checks if the bot itself has any of the roles listed. @@ -1918,17 +2022,25 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]: Raise :exc:`.BotMissingAnyRole` or :exc:`.NoPrivateMessage` instead of generic checkfailure """ + def predicate(ctx): if ctx.guild is None: raise NoPrivateMessage() me = ctx.me getter = functools.partial(discord.utils.get, me.roles) - if any(getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items): + if any( + getter(id=item) is not None + if isinstance(item, int) + else getter(name=item) is not None + for item in items + ): return True raise BotMissingAnyRole(list(items)) + return check(predicate) + def has_permissions(**perms: bool) -> Callable[[T], T]: """A :func:`.check` that is added that checks if the member has all of the permissions necessary. @@ -1967,7 +2079,9 @@ def has_permissions(**perms: bool) -> Callable[[T], T]: ch = ctx.channel permissions = ch.permissions_for(ctx.author) # type: ignore - missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] + missing = [ + perm for perm, value in perms.items() if getattr(permissions, perm) != value + ] if not missing: return True @@ -1976,6 +2090,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]: return check(predicate) + def bot_has_permissions(**perms: bool) -> Callable[[T], T]: """Similar to :func:`.has_permissions` except checks if the bot itself has the permissions listed. @@ -1993,7 +2108,9 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]: me = guild.me if guild is not None else ctx.bot.user permissions = ctx.channel.permissions_for(me) # type: ignore - missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] + missing = [ + perm for perm, value in perms.items() if getattr(permissions, perm) != value + ] if not missing: return True @@ -2002,6 +2119,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]: return check(predicate) + def has_guild_permissions(**perms: bool) -> Callable[[T], T]: """Similar to :func:`.has_permissions`, but operates on guild wide permissions instead of the current channel permissions. @@ -2021,7 +2139,9 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]: raise NoPrivateMessage permissions = ctx.author.guild_permissions # type: ignore - missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] + missing = [ + perm for perm, value in perms.items() if getattr(permissions, perm) != value + ] if not missing: return True @@ -2030,6 +2150,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]: return check(predicate) + def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]: """Similar to :func:`.has_guild_permissions`, but checks the bot members guild permissions. @@ -2046,7 +2167,9 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]: raise NoPrivateMessage permissions = ctx.me.guild_permissions # type: ignore - missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] + missing = [ + perm for perm, value in perms.items() if getattr(permissions, perm) != value + ] if not missing: return True @@ -2055,6 +2178,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]: return check(predicate) + def dm_only() -> Callable[[T], T]: """A :func:`.check` that indicates this command must only be used in a DM context. Only private messages are allowed when @@ -2073,6 +2197,7 @@ def dm_only() -> Callable[[T], T]: return check(predicate) + def guild_only() -> Callable[[T], T]: """A :func:`.check` that indicates this command must only be used in a guild context only. Basically, no private messages are allowed when @@ -2089,6 +2214,7 @@ def guild_only() -> Callable[[T], T]: return check(predicate) + def is_owner() -> Callable[[T], T]: """A :func:`.check` that checks if the person invoking this command is the owner of the bot. @@ -2101,11 +2227,12 @@ def is_owner() -> Callable[[T], T]: async def predicate(ctx: Context) -> bool: if not await ctx.bot.is_owner(ctx.author): - raise NotOwner('You do not own this bot.') + raise NotOwner("You do not own this bot.") return True return check(predicate) + def is_nsfw() -> Callable[[T], T]: """A :func:`.check` that checks if the channel is a NSFW channel. @@ -2117,14 +2244,23 @@ def is_nsfw() -> Callable[[T], T]: Raise :exc:`.NSFWChannelRequired` instead of generic :exc:`.CheckFailure`. DM channels will also now pass this check. """ + def pred(ctx: Context) -> bool: ch = ctx.channel - if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()): + if ctx.guild is None or ( + isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw() + ): return True raise NSFWChannelRequired(ch) # type: ignore + return check(pred) -def cooldown(rate: int, per: float, type: Union[BucketType, Callable[[Message], Any]] = BucketType.default) -> Callable[[T], T]: + +def cooldown( + rate: int, + per: float, + type: Union[BucketType, Callable[[Message], Any]] = BucketType.default, +) -> Callable[[T], T]: """A decorator that adds a cooldown to a :class:`.Command` A cooldown allows a command to only be used a specific amount @@ -2157,9 +2293,14 @@ def cooldown(rate: int, per: float, type: Union[BucketType, Callable[[Message], else: func.__commands_cooldown__ = CooldownMapping(Cooldown(rate, per), type) return func + return decorator # type: ignore -def dynamic_cooldown(cooldown: Union[BucketType, Callable[[Message], Any]], type: BucketType = BucketType.default) -> Callable[[T], T]: + +def dynamic_cooldown( + cooldown: Union[BucketType, Callable[[Message], Any]], + type: BucketType = BucketType.default, +) -> Callable[[T], T]: """A decorator that adds a dynamic cooldown to a :class:`.Command` This differs from :func:`.cooldown` in that it takes a function that @@ -2197,9 +2338,13 @@ def dynamic_cooldown(cooldown: Union[BucketType, Callable[[Message], Any]], type else: func.__commands_cooldown__ = DynamicCooldownMapping(cooldown, type) return func + return decorator # type: ignore -def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait: bool = False) -> Callable[[T], T]: + +def max_concurrency( + number: int, per: BucketType = BucketType.default, *, wait: bool = False +) -> Callable[[T], T]: """A decorator that adds a maximum concurrency to a :class:`.Command` or its subclasses. This enables you to only allow a certain number of command invocations at the same time, @@ -2230,8 +2375,10 @@ def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait: else: func.__commands_max_concurrency__ = value return func + return decorator # type: ignore + def before_invoke(coro) -> Callable[[T], T]: """A decorator that registers a coroutine as a pre-invoke hook. @@ -2270,14 +2417,17 @@ def before_invoke(coro) -> Callable[[T], T]: bot.add_cog(What()) """ + def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: if isinstance(func, Command): func.before_invoke(coro) else: func.__before_invoke__ = coro return func + return decorator # type: ignore + def after_invoke(coro) -> Callable[[T], T]: """A decorator that registers a coroutine as a post-invoke hook. @@ -2286,10 +2436,12 @@ def after_invoke(coro) -> Callable[[T], T]: .. versionadded:: 1.4 """ + def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: if isinstance(func, Command): func.after_invoke(coro) else: func.__after_invoke__ = coro return func + return decorator # type: ignore diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index 93834385..b22959bd 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -41,65 +41,66 @@ if TYPE_CHECKING: __all__ = ( - 'CommandError', - 'MissingRequiredArgument', - 'BadArgument', - 'PrivateMessageOnly', - 'NoPrivateMessage', - 'CheckFailure', - 'CheckAnyFailure', - 'CommandNotFound', - 'DisabledCommand', - 'CommandInvokeError', - 'TooManyArguments', - 'UserInputError', - 'CommandOnCooldown', - 'MaxConcurrencyReached', - 'NotOwner', - 'MessageNotFound', - 'ObjectNotFound', - 'MemberNotFound', - 'GuildNotFound', - 'UserNotFound', - 'ChannelNotFound', - 'ThreadNotFound', - 'ChannelNotReadable', - 'BadColourArgument', - 'BadColorArgument', - 'RoleNotFound', - 'BadInviteArgument', - 'EmojiNotFound', - 'GuildStickerNotFound', - 'PartialEmojiConversionFailure', - 'BadBoolArgument', - 'MissingRole', - 'BotMissingRole', - 'MissingAnyRole', - 'BotMissingAnyRole', - 'MissingPermissions', - 'BotMissingPermissions', - 'NSFWChannelRequired', - 'ConversionError', - 'BadUnionArgument', - 'BadLiteralArgument', - 'ArgumentParsingError', - 'UnexpectedQuoteError', - 'InvalidEndOfQuotedStringError', - 'ExpectedClosingQuoteError', - 'ExtensionError', - 'ExtensionAlreadyLoaded', - 'ExtensionNotLoaded', - 'NoEntryPointError', - 'ExtensionFailed', - 'ExtensionNotFound', - 'CommandRegistrationError', - 'FlagError', - 'BadFlagArgument', - 'MissingFlagArgument', - 'TooManyFlags', - 'MissingRequiredFlag', + "CommandError", + "MissingRequiredArgument", + "BadArgument", + "PrivateMessageOnly", + "NoPrivateMessage", + "CheckFailure", + "CheckAnyFailure", + "CommandNotFound", + "DisabledCommand", + "CommandInvokeError", + "TooManyArguments", + "UserInputError", + "CommandOnCooldown", + "MaxConcurrencyReached", + "NotOwner", + "MessageNotFound", + "ObjectNotFound", + "MemberNotFound", + "GuildNotFound", + "UserNotFound", + "ChannelNotFound", + "ThreadNotFound", + "ChannelNotReadable", + "BadColourArgument", + "BadColorArgument", + "RoleNotFound", + "BadInviteArgument", + "EmojiNotFound", + "GuildStickerNotFound", + "PartialEmojiConversionFailure", + "BadBoolArgument", + "MissingRole", + "BotMissingRole", + "MissingAnyRole", + "BotMissingAnyRole", + "MissingPermissions", + "BotMissingPermissions", + "NSFWChannelRequired", + "ConversionError", + "BadUnionArgument", + "BadLiteralArgument", + "ArgumentParsingError", + "UnexpectedQuoteError", + "InvalidEndOfQuotedStringError", + "ExpectedClosingQuoteError", + "ExtensionError", + "ExtensionAlreadyLoaded", + "ExtensionNotLoaded", + "NoEntryPointError", + "ExtensionFailed", + "ExtensionNotFound", + "CommandRegistrationError", + "FlagError", + "BadFlagArgument", + "MissingFlagArgument", + "TooManyFlags", + "MissingRequiredFlag", ) + class CommandError(DiscordException): r"""The base exception type for all command related errors. @@ -109,14 +110,18 @@ class CommandError(DiscordException): in a special way as they are caught and passed into a special event from :class:`.Bot`\, :func:`.on_command_error`. """ + def __init__(self, message: Optional[str] = None, *args: Any) -> None: if message is not None: # clean-up @everyone and @here mentions - m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere') + m = message.replace("@everyone", "@\u200beveryone").replace( + "@here", "@\u200bhere" + ) super().__init__(m, *args) else: super().__init__(*args) + class ConversionError(CommandError): """Exception raised when a Converter class raises non-CommandError. @@ -130,18 +135,22 @@ class ConversionError(CommandError): The original exception that was raised. You can also get this via the ``__cause__`` attribute. """ + def __init__(self, converter: Converter, original: Exception) -> None: self.converter: Converter = converter self.original: Exception = original + class UserInputError(CommandError): """The base exception type for errors that involve errors regarding user input. This inherits from :exc:`CommandError`. """ + pass + class CommandNotFound(CommandError): """Exception raised when a command is attempted to be invoked but no command under that name is found. @@ -151,8 +160,10 @@ class CommandNotFound(CommandError): This inherits from :exc:`CommandError`. """ + pass + class MissingRequiredArgument(UserInputError): """Exception raised when parsing a command and a parameter that is required is not encountered. @@ -164,9 +175,11 @@ class MissingRequiredArgument(UserInputError): param: :class:`inspect.Parameter` The argument that is missing. """ + def __init__(self, param: Parameter) -> None: self.param: Parameter = param - super().__init__(f'{param.name} is a required argument that is missing.') + super().__init__(f"{param.name} is a required argument that is missing.") + class TooManyArguments(UserInputError): """Exception raised when the command was passed too many arguments and its @@ -174,23 +187,29 @@ class TooManyArguments(UserInputError): This inherits from :exc:`UserInputError` """ + pass + class BadArgument(UserInputError): """Exception raised when a parsing or conversion failure is encountered on an argument to pass into a command. This inherits from :exc:`UserInputError` """ + pass + class CheckFailure(CommandError): """Exception raised when the predicates in :attr:`.Command.checks` have failed. This inherits from :exc:`CommandError` """ + pass + class CheckAnyFailure(CheckFailure): """Exception raised when all predicates in :func:`check_any` fail. @@ -206,10 +225,13 @@ class CheckAnyFailure(CheckFailure): A list of check predicates that failed. """ - def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]) -> None: + def __init__( + self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]] + ) -> None: self.checks: List[CheckFailure] = checks self.errors: List[Callable[[Context], bool]] = errors - super().__init__('You do not have permission to run this command.') + super().__init__("You do not have permission to run this command.") + class PrivateMessageOnly(CheckFailure): """Exception raised when an operation does not work outside of private @@ -217,8 +239,12 @@ class PrivateMessageOnly(CheckFailure): This inherits from :exc:`CheckFailure` """ + def __init__(self, message: Optional[str] = None) -> None: - super().__init__(message or 'This command can only be used in private messages.') + super().__init__( + message or "This command can only be used in private messages." + ) + class NoPrivateMessage(CheckFailure): """Exception raised when an operation does not work in private message @@ -228,15 +254,18 @@ class NoPrivateMessage(CheckFailure): """ def __init__(self, message: Optional[str] = None) -> None: - super().__init__(message or 'This command cannot be used in private messages.') + super().__init__(message or "This command cannot be used in private messages.") + class NotOwner(CheckFailure): """Exception raised when the message author is not the owner of the bot. This inherits from :exc:`CheckFailure` """ + pass + class ObjectNotFound(BadArgument): """Exception raised when the argument provided did not match the format of an ID or a mention. @@ -250,9 +279,11 @@ class ObjectNotFound(BadArgument): argument: :class:`str` The argument supplied by the caller that was not matched """ + def __init__(self, argument: str) -> None: self.argument: str = argument - super().__init__(f'{argument!r} does not follow a valid ID or mention format.') + super().__init__(f"{argument!r} does not follow a valid ID or mention format.") + class MemberNotFound(BadArgument): """Exception raised when the member provided was not found in the bot's @@ -267,10 +298,12 @@ class MemberNotFound(BadArgument): argument: :class:`str` The member supplied by the caller that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Member "{argument}" not found.') + class GuildNotFound(BadArgument): """Exception raised when the guild provided was not found in the bot's cache. @@ -283,10 +316,12 @@ class GuildNotFound(BadArgument): argument: :class:`str` The guild supplied by the called that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Guild "{argument}" not found.') + class UserNotFound(BadArgument): """Exception raised when the user provided was not found in the bot's cache. @@ -300,10 +335,12 @@ class UserNotFound(BadArgument): argument: :class:`str` The user supplied by the caller that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'User "{argument}" not found.') + class MessageNotFound(BadArgument): """Exception raised when the message provided was not found in the channel. @@ -316,10 +353,12 @@ class MessageNotFound(BadArgument): argument: :class:`str` The message supplied by the caller that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Message "{argument}" not found.') + class ChannelNotReadable(BadArgument): """Exception raised when the bot does not have permission to read messages in the channel. @@ -333,10 +372,12 @@ class ChannelNotReadable(BadArgument): argument: Union[:class:`.abc.GuildChannel`, :class:`.Thread`] The channel supplied by the caller that was not readable """ + def __init__(self, argument: Union[GuildChannel, Thread]) -> None: self.argument: Union[GuildChannel, Thread] = argument super().__init__(f"Can't read messages in {argument.mention}.") + class ChannelNotFound(BadArgument): """Exception raised when the bot can not find the channel. @@ -349,10 +390,12 @@ class ChannelNotFound(BadArgument): argument: :class:`str` The channel supplied by the caller that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Channel "{argument}" not found.') + class ThreadNotFound(BadArgument): """Exception raised when the bot can not find the thread. @@ -365,10 +408,12 @@ class ThreadNotFound(BadArgument): argument: :class:`str` The thread supplied by the caller that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Thread "{argument}" not found.') + class BadColourArgument(BadArgument): """Exception raised when the colour is not valid. @@ -381,12 +426,15 @@ class BadColourArgument(BadArgument): argument: :class:`str` The colour supplied by the caller that was not valid """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Colour "{argument}" is invalid.') + BadColorArgument = BadColourArgument + class RoleNotFound(BadArgument): """Exception raised when the bot can not find the role. @@ -399,10 +447,12 @@ class RoleNotFound(BadArgument): argument: :class:`str` The role supplied by the caller that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Role "{argument}" not found.') + class BadInviteArgument(BadArgument): """Exception raised when the invite is invalid or expired. @@ -410,10 +460,12 @@ class BadInviteArgument(BadArgument): .. versionadded:: 1.5 """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Invite "{argument}" is invalid or expired.') + class EmojiNotFound(BadArgument): """Exception raised when the bot can not find the emoji. @@ -426,10 +478,12 @@ class EmojiNotFound(BadArgument): argument: :class:`str` The emoji supplied by the caller that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Emoji "{argument}" not found.') + class PartialEmojiConversionFailure(BadArgument): """Exception raised when the emoji provided does not match the correct format. @@ -443,10 +497,12 @@ class PartialEmojiConversionFailure(BadArgument): argument: :class:`str` The emoji supplied by the caller that did not match the regex """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Couldn\'t convert "{argument}" to PartialEmoji.') + class GuildStickerNotFound(BadArgument): """Exception raised when the bot can not find the sticker. @@ -459,10 +515,12 @@ class GuildStickerNotFound(BadArgument): argument: :class:`str` The sticker supplied by the caller that was not found """ + def __init__(self, argument: str) -> None: self.argument: str = argument super().__init__(f'Sticker "{argument}" not found.') + class BadBoolArgument(BadArgument): """Exception raised when a boolean argument was not convertable. @@ -475,17 +533,21 @@ class BadBoolArgument(BadArgument): argument: :class:`str` The boolean argument supplied by the caller that is not in the predefined list """ + def __init__(self, argument: str) -> None: self.argument: str = argument - super().__init__(f'{argument} is not a recognised boolean option') + super().__init__(f"{argument} is not a recognised boolean option") + class DisabledCommand(CommandError): """Exception raised when the command being invoked is disabled. This inherits from :exc:`CommandError` """ + pass + class CommandInvokeError(CommandError): """Exception raised when the command being invoked raised an exception. @@ -497,9 +559,11 @@ class CommandInvokeError(CommandError): The original exception that was raised. You can also get this via the ``__cause__`` attribute. """ + def __init__(self, e: Exception) -> None: self.original: Exception = e - super().__init__(f'Command raised an exception: {e.__class__.__name__}: {e}') + super().__init__(f"Command raised an exception: {e.__class__.__name__}: {e}") + class CommandOnCooldown(CommandError): """Exception raised when the command being invoked is on cooldown. @@ -516,11 +580,15 @@ class CommandOnCooldown(CommandError): retry_after: :class:`float` The amount of seconds to wait before you can retry again. """ - def __init__(self, cooldown: Cooldown, retry_after: float, type: BucketType) -> None: + + def __init__( + self, cooldown: Cooldown, retry_after: float, type: BucketType + ) -> None: self.cooldown: Cooldown = cooldown self.retry_after: float = retry_after self.type: BucketType = type - super().__init__(f'You are on cooldown. Try again in {retry_after:.2f}s') + super().__init__(f"You are on cooldown. Try again in {retry_after:.2f}s") + class MaxConcurrencyReached(CommandError): """Exception raised when the command being invoked has reached its maximum concurrency. @@ -539,10 +607,13 @@ class MaxConcurrencyReached(CommandError): self.number: int = number self.per: BucketType = per name = per.name - suffix = 'per %s' % name if per.name != 'default' else 'globally' - plural = '%s times %s' if number > 1 else '%s time %s' + suffix = "per %s" % name if per.name != "default" else "globally" + plural = "%s times %s" if number > 1 else "%s time %s" fmt = plural % (number, suffix) - super().__init__(f'Too many people are using this command. It can only be used {fmt} concurrently.') + super().__init__( + f"Too many people are using this command. It can only be used {fmt} concurrently." + ) + class MissingRole(CheckFailure): """Exception raised when the command invoker lacks a role to run a command. @@ -557,11 +628,13 @@ class MissingRole(CheckFailure): The required role that is missing. This is the parameter passed to :func:`~.commands.has_role`. """ + def __init__(self, missing_role: Snowflake) -> None: self.missing_role: Snowflake = missing_role - message = f'Role {missing_role!r} is required to run this command.' + message = f"Role {missing_role!r} is required to run this command." super().__init__(message) + class BotMissingRole(CheckFailure): """Exception raised when the bot's member lacks a role to run a command. @@ -575,11 +648,13 @@ class BotMissingRole(CheckFailure): The required role that is missing. This is the parameter passed to :func:`~.commands.has_role`. """ + def __init__(self, missing_role: Snowflake) -> None: self.missing_role: Snowflake = missing_role - message = f'Bot requires the role {missing_role!r} to run this command' + message = f"Bot requires the role {missing_role!r} to run this command" super().__init__(message) + class MissingAnyRole(CheckFailure): """Exception raised when the command invoker lacks any of the roles specified to run a command. @@ -594,15 +669,16 @@ class MissingAnyRole(CheckFailure): The roles that the invoker is missing. These are the parameters passed to :func:`~.commands.has_any_role`. """ + def __init__(self, missing_roles: SnowflakeList) -> None: self.missing_roles: SnowflakeList = missing_roles missing = [f"'{role}'" for role in missing_roles] if len(missing) > 2: - fmt = '{}, or {}'.format(", ".join(missing[:-1]), missing[-1]) + fmt = "{}, or {}".format(", ".join(missing[:-1]), missing[-1]) else: - fmt = ' or '.join(missing) + fmt = " or ".join(missing) message = f"You are missing at least one of the required roles: {fmt}" super().__init__(message) @@ -623,19 +699,21 @@ class BotMissingAnyRole(CheckFailure): These are the parameters passed to :func:`~.commands.has_any_role`. """ + def __init__(self, missing_roles: SnowflakeList) -> None: self.missing_roles: SnowflakeList = missing_roles missing = [f"'{role}'" for role in missing_roles] if len(missing) > 2: - fmt = '{}, or {}'.format(", ".join(missing[:-1]), missing[-1]) + fmt = "{}, or {}".format(", ".join(missing[:-1]), missing[-1]) else: - fmt = ' or '.join(missing) + fmt = " or ".join(missing) message = f"Bot is missing at least one of the required roles: {fmt}" super().__init__(message) + class NSFWChannelRequired(CheckFailure): """Exception raised when a channel does not have the required NSFW setting. @@ -648,9 +726,13 @@ class NSFWChannelRequired(CheckFailure): channel: Union[:class:`.abc.GuildChannel`, :class:`.Thread`] The channel that does not have NSFW enabled. """ + def __init__(self, channel: Union[GuildChannel, Thread]) -> None: self.channel: Union[GuildChannel, Thread] = channel - super().__init__(f"Channel '{channel}' needs to be NSFW for this command to work.") + super().__init__( + f"Channel '{channel}' needs to be NSFW for this command to work." + ) + class MissingPermissions(CheckFailure): """Exception raised when the command invoker lacks permissions to run a @@ -663,18 +745,23 @@ class MissingPermissions(CheckFailure): missing_permissions: List[:class:`str`] The required permissions that are missing. """ + def __init__(self, missing_permissions: List[str], *args: Any) -> None: self.missing_permissions: List[str] = missing_permissions - missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions] + missing = [ + perm.replace("_", " ").replace("guild", "server").title() + for perm in missing_permissions + ] if len(missing) > 2: - fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1]) + fmt = "{}, and {}".format(", ".join(missing[:-1]), missing[-1]) else: - fmt = ' and '.join(missing) - message = f'You are missing {fmt} permission(s) to run this command.' + fmt = " and ".join(missing) + message = f"You are missing {fmt} permission(s) to run this command." super().__init__(message, *args) + class BotMissingPermissions(CheckFailure): """Exception raised when the bot's member lacks permissions to run a command. @@ -686,18 +773,23 @@ class BotMissingPermissions(CheckFailure): missing_permissions: List[:class:`str`] The required permissions that are missing. """ + def __init__(self, missing_permissions: List[str], *args: Any) -> None: self.missing_permissions: List[str] = missing_permissions - missing = [perm.replace('_', ' ').replace('guild', 'server').title() for perm in missing_permissions] + missing = [ + perm.replace("_", " ").replace("guild", "server").title() + for perm in missing_permissions + ] if len(missing) > 2: - fmt = '{}, and {}'.format(", ".join(missing[:-1]), missing[-1]) + fmt = "{}, and {}".format(", ".join(missing[:-1]), missing[-1]) else: - fmt = ' and '.join(missing) - message = f'Bot requires {fmt} permission(s) to run this command.' + fmt = " and ".join(missing) + message = f"Bot requires {fmt} permission(s) to run this command." super().__init__(message, *args) + class BadUnionArgument(UserInputError): """Exception raised when a :data:`typing.Union` converter fails for all its associated types. @@ -713,7 +805,10 @@ class BadUnionArgument(UserInputError): errors: List[:class:`CommandError`] A list of errors that were caught from failing the conversion. """ - def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None: + + def __init__( + self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError] + ) -> None: self.param: Parameter = param self.converters: Tuple[Type, ...] = converters self.errors: List[CommandError] = errors @@ -722,18 +817,19 @@ class BadUnionArgument(UserInputError): try: return x.__name__ except AttributeError: - if hasattr(x, '__origin__'): + if hasattr(x, "__origin__"): return repr(x) return x.__class__.__name__ to_string = [_get_name(x) for x in converters] if len(to_string) > 2: - fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1]) + fmt = "{}, or {}".format(", ".join(to_string[:-1]), to_string[-1]) else: - fmt = ' or '.join(to_string) + fmt = " or ".join(to_string) super().__init__(f'Could not convert "{param.name}" into {fmt}.') + class BadLiteralArgument(UserInputError): """Exception raised when a :data:`typing.Literal` converter fails for all its associated values. @@ -751,19 +847,23 @@ class BadLiteralArgument(UserInputError): errors: List[:class:`CommandError`] A list of errors that were caught from failing the conversion. """ - def __init__(self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError]) -> None: + + def __init__( + self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError] + ) -> None: self.param: Parameter = param self.literals: Tuple[Any, ...] = literals self.errors: List[CommandError] = errors to_string = [repr(l) for l in literals] if len(to_string) > 2: - fmt = '{}, or {}'.format(', '.join(to_string[:-1]), to_string[-1]) + fmt = "{}, or {}".format(", ".join(to_string[:-1]), to_string[-1]) else: - fmt = ' or '.join(to_string) + fmt = " or ".join(to_string) super().__init__(f'Could not convert "{param.name}" into the literal {fmt}.') + class ArgumentParsingError(UserInputError): """An exception raised when the parser fails to parse a user's input. @@ -772,8 +872,10 @@ class ArgumentParsingError(UserInputError): There are child classes that implement more granular parsing errors for i18n purposes. """ + pass + class UnexpectedQuoteError(ArgumentParsingError): """An exception raised when the parser encounters a quote mark inside a non-quoted string. @@ -784,9 +886,11 @@ class UnexpectedQuoteError(ArgumentParsingError): quote: :class:`str` The quote mark that was found inside the non-quoted string. """ + def __init__(self, quote: str) -> None: self.quote: str = quote - super().__init__(f'Unexpected quote mark, {quote!r}, in non-quoted string') + super().__init__(f"Unexpected quote mark, {quote!r}, in non-quoted string") + class InvalidEndOfQuotedStringError(ArgumentParsingError): """An exception raised when a space is expected after the closing quote in a string @@ -799,9 +903,13 @@ class InvalidEndOfQuotedStringError(ArgumentParsingError): char: :class:`str` The character found instead of the expected string. """ + def __init__(self, char: str) -> None: self.char: str = char - super().__init__(f'Expected space after closing quotation but received {char!r}') + super().__init__( + f"Expected space after closing quotation but received {char!r}" + ) + class ExpectedClosingQuoteError(ArgumentParsingError): """An exception raised when a quote character is expected but not found. @@ -816,7 +924,8 @@ class ExpectedClosingQuoteError(ArgumentParsingError): def __init__(self, close_quote: str) -> None: self.close_quote: str = close_quote - super().__init__(f'Expected closing {close_quote}.') + super().__init__(f"Expected closing {close_quote}.") + class ExtensionError(DiscordException): """Base exception for extension related errors. @@ -828,37 +937,47 @@ class ExtensionError(DiscordException): name: :class:`str` The extension that had an error. """ + def __init__(self, message: Optional[str] = None, *args: Any, name: str) -> None: self.name: str = name - message = message or f'Extension {name!r} had an error.' + message = message or f"Extension {name!r} had an error." # clean-up @everyone and @here mentions - m = message.replace('@everyone', '@\u200beveryone').replace('@here', '@\u200bhere') + m = message.replace("@everyone", "@\u200beveryone").replace( + "@here", "@\u200bhere" + ) super().__init__(m, *args) + class ExtensionAlreadyLoaded(ExtensionError): """An exception raised when an extension has already been loaded. This inherits from :exc:`ExtensionError` """ + def __init__(self, name: str) -> None: - super().__init__(f'Extension {name!r} is already loaded.', name=name) + super().__init__(f"Extension {name!r} is already loaded.", name=name) + class ExtensionNotLoaded(ExtensionError): """An exception raised when an extension was not loaded. This inherits from :exc:`ExtensionError` """ + def __init__(self, name: str) -> None: - super().__init__(f'Extension {name!r} has not been loaded.', name=name) + super().__init__(f"Extension {name!r} has not been loaded.", name=name) + class NoEntryPointError(ExtensionError): """An exception raised when an extension does not have a ``setup`` entry point function. This inherits from :exc:`ExtensionError` """ + def __init__(self, name: str) -> None: super().__init__(f"Extension {name!r} has no 'setup' function.", name=name) + class ExtensionFailed(ExtensionError): """An exception raised when an extension failed to load during execution of the module or ``setup`` entry point. @@ -872,11 +991,13 @@ class ExtensionFailed(ExtensionError): The original exception that was raised. You can also get this via the ``__cause__`` attribute. """ + def __init__(self, name: str, original: Exception) -> None: self.original: Exception = original - msg = f'Extension {name!r} raised an error: {original.__class__.__name__}: {original}' + msg = f"Extension {name!r} raised an error: {original.__class__.__name__}: {original}" super().__init__(msg, name=name) + class ExtensionNotFound(ExtensionError): """An exception raised when an extension is not found. @@ -890,10 +1011,12 @@ class ExtensionNotFound(ExtensionError): name: :class:`str` The extension that had the error. """ + def __init__(self, name: str) -> None: - msg = f'Extension {name!r} could not be loaded.' + msg = f"Extension {name!r} could not be loaded." super().__init__(msg, name=name) + class CommandRegistrationError(ClientException): """An exception raised when the command can't be added because the name is already taken by a different command. @@ -909,11 +1032,13 @@ class CommandRegistrationError(ClientException): alias_conflict: :class:`bool` Whether the name that conflicts is an alias of the command we try to add. """ + def __init__(self, name: str, *, alias_conflict: bool = False) -> None: self.name: str = name self.alias_conflict: bool = alias_conflict - type_ = 'alias' if alias_conflict else 'command' - super().__init__(f'The {type_} {name} is already an existing command or alias.') + type_ = "alias" if alias_conflict else "command" + super().__init__(f"The {type_} {name} is already an existing command or alias.") + class FlagError(BadArgument): """The base exception type for all flag parsing related errors. @@ -922,8 +1047,10 @@ class FlagError(BadArgument): .. versionadded:: 2.0 """ + pass + class TooManyFlags(FlagError): """An exception raised when a flag has received too many values. @@ -938,10 +1065,14 @@ class TooManyFlags(FlagError): values: List[:class:`str`] The values that were passed. """ + def __init__(self, flag: Flag, values: List[str]) -> None: self.flag: Flag = flag self.values: List[str] = values - super().__init__(f'Too many flag values, expected {flag.max_args} but received {len(values)}.') + super().__init__( + f"Too many flag values, expected {flag.max_args} but received {len(values)}." + ) + class BadFlagArgument(FlagError): """An exception raised when a flag failed to convert a value. @@ -955,6 +1086,7 @@ class BadFlagArgument(FlagError): flag: :class:`~discord.ext.commands.Flag` The flag that failed to convert. """ + def __init__(self, flag: Flag) -> None: self.flag: Flag = flag try: @@ -962,7 +1094,8 @@ class BadFlagArgument(FlagError): except AttributeError: name = flag.annotation.__class__.__name__ - super().__init__(f'Could not convert to {name!r} for flag {flag.name!r}') + super().__init__(f"Could not convert to {name!r} for flag {flag.name!r}") + class MissingRequiredFlag(FlagError): """An exception raised when a required flag was not given. @@ -976,9 +1109,11 @@ class MissingRequiredFlag(FlagError): flag: :class:`~discord.ext.commands.Flag` The required flag that was not found. """ + def __init__(self, flag: Flag) -> None: self.flag: Flag = flag - super().__init__(f'Flag {flag.name!r} is required and missing') + super().__init__(f"Flag {flag.name!r} is required and missing") + class MissingFlagArgument(FlagError): """An exception raised when a flag did not get a value. @@ -992,6 +1127,7 @@ class MissingFlagArgument(FlagError): flag: :class:`~discord.ext.commands.Flag` The flag that did not get a value. """ + def __init__(self, flag: Flag) -> None: self.flag: Flag = flag - super().__init__(f'Flag {flag.name!r} does not have an argument') + super().__init__(f"Flag {flag.name!r} does not have an argument") diff --git a/discord/ext/commands/flags.py b/discord/ext/commands/flags.py index b356af34..52c62c0a 100644 --- a/discord/ext/commands/flags.py +++ b/discord/ext/commands/flags.py @@ -59,9 +59,9 @@ import sys import re __all__ = ( - 'Flag', - 'flag', - 'FlagConverter', + "Flag", + "flag", + "FlagConverter", ) @@ -143,25 +143,35 @@ def flag( Whether multiple given values overrides the previous value. The default value depends on the annotation given. """ - return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override) + return Flag( + name=name, + aliases=aliases, + default=default, + max_args=max_args, + override=override, + ) def validate_flag_name(name: str, forbidden: Set[str]): if not name: - raise ValueError('flag names should not be empty') + raise ValueError("flag names should not be empty") for ch in name: if ch.isspace(): - raise ValueError(f'flag name {name!r} cannot have spaces') - if ch == '\\': - raise ValueError(f'flag name {name!r} cannot have backslashes') + raise ValueError(f"flag name {name!r} cannot have spaces") + if ch == "\\": + raise ValueError(f"flag name {name!r} cannot have backslashes") if ch in forbidden: - raise ValueError(f'flag name {name!r} cannot have any of {forbidden!r} within them') + raise ValueError( + f"flag name {name!r} cannot have any of {forbidden!r} within them" + ) -def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]: - annotations = namespace.get('__annotations__', {}) - case_insensitive = namespace['__commands_flag_case_insensitive__'] +def get_flags( + namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any] +) -> Dict[str, Flag]: + annotations = namespace.get("__annotations__", {}) + case_insensitive = namespace["__commands_flag_case_insensitive__"] flags: Dict[str, Flag] = {} cache: Dict[str, Any] = {} names: Set[str] = set() @@ -176,9 +186,15 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s if flag.name is MISSING: flag.name = name - annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache) + annotation = flag.annotation = resolve_annotation( + flag.annotation, globals, locals, cache + ) - if flag.default is MISSING and hasattr(annotation, '__commands_is_flag__') and annotation._can_be_constructible(): + if ( + flag.default is MISSING + and hasattr(annotation, "__commands_is_flag__") + and annotation._can_be_constructible() + ): flag.default = annotation._construct_default if flag.aliases is MISSING: @@ -229,7 +245,9 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s if flag.max_args is MISSING: flag.max_args = 1 else: - raise TypeError(f'Unsupported typing annotation {annotation!r} for {flag.name!r} flag') + raise TypeError( + f"Unsupported typing annotation {annotation!r} for {flag.name!r} flag" + ) if flag.override is MISSING: flag.override = False @@ -237,7 +255,9 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s # Validate flag names are unique name = flag.name.casefold() if case_insensitive else flag.name if name in names: - raise TypeError(f'{flag.name!r} flag conflicts with previous flag or alias.') + raise TypeError( + f"{flag.name!r} flag conflicts with previous flag or alias." + ) else: names.add(name) @@ -245,7 +265,9 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s # Validate alias is unique alias = alias.casefold() if case_insensitive else alias if alias in names: - raise TypeError(f'{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.') + raise TypeError( + f"{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias." + ) else: names.add(alias) @@ -274,10 +296,10 @@ class FlagsMeta(type): delimiter: str = MISSING, prefix: str = MISSING, ): - attrs['__commands_is_flag__'] = True + attrs["__commands_is_flag__"] = True try: - global_ns = sys.modules[attrs['__module__']].__dict__ + global_ns = sys.modules[attrs["__module__"]].__dict__ except KeyError: global_ns = {} @@ -296,26 +318,32 @@ class FlagsMeta(type): flags: Dict[str, Flag] = {} aliases: Dict[str, str] = {} for base in reversed(bases): - if base.__dict__.get('__commands_is_flag__', False): - flags.update(base.__dict__['__commands_flags__']) - aliases.update(base.__dict__['__commands_flag_aliases__']) + if base.__dict__.get("__commands_is_flag__", False): + flags.update(base.__dict__["__commands_flags__"]) + aliases.update(base.__dict__["__commands_flag_aliases__"]) if case_insensitive is MISSING: - attrs['__commands_flag_case_insensitive__'] = base.__dict__['__commands_flag_case_insensitive__'] + attrs["__commands_flag_case_insensitive__"] = base.__dict__[ + "__commands_flag_case_insensitive__" + ] if delimiter is MISSING: - attrs['__commands_flag_delimiter__'] = base.__dict__['__commands_flag_delimiter__'] + attrs["__commands_flag_delimiter__"] = base.__dict__[ + "__commands_flag_delimiter__" + ] if prefix is MISSING: - attrs['__commands_flag_prefix__'] = base.__dict__['__commands_flag_prefix__'] + attrs["__commands_flag_prefix__"] = base.__dict__[ + "__commands_flag_prefix__" + ] if case_insensitive is not MISSING: - attrs['__commands_flag_case_insensitive__'] = case_insensitive + attrs["__commands_flag_case_insensitive__"] = case_insensitive if delimiter is not MISSING: - attrs['__commands_flag_delimiter__'] = delimiter + attrs["__commands_flag_delimiter__"] = delimiter if prefix is not MISSING: - attrs['__commands_flag_prefix__'] = prefix + attrs["__commands_flag_prefix__"] = prefix - case_insensitive = attrs.setdefault('__commands_flag_case_insensitive__', False) - delimiter = attrs.setdefault('__commands_flag_delimiter__', ':') - prefix = attrs.setdefault('__commands_flag_prefix__', '') + case_insensitive = attrs.setdefault("__commands_flag_case_insensitive__", False) + delimiter = attrs.setdefault("__commands_flag_delimiter__", ":") + prefix = attrs.setdefault("__commands_flag_prefix__", "") for flag_name, flag in get_flags(attrs, global_ns, local_ns).items(): flags[flag_name] = flag @@ -330,23 +358,30 @@ class FlagsMeta(type): regex_flags = 0 if case_insensitive: flags = {key.casefold(): value for key, value in flags.items()} - aliases = {key.casefold(): value.casefold() for key, value in aliases.items()} + aliases = { + key.casefold(): value.casefold() for key, value in aliases.items() + } regex_flags = re.IGNORECASE keys = list(re.escape(k) for k in flags) keys.extend(re.escape(a) for a in aliases) keys = sorted(keys, key=lambda t: len(t), reverse=True) - joined = '|'.join(keys) - pattern = re.compile(f'(({re.escape(prefix)})(?P{joined}){re.escape(delimiter)})', regex_flags) - attrs['__commands_flag_regex__'] = pattern - attrs['__commands_flags__'] = flags - attrs['__commands_flag_aliases__'] = aliases + joined = "|".join(keys) + pattern = re.compile( + f"(({re.escape(prefix)})(?P{joined}){re.escape(delimiter)})", + regex_flags, + ) + attrs["__commands_flag_regex__"] = pattern + attrs["__commands_flags__"] = flags + attrs["__commands_flag_aliases__"] = aliases return type.__new__(cls, name, bases, attrs) -async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]: +async def tuple_convert_all( + ctx: Context, argument: str, flag: Flag, converter: Any +) -> Tuple[Any, ...]: view = StringView(argument) results = [] param: inspect.Parameter = ctx.current_parameter # type: ignore @@ -371,7 +406,9 @@ async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter: return tuple(results) -async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]: +async def tuple_convert_flag( + ctx: Context, argument: str, flag: Flag, converters: Any +) -> Tuple[Any, ...]: view = StringView(argument) results = [] param: inspect.Parameter = ctx.current_parameter # type: ignore @@ -409,9 +446,13 @@ async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) - else: if origin is tuple: if annotation.__args__[-1] is Ellipsis: - return await tuple_convert_all(ctx, argument, flag, annotation.__args__[0]) + return await tuple_convert_all( + ctx, argument, flag, annotation.__args__[0] + ) else: - return await tuple_convert_flag(ctx, argument, flag, annotation.__args__) + return await tuple_convert_flag( + ctx, argument, flag, annotation.__args__ + ) elif origin is list: # typing.List[x] annotation = annotation.__args__[0] @@ -432,7 +473,7 @@ async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) - raise BadFlagArgument(flag) from e -F = TypeVar('F', bound='FlagConverter') +F = TypeVar("F", bound="FlagConverter") class FlagConverter(metaclass=FlagsMeta): @@ -493,8 +534,13 @@ class FlagConverter(metaclass=FlagsMeta): return self def __repr__(self) -> str: - pairs = ' '.join([f'{flag.attribute}={getattr(self, flag.attribute)!r}' for flag in self.get_flags().values()]) - return f'<{self.__class__.__name__} {pairs}>' + pairs = " ".join( + [ + f"{flag.attribute}={getattr(self, flag.attribute)!r}" + for flag in self.get_flags().values() + ] + ) + return f"<{self.__class__.__name__} {pairs}>" @classmethod def parse_flags(cls, argument: str) -> Dict[str, List[str]]: @@ -507,7 +553,7 @@ class FlagConverter(metaclass=FlagsMeta): case_insensitive = cls.__commands_flag_case_insensitive__ for match in cls.__commands_flag_regex__.finditer(argument): begin, end = match.span(0) - key = match.group('flag') + key = match.group("flag") if case_insensitive: key = key.casefold() diff --git a/discord/ext/commands/help.py b/discord/ext/commands/help.py index afaacbfb..37c6e226 100644 --- a/discord/ext/commands/help.py +++ b/discord/ext/commands/help.py @@ -39,10 +39,10 @@ if TYPE_CHECKING: from .context import Context __all__ = ( - 'Paginator', - 'HelpCommand', - 'DefaultHelpCommand', - 'MinimalHelpCommand', + "Paginator", + "HelpCommand", + "DefaultHelpCommand", + "MinimalHelpCommand", ) # help -> shows info of bot on top/bottom and lists subcommands @@ -89,7 +89,7 @@ class Paginator: .. versionadded:: 1.7 """ - def __init__(self, prefix='```', suffix='```', max_size=2000, linesep='\n'): + def __init__(self, prefix="```", suffix="```", max_size=2000, linesep="\n"): self.prefix = prefix self.suffix = suffix self.max_size = max_size @@ -118,7 +118,7 @@ class Paginator: def _linesep_len(self): return len(self.linesep) - def add_line(self, line='', *, empty=False): + def add_line(self, line="", *, empty=False): """Adds a line to the current page. If the line exceeds the :attr:`max_size` then an exception @@ -136,18 +136,23 @@ class Paginator: RuntimeError The line was too big for the current :attr:`max_size`. """ - max_page_size = self.max_size - self._prefix_len - self._suffix_len - 2 * self._linesep_len + max_page_size = ( + self.max_size - self._prefix_len - self._suffix_len - 2 * self._linesep_len + ) if len(line) > max_page_size: - raise RuntimeError(f'Line exceeds maximum page size {max_page_size}') + raise RuntimeError(f"Line exceeds maximum page size {max_page_size}") - if self._count + len(line) + self._linesep_len > self.max_size - self._suffix_len: + if ( + self._count + len(line) + self._linesep_len + > self.max_size - self._suffix_len + ): self.close_page() self._count += len(line) + self._linesep_len self._current_page.append(line) if empty: - self._current_page.append('') + self._current_page.append("") self._count += self._linesep_len def close_page(self): @@ -176,7 +181,7 @@ class Paginator: return self._pages def __repr__(self): - fmt = '' + fmt = "" return fmt.format(self) @@ -197,7 +202,7 @@ class _HelpCommandImpl(Command): self.callback = injected.command_callback on_error = injected.on_help_command_error - if not hasattr(on_error, '__help_command_not_overriden__'): + if not hasattr(on_error, "__help_command_not_overriden__"): if self.cog is not None: self.on_error = self._on_error_cog_implementation else: @@ -224,7 +229,7 @@ class _HelpCommandImpl(Command): try: del result[next(iter(result))] except StopIteration: - raise ValueError('Missing context parameter') from None + raise ValueError("Missing context parameter") from None else: return result @@ -296,13 +301,13 @@ class HelpCommand: """ MENTION_TRANSFORMS = { - '@everyone': '@\u200beveryone', - '@here': '@\u200bhere', - r'<@!?[0-9]{17,22}>': '@deleted-user', - r'<@&[0-9]{17,22}>': '@deleted-role', + "@everyone": "@\u200beveryone", + "@here": "@\u200bhere", + r"<@!?[0-9]{17,22}>": "@deleted-user", + r"<@&[0-9]{17,22}>": "@deleted-role", } - MENTION_PATTERN = re.compile('|'.join(MENTION_TRANSFORMS.keys())) + MENTION_PATTERN = re.compile("|".join(MENTION_TRANSFORMS.keys())) def __new__(cls, *args, **kwargs): # To prevent race conditions of a single instance while also allowing @@ -321,11 +326,11 @@ class HelpCommand: return self def __init__(self, **options): - self.show_hidden = options.pop('show_hidden', False) - self.verify_checks = options.pop('verify_checks', True) - self.command_attrs = attrs = options.pop('command_attrs', {}) - attrs.setdefault('name', 'help') - attrs.setdefault('help', 'Shows this message') + self.show_hidden = options.pop("show_hidden", False) + self.verify_checks = options.pop("verify_checks", True) + self.command_attrs = attrs = options.pop("command_attrs", {}) + attrs.setdefault("name", "help") + attrs.setdefault("help", "Shows this message") self.context: Context = discord.utils.MISSING self._command_impl = _HelpCommandImpl(self, **self.command_attrs) @@ -398,7 +403,11 @@ class HelpCommand: """ command_name = self._command_impl.name ctx = self.context - if ctx is None or ctx.command is None or ctx.command.qualified_name != command_name: + if ( + ctx is None + or ctx.command is None + or ctx.command.qualified_name != command_name + ): return command_name return ctx.invoked_with @@ -422,20 +431,20 @@ class HelpCommand: if not parent.signature or parent.invoke_without_command: entries.append(parent.name) else: - entries.append(parent.name + ' ' + parent.signature) + entries.append(parent.name + " " + parent.signature) parent = parent.parent - parent_sig = ' '.join(reversed(entries)) + parent_sig = " ".join(reversed(entries)) if len(command.aliases) > 0: - aliases = '|'.join(command.aliases) - fmt = f'[{command.name}|{aliases}]' + aliases = "|".join(command.aliases) + fmt = f"[{command.name}|{aliases}]" if parent_sig: - fmt = parent_sig + ' ' + fmt + fmt = parent_sig + " " + fmt alias = fmt else: - alias = command.name if not parent_sig else parent_sig + ' ' + command.name + alias = command.name if not parent_sig else parent_sig + " " + command.name - return f'{self.context.clean_prefix}{alias} {command.signature}' + return f"{self.context.clean_prefix}{alias} {command.signature}" def remove_mentions(self, string): """Removes mentions from the string to prevent abuse. @@ -449,7 +458,7 @@ class HelpCommand: """ def replace(obj, *, transforms=self.MENTION_TRANSFORMS): - return transforms.get(obj.group(0), '@invalid') + return transforms.get(obj.group(0), "@invalid") return self.MENTION_PATTERN.sub(replace, string) @@ -527,7 +536,9 @@ class HelpCommand: The string to use when the command did not have the subcommand requested. """ if isinstance(command, Group) and len(command.all_commands) > 0: - return f'Command "{command.qualified_name}" has no subcommand named {string}' + return ( + f'Command "{command.qualified_name}" has no subcommand named {string}' + ) return f'Command "{command.qualified_name}" has no subcommands.' async def filter_commands(self, commands, *, sort=False, key=None): @@ -558,7 +569,9 @@ class HelpCommand: if sort and key is None: key = lambda c: c.name - iterator = commands if self.show_hidden else filter(lambda c: not c.hidden, commands) + iterator = ( + commands if self.show_hidden else filter(lambda c: not c.hidden, commands) + ) if self.verify_checks is False: # if we do not need to verify the checks then we can just @@ -846,21 +859,27 @@ class HelpCommand: # Since we want to have detailed errors when someone # passes an invalid subcommand, we need to walk through # the command group chain ourselves. - keys = command.split(' ') + keys = command.split(" ") cmd = bot.all_commands.get(keys[0]) if cmd is None: - string = await maybe_coro(self.command_not_found, self.remove_mentions(keys[0])) + string = await maybe_coro( + self.command_not_found, self.remove_mentions(keys[0]) + ) return await self.send_error_message(string) for key in keys[1:]: try: found = cmd.all_commands.get(key) except AttributeError: - string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key)) + string = await maybe_coro( + self.subcommand_not_found, cmd, self.remove_mentions(key) + ) return await self.send_error_message(string) else: if found is None: - string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key)) + string = await maybe_coro( + self.subcommand_not_found, cmd, self.remove_mentions(key) + ) return await self.send_error_message(string) cmd = found @@ -907,14 +926,14 @@ class DefaultHelpCommand(HelpCommand): """ def __init__(self, **options): - self.width = options.pop('width', 80) - self.indent = options.pop('indent', 2) - self.sort_commands = options.pop('sort_commands', True) - self.dm_help = options.pop('dm_help', False) - self.dm_help_threshold = options.pop('dm_help_threshold', 1000) - self.commands_heading = options.pop('commands_heading', "Commands:") - self.no_category = options.pop('no_category', 'No Category') - self.paginator = options.pop('paginator', None) + self.width = options.pop("width", 80) + self.indent = options.pop("indent", 2) + self.sort_commands = options.pop("sort_commands", True) + self.dm_help = options.pop("dm_help", False) + self.dm_help_threshold = options.pop("dm_help_threshold", 1000) + self.commands_heading = options.pop("commands_heading", "Commands:") + self.no_category = options.pop("no_category", "No Category") + self.paginator = options.pop("paginator", None) if self.paginator is None: self.paginator = Paginator() @@ -924,7 +943,7 @@ class DefaultHelpCommand(HelpCommand): def shorten_text(self, text): """:class:`str`: Shortens text to fit into the :attr:`width`.""" if len(text) > self.width: - return text[:self.width - 3].rstrip() + '...' + return text[: self.width - 3].rstrip() + "..." return text def get_ending_note(self): @@ -1021,11 +1040,11 @@ class DefaultHelpCommand(HelpCommand): # portion self.paginator.add_line(bot.description, empty=True) - no_category = f'\u200b{self.no_category}:' + no_category = f"\u200b{self.no_category}:" def get_category(command, *, no_category=no_category): cog = command.cog - return cog.qualified_name + ':' if cog is not None else no_category + return cog.qualified_name + ":" if cog is not None else no_category filtered = await self.filter_commands(bot.commands, sort=True, key=get_category) max_size = self.get_max_size(filtered) @@ -1033,7 +1052,11 @@ class DefaultHelpCommand(HelpCommand): # Now we can add the commands to the page. for category, commands in to_iterate: - commands = sorted(commands, key=lambda c: c.name) if self.sort_commands else list(commands) + commands = ( + sorted(commands, key=lambda c: c.name) + if self.sort_commands + else list(commands) + ) self.add_indented_commands(commands, heading=category, max_size=max_size) note = self.get_ending_note() @@ -1066,7 +1089,9 @@ class DefaultHelpCommand(HelpCommand): if cog.description: self.paginator.add_line(cog.description, empty=True) - filtered = await self.filter_commands(cog.get_commands(), sort=self.sort_commands) + filtered = await self.filter_commands( + cog.get_commands(), sort=self.sort_commands + ) self.add_indented_commands(filtered, heading=self.commands_heading) note = self.get_ending_note() @@ -1110,13 +1135,13 @@ class MinimalHelpCommand(HelpCommand): """ def __init__(self, **options): - self.sort_commands = options.pop('sort_commands', True) - self.commands_heading = options.pop('commands_heading', "Commands") - self.dm_help = options.pop('dm_help', False) - self.dm_help_threshold = options.pop('dm_help_threshold', 1000) - self.aliases_heading = options.pop('aliases_heading', "Aliases:") - self.no_category = options.pop('no_category', 'No Category') - self.paginator = options.pop('paginator', None) + self.sort_commands = options.pop("sort_commands", True) + self.commands_heading = options.pop("commands_heading", "Commands") + self.dm_help = options.pop("dm_help", False) + self.dm_help_threshold = options.pop("dm_help_threshold", 1000) + self.aliases_heading = options.pop("aliases_heading", "Aliases:") + self.no_category = options.pop("no_category", "No Category") + self.paginator = options.pop("paginator", None) if self.paginator is None: self.paginator = Paginator(suffix=None, prefix=None) @@ -1149,7 +1174,9 @@ class MinimalHelpCommand(HelpCommand): ) def get_command_signature(self, command): - return f'{self.context.clean_prefix}{command.qualified_name} {command.signature}' + return ( + f"{self.context.clean_prefix}{command.qualified_name} {command.signature}" + ) def get_ending_note(self): """Return the help command's ending note. This is mainly useful to override for i18n purposes. @@ -1180,8 +1207,8 @@ class MinimalHelpCommand(HelpCommand): """ if commands: # U+2002 Middle Dot - joined = '\u2002'.join(c.name for c in commands) - self.paginator.add_line(f'__**{heading}**__') + joined = "\u2002".join(c.name for c in commands) + self.paginator.add_line(f"__**{heading}**__") self.paginator.add_line(joined) def add_subcommand_formatting(self, command): @@ -1197,8 +1224,12 @@ class MinimalHelpCommand(HelpCommand): command: :class:`Command` The command to show information of. """ - fmt = '{0}{1} \N{EN DASH} {2}' if command.short_doc else '{0}{1}' - self.paginator.add_line(fmt.format(self.context.clean_prefix, command.qualified_name, command.short_doc)) + fmt = "{0}{1} \N{EN DASH} {2}" if command.short_doc else "{0}{1}" + self.paginator.add_line( + fmt.format( + self.context.clean_prefix, command.qualified_name, command.short_doc + ) + ) def add_aliases_formatting(self, aliases): """Adds the formatting information on a command's aliases. @@ -1215,7 +1246,9 @@ class MinimalHelpCommand(HelpCommand): aliases: Sequence[:class:`str`] A list of aliases to format. """ - self.paginator.add_line(f'**{self.aliases_heading}** {", ".join(aliases)}', empty=True) + self.paginator.add_line( + f'**{self.aliases_heading}** {", ".join(aliases)}', empty=True + ) def add_command_formatting(self, command): """A utility function to format commands and groups. @@ -1268,7 +1301,7 @@ class MinimalHelpCommand(HelpCommand): if note: self.paginator.add_line(note, empty=True) - no_category = f'\u200b{self.no_category}' + no_category = f"\u200b{self.no_category}" def get_category(command, *, no_category=no_category): cog = command.cog @@ -1278,7 +1311,11 @@ class MinimalHelpCommand(HelpCommand): to_iterate = itertools.groupby(filtered, key=get_category) for category, commands in to_iterate: - commands = sorted(commands, key=lambda c: c.name) if self.sort_commands else list(commands) + commands = ( + sorted(commands, key=lambda c: c.name) + if self.sort_commands + else list(commands) + ) self.add_bot_commands_formatting(commands, category) note = self.get_ending_note() @@ -1300,9 +1337,11 @@ class MinimalHelpCommand(HelpCommand): if cog.description: self.paginator.add_line(cog.description, empty=True) - filtered = await self.filter_commands(cog.get_commands(), sort=self.sort_commands) + filtered = await self.filter_commands( + cog.get_commands(), sort=self.sort_commands + ) if filtered: - self.paginator.add_line(f'**{cog.qualified_name} {self.commands_heading}**') + self.paginator.add_line(f"**{cog.qualified_name} {self.commands_heading}**") for command in filtered: self.add_subcommand_formatting(command) @@ -1322,7 +1361,7 @@ class MinimalHelpCommand(HelpCommand): if note: self.paginator.add_line(note, empty=True) - self.paginator.add_line(f'**{self.commands_heading}**') + self.paginator.add_line(f"**{self.commands_heading}**") for command in filtered: self.add_subcommand_formatting(command) diff --git a/discord/ext/commands/view.py b/discord/ext/commands/view.py index a7dc7236..31dfefb2 100644 --- a/discord/ext/commands/view.py +++ b/discord/ext/commands/view.py @@ -22,7 +22,11 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ -from .errors import UnexpectedQuoteError, InvalidEndOfQuotedStringError, ExpectedClosingQuoteError +from .errors import ( + UnexpectedQuoteError, + InvalidEndOfQuotedStringError, + ExpectedClosingQuoteError, +) # map from opening quotes to closing quotes _quotes = { @@ -46,6 +50,7 @@ _quotes = { } _all_quotes = set(_quotes.keys()) | set(_quotes.values()) + class StringView: def __init__(self, buffer): self.index = 0 @@ -81,20 +86,20 @@ class StringView: def skip_string(self, string): strlen = len(string) - if self.buffer[self.index:self.index + strlen] == string: + if self.buffer[self.index : self.index + strlen] == string: self.previous = self.index self.index += strlen return True return False def read_rest(self): - result = self.buffer[self.index:] + result = self.buffer[self.index :] self.previous = self.index self.index = self.end return result def read(self, n): - result = self.buffer[self.index:self.index + n] + result = self.buffer[self.index : self.index + n] self.previous = self.index self.index += n return result @@ -120,7 +125,7 @@ class StringView: except IndexError: break self.previous = self.index - result = self.buffer[self.index:self.index + pos] + result = self.buffer[self.index : self.index + pos] self.index += pos return result @@ -144,11 +149,11 @@ class StringView: if is_quoted: # unexpected EOF raise ExpectedClosingQuoteError(close_quote) - return ''.join(result) + return "".join(result) # currently we accept strings in the format of "hello world" # to embed a quote inside the string you must escape it: "a \"world\"" - if current == '\\': + if current == "\\": next_char = self.get() if not next_char: # string ends with \ and no character after it @@ -156,7 +161,7 @@ class StringView: # if we're quoted then we're expecting a closing quote raise ExpectedClosingQuoteError(close_quote) # if we aren't then we just let it through - return ''.join(result) + return "".join(result) if next_char in _escaped_quotes: # escaped quote @@ -179,14 +184,13 @@ class StringView: raise InvalidEndOfQuotedStringError(next_char) # we're quoted so it's okay - return ''.join(result) + return "".join(result) if current.isspace() and not is_quoted: # end of word found - return ''.join(result) + return "".join(result) result.append(current) - def __repr__(self): - return f'' + return f"" diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index 5b78f10e..37815509 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -48,21 +48,21 @@ from collections.abc import Sequence from discord.backoff import ExponentialBackoff from discord.utils import MISSING -__all__ = ( - 'loop', -) +__all__ = ("loop",) -T = TypeVar('T') +T = TypeVar("T") _func = Callable[..., Awaitable[Any]] -LF = TypeVar('LF', bound=_func) -FT = TypeVar('FT', bound=_func) -ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]]) +LF = TypeVar("LF", bound=_func) +FT = TypeVar("FT", bound=_func) +ET = TypeVar("ET", bound=Callable[[Any, BaseException], Awaitable[Any]]) class SleepHandle: - __slots__ = ('future', 'loop', 'handle') + __slots__ = ("future", "loop", "handle") - def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None: + def __init__( + self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop + ) -> None: self.loop = loop self.future = future = loop.create_future() relative_delta = discord.utils.compute_timedelta(dt) @@ -124,7 +124,7 @@ class Loop(Generic[LF]): self._stop_next_iteration = False if self.count is not None and self.count <= 0: - raise ValueError('count must be greater than 0 or None.') + raise ValueError("count must be greater than 0 or None.") self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time) self._last_iteration_failed = False @@ -132,10 +132,12 @@ class Loop(Generic[LF]): self._next_iteration = None if not inspect.iscoroutinefunction(self.coro): - raise TypeError(f'Expected coroutine function, not {type(self.coro).__name__!r}.') + raise TypeError( + f"Expected coroutine function, not {type(self.coro).__name__!r}." + ) async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None: - coro = getattr(self, '_' + name) + coro = getattr(self, "_" + name) if coro is None: return @@ -150,7 +152,7 @@ class Loop(Generic[LF]): async def _loop(self, *args: Any, **kwargs: Any) -> None: backoff = ExponentialBackoff() - await self._call_loop_function('before_loop') + await self._call_loop_function("before_loop") self._last_iteration_failed = False if self._time is not MISSING: # the time index should be prepared every time the internal loop is started @@ -193,10 +195,10 @@ class Loop(Generic[LF]): raise except Exception as exc: self._has_failed = True - await self._call_loop_function('error', exc) + await self._call_loop_function("error", exc) raise exc finally: - await self._call_loop_function('after_loop') + await self._call_loop_function("after_loop") self._handle.cancel() self._is_being_cancelled = False self._current_loop = 0 @@ -323,7 +325,7 @@ class Loop(Generic[LF]): """ if self._task is not MISSING and not self._task.done(): - raise RuntimeError('Task is already launched and is not completed.') + raise RuntimeError("Task is already launched and is not completed.") if self._injected is not None: args = (self._injected, *args) @@ -356,7 +358,9 @@ class Loop(Generic[LF]): self._stop_next_iteration = True def _can_be_cancelled(self) -> bool: - return bool(not self._is_being_cancelled and self._task and not self._task.done()) + return bool( + not self._is_being_cancelled and self._task and not self._task.done() + ) def cancel(self) -> None: """Cancels the internal task, if it is running.""" @@ -379,7 +383,9 @@ class Loop(Generic[LF]): The keyword arguments to use. """ - def restart_when_over(fut: Any, *, args: Any = args, kwargs: Any = kwargs) -> None: + def restart_when_over( + fut: Any, *, args: Any = args, kwargs: Any = kwargs + ) -> None: self._task.remove_done_callback(restart_when_over) self.start(*args, **kwargs) @@ -410,9 +416,9 @@ class Loop(Generic[LF]): for exc in exceptions: if not inspect.isclass(exc): - raise TypeError(f'{exc!r} must be a class.') + raise TypeError(f"{exc!r} must be a class.") if not issubclass(exc, BaseException): - raise TypeError(f'{exc!r} must inherit from BaseException.') + raise TypeError(f"{exc!r} must inherit from BaseException.") self._valid_exception = (*self._valid_exception, *exceptions) @@ -439,7 +445,9 @@ class Loop(Generic[LF]): Whether all exceptions were successfully removed. """ old_length = len(self._valid_exception) - self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions) + self._valid_exception = tuple( + x for x in self._valid_exception if x not in exceptions + ) return len(self._valid_exception) == old_length - len(exceptions) def get_task(self) -> Optional[asyncio.Task[None]]: @@ -466,8 +474,13 @@ class Loop(Generic[LF]): async def _error(self, *args: Any) -> None: exception: Exception = args[-1] - print(f'Unhandled exception in internal background task {self.coro.__name__!r}.', file=sys.stderr) - traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr) + print( + f"Unhandled exception in internal background task {self.coro.__name__!r}.", + file=sys.stderr, + ) + traceback.print_exception( + type(exception), exception, exception.__traceback__, file=sys.stderr + ) def before_loop(self, coro: FT) -> FT: """A decorator that registers a coroutine to be called before the loop starts running. @@ -489,7 +502,9 @@ class Loop(Generic[LF]): """ if not inspect.iscoroutinefunction(coro): - raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') + raise TypeError( + f"Expected coroutine function, received {coro.__class__.__name__!r}." + ) self._before_loop = coro return coro @@ -517,7 +532,9 @@ class Loop(Generic[LF]): """ if not inspect.iscoroutinefunction(coro): - raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') + raise TypeError( + f"Expected coroutine function, received {coro.__class__.__name__!r}." + ) self._after_loop = coro return coro @@ -543,7 +560,9 @@ class Loop(Generic[LF]): The function was not a coroutine. """ if not inspect.iscoroutinefunction(coro): - raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.') + raise TypeError( + f"Expected coroutine function, received {coro.__class__.__name__!r}." + ) self._error = coro # type: ignore return coro @@ -557,14 +576,18 @@ class Loop(Generic[LF]): if self._current_loop == 0: # if we're at the last index on the first iteration, we need to sleep until tomorrow return datetime.datetime.combine( - datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0] + datetime.datetime.now(datetime.timezone.utc) + + datetime.timedelta(days=1), + self._time[0], ) next_time = self._time[self._time_index] if self._current_loop == 0: self._time_index += 1 - return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time) + return datetime.datetime.combine( + datetime.datetime.now(datetime.timezone.utc), next_time + ) next_date = self._last_iteration if self._time_index == 0: @@ -580,7 +603,9 @@ class Loop(Generic[LF]): # pre-condition: self._time is set time_now = ( - now if now is not MISSING else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0) + now + if now is not MISSING + else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0) ).timetz() for idx, time in enumerate(self._time): if time >= time_now: @@ -601,16 +626,16 @@ class Loop(Generic[LF]): return [inner] if not isinstance(time, Sequence): raise TypeError( - f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.' + f"Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead." ) if not time: - raise ValueError('time parameter must not be an empty sequence.') + raise ValueError("time parameter must not be an empty sequence.") ret: List[datetime.time] = [] for index, t in enumerate(time): if not isinstance(t, dt): raise TypeError( - f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.' + f"Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead." ) ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc)) @@ -663,7 +688,7 @@ class Loop(Generic[LF]): hours = hours or 0 sleep = seconds + (minutes * 60.0) + (hours * 3600.0) if sleep < 0: - raise ValueError('Total number of seconds cannot be less than zero.') + raise ValueError("Total number of seconds cannot be less than zero.") self._sleep = sleep self._seconds = float(seconds) @@ -672,7 +697,7 @@ class Loop(Generic[LF]): self._time: List[datetime.time] = MISSING else: if any((seconds, minutes, hours)): - raise TypeError('Cannot mix explicit time with relative time') + raise TypeError("Cannot mix explicit time with relative time") self._time = self._get_time_parameter(time) self._sleep = self._seconds = self._minutes = self._hours = MISSING diff --git a/discord/file.py b/discord/file.py index 5303e325..0e1558a0 100644 --- a/discord/file.py +++ b/discord/file.py @@ -28,9 +28,7 @@ from typing import Optional, TYPE_CHECKING, Union import os import io -__all__ = ( - 'File', -) +__all__ = ("File",) class File: @@ -64,7 +62,7 @@ class File: Whether the attachment is a spoiler. """ - __slots__ = ('fp', 'filename', 'spoiler', '_original_pos', '_owner', '_closer') + __slots__ = ("fp", "filename", "spoiler", "_original_pos", "_owner", "_closer") if TYPE_CHECKING: fp: io.BufferedIOBase @@ -80,12 +78,12 @@ class File: ): if isinstance(fp, io.IOBase): if not (fp.seekable() and fp.readable()): - raise ValueError(f'File buffer {fp!r} must be seekable and readable') + raise ValueError(f"File buffer {fp!r} must be seekable and readable") self.fp = fp self._original_pos = fp.tell() self._owner = False else: - self.fp = open(fp, 'rb') + self.fp = open(fp, "rb") self._original_pos = 0 self._owner = True @@ -100,14 +98,20 @@ class File: if isinstance(fp, str): _, self.filename = os.path.split(fp) else: - self.filename = getattr(fp, 'name', None) + self.filename = getattr(fp, "name", None) else: self.filename = filename - if spoiler and self.filename is not None and not self.filename.startswith('SPOILER_'): - self.filename = 'SPOILER_' + self.filename + if ( + spoiler + and self.filename is not None + and not self.filename.startswith("SPOILER_") + ): + self.filename = "SPOILER_" + self.filename - self.spoiler = spoiler or (self.filename is not None and self.filename.startswith('SPOILER_')) + self.spoiler = spoiler or ( + self.filename is not None and self.filename.startswith("SPOILER_") + ) def reset(self, *, seek: Union[int, bool] = True) -> None: # The `seek` parameter is needed because diff --git a/discord/flags.py b/discord/flags.py index 3c5956a4..55822f45 100644 --- a/discord/flags.py +++ b/discord/flags.py @@ -24,21 +24,34 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import Any, Callable, ClassVar, Dict, Generic, Iterator, List, Optional, Tuple, Type, TypeVar, overload +from typing import ( + Any, + Callable, + ClassVar, + Dict, + Generic, + Iterator, + List, + Optional, + Tuple, + Type, + TypeVar, + overload, +) from .enums import UserFlags __all__ = ( - 'SystemChannelFlags', - 'MessageFlags', - 'PublicUserFlags', - 'Intents', - 'MemberCacheFlags', - 'ApplicationFlags', + "SystemChannelFlags", + "MessageFlags", + "PublicUserFlags", + "Intents", + "MemberCacheFlags", + "ApplicationFlags", ) -FV = TypeVar('FV', bound='flag_value') -BF = TypeVar('BF', bound='BaseFlags') +FV = TypeVar("FV", bound="flag_value") +BF = TypeVar("BF", bound="BaseFlags") class flag_value: @@ -63,7 +76,7 @@ class flag_value: instance._set_flag(self.flag, value) def __repr__(self): - return f'' + return f"" class alias_flag_value(flag_value): @@ -98,13 +111,13 @@ class BaseFlags: value: int - __slots__ = ('value',) + __slots__ = ("value",) def __init__(self, **kwargs: bool): self.value = self.DEFAULT_VALUE for key, value in kwargs.items(): if key not in self.VALID_FLAGS: - raise TypeError(f'{key!r} is not a valid flag name.') + raise TypeError(f"{key!r} is not a valid flag name.") setattr(self, key, value) @classmethod @@ -123,7 +136,7 @@ class BaseFlags: return hash(self.value) def __repr__(self) -> str: - return f'<{self.__class__.__name__} value={self.value}>' + return f"<{self.__class__.__name__} value={self.value}>" def __iter__(self) -> Iterator[Tuple[str, bool]]: for name, value in self.__class__.__dict__.items(): @@ -142,7 +155,9 @@ class BaseFlags: elif toggle is False: self.value &= ~o else: - raise TypeError(f'Value to set for {self.__class__.__name__} must be a bool.') + raise TypeError( + f"Value to set for {self.__class__.__name__} must be a bool." + ) @fill_with_flags(inverted=True) @@ -196,7 +211,7 @@ class SystemChannelFlags(BaseFlags): elif toggle is False: self.value |= o else: - raise TypeError('Value to set for SystemChannelFlags must be a bool.') + raise TypeError("Value to set for SystemChannelFlags must be a bool.") @flag_value def join_notifications(self): @@ -412,7 +427,11 @@ class PublicUserFlags(BaseFlags): def all(self) -> List[UserFlags]: """List[:class:`UserFlags`]: Returns all public flags the user has.""" - return [public_flag for public_flag in UserFlags if self._has_flag(public_flag.value)] + return [ + public_flag + for public_flag in UserFlags + if self._has_flag(public_flag.value) + ] @fill_with_flags() @@ -461,7 +480,7 @@ class Intents(BaseFlags): self.value = self.DEFAULT_VALUE for key, value in kwargs.items(): if key not in self.VALID_FLAGS: - raise TypeError(f'{key!r} is not a valid flag name.') + raise TypeError(f"{key!r} is not a valid flag name.") setattr(self, key, value) @classmethod @@ -907,7 +926,7 @@ class MemberCacheFlags(BaseFlags): self.value = (1 << bits) - 1 for key, value in kwargs.items(): if key not in self.VALID_FLAGS: - raise TypeError(f'{key!r} is not a valid flag name.') + raise TypeError(f"{key!r} is not a valid flag name.") setattr(self, key, value) @classmethod @@ -977,10 +996,10 @@ class MemberCacheFlags(BaseFlags): def _verify_intents(self, intents: Intents): if self.voice and not intents.voice_states: - raise ValueError('MemberCacheFlags.voice requires Intents.voice_states') + raise ValueError("MemberCacheFlags.voice requires Intents.voice_states") if self.joined and not intents.members: - raise ValueError('MemberCacheFlags.joined requires Intents.members') + raise ValueError("MemberCacheFlags.joined requires Intents.members") @property def _voice_only(self): diff --git a/discord/gateway.py b/discord/gateway.py index aa0c6ba0..b32a8492 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -43,25 +43,31 @@ from .errors import ConnectionClosed, InvalidArgument _log = logging.getLogger(__name__) __all__ = ( - 'DiscordWebSocket', - 'KeepAliveHandler', - 'VoiceKeepAliveHandler', - 'DiscordVoiceWebSocket', - 'ReconnectWebSocket', + "DiscordWebSocket", + "KeepAliveHandler", + "VoiceKeepAliveHandler", + "DiscordVoiceWebSocket", + "ReconnectWebSocket", ) + class ReconnectWebSocket(Exception): """Signals to safely reconnect the websocket.""" + def __init__(self, shard_id, *, resume=True): self.shard_id = shard_id self.resume = resume - self.op = 'RESUME' if resume else 'IDENTIFY' + self.op = "RESUME" if resume else "IDENTIFY" + class WebSocketClosure(Exception): """An exception to make up for the fact that aiohttp doesn't signal closure.""" + pass -EventListener = namedtuple('EventListener', 'predicate event result future') + +EventListener = namedtuple("EventListener", "predicate event result future") + class GatewayRatelimiter: def __init__(self, count=110, per=60.0): @@ -101,48 +107,57 @@ class GatewayRatelimiter: async with self.lock: delta = self.get_delay() if delta: - _log.warning('WebSocket in shard ID %s is ratelimited, waiting %.2f seconds', self.shard_id, delta) + _log.warning( + "WebSocket in shard ID %s is ratelimited, waiting %.2f seconds", + self.shard_id, + delta, + ) await asyncio.sleep(delta) class KeepAliveHandler(threading.Thread): def __init__(self, *args, **kwargs): - ws = kwargs.pop('ws', None) - interval = kwargs.pop('interval', None) - shard_id = kwargs.pop('shard_id', None) + ws = kwargs.pop("ws", None) + interval = kwargs.pop("interval", None) + shard_id = kwargs.pop("shard_id", None) threading.Thread.__init__(self, *args, **kwargs) self.ws = ws self._main_thread_id = ws.thread_id self.interval = interval self.daemon = True self.shard_id = shard_id - self.msg = 'Keeping shard ID %s websocket alive with sequence %s.' - self.block_msg = 'Shard ID %s heartbeat blocked for more than %s seconds.' - self.behind_msg = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.' + self.msg = "Keeping shard ID %s websocket alive with sequence %s." + self.block_msg = "Shard ID %s heartbeat blocked for more than %s seconds." + self.behind_msg = "Can't keep up, shard ID %s websocket is %.1fs behind." self._stop_ev = threading.Event() self._last_ack = time.perf_counter() self._last_send = time.perf_counter() self._last_recv = time.perf_counter() - self.latency = float('inf') + self.latency = float("inf") self.heartbeat_timeout = ws._max_heartbeat_timeout def run(self): while not self._stop_ev.wait(self.interval): if self._last_recv + self.heartbeat_timeout < time.perf_counter(): - _log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id) + _log.warning( + "Shard ID %s has stopped responding to the gateway. Closing and restarting.", + self.shard_id, + ) coro = self.ws.close(4000) f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop) try: f.result() except Exception: - _log.exception('An error occurred while stopping the gateway. Ignoring.') + _log.exception( + "An error occurred while stopping the gateway. Ignoring." + ) finally: self.stop() return data = self.get_payload() - _log.debug(self.msg, self.shard_id, data['d']) + _log.debug(self.msg, self.shard_id, data["d"]) coro = self.ws.send_heartbeat(data) f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop) try: @@ -159,8 +174,8 @@ class KeepAliveHandler(threading.Thread): except KeyError: msg = self.block_msg else: - stack = ''.join(traceback.format_stack(frame)) - msg = f'{self.block_msg}\nLoop thread traceback (most recent call last):\n{stack}' + stack = "".join(traceback.format_stack(frame)) + msg = f"{self.block_msg}\nLoop thread traceback (most recent call last):\n{stack}" _log.warning(msg, self.shard_id, total) except Exception: @@ -169,10 +184,7 @@ class KeepAliveHandler(threading.Thread): self._last_send = time.perf_counter() def get_payload(self): - return { - 'op': self.ws.HEARTBEAT, - 'd': self.ws.sequence - } + return {"op": self.ws.HEARTBEAT, "d": self.ws.sequence} def stop(self): self._stop_ev.set() @@ -187,19 +199,17 @@ class KeepAliveHandler(threading.Thread): if self.latency > 10: _log.warning(self.behind_msg, self.shard_id, self.latency) + class VoiceKeepAliveHandler(KeepAliveHandler): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.recent_ack_latencies = deque(maxlen=20) - self.msg = 'Keeping shard ID %s voice websocket alive with timestamp %s.' - self.block_msg = 'Shard ID %s voice heartbeat blocked for more than %s seconds' - self.behind_msg = 'High socket latency, shard ID %s heartbeat is %.1fs behind' + self.msg = "Keeping shard ID %s voice websocket alive with timestamp %s." + self.block_msg = "Shard ID %s voice heartbeat blocked for more than %s seconds" + self.behind_msg = "High socket latency, shard ID %s heartbeat is %.1fs behind" def get_payload(self): - return { - 'op': self.ws.HEARTBEAT, - 'd': int(time.time() * 1000) - } + return {"op": self.ws.HEARTBEAT, "d": int(time.time() * 1000)} def ack(self): ack_time = time.perf_counter() @@ -208,10 +218,12 @@ class VoiceKeepAliveHandler(KeepAliveHandler): self.latency = ack_time - self._last_send self.recent_ack_latencies.append(self.latency) + class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse): - async def close(self, *, code: int = 4000, message: bytes = b'') -> bool: + async def close(self, *, code: int = 4000, message: bytes = b"") -> bool: return await super().close(code=code, message=message) + class DiscordWebSocket: """Implements a WebSocket for Discord's gateway v6. @@ -252,19 +264,19 @@ class DiscordWebSocket: The authentication token for discord. """ - DISPATCH = 0 - HEARTBEAT = 1 - IDENTIFY = 2 - PRESENCE = 3 - VOICE_STATE = 4 - VOICE_PING = 5 - RESUME = 6 - RECONNECT = 7 - REQUEST_MEMBERS = 8 + DISPATCH = 0 + HEARTBEAT = 1 + IDENTIFY = 2 + PRESENCE = 3 + VOICE_STATE = 4 + VOICE_PING = 5 + RESUME = 6 + RECONNECT = 7 + REQUEST_MEMBERS = 8 INVALIDATE_SESSION = 9 - HELLO = 10 - HEARTBEAT_ACK = 11 - GUILD_SYNC = 12 + HELLO = 10 + HEARTBEAT_ACK = 11 + GUILD_SYNC = 12 def __init__(self, socket, *, loop): self.socket = socket @@ -294,13 +306,23 @@ class DiscordWebSocket: return self._rate_limiter.is_ratelimited() def debug_log_receive(self, data, /): - self._dispatch('socket_raw_receive', data) + self._dispatch("socket_raw_receive", data) def log_receive(self, _, /): pass @classmethod - async def from_client(cls, client, *, initial=False, gateway=None, shard_id=None, session=None, sequence=None, resume=False): + async def from_client( + cls, + client, + *, + initial=False, + gateway=None, + shard_id=None, + session=None, + sequence=None, + resume=False, + ): """Creates a main websocket for Discord from a :class:`Client`. This is for internal use only. @@ -330,7 +352,7 @@ class DiscordWebSocket: client._connection._update_references(ws) - _log.debug('Created websocket connected to %s', gateway) + _log.debug("Created websocket connected to %s", gateway) # poll event for OP Hello await ws.poll_event() @@ -363,83 +385,87 @@ class DiscordWebSocket: """ future = self.loop.create_future() - entry = EventListener(event=event, predicate=predicate, result=result, future=future) + entry = EventListener( + event=event, predicate=predicate, result=result, future=future + ) self._dispatch_listeners.append(entry) return future async def identify(self): """Sends the IDENTIFY packet.""" payload = { - 'op': self.IDENTIFY, - 'd': { - 'token': self.token, - 'properties': { - '$os': sys.platform, - '$browser': 'discord.py', - '$device': 'discord.py', - '$referrer': '', - '$referring_domain': '' + "op": self.IDENTIFY, + "d": { + "token": self.token, + "properties": { + "$os": sys.platform, + "$browser": "discord.py", + "$device": "discord.py", + "$referrer": "", + "$referring_domain": "", }, - 'compress': True, - 'large_threshold': 250, - 'v': 3 - } + "compress": True, + "large_threshold": 250, + "v": 3, + }, } if self.shard_id is not None and self.shard_count is not None: - payload['d']['shard'] = [self.shard_id, self.shard_count] + payload["d"]["shard"] = [self.shard_id, self.shard_count] state = self._connection if state._activity is not None or state._status is not None: - payload['d']['presence'] = { - 'status': state._status, - 'game': state._activity, - 'since': 0, - 'afk': False + payload["d"]["presence"] = { + "status": state._status, + "game": state._activity, + "since": 0, + "afk": False, } if state._intents is not None: - payload['d']['intents'] = state._intents.value + payload["d"]["intents"] = state._intents.value - await self.call_hooks('before_identify', self.shard_id, initial=self._initial_identify) + await self.call_hooks( + "before_identify", self.shard_id, initial=self._initial_identify + ) await self.send_as_json(payload) - _log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id) + _log.info("Shard ID %s has sent the IDENTIFY payload.", self.shard_id) async def resume(self): """Sends the RESUME packet.""" payload = { - 'op': self.RESUME, - 'd': { - 'seq': self.sequence, - 'session_id': self.session_id, - 'token': self.token - } + "op": self.RESUME, + "d": { + "seq": self.sequence, + "session_id": self.session_id, + "token": self.token, + }, } await self.send_as_json(payload) - _log.info('Shard ID %s has sent the RESUME payload.', self.shard_id) + _log.info("Shard ID %s has sent the RESUME payload.", self.shard_id) async def received_message(self, msg, /): if type(msg) is bytes: self._buffer.extend(msg) - if len(msg) < 4 or msg[-4:] != b'\x00\x00\xff\xff': + if len(msg) < 4 or msg[-4:] != b"\x00\x00\xff\xff": return msg = self._zlib.decompress(self._buffer) - msg = msg.decode('utf-8') + msg = msg.decode("utf-8") self._buffer = bytearray() self.log_receive(msg) msg = utils._from_json(msg) - _log.debug('For Shard ID %s: WebSocket Event: %s', self.shard_id, msg) - event = msg.get('t') + _log.debug("For Shard ID %s: WebSocket Event: %s", self.shard_id, msg) + event = msg.get("t") if event: - self._dispatch('socket_event_type', event) + self._dispatch("socket_event_type", event) - op = msg.get('op') - data = msg.get('d') - seq = msg.get('s') + op = msg.get("op") + data = msg.get("d") + seq = msg.get("s") if seq is not None: self.sequence = seq @@ -451,7 +477,7 @@ class DiscordWebSocket: # "reconnect" can only be handled by the Client # so we terminate our connection and raise an # internal exception signalling to reconnect. - _log.debug('Received RECONNECT opcode.') + _log.debug("Received RECONNECT opcode.") await self.close() raise ReconnectWebSocket(self.shard_id) @@ -467,8 +493,10 @@ class DiscordWebSocket: return if op == self.HELLO: - interval = data['heartbeat_interval'] / 1000.0 - self._keep_alive = KeepAliveHandler(ws=self, interval=interval, shard_id=self.shard_id) + interval = data["heartbeat_interval"] / 1000.0 + self._keep_alive = KeepAliveHandler( + ws=self, interval=interval, shard_id=self.shard_id + ) # send a heartbeat immediately await self.send_as_json(self._keep_alive.get_payload()) self._keep_alive.start() @@ -481,33 +509,41 @@ class DiscordWebSocket: self.sequence = None self.session_id = None - _log.info('Shard ID %s session has been invalidated.', self.shard_id) + _log.info("Shard ID %s session has been invalidated.", self.shard_id) await self.close(code=1000) raise ReconnectWebSocket(self.shard_id, resume=False) - _log.warning('Unknown OP code %s.', op) + _log.warning("Unknown OP code %s.", op) return - if event == 'READY': - self._trace = trace = data.get('_trace', []) - self.sequence = msg['s'] - self.session_id = data['session_id'] + if event == "READY": + self._trace = trace = data.get("_trace", []) + self.sequence = msg["s"] + self.session_id = data["session_id"] # pass back shard ID to ready handler - data['__shard_id__'] = self.shard_id - _log.info('Shard ID %s has connected to Gateway: %s (Session ID: %s).', - self.shard_id, ', '.join(trace), self.session_id) + data["__shard_id__"] = self.shard_id + _log.info( + "Shard ID %s has connected to Gateway: %s (Session ID: %s).", + self.shard_id, + ", ".join(trace), + self.session_id, + ) - elif event == 'RESUMED': - self._trace = trace = data.get('_trace', []) + elif event == "RESUMED": + self._trace = trace = data.get("_trace", []) # pass back the shard ID to the resumed handler - data['__shard_id__'] = self.shard_id - _log.info('Shard ID %s has successfully RESUMED session %s under trace %s.', - self.shard_id, self.session_id, ', '.join(trace)) + data["__shard_id__"] = self.shard_id + _log.info( + "Shard ID %s has successfully RESUMED session %s under trace %s.", + self.shard_id, + self.session_id, + ", ".join(trace), + ) try: func = self._discord_parsers[event] except KeyError: - _log.debug('Unknown event %s.', event) + _log.debug("Unknown event %s.", event) else: func(data) @@ -540,7 +576,7 @@ class DiscordWebSocket: def latency(self): """:class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.""" heartbeat = self._keep_alive - return float('inf') if heartbeat is None else heartbeat.latency + return float("inf") if heartbeat is None else heartbeat.latency def _can_handle_close(self): code = self._close_code or self.socket.close_code @@ -561,10 +597,14 @@ class DiscordWebSocket: elif msg.type is aiohttp.WSMsgType.BINARY: await self.received_message(msg.data) elif msg.type is aiohttp.WSMsgType.ERROR: - _log.debug('Received %s', msg) + _log.debug("Received %s", msg) raise msg.data - elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSE): - _log.debug('Received %s', msg) + elif msg.type in ( + aiohttp.WSMsgType.CLOSED, + aiohttp.WSMsgType.CLOSING, + aiohttp.WSMsgType.CLOSE, + ): + _log.debug("Received %s", msg) raise WebSocketClosure except (asyncio.TimeoutError, WebSocketClosure) as e: # Ensure the keep alive handler is closed @@ -573,20 +613,22 @@ class DiscordWebSocket: self._keep_alive = None if isinstance(e, asyncio.TimeoutError): - _log.info('Timed out receiving packet. Attempting a reconnect.') + _log.info("Timed out receiving packet. Attempting a reconnect.") raise ReconnectWebSocket(self.shard_id) from None code = self._close_code or self.socket.close_code if self._can_handle_close(): - _log.info('Websocket closed with %s, attempting a reconnect.', code) + _log.info("Websocket closed with %s, attempting a reconnect.", code) raise ReconnectWebSocket(self.shard_id) from None else: - _log.info('Websocket closed with %s, cannot reconnect.', code) - raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None + _log.info("Websocket closed with %s, cannot reconnect.", code) + raise ConnectionClosed( + self.socket, shard_id=self.shard_id, code=code + ) from None async def debug_send(self, data, /): await self._rate_limiter.block() - self._dispatch('socket_raw_send', data) + self._dispatch("socket_raw_send", data) await self.socket.send_str(data) async def send(self, data, /): @@ -611,62 +653,59 @@ class DiscordWebSocket: async def change_presence(self, *, activity=None, status=None, since=0.0): if activity is not None: if not isinstance(activity, BaseActivity): - raise InvalidArgument('activity must derive from BaseActivity.') + raise InvalidArgument("activity must derive from BaseActivity.") activity = [activity.to_dict()] else: activity = [] - if status == 'idle': + if status == "idle": since = int(time.time() * 1000) payload = { - 'op': self.PRESENCE, - 'd': { - 'activities': activity, - 'afk': False, - 'since': since, - 'status': status - } + "op": self.PRESENCE, + "d": { + "activities": activity, + "afk": False, + "since": since, + "status": status, + }, } sent = utils._to_json(payload) _log.debug('Sending "%s" to change status', sent) await self.send(sent) - async def request_chunks(self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None): + async def request_chunks( + self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None + ): payload = { - 'op': self.REQUEST_MEMBERS, - 'd': { - 'guild_id': guild_id, - 'presences': presences, - 'limit': limit - } + "op": self.REQUEST_MEMBERS, + "d": {"guild_id": guild_id, "presences": presences, "limit": limit}, } if nonce: - payload['d']['nonce'] = nonce + payload["d"]["nonce"] = nonce if user_ids: - payload['d']['user_ids'] = user_ids + payload["d"]["user_ids"] = user_ids if query is not None: - payload['d']['query'] = query - + payload["d"]["query"] = query await self.send_as_json(payload) async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False): payload = { - 'op': self.VOICE_STATE, - 'd': { - 'guild_id': guild_id, - 'channel_id': channel_id, - 'self_mute': self_mute, - 'self_deaf': self_deaf - } + "op": self.VOICE_STATE, + "d": { + "guild_id": guild_id, + "channel_id": channel_id, + "self_mute": self_mute, + "self_deaf": self_deaf, + }, } - _log.debug('Updating our voice state to %s.', payload) + _log.debug("Updating our voice state to %s.", payload) await self.send_as_json(payload) async def close(self, code=4000): @@ -677,6 +716,7 @@ class DiscordWebSocket: self._close_code = code await self.socket.close(code=code) + class DiscordVoiceWebSocket: """Implements the websocket protocol for handling voice connections. @@ -708,18 +748,18 @@ class DiscordVoiceWebSocket: Receive only. Indicates a user has disconnected from voice. """ - IDENTIFY = 0 - SELECT_PROTOCOL = 1 - READY = 2 - HEARTBEAT = 3 + IDENTIFY = 0 + SELECT_PROTOCOL = 1 + READY = 2 + HEARTBEAT = 3 SESSION_DESCRIPTION = 4 - SPEAKING = 5 - HEARTBEAT_ACK = 6 - RESUME = 7 - HELLO = 8 - RESUMED = 9 - CLIENT_CONNECT = 12 - CLIENT_DISCONNECT = 13 + SPEAKING = 5 + HEARTBEAT_ACK = 6 + RESUME = 7 + HELLO = 8 + RESUMED = 9 + CLIENT_CONNECT = 12 + CLIENT_DISCONNECT = 13 def __init__(self, socket, loop, *, hook=None): self.ws = socket @@ -734,7 +774,7 @@ class DiscordVoiceWebSocket: pass async def send_as_json(self, data): - _log.debug('Sending voice websocket frame: %s.', data) + _log.debug("Sending voice websocket frame: %s.", data) await self.ws.send_str(utils._to_json(data)) send_heartbeat = send_as_json @@ -742,32 +782,32 @@ class DiscordVoiceWebSocket: async def resume(self): state = self._connection payload = { - 'op': self.RESUME, - 'd': { - 'token': state.token, - 'server_id': str(state.server_id), - 'session_id': state.session_id - } + "op": self.RESUME, + "d": { + "token": state.token, + "server_id": str(state.server_id), + "session_id": state.session_id, + }, } await self.send_as_json(payload) async def identify(self): state = self._connection payload = { - 'op': self.IDENTIFY, - 'd': { - 'server_id': str(state.server_id), - 'user_id': str(state.user.id), - 'session_id': state.session_id, - 'token': state.token - } + "op": self.IDENTIFY, + "d": { + "server_id": str(state.server_id), + "user_id": str(state.user.id), + "session_id": state.session_id, + "token": state.token, + }, } await self.send_as_json(payload) @classmethod async def from_client(cls, client, *, resume=False, hook=None): """Creates a voice websocket for the :class:`VoiceClient`.""" - gateway = 'wss://' + client.endpoint + '/?v=4' + gateway = "wss://" + client.endpoint + "/?v=4" http = client._state.http socket = await http.ws_connect(gateway, compress=15) ws = cls(socket, loop=client.loop, hook=hook) @@ -785,109 +825,101 @@ class DiscordVoiceWebSocket: async def select_protocol(self, ip, port, mode): payload = { - 'op': self.SELECT_PROTOCOL, - 'd': { - 'protocol': 'udp', - 'data': { - 'address': ip, - 'port': port, - 'mode': mode - } - } + "op": self.SELECT_PROTOCOL, + "d": { + "protocol": "udp", + "data": {"address": ip, "port": port, "mode": mode}, + }, } await self.send_as_json(payload) async def client_connect(self): payload = { - 'op': self.CLIENT_CONNECT, - 'd': { - 'audio_ssrc': self._connection.ssrc - } + "op": self.CLIENT_CONNECT, + "d": {"audio_ssrc": self._connection.ssrc}, } await self.send_as_json(payload) async def speak(self, state=SpeakingState.voice): - payload = { - 'op': self.SPEAKING, - 'd': { - 'speaking': int(state), - 'delay': 0 - } - } + payload = {"op": self.SPEAKING, "d": {"speaking": int(state), "delay": 0}} await self.send_as_json(payload) async def received_message(self, msg): - _log.debug('Voice websocket frame received: %s', msg) - op = msg['op'] - data = msg.get('d') + _log.debug("Voice websocket frame received: %s", msg) + op = msg["op"] + data = msg.get("d") if op == self.READY: await self.initial_connection(data) elif op == self.HEARTBEAT_ACK: self._keep_alive.ack() elif op == self.RESUMED: - _log.info('Voice RESUME succeeded.') + _log.info("Voice RESUME succeeded.") elif op == self.SESSION_DESCRIPTION: - self._connection.mode = data['mode'] + self._connection.mode = data["mode"] await self.load_secret_key(data) elif op == self.HELLO: - interval = data['heartbeat_interval'] / 1000.0 - self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0)) + interval = data["heartbeat_interval"] / 1000.0 + self._keep_alive = VoiceKeepAliveHandler( + ws=self, interval=min(interval, 5.0) + ) self._keep_alive.start() await self._hook(self, msg) async def initial_connection(self, data): state = self._connection - state.ssrc = data['ssrc'] - state.voice_port = data['port'] - state.endpoint_ip = data['ip'] + state.ssrc = data["ssrc"] + state.voice_port = data["port"] + state.endpoint_ip = data["ip"] packet = bytearray(70) - struct.pack_into('>H', packet, 0, 1) # 1 = Send - struct.pack_into('>H', packet, 2, 70) # 70 = Length - struct.pack_into('>I', packet, 4, state.ssrc) + struct.pack_into(">H", packet, 0, 1) # 1 = Send + struct.pack_into(">H", packet, 2, 70) # 70 = Length + struct.pack_into(">I", packet, 4, state.ssrc) state.socket.sendto(packet, (state.endpoint_ip, state.voice_port)) recv = await self.loop.sock_recv(state.socket, 70) - _log.debug('received packet in initial_connection: %s', recv) + _log.debug("received packet in initial_connection: %s", recv) # the ip is ascii starting at the 4th byte and ending at the first null ip_start = 4 ip_end = recv.index(0, ip_start) - state.ip = recv[ip_start:ip_end].decode('ascii') + state.ip = recv[ip_start:ip_end].decode("ascii") - state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0] - _log.debug('detected ip: %s port: %s', state.ip, state.port) + state.port = struct.unpack_from(">H", recv, len(recv) - 2)[0] + _log.debug("detected ip: %s port: %s", state.ip, state.port) # there *should* always be at least one supported mode (xsalsa20_poly1305) - modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes] - _log.debug('received supported encryption modes: %s', ", ".join(modes)) + modes = [ + mode for mode in data["modes"] if mode in self._connection.supported_modes + ] + _log.debug("received supported encryption modes: %s", ", ".join(modes)) mode = modes[0] await self.select_protocol(state.ip, state.port, mode) - _log.info('selected the voice protocol for use (%s)', mode) + _log.info("selected the voice protocol for use (%s)", mode) @property def latency(self): """:class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds.""" heartbeat = self._keep_alive - return float('inf') if heartbeat is None else heartbeat.latency + return float("inf") if heartbeat is None else heartbeat.latency @property def average_latency(self): """:class:`list`: Average of last 20 HEARTBEAT latencies.""" heartbeat = self._keep_alive if heartbeat is None or not heartbeat.recent_ack_latencies: - return float('inf') + return float("inf") return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies) async def load_secret_key(self, data): - _log.info('received secret key for voice connection') - self.secret_key = self._connection.secret_key = data.get('secret_key') + _log.info("received secret key for voice connection") + self.secret_key = self._connection.secret_key = data.get("secret_key") await self.speak() await self.speak(False) @@ -897,10 +929,14 @@ class DiscordVoiceWebSocket: if msg.type is aiohttp.WSMsgType.TEXT: await self.received_message(utils._from_json(msg.data)) elif msg.type is aiohttp.WSMsgType.ERROR: - _log.debug('Received %s', msg) + _log.debug("Received %s", msg) raise ConnectionClosed(self.ws, shard_id=None) from msg.data - elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING): - _log.debug('Received %s', msg) + elif msg.type in ( + aiohttp.WSMsgType.CLOSED, + aiohttp.WSMsgType.CLOSE, + aiohttp.WSMsgType.CLOSING, + ): + _log.debug("Received %s", msg) raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code) async def close(self, code=1000): diff --git a/discord/guild.py b/discord/guild.py index 4ed89821..eb24ac30 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -78,21 +78,30 @@ from .sticker import GuildSticker from .file import File -__all__ = ( - 'Guild', -) +__all__ = ("Guild",) MISSING = utils.MISSING if TYPE_CHECKING: from .abc import Snowflake, SnowflakeTime - from .types.guild import Ban as BanPayload, Guild as GuildPayload, MFALevel, GuildFeature + from .types.guild import ( + Ban as BanPayload, + Guild as GuildPayload, + MFALevel, + GuildFeature, + ) from .types.threads import ( Thread as ThreadPayload, ) from .types.voice import GuildVoiceState from .permissions import Permissions - from .channel import VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel + from .channel import ( + VoiceChannel, + StageChannel, + TextChannel, + CategoryChannel, + StoreChannel, + ) from .template import Template from .webhook import Webhook from .state import ConnectionState @@ -101,7 +110,9 @@ if TYPE_CHECKING: import datetime VocalGuildChannel = Union[VoiceChannel, StageChannel] - GuildChannel = Union[VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel] + GuildChannel = Union[ + VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel + ] ByCategoryItem = Tuple[Optional[CategoryChannel], List[GuildChannel]] @@ -239,45 +250,45 @@ class Guild(Hashable): """ __slots__ = ( - 'afk_timeout', - 'afk_channel', - 'name', - 'id', - 'unavailable', - 'region', - 'owner_id', - 'mfa_level', - 'emojis', - 'stickers', - 'features', - 'verification_level', - 'explicit_content_filter', - 'default_notifications', - 'description', - 'max_presences', - 'max_members', - 'max_video_channel_users', - 'premium_tier', - 'premium_subscription_count', - 'preferred_locale', - 'nsfw_level', - '_members', - '_channels', - '_icon', - '_banner', - '_state', - '_roles', - '_member_count', - '_large', - '_splash', - '_voice_states', - '_system_channel_id', - '_system_channel_flags', - '_discovery_splash', - '_rules_channel_id', - '_public_updates_channel_id', - '_stage_instances', - '_threads', + "afk_timeout", + "afk_channel", + "name", + "id", + "unavailable", + "region", + "owner_id", + "mfa_level", + "emojis", + "stickers", + "features", + "verification_level", + "explicit_content_filter", + "default_notifications", + "description", + "max_presences", + "max_members", + "max_video_channel_users", + "premium_tier", + "premium_subscription_count", + "preferred_locale", + "nsfw_level", + "_members", + "_channels", + "_icon", + "_banner", + "_state", + "_roles", + "_member_count", + "_large", + "_splash", + "_voice_states", + "_system_channel_id", + "_system_channel_flags", + "_discovery_splash", + "_rules_channel_id", + "_public_updates_channel_id", + "_stage_instances", + "_threads", ) _PREMIUM_GUILD_LIMITS: ClassVar[Dict[Optional[int], _GuildLimit]] = { @@ -331,27 +342,31 @@ class Guild(Hashable): del self._threads[k] def _filter_threads(self, channel_ids: Set[int]) -> Dict[int, Thread]: - to_remove: Dict[int, Thread] = {k: t for k, t in self._threads.items() if t.parent_id in channel_ids} + to_remove: Dict[int, Thread] = { + k: t for k, t in self._threads.items() if t.parent_id in channel_ids + } for k in to_remove: del self._threads[k] return to_remove def __str__(self) -> str: - return self.name or '' + return self.name or "" def __repr__(self) -> str: attrs = ( - ('id', self.id), - ('name', self.name), - ('shard_id', self.shard_id), - ('chunked', self.chunked), - ('member_count', getattr(self, '_member_count', None)), + ("id", self.id), + ("name", self.name), + ("shard_id", self.shard_id), + ("chunked", self.chunked), + ("member_count", getattr(self, "_member_count", None)), ) - inner = ' '.join('%s=%r' % t for t in attrs) - return f'' + inner = " ".join("%s=%r" % t for t in attrs) + return f"" - def _update_voice_state(self, data: GuildVoiceState, channel_id: int) -> Tuple[Optional[Member], VoiceState, VoiceState]: - user_id = int(data['user_id']) + def _update_voice_state( + self, data: GuildVoiceState, channel_id: int + ) -> Tuple[Optional[Member], VoiceState, VoiceState]: + user_id = int(data["user_id"]) channel = self.get_channel(channel_id) try: # check if we should remove the voice state from cache @@ -371,7 +386,7 @@ class Guild(Hashable): member = self.get_member(user_id) if member is None: try: - member = Member(data=data['member'], state=self._state, guild=self) + member = Member(data=data["member"], state=self._state, guild=self) except KeyError: member = None @@ -403,93 +418,111 @@ class Guild(Hashable): def _from_data(self, guild: GuildPayload) -> None: # according to Stan, this is always available even if the guild is unavailable # I don't have this guarantee when someone updates the guild. - member_count = guild.get('member_count', None) + member_count = guild.get("member_count", None) if member_count is not None: self._member_count: int = member_count - self.name: str = guild.get('name') - self.region: VoiceRegion = try_enum(VoiceRegion, guild.get('region')) - self.verification_level: VerificationLevel = try_enum(VerificationLevel, guild.get('verification_level')) - self.default_notifications: NotificationLevel = try_enum( - NotificationLevel, guild.get('default_message_notifications') + self.name: str = guild.get("name") + self.region: VoiceRegion = try_enum(VoiceRegion, guild.get("region")) + self.verification_level: VerificationLevel = try_enum( + VerificationLevel, guild.get("verification_level") ) - self.explicit_content_filter: ContentFilter = try_enum(ContentFilter, guild.get('explicit_content_filter', 0)) - self.afk_timeout: int = guild.get('afk_timeout') - self._icon: Optional[str] = guild.get('icon') - self._banner: Optional[str] = guild.get('banner') - self.unavailable: bool = guild.get('unavailable', False) - self.id: int = int(guild['id']) + self.default_notifications: NotificationLevel = try_enum( + NotificationLevel, guild.get("default_message_notifications") + ) + self.explicit_content_filter: ContentFilter = try_enum( + ContentFilter, guild.get("explicit_content_filter", 0) + ) + self.afk_timeout: int = guild.get("afk_timeout") + self._icon: Optional[str] = guild.get("icon") + self._banner: Optional[str] = guild.get("banner") + self.unavailable: bool = guild.get("unavailable", False) + self.id: int = int(guild["id"]) self._roles: Dict[int, Role] = {} state = self._state # speed up attribute access - for r in guild.get('roles', []): + for r in guild.get("roles", []): role = Role(guild=self, data=r, state=state) self._roles[role.id] = role - self.mfa_level: MFALevel = guild.get('mfa_level') - self.emojis: Tuple[Emoji, ...] = tuple(map(lambda d: state.store_emoji(self, d), guild.get('emojis', []))) - self.stickers: Tuple[GuildSticker, ...] = tuple( - map(lambda d: state.store_sticker(self, d), guild.get('stickers', [])) + self.mfa_level: MFALevel = guild.get("mfa_level") + self.emojis: Tuple[Emoji, ...] = tuple( + map(lambda d: state.store_emoji(self, d), guild.get("emojis", [])) ) - self.features: List[GuildFeature] = guild.get('features', []) - self._splash: Optional[str] = guild.get('splash') - self._system_channel_id: Optional[int] = utils._get_as_snowflake(guild, 'system_channel_id') - self.description: Optional[str] = guild.get('description') - self.max_presences: Optional[int] = guild.get('max_presences') - self.max_members: Optional[int] = guild.get('max_members') - self.max_video_channel_users: Optional[int] = guild.get('max_video_channel_users') - self.premium_tier: int = guild.get('premium_tier', 0) - self.premium_subscription_count: int = guild.get('premium_subscription_count') or 0 - self._system_channel_flags: int = guild.get('system_channel_flags', 0) - self.preferred_locale: Optional[str] = guild.get('preferred_locale') - self._discovery_splash: Optional[str] = guild.get('discovery_splash') - self._rules_channel_id: Optional[int] = utils._get_as_snowflake(guild, 'rules_channel_id') - self._public_updates_channel_id: Optional[int] = utils._get_as_snowflake(guild, 'public_updates_channel_id') - self.nsfw_level: NSFWLevel = try_enum(NSFWLevel, guild.get('nsfw_level', 0)) + self.stickers: Tuple[GuildSticker, ...] = tuple( + map(lambda d: state.store_sticker(self, d), guild.get("stickers", [])) + ) + self.features: List[GuildFeature] = guild.get("features", []) + self._splash: Optional[str] = guild.get("splash") + self._system_channel_id: Optional[int] = utils._get_as_snowflake( + guild, "system_channel_id" + ) + self.description: Optional[str] = guild.get("description") + self.max_presences: Optional[int] = guild.get("max_presences") + self.max_members: Optional[int] = guild.get("max_members") + self.max_video_channel_users: Optional[int] = guild.get( + "max_video_channel_users" + ) + self.premium_tier: int = guild.get("premium_tier", 0) + self.premium_subscription_count: int = ( + guild.get("premium_subscription_count") or 0 + ) + self._system_channel_flags: int = guild.get("system_channel_flags", 0) + self.preferred_locale: Optional[str] = guild.get("preferred_locale") + self._discovery_splash: Optional[str] = guild.get("discovery_splash") + self._rules_channel_id: Optional[int] = utils._get_as_snowflake( + guild, "rules_channel_id" + ) + self._public_updates_channel_id: Optional[int] = utils._get_as_snowflake( + guild, "public_updates_channel_id" + ) + self.nsfw_level: NSFWLevel = try_enum(NSFWLevel, guild.get("nsfw_level", 0)) self._stage_instances: Dict[int, StageInstance] = {} - for s in guild.get('stage_instances', []): + for s in guild.get("stage_instances", []): stage_instance = StageInstance(guild=self, data=s, state=state) self._stage_instances[stage_instance.id] = stage_instance cache_joined = self._state.member_cache_flags.joined self_id = self._state.self_id - for mdata in guild.get('members', []): + for mdata in guild.get("members", []): member = Member(data=mdata, guild=self, state=state) if cache_joined or member.id == self_id: self._add_member(member) self._sync(guild) - self._large: Optional[bool] = None if member_count is None else self._member_count >= 250 + self._large: Optional[bool] = ( + None if member_count is None else self._member_count >= 250 + ) - self.owner_id: Optional[int] = utils._get_as_snowflake(guild, 'owner_id') - self.afk_channel: Optional[VocalGuildChannel] = self.get_channel(utils._get_as_snowflake(guild, 'afk_channel_id')) # type: ignore + self.owner_id: Optional[int] = utils._get_as_snowflake(guild, "owner_id") + self.afk_channel: Optional[VocalGuildChannel] = self.get_channel(utils._get_as_snowflake(guild, "afk_channel_id")) # type: ignore - for obj in guild.get('voice_states', []): - self._update_voice_state(obj, int(obj['channel_id'])) + for obj in guild.get("voice_states", []): + self._update_voice_state(obj, int(obj["channel_id"])) # TODO: refactor/remove? def _sync(self, data: GuildPayload) -> None: try: - self._large = data['large'] + self._large = data["large"] except KeyError: pass empty_tuple = tuple() - for presence in data.get('presences', []): - user_id = int(presence['user']['id']) + for presence in data.get("presences", []): + user_id = int(presence["user"]["id"]) member = self.get_member(user_id) if member is not None: member._presence_update(presence, empty_tuple) # type: ignore - if 'channels' in data: - channels = data['channels'] + if "channels" in data: + channels = data["channels"] for c in channels: - factory, ch_type = _guild_channel_factory(c['type']) + factory, ch_type = _guild_channel_factory(c["type"]) if factory: self._add_channel(factory(guild=self, data=c, state=self._state)) # type: ignore - if 'threads' in data: - threads = data['threads'] + if "threads" in data: + threads = data["threads"] for thread in threads: self._add_thread(Thread(guild=self, state=self._state, data=thread)) @@ -611,13 +644,17 @@ class Guild(Hashable): channels.sort(key=lambda c: (c._sorting_bucket, c.position, c.id)) return as_list - def _resolve_channel(self, id: Optional[int], /) -> Optional[Union[GuildChannel, Thread]]: + def _resolve_channel( + self, id: Optional[int], / + ) -> Optional[Union[GuildChannel, Thread]]: if id is None: return return self._channels.get(id) or self._threads.get(id) - def get_channel_or_thread(self, channel_id: int, /) -> Optional[Union[Thread, GuildChannel]]: + def get_channel_or_thread( + self, channel_id: int, / + ) -> Optional[Union[Thread, GuildChannel]]: """Returns a channel or thread with the given ID. .. versionadded:: 2.0 @@ -712,7 +749,7 @@ class Guild(Hashable): @property def emoji_limit(self) -> int: """:class:`int`: The maximum number of emoji slots this guild has.""" - more_emoji = 200 if 'MORE_EMOJI' in self.features else 50 + more_emoji = 200 if "MORE_EMOJI" in self.features else 50 return max(more_emoji, self._PREMIUM_GUILD_LIMITS[self.premium_tier].emoji) @property @@ -721,13 +758,19 @@ class Guild(Hashable): .. versionadded:: 2.0 """ - more_stickers = 60 if 'MORE_STICKERS' in self.features else 0 - return max(more_stickers, self._PREMIUM_GUILD_LIMITS[self.premium_tier].stickers) + more_stickers = 60 if "MORE_STICKERS" in self.features else 0 + return max( + more_stickers, self._PREMIUM_GUILD_LIMITS[self.premium_tier].stickers + ) @property def bitrate_limit(self) -> float: """:class:`float`: The maximum bitrate for voice channels this guild can have.""" - vip_guild = self._PREMIUM_GUILD_LIMITS[1].bitrate if 'VIP_REGIONS' in self.features else 96e3 + vip_guild = ( + self._PREMIUM_GUILD_LIMITS[1].bitrate + if "VIP_REGIONS" in self.features + else 96e3 + ) return max(vip_guild, self._PREMIUM_GUILD_LIMITS[self.premium_tier].bitrate) @property @@ -743,15 +786,15 @@ class Guild(Hashable): @property def humans(self) -> List[Member]: """List[:class:`Member`]: A list of human members that belong to this guild. - - .. versionadded:: 2.0 """ + + .. versionadded:: 2.0""" return [member for member in self.members if not member.bot] @property def bots(self) -> List[Member]: """List[:class:`Member`]: A list of bots that belong to this guild. - - .. versionadded:: 2.0 """ + + .. versionadded:: 2.0""" return [member for member in self.members if member.bot] def get_member(self, user_id: int, /) -> Optional[Member]: @@ -871,21 +914,27 @@ class Guild(Hashable): """Optional[:class:`Asset`]: Returns the guild's banner asset, if available.""" if self._banner is None: return None - return Asset._from_guild_image(self._state, self.id, self._banner, path='banners') + return Asset._from_guild_image( + self._state, self.id, self._banner, path="banners" + ) @property def splash(self) -> Optional[Asset]: """Optional[:class:`Asset`]: Returns the guild's invite splash asset, if available.""" if self._splash is None: return None - return Asset._from_guild_image(self._state, self.id, self._splash, path='splashes') + return Asset._from_guild_image( + self._state, self.id, self._splash, path="splashes" + ) @property def discovery_splash(self) -> Optional[Asset]: """Optional[:class:`Asset`]: Returns the guild's discovery splash asset, if available.""" if self._discovery_splash is None: return None - return Asset._from_guild_image(self._state, self.id, self._discovery_splash, path='discovery-splashes') + return Asset._from_guild_image( + self._state, self.id, self._discovery_splash, path="discovery-splashes" + ) @property def member_count(self) -> int: @@ -909,7 +958,7 @@ class Guild(Hashable): If this value returns ``False``, then you should request for offline members. """ - count = getattr(self, '_member_count', None) + count = getattr(self, "_member_count", None) if count is None: return False return count == len(self._members) @@ -956,7 +1005,7 @@ class Guild(Hashable): result = None members = self.members - if len(name) > 5 and name[-5] == '#': + if len(name) > 5 and name[-5] == "#": # The 5 length is checking to see if #0000 is in the string, # as a#0000 has a length of 6, the minimum for a potential # discriminator lookup. @@ -964,7 +1013,9 @@ class Guild(Hashable): # do the actual lookup and return if found # if it isn't found then we'll do a full name lookup below. - result = utils.get(members, name=name[:-5], discriminator=potential_discriminator) + result = utils.get( + members, name=name[:-5], discriminator=potential_discriminator + ) if result is not None: return result @@ -984,26 +1035,33 @@ class Guild(Hashable): if overwrites is MISSING: overwrites = {} elif not isinstance(overwrites, dict): - raise InvalidArgument('overwrites parameter expects a dict.') + raise InvalidArgument("overwrites parameter expects a dict.") perms = [] for target, perm in overwrites.items(): if not isinstance(perm, PermissionOverwrite): - raise InvalidArgument(f'Expected PermissionOverwrite received {perm.__class__.__name__}') + raise InvalidArgument( + f"Expected PermissionOverwrite received {perm.__class__.__name__}" + ) allow, deny = perm.pair() - payload = {'allow': allow.value, 'deny': deny.value, 'id': target.id} + payload = {"allow": allow.value, "deny": deny.value, "id": target.id} if isinstance(target, Role): - payload['type'] = abc._Overwrites.ROLE + payload["type"] = abc._Overwrites.ROLE else: - payload['type'] = abc._Overwrites.MEMBER + payload["type"] = abc._Overwrites.MEMBER perms.append(payload) parent_id = category.id if category else None return self._state.http.create_channel( - self.id, channel_type.value, name=name, parent_id=parent_id, permission_overwrites=perms, **options + self.id, + channel_type.value, + name=name, + parent_id=parent_id, + permission_overwrites=perms, + **options, ) async def create_text_channel( @@ -1098,19 +1156,24 @@ class Guild(Hashable): options = {} if position is not MISSING: - options['position'] = position + options["position"] = position if topic is not MISSING: - options['topic'] = topic + options["topic"] = topic if slowmode_delay is not MISSING: - options['rate_limit_per_user'] = slowmode_delay + options["rate_limit_per_user"] = slowmode_delay if nsfw is not MISSING: - options['nsfw'] = nsfw + options["nsfw"] = nsfw data = await self._create_channel( - name, overwrites=overwrites, channel_type=ChannelType.text, category=category, reason=reason, **options + name, + overwrites=overwrites, + channel_type=ChannelType.text, + category=category, + reason=reason, + **options, ) channel = TextChannel(state=self._state, guild=self, data=data) @@ -1182,22 +1245,27 @@ class Guild(Hashable): """ options = {} if position is not MISSING: - options['position'] = position + options["position"] = position if bitrate is not MISSING: - options['bitrate'] = bitrate + options["bitrate"] = bitrate if user_limit is not MISSING: - options['user_limit'] = user_limit + options["user_limit"] = user_limit if rtc_region is not MISSING: - options['rtc_region'] = None if rtc_region is None else str(rtc_region) + options["rtc_region"] = None if rtc_region is None else str(rtc_region) if video_quality_mode is not MISSING: - options['video_quality_mode'] = video_quality_mode.value + options["video_quality_mode"] = video_quality_mode.value data = await self._create_channel( - name, overwrites=overwrites, channel_type=ChannelType.voice, category=category, reason=reason, **options + name, + overwrites=overwrites, + channel_type=ChannelType.voice, + category=category, + reason=reason, + **options, ) channel = VoiceChannel(state=self._state, guild=self, data=data) @@ -1257,13 +1325,18 @@ class Guild(Hashable): """ options: Dict[str, Any] = { - 'topic': topic, + "topic": topic, } if position is not MISSING: - options['position'] = position + options["position"] = position data = await self._create_channel( - name, overwrites=overwrites, channel_type=ChannelType.stage_voice, category=category, reason=reason, **options + name, + overwrites=overwrites, + channel_type=ChannelType.stage_voice, + category=category, + reason=reason, + **options, ) channel = StageChannel(state=self._state, guild=self, data=data) @@ -1304,10 +1377,14 @@ class Guild(Hashable): """ options: Dict[str, Any] = {} if position is not MISSING: - options['position'] = position + options["position"] = position data = await self._create_channel( - name, overwrites=overwrites, channel_type=ChannelType.category, reason=reason, **options + name, + overwrites=overwrites, + channel_type=ChannelType.category, + reason=reason, + **options, ) channel = CategoryChannel(state=self._state, guild=self, data=data) @@ -1479,108 +1556,123 @@ class Guild(Hashable): fields: Dict[str, Any] = {} if name is not MISSING: - fields['name'] = name + fields["name"] = name if description is not MISSING: - fields['description'] = description + fields["description"] = description if preferred_locale is not MISSING: - fields['preferred_locale'] = preferred_locale + fields["preferred_locale"] = preferred_locale if afk_timeout is not MISSING: - fields['afk_timeout'] = afk_timeout + fields["afk_timeout"] = afk_timeout if icon is not MISSING: if icon is None: - fields['icon'] = icon + fields["icon"] = icon else: - fields['icon'] = utils._bytes_to_base64_data(icon) + fields["icon"] = utils._bytes_to_base64_data(icon) if banner is not MISSING: if banner is None: - fields['banner'] = banner + fields["banner"] = banner else: - fields['banner'] = utils._bytes_to_base64_data(banner) + fields["banner"] = utils._bytes_to_base64_data(banner) if splash is not MISSING: if splash is None: - fields['splash'] = splash + fields["splash"] = splash else: - fields['splash'] = utils._bytes_to_base64_data(splash) + fields["splash"] = utils._bytes_to_base64_data(splash) if discovery_splash is not MISSING: if discovery_splash is None: - fields['discovery_splash'] = discovery_splash + fields["discovery_splash"] = discovery_splash else: - fields['discovery_splash'] = utils._bytes_to_base64_data(discovery_splash) + fields["discovery_splash"] = utils._bytes_to_base64_data( + discovery_splash + ) if default_notifications is not MISSING: if not isinstance(default_notifications, NotificationLevel): - raise InvalidArgument('default_notifications field must be of type NotificationLevel') - fields['default_message_notifications'] = default_notifications.value + raise InvalidArgument( + "default_notifications field must be of type NotificationLevel" + ) + fields["default_message_notifications"] = default_notifications.value if afk_channel is not MISSING: if afk_channel is None: - fields['afk_channel_id'] = afk_channel + fields["afk_channel_id"] = afk_channel else: - fields['afk_channel_id'] = afk_channel.id + fields["afk_channel_id"] = afk_channel.id if system_channel is not MISSING: if system_channel is None: - fields['system_channel_id'] = system_channel + fields["system_channel_id"] = system_channel else: - fields['system_channel_id'] = system_channel.id + fields["system_channel_id"] = system_channel.id if rules_channel is not MISSING: if rules_channel is None: - fields['rules_channel_id'] = rules_channel + fields["rules_channel_id"] = rules_channel else: - fields['rules_channel_id'] = rules_channel.id + fields["rules_channel_id"] = rules_channel.id if public_updates_channel is not MISSING: if public_updates_channel is None: - fields['public_updates_channel_id'] = public_updates_channel + fields["public_updates_channel_id"] = public_updates_channel else: - fields['public_updates_channel_id'] = public_updates_channel.id + fields["public_updates_channel_id"] = public_updates_channel.id if owner is not MISSING: if self.owner_id != self._state.self_id: - raise InvalidArgument('To transfer ownership you must be the owner of the guild.') + raise InvalidArgument( + "To transfer ownership you must be the owner of the guild." + ) - fields['owner_id'] = owner.id + fields["owner_id"] = owner.id if region is not MISSING: - fields['region'] = str(region) + fields["region"] = str(region) if verification_level is not MISSING: if not isinstance(verification_level, VerificationLevel): - raise InvalidArgument('verification_level field must be of type VerificationLevel') + raise InvalidArgument( + "verification_level field must be of type VerificationLevel" + ) - fields['verification_level'] = verification_level.value + fields["verification_level"] = verification_level.value if explicit_content_filter is not MISSING: if not isinstance(explicit_content_filter, ContentFilter): - raise InvalidArgument('explicit_content_filter field must be of type ContentFilter') + raise InvalidArgument( + "explicit_content_filter field must be of type ContentFilter" + ) - fields['explicit_content_filter'] = explicit_content_filter.value + fields["explicit_content_filter"] = explicit_content_filter.value if system_channel_flags is not MISSING: if not isinstance(system_channel_flags, SystemChannelFlags): - raise InvalidArgument('system_channel_flags field must be of type SystemChannelFlags') + raise InvalidArgument( + "system_channel_flags field must be of type SystemChannelFlags" + ) - fields['system_channel_flags'] = system_channel_flags.value + fields["system_channel_flags"] = system_channel_flags.value if community is not MISSING: features = [] if community: - if 'rules_channel_id' in fields and 'public_updates_channel_id' in fields: - features.append('COMMUNITY') + if ( + "rules_channel_id" in fields + and "public_updates_channel_id" in fields + ): + features.append("COMMUNITY") else: raise InvalidArgument( - 'community field requires both rules_channel and public_updates_channel fields to be provided' + "community field requires both rules_channel and public_updates_channel fields to be provided" ) - fields['features'] = features + fields["features"] = features data = await http.edit_guild(self.id, reason=reason, **fields) return Guild(data=data, state=self._state) @@ -1611,9 +1703,11 @@ class Guild(Hashable): data = await self._state.http.get_all_guild_channels(self.id) def convert(d): - factory, ch_type = _guild_channel_factory(d['type']) + factory, ch_type = _guild_channel_factory(d["type"]) if factory is None: - raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(d)) + raise InvalidData( + "Unknown channel type {type} for channel ID {id}.".format_map(d) + ) channel = factory(guild=self, state=self._state, data=d) return channel @@ -1640,17 +1734,22 @@ class Guild(Hashable): The active threads """ data = await self._state.http.get_active_threads(self.id) - threads = [Thread(guild=self, state=self._state, data=d) for d in data.get('threads', [])] + threads = [ + Thread(guild=self, state=self._state, data=d) + for d in data.get("threads", []) + ] thread_lookup: Dict[int, Thread] = {thread.id: thread for thread in threads} - for member in data.get('members', []): - thread = thread_lookup.get(int(member['id'])) + for member in data.get("members", []): + thread = thread_lookup.get(int(member["id"])) if thread is not None: thread._add_member(ThreadMember(parent=thread, data=member)) return threads # TODO: Remove Optional typing here when async iterators are refactored - def fetch_members(self, *, limit: int = 1000, after: Optional[SnowflakeTime] = None) -> MemberIterator: + def fetch_members( + self, *, limit: int = 1000, after: Optional[SnowflakeTime] = None + ) -> MemberIterator: """Retrieves an :class:`.AsyncIterator` that enables receiving the guild's members. In order to use this, :meth:`Intents.members` must be enabled. @@ -1699,7 +1798,7 @@ class Guild(Hashable): """ if not self._state._intents.members: - raise ClientException('Intents.members must be enabled to use this.') + raise ClientException("Intents.members must be enabled to use this.") return MemberIterator(self, limit=limit, after=after) @@ -1760,7 +1859,9 @@ class Guild(Hashable): The :class:`BanEntry` object for the specified user. """ data: BanPayload = await self._state.http.get_ban(user.id, self.id) - return BanEntry(user=User(state=self._state, data=data['user']), reason=data['reason']) + return BanEntry( + user=User(state=self._state, data=data["user"]), reason=data["reason"] + ) async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, Thread]: """|coro| @@ -1793,16 +1894,18 @@ class Guild(Hashable): """ data = await self._state.http.get_channel(channel_id) - factory, ch_type = _threaded_guild_channel_factory(data['type']) + factory, ch_type = _threaded_guild_channel_factory(data["type"]) if factory is None: - raise InvalidData('Unknown channel type {type} for channel ID {id}.'.format_map(data)) + raise InvalidData( + "Unknown channel type {type} for channel ID {id}.".format_map(data) + ) if ch_type in (ChannelType.group, ChannelType.private): - raise InvalidData('Channel ID resolved to a private channel') + raise InvalidData("Channel ID resolved to a private channel") - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) if self.id != guild_id: - raise InvalidData('Guild ID resolved to a different guild') + raise InvalidData("Guild ID resolved to a different guild") channel: GuildChannel = factory(guild=self, state=self._state, data=data) # type: ignore return channel @@ -1829,7 +1932,10 @@ class Guild(Hashable): """ data: List[BanPayload] = await self._state.http.get_bans(self.id) - return [BanEntry(user=User(state=self._state, data=e['user']), reason=e['reason']) for e in data] + return [ + BanEntry(user=User(state=self._state, data=e["user"]), reason=e["reason"]) + for e in data + ] async def prune_members( self, @@ -1889,7 +1995,9 @@ class Guild(Hashable): """ if not isinstance(days, int): - raise InvalidArgument(f'Expected int for ``days``, received {days.__class__.__name__} instead.') + raise InvalidArgument( + f"Expected int for ``days``, received {days.__class__.__name__} instead." + ) if roles: role_ids = [str(role.id) for role in roles] @@ -1897,9 +2005,13 @@ class Guild(Hashable): role_ids = [] data = await self._state.http.prune_members( - self.id, days, compute_prune_count=compute_prune_count, roles=role_ids, reason=reason + self.id, + days, + compute_prune_count=compute_prune_count, + roles=role_ids, + reason=reason, ) - return data['pruned'] + return data["pruned"] async def templates(self) -> List[Template]: """|coro| @@ -1948,7 +2060,9 @@ class Guild(Hashable): data = await self._state.http.guild_webhooks(self.id) return [Webhook.from_state(d, state=self._state) for d in data] - async def estimate_pruned_members(self, *, days: int, roles: List[Snowflake] = MISSING) -> int: + async def estimate_pruned_members( + self, *, days: int, roles: List[Snowflake] = MISSING + ) -> int: """|coro| Similar to :meth:`prune_members` except instead of actually @@ -1981,7 +2095,9 @@ class Guild(Hashable): """ if not isinstance(days, int): - raise InvalidArgument(f'Expected int for ``days``, received {days.__class__.__name__} instead.') + raise InvalidArgument( + f"Expected int for ``days``, received {days.__class__.__name__} instead." + ) if roles: role_ids = [str(role.id) for role in roles] @@ -1989,7 +2105,7 @@ class Guild(Hashable): role_ids = [] data = await self._state.http.estimate_pruned_members(self.id, days, role_ids) - return data['pruned'] + return data["pruned"] async def invites(self) -> List[Invite]: """|coro| @@ -2015,12 +2131,16 @@ class Guild(Hashable): data = await self._state.http.invites_from(self.id) result = [] for invite in data: - channel = self.get_channel(int(invite['channel']['id'])) - result.append(Invite(state=self._state, data=invite, guild=self, channel=channel)) + channel = self.get_channel(int(invite["channel"]["id"])) + result.append( + Invite(state=self._state, data=invite, guild=self, channel=channel) + ) return result - async def create_template(self, *, name: str, description: str = MISSING) -> Template: + async def create_template( + self, *, name: str, description: str = MISSING + ) -> Template: """|coro| Creates a template for the guild. @@ -2039,10 +2159,10 @@ class Guild(Hashable): """ from .template import Template - payload = {'name': name} + payload = {"name": name} if description: - payload['description'] = description + payload["description"] = description data = await self._state.http.create_template(self.id, payload) @@ -2099,9 +2219,13 @@ class Guild(Hashable): data = await self._state.http.get_all_integrations(self.id) def convert(d): - factory, _ = _integration_factory(d['type']) + factory, _ = _integration_factory(d["type"]) if factory is None: - raise InvalidData('Unknown integration type {type!r} for integration ID {id}'.format_map(d)) + raise InvalidData( + "Unknown integration type {type!r} for integration ID {id}".format_map( + d + ) + ) return factory(guild=self, data=d) return [convert(d) for d in data] @@ -2206,25 +2330,29 @@ class Guild(Hashable): The created sticker. """ payload = { - 'name': name, + "name": name, } if description: - payload['description'] = description + payload["description"] = description try: emoji = unicodedata.name(emoji) except TypeError: pass else: - emoji = emoji.replace(' ', '_') + emoji = emoji.replace(" ", "_") - payload['tags'] = emoji + payload["tags"] = emoji - data = await self._state.http.create_guild_sticker(self.id, payload, file, reason) + data = await self._state.http.create_guild_sticker( + self.id, payload, file, reason + ) return self._state.store_sticker(self, data) - async def delete_sticker(self, sticker: Snowflake, *, reason: Optional[str] = None) -> None: + async def delete_sticker( + self, sticker: Snowflake, *, reason: Optional[str] = None + ) -> None: """|coro| Deletes the custom :class:`Sticker` from the guild. @@ -2351,10 +2479,14 @@ class Guild(Hashable): else: role_ids = [] - data = await self._state.http.create_custom_emoji(self.id, name, img, roles=role_ids, reason=reason) + data = await self._state.http.create_custom_emoji( + self.id, name, img, roles=role_ids, reason=reason + ) return self._state.store_emoji(self, data) - async def delete_emoji(self, emoji: Snowflake, *, reason: Optional[str] = None) -> None: + async def delete_emoji( + self, emoji: Snowflake, *, reason: Optional[str] = None + ) -> None: """|coro| Deletes the custom :class:`Emoji` from the guild. @@ -2486,24 +2618,24 @@ class Guild(Hashable): """ fields: Dict[str, Any] = {} if permissions is not MISSING: - fields['permissions'] = str(permissions.value) + fields["permissions"] = str(permissions.value) else: - fields['permissions'] = '0' + fields["permissions"] = "0" actual_colour = colour or color or Colour.default() if isinstance(actual_colour, int): - fields['color'] = actual_colour + fields["color"] = actual_colour else: - fields['color'] = actual_colour.value + fields["color"] = actual_colour.value if hoist is not MISSING: - fields['hoist'] = hoist + fields["hoist"] = hoist if mentionable is not MISSING: - fields['mentionable'] = mentionable + fields["mentionable"] = mentionable if name is not MISSING: - fields['name'] = name + fields["name"] = name data = await self._state.http.create_role(self.id, reason=reason, **fields) role = Role(guild=self, data=data, state=self._state) @@ -2511,7 +2643,9 @@ class Guild(Hashable): # TODO: add to cache return role - async def edit_role_positions(self, positions: Dict[Snowflake, int], *, reason: Optional[str] = None) -> List[Role]: + async def edit_role_positions( + self, positions: Dict[Snowflake, int], *, reason: Optional[str] = None + ) -> List[Role]: """|coro| Bulk edits a list of :class:`Role` in the guild. @@ -2556,16 +2690,18 @@ class Guild(Hashable): A list of all the roles in the guild. """ if not isinstance(positions, dict): - raise InvalidArgument('positions parameter expects a dict.') + raise InvalidArgument("positions parameter expects a dict.") role_positions: List[Dict[str, Any]] = [] for role, position in positions.items(): - payload = {'id': role.id, 'position': position} + payload = {"id": role.id, "position": position} role_positions.append(payload) - data = await self._state.http.move_role_position(self.id, role_positions, reason=reason) + data = await self._state.http.move_role_position( + self.id, role_positions, reason=reason + ) roles: List[Role] = [] for d in data: role = Role(guild=self, data=d, state=self._state) @@ -2687,19 +2823,19 @@ class Guild(Hashable): # we start with { code: abc } payload = await self._state.http.get_vanity_code(self.id) - if not payload['code']: + if not payload["code"]: return None # get the vanity URL channel since default channels aren't # reliable or a thing anymore - data = await self._state.http.get_invite(payload['code']) + data = await self._state.http.get_invite(payload["code"]) - channel = self.get_channel(int(data['channel']['id'])) - payload['revoked'] = False - payload['temporary'] = False - payload['max_uses'] = 0 - payload['max_age'] = 0 - payload['uses'] = payload.get('uses', 0) + channel = self.get_channel(int(data["channel"]["id"])) + payload["revoked"] = False + payload["temporary"] = False + payload["max_uses"] = 0 + payload["max_age"] = 0 + payload["uses"] = payload.get("uses", 0) return Invite(state=self._state, data=payload, guild=self, channel=channel) # TODO: use MISSING when async iterators get refactored @@ -2776,7 +2912,13 @@ class Guild(Hashable): action = action.value return AuditLogIterator( - self, before=before, after=after, limit=limit, oldest_first=oldest_first, user_id=user_id, action_type=action + self, + before=before, + after=after, + limit=limit, + oldest_first=oldest_first, + user_id=user_id, + action_type=action, ) async def widget(self) -> Widget: @@ -2804,7 +2946,9 @@ class Guild(Hashable): return Widget(state=self._state, data=data) - async def edit_widget(self, *, enabled: bool = MISSING, channel: Optional[Snowflake] = MISSING) -> None: + async def edit_widget( + self, *, enabled: bool = MISSING, channel: Optional[Snowflake] = MISSING + ) -> None: """|coro| Edits the widget of the guild. @@ -2830,9 +2974,9 @@ class Guild(Hashable): """ payload = {} if channel is not MISSING: - payload['channel_id'] = None if channel is None else channel.id + payload["channel_id"] = None if channel is None else channel.id if enabled is not MISSING: - payload['enabled'] = enabled + payload["enabled"] = enabled await self._state.http.edit_widget(self.id, payload=payload) @@ -2858,7 +3002,7 @@ class Guild(Hashable): """ if not self._state._intents.members: - raise ClientException('Intents.members must be enabled to use this.') + raise ClientException("Intents.members must be enabled to use this.") if not self._state.is_guild_evicted(self): return await self._state.chunk_guild(self, cache=cache) @@ -2919,28 +3063,37 @@ class Guild(Hashable): """ if presences and not self._state._intents.presences: - raise ClientException('Intents.presences must be enabled to use this.') + raise ClientException("Intents.presences must be enabled to use this.") if query is None: - if query == '': - raise ValueError('Cannot pass empty query string.') + if query == "": + raise ValueError("Cannot pass empty query string.") if user_ids is None: - raise ValueError('Must pass either query or user_ids') + raise ValueError("Must pass either query or user_ids") if user_ids is not None and query is not None: - raise ValueError('Cannot pass both query and user_ids') + raise ValueError("Cannot pass both query and user_ids") if user_ids is not None and not user_ids: - raise ValueError('user_ids must contain at least 1 value') + raise ValueError("user_ids must contain at least 1 value") limit = min(100, limit or 5) return await self._state.query_members( - self, query=query, limit=limit, user_ids=user_ids, presences=presences, cache=cache + self, + query=query, + limit=limit, + user_ids=user_ids, + presences=presences, + cache=cache, ) async def change_voice_state( - self, *, channel: Optional[VocalGuildChannel], self_mute: bool = False, self_deaf: bool = False + self, + *, + channel: Optional[VocalGuildChannel], + self_mute: bool = False, + self_deaf: bool = False, ): """|coro| diff --git a/discord/http.py b/discord/http.py index 7a4c2adc..d42764fe 100644 --- a/discord/http.py +++ b/discord/http.py @@ -48,7 +48,15 @@ import weakref import aiohttp -from .errors import HTTPException, Forbidden, NotFound, LoginFailure, DiscordServerError, GatewayNotFound, InvalidArgument +from .errors import ( + HTTPException, + Forbidden, + NotFound, + LoginFailure, + DiscordServerError, + GatewayNotFound, + InvalidArgument, +) from .gateway import DiscordClientWebSocketResponse from . import __version__, utils from .utils import MISSING @@ -89,16 +97,16 @@ if TYPE_CHECKING: from types import TracebackType - T = TypeVar('T') - BE = TypeVar('BE', bound=BaseException) - MU = TypeVar('MU', bound='MaybeUnlock') + T = TypeVar("T") + BE = TypeVar("BE", bound=BaseException) + MU = TypeVar("MU", bound="MaybeUnlock") Response = Coroutine[Any, Any, T] async def json_or_text(response: aiohttp.ClientResponse) -> Union[Dict[str, Any], str]: - text = await response.text(encoding='utf-8') + text = await response.text(encoding="utf-8") try: - if response.headers['content-type'] == 'application/json': + if response.headers["content-type"] == "application/json": return utils._from_json(text) except KeyError: # Thanks Cloudflare @@ -108,26 +116,31 @@ async def json_or_text(response: aiohttp.ClientResponse) -> Union[Dict[str, Any] class Route: - BASE: ClassVar[str] = 'https://discord.com/api/v8' + BASE: ClassVar[str] = "https://discord.com/api/v8" def __init__(self, method: str, path: str, **parameters: Any) -> None: self.path: str = path self.method: str = method url = self.BASE + self.path if parameters: - url = url.format_map({k: _uriquote(v) if isinstance(v, str) else v for k, v in parameters.items()}) + url = url.format_map( + { + k: _uriquote(v) if isinstance(v, str) else v + for k, v in parameters.items() + } + ) self.url: str = url # major parameters: - self.channel_id: Optional[Snowflake] = parameters.get('channel_id') - self.guild_id: Optional[Snowflake] = parameters.get('guild_id') - self.webhook_id: Optional[Snowflake] = parameters.get('webhook_id') - self.webhook_token: Optional[str] = parameters.get('webhook_token') + self.channel_id: Optional[Snowflake] = parameters.get("channel_id") + self.guild_id: Optional[Snowflake] = parameters.get("guild_id") + self.webhook_id: Optional[Snowflake] = parameters.get("webhook_id") + self.webhook_token: Optional[str] = parameters.get("webhook_token") @property def bucket(self) -> str: # the bucket is just method + path w/ major parameters - return f'{self.channel_id}:{self.guild_id}:{self.path}' + return f"{self.channel_id}:{self.guild_id}:{self.path}" class MaybeUnlock: @@ -153,7 +166,7 @@ class MaybeUnlock: # For some reason, the Discord voice websocket expects this header to be # completely lowercase while aiohttp respects spec and does it as case-insensitive -aiohttp.hdrs.WEBSOCKET = 'websocket' # type: ignore +aiohttp.hdrs.WEBSOCKET = "websocket" # type: ignore class HTTPClient: @@ -168,7 +181,9 @@ class HTTPClient: loop: Optional[asyncio.AbstractEventLoop] = None, unsync_clock: bool = True, ) -> None: - self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop + self.loop: asyncio.AbstractEventLoop = ( + asyncio.get_event_loop() if loop is None else loop + ) self.connector = connector self.__session: aiohttp.ClientSession = MISSING # filled in static_login self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary() @@ -180,26 +195,29 @@ class HTTPClient: self.proxy_auth: Optional[aiohttp.BasicAuth] = proxy_auth self.use_clock: bool = not unsync_clock - user_agent = 'DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}' - self.user_agent: str = user_agent.format(__version__, sys.version_info, aiohttp.__version__) + user_agent = "DiscordBot (https://github.com/Rapptz/discord.py {0}) Python/{1[0]}.{1[1]} aiohttp/{2}" + self.user_agent: str = user_agent.format( + __version__, sys.version_info, aiohttp.__version__ + ) def recreate(self) -> None: if self.__session.closed: self.__session = aiohttp.ClientSession( - connector=self.connector, ws_response_class=DiscordClientWebSocketResponse + connector=self.connector, + ws_response_class=DiscordClientWebSocketResponse, ) async def ws_connect(self, url: str, *, compress: int = 0) -> Any: kwargs = { - 'proxy_auth': self.proxy_auth, - 'proxy': self.proxy, - 'max_msg_size': 0, - 'timeout': 30.0, - 'autoclose': False, - 'headers': { - 'User-Agent': self.user_agent, + "proxy_auth": self.proxy_auth, + "proxy": self.proxy, + "max_msg_size": 0, + "timeout": 30.0, + "autoclose": False, + "headers": { + "User-Agent": self.user_agent, }, - 'compress': compress, + "compress": compress, } return await self.__session.ws_connect(url, **kwargs) @@ -224,31 +242,31 @@ class HTTPClient: # header creation headers: Dict[str, str] = { - 'User-Agent': self.user_agent, + "User-Agent": self.user_agent, } if self.token is not None: - headers['Authorization'] = 'Bot ' + self.token + headers["Authorization"] = "Bot " + self.token # some checking if it's a JSON request - if 'json' in kwargs: - headers['Content-Type'] = 'application/json' - kwargs['data'] = utils._to_json(kwargs.pop('json')) + if "json" in kwargs: + headers["Content-Type"] = "application/json" + kwargs["data"] = utils._to_json(kwargs.pop("json")) try: - reason = kwargs.pop('reason') + reason = kwargs.pop("reason") except KeyError: pass else: if reason: - headers['X-Audit-Log-Reason'] = _uriquote(reason, safe='/ ') + headers["X-Audit-Log-Reason"] = _uriquote(reason, safe="/ ") - kwargs['headers'] = headers + kwargs["headers"] = headers # Proxy support if self.proxy is not None: - kwargs['proxy'] = self.proxy + kwargs["proxy"] = self.proxy if self.proxy_auth is not None: - kwargs['proxy_auth'] = self.proxy_auth + kwargs["proxy_auth"] = self.proxy_auth if not self._global_over.is_set(): # wait until the global lock is complete @@ -267,55 +285,72 @@ class HTTPClient: form_data = aiohttp.FormData() for params in form: form_data.add_field(**params) - kwargs['data'] = form_data + kwargs["data"] = form_data try: - async with self.__session.request(method, url, **kwargs) as response: - _log.debug('%s %s with %s has returned %s', method, url, kwargs.get('data'), response.status) + async with self.__session.request( + method, url, **kwargs + ) as response: + _log.debug( + "%s %s with %s has returned %s", + method, + url, + kwargs.get("data"), + response.status, + ) # even errors have text involved in them so this is safe to call data = await json_or_text(response) # check if we have rate limit header information - remaining = response.headers.get('X-Ratelimit-Remaining') - if remaining == '0' and response.status != 429: + remaining = response.headers.get("X-Ratelimit-Remaining") + if remaining == "0" and response.status != 429: # we've depleted our current bucket - delta = utils._parse_ratelimit_header(response, use_clock=self.use_clock) - _log.debug('A rate limit bucket has been exhausted (bucket: %s, retry: %s).', bucket, delta) + delta = utils._parse_ratelimit_header( + response, use_clock=self.use_clock + ) + _log.debug( + "A rate limit bucket has been exhausted (bucket: %s, retry: %s).", + bucket, + delta, + ) maybe_lock.defer() self.loop.call_later(delta, lock.release) # the request was successful so just return the text/json if 300 > response.status >= 200: - _log.debug('%s %s has received %s', method, url, data) + _log.debug("%s %s has received %s", method, url, data) return data # we are being rate limited if response.status == 429: - if not response.headers.get('Via') or isinstance(data, str): + if not response.headers.get("Via") or isinstance(data, str): # Banned by Cloudflare more than likely. raise HTTPException(response, data) fmt = 'We are being rate limited. Retrying in %.2f seconds. Handled under the bucket "%s"' # sleep a bit - retry_after: float = data['retry_after'] + retry_after: float = data["retry_after"] _log.warning(fmt, retry_after, bucket) # check if it's a global rate limit - is_global = data.get('global', False) + is_global = data.get("global", False) if is_global: - _log.warning('Global rate limit has been hit. Retrying in %.2f seconds.', retry_after) + _log.warning( + "Global rate limit has been hit. Retrying in %.2f seconds.", + retry_after, + ) self._global_over.clear() await asyncio.sleep(retry_after) - _log.debug('Done sleeping for the rate limit. Retrying...') + _log.debug("Done sleeping for the rate limit. Retrying...") # release the global lock now that the # global rate limit has passed if is_global: self._global_over.set() - _log.debug('Global rate limit is now over.') + _log.debug("Global rate limit is now over.") continue @@ -349,18 +384,18 @@ class HTTPClient: raise HTTPException(response, data) - raise RuntimeError('Unreachable code in HTTP handling') + raise RuntimeError("Unreachable code in HTTP handling") async def get_from_cdn(self, url: str) -> bytes: async with self.__session.get(url) as resp: if resp.status == 200: return await resp.read() elif resp.status == 404: - raise NotFound(resp, 'asset not found') + raise NotFound(resp, "asset not found") elif resp.status == 403: - raise Forbidden(resp, 'cannot retrieve asset') + raise Forbidden(resp, "cannot retrieve asset") else: - raise HTTPException(resp, 'failed to get asset') + raise HTTPException(resp, "failed to get asset") # state management @@ -372,43 +407,51 @@ class HTTPClient: async def static_login(self, token: str) -> user.User: # Necessary to get aiohttp to stop complaining about session creation - self.__session = aiohttp.ClientSession(connector=self.connector, ws_response_class=DiscordClientWebSocketResponse) + self.__session = aiohttp.ClientSession( + connector=self.connector, ws_response_class=DiscordClientWebSocketResponse + ) old_token = self.token self.token = token try: - data = await self.request(Route('GET', '/users/@me')) + data = await self.request(Route("GET", "/users/@me")) except HTTPException as exc: self.token = old_token if exc.status == 401: - raise LoginFailure('Improper token has been passed.') from exc + raise LoginFailure("Improper token has been passed.") from exc raise return data def logout(self) -> Response[None]: - return self.request(Route('POST', '/auth/logout')) + return self.request(Route("POST", "/auth/logout")) # Group functionality - def start_group(self, user_id: Snowflake, recipients: List[int]) -> Response[channel.GroupDMChannel]: + def start_group( + self, user_id: Snowflake, recipients: List[int] + ) -> Response[channel.GroupDMChannel]: payload = { - 'recipients': recipients, + "recipients": recipients, } - return self.request(Route('POST', '/users/{user_id}/channels', user_id=user_id), json=payload) + return self.request( + Route("POST", "/users/{user_id}/channels", user_id=user_id), json=payload + ) def leave_group(self, channel_id) -> Response[None]: - return self.request(Route('DELETE', '/channels/{channel_id}', channel_id=channel_id)) + return self.request( + Route("DELETE", "/channels/{channel_id}", channel_id=channel_id) + ) # Message management def start_private_message(self, user_id: Snowflake) -> Response[channel.DMChannel]: payload = { - 'recipient_id': user_id, + "recipient_id": user_id, } - return self.request(Route('POST', '/users/@me/channels'), json=payload) + return self.request(Route("POST", "/users/@me/channels"), json=payload) def send_message( self, @@ -424,40 +467,42 @@ class HTTPClient: stickers: Optional[List[sticker.StickerItem]] = None, components: Optional[List[components.Component]] = None, ) -> Response[message.Message]: - r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id) + r = Route("POST", "/channels/{channel_id}/messages", channel_id=channel_id) payload = {} if content: - payload['content'] = content + payload["content"] = content if tts: - payload['tts'] = True + payload["tts"] = True if embed: - payload['embeds'] = [embed] + payload["embeds"] = [embed] if embeds: - payload['embeds'] = embeds + payload["embeds"] = embeds if nonce: - payload['nonce'] = nonce + payload["nonce"] = nonce if allowed_mentions: - payload['allowed_mentions'] = allowed_mentions + payload["allowed_mentions"] = allowed_mentions if message_reference: - payload['message_reference'] = message_reference + payload["message_reference"] = message_reference if components: - payload['components'] = components + payload["components"] = components if stickers: - payload['sticker_ids'] = stickers + payload["sticker_ids"] = stickers return self.request(r, json=payload) def send_typing(self, channel_id: Snowflake) -> Response[None]: - return self.request(Route('POST', '/channels/{channel_id}/typing', channel_id=channel_id)) + return self.request( + Route("POST", "/channels/{channel_id}/typing", channel_id=channel_id) + ) def send_multipart_helper( self, @@ -476,43 +521,43 @@ class HTTPClient: ) -> Response[message.Message]: form = [] - payload: Dict[str, Any] = {'tts': tts} + payload: Dict[str, Any] = {"tts": tts} if content: - payload['content'] = content + payload["content"] = content if embed: - payload['embeds'] = [embed] + payload["embeds"] = [embed] if embeds: - payload['embeds'] = embeds + payload["embeds"] = embeds if nonce: - payload['nonce'] = nonce + payload["nonce"] = nonce if allowed_mentions: - payload['allowed_mentions'] = allowed_mentions + payload["allowed_mentions"] = allowed_mentions if message_reference: - payload['message_reference'] = message_reference + payload["message_reference"] = message_reference if components: - payload['components'] = components + payload["components"] = components if stickers: - payload['sticker_ids'] = stickers + payload["sticker_ids"] = stickers - form.append({'name': 'payload_json', 'value': utils._to_json(payload)}) + form.append({"name": "payload_json", "value": utils._to_json(payload)}) if len(files) == 1: file = files[0] form.append( { - 'name': 'file', - 'value': file.fp, - 'filename': file.filename, - 'content_type': 'application/octet-stream', + "name": "file", + "value": file.fp, + "filename": file.filename, + "content_type": "application/octet-stream", } ) else: for index, file in enumerate(files): form.append( { - 'name': f'file{index}', - 'value': file.fp, - 'filename': file.filename, - 'content_type': 'application/octet-stream', + "name": f"file{index}", + "value": file.fp, + "filename": file.filename, + "content_type": "application/octet-stream", } ) @@ -533,7 +578,7 @@ class HTTPClient: stickers: Optional[List[sticker.StickerItem]] = None, components: Optional[List[components.Component]] = None, ) -> Response[message.Message]: - r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id) + r = Route("POST", "/channels/{channel_id}/messages", channel_id=channel_id) return self.send_multipart_helper( r, files=files, @@ -549,29 +594,53 @@ class HTTPClient: ) def delete_message( - self, channel_id: Snowflake, message_id: Snowflake, *, reason: Optional[str] = None + self, + channel_id: Snowflake, + message_id: Snowflake, + *, + reason: Optional[str] = None, ) -> Response[None]: - r = Route('DELETE', '/channels/{channel_id}/messages/{message_id}', channel_id=channel_id, message_id=message_id) + r = Route( + "DELETE", + "/channels/{channel_id}/messages/{message_id}", + channel_id=channel_id, + message_id=message_id, + ) return self.request(r, reason=reason) def delete_messages( - self, channel_id: Snowflake, message_ids: SnowflakeList, *, reason: Optional[str] = None + self, + channel_id: Snowflake, + message_ids: SnowflakeList, + *, + reason: Optional[str] = None, ) -> Response[None]: - r = Route('POST', '/channels/{channel_id}/messages/bulk-delete', channel_id=channel_id) + r = Route( + "POST", "/channels/{channel_id}/messages/bulk-delete", channel_id=channel_id + ) payload = { - 'messages': message_ids, + "messages": message_ids, } return self.request(r, json=payload, reason=reason) - def edit_message(self, channel_id: Snowflake, message_id: Snowflake, **fields: Any) -> Response[message.Message]: - r = Route('PATCH', '/channels/{channel_id}/messages/{message_id}', channel_id=channel_id, message_id=message_id) + def edit_message( + self, channel_id: Snowflake, message_id: Snowflake, **fields: Any + ) -> Response[message.Message]: + r = Route( + "PATCH", + "/channels/{channel_id}/messages/{message_id}", + channel_id=channel_id, + message_id=message_id, + ) return self.request(r, json=fields) - def add_reaction(self, channel_id: Snowflake, message_id: Snowflake, emoji: str) -> Response[None]: + def add_reaction( + self, channel_id: Snowflake, message_id: Snowflake, emoji: str + ) -> Response[None]: r = Route( - 'PUT', - '/channels/{channel_id}/messages/{message_id}/reactions/{emoji}/@me', + "PUT", + "/channels/{channel_id}/messages/{message_id}/reactions/{emoji}/@me", channel_id=channel_id, message_id=message_id, emoji=emoji, @@ -579,11 +648,15 @@ class HTTPClient: return self.request(r) def remove_reaction( - self, channel_id: Snowflake, message_id: Snowflake, emoji: str, member_id: Snowflake + self, + channel_id: Snowflake, + message_id: Snowflake, + emoji: str, + member_id: Snowflake, ) -> Response[None]: r = Route( - 'DELETE', - '/channels/{channel_id}/messages/{message_id}/reactions/{emoji}/{member_id}', + "DELETE", + "/channels/{channel_id}/messages/{message_id}/reactions/{emoji}/{member_id}", channel_id=channel_id, message_id=message_id, member_id=member_id, @@ -591,10 +664,12 @@ class HTTPClient: ) return self.request(r) - def remove_own_reaction(self, channel_id: Snowflake, message_id: Snowflake, emoji: str) -> Response[None]: + def remove_own_reaction( + self, channel_id: Snowflake, message_id: Snowflake, emoji: str + ) -> Response[None]: r = Route( - 'DELETE', - '/channels/{channel_id}/messages/{message_id}/reactions/{emoji}/@me', + "DELETE", + "/channels/{channel_id}/messages/{message_id}/reactions/{emoji}/@me", channel_id=channel_id, message_id=message_id, emoji=emoji, @@ -610,46 +685,57 @@ class HTTPClient: after: Optional[Snowflake] = None, ) -> Response[List[user.User]]: r = Route( - 'GET', - '/channels/{channel_id}/messages/{message_id}/reactions/{emoji}', + "GET", + "/channels/{channel_id}/messages/{message_id}/reactions/{emoji}", channel_id=channel_id, message_id=message_id, emoji=emoji, ) params: Dict[str, Any] = { - 'limit': limit, + "limit": limit, } if after: - params['after'] = after + params["after"] = after return self.request(r, params=params) - def clear_reactions(self, channel_id: Snowflake, message_id: Snowflake) -> Response[None]: + def clear_reactions( + self, channel_id: Snowflake, message_id: Snowflake + ) -> Response[None]: r = Route( - 'DELETE', - '/channels/{channel_id}/messages/{message_id}/reactions', + "DELETE", + "/channels/{channel_id}/messages/{message_id}/reactions", channel_id=channel_id, message_id=message_id, ) return self.request(r) - def clear_single_reaction(self, channel_id: Snowflake, message_id: Snowflake, emoji: str) -> Response[None]: + def clear_single_reaction( + self, channel_id: Snowflake, message_id: Snowflake, emoji: str + ) -> Response[None]: r = Route( - 'DELETE', - '/channels/{channel_id}/messages/{message_id}/reactions/{emoji}', + "DELETE", + "/channels/{channel_id}/messages/{message_id}/reactions/{emoji}", channel_id=channel_id, message_id=message_id, emoji=emoji, ) return self.request(r) - def get_message(self, channel_id: Snowflake, message_id: Snowflake) -> Response[message.Message]: - r = Route('GET', '/channels/{channel_id}/messages/{message_id}', channel_id=channel_id, message_id=message_id) + def get_message( + self, channel_id: Snowflake, message_id: Snowflake + ) -> Response[message.Message]: + r = Route( + "GET", + "/channels/{channel_id}/messages/{message_id}", + channel_id=channel_id, + message_id=message_id, + ) return self.request(r) def get_channel(self, channel_id: Snowflake) -> Response[channel.Channel]: - r = Route('GET', '/channels/{channel_id}', channel_id=channel_id) + r = Route("GET", "/channels/{channel_id}", channel_id=channel_id) return self.request(r) def logs_from( @@ -661,56 +747,74 @@ class HTTPClient: around: Optional[Snowflake] = None, ) -> Response[List[message.Message]]: params: Dict[str, Any] = { - 'limit': limit, + "limit": limit, } if before is not None: - params['before'] = before + params["before"] = before if after is not None: - params['after'] = after + params["after"] = after if around is not None: - params['around'] = around + params["around"] = around - return self.request(Route('GET', '/channels/{channel_id}/messages', channel_id=channel_id), params=params) + return self.request( + Route("GET", "/channels/{channel_id}/messages", channel_id=channel_id), + params=params, + ) - def publish_message(self, channel_id: Snowflake, message_id: Snowflake) -> Response[message.Message]: + def publish_message( + self, channel_id: Snowflake, message_id: Snowflake + ) -> Response[message.Message]: return self.request( Route( - 'POST', - '/channels/{channel_id}/messages/{message_id}/crosspost', + "POST", + "/channels/{channel_id}/messages/{message_id}/crosspost", channel_id=channel_id, message_id=message_id, ) ) - def pin_message(self, channel_id: Snowflake, message_id: Snowflake, reason: Optional[str] = None) -> Response[None]: + def pin_message( + self, channel_id: Snowflake, message_id: Snowflake, reason: Optional[str] = None + ) -> Response[None]: r = Route( - 'PUT', - '/channels/{channel_id}/pins/{message_id}', + "PUT", + "/channels/{channel_id}/pins/{message_id}", channel_id=channel_id, message_id=message_id, ) return self.request(r, reason=reason) - def unpin_message(self, channel_id: Snowflake, message_id: Snowflake, reason: Optional[str] = None) -> Response[None]: + def unpin_message( + self, channel_id: Snowflake, message_id: Snowflake, reason: Optional[str] = None + ) -> Response[None]: r = Route( - 'DELETE', - '/channels/{channel_id}/pins/{message_id}', + "DELETE", + "/channels/{channel_id}/pins/{message_id}", channel_id=channel_id, message_id=message_id, ) return self.request(r, reason=reason) def pins_from(self, channel_id: Snowflake) -> Response[List[message.Message]]: - return self.request(Route('GET', '/channels/{channel_id}/pins', channel_id=channel_id)) + return self.request( + Route("GET", "/channels/{channel_id}/pins", channel_id=channel_id) + ) # Member management - def kick(self, user_id: Snowflake, guild_id: Snowflake, reason: Optional[str] = None) -> Response[None]: - r = Route('DELETE', '/guilds/{guild_id}/members/{user_id}', guild_id=guild_id, user_id=user_id) + def kick( + self, user_id: Snowflake, guild_id: Snowflake, reason: Optional[str] = None + ) -> Response[None]: + r = Route( + "DELETE", + "/guilds/{guild_id}/members/{user_id}", + guild_id=guild_id, + user_id=user_id, + ) if reason: # thanks aiohttp - r.url = f'{r.url}?reason={_uriquote(reason)}' + r.url = f"{r.url}?reason={_uriquote(reason)}" return self.request(r) @@ -721,15 +825,27 @@ class HTTPClient: delete_message_days: int = 1, reason: Optional[str] = None, ) -> Response[None]: - r = Route('PUT', '/guilds/{guild_id}/bans/{user_id}', guild_id=guild_id, user_id=user_id) + r = Route( + "PUT", + "/guilds/{guild_id}/bans/{user_id}", + guild_id=guild_id, + user_id=user_id, + ) params = { - 'delete_message_days': delete_message_days, + "delete_message_days": delete_message_days, } return self.request(r, params=params, reason=reason) - def unban(self, user_id: Snowflake, guild_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]: - r = Route('DELETE', '/guilds/{guild_id}/bans/{user_id}', guild_id=guild_id, user_id=user_id) + def unban( + self, user_id: Snowflake, guild_id: Snowflake, *, reason: Optional[str] = None + ) -> Response[None]: + r = Route( + "DELETE", + "/guilds/{guild_id}/bans/{user_id}", + guild_id=guild_id, + user_id=user_id, + ) return self.request(r, reason=reason) def guild_voice_state( @@ -741,18 +857,23 @@ class HTTPClient: deafen: Optional[bool] = None, reason: Optional[str] = None, ) -> Response[member.Member]: - r = Route('PATCH', '/guilds/{guild_id}/members/{user_id}', guild_id=guild_id, user_id=user_id) + r = Route( + "PATCH", + "/guilds/{guild_id}/members/{user_id}", + guild_id=guild_id, + user_id=user_id, + ) payload = {} if mute is not None: - payload['mute'] = mute + payload["mute"] = mute if deafen is not None: - payload['deaf'] = deafen + payload["deaf"] = deafen return self.request(r, json=payload, reason=reason) def edit_profile(self, payload: Dict[str, Any]) -> Response[user.User]: - return self.request(Route('PATCH', '/users/@me'), json=payload) + return self.request(Route("PATCH", "/users/@me"), json=payload) def change_my_nickname( self, @@ -761,9 +882,9 @@ class HTTPClient: *, reason: Optional[str] = None, ) -> Response[member.Nickname]: - r = Route('PATCH', '/guilds/{guild_id}/members/@me/nick', guild_id=guild_id) + r = Route("PATCH", "/guilds/{guild_id}/members/@me/nick", guild_id=guild_id) payload = { - 'nick': nickname, + "nick": nickname, } return self.request(r, json=payload, reason=reason) @@ -775,18 +896,32 @@ class HTTPClient: *, reason: Optional[str] = None, ) -> Response[member.Member]: - r = Route('PATCH', '/guilds/{guild_id}/members/{user_id}', guild_id=guild_id, user_id=user_id) + r = Route( + "PATCH", + "/guilds/{guild_id}/members/{user_id}", + guild_id=guild_id, + user_id=user_id, + ) payload = { - 'nick': nickname, + "nick": nickname, } return self.request(r, json=payload, reason=reason) - def edit_my_voice_state(self, guild_id: Snowflake, payload: Dict[str, Any]) -> Response[None]: - r = Route('PATCH', '/guilds/{guild_id}/voice-states/@me', guild_id=guild_id) + def edit_my_voice_state( + self, guild_id: Snowflake, payload: Dict[str, Any] + ) -> Response[None]: + r = Route("PATCH", "/guilds/{guild_id}/voice-states/@me", guild_id=guild_id) return self.request(r, json=payload) - def edit_voice_state(self, guild_id: Snowflake, user_id: Snowflake, payload: Dict[str, Any]) -> Response[None]: - r = Route('PATCH', '/guilds/{guild_id}/voice-states/{user_id}', guild_id=guild_id, user_id=user_id) + def edit_voice_state( + self, guild_id: Snowflake, user_id: Snowflake, payload: Dict[str, Any] + ) -> Response[None]: + r = Route( + "PATCH", + "/guilds/{guild_id}/voice-states/{user_id}", + guild_id=guild_id, + user_id=user_id, + ) return self.request(r, json=payload) def edit_member( @@ -797,7 +932,12 @@ class HTTPClient: reason: Optional[str] = None, **fields: Any, ) -> Response[member.MemberWithUser]: - r = Route('PATCH', '/guilds/{guild_id}/members/{user_id}', guild_id=guild_id, user_id=user_id) + r = Route( + "PATCH", + "/guilds/{guild_id}/members/{user_id}", + guild_id=guild_id, + user_id=user_id, + ) return self.request(r, json=fields, reason=reason) # Channel management @@ -809,25 +949,25 @@ class HTTPClient: reason: Optional[str] = None, **options: Any, ) -> Response[channel.Channel]: - r = Route('PATCH', '/channels/{channel_id}', channel_id=channel_id) + r = Route("PATCH", "/channels/{channel_id}", channel_id=channel_id) valid_keys = ( - 'name', - 'parent_id', - 'topic', - 'bitrate', - 'nsfw', - 'user_limit', - 'position', - 'permission_overwrites', - 'rate_limit_per_user', - 'type', - 'rtc_region', - 'video_quality_mode', - 'archived', - 'auto_archive_duration', - 'locked', - 'invitable', - 'default_auto_archive_duration', + "name", + "parent_id", + "topic", + "bitrate", + "nsfw", + "user_limit", + "position", + "permission_overwrites", + "rate_limit_per_user", + "type", + "rtc_region", + "video_quality_mode", + "archived", + "auto_archive_duration", + "locked", + "invitable", + "default_auto_archive_duration", ) payload = {k: v for k, v in options.items() if k in valid_keys} return self.request(r, reason=reason, json=payload) @@ -839,7 +979,7 @@ class HTTPClient: *, reason: Optional[str] = None, ) -> Response[None]: - r = Route('PATCH', '/guilds/{guild_id}/channels', guild_id=guild_id) + r = Route("PATCH", "/guilds/{guild_id}/channels", guild_id=guild_id) return self.request(r, json=data, reason=reason) def create_channel( @@ -851,26 +991,32 @@ class HTTPClient: **options: Any, ) -> Response[channel.GuildChannel]: payload = { - 'type': channel_type, + "type": channel_type, } valid_keys = ( - 'name', - 'parent_id', - 'topic', - 'bitrate', - 'nsfw', - 'user_limit', - 'position', - 'permission_overwrites', - 'rate_limit_per_user', - 'rtc_region', - 'video_quality_mode', - 'auto_archive_duration', + "name", + "parent_id", + "topic", + "bitrate", + "nsfw", + "user_limit", + "position", + "permission_overwrites", + "rate_limit_per_user", + "rtc_region", + "video_quality_mode", + "auto_archive_duration", + ) + payload.update( + {k: v for k, v in options.items() if k in valid_keys and v is not None} ) - payload.update({k: v for k, v in options.items() if k in valid_keys and v is not None}) - return self.request(Route('POST', '/guilds/{guild_id}/channels', guild_id=guild_id), json=payload, reason=reason) + return self.request( + Route("POST", "/guilds/{guild_id}/channels", guild_id=guild_id), + json=payload, + reason=reason, + ) def delete_channel( self, @@ -878,7 +1024,10 @@ class HTTPClient: *, reason: Optional[str] = None, ) -> Response[None]: - return self.request(Route('DELETE', '/channels/{channel_id}', channel_id=channel_id), reason=reason) + return self.request( + Route("DELETE", "/channels/{channel_id}", channel_id=channel_id), + reason=reason, + ) # Thread management @@ -892,12 +1041,15 @@ class HTTPClient: reason: Optional[str] = None, ) -> Response[threads.Thread]: payload = { - 'name': name, - 'auto_archive_duration': auto_archive_duration, + "name": name, + "auto_archive_duration": auto_archive_duration, } route = Route( - 'POST', '/channels/{channel_id}/messages/{message_id}/threads', channel_id=channel_id, message_id=message_id + "POST", + "/channels/{channel_id}/messages/{message_id}/threads", + channel_id=channel_id, + message_id=message_id, ) return self.request(route, json=payload, reason=reason) @@ -912,68 +1064,112 @@ class HTTPClient: reason: Optional[str] = None, ) -> Response[threads.Thread]: payload = { - 'name': name, - 'auto_archive_duration': auto_archive_duration, - 'type': type, - 'invitable': invitable, + "name": name, + "auto_archive_duration": auto_archive_duration, + "type": type, + "invitable": invitable, } - route = Route('POST', '/channels/{channel_id}/threads', channel_id=channel_id) + route = Route("POST", "/channels/{channel_id}/threads", channel_id=channel_id) return self.request(route, json=payload, reason=reason) def join_thread(self, channel_id: Snowflake) -> Response[None]: - return self.request(Route('POST', '/channels/{channel_id}/thread-members/@me', channel_id=channel_id)) - - def add_user_to_thread(self, channel_id: Snowflake, user_id: Snowflake) -> Response[None]: return self.request( - Route('PUT', '/channels/{channel_id}/thread-members/{user_id}', channel_id=channel_id, user_id=user_id) + Route( + "POST", + "/channels/{channel_id}/thread-members/@me", + channel_id=channel_id, + ) + ) + + def add_user_to_thread( + self, channel_id: Snowflake, user_id: Snowflake + ) -> Response[None]: + return self.request( + Route( + "PUT", + "/channels/{channel_id}/thread-members/{user_id}", + channel_id=channel_id, + user_id=user_id, + ) ) def leave_thread(self, channel_id: Snowflake) -> Response[None]: - return self.request(Route('DELETE', '/channels/{channel_id}/thread-members/@me', channel_id=channel_id)) + return self.request( + Route( + "DELETE", + "/channels/{channel_id}/thread-members/@me", + channel_id=channel_id, + ) + ) - def remove_user_from_thread(self, channel_id: Snowflake, user_id: Snowflake) -> Response[None]: - route = Route('DELETE', '/channels/{channel_id}/thread-members/{user_id}', channel_id=channel_id, user_id=user_id) + def remove_user_from_thread( + self, channel_id: Snowflake, user_id: Snowflake + ) -> Response[None]: + route = Route( + "DELETE", + "/channels/{channel_id}/thread-members/{user_id}", + channel_id=channel_id, + user_id=user_id, + ) return self.request(route) def get_public_archived_threads( self, channel_id: Snowflake, before: Optional[Snowflake] = None, limit: int = 50 ) -> Response[threads.ThreadPaginationPayload]: - route = Route('GET', '/channels/{channel_id}/threads/archived/public', channel_id=channel_id) + route = Route( + "GET", + "/channels/{channel_id}/threads/archived/public", + channel_id=channel_id, + ) params = {} if before: - params['before'] = before - params['limit'] = limit + params["before"] = before + params["limit"] = limit return self.request(route, params=params) def get_private_archived_threads( self, channel_id: Snowflake, before: Optional[Snowflake] = None, limit: int = 50 ) -> Response[threads.ThreadPaginationPayload]: - route = Route('GET', '/channels/{channel_id}/threads/archived/private', channel_id=channel_id) + route = Route( + "GET", + "/channels/{channel_id}/threads/archived/private", + channel_id=channel_id, + ) params = {} if before: - params['before'] = before - params['limit'] = limit + params["before"] = before + params["limit"] = limit return self.request(route, params=params) def get_joined_private_archived_threads( self, channel_id: Snowflake, before: Optional[Snowflake] = None, limit: int = 50 ) -> Response[threads.ThreadPaginationPayload]: - route = Route('GET', '/channels/{channel_id}/users/@me/threads/archived/private', channel_id=channel_id) + route = Route( + "GET", + "/channels/{channel_id}/users/@me/threads/archived/private", + channel_id=channel_id, + ) params = {} if before: - params['before'] = before - params['limit'] = limit + params["before"] = before + params["limit"] = limit return self.request(route, params=params) - def get_active_threads(self, guild_id: Snowflake) -> Response[threads.ThreadPaginationPayload]: - route = Route('GET', '/guilds/{guild_id}/threads/active', guild_id=guild_id) + def get_active_threads( + self, guild_id: Snowflake + ) -> Response[threads.ThreadPaginationPayload]: + route = Route("GET", "/guilds/{guild_id}/threads/active", guild_id=guild_id) return self.request(route) - def get_thread_members(self, channel_id: Snowflake) -> Response[List[threads.ThreadMember]]: - route = Route('GET', '/channels/{channel_id}/thread-members', channel_id=channel_id) + def get_thread_members( + self, channel_id: Snowflake + ) -> Response[List[threads.ThreadMember]]: + route = Route( + "GET", "/channels/{channel_id}/thread-members", channel_id=channel_id + ) return self.request(route) # Webhook management @@ -987,22 +1183,30 @@ class HTTPClient: reason: Optional[str] = None, ) -> Response[webhook.Webhook]: payload: Dict[str, Any] = { - 'name': name, + "name": name, } if avatar is not None: - payload['avatar'] = avatar + payload["avatar"] = avatar - r = Route('POST', '/channels/{channel_id}/webhooks', channel_id=channel_id) + r = Route("POST", "/channels/{channel_id}/webhooks", channel_id=channel_id) return self.request(r, json=payload, reason=reason) - def channel_webhooks(self, channel_id: Snowflake) -> Response[List[webhook.Webhook]]: - return self.request(Route('GET', '/channels/{channel_id}/webhooks', channel_id=channel_id)) + def channel_webhooks( + self, channel_id: Snowflake + ) -> Response[List[webhook.Webhook]]: + return self.request( + Route("GET", "/channels/{channel_id}/webhooks", channel_id=channel_id) + ) def guild_webhooks(self, guild_id: Snowflake) -> Response[List[webhook.Webhook]]: - return self.request(Route('GET', '/guilds/{guild_id}/webhooks', guild_id=guild_id)) + return self.request( + Route("GET", "/guilds/{guild_id}/webhooks", guild_id=guild_id) + ) def get_webhook(self, webhook_id: Snowflake) -> Response[webhook.Webhook]: - return self.request(Route('GET', '/webhooks/{webhook_id}', webhook_id=webhook_id)) + return self.request( + Route("GET", "/webhooks/{webhook_id}", webhook_id=webhook_id) + ) def follow_webhook( self, @@ -1011,10 +1215,12 @@ class HTTPClient: reason: Optional[str] = None, ) -> Response[None]: payload = { - 'webhook_channel_id': str(webhook_channel_id), + "webhook_channel_id": str(webhook_channel_id), } return self.request( - Route('POST', '/channels/{channel_id}/followers', channel_id=channel_id), json=payload, reason=reason + Route("POST", "/channels/{channel_id}/followers", channel_id=channel_id), + json=payload, + reason=reason, ) # Guild management @@ -1026,126 +1232,199 @@ class HTTPClient: after: Optional[Snowflake] = None, ) -> Response[List[guild.Guild]]: params: Dict[str, Any] = { - 'limit': limit, + "limit": limit, } if before: - params['before'] = before + params["before"] = before if after: - params['after'] = after + params["after"] = after - return self.request(Route('GET', '/users/@me/guilds'), params=params) + return self.request(Route("GET", "/users/@me/guilds"), params=params) def leave_guild(self, guild_id: Snowflake) -> Response[None]: - return self.request(Route('DELETE', '/users/@me/guilds/{guild_id}', guild_id=guild_id)) + return self.request( + Route("DELETE", "/users/@me/guilds/{guild_id}", guild_id=guild_id) + ) def get_guild(self, guild_id: Snowflake) -> Response[guild.Guild]: - return self.request(Route('GET', '/guilds/{guild_id}', guild_id=guild_id)) + return self.request(Route("GET", "/guilds/{guild_id}", guild_id=guild_id)) def delete_guild(self, guild_id: Snowflake) -> Response[None]: - return self.request(Route('DELETE', '/guilds/{guild_id}', guild_id=guild_id)) + return self.request(Route("DELETE", "/guilds/{guild_id}", guild_id=guild_id)) - def create_guild(self, name: str, region: str, icon: Optional[str]) -> Response[guild.Guild]: + def create_guild( + self, name: str, region: str, icon: Optional[str] + ) -> Response[guild.Guild]: payload = { - 'name': name, - 'region': region, + "name": name, + "region": region, } if icon: - payload['icon'] = icon + payload["icon"] = icon - return self.request(Route('POST', '/guilds'), json=payload) + return self.request(Route("POST", "/guilds"), json=payload) - def edit_guild(self, guild_id: Snowflake, *, reason: Optional[str] = None, **fields: Any) -> Response[guild.Guild]: + def edit_guild( + self, guild_id: Snowflake, *, reason: Optional[str] = None, **fields: Any + ) -> Response[guild.Guild]: valid_keys = ( - 'name', - 'region', - 'icon', - 'afk_timeout', - 'owner_id', - 'afk_channel_id', - 'splash', - 'discovery_splash', - 'features', - 'verification_level', - 'system_channel_id', - 'default_message_notifications', - 'description', - 'explicit_content_filter', - 'banner', - 'system_channel_flags', - 'rules_channel_id', - 'public_updates_channel_id', - 'preferred_locale', + "name", + "region", + "icon", + "afk_timeout", + "owner_id", + "afk_channel_id", + "splash", + "discovery_splash", + "features", + "verification_level", + "system_channel_id", + "default_message_notifications", + "description", + "explicit_content_filter", + "banner", + "system_channel_flags", + "rules_channel_id", + "public_updates_channel_id", + "preferred_locale", ) payload = {k: v for k, v in fields.items() if k in valid_keys} - return self.request(Route('PATCH', '/guilds/{guild_id}', guild_id=guild_id), json=payload, reason=reason) + return self.request( + Route("PATCH", "/guilds/{guild_id}", guild_id=guild_id), + json=payload, + reason=reason, + ) def get_template(self, code: str) -> Response[template.Template]: - return self.request(Route('GET', '/guilds/templates/{code}', code=code)) + return self.request(Route("GET", "/guilds/templates/{code}", code=code)) def guild_templates(self, guild_id: Snowflake) -> Response[List[template.Template]]: - return self.request(Route('GET', '/guilds/{guild_id}/templates', guild_id=guild_id)) + return self.request( + Route("GET", "/guilds/{guild_id}/templates", guild_id=guild_id) + ) - def create_template(self, guild_id: Snowflake, payload: template.CreateTemplate) -> Response[template.Template]: - return self.request(Route('POST', '/guilds/{guild_id}/templates', guild_id=guild_id), json=payload) + def create_template( + self, guild_id: Snowflake, payload: template.CreateTemplate + ) -> Response[template.Template]: + return self.request( + Route("POST", "/guilds/{guild_id}/templates", guild_id=guild_id), + json=payload, + ) - def sync_template(self, guild_id: Snowflake, code: str) -> Response[template.Template]: - return self.request(Route('PUT', '/guilds/{guild_id}/templates/{code}', guild_id=guild_id, code=code)) + def sync_template( + self, guild_id: Snowflake, code: str + ) -> Response[template.Template]: + return self.request( + Route( + "PUT", + "/guilds/{guild_id}/templates/{code}", + guild_id=guild_id, + code=code, + ) + ) - def edit_template(self, guild_id: Snowflake, code: str, payload) -> Response[template.Template]: + def edit_template( + self, guild_id: Snowflake, code: str, payload + ) -> Response[template.Template]: valid_keys = ( - 'name', - 'description', + "name", + "description", ) payload = {k: v for k, v in payload.items() if k in valid_keys} return self.request( - Route('PATCH', '/guilds/{guild_id}/templates/{code}', guild_id=guild_id, code=code), json=payload + Route( + "PATCH", + "/guilds/{guild_id}/templates/{code}", + guild_id=guild_id, + code=code, + ), + json=payload, ) def delete_template(self, guild_id: Snowflake, code: str) -> Response[None]: - return self.request(Route('DELETE', '/guilds/{guild_id}/templates/{code}', guild_id=guild_id, code=code)) + return self.request( + Route( + "DELETE", + "/guilds/{guild_id}/templates/{code}", + guild_id=guild_id, + code=code, + ) + ) - def create_from_template(self, code: str, name: str, region: str, icon: Optional[str]) -> Response[guild.Guild]: + def create_from_template( + self, code: str, name: str, region: str, icon: Optional[str] + ) -> Response[guild.Guild]: payload = { - 'name': name, - 'region': region, + "name": name, + "region": region, } if icon: - payload['icon'] = icon - return self.request(Route('POST', '/guilds/templates/{code}', code=code), json=payload) + payload["icon"] = icon + return self.request( + Route("POST", "/guilds/templates/{code}", code=code), json=payload + ) def get_bans(self, guild_id: Snowflake) -> Response[List[guild.Ban]]: - return self.request(Route('GET', '/guilds/{guild_id}/bans', guild_id=guild_id)) + return self.request(Route("GET", "/guilds/{guild_id}/bans", guild_id=guild_id)) def get_ban(self, user_id: Snowflake, guild_id: Snowflake) -> Response[guild.Ban]: - return self.request(Route('GET', '/guilds/{guild_id}/bans/{user_id}', guild_id=guild_id, user_id=user_id)) + return self.request( + Route( + "GET", + "/guilds/{guild_id}/bans/{user_id}", + guild_id=guild_id, + user_id=user_id, + ) + ) def get_vanity_code(self, guild_id: Snowflake) -> Response[invite.VanityInvite]: - return self.request(Route('GET', '/guilds/{guild_id}/vanity-url', guild_id=guild_id)) + return self.request( + Route("GET", "/guilds/{guild_id}/vanity-url", guild_id=guild_id) + ) - def change_vanity_code(self, guild_id: Snowflake, code: str, *, reason: Optional[str] = None) -> Response[None]: - payload: Dict[str, Any] = {'code': code} - return self.request(Route('PATCH', '/guilds/{guild_id}/vanity-url', guild_id=guild_id), json=payload, reason=reason) + def change_vanity_code( + self, guild_id: Snowflake, code: str, *, reason: Optional[str] = None + ) -> Response[None]: + payload: Dict[str, Any] = {"code": code} + return self.request( + Route("PATCH", "/guilds/{guild_id}/vanity-url", guild_id=guild_id), + json=payload, + reason=reason, + ) - def get_all_guild_channels(self, guild_id: Snowflake) -> Response[List[guild.GuildChannel]]: - return self.request(Route('GET', '/guilds/{guild_id}/channels', guild_id=guild_id)) + def get_all_guild_channels( + self, guild_id: Snowflake + ) -> Response[List[guild.GuildChannel]]: + return self.request( + Route("GET", "/guilds/{guild_id}/channels", guild_id=guild_id) + ) def get_members( self, guild_id: Snowflake, limit: int, after: Optional[Snowflake] ) -> Response[List[member.MemberWithUser]]: params: Dict[str, Any] = { - 'limit': limit, + "limit": limit, } if after: - params['after'] = after + params["after"] = after - r = Route('GET', '/guilds/{guild_id}/members', guild_id=guild_id) + r = Route("GET", "/guilds/{guild_id}/members", guild_id=guild_id) return self.request(r, params=params) - def get_member(self, guild_id: Snowflake, member_id: Snowflake) -> Response[member.MemberWithUser]: - return self.request(Route('GET', '/guilds/{guild_id}/members/{member_id}', guild_id=guild_id, member_id=member_id)) + def get_member( + self, guild_id: Snowflake, member_id: Snowflake + ) -> Response[member.MemberWithUser]: + return self.request( + Route( + "GET", + "/guilds/{guild_id}/members/{member_id}", + guild_id=guild_id, + member_id=member_id, + ) + ) def prune_members( self, @@ -1157,13 +1436,17 @@ class HTTPClient: reason: Optional[str] = None, ) -> Response[guild.GuildPrune]: payload: Dict[str, Any] = { - 'days': days, - 'compute_prune_count': 'true' if compute_prune_count else 'false', + "days": days, + "compute_prune_count": "true" if compute_prune_count else "false", } if roles: - payload['include_roles'] = ', '.join(roles) + payload["include_roles"] = ", ".join(roles) - return self.request(Route('POST', '/guilds/{guild_id}/prune', guild_id=guild_id), json=payload, reason=reason) + return self.request( + Route("POST", "/guilds/{guild_id}/prune", guild_id=guild_id), + json=payload, + reason=reason, + ) def estimate_pruned_members( self, @@ -1172,83 +1455,132 @@ class HTTPClient: roles: List[str], ) -> Response[guild.GuildPrune]: params: Dict[str, Any] = { - 'days': days, + "days": days, } if roles: - params['include_roles'] = ', '.join(roles) + params["include_roles"] = ", ".join(roles) - return self.request(Route('GET', '/guilds/{guild_id}/prune', guild_id=guild_id), params=params) + return self.request( + Route("GET", "/guilds/{guild_id}/prune", guild_id=guild_id), params=params + ) def get_sticker(self, sticker_id: Snowflake) -> Response[sticker.Sticker]: - return self.request(Route('GET', '/stickers/{sticker_id}', sticker_id=sticker_id)) + return self.request( + Route("GET", "/stickers/{sticker_id}", sticker_id=sticker_id) + ) def list_premium_sticker_packs(self) -> Response[sticker.ListPremiumStickerPacks]: - return self.request(Route('GET', '/sticker-packs')) + return self.request(Route("GET", "/sticker-packs")) - def get_all_guild_stickers(self, guild_id: Snowflake) -> Response[List[sticker.GuildSticker]]: - return self.request(Route('GET', '/guilds/{guild_id}/stickers', guild_id=guild_id)) - - def get_guild_sticker(self, guild_id: Snowflake, sticker_id: Snowflake) -> Response[sticker.GuildSticker]: + def get_all_guild_stickers( + self, guild_id: Snowflake + ) -> Response[List[sticker.GuildSticker]]: return self.request( - Route('GET', '/guilds/{guild_id}/stickers/{sticker_id}', guild_id=guild_id, sticker_id=sticker_id) + Route("GET", "/guilds/{guild_id}/stickers", guild_id=guild_id) + ) + + def get_guild_sticker( + self, guild_id: Snowflake, sticker_id: Snowflake + ) -> Response[sticker.GuildSticker]: + return self.request( + Route( + "GET", + "/guilds/{guild_id}/stickers/{sticker_id}", + guild_id=guild_id, + sticker_id=sticker_id, + ) ) def create_guild_sticker( - self, guild_id: Snowflake, payload: sticker.CreateGuildSticker, file: File, reason: str + self, + guild_id: Snowflake, + payload: sticker.CreateGuildSticker, + file: File, + reason: str, ) -> Response[sticker.GuildSticker]: initial_bytes = file.fp.read(16) try: mime_type = utils._get_mime_type_for_image(initial_bytes) except InvalidArgument: - if initial_bytes.startswith(b'{'): - mime_type = 'application/json' + if initial_bytes.startswith(b"{"): + mime_type = "application/json" else: - mime_type = 'application/octet-stream' + mime_type = "application/octet-stream" finally: file.reset() form: List[Dict[str, Any]] = [ { - 'name': 'file', - 'value': file.fp, - 'filename': file.filename, - 'content_type': mime_type, + "name": "file", + "value": file.fp, + "filename": file.filename, + "content_type": mime_type, } ] for k, v in payload.items(): form.append( { - 'name': k, - 'value': v, + "name": k, + "value": v, } ) return self.request( - Route('POST', '/guilds/{guild_id}/stickers', guild_id=guild_id), form=form, files=[file], reason=reason + Route("POST", "/guilds/{guild_id}/stickers", guild_id=guild_id), + form=form, + files=[file], + reason=reason, ) def modify_guild_sticker( - self, guild_id: Snowflake, sticker_id: Snowflake, payload: sticker.EditGuildSticker, reason: Optional[str], + self, + guild_id: Snowflake, + sticker_id: Snowflake, + payload: sticker.EditGuildSticker, + reason: Optional[str], ) -> Response[sticker.GuildSticker]: return self.request( - Route('PATCH', '/guilds/{guild_id}/stickers/{sticker_id}', guild_id=guild_id, sticker_id=sticker_id), + Route( + "PATCH", + "/guilds/{guild_id}/stickers/{sticker_id}", + guild_id=guild_id, + sticker_id=sticker_id, + ), json=payload, reason=reason, ) - def delete_guild_sticker(self, guild_id: Snowflake, sticker_id: Snowflake, reason: Optional[str]) -> Response[None]: + def delete_guild_sticker( + self, guild_id: Snowflake, sticker_id: Snowflake, reason: Optional[str] + ) -> Response[None]: return self.request( - Route('DELETE', '/guilds/{guild_id}/stickers/{sticker_id}', guild_id=guild_id, sticker_id=sticker_id), + Route( + "DELETE", + "/guilds/{guild_id}/stickers/{sticker_id}", + guild_id=guild_id, + sticker_id=sticker_id, + ), reason=reason, ) def get_all_custom_emojis(self, guild_id: Snowflake) -> Response[List[emoji.Emoji]]: - return self.request(Route('GET', '/guilds/{guild_id}/emojis', guild_id=guild_id)) + return self.request( + Route("GET", "/guilds/{guild_id}/emojis", guild_id=guild_id) + ) - def get_custom_emoji(self, guild_id: Snowflake, emoji_id: Snowflake) -> Response[emoji.Emoji]: - return self.request(Route('GET', '/guilds/{guild_id}/emojis/{emoji_id}', guild_id=guild_id, emoji_id=emoji_id)) + def get_custom_emoji( + self, guild_id: Snowflake, emoji_id: Snowflake + ) -> Response[emoji.Emoji]: + return self.request( + Route( + "GET", + "/guilds/{guild_id}/emojis/{emoji_id}", + guild_id=guild_id, + emoji_id=emoji_id, + ) + ) def create_custom_emoji( self, @@ -1260,12 +1592,12 @@ class HTTPClient: reason: Optional[str] = None, ) -> Response[emoji.Emoji]: payload = { - 'name': name, - 'image': image, - 'roles': roles or [], + "name": name, + "image": image, + "roles": roles or [], } - r = Route('POST', '/guilds/{guild_id}/emojis', guild_id=guild_id) + r = Route("POST", "/guilds/{guild_id}/emojis", guild_id=guild_id) return self.request(r, json=payload, reason=reason) def delete_custom_emoji( @@ -1275,7 +1607,12 @@ class HTTPClient: *, reason: Optional[str] = None, ) -> Response[None]: - r = Route('DELETE', '/guilds/{guild_id}/emojis/{emoji_id}', guild_id=guild_id, emoji_id=emoji_id) + r = Route( + "DELETE", + "/guilds/{guild_id}/emojis/{emoji_id}", + guild_id=guild_id, + emoji_id=emoji_id, + ) return self.request(r, reason=reason) def edit_custom_emoji( @@ -1286,42 +1623,68 @@ class HTTPClient: payload: Dict[str, Any], reason: Optional[str] = None, ) -> Response[emoji.Emoji]: - r = Route('PATCH', '/guilds/{guild_id}/emojis/{emoji_id}', guild_id=guild_id, emoji_id=emoji_id) + r = Route( + "PATCH", + "/guilds/{guild_id}/emojis/{emoji_id}", + guild_id=guild_id, + emoji_id=emoji_id, + ) return self.request(r, json=payload, reason=reason) - def get_all_integrations(self, guild_id: Snowflake) -> Response[List[integration.Integration]]: - r = Route('GET', '/guilds/{guild_id}/integrations', guild_id=guild_id) + def get_all_integrations( + self, guild_id: Snowflake + ) -> Response[List[integration.Integration]]: + r = Route("GET", "/guilds/{guild_id}/integrations", guild_id=guild_id) return self.request(r) - def create_integration(self, guild_id: Snowflake, type: integration.IntegrationType, id: int) -> Response[None]: + def create_integration( + self, guild_id: Snowflake, type: integration.IntegrationType, id: int + ) -> Response[None]: payload = { - 'type': type, - 'id': id, + "type": type, + "id": id, } - r = Route('POST', '/guilds/{guild_id}/integrations', guild_id=guild_id) + r = Route("POST", "/guilds/{guild_id}/integrations", guild_id=guild_id) return self.request(r, json=payload) - def edit_integration(self, guild_id: Snowflake, integration_id: Snowflake, **payload: Any) -> Response[None]: + def edit_integration( + self, guild_id: Snowflake, integration_id: Snowflake, **payload: Any + ) -> Response[None]: r = Route( - 'PATCH', '/guilds/{guild_id}/integrations/{integration_id}', guild_id=guild_id, integration_id=integration_id + "PATCH", + "/guilds/{guild_id}/integrations/{integration_id}", + guild_id=guild_id, + integration_id=integration_id, ) return self.request(r, json=payload) - def sync_integration(self, guild_id: Snowflake, integration_id: Snowflake) -> Response[None]: + def sync_integration( + self, guild_id: Snowflake, integration_id: Snowflake + ) -> Response[None]: r = Route( - 'POST', '/guilds/{guild_id}/integrations/{integration_id}/sync', guild_id=guild_id, integration_id=integration_id + "POST", + "/guilds/{guild_id}/integrations/{integration_id}/sync", + guild_id=guild_id, + integration_id=integration_id, ) return self.request(r) def delete_integration( - self, guild_id: Snowflake, integration_id: Snowflake, *, reason: Optional[str] = None + self, + guild_id: Snowflake, + integration_id: Snowflake, + *, + reason: Optional[str] = None, ) -> Response[None]: r = Route( - 'DELETE', '/guilds/{guild_id}/integrations/{integration_id}', guild_id=guild_id, integration_id=integration_id + "DELETE", + "/guilds/{guild_id}/integrations/{integration_id}", + guild_id=guild_id, + integration_id=integration_id, ) return self.request(r, reason=reason) @@ -1335,24 +1698,30 @@ class HTTPClient: user_id: Optional[Snowflake] = None, action_type: Optional[AuditLogAction] = None, ) -> Response[audit_log.AuditLog]: - params: Dict[str, Any] = {'limit': limit} + params: Dict[str, Any] = {"limit": limit} if before: - params['before'] = before + params["before"] = before if after: - params['after'] = after + params["after"] = after if user_id: - params['user_id'] = user_id + params["user_id"] = user_id if action_type: - params['action_type'] = action_type + params["action_type"] = action_type - r = Route('GET', '/guilds/{guild_id}/audit-logs', guild_id=guild_id) + r = Route("GET", "/guilds/{guild_id}/audit-logs", guild_id=guild_id) return self.request(r, params=params) def get_widget(self, guild_id: Snowflake) -> Response[widget.Widget]: - return self.request(Route('GET', '/guilds/{guild_id}/widget.json', guild_id=guild_id)) + return self.request( + Route("GET", "/guilds/{guild_id}/widget.json", guild_id=guild_id) + ) - def edit_widget(self, guild_id: Snowflake, payload) -> Response[widget.WidgetSettings]: - return self.request(Route('PATCH', '/guilds/{guild_id}/widget', guild_id=guild_id), json=payload) + def edit_widget( + self, guild_id: Snowflake, payload + ) -> Response[widget.WidgetSettings]: + return self.request( + Route("PATCH", "/guilds/{guild_id}/widget", guild_id=guild_id), json=payload + ) # Invite management @@ -1369,22 +1738,22 @@ class HTTPClient: target_user_id: Optional[Snowflake] = None, target_application_id: Optional[Snowflake] = None, ) -> Response[invite.Invite]: - r = Route('POST', '/channels/{channel_id}/invites', channel_id=channel_id) + r = Route("POST", "/channels/{channel_id}/invites", channel_id=channel_id) payload = { - 'max_age': max_age, - 'max_uses': max_uses, - 'temporary': temporary, - 'unique': unique, + "max_age": max_age, + "max_uses": max_uses, + "temporary": temporary, + "unique": unique, } if target_type: - payload['target_type'] = target_type + payload["target_type"] = target_type if target_user_id: - payload['target_user_id'] = target_user_id + payload["target_user_id"] = target_user_id if target_application_id: - payload['target_application_id'] = str(target_application_id) + payload["target_application_id"] = str(target_application_id) return self.request(r, reason=reason, json=payload) @@ -1392,35 +1761,64 @@ class HTTPClient: self, invite_id: str, *, with_counts: bool = True, with_expiration: bool = True ) -> Response[invite.Invite]: params = { - 'with_counts': int(with_counts), - 'with_expiration': int(with_expiration), + "with_counts": int(with_counts), + "with_expiration": int(with_expiration), } - return self.request(Route('GET', '/invites/{invite_id}', invite_id=invite_id), params=params) + return self.request( + Route("GET", "/invites/{invite_id}", invite_id=invite_id), params=params + ) def invites_from(self, guild_id: Snowflake) -> Response[List[invite.Invite]]: - return self.request(Route('GET', '/guilds/{guild_id}/invites', guild_id=guild_id)) + return self.request( + Route("GET", "/guilds/{guild_id}/invites", guild_id=guild_id) + ) - def invites_from_channel(self, channel_id: Snowflake) -> Response[List[invite.Invite]]: - return self.request(Route('GET', '/channels/{channel_id}/invites', channel_id=channel_id)) + def invites_from_channel( + self, channel_id: Snowflake + ) -> Response[List[invite.Invite]]: + return self.request( + Route("GET", "/channels/{channel_id}/invites", channel_id=channel_id) + ) - def delete_invite(self, invite_id: str, *, reason: Optional[str] = None) -> Response[None]: - return self.request(Route('DELETE', '/invites/{invite_id}', invite_id=invite_id), reason=reason) + def delete_invite( + self, invite_id: str, *, reason: Optional[str] = None + ) -> Response[None]: + return self.request( + Route("DELETE", "/invites/{invite_id}", invite_id=invite_id), reason=reason + ) # Role management def get_roles(self, guild_id: Snowflake) -> Response[List[role.Role]]: - return self.request(Route('GET', '/guilds/{guild_id}/roles', guild_id=guild_id)) + return self.request(Route("GET", "/guilds/{guild_id}/roles", guild_id=guild_id)) def edit_role( - self, guild_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None, **fields: Any + self, + guild_id: Snowflake, + role_id: Snowflake, + *, + reason: Optional[str] = None, + **fields: Any, ) -> Response[role.Role]: - r = Route('PATCH', '/guilds/{guild_id}/roles/{role_id}', guild_id=guild_id, role_id=role_id) - valid_keys = ('name', 'permissions', 'color', 'hoist', 'mentionable') + r = Route( + "PATCH", + "/guilds/{guild_id}/roles/{role_id}", + guild_id=guild_id, + role_id=role_id, + ) + valid_keys = ("name", "permissions", "color", "hoist", "mentionable") payload = {k: v for k, v in fields.items() if k in valid_keys} return self.request(r, json=payload, reason=reason) - def delete_role(self, guild_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]: - r = Route('DELETE', '/guilds/{guild_id}/roles/{role_id}', guild_id=guild_id, role_id=role_id) + def delete_role( + self, guild_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None + ) -> Response[None]: + r = Route( + "DELETE", + "/guilds/{guild_id}/roles/{role_id}", + guild_id=guild_id, + role_id=role_id, + ) return self.request(r, reason=reason) def replace_roles( @@ -1431,10 +1829,14 @@ class HTTPClient: *, reason: Optional[str] = None, ) -> Response[member.MemberWithUser]: - return self.edit_member(guild_id=guild_id, user_id=user_id, roles=role_ids, reason=reason) + return self.edit_member( + guild_id=guild_id, user_id=user_id, roles=role_ids, reason=reason + ) - def create_role(self, guild_id: Snowflake, *, reason: Optional[str] = None, **fields: Any) -> Response[role.Role]: - r = Route('POST', '/guilds/{guild_id}/roles', guild_id=guild_id) + def create_role( + self, guild_id: Snowflake, *, reason: Optional[str] = None, **fields: Any + ) -> Response[role.Role]: + r = Route("POST", "/guilds/{guild_id}/roles", guild_id=guild_id) return self.request(r, json=fields, reason=reason) def move_role_position( @@ -1444,15 +1846,20 @@ class HTTPClient: *, reason: Optional[str] = None, ) -> Response[List[role.Role]]: - r = Route('PATCH', '/guilds/{guild_id}/roles', guild_id=guild_id) + r = Route("PATCH", "/guilds/{guild_id}/roles", guild_id=guild_id) return self.request(r, json=positions, reason=reason) def add_role( - self, guild_id: Snowflake, user_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None + self, + guild_id: Snowflake, + user_id: Snowflake, + role_id: Snowflake, + *, + reason: Optional[str] = None, ) -> Response[None]: r = Route( - 'PUT', - '/guilds/{guild_id}/members/{user_id}/roles/{role_id}', + "PUT", + "/guilds/{guild_id}/members/{user_id}/roles/{role_id}", guild_id=guild_id, user_id=user_id, role_id=role_id, @@ -1460,11 +1867,16 @@ class HTTPClient: return self.request(r, reason=reason) def remove_role( - self, guild_id: Snowflake, user_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None + self, + guild_id: Snowflake, + user_id: Snowflake, + role_id: Snowflake, + *, + reason: Optional[str] = None, ) -> Response[None]: r = Route( - 'DELETE', - '/guilds/{guild_id}/members/{user_id}/roles/{role_id}', + "DELETE", + "/guilds/{guild_id}/members/{user_id}/roles/{role_id}", guild_id=guild_id, user_id=user_id, role_id=role_id, @@ -1481,14 +1893,28 @@ class HTTPClient: *, reason: Optional[str] = None, ) -> Response[None]: - payload = {'id': target, 'allow': allow, 'deny': deny, 'type': type} - r = Route('PUT', '/channels/{channel_id}/permissions/{target}', channel_id=channel_id, target=target) + payload = {"id": target, "allow": allow, "deny": deny, "type": type} + r = Route( + "PUT", + "/channels/{channel_id}/permissions/{target}", + channel_id=channel_id, + target=target, + ) return self.request(r, json=payload, reason=reason) def delete_channel_permissions( - self, channel_id: Snowflake, target: channel.OverwriteType, *, reason: Optional[str] = None + self, + channel_id: Snowflake, + target: channel.OverwriteType, + *, + reason: Optional[str] = None, ) -> Response[None]: - r = Route('DELETE', '/channels/{channel_id}/permissions/{target}', channel_id=channel_id, target=target) + r = Route( + "DELETE", + "/channels/{channel_id}/permissions/{target}", + channel_id=channel_id, + target=target, + ) return self.request(r, reason=reason) # Voice management @@ -1501,55 +1927,88 @@ class HTTPClient: *, reason: Optional[str] = None, ) -> Response[member.MemberWithUser]: - return self.edit_member(guild_id=guild_id, user_id=user_id, channel_id=channel_id, reason=reason) + return self.edit_member( + guild_id=guild_id, user_id=user_id, channel_id=channel_id, reason=reason + ) # Stage instance management - def get_stage_instance(self, channel_id: Snowflake) -> Response[channel.StageInstance]: - return self.request(Route('GET', '/stage-instances/{channel_id}', channel_id=channel_id)) - - def create_stage_instance(self, *, reason: Optional[str], **payload: Any) -> Response[channel.StageInstance]: - valid_keys = ( - 'channel_id', - 'topic', - 'privacy_level', + def get_stage_instance( + self, channel_id: Snowflake + ) -> Response[channel.StageInstance]: + return self.request( + Route("GET", "/stage-instances/{channel_id}", channel_id=channel_id) ) - payload = {k: v for k, v in payload.items() if k in valid_keys} - return self.request(Route('POST', '/stage-instances'), json=payload, reason=reason) - - def edit_stage_instance(self, channel_id: Snowflake, *, reason: Optional[str] = None, **payload: Any) -> Response[None]: + def create_stage_instance( + self, *, reason: Optional[str], **payload: Any + ) -> Response[channel.StageInstance]: valid_keys = ( - 'topic', - 'privacy_level', + "channel_id", + "topic", + "privacy_level", ) payload = {k: v for k, v in payload.items() if k in valid_keys} return self.request( - Route('PATCH', '/stage-instances/{channel_id}', channel_id=channel_id), json=payload, reason=reason + Route("POST", "/stage-instances"), json=payload, reason=reason ) - def delete_stage_instance(self, channel_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]: - return self.request(Route('DELETE', '/stage-instances/{channel_id}', channel_id=channel_id), reason=reason) + def edit_stage_instance( + self, channel_id: Snowflake, *, reason: Optional[str] = None, **payload: Any + ) -> Response[None]: + valid_keys = ( + "topic", + "privacy_level", + ) + payload = {k: v for k, v in payload.items() if k in valid_keys} + + return self.request( + Route("PATCH", "/stage-instances/{channel_id}", channel_id=channel_id), + json=payload, + reason=reason, + ) + + def delete_stage_instance( + self, channel_id: Snowflake, *, reason: Optional[str] = None + ) -> Response[None]: + return self.request( + Route("DELETE", "/stage-instances/{channel_id}", channel_id=channel_id), + reason=reason, + ) # Application commands (global) - def get_global_commands(self, application_id: Snowflake) -> Response[List[interactions.ApplicationCommand]]: - return self.request(Route('GET', '/applications/{application_id}/commands', application_id=application_id)) + def get_global_commands( + self, application_id: Snowflake + ) -> Response[List[interactions.ApplicationCommand]]: + return self.request( + Route( + "GET", + "/applications/{application_id}/commands", + application_id=application_id, + ) + ) def get_global_command( self, application_id: Snowflake, command_id: Snowflake ) -> Response[interactions.ApplicationCommand]: r = Route( - 'GET', - '/applications/{application_id}/commands/{command_id}', + "GET", + "/applications/{application_id}/commands/{command_id}", application_id=application_id, command_id=command_id, ) return self.request(r) - def upsert_global_command(self, application_id: Snowflake, payload) -> Response[interactions.ApplicationCommand]: - r = Route('POST', '/applications/{application_id}/commands', application_id=application_id) + def upsert_global_command( + self, application_id: Snowflake, payload + ) -> Response[interactions.ApplicationCommand]: + r = Route( + "POST", + "/applications/{application_id}/commands", + application_id=application_id, + ) return self.request(r, json=payload) def edit_global_command( @@ -1559,23 +2018,25 @@ class HTTPClient: payload: interactions.EditApplicationCommand, ) -> Response[interactions.ApplicationCommand]: valid_keys = ( - 'name', - 'description', - 'options', + "name", + "description", + "options", ) payload = {k: v for k, v in payload.items() if k in valid_keys} # type: ignore r = Route( - 'PATCH', - '/applications/{application_id}/commands/{command_id}', + "PATCH", + "/applications/{application_id}/commands/{command_id}", application_id=application_id, command_id=command_id, ) return self.request(r, json=payload) - def delete_global_command(self, application_id: Snowflake, command_id: Snowflake) -> Response[None]: + def delete_global_command( + self, application_id: Snowflake, command_id: Snowflake + ) -> Response[None]: r = Route( - 'DELETE', - '/applications/{application_id}/commands/{command_id}', + "DELETE", + "/applications/{application_id}/commands/{command_id}", application_id=application_id, command_id=command_id, ) @@ -1584,7 +2045,11 @@ class HTTPClient: def bulk_upsert_global_commands( self, application_id: Snowflake, payload ) -> Response[List[interactions.ApplicationCommand]]: - r = Route('PUT', '/applications/{application_id}/commands', application_id=application_id) + r = Route( + "PUT", + "/applications/{application_id}/commands", + application_id=application_id, + ) return self.request(r, json=payload) # Application commands (guild) @@ -1593,8 +2058,8 @@ class HTTPClient: self, application_id: Snowflake, guild_id: Snowflake ) -> Response[List[interactions.ApplicationCommand]]: r = Route( - 'GET', - '/applications/{application_id}/guilds/{guild_id}/commands', + "GET", + "/applications/{application_id}/guilds/{guild_id}/commands", application_id=application_id, guild_id=guild_id, ) @@ -1607,8 +2072,8 @@ class HTTPClient: command_id: Snowflake, ) -> Response[interactions.ApplicationCommand]: r = Route( - 'GET', - '/applications/{application_id}/guilds/{guild_id}/commands/{command_id}', + "GET", + "/applications/{application_id}/guilds/{guild_id}/commands/{command_id}", application_id=application_id, guild_id=guild_id, command_id=command_id, @@ -1622,8 +2087,8 @@ class HTTPClient: payload: interactions.EditApplicationCommand, ) -> Response[interactions.ApplicationCommand]: r = Route( - 'POST', - '/applications/{application_id}/guilds/{guild_id}/commands', + "POST", + "/applications/{application_id}/guilds/{guild_id}/commands", application_id=application_id, guild_id=guild_id, ) @@ -1637,14 +2102,14 @@ class HTTPClient: payload: interactions.EditApplicationCommand, ) -> Response[interactions.ApplicationCommand]: valid_keys = ( - 'name', - 'description', - 'options', + "name", + "description", + "options", ) payload = {k: v for k, v in payload.items() if k in valid_keys} # type: ignore r = Route( - 'PATCH', - '/applications/{application_id}/guilds/{guild_id}/commands/{command_id}', + "PATCH", + "/applications/{application_id}/guilds/{guild_id}/commands/{command_id}", application_id=application_id, guild_id=guild_id, command_id=command_id, @@ -1658,8 +2123,8 @@ class HTTPClient: command_id: Snowflake, ) -> Response[None]: r = Route( - 'DELETE', - '/applications/{application_id}/guilds/{guild_id}/commands/{command_id}', + "DELETE", + "/applications/{application_id}/guilds/{guild_id}/commands/{command_id}", application_id=application_id, guild_id=guild_id, command_id=command_id, @@ -1673,8 +2138,8 @@ class HTTPClient: payload: List[interactions.EditApplicationCommand], ) -> Response[List[interactions.ApplicationCommand]]: r = Route( - 'PUT', - '/applications/{application_id}/guilds/{guild_id}/commands', + "PUT", + "/applications/{application_id}/guilds/{guild_id}/commands", application_id=application_id, guild_id=guild_id, ) @@ -1693,26 +2158,26 @@ class HTTPClient: payload: Dict[str, Any] = {} if content: - payload['content'] = content + payload["content"] = content if embeds: - payload['embeds'] = embeds + payload["embeds"] = embeds if allowed_mentions: - payload['allowed_mentions'] = allowed_mentions + payload["allowed_mentions"] = allowed_mentions form: List[Dict[str, Any]] = [ { - 'name': 'payload_json', - 'value': utils._to_json(payload), + "name": "payload_json", + "value": utils._to_json(payload), } ] if file: form.append( { - 'name': 'file', - 'value': file.fp, - 'filename': file.filename, - 'content_type': 'application/octet-stream', + "name": "file", + "value": file.fp, + "filename": file.filename, + "content_type": "application/octet-stream", } ) @@ -1727,17 +2192,17 @@ class HTTPClient: data: Optional[interactions.InteractionApplicationCommandCallbackData] = None, ) -> Response[None]: r = Route( - 'POST', - '/interactions/{interaction_id}/{interaction_token}/callback', + "POST", + "/interactions/{interaction_id}/{interaction_token}/callback", interaction_id=interaction_id, interaction_token=token, ) payload: Dict[str, Any] = { - 'type': type, + "type": type, } if data is not None: - payload['data'] = data + payload["data"] = data return self.request(r, json=payload) @@ -1747,8 +2212,8 @@ class HTTPClient: token: str, ) -> Response[message.Message]: r = Route( - 'GET', - '/webhooks/{application_id}/{interaction_token}/messages/@original', + "GET", + "/webhooks/{application_id}/{interaction_token}/messages/@original", application_id=application_id, interaction_token=token, ) @@ -1764,17 +2229,25 @@ class HTTPClient: allowed_mentions: Optional[message.AllowedMentions] = None, ) -> Response[message.Message]: r = Route( - 'PATCH', - '/webhooks/{application_id}/{interaction_token}/messages/@original', + "PATCH", + "/webhooks/{application_id}/{interaction_token}/messages/@original", application_id=application_id, interaction_token=token, ) - return self._edit_webhook_helper(r, file=file, content=content, embeds=embeds, allowed_mentions=allowed_mentions) + return self._edit_webhook_helper( + r, + file=file, + content=content, + embeds=embeds, + allowed_mentions=allowed_mentions, + ) - def delete_original_interaction_response(self, application_id: Snowflake, token: str) -> Response[None]: + def delete_original_interaction_response( + self, application_id: Snowflake, token: str + ) -> Response[None]: r = Route( - 'DELETE', - '/webhooks/{application_id}/{interaction_token}/messages/@original', + "DELETE", + "/webhooks/{application_id}/{interaction_token}/messages/@original", application_id=application_id, interaction_token=token, ) @@ -1791,8 +2264,8 @@ class HTTPClient: allowed_mentions: Optional[message.AllowedMentions] = None, ) -> Response[message.Message]: r = Route( - 'POST', - '/webhooks/{application_id}/{interaction_token}', + "POST", + "/webhooks/{application_id}/{interaction_token}", application_id=application_id, interaction_token=token, ) @@ -1816,18 +2289,26 @@ class HTTPClient: allowed_mentions: Optional[message.AllowedMentions] = None, ) -> Response[message.Message]: r = Route( - 'PATCH', - '/webhooks/{application_id}/{interaction_token}/messages/{message_id}', + "PATCH", + "/webhooks/{application_id}/{interaction_token}/messages/{message_id}", application_id=application_id, interaction_token=token, message_id=message_id, ) - return self._edit_webhook_helper(r, file=file, content=content, embeds=embeds, allowed_mentions=allowed_mentions) + return self._edit_webhook_helper( + r, + file=file, + content=content, + embeds=embeds, + allowed_mentions=allowed_mentions, + ) - def delete_followup_message(self, application_id: Snowflake, token: str, message_id: Snowflake) -> Response[None]: + def delete_followup_message( + self, application_id: Snowflake, token: str, message_id: Snowflake + ) -> Response[None]: r = Route( - 'DELETE', - '/webhooks/{application_id}/{interaction_token}/messages/{message_id}', + "DELETE", + "/webhooks/{application_id}/{interaction_token}/messages/{message_id}", application_id=application_id, interaction_token=token, message_id=message_id, @@ -1840,8 +2321,8 @@ class HTTPClient: guild_id: Snowflake, ) -> Response[List[interactions.GuildApplicationCommandPermissions]]: r = Route( - 'GET', - '/applications/{application_id}/guilds/{guild_id}/commands/permissions', + "GET", + "/applications/{application_id}/guilds/{guild_id}/commands/permissions", application_id=application_id, guild_id=guild_id, ) @@ -1854,8 +2335,8 @@ class HTTPClient: command_id: Snowflake, ) -> Response[interactions.GuildApplicationCommandPermissions]: r = Route( - 'GET', - '/applications/{application_id}/guilds/{guild_id}/commands/{command_id}/permissions', + "GET", + "/applications/{application_id}/guilds/{guild_id}/commands/{command_id}/permissions", application_id=application_id, guild_id=guild_id, command_id=command_id, @@ -1870,8 +2351,8 @@ class HTTPClient: payload: interactions.BaseGuildApplicationCommandPermissions, ) -> Response[None]: r = Route( - 'PUT', - '/applications/{application_id}/guilds/{guild_id}/commands/{command_id}/permissions', + "PUT", + "/applications/{application_id}/guilds/{guild_id}/commands/{command_id}/permissions", application_id=application_id, guild_id=guild_id, command_id=command_id, @@ -1885,8 +2366,8 @@ class HTTPClient: payload: List[interactions.PartialGuildApplicationCommandPermissions], ) -> Response[None]: r = Route( - 'PUT', - '/applications/{application_id}/guilds/{guild_id}/commands/permissions', + "PUT", + "/applications/{application_id}/guilds/{guild_id}/commands/permissions", application_id=application_id, guild_id=guild_id, ) @@ -1895,30 +2376,32 @@ class HTTPClient: # Misc def application_info(self) -> Response[appinfo.AppInfo]: - return self.request(Route('GET', '/oauth2/applications/@me')) + return self.request(Route("GET", "/oauth2/applications/@me")) - async def get_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> str: + async def get_gateway(self, *, encoding: str = "json", zlib: bool = True) -> str: try: - data = await self.request(Route('GET', '/gateway')) + data = await self.request(Route("GET", "/gateway")) except HTTPException as exc: raise GatewayNotFound() from exc if zlib: - value = '{0}?encoding={1}&v=9&compress=zlib-stream' + value = "{0}?encoding={1}&v=9&compress=zlib-stream" else: - value = '{0}?encoding={1}&v=9' - return value.format(data['url'], encoding) + value = "{0}?encoding={1}&v=9" + return value.format(data["url"], encoding) - async def get_bot_gateway(self, *, encoding: str = 'json', zlib: bool = True) -> Tuple[int, str]: + async def get_bot_gateway( + self, *, encoding: str = "json", zlib: bool = True + ) -> Tuple[int, str]: try: - data = await self.request(Route('GET', '/gateway/bot')) + data = await self.request(Route("GET", "/gateway/bot")) except HTTPException as exc: raise GatewayNotFound() from exc if zlib: - value = '{0}?encoding={1}&v=9&compress=zlib-stream' + value = "{0}?encoding={1}&v=9&compress=zlib-stream" else: - value = '{0}?encoding={1}&v=9' - return data['shards'], value.format(data['url'], encoding) + value = "{0}?encoding={1}&v=9" + return data["shards"], value.format(data["url"], encoding) def get_user(self, user_id: Snowflake) -> Response[user.User]: - return self.request(Route('GET', '/users/{user_id}', user_id=user_id)) + return self.request(Route("GET", "/users/{user_id}", user_id=user_id)) diff --git a/discord/integrations.py b/discord/integrations.py index 23d02930..b946cd5b 100644 --- a/discord/integrations.py +++ b/discord/integrations.py @@ -32,11 +32,11 @@ from .errors import InvalidArgument from .enums import try_enum, ExpireBehaviour __all__ = ( - 'IntegrationAccount', - 'IntegrationApplication', - 'Integration', - 'StreamIntegration', - 'BotIntegration', + "IntegrationAccount", + "IntegrationApplication", + "Integration", + "StreamIntegration", + "BotIntegration", ) if TYPE_CHECKING: @@ -65,14 +65,14 @@ class IntegrationAccount: The account name. """ - __slots__ = ('id', 'name') + __slots__ = ("id", "name") def __init__(self, data: IntegrationAccountPayload) -> None: - self.id: str = data['id'] - self.name: str = data['name'] + self.id: str = data["id"] + self.name: str = data["name"] def __repr__(self) -> str: - return f'' + return f"" class Integration: @@ -99,14 +99,14 @@ class Integration: """ __slots__ = ( - 'guild', - 'id', - '_state', - 'type', - 'name', - 'account', - 'user', - 'enabled', + "guild", + "id", + "_state", + "type", + "name", + "account", + "user", + "enabled", ) def __init__(self, *, data: IntegrationPayload, guild: Guild) -> None: @@ -118,14 +118,14 @@ class Integration: return f"<{self.__class__.__name__} id={self.id} name={self.name!r}>" def _from_data(self, data: IntegrationPayload) -> None: - self.id: int = int(data['id']) - self.type: IntegrationType = data['type'] - self.name: str = data['name'] - self.account: IntegrationAccount = IntegrationAccount(data['account']) + self.id: int = int(data["id"]) + self.type: IntegrationType = data["type"] + self.name: str = data["name"] + self.account: IntegrationAccount = IntegrationAccount(data["account"]) - user = data.get('user') + user = data.get("user") self.user = User(state=self._state, data=user) if user else None - self.enabled: bool = data['enabled'] + self.enabled: bool = data["enabled"] async def delete(self, *, reason: Optional[str] = None) -> None: """|coro| @@ -186,26 +186,28 @@ class StreamIntegration(Integration): """ __slots__ = ( - 'revoked', - 'expire_behaviour', - 'expire_grace_period', - 'synced_at', - '_role_id', - 'syncing', - 'enable_emoticons', - 'subscriber_count', + "revoked", + "expire_behaviour", + "expire_grace_period", + "synced_at", + "_role_id", + "syncing", + "enable_emoticons", + "subscriber_count", ) def _from_data(self, data: StreamIntegrationPayload) -> None: super()._from_data(data) - self.revoked: bool = data['revoked'] - self.expire_behaviour: ExpireBehaviour = try_enum(ExpireBehaviour, data['expire_behavior']) - self.expire_grace_period: int = data['expire_grace_period'] - self.synced_at: datetime.datetime = parse_time(data['synced_at']) - self._role_id: Optional[int] = _get_as_snowflake(data, 'role_id') - self.syncing: bool = data['syncing'] - self.enable_emoticons: bool = data['enable_emoticons'] - self.subscriber_count: int = data['subscriber_count'] + self.revoked: bool = data["revoked"] + self.expire_behaviour: ExpireBehaviour = try_enum( + ExpireBehaviour, data["expire_behavior"] + ) + self.expire_grace_period: int = data["expire_grace_period"] + self.synced_at: datetime.datetime = parse_time(data["synced_at"]) + self._role_id: Optional[int] = _get_as_snowflake(data, "role_id") + self.syncing: bool = data["syncing"] + self.enable_emoticons: bool = data["enable_emoticons"] + self.subscriber_count: int = data["subscriber_count"] @property def expire_behavior(self) -> ExpireBehaviour: @@ -252,15 +254,17 @@ class StreamIntegration(Integration): payload: Dict[str, Any] = {} if expire_behaviour is not MISSING: if not isinstance(expire_behaviour, ExpireBehaviour): - raise InvalidArgument('expire_behaviour field must be of type ExpireBehaviour') + raise InvalidArgument( + "expire_behaviour field must be of type ExpireBehaviour" + ) - payload['expire_behavior'] = expire_behaviour.value + payload["expire_behavior"] = expire_behaviour.value if expire_grace_period is not MISSING: - payload['expire_grace_period'] = expire_grace_period + payload["expire_grace_period"] = expire_grace_period if enable_emoticons is not MISSING: - payload['enable_emoticons'] = enable_emoticons + payload["enable_emoticons"] = enable_emoticons # This endpoint is undocumented. # Unsure if it returns the data or not as a result @@ -307,21 +311,21 @@ class IntegrationApplication: """ __slots__ = ( - 'id', - 'name', - 'icon', - 'description', - 'summary', - 'user', + "id", + "name", + "icon", + "description", + "summary", + "user", ) def __init__(self, *, data: IntegrationApplicationPayload, state): - self.id: int = int(data['id']) - self.name: str = data['name'] - self.icon: Optional[str] = data['icon'] - self.description: str = data['description'] - self.summary: str = data['summary'] - user = data.get('bot') + self.id: int = int(data["id"]) + self.name: str = data["name"] + self.icon: Optional[str] = data["icon"] + self.description: str = data["description"] + self.summary: str = data["summary"] + user = data.get("bot") self.user: Optional[User] = User(state=state, data=user) if user else None @@ -350,17 +354,19 @@ class BotIntegration(Integration): The application tied to this integration. """ - __slots__ = ('application',) + __slots__ = ("application",) def _from_data(self, data: BotIntegrationPayload) -> None: super()._from_data(data) - self.application = IntegrationApplication(data=data['application'], state=self._state) + self.application = IntegrationApplication( + data=data["application"], state=self._state + ) def _integration_factory(value: str) -> Tuple[Type[Integration], str]: - if value == 'discord': + if value == "discord": return BotIntegration, value - elif value in ('twitch', 'youtube'): + elif value in ("twitch", "youtube"): return StreamIntegration, value else: return Integration, value diff --git a/discord/interactions.py b/discord/interactions.py index b89d49f5..98cd9fdf 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -41,9 +41,9 @@ from .permissions import Permissions from .webhook.async_ import async_context, Webhook, handle_message_parameters __all__ = ( - 'Interaction', - 'InteractionMessage', - 'InteractionResponse', + "Interaction", + "InteractionMessage", + "InteractionResponse", ) if TYPE_CHECKING: @@ -58,11 +58,24 @@ if TYPE_CHECKING: from aiohttp import ClientSession from .embeds import Embed from .ui.view import View - from .channel import VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, PartialMessageable + from .channel import ( + VoiceChannel, + StageChannel, + TextChannel, + CategoryChannel, + StoreChannel, + PartialMessageable, + ) from .threads import Thread InteractionChannel = Union[ - VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel, Thread, PartialMessageable + VoiceChannel, + StageChannel, + TextChannel, + CategoryChannel, + StoreChannel, + Thread, + PartialMessageable, ] MISSING: Any = utils.MISSING @@ -100,23 +113,23 @@ class Interaction: """ __slots__: Tuple[str, ...] = ( - 'id', - 'type', - 'guild_id', - 'channel_id', - 'data', - 'application_id', - 'message', - 'user', - 'token', - 'version', - '_permissions', - '_state', - '_session', - '_original_message', - '_cs_response', - '_cs_followup', - '_cs_channel', + "id", + "type", + "guild_id", + "channel_id", + "data", + "application_id", + "message", + "user", + "token", + "version", + "_permissions", + "_state", + "_session", + "_original_message", + "_cs_response", + "_cs_followup", + "_cs_channel", ) def __init__(self, *, data: InteractionPayload, state: ConnectionState): @@ -126,18 +139,18 @@ class Interaction: self._from_data(data) def _from_data(self, data: InteractionPayload): - self.id: int = int(data['id']) - self.type: InteractionType = try_enum(InteractionType, data['type']) - self.data: Optional[InteractionData] = data.get('data') - self.token: str = data['token'] - self.version: int = data['version'] - self.channel_id: Optional[int] = utils._get_as_snowflake(data, 'channel_id') - self.guild_id: Optional[int] = utils._get_as_snowflake(data, 'guild_id') - self.application_id: int = int(data['application_id']) + self.id: int = int(data["id"]) + self.type: InteractionType = try_enum(InteractionType, data["type"]) + self.data: Optional[InteractionData] = data.get("data") + self.token: str = data["token"] + self.version: int = data["version"] + self.channel_id: Optional[int] = utils._get_as_snowflake(data, "channel_id") + self.guild_id: Optional[int] = utils._get_as_snowflake(data, "guild_id") + self.application_id: int = int(data["application_id"]) self.message: Optional[Message] try: - self.message = Message(state=self._state, channel=self.channel, data=data['message']) # type: ignore + self.message = Message(state=self._state, channel=self.channel, data=data["message"]) # type: ignore except KeyError: self.message = None @@ -148,15 +161,15 @@ class Interaction: if self.guild_id: guild = self.guild or Object(id=self.guild_id) try: - member = data['member'] # type: ignore + member = data["member"] # type: ignore except KeyError: pass else: self.user = Member(state=self._state, guild=guild, data=member) # type: ignore - self._permissions = int(member.get('permissions', 0)) + self._permissions = int(member.get("permissions", 0)) else: try: - self.user = User(state=self._state, data=data['user']) + self.user = User(state=self._state, data=data["user"]) except KeyError: pass @@ -165,7 +178,7 @@ class Interaction: """Optional[:class:`Guild`]: The guild the interaction was sent from.""" return self._state and self._state._get_guild(self.guild_id) - @utils.cached_slot_property('_cs_channel') + @utils.cached_slot_property("_cs_channel") def channel(self) -> Optional[InteractionChannel]: """Optional[Union[:class:`abc.GuildChannel`, :class:`PartialMessageable`, :class:`Thread`]]: The channel the interaction was sent from. @@ -176,8 +189,14 @@ class Interaction: channel = guild and guild._resolve_channel(self.channel_id) if channel is None: if self.channel_id is not None: - type = ChannelType.text if self.guild_id is not None else ChannelType.private - return PartialMessageable(state=self._state, id=self.channel_id, type=type) + type = ( + ChannelType.text + if self.guild_id is not None + else ChannelType.private + ) + return PartialMessageable( + state=self._state, id=self.channel_id, type=type + ) return None return channel @@ -189,7 +208,7 @@ class Interaction: """ return Permissions(self._permissions) - @utils.cached_slot_property('_cs_response') + @utils.cached_slot_property("_cs_response") def response(self) -> InteractionResponse: """:class:`InteractionResponse`: Returns an object responsible for handling responding to the interaction. @@ -198,13 +217,13 @@ class Interaction: """ return InteractionResponse(self) - @utils.cached_slot_property('_cs_followup') + @utils.cached_slot_property("_cs_followup") def followup(self) -> Webhook: """:class:`Webhook`: Returns the follow up webhook for follow up interactions.""" payload = { - 'id': self.application_id, - 'type': 3, - 'token': self.token, + "id": self.application_id, + "type": 3, + "token": self.token, } return Webhook.from_state(data=payload, state=self._state) @@ -238,7 +257,7 @@ class Interaction: # TODO: fix later to not raise? channel = self.channel if channel is None: - raise ClientException('Channel for message could not be resolved') + raise ClientException("Channel for message could not be resolved") adapter = async_context.get() data = await adapter.get_original_interaction_response( @@ -369,8 +388,8 @@ class InteractionResponse: """ __slots__: Tuple[str, ...] = ( - '_responded', - '_parent', + "_responded", + "_parent", ) def __init__(self, parent: Interaction): @@ -416,12 +435,16 @@ class InteractionResponse: elif parent.type is InteractionType.application_command: defer_type = InteractionResponseType.deferred_channel_message.value if ephemeral: - data = {'flags': 64} + data = {"flags": 64} if defer_type: adapter = async_context.get() await adapter.create_interaction_response( - parent.id, parent.token, session=parent._session, type=defer_type, data=data + parent.id, + parent.token, + session=parent._session, + type=defer_type, + data=data, ) self._responded = True @@ -446,7 +469,10 @@ class InteractionResponse: if parent.type is InteractionType.ping: adapter = async_context.get() await adapter.create_interaction_response( - parent.id, parent.token, session=parent._session, type=InteractionResponseType.pong.value + parent.id, + parent.token, + session=parent._session, + type=InteractionResponseType.pong.value, ) self._responded = True @@ -498,28 +524,28 @@ class InteractionResponse: raise InteractionResponded(self._parent) payload: Dict[str, Any] = { - 'tts': tts, + "tts": tts, } if embed is not MISSING and embeds is not MISSING: - raise TypeError('cannot mix embed and embeds keyword arguments') + raise TypeError("cannot mix embed and embeds keyword arguments") if embed is not MISSING: embeds = [embed] if embeds: if len(embeds) > 10: - raise ValueError('embeds cannot exceed maximum of 10 elements') - payload['embeds'] = [e.to_dict() for e in embeds] + raise ValueError("embeds cannot exceed maximum of 10 elements") + payload["embeds"] = [e.to_dict() for e in embeds] if content is not None: - payload['content'] = str(content) + payload["content"] = str(content) if ephemeral: - payload['flags'] = 64 + payload["flags"] = 64 if view is not MISSING: - payload['components'] = view.to_components() + payload["components"] = view.to_components() parent = self._parent adapter = async_context.get() @@ -591,12 +617,12 @@ class InteractionResponse: payload = {} if content is not MISSING: if content is None: - payload['content'] = None + payload["content"] = None else: - payload['content'] = str(content) + payload["content"] = str(content) if embed is not MISSING and embeds is not MISSING: - raise TypeError('cannot mix both embed and embeds keyword arguments') + raise TypeError("cannot mix both embed and embeds keyword arguments") if embed is not MISSING: if embed is None: @@ -605,17 +631,17 @@ class InteractionResponse: embeds = [embed] if embeds is not MISSING: - payload['embeds'] = [e.to_dict() for e in embeds] + payload["embeds"] = [e.to_dict() for e in embeds] if attachments is not MISSING: - payload['attachments'] = [a.to_dict() for a in attachments] + payload["attachments"] = [a.to_dict() for a in attachments] if view is not MISSING: state.prevent_view_updates_for(message_id) if view is None: - payload['components'] = [] + payload["components"] = [] else: - payload['components'] = view.to_components() + payload["components"] = view.to_components() adapter = async_context.get() await adapter.create_interaction_response( @@ -633,7 +659,7 @@ class InteractionResponse: class _InteractionMessageState: - __slots__ = ('_parent', '_interaction') + __slots__ = ("_parent", "_interaction") def __init__(self, interaction: Interaction, parent: ConnectionState): self._interaction: Interaction = interaction diff --git a/discord/invite.py b/discord/invite.py index 050d2b83..1df4327c 100644 --- a/discord/invite.py +++ b/discord/invite.py @@ -33,9 +33,9 @@ from .enums import ChannelType, VerificationLevel, InviteTarget, try_enum from .appinfo import PartialAppInfo __all__ = ( - 'PartialInviteChannel', - 'PartialInviteGuild', - 'Invite', + "PartialInviteChannel", + "PartialInviteGuild", + "Invite", ) if TYPE_CHECKING: @@ -52,8 +52,8 @@ if TYPE_CHECKING: from .abc import GuildChannel from .user import User - InviteGuildType = Union[Guild, 'PartialInviteGuild', Object] - InviteChannelType = Union[GuildChannel, 'PartialInviteChannel', Object] + InviteGuildType = Union[Guild, "PartialInviteGuild", Object] + InviteChannelType = Union[GuildChannel, "PartialInviteChannel", Object] import datetime @@ -92,23 +92,25 @@ class PartialInviteChannel: The partial channel's type. """ - __slots__ = ('id', 'name', 'type') + __slots__ = ("id", "name", "type") def __init__(self, data: InviteChannelPayload): - self.id: int = int(data['id']) - self.name: str = data['name'] - self.type: ChannelType = try_enum(ChannelType, data['type']) + self.id: int = int(data["id"]) + self.name: str = data["name"] + self.type: ChannelType = try_enum(ChannelType, data["type"]) def __str__(self) -> str: return self.name def __repr__(self) -> str: - return f'' + return ( + f"" + ) @property def mention(self) -> str: """:class:`str`: The string that allows you to mention the channel.""" - return f'<#{self.id}>' + return f"<#{self.id}>" @property def created_at(self) -> datetime.datetime: @@ -154,26 +156,38 @@ class PartialInviteGuild: The partial guild's description. """ - __slots__ = ('_state', 'features', '_icon', '_banner', 'id', 'name', '_splash', 'verification_level', 'description') + __slots__ = ( + "_state", + "features", + "_icon", + "_banner", + "id", + "name", + "_splash", + "verification_level", + "description", + ) def __init__(self, state: ConnectionState, data: InviteGuildPayload, id: int): self._state: ConnectionState = state self.id: int = id - self.name: str = data['name'] - self.features: List[str] = data.get('features', []) - self._icon: Optional[str] = data.get('icon') - self._banner: Optional[str] = data.get('banner') - self._splash: Optional[str] = data.get('splash') - self.verification_level: VerificationLevel = try_enum(VerificationLevel, data.get('verification_level')) - self.description: Optional[str] = data.get('description') + self.name: str = data["name"] + self.features: List[str] = data.get("features", []) + self._icon: Optional[str] = data.get("icon") + self._banner: Optional[str] = data.get("banner") + self._splash: Optional[str] = data.get("splash") + self.verification_level: VerificationLevel = try_enum( + VerificationLevel, data.get("verification_level") + ) + self.description: Optional[str] = data.get("description") def __str__(self) -> str: return self.name def __repr__(self) -> str: return ( - f'<{self.__class__.__name__} id={self.id} name={self.name!r} features={self.features} ' - f'description={self.description!r}>' + f"<{self.__class__.__name__} id={self.id} name={self.name!r} features={self.features} " + f"description={self.description!r}>" ) @property @@ -193,17 +207,21 @@ class PartialInviteGuild: """Optional[:class:`Asset`]: Returns the guild's banner asset, if available.""" if self._banner is None: return None - return Asset._from_guild_image(self._state, self.id, self._banner, path='banners') + return Asset._from_guild_image( + self._state, self.id, self._banner, path="banners" + ) @property def splash(self) -> Optional[Asset]: """Optional[:class:`Asset`]: Returns the guild's invite splash asset, if available.""" if self._splash is None: return None - return Asset._from_guild_image(self._state, self.id, self._splash, path='splashes') + return Asset._from_guild_image( + self._state, self.id, self._splash, path="splashes" + ) -I = TypeVar('I', bound='Invite') +I = TypeVar("I", bound="Invite") class Invite(Hashable): @@ -308,26 +326,26 @@ class Invite(Hashable): """ __slots__ = ( - 'max_age', - 'code', - 'guild', - 'revoked', - 'created_at', - 'uses', - 'temporary', - 'max_uses', - 'inviter', - 'channel', - 'target_user', - 'target_type', - '_state', - 'approximate_member_count', - 'approximate_presence_count', - 'target_application', - 'expires_at', + "max_age", + "code", + "guild", + "revoked", + "created_at", + "uses", + "temporary", + "max_uses", + "inviter", + "channel", + "target_user", + "target_type", + "_state", + "approximate_member_count", + "approximate_presence_count", + "target_application", + "expires_at", ) - BASE = 'https://discord.gg' + BASE = "https://discord.gg" def __init__( self, @@ -338,45 +356,67 @@ class Invite(Hashable): channel: Optional[Union[PartialInviteChannel, GuildChannel]] = None, ): self._state: ConnectionState = state - self.max_age: Optional[int] = data.get('max_age') - self.code: str = data['code'] - self.guild: Optional[InviteGuildType] = self._resolve_guild(data.get('guild'), guild) - self.revoked: Optional[bool] = data.get('revoked') - self.created_at: Optional[datetime.datetime] = parse_time(data.get('created_at')) - self.temporary: Optional[bool] = data.get('temporary') - self.uses: Optional[int] = data.get('uses') - self.max_uses: Optional[int] = data.get('max_uses') - self.approximate_presence_count: Optional[int] = data.get('approximate_presence_count') - self.approximate_member_count: Optional[int] = data.get('approximate_member_count') + self.max_age: Optional[int] = data.get("max_age") + self.code: str = data["code"] + self.guild: Optional[InviteGuildType] = self._resolve_guild( + data.get("guild"), guild + ) + self.revoked: Optional[bool] = data.get("revoked") + self.created_at: Optional[datetime.datetime] = parse_time( + data.get("created_at") + ) + self.temporary: Optional[bool] = data.get("temporary") + self.uses: Optional[int] = data.get("uses") + self.max_uses: Optional[int] = data.get("max_uses") + self.approximate_presence_count: Optional[int] = data.get( + "approximate_presence_count" + ) + self.approximate_member_count: Optional[int] = data.get( + "approximate_member_count" + ) - expires_at = data.get('expires_at', None) - self.expires_at: Optional[datetime.datetime] = parse_time(expires_at) if expires_at else None + expires_at = data.get("expires_at", None) + self.expires_at: Optional[datetime.datetime] = ( + parse_time(expires_at) if expires_at else None + ) - inviter_data = data.get('inviter') - self.inviter: Optional[User] = None if inviter_data is None else self._state.create_user(inviter_data) + inviter_data = data.get("inviter") + self.inviter: Optional[User] = ( + None if inviter_data is None else self._state.create_user(inviter_data) + ) - self.channel: Optional[InviteChannelType] = self._resolve_channel(data.get('channel'), channel) + self.channel: Optional[InviteChannelType] = self._resolve_channel( + data.get("channel"), channel + ) - target_user_data = data.get('target_user') - self.target_user: Optional[User] = None if target_user_data is None else self._state.create_user(target_user_data) + target_user_data = data.get("target_user") + self.target_user: Optional[User] = ( + None + if target_user_data is None + else self._state.create_user(target_user_data) + ) - self.target_type: InviteTarget = try_enum(InviteTarget, data.get("target_type", 0)) + self.target_type: InviteTarget = try_enum( + InviteTarget, data.get("target_type", 0) + ) - application = data.get('target_application') + application = data.get("target_application") self.target_application: Optional[PartialAppInfo] = ( PartialAppInfo(data=application, state=state) if application else None ) @classmethod - def from_incomplete(cls: Type[I], *, state: ConnectionState, data: InvitePayload) -> I: + def from_incomplete( + cls: Type[I], *, state: ConnectionState, data: InvitePayload + ) -> I: guild: Optional[Union[Guild, PartialInviteGuild]] try: - guild_data = data['guild'] + guild_data = data["guild"] except KeyError: # If we're here, then this is a group DM guild = None else: - guild_id = int(guild_data['id']) + guild_id = int(guild_data["id"]) guild = state._get_guild(guild_id) if guild is None: # If it's not cached, then it has to be a partial guild @@ -384,7 +424,9 @@ class Invite(Hashable): # As far as I know, invites always need a channel # So this should never raise. - channel: Union[PartialInviteChannel, GuildChannel] = PartialInviteChannel(data['channel']) + channel: Union[PartialInviteChannel, GuildChannel] = PartialInviteChannel( + data["channel"] + ) if guild is not None and not isinstance(guild, PartialInviteGuild): # Upgrade the partial data if applicable channel = guild.get_channel(channel.id) or channel @@ -392,10 +434,12 @@ class Invite(Hashable): return cls(state=state, data=data, guild=guild, channel=channel) @classmethod - def from_gateway(cls: Type[I], *, state: ConnectionState, data: GatewayInvitePayload) -> I: - guild_id: Optional[int] = _get_as_snowflake(data, 'guild_id') + def from_gateway( + cls: Type[I], *, state: ConnectionState, data: GatewayInvitePayload + ) -> I: + guild_id: Optional[int] = _get_as_snowflake(data, "guild_id") guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id) - channel_id = int(data['channel_id']) + channel_id = int(data["channel_id"]) if guild is not None: channel = guild.get_channel(channel_id) or Object(id=channel_id) # type: ignore else: @@ -415,7 +459,7 @@ class Invite(Hashable): if data is None: return None - guild_id = int(data['id']) + guild_id = int(data["id"]) return PartialInviteGuild(self._state, data, guild_id) def _resolve_channel( @@ -439,9 +483,9 @@ class Invite(Hashable): def __repr__(self) -> str: return ( - f'' + f"" ) def __hash__(self) -> int: @@ -455,7 +499,7 @@ class Invite(Hashable): @property def url(self) -> str: """:class:`str`: A property that retrieves the invite URL.""" - return self.BASE + '/' + self.code + return self.BASE + "/" + self.code async def delete(self, *, reason: Optional[str] = None): """|coro| diff --git a/discord/iterators.py b/discord/iterators.py index f725d527..593226c1 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -26,7 +26,17 @@ from __future__ import annotations import asyncio import datetime -from typing import Awaitable, TYPE_CHECKING, TypeVar, Optional, Any, Callable, Union, List, AsyncIterator +from typing import ( + Awaitable, + TYPE_CHECKING, + TypeVar, + Optional, + Any, + Callable, + Union, + List, + AsyncIterator, +) from .errors import NoMoreItems from .utils import snowflake_time, time_snowflake, maybe_coroutine @@ -34,11 +44,11 @@ from .object import Object from .audit_logs import AuditLogEntry __all__ = ( - 'ReactionIterator', - 'HistoryIterator', - 'AuditLogIterator', - 'GuildIterator', - 'MemberIterator', + "ReactionIterator", + "HistoryIterator", + "AuditLogIterator", + "GuildIterator", + "MemberIterator", ) if TYPE_CHECKING: @@ -67,8 +77,8 @@ if TYPE_CHECKING: from .threads import Thread from .abc import Snowflake -T = TypeVar('T') -OT = TypeVar('OT') +T = TypeVar("T") +OT = TypeVar("OT") _Func = Callable[[T], Union[OT, Awaitable[OT]]] OLDEST_OBJECT = Object(id=0) @@ -83,7 +93,7 @@ class _AsyncIterator(AsyncIterator[T]): def get(self, **attrs: Any) -> Awaitable[Optional[T]]: def predicate(elem: T): for attr, val in attrs.items(): - nested = attr.split('__') + nested = attr.split("__") obj = elem for attribute in nested: obj = getattr(obj, attribute) @@ -107,7 +117,7 @@ class _AsyncIterator(AsyncIterator[T]): def chunk(self, max_size: int) -> _ChunkedAsyncIterator[T]: if max_size <= 0: - raise ValueError('async iterator chunk sizes must be greater than 0.') + raise ValueError("async iterator chunk sizes must be greater than 0.") return _ChunkedAsyncIterator(self, max_size) def map(self, func: _Func[T, OT]) -> _MappedAsyncIterator[OT]: @@ -182,7 +192,7 @@ class _FilteredAsyncIterator(_AsyncIterator[T]): return item -class ReactionIterator(_AsyncIterator[Union['User', 'Member']]): +class ReactionIterator(_AsyncIterator[Union["User", "Member"]]): def __init__(self, message, emoji, limit=100, after=None): self.message = message self.limit = limit @@ -218,14 +228,14 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]): if data: self.limit -= retrieve - self.after = Object(id=int(data[-1]['id'])) + self.after = Object(id=int(data[-1]["id"])) if self.guild is None or isinstance(self.guild, Object): for element in reversed(data): await self.users.put(User(state=self.state, data=element)) else: for element in reversed(data): - member_id = int(element['id']) + member_id = int(element["id"]) member = self.guild.get_member(member_id) if member is not None: await self.users.put(member) @@ -233,7 +243,7 @@ class ReactionIterator(_AsyncIterator[Union['User', 'Member']]): await self.users.put(User(state=self.state, data=element)) -class HistoryIterator(_AsyncIterator['Message']): +class HistoryIterator(_AsyncIterator["Message"]): """Iterator for receiving a channel's message history. The messages endpoint has two behaviours we care about here: @@ -267,7 +277,15 @@ class HistoryIterator(_AsyncIterator['Message']): ``True`` if `after` is specified, otherwise ``False``. """ - def __init__(self, messageable, limit, before=None, after=None, around=None, oldest_first=None): + def __init__( + self, + messageable, + limit, + before=None, + after=None, + around=None, + oldest_first=None, + ): if isinstance(before, datetime.datetime): before = Object(id=time_snowflake(before, high=False)) @@ -295,28 +313,30 @@ class HistoryIterator(_AsyncIterator['Message']): if self.around: if self.limit is None: - raise ValueError('history does not support around with limit=None') + raise ValueError("history does not support around with limit=None") if self.limit > 101: - raise ValueError("history max limit 101 when specifying around parameter") + raise ValueError( + "history max limit 101 when specifying around parameter" + ) elif self.limit == 101: self.limit = 100 # Thanks discord self._retrieve_messages = self._retrieve_messages_around_strategy # type: ignore if self.before and self.after: - self._filter = lambda m: self.after.id < int(m['id']) < self.before.id + self._filter = lambda m: self.after.id < int(m["id"]) < self.before.id elif self.before: - self._filter = lambda m: int(m['id']) < self.before.id + self._filter = lambda m: int(m["id"]) < self.before.id elif self.after: - self._filter = lambda m: self.after.id < int(m['id']) + self._filter = lambda m: self.after.id < int(m["id"]) else: if self.reverse: self._retrieve_messages = self._retrieve_messages_after_strategy # type: ignore if self.before: - self._filter = lambda m: int(m['id']) < self.before.id + self._filter = lambda m: int(m["id"]) < self.before.id else: self._retrieve_messages = self._retrieve_messages_before_strategy # type: ignore if self.after and self.after != OLDEST_OBJECT: - self._filter = lambda m: int(m['id']) > self.after.id + self._filter = lambda m: int(m["id"]) > self.after.id async def next(self) -> Message: if self.messages.empty(): @@ -337,7 +357,7 @@ class HistoryIterator(_AsyncIterator['Message']): return r > 0 async def fill_messages(self): - if not hasattr(self, 'channel'): + if not hasattr(self, "channel"): # do the required set up channel = await self.messageable._get_channel() self.channel = channel @@ -354,7 +374,9 @@ class HistoryIterator(_AsyncIterator['Message']): channel = self.channel for element in data: - await self.messages.put(self.state.create_message(channel=channel, data=element)) + await self.messages.put( + self.state.create_message(channel=channel, data=element) + ) async def _retrieve_messages(self, retrieve) -> List[Message]: """Retrieve messages and update next parameters.""" @@ -363,35 +385,50 @@ class HistoryIterator(_AsyncIterator['Message']): async def _retrieve_messages_before_strategy(self, retrieve): """Retrieve messages using before parameter.""" before = self.before.id if self.before else None - data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, before=before) + data: List[MessagePayload] = await self.logs_from( + self.channel.id, retrieve, before=before + ) if len(data): if self.limit is not None: self.limit -= retrieve - self.before = Object(id=int(data[-1]['id'])) + self.before = Object(id=int(data[-1]["id"])) return data async def _retrieve_messages_after_strategy(self, retrieve): """Retrieve messages using after parameter.""" after = self.after.id if self.after else None - data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, after=after) + data: List[MessagePayload] = await self.logs_from( + self.channel.id, retrieve, after=after + ) if len(data): if self.limit is not None: self.limit -= retrieve - self.after = Object(id=int(data[0]['id'])) + self.after = Object(id=int(data[0]["id"])) return data async def _retrieve_messages_around_strategy(self, retrieve): """Retrieve messages using around parameter.""" if self.around: around = self.around.id if self.around else None - data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, around=around) + data: List[MessagePayload] = await self.logs_from( + self.channel.id, retrieve, around=around + ) self.around = None return data return [] -class AuditLogIterator(_AsyncIterator['AuditLogEntry']): - def __init__(self, guild, limit=None, before=None, after=None, oldest_first=None, user_id=None, action_type=None): +class AuditLogIterator(_AsyncIterator["AuditLogEntry"]): + def __init__( + self, + guild, + limit=None, + before=None, + after=None, + oldest_first=None, + user_id=None, + action_type=None, + ): if isinstance(before, datetime.datetime): before = Object(id=time_snowflake(before, high=False)) if isinstance(after, datetime.datetime): @@ -420,36 +457,44 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']): if self.reverse: self._strategy = self._after_strategy if self.before: - self._filter = lambda m: int(m['id']) < self.before.id + self._filter = lambda m: int(m["id"]) < self.before.id else: self._strategy = self._before_strategy if self.after and self.after != OLDEST_OBJECT: - self._filter = lambda m: int(m['id']) > self.after.id + self._filter = lambda m: int(m["id"]) > self.after.id async def _before_strategy(self, retrieve): before = self.before.id if self.before else None data: AuditLogPayload = await self.request( - self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, before=before + self.guild.id, + limit=retrieve, + user_id=self.user_id, + action_type=self.action_type, + before=before, ) - entries = data.get('audit_log_entries', []) + entries = data.get("audit_log_entries", []) if len(data) and entries: if self.limit is not None: self.limit -= retrieve - self.before = Object(id=int(entries[-1]['id'])) - return data.get('users', []), entries + self.before = Object(id=int(entries[-1]["id"])) + return data.get("users", []), entries async def _after_strategy(self, retrieve): after = self.after.id if self.after else None data: AuditLogPayload = await self.request( - self.guild.id, limit=retrieve, user_id=self.user_id, action_type=self.action_type, after=after + self.guild.id, + limit=retrieve, + user_id=self.user_id, + action_type=self.action_type, + after=after, ) - entries = data.get('audit_log_entries', []) + entries = data.get("audit_log_entries", []) if len(data) and entries: if self.limit is not None: self.limit -= retrieve - self.after = Object(id=int(entries[0]['id'])) - return data.get('users', []), entries + self.after = Object(id=int(entries[0]["id"])) + return data.get("users", []), entries async def next(self) -> AuditLogEntry: if self.entries.empty(): @@ -488,13 +533,15 @@ class AuditLogIterator(_AsyncIterator['AuditLogEntry']): for element in data: # TODO: remove this if statement later - if element['action_type'] is None: + if element["action_type"] is None: continue - await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild)) + await self.entries.put( + AuditLogEntry(data=element, users=self._users, guild=self.guild) + ) -class GuildIterator(_AsyncIterator['Guild']): +class GuildIterator(_AsyncIterator["Guild"]): """Iterator for receiving the client's guilds. The guilds endpoint has the same two behaviours as described @@ -543,7 +590,7 @@ class GuildIterator(_AsyncIterator['Guild']): if self.before and self.after: self._retrieve_guilds = self._retrieve_guilds_before_strategy # type: ignore - self._filter = lambda m: int(m['id']) > self.after.id + self._filter = lambda m: int(m["id"]) > self.after.id elif self.after: self._retrieve_guilds = self._retrieve_guilds_after_strategy # type: ignore else: @@ -595,7 +642,7 @@ class GuildIterator(_AsyncIterator['Guild']): if len(data): if self.limit is not None: self.limit -= retrieve - self.before = Object(id=int(data[-1]['id'])) + self.before = Object(id=int(data[-1]["id"])) return data async def _retrieve_guilds_after_strategy(self, retrieve): @@ -605,11 +652,11 @@ class GuildIterator(_AsyncIterator['Guild']): if len(data): if self.limit is not None: self.limit -= retrieve - self.after = Object(id=int(data[0]['id'])) + self.after = Object(id=int(data[0]["id"])) return data -class MemberIterator(_AsyncIterator['Member']): +class MemberIterator(_AsyncIterator["Member"]): def __init__(self, guild, limit=1000, after=None): if isinstance(after, datetime.datetime): @@ -652,7 +699,7 @@ class MemberIterator(_AsyncIterator['Member']): if len(data) < 1000: self.limit = 0 # terminate loop - self.after = Object(id=int(data[-1]['user']['id'])) + self.after = Object(id=int(data[-1]["user"]["id"])) for element in reversed(data): await self.members.put(self.create_member(element)) @@ -663,7 +710,7 @@ class MemberIterator(_AsyncIterator['Member']): return Member(data=data, guild=self.guild, state=self.state) -class ArchivedThreadIterator(_AsyncIterator['Thread']): +class ArchivedThreadIterator(_AsyncIterator["Thread"]): def __init__( self, channel_id: int, @@ -681,7 +728,7 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']): self.http = guild._state.http if joined and not private: - raise ValueError('Cannot iterate over joined public archived threads') + raise ValueError("Cannot iterate over joined public archived threads") self.before: Optional[str] if before is None: @@ -721,11 +768,11 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']): @staticmethod def get_archive_timestamp(data: ThreadPayload) -> str: - return data['thread_metadata']['archive_timestamp'] + return data["thread_metadata"]["archive_timestamp"] @staticmethod def get_thread_id(data: ThreadPayload) -> str: - return data['id'] # type: ignore + return data["id"] # type: ignore async def fill_queue(self) -> None: if not self.has_more: @@ -735,11 +782,11 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']): data = await self.endpoint(self.channel_id, before=self.before, limit=limit) # This stuff is obviously WIP because 'members' is always empty - threads: List[ThreadPayload] = data.get('threads', []) + threads: List[ThreadPayload] = data.get("threads", []) for d in reversed(threads): self.queue.put_nowait(self.create_thread(d)) - self.has_more = data.get('has_more', False) + self.has_more = data.get("has_more", False) if self.limit is not None: self.limit -= len(threads) if self.limit <= 0: @@ -750,4 +797,5 @@ class ArchivedThreadIterator(_AsyncIterator['Thread']): def create_thread(self, data: ThreadPayload) -> Thread: from .threads import Thread + return Thread(guild=self.guild, state=self.guild._state, data=data) diff --git a/discord/member.py b/discord/member.py index 49b7e3ab..9ba0bda3 100644 --- a/discord/member.py +++ b/discord/member.py @@ -29,7 +29,19 @@ import inspect import itertools import sys from operator import attrgetter -from typing import Any, Dict, List, Literal, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union, overload +from typing import ( + Any, + Dict, + List, + Literal, + Optional, + TYPE_CHECKING, + Tuple, + Type, + TypeVar, + Union, + overload, +) import discord.abc @@ -44,8 +56,8 @@ from .colour import Colour from .object import Object __all__ = ( - 'VoiceState', - 'Member', + "VoiceState", + "Member", ) if TYPE_CHECKING: @@ -113,52 +125,58 @@ class VoiceState: """ __slots__ = ( - 'session_id', - 'deaf', - 'mute', - 'self_mute', - 'self_stream', - 'self_video', - 'self_deaf', - 'afk', - 'channel', - 'requested_to_speak_at', - 'suppress', + "session_id", + "deaf", + "mute", + "self_mute", + "self_stream", + "self_video", + "self_deaf", + "afk", + "channel", + "requested_to_speak_at", + "suppress", ) - def __init__(self, *, data: VoiceStatePayload, channel: Optional[VocalGuildChannel] = None): - self.session_id: str = data.get('session_id') + def __init__( + self, *, data: VoiceStatePayload, channel: Optional[VocalGuildChannel] = None + ): + self.session_id: str = data.get("session_id") self._update(data, channel) def _update(self, data: VoiceStatePayload, channel: Optional[VocalGuildChannel]): - self.self_mute: bool = data.get('self_mute', False) - self.self_deaf: bool = data.get('self_deaf', False) - self.self_stream: bool = data.get('self_stream', False) - self.self_video: bool = data.get('self_video', False) - self.afk: bool = data.get('suppress', False) - self.mute: bool = data.get('mute', False) - self.deaf: bool = data.get('deaf', False) - self.suppress: bool = data.get('suppress', False) - self.requested_to_speak_at: Optional[datetime.datetime] = utils.parse_time(data.get('request_to_speak_timestamp')) + self.self_mute: bool = data.get("self_mute", False) + self.self_deaf: bool = data.get("self_deaf", False) + self.self_stream: bool = data.get("self_stream", False) + self.self_video: bool = data.get("self_video", False) + self.afk: bool = data.get("suppress", False) + self.mute: bool = data.get("mute", False) + self.deaf: bool = data.get("deaf", False) + self.suppress: bool = data.get("suppress", False) + self.requested_to_speak_at: Optional[datetime.datetime] = utils.parse_time( + data.get("request_to_speak_timestamp") + ) self.channel: Optional[VocalGuildChannel] = channel def __repr__(self) -> str: attrs = [ - ('self_mute', self.self_mute), - ('self_deaf', self.self_deaf), - ('self_stream', self.self_stream), - ('suppress', self.suppress), - ('requested_to_speak_at', self.requested_to_speak_at), - ('channel', self.channel), + ("self_mute", self.self_mute), + ("self_deaf", self.self_deaf), + ("self_stream", self.self_stream), + ("suppress", self.suppress), + ("requested_to_speak_at", self.requested_to_speak_at), + ("channel", self.channel), ] - inner = ' '.join('%s=%r' % t for t in attrs) - return f'<{self.__class__.__name__} {inner}>' + inner = " ".join("%s=%r" % t for t in attrs) + return f"<{self.__class__.__name__} {inner}>" def flatten_user(cls): - for attr, value in itertools.chain(BaseUser.__dict__.items(), User.__dict__.items()): + for attr, value in itertools.chain( + BaseUser.__dict__.items(), User.__dict__.items() + ): # ignore private/special methods - if attr.startswith('_'): + if attr.startswith("_"): continue # don't override what we already have @@ -167,9 +185,11 @@ def flatten_user(cls): # if it's a slotted attribute or a property, redirect it # slotted members are implemented as member_descriptors in Type.__dict__ - if not hasattr(value, '__annotations__'): - getter = attrgetter('_user.' + attr) - setattr(cls, attr, property(getter, doc=f'Equivalent to :attr:`User.{attr}`')) + if not hasattr(value, "__annotations__"): + getter = attrgetter("_user." + attr) + setattr( + cls, attr, property(getter, doc=f"Equivalent to :attr:`User.{attr}`") + ) else: # Technically, this can also use attrgetter # However I'm not sure how I feel about "functions" returning properties @@ -197,7 +217,7 @@ def flatten_user(cls): return cls -M = TypeVar('M', bound='Member') +M = TypeVar("M", bound="Member") @flatten_user @@ -258,17 +278,17 @@ class Member(discord.abc.Messageable, _UserTag): """ __slots__ = ( - '_roles', - 'joined_at', - 'premium_since', - 'activities', - 'guild', - 'pending', - 'nick', - '_client_status', - '_user', - '_state', - '_avatar', + "_roles", + "joined_at", + "premium_since", + "activities", + "guild", + "pending", + "nick", + "_client_status", + "_user", + "_state", + "_avatar", ) if TYPE_CHECKING: @@ -288,18 +308,24 @@ class Member(discord.abc.Messageable, _UserTag): accent_color: Optional[Colour] accent_colour: Optional[Colour] - def __init__(self, *, data: MemberWithUserPayload, guild: Guild, state: ConnectionState): + def __init__( + self, *, data: MemberWithUserPayload, guild: Guild, state: ConnectionState + ): self._state: ConnectionState = state - self._user: User = state.store_user(data['user']) + self._user: User = state.store_user(data["user"]) self.guild: Guild = guild - self.joined_at: Optional[datetime.datetime] = utils.parse_time(data.get('joined_at')) - self.premium_since: Optional[datetime.datetime] = utils.parse_time(data.get('premium_since')) - self._roles: utils.SnowflakeList = utils.SnowflakeList(map(int, data['roles'])) - self._client_status: Dict[Optional[str], str] = {None: 'offline'} + self.joined_at: Optional[datetime.datetime] = utils.parse_time( + data.get("joined_at") + ) + self.premium_since: Optional[datetime.datetime] = utils.parse_time( + data.get("premium_since") + ) + self._roles: utils.SnowflakeList = utils.SnowflakeList(map(int, data["roles"])) + self._client_status: Dict[Optional[str], str] = {None: "offline"} self.activities: Tuple[ActivityTypes, ...] = tuple() - self.nick: Optional[str] = data.get('nick', None) - self.pending: bool = data.get('pending', False) - self._avatar: Optional[str] = data.get('avatar') + self.nick: Optional[str] = data.get("nick", None) + self.pending: bool = data.get("pending", False) + self._avatar: Optional[str] = data.get("avatar") def __str__(self) -> str: return str(self._user) @@ -309,8 +335,8 @@ class Member(discord.abc.Messageable, _UserTag): def __repr__(self) -> str: return ( - f'' + f"" ) def __eq__(self, other: Any) -> bool: @@ -325,25 +351,31 @@ class Member(discord.abc.Messageable, _UserTag): @classmethod def _from_message(cls: Type[M], *, message: Message, data: MemberPayload) -> M: author = message.author - data['user'] = author._to_minimal_user_json() # type: ignore + data["user"] = author._to_minimal_user_json() # type: ignore return cls(data=data, guild=message.guild, state=message._state) # type: ignore def _update_from_message(self, data: MemberPayload) -> None: - self.joined_at = utils.parse_time(data.get('joined_at')) - self.premium_since = utils.parse_time(data.get('premium_since')) - self._roles = utils.SnowflakeList(map(int, data['roles'])) - self.nick = data.get('nick', None) - self.pending = data.get('pending', False) + self.joined_at = utils.parse_time(data.get("joined_at")) + self.premium_since = utils.parse_time(data.get("premium_since")) + self._roles = utils.SnowflakeList(map(int, data["roles"])) + self.nick = data.get("nick", None) + self.pending = data.get("pending", False) @classmethod - def _try_upgrade(cls: Type[M], *, data: UserWithMemberPayload, guild: Guild, state: ConnectionState) -> Union[User, M]: + def _try_upgrade( + cls: Type[M], + *, + data: UserWithMemberPayload, + guild: Guild, + state: ConnectionState, + ) -> Union[User, M]: # A User object with a 'member' key try: - member_data = data.pop('member') + member_data = data.pop("member") except KeyError: return state.create_user(data) else: - member_data['user'] = data # type: ignore + member_data["user"] = data # type: ignore return cls(data=member_data, guild=guild, state=state) # type: ignore @classmethod @@ -374,25 +406,27 @@ class Member(discord.abc.Messageable, _UserTag): # the nickname change is optional, # if it isn't in the payload then it didn't change try: - self.nick = data['nick'] + self.nick = data["nick"] except KeyError: pass try: - self.pending = data['pending'] + self.pending = data["pending"] except KeyError: pass - self.premium_since = utils.parse_time(data.get('premium_since')) - self._roles = utils.SnowflakeList(map(int, data['roles'])) - self._avatar = data.get('avatar') + self.premium_since = utils.parse_time(data.get("premium_since")) + self._roles = utils.SnowflakeList(map(int, data["roles"])) + self._avatar = data.get("avatar") - def _presence_update(self, data: PartialPresenceUpdate, user: UserPayload) -> Optional[Tuple[User, User]]: - self.activities = tuple(map(create_activity, data['activities'])) + def _presence_update( + self, data: PartialPresenceUpdate, user: UserPayload + ) -> Optional[Tuple[User, User]]: + self.activities = tuple(map(create_activity, data["activities"])) self._client_status = { - sys.intern(key): sys.intern(value) for key, value in data.get('client_status', {}).items() # type: ignore + sys.intern(key): sys.intern(value) for key, value in data.get("client_status", {}).items() # type: ignore } - self._client_status[None] = sys.intern(data['status']) + self._client_status[None] = sys.intern(data["status"]) if len(user) > 1: return self._update_inner_user(user) @@ -402,7 +436,12 @@ class Member(discord.abc.Messageable, _UserTag): u = self._user original = (u.name, u._avatar, u.discriminator, u._public_flags) # These keys seem to always be available - modified = (user['username'], user['avatar'], user['discriminator'], user.get('public_flags', 0)) + modified = ( + user["username"], + user["avatar"], + user["discriminator"], + user.get("public_flags", 0), + ) if original != modified: to_return = User._copy(self._user) u.name, u._avatar, u.discriminator, u._public_flags = modified @@ -430,21 +469,21 @@ class Member(discord.abc.Messageable, _UserTag): @property def mobile_status(self) -> Status: """:class:`Status`: The member's status on a mobile device, if applicable.""" - return try_enum(Status, self._client_status.get('mobile', 'offline')) + return try_enum(Status, self._client_status.get("mobile", "offline")) @property def desktop_status(self) -> Status: """:class:`Status`: The member's status on the desktop client, if applicable.""" - return try_enum(Status, self._client_status.get('desktop', 'offline')) + return try_enum(Status, self._client_status.get("desktop", "offline")) @property def web_status(self) -> Status: """:class:`Status`: The member's status on the web client, if applicable.""" - return try_enum(Status, self._client_status.get('web', 'offline')) + return try_enum(Status, self._client_status.get("web", "offline")) def is_on_mobile(self) -> bool: """:class:`bool`: A helper function that determines if a member is active on a mobile device.""" - return 'mobile' in self._client_status + return "mobile" in self._client_status @property def colour(self) -> Colour: @@ -497,8 +536,8 @@ class Member(discord.abc.Messageable, _UserTag): def mention(self) -> str: """:class:`str`: Returns a string that allows you to mention the member.""" if self.nick: - return f'<@!{self._user.id}>' - return f'<@{self._user.id}>' + return f"<@!{self._user.id}>" + return f"<@{self._user.id}>" @property def display_name(self) -> str: @@ -531,7 +570,9 @@ class Member(discord.abc.Messageable, _UserTag): """ if self._avatar is None: return None - return Asset._from_guild_avatar(self._state, self.guild.id, self.id, self._avatar) + return Asset._from_guild_avatar( + self._state, self.guild.id, self.id, self._avatar + ) @property def activity(self) -> Optional[ActivityTypes]: @@ -625,7 +666,9 @@ class Member(discord.abc.Messageable, _UserTag): Bans this member. Equivalent to :meth:`Guild.ban`. """ - await self.guild.ban(self, reason=reason, delete_message_days=delete_message_days) + await self.guild.ban( + self, reason=reason, delete_message_days=delete_message_days + ) async def unban(self, *, reason: Optional[str] = None) -> None: """|coro| @@ -720,39 +763,41 @@ class Member(discord.abc.Messageable, _UserTag): payload: Dict[str, Any] = {} if nick is not MISSING: - nick = nick or '' + nick = nick or "" if me: await http.change_my_nickname(guild_id, nick, reason=reason) else: - payload['nick'] = nick + payload["nick"] = nick if deafen is not MISSING: - payload['deaf'] = deafen + payload["deaf"] = deafen if mute is not MISSING: - payload['mute'] = mute + payload["mute"] = mute if suppress is not MISSING: voice_state_payload = { - 'channel_id': self.voice.channel.id, - 'suppress': suppress, + "channel_id": self.voice.channel.id, + "suppress": suppress, } if suppress or self.bot: - voice_state_payload['request_to_speak_timestamp'] = None + voice_state_payload["request_to_speak_timestamp"] = None if me: await http.edit_my_voice_state(guild_id, voice_state_payload) else: if not suppress: - voice_state_payload['request_to_speak_timestamp'] = datetime.datetime.utcnow().isoformat() + voice_state_payload[ + "request_to_speak_timestamp" + ] = datetime.datetime.utcnow().isoformat() await http.edit_voice_state(guild_id, self.id, voice_state_payload) if voice_channel is not MISSING: - payload['channel_id'] = voice_channel and voice_channel.id + payload["channel_id"] = voice_channel and voice_channel.id if roles is not MISSING: - payload['roles'] = tuple(r.id for r in roles) + payload["roles"] = tuple(r.id for r in roles) if payload: data = await http.edit_member(guild_id, self.id, reason=reason, **payload) @@ -780,17 +825,19 @@ class Member(discord.abc.Messageable, _UserTag): The operation failed. """ payload = { - 'channel_id': self.voice.channel.id, - 'request_to_speak_timestamp': datetime.datetime.utcnow().isoformat(), + "channel_id": self.voice.channel.id, + "request_to_speak_timestamp": datetime.datetime.utcnow().isoformat(), } if self._state.self_id != self.id: - payload['suppress'] = False + payload["suppress"] = False await self._state.http.edit_voice_state(self.guild.id, self.id, payload) else: await self._state.http.edit_my_voice_state(self.guild.id, payload) - async def move_to(self, channel: VocalGuildChannel, *, reason: Optional[str] = None) -> None: + async def move_to( + self, channel: VocalGuildChannel, *, reason: Optional[str] = None + ) -> None: """|coro| Moves a member to a new voice channel (they must be connected first). @@ -813,7 +860,9 @@ class Member(discord.abc.Messageable, _UserTag): """ await self.edit(voice_channel=channel, reason=reason) - async def add_roles(self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True) -> None: + async def add_roles( + self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True + ) -> None: r"""|coro| Gives the member a number of :class:`Role`\s. @@ -843,7 +892,9 @@ class Member(discord.abc.Messageable, _UserTag): """ if not atomic: - new_roles = utils._unique(Object(id=r.id) for s in (self.roles[1:], roles) for r in s) + new_roles = utils._unique( + Object(id=r.id) for s in (self.roles[1:], roles) for r in s + ) await self.edit(roles=new_roles, reason=reason) else: req = self._state.http.add_role @@ -852,7 +903,9 @@ class Member(discord.abc.Messageable, _UserTag): for role in roles: await req(guild_id, user_id, role.id, reason=reason) - async def remove_roles(self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True) -> None: + async def remove_roles( + self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True + ) -> None: r"""|coro| Removes :class:`Role`\s from this member. diff --git a/discord/mentions.py b/discord/mentions.py index 0516decf..a5629419 100644 --- a/discord/mentions.py +++ b/discord/mentions.py @@ -25,9 +25,7 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations from typing import Type, TypeVar, Union, List, TYPE_CHECKING, Any, Union -__all__ = ( - 'AllowedMentions', -) +__all__ = ("AllowedMentions",) if TYPE_CHECKING: from .types.message import AllowedMentions as AllowedMentionsPayload @@ -36,7 +34,7 @@ if TYPE_CHECKING: class _FakeBool: def __repr__(self): - return 'True' + return "True" def __eq__(self, other): return other is True @@ -47,7 +45,7 @@ class _FakeBool: default: Any = _FakeBool() -A = TypeVar('A', bound='AllowedMentions') +A = TypeVar("A", bound="AllowedMentions") class AllowedMentions: @@ -80,7 +78,7 @@ class AllowedMentions: .. versionadded:: 1.6 """ - __slots__ = ('everyone', 'users', 'roles', 'replied_user') + __slots__ = ("everyone", "users", "roles", "replied_user") def __init__( self, @@ -116,22 +114,22 @@ class AllowedMentions: data = {} if self.everyone: - parse.append('everyone') + parse.append("everyone") if self.users == True: - parse.append('users') + parse.append("users") elif self.users != False: - data['users'] = [x.id for x in self.users] + data["users"] = [x.id for x in self.users] if self.roles == True: - parse.append('roles') + parse.append("roles") elif self.roles != False: - data['roles'] = [x.id for x in self.roles] + data["roles"] = [x.id for x in self.roles] if self.replied_user: - data['replied_user'] = True + data["replied_user"] = True - data['parse'] = parse + data["parse"] = parse return data # type: ignore def merge(self, other: AllowedMentions) -> AllowedMentions: @@ -141,11 +139,15 @@ class AllowedMentions: everyone = self.everyone if other.everyone is default else other.everyone users = self.users if other.users is default else other.users roles = self.roles if other.roles is default else other.roles - replied_user = self.replied_user if other.replied_user is default else other.replied_user - return AllowedMentions(everyone=everyone, roles=roles, users=users, replied_user=replied_user) + replied_user = ( + self.replied_user if other.replied_user is default else other.replied_user + ) + return AllowedMentions( + everyone=everyone, roles=roles, users=users, replied_user=replied_user + ) def __repr__(self) -> str: return ( - f'{self.__class__.__name__}(everyone={self.everyone}, ' - f'users={self.users}, roles={self.roles}, replied_user={self.replied_user})' + f"{self.__class__.__name__}(everyone={self.everyone}, " + f"users={self.users}, roles={self.roles}, replied_user={self.replied_user})" ) diff --git a/discord/message.py b/discord/message.py index 49c5e718..7d66a674 100644 --- a/discord/message.py +++ b/discord/message.py @@ -29,7 +29,21 @@ import datetime import re import io from os import PathLike -from typing import Dict, TYPE_CHECKING, Union, List, Optional, Any, Callable, Tuple, ClassVar, Optional, overload, TypeVar, Type +from typing import ( + Dict, + TYPE_CHECKING, + Union, + List, + Optional, + Any, + Callable, + Tuple, + ClassVar, + Optional, + overload, + TypeVar, + Type, +) from . import utils from .reaction import Reaction @@ -76,15 +90,15 @@ if TYPE_CHECKING: from .role import Role from .ui.view import View - MR = TypeVar('MR', bound='MessageReference') + MR = TypeVar("MR", bound="MessageReference") EmojiInputType = Union[Emoji, PartialEmoji, str] __all__ = ( - 'Attachment', - 'Message', - 'PartialMessage', - 'MessageReference', - 'DeletedReferencedMessage', + "Attachment", + "Message", + "PartialMessage", + "MessageReference", + "DeletedReferencedMessage", ) @@ -93,15 +107,17 @@ def convert_emoji_reaction(emoji): emoji = emoji.emoji if isinstance(emoji, Emoji): - return f'{emoji.name}:{emoji.id}' + return f"{emoji.name}:{emoji.id}" if isinstance(emoji, PartialEmoji): return emoji._as_reaction() if isinstance(emoji, str): # Reactions can be in :name:id format, but not <:name:id>. # No existing emojis have <> in them, so this should be okay. - return emoji.strip('<>') + return emoji.strip("<>") - raise InvalidArgument(f'emoji argument must be str, Emoji, or Reaction not {emoji.__class__.__name__}.') + raise InvalidArgument( + f"emoji argument must be str, Emoji, or Reaction not {emoji.__class__.__name__}." + ) class Attachment(Hashable): @@ -157,28 +173,38 @@ class Attachment(Hashable): .. versionadded:: 1.7 """ - __slots__ = ('id', 'size', 'height', 'width', 'filename', 'url', 'proxy_url', '_http', 'content_type') + __slots__ = ( + "id", + "size", + "height", + "width", + "filename", + "url", + "proxy_url", + "_http", + "content_type", + ) def __init__(self, *, data: AttachmentPayload, state: ConnectionState): - self.id: int = int(data['id']) - self.size: int = data['size'] - self.height: Optional[int] = data.get('height') - self.width: Optional[int] = data.get('width') - self.filename: str = data['filename'] - self.url: str = data.get('url') - self.proxy_url: str = data.get('proxy_url') + self.id: int = int(data["id"]) + self.size: int = data["size"] + self.height: Optional[int] = data.get("height") + self.width: Optional[int] = data.get("width") + self.filename: str = data["filename"] + self.url: str = data.get("url") + self.proxy_url: str = data.get("proxy_url") self._http = state.http - self.content_type: Optional[str] = data.get('content_type') + self.content_type: Optional[str] = data.get("content_type") def is_spoiler(self) -> bool: """:class:`bool`: Whether this attachment contains a spoiler.""" - return self.filename.startswith('SPOILER_') + return self.filename.startswith("SPOILER_") def __repr__(self) -> str: - return f'' + return f"" def __str__(self) -> str: - return self.url or '' + return self.url or "" async def save( self, @@ -227,7 +253,7 @@ class Attachment(Hashable): fp.seek(0) return written else: - with open(fp, 'wb') as f: + with open(fp, "wb") as f: return f.write(data) async def read(self, *, use_cached: bool = False) -> bytes: @@ -309,19 +335,19 @@ class Attachment(Hashable): def to_dict(self) -> AttachmentPayload: result: AttachmentPayload = { - 'filename': self.filename, - 'id': self.id, - 'proxy_url': self.proxy_url, - 'size': self.size, - 'url': self.url, - 'spoiler': self.is_spoiler(), + "filename": self.filename, + "id": self.id, + "proxy_url": self.proxy_url, + "size": self.size, + "url": self.url, + "spoiler": self.is_spoiler(), } if self.height: - result['height'] = self.height + result["height"] = self.height if self.width: - result['width'] = self.width + result["width"] = self.width if self.content_type: - result['content_type'] = self.content_type + result["content_type"] = self.content_type return result @@ -335,7 +361,7 @@ class DeletedReferencedMessage: .. versionadded:: 1.6 """ - __slots__ = ('_parent',) + __slots__ = ("_parent",) def __init__(self, parent: MessageReference): self._parent: MessageReference = parent @@ -347,7 +373,7 @@ class DeletedReferencedMessage: def id(self) -> int: """:class:`int`: The message ID of the deleted referenced message.""" # the parent's message id won't be None here - return self._parent.message_id # type: ignore + return self._parent.message_id # type: ignore @property def channel_id(self) -> int: @@ -394,9 +420,23 @@ class MessageReference: .. versionadded:: 1.6 """ - __slots__ = ('message_id', 'channel_id', 'guild_id', 'fail_if_not_exists', 'resolved', '_state') + __slots__ = ( + "message_id", + "channel_id", + "guild_id", + "fail_if_not_exists", + "resolved", + "_state", + ) - def __init__(self, *, message_id: int, channel_id: int, guild_id: Optional[int] = None, fail_if_not_exists: bool = True): + def __init__( + self, + *, + message_id: int, + channel_id: int, + guild_id: Optional[int] = None, + fail_if_not_exists: bool = True, + ): self._state: Optional[ConnectionState] = None self.resolved: Optional[Union[Message, DeletedReferencedMessage]] = None self.message_id: Optional[int] = message_id @@ -405,18 +445,22 @@ class MessageReference: self.fail_if_not_exists: bool = fail_if_not_exists @classmethod - def with_state(cls: Type[MR], state: ConnectionState, data: MessageReferencePayload) -> MR: + def with_state( + cls: Type[MR], state: ConnectionState, data: MessageReferencePayload + ) -> MR: self = cls.__new__(cls) - self.message_id = utils._get_as_snowflake(data, 'message_id') - self.channel_id = int(data.pop('channel_id')) - self.guild_id = utils._get_as_snowflake(data, 'guild_id') - self.fail_if_not_exists = data.get('fail_if_not_exists', True) + self.message_id = utils._get_as_snowflake(data, "message_id") + self.channel_id = int(data.pop("channel_id")) + self.guild_id = utils._get_as_snowflake(data, "guild_id") + self.fail_if_not_exists = data.get("fail_if_not_exists", True) self._state = state self.resolved = None return self @classmethod - def from_message(cls: Type[MR], message: Message, *, fail_if_not_exists: bool = True) -> MR: + def from_message( + cls: Type[MR], message: Message, *, fail_if_not_exists: bool = True + ) -> MR: """Creates a :class:`MessageReference` from an existing :class:`~discord.Message`. .. versionadded:: 1.6 @@ -439,7 +483,7 @@ class MessageReference: self = cls( message_id=message.id, channel_id=message.channel.id, - guild_id=getattr(message.guild, 'id', None), + guild_id=getattr(message.guild, "id", None), fail_if_not_exists=fail_if_not_exists, ) self._state = message._state @@ -456,36 +500,38 @@ class MessageReference: .. versionadded:: 1.7 """ - guild_id = self.guild_id if self.guild_id is not None else '@me' - return f'https://discord.com/channels/{guild_id}/{self.channel_id}/{self.message_id}' + guild_id = self.guild_id if self.guild_id is not None else "@me" + return f"https://discord.com/channels/{guild_id}/{self.channel_id}/{self.message_id}" def __repr__(self) -> str: - return f'' + return f"" def to_dict(self) -> MessageReferencePayload: - result: MessageReferencePayload = {'message_id': self.message_id} if self.message_id is not None else {} - result['channel_id'] = self.channel_id + result: MessageReferencePayload = ( + {"message_id": self.message_id} if self.message_id is not None else {} + ) + result["channel_id"] = self.channel_id if self.guild_id is not None: - result['guild_id'] = self.guild_id + result["guild_id"] = self.guild_id if self.fail_if_not_exists is not None: - result['fail_if_not_exists'] = self.fail_if_not_exists + result["fail_if_not_exists"] = self.fail_if_not_exists return result to_message_reference_dict = to_dict def flatten_handlers(cls): - prefix = len('_handle_') + prefix = len("_handle_") handlers = [ (key[prefix:], value) for key, value in cls.__dict__.items() - if key.startswith('_handle_') and key != '_handle_member' + if key.startswith("_handle_") and key != "_handle_member" ] # store _handle_member last - handlers.append(('member', cls._handle_member)) + handlers.append(("member", cls._handle_member)) cls._HANDLERS = handlers - cls._CACHED_SLOTS = [attr for attr in cls.__slots__ if attr.startswith('_cs_')] + cls._CACHED_SLOTS = [attr for attr in cls.__slots__ if attr.startswith("_cs_")] return cls @@ -615,36 +661,36 @@ class Message(Hashable): """ __slots__ = ( - '_state', - '_edited_timestamp', - '_cs_channel_mentions', - '_cs_raw_mentions', - '_cs_clean_content', - '_cs_raw_channel_mentions', - '_cs_raw_role_mentions', - '_cs_system_content', - 'tts', - 'content', - 'channel', - 'webhook_id', - 'mention_everyone', - 'embeds', - 'id', - 'mentions', - 'author', - 'attachments', - 'nonce', - 'pinned', - 'role_mentions', - 'type', - 'flags', - 'reactions', - 'reference', - 'application', - 'activity', - 'stickers', - 'components', - 'guild', + "_state", + "_edited_timestamp", + "_cs_channel_mentions", + "_cs_raw_mentions", + "_cs_clean_content", + "_cs_raw_channel_mentions", + "_cs_raw_role_mentions", + "_cs_system_content", + "tts", + "content", + "channel", + "webhook_id", + "mention_everyone", + "embeds", + "id", + "mentions", + "author", + "attachments", + "nonce", + "pinned", + "role_mentions", + "type", + "flags", + "reactions", + "reference", + "application", + "activity", + "stickers", + "components", + "guild", ) if TYPE_CHECKING: @@ -664,39 +710,49 @@ class Message(Hashable): data: MessagePayload, ): self._state: ConnectionState = state - self.id: int = int(data['id']) - self.webhook_id: Optional[int] = utils._get_as_snowflake(data, 'webhook_id') - self.reactions: List[Reaction] = [Reaction(message=self, data=d) for d in data.get('reactions', [])] - self.attachments: List[Attachment] = [Attachment(data=a, state=self._state) for a in data['attachments']] - self.embeds: List[Embed] = [Embed.from_dict(a) for a in data['embeds']] - self.application: Optional[MessageApplicationPayload] = data.get('application') - self.activity: Optional[MessageActivityPayload] = data.get('activity') + self.id: int = int(data["id"]) + self.webhook_id: Optional[int] = utils._get_as_snowflake(data, "webhook_id") + self.reactions: List[Reaction] = [ + Reaction(message=self, data=d) for d in data.get("reactions", []) + ] + self.attachments: List[Attachment] = [ + Attachment(data=a, state=self._state) for a in data["attachments"] + ] + self.embeds: List[Embed] = [Embed.from_dict(a) for a in data["embeds"]] + self.application: Optional[MessageApplicationPayload] = data.get("application") + self.activity: Optional[MessageActivityPayload] = data.get("activity") self.channel: MessageableChannel = channel - self._edited_timestamp: Optional[datetime.datetime] = utils.parse_time(data['edited_timestamp']) - self.type: MessageType = try_enum(MessageType, data['type']) - self.pinned: bool = data['pinned'] - self.flags: MessageFlags = MessageFlags._from_value(data.get('flags', 0)) - self.mention_everyone: bool = data['mention_everyone'] - self.tts: bool = data['tts'] - self.content: str = data['content'] - self.nonce: Optional[Union[int, str]] = data.get('nonce') - self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get('sticker_items', [])] - self.components: List[Component] = [_component_factory(d) for d in data.get('components', [])] + self._edited_timestamp: Optional[datetime.datetime] = utils.parse_time( + data["edited_timestamp"] + ) + self.type: MessageType = try_enum(MessageType, data["type"]) + self.pinned: bool = data["pinned"] + self.flags: MessageFlags = MessageFlags._from_value(data.get("flags", 0)) + self.mention_everyone: bool = data["mention_everyone"] + self.tts: bool = data["tts"] + self.content: str = data["content"] + self.nonce: Optional[Union[int, str]] = data.get("nonce") + self.stickers: List[StickerItem] = [ + StickerItem(data=d, state=state) for d in data.get("sticker_items", []) + ] + self.components: List[Component] = [ + _component_factory(d) for d in data.get("components", []) + ] try: # if the channel doesn't have a guild attribute, we handle that self.guild = channel.guild # type: ignore except AttributeError: - self.guild = state._get_guild(utils._get_as_snowflake(data, 'guild_id')) + self.guild = state._get_guild(utils._get_as_snowflake(data, "guild_id")) try: - ref = data['message_reference'] + ref = data["message_reference"] except KeyError: self.reference = None else: self.reference = ref = MessageReference.with_state(state, ref) try: - resolved = data['referenced_message'] + resolved = data["referenced_message"] except KeyError: pass else: @@ -712,18 +768,15 @@ class Message(Hashable): # the channel will be the correct type here ref.resolved = self.__class__(channel=chan, data=resolved, state=state) # type: ignore - for handler in ('author', 'member', 'mentions', 'mention_roles'): + for handler in ("author", "member", "mentions", "mention_roles"): try: - getattr(self, f'_handle_{handler}')(data[handler]) + getattr(self, f"_handle_{handler}")(data[handler]) except KeyError: continue def __repr__(self) -> str: name = self.__class__.__name__ - return ( - f'<{name} id={self.id} channel={self.channel!r} type={self.type!r} author={self.author!r} flags={self.flags!r}>' - ) - + return f"<{name} id={self.id} channel={self.channel!r} type={self.type!r} author={self.author!r} flags={self.flags!r}>" def __str__(self) -> Optional[str]: return self.content @@ -741,7 +794,7 @@ class Message(Hashable): def _add_reaction(self, data, emoji, user_id) -> Reaction: reaction = utils.find(lambda r: r.emoji == emoji, self.reactions) - is_me = data['me'] = user_id == self._state.self_id + is_me = data["me"] = user_id == self._state.self_id if reaction is None: reaction = Reaction(message=self, data=data, emoji=emoji) @@ -753,12 +806,14 @@ class Message(Hashable): return reaction - def _remove_reaction(self, data: ReactionPayload, emoji: EmojiInputType, user_id: int) -> Reaction: + def _remove_reaction( + self, data: ReactionPayload, emoji: EmojiInputType, user_id: int + ) -> Reaction: reaction = utils.find(lambda r: r.emoji == emoji, self.reactions) if reaction is None: # already removed? - raise ValueError('Emoji already removed?') + raise ValueError("Emoji already removed?") # if reaction isn't in the list, we crash. This means discord # sent bad data, or we stored improperly @@ -872,7 +927,7 @@ class Message(Hashable): return for mention in filter(None, mentions): - id_search = int(mention['id']) + id_search = int(mention["id"]) member = guild.get_member(id_search) if member is not None: r.append(member) @@ -890,11 +945,13 @@ class Message(Hashable): def _handle_components(self, components: List[ComponentPayload]): self.components = [_component_factory(d) for d in components] - def _rebind_cached_references(self, new_guild: Guild, new_channel: Union[TextChannel, Thread]) -> None: + def _rebind_cached_references( + self, new_guild: Guild, new_channel: Union[TextChannel, Thread] + ) -> None: self.guild = new_guild self.channel = new_channel - @utils.cached_slot_property('_cs_raw_mentions') + @utils.cached_slot_property("_cs_raw_mentions") def raw_mentions(self) -> List[int]: """List[:class:`int`]: A property that returns an array of user IDs matched with the syntax of ``<@user_id>`` in the message content. @@ -902,30 +959,30 @@ class Message(Hashable): This allows you to receive the user IDs of mentioned users even in a private message context. """ - return [int(x) for x in re.findall(r'<@!?([0-9]{15,20})>', self.content)] + return [int(x) for x in re.findall(r"<@!?([0-9]{15,20})>", self.content)] - @utils.cached_slot_property('_cs_raw_channel_mentions') + @utils.cached_slot_property("_cs_raw_channel_mentions") def raw_channel_mentions(self) -> List[int]: """List[:class:`int`]: A property that returns an array of channel IDs matched with the syntax of ``<#channel_id>`` in the message content. """ - return [int(x) for x in re.findall(r'<#([0-9]{15,20})>', self.content)] + return [int(x) for x in re.findall(r"<#([0-9]{15,20})>", self.content)] - @utils.cached_slot_property('_cs_raw_role_mentions') + @utils.cached_slot_property("_cs_raw_role_mentions") def raw_role_mentions(self) -> List[int]: """List[:class:`int`]: A property that returns an array of role IDs matched with the syntax of ``<@&role_id>`` in the message content. """ - return [int(x) for x in re.findall(r'<@&([0-9]{15,20})>', self.content)] + return [int(x) for x in re.findall(r"<@&([0-9]{15,20})>", self.content)] - @utils.cached_slot_property('_cs_channel_mentions') + @utils.cached_slot_property("_cs_channel_mentions") def channel_mentions(self) -> List[GuildChannel]: if self.guild is None: return [] it = filter(None, map(self.guild.get_channel, self.raw_channel_mentions)) return utils._unique(it) - @utils.cached_slot_property('_cs_clean_content') + @utils.cached_slot_property("_cs_clean_content") def clean_content(self) -> str: """:class:`str`: A property that returns the content in a "cleaned up" manner. This basically means that mentions are transformed @@ -972,9 +1029,9 @@ class Message(Hashable): # fmt: on def repl(obj): - return transformations.get(re.escape(obj.group(0)), '') + return transformations.get(re.escape(obj.group(0)), "") - pattern = re.compile('|'.join(transformations.keys())) + pattern = re.compile("|".join(transformations.keys())) result = pattern.sub(repl, self.content) return escape_mentions(result) @@ -991,8 +1048,8 @@ class Message(Hashable): @property def jump_url(self) -> str: """:class:`str`: Returns a URL that allows the client to jump to this message.""" - guild_id = getattr(self.guild, 'id', '@me') - return f'https://discord.com/channels/{guild_id}/{self.channel.id}/{self.id}' + guild_id = getattr(self.guild, "id", "@me") + return f"https://discord.com/channels/{guild_id}/{self.channel.id}/{self.id}" def is_system(self) -> bool: """:class:`bool`: Whether the message is a system message. @@ -1009,7 +1066,7 @@ class Message(Hashable): MessageType.thread_starter_message, ) - @utils.cached_slot_property('_cs_system_content') + @utils.cached_slot_property("_cs_system_content") def system_content(self): r""":class:`str`: A property that returns the content that is rendered regardless of the :attr:`Message.type`. @@ -1024,24 +1081,26 @@ class Message(Hashable): if self.type is MessageType.recipient_add: if self.channel.type is ChannelType.group: - return f'{self.author.name} added {self.mentions[0].name} to the group.' + return f"{self.author.name} added {self.mentions[0].name} to the group." else: - return f'{self.author.name} added {self.mentions[0].name} to the thread.' + return ( + f"{self.author.name} added {self.mentions[0].name} to the thread." + ) if self.type is MessageType.recipient_remove: if self.channel.type is ChannelType.group: - return f'{self.author.name} removed {self.mentions[0].name} from the group.' + return f"{self.author.name} removed {self.mentions[0].name} from the group." else: - return f'{self.author.name} removed {self.mentions[0].name} from the thread.' + return f"{self.author.name} removed {self.mentions[0].name} from the thread." if self.type is MessageType.channel_name_change: - return f'{self.author.name} changed the channel name: **{self.content}**' + return f"{self.author.name} changed the channel name: **{self.content}**" if self.type is MessageType.channel_icon_change: - return f'{self.author.name} changed the channel icon.' + return f"{self.author.name} changed the channel icon." if self.type is MessageType.pins_add: - return f'{self.author.name} pinned a message to this channel.' + return f"{self.author.name} pinned a message to this channel." if self.type is MessageType.new_member: formats = [ @@ -1065,64 +1124,66 @@ class Message(Hashable): if self.type is MessageType.premium_guild_subscription: if not self.content: - return f'{self.author.name} just boosted the server!' + return f"{self.author.name} just boosted the server!" else: - return f'{self.author.name} just boosted the server **{self.content}** times!' + return f"{self.author.name} just boosted the server **{self.content}** times!" if self.type is MessageType.premium_guild_tier_1: if not self.content: - return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 1!**' + return f"{self.author.name} just boosted the server! {self.guild} has achieved **Level 1!**" else: - return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 1!**' + return f"{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 1!**" if self.type is MessageType.premium_guild_tier_2: if not self.content: - return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 2!**' + return f"{self.author.name} just boosted the server! {self.guild} has achieved **Level 2!**" else: - return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 2!**' + return f"{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 2!**" if self.type is MessageType.premium_guild_tier_3: if not self.content: - return f'{self.author.name} just boosted the server! {self.guild} has achieved **Level 3!**' + return f"{self.author.name} just boosted the server! {self.guild} has achieved **Level 3!**" else: - return f'{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 3!**' + return f"{self.author.name} just boosted the server **{self.content}** times! {self.guild} has achieved **Level 3!**" if self.type is MessageType.channel_follow_add: - return f'{self.author.name} has added {self.content} to this channel' + return f"{self.author.name} has added {self.content} to this channel" if self.type is MessageType.guild_stream: # the author will be a Member - return f'{self.author.name} is live! Now streaming {self.author.activity.name}' # type: ignore + return f"{self.author.name} is live! Now streaming {self.author.activity.name}" # type: ignore if self.type is MessageType.guild_discovery_disqualified: - return 'This server has been removed from Server Discovery because it no longer passes all the requirements. Check Server Settings for more details.' + return "This server has been removed from Server Discovery because it no longer passes all the requirements. Check Server Settings for more details." if self.type is MessageType.guild_discovery_requalified: - return 'This server is eligible for Server Discovery again and has been automatically relisted!' + return "This server is eligible for Server Discovery again and has been automatically relisted!" if self.type is MessageType.guild_discovery_grace_period_initial_warning: - return 'This server has failed Discovery activity requirements for 1 week. If this server fails for 4 weeks in a row, it will be automatically removed from Discovery.' + return "This server has failed Discovery activity requirements for 1 week. If this server fails for 4 weeks in a row, it will be automatically removed from Discovery." if self.type is MessageType.guild_discovery_grace_period_final_warning: - return 'This server has failed Discovery activity requirements for 3 weeks in a row. If this server fails for 1 more week, it will be removed from Discovery.' + return "This server has failed Discovery activity requirements for 3 weeks in a row. If this server fails for 1 more week, it will be removed from Discovery." if self.type is MessageType.thread_created: - return f'{self.author.name} started a thread: **{self.content}**. See all **threads**.' + return f"{self.author.name} started a thread: **{self.content}**. See all **threads**." if self.type is MessageType.reply: return self.content if self.type is MessageType.thread_starter_message: if self.reference is None or self.reference.resolved is None: - return 'Sorry, we couldn\'t load the first message in this thread' + return "Sorry, we couldn't load the first message in this thread" # the resolved message for the reference will be a Message return self.reference.resolved.content # type: ignore if self.type is MessageType.guild_invite_reminder: - return 'Wondering who to invite?\nStart by inviting anyone who can help you build the server!' + return "Wondering who to invite?\nStart by inviting anyone who can help you build the server!" - async def delete(self, *, delay: Optional[float] = None, silent: bool = False) -> None: + async def delete( + self, *, delay: Optional[float] = None, silent: bool = False + ) -> None: """|coro| Deletes the message. @@ -1271,45 +1332,52 @@ class Message(Hashable): payload: Dict[str, Any] = {} if content is not MISSING: if content is not None: - payload['content'] = str(content) + payload["content"] = str(content) else: - payload['content'] = None + payload["content"] = None if embed is not MISSING and embeds is not MISSING: - raise InvalidArgument('cannot pass both embed and embeds parameter to edit()') + raise InvalidArgument( + "cannot pass both embed and embeds parameter to edit()" + ) if embed is not MISSING: if embed is None: - payload['embeds'] = [] + payload["embeds"] = [] else: - payload['embeds'] = [embed.to_dict()] + payload["embeds"] = [embed.to_dict()] elif embeds is not MISSING: - payload['embeds'] = [e.to_dict() for e in embeds] + payload["embeds"] = [e.to_dict() for e in embeds] if suppress is not MISSING: flags = MessageFlags._from_value(self.flags.value) flags.suppress_embeds = suppress - payload['flags'] = flags.value + payload["flags"] = flags.value if allowed_mentions is MISSING: - if self._state.allowed_mentions is not None and self.author.id == self._state.self_id: - payload['allowed_mentions'] = self._state.allowed_mentions.to_dict() + if ( + self._state.allowed_mentions is not None + and self.author.id == self._state.self_id + ): + payload["allowed_mentions"] = self._state.allowed_mentions.to_dict() else: if allowed_mentions is not None: if self._state.allowed_mentions is not None: - payload['allowed_mentions'] = self._state.allowed_mentions.merge(allowed_mentions).to_dict() + payload["allowed_mentions"] = self._state.allowed_mentions.merge( + allowed_mentions + ).to_dict() else: - payload['allowed_mentions'] = allowed_mentions.to_dict() + payload["allowed_mentions"] = allowed_mentions.to_dict() if attachments is not MISSING: - payload['attachments'] = [a.to_dict() for a in attachments] + payload["attachments"] = [a.to_dict() for a in attachments] if view is not MISSING: self._state.prevent_view_updates_for(self.id) if view: - payload['components'] = view.to_components() + payload["components"] = view.to_components() else: - payload['components'] = [] + payload["components"] = [] data = await self._state.http.edit_message(self.channel.id, self.id, **payload) message = Message(state=self._state, channel=self.channel, data=data) @@ -1430,7 +1498,9 @@ class Message(Hashable): emoji = convert_emoji_reaction(emoji) await self._state.http.add_reaction(self.channel.id, self.id, emoji) - async def remove_reaction(self, emoji: Union[EmojiInputType, Reaction], member: Snowflake) -> None: + async def remove_reaction( + self, emoji: Union[EmojiInputType, Reaction], member: Snowflake + ) -> None: """|coro| Remove a reaction by the member from the message. @@ -1467,7 +1537,9 @@ class Message(Hashable): if member.id == self._state.self_id: await self._state.http.remove_own_reaction(self.channel.id, self.id, emoji) else: - await self._state.http.remove_reaction(self.channel.id, self.id, emoji, member.id) + await self._state.http.remove_reaction( + self.channel.id, self.id, emoji, member.id + ) async def clear_reaction(self, emoji: Union[EmojiInputType, Reaction]) -> None: """|coro| @@ -1516,7 +1588,9 @@ class Message(Hashable): """ await self._state.http.clear_reactions(self.channel.id, self.id) - async def create_thread(self, *, name: str, auto_archive_duration: ThreadArchiveDuration = MISSING) -> Thread: + async def create_thread( + self, *, name: str, auto_archive_duration: ThreadArchiveDuration = MISSING + ) -> Thread: """|coro| Creates a public thread from this message. @@ -1551,14 +1625,17 @@ class Message(Hashable): The created thread. """ if self.guild is None: - raise InvalidArgument('This message does not have guild info attached.') + raise InvalidArgument("This message does not have guild info attached.") - default_auto_archive_duration: ThreadArchiveDuration = getattr(self.channel, 'default_auto_archive_duration', 1440) + default_auto_archive_duration: ThreadArchiveDuration = getattr( + self.channel, "default_auto_archive_duration", 1440 + ) data = await self._state.http.start_thread_with_message( self.channel.id, self.id, name=name, - auto_archive_duration=auto_archive_duration or default_auto_archive_duration, + auto_archive_duration=auto_archive_duration + or default_auto_archive_duration, ) return Thread(guild=self.guild, state=self._state, data=data) @@ -1607,16 +1684,18 @@ class Message(Hashable): The reference to this message. """ - return MessageReference.from_message(self, fail_if_not_exists=fail_if_not_exists) + return MessageReference.from_message( + self, fail_if_not_exists=fail_if_not_exists + ) def to_message_reference_dict(self) -> MessageReferencePayload: data: MessageReferencePayload = { - 'message_id': self.id, - 'channel_id': self.channel.id, + "message_id": self.id, + "channel_id": self.channel.id, } if self.guild is not None: - data['guild_id'] = self.guild.id + data["guild_id"] = self.guild.id return data @@ -1662,7 +1741,7 @@ class PartialMessage(Hashable): The message ID. """ - __slots__ = ('channel', 'id', '_cs_guild', '_state') + __slots__ = ("channel", "id", "_cs_guild", "_state") jump_url: str = Message.jump_url # type: ignore delete = Message.delete @@ -1686,7 +1765,9 @@ class PartialMessage(Hashable): ChannelType.public_thread, ChannelType.private_thread, ): - raise TypeError(f'Expected TextChannel, DMChannel or Thread not {type(channel)!r}') + raise TypeError( + f"Expected TextChannel, DMChannel or Thread not {type(channel)!r}" + ) self.channel: PartialMessageableChannel = channel self._state: ConnectionState = channel._state @@ -1702,17 +1783,17 @@ class PartialMessage(Hashable): pinned = property(None, lambda x, y: None) def __repr__(self) -> str: - return f'' + return f"" @property def created_at(self) -> datetime.datetime: """:class:`datetime.datetime`: The partial message's creation time in UTC.""" return utils.snowflake_time(self.id) - @utils.cached_slot_property('_cs_guild') + @utils.cached_slot_property("_cs_guild") def guild(self) -> Optional[Guild]: """Optional[:class:`Guild`]: The guild that the partial message belongs to, if applicable.""" - return getattr(self.channel, 'guild', None) + return getattr(self.channel, "guild", None) async def fetch(self) -> Message: """|coro| @@ -1794,58 +1875,62 @@ class PartialMessage(Hashable): """ try: - content = fields['content'] + content = fields["content"] except KeyError: pass else: if content is not None: - fields['content'] = str(content) + fields["content"] = str(content) try: - embed = fields['embed'] + embed = fields["embed"] except KeyError: pass else: if embed is not None: - fields['embed'] = embed.to_dict() + fields["embed"] = embed.to_dict() try: - suppress: bool = fields.pop('suppress') + suppress: bool = fields.pop("suppress") except KeyError: pass else: flags = MessageFlags._from_value(0) flags.suppress_embeds = suppress - fields['flags'] = flags.value + fields["flags"] = flags.value - delete_after = fields.pop('delete_after', None) + delete_after = fields.pop("delete_after", None) try: - allowed_mentions = fields.pop('allowed_mentions') + allowed_mentions = fields.pop("allowed_mentions") except KeyError: pass else: if allowed_mentions is not None: if self._state.allowed_mentions is not None: - allowed_mentions = self._state.allowed_mentions.merge(allowed_mentions).to_dict() + allowed_mentions = self._state.allowed_mentions.merge( + allowed_mentions + ).to_dict() else: allowed_mentions = allowed_mentions.to_dict() - fields['allowed_mentions'] = allowed_mentions + fields["allowed_mentions"] = allowed_mentions try: - view = fields.pop('view') + view = fields.pop("view") except KeyError: # To check for the view afterwards view = None else: self._state.prevent_view_updates_for(self.id) if view: - fields['components'] = view.to_components() + fields["components"] = view.to_components() else: - fields['components'] = [] + fields["components"] = [] if fields: - data = await self._state.http.edit_message(self.channel.id, self.id, **fields) + data = await self._state.http.edit_message( + self.channel.id, self.id, **fields + ) if delete_after is not None: await self.delete(delay=delete_after) diff --git a/discord/mixins.py b/discord/mixins.py index fdacf863..0fd42641 100644 --- a/discord/mixins.py +++ b/discord/mixins.py @@ -23,10 +23,11 @@ DEALINGS IN THE SOFTWARE. """ __all__ = ( - 'EqualityComparable', - 'Hashable', + "EqualityComparable", + "Hashable", ) + class EqualityComparable: __slots__ = () @@ -40,6 +41,7 @@ class EqualityComparable: return other.id != self.id return True + class Hashable(EqualityComparable): __slots__ = () diff --git a/discord/object.py b/discord/object.py index 8061a8be..77f43adc 100644 --- a/discord/object.py +++ b/discord/object.py @@ -35,11 +35,11 @@ from typing import ( if TYPE_CHECKING: import datetime + SupportsIntCast = Union[SupportsInt, str, bytes, bytearray] -__all__ = ( - 'Object', -) +__all__ = ("Object",) + class Object(Hashable): """Represents a generic Discord object. @@ -83,12 +83,14 @@ class Object(Hashable): try: id = int(id) except ValueError: - raise TypeError(f'id parameter must be convertable to int not {id.__class__!r}') from None + raise TypeError( + f"id parameter must be convertable to int not {id.__class__!r}" + ) from None else: self.id = id def __repr__(self) -> str: - return f'' + return f"" @property def created_at(self) -> datetime.datetime: diff --git a/discord/oggparse.py b/discord/oggparse.py index e0347d2c..e356970a 100644 --- a/discord/oggparse.py +++ b/discord/oggparse.py @@ -31,20 +31,24 @@ from typing import TYPE_CHECKING, ClassVar, IO, Generator, Tuple, Optional from .errors import DiscordException __all__ = ( - 'OggError', - 'OggPage', - 'OggStream', + "OggError", + "OggPage", + "OggStream", ) + class OggError(DiscordException): """An exception that is thrown for Ogg stream parsing errors.""" + pass + # https://tools.ietf.org/html/rfc3533 # https://tools.ietf.org/html/rfc7845 + class OggPage: - _header: ClassVar[struct.Struct] = struct.Struct(' Generator[Tuple[bytes, bool], None, None]: packetlen = offset = 0 @@ -76,7 +86,7 @@ class OggPage: partial = True else: packetlen += seg - yield self.data[offset:offset+packetlen], True + yield self.data[offset : offset + packetlen], True offset += packetlen packetlen = 0 partial = False @@ -84,18 +94,19 @@ class OggPage: if partial: yield self.data[offset:], False + class OggStream: def __init__(self, stream: IO[bytes]) -> None: self.stream: IO[bytes] = stream def _next_page(self) -> Optional[OggPage]: head = self.stream.read(4) - if head == b'OggS': + if head == b"OggS": return OggPage(self.stream) elif not head: return None else: - raise OggError('invalid header magic') + raise OggError("invalid header magic") def _iter_pages(self) -> Generator[OggPage, None, None]: page = self._next_page() @@ -104,10 +115,10 @@ class OggStream: page = self._next_page() def iter_packets(self) -> Generator[bytes, None, None]: - partial = b'' + partial = b"" for page in self._iter_pages(): for data, complete in page.iter_packets(): partial += data if complete: yield partial - partial = b'' + partial = b"" diff --git a/discord/opus.py b/discord/opus.py index 97d437a3..df5f4fce 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -24,7 +24,18 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import List, Tuple, TypedDict, Any, TYPE_CHECKING, Callable, TypeVar, Literal, Optional, overload +from typing import ( + List, + Tuple, + TypedDict, + Any, + TYPE_CHECKING, + Callable, + TypeVar, + Literal, + Optional, + overload, +) import array import ctypes @@ -38,9 +49,10 @@ import sys from .errors import DiscordException, InvalidArgument if TYPE_CHECKING: - T = TypeVar('T') - BAND_CTL = Literal['narrow', 'medium', 'wide', 'superwide', 'full'] - SIGNAL_CTL = Literal['auto', 'voice', 'music'] + T = TypeVar("T") + BAND_CTL = Literal["narrow", "medium", "wide", "superwide", "full"] + SIGNAL_CTL = Literal["auto", "voice", "music"] + class BandCtl(TypedDict): narrow: int @@ -49,81 +61,89 @@ class BandCtl(TypedDict): superwide: int full: int + class SignalCtl(TypedDict): auto: int voice: int music: int + __all__ = ( - 'Encoder', - 'OpusError', - 'OpusNotLoaded', + "Encoder", + "OpusError", + "OpusNotLoaded", ) _log = logging.getLogger(__name__) -c_int_ptr = ctypes.POINTER(ctypes.c_int) +c_int_ptr = ctypes.POINTER(ctypes.c_int) c_int16_ptr = ctypes.POINTER(ctypes.c_int16) c_float_ptr = ctypes.POINTER(ctypes.c_float) _lib = None + class EncoderStruct(ctypes.Structure): pass + class DecoderStruct(ctypes.Structure): pass + EncoderStructPtr = ctypes.POINTER(EncoderStruct) DecoderStructPtr = ctypes.POINTER(DecoderStruct) ## Some constants from opus_defines.h # Error codes -OK = 0 +OK = 0 BAD_ARG = -1 # Encoder CTLs -APPLICATION_AUDIO = 2049 -APPLICATION_VOIP = 2048 +APPLICATION_AUDIO = 2049 +APPLICATION_VOIP = 2048 APPLICATION_LOWDELAY = 2051 -CTL_SET_BITRATE = 4002 -CTL_SET_BANDWIDTH = 4008 -CTL_SET_FEC = 4012 -CTL_SET_PLP = 4014 -CTL_SET_SIGNAL = 4024 +CTL_SET_BITRATE = 4002 +CTL_SET_BANDWIDTH = 4008 +CTL_SET_FEC = 4012 +CTL_SET_PLP = 4014 +CTL_SET_SIGNAL = 4024 # Decoder CTLs -CTL_SET_GAIN = 4034 +CTL_SET_GAIN = 4034 CTL_LAST_PACKET_DURATION = 4039 band_ctl: BandCtl = { - 'narrow': 1101, - 'medium': 1102, - 'wide': 1103, - 'superwide': 1104, - 'full': 1105, + "narrow": 1101, + "medium": 1102, + "wide": 1103, + "superwide": 1104, + "full": 1105, } signal_ctl: SignalCtl = { - 'auto': -1000, - 'voice': 3001, - 'music': 3002, + "auto": -1000, + "voice": 3001, + "music": 3002, } + def _err_lt(result: int, func: Callable, args: List) -> int: if result < OK: - _log.info('error has happened in %s', func.__name__) + _log.info("error has happened in %s", func.__name__) raise OpusError(result) return result + def _err_ne(result: T, func: Callable, args: List) -> T: ret = args[-1]._obj if ret.value != OK: - _log.info('error has happened in %s', func.__name__) + _log.info("error has happened in %s", func.__name__) raise OpusError(ret.value) return result + # A list of exported functions. # The first argument is obviously the name. # The second one are the types of arguments it takes. @@ -131,54 +151,90 @@ def _err_ne(result: T, func: Callable, args: List) -> T: # The fourth is the error handler. exported_functions: List[Tuple[Any, ...]] = [ # Generic - ('opus_get_version_string', - None, ctypes.c_char_p, None), - ('opus_strerror', - [ctypes.c_int], ctypes.c_char_p, None), - + ("opus_get_version_string", None, ctypes.c_char_p, None), + ("opus_strerror", [ctypes.c_int], ctypes.c_char_p, None), # Encoder functions - ('opus_encoder_get_size', - [ctypes.c_int], ctypes.c_int, None), - ('opus_encoder_create', - [ctypes.c_int, ctypes.c_int, ctypes.c_int, c_int_ptr], EncoderStructPtr, _err_ne), - ('opus_encode', - [EncoderStructPtr, c_int16_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int32, _err_lt), - ('opus_encode_float', - [EncoderStructPtr, c_float_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int32, _err_lt), - ('opus_encoder_ctl', - None, ctypes.c_int32, _err_lt), - ('opus_encoder_destroy', - [EncoderStructPtr], None, None), - + ("opus_encoder_get_size", [ctypes.c_int], ctypes.c_int, None), + ( + "opus_encoder_create", + [ctypes.c_int, ctypes.c_int, ctypes.c_int, c_int_ptr], + EncoderStructPtr, + _err_ne, + ), + ( + "opus_encode", + [EncoderStructPtr, c_int16_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], + ctypes.c_int32, + _err_lt, + ), + ( + "opus_encode_float", + [EncoderStructPtr, c_float_ptr, ctypes.c_int, ctypes.c_char_p, ctypes.c_int32], + ctypes.c_int32, + _err_lt, + ), + ("opus_encoder_ctl", None, ctypes.c_int32, _err_lt), + ("opus_encoder_destroy", [EncoderStructPtr], None, None), # Decoder functions - ('opus_decoder_get_size', - [ctypes.c_int], ctypes.c_int, None), - ('opus_decoder_create', - [ctypes.c_int, ctypes.c_int, c_int_ptr], DecoderStructPtr, _err_ne), - ('opus_decode', - [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_int16_ptr, ctypes.c_int, ctypes.c_int], - ctypes.c_int, _err_lt), - ('opus_decode_float', - [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32, c_float_ptr, ctypes.c_int, ctypes.c_int], - ctypes.c_int, _err_lt), - ('opus_decoder_ctl', - None, ctypes.c_int32, _err_lt), - ('opus_decoder_destroy', - [DecoderStructPtr], None, None), - ('opus_decoder_get_nb_samples', - [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32], ctypes.c_int, _err_lt), - + ("opus_decoder_get_size", [ctypes.c_int], ctypes.c_int, None), + ( + "opus_decoder_create", + [ctypes.c_int, ctypes.c_int, c_int_ptr], + DecoderStructPtr, + _err_ne, + ), + ( + "opus_decode", + [ + DecoderStructPtr, + ctypes.c_char_p, + ctypes.c_int32, + c_int16_ptr, + ctypes.c_int, + ctypes.c_int, + ], + ctypes.c_int, + _err_lt, + ), + ( + "opus_decode_float", + [ + DecoderStructPtr, + ctypes.c_char_p, + ctypes.c_int32, + c_float_ptr, + ctypes.c_int, + ctypes.c_int, + ], + ctypes.c_int, + _err_lt, + ), + ("opus_decoder_ctl", None, ctypes.c_int32, _err_lt), + ("opus_decoder_destroy", [DecoderStructPtr], None, None), + ( + "opus_decoder_get_nb_samples", + [DecoderStructPtr, ctypes.c_char_p, ctypes.c_int32], + ctypes.c_int, + _err_lt, + ), # Packet functions - ('opus_packet_get_bandwidth', - [ctypes.c_char_p], ctypes.c_int, _err_lt), - ('opus_packet_get_nb_channels', - [ctypes.c_char_p], ctypes.c_int, _err_lt), - ('opus_packet_get_nb_frames', - [ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt), - ('opus_packet_get_samples_per_frame', - [ctypes.c_char_p, ctypes.c_int], ctypes.c_int, _err_lt), + ("opus_packet_get_bandwidth", [ctypes.c_char_p], ctypes.c_int, _err_lt), + ("opus_packet_get_nb_channels", [ctypes.c_char_p], ctypes.c_int, _err_lt), + ( + "opus_packet_get_nb_frames", + [ctypes.c_char_p, ctypes.c_int], + ctypes.c_int, + _err_lt, + ), + ( + "opus_packet_get_samples_per_frame", + [ctypes.c_char_p, ctypes.c_int], + ctypes.c_int, + _err_lt, + ), ] + def libopus_loader(name: str) -> Any: # create the library... lib = ctypes.cdll.LoadLibrary(name) @@ -203,22 +259,24 @@ def libopus_loader(name: str) -> Any: return lib + def _load_default() -> bool: global _lib try: - if sys.platform == 'win32': + if sys.platform == "win32": _basedir = os.path.dirname(os.path.abspath(__file__)) - _bitness = struct.calcsize('P') * 8 - _target = 'x64' if _bitness > 32 else 'x86' - _filename = os.path.join(_basedir, 'bin', f'libopus-0.{_target}.dll') + _bitness = struct.calcsize("P") * 8 + _target = "x64" if _bitness > 32 else "x86" + _filename = os.path.join(_basedir, "bin", f"libopus-0.{_target}.dll") _lib = libopus_loader(_filename) else: - _lib = libopus_loader(ctypes.util.find_library('opus')) + _lib = libopus_loader(ctypes.util.find_library("opus")) except Exception: _lib = None return _lib is not None + def load_opus(name: str) -> None: """Loads the libopus shared library for use with voice. @@ -257,6 +315,7 @@ def load_opus(name: str) -> None: global _lib _lib = libopus_loader(name) + def is_loaded() -> bool: """Function to check if opus lib is successfully loaded either via the :func:`ctypes.util.find_library` call of :func:`load_opus`. @@ -271,6 +330,7 @@ def is_loaded() -> bool: global _lib return _lib is not None + class OpusError(DiscordException): """An exception that is thrown for libopus related errors. @@ -282,19 +342,22 @@ class OpusError(DiscordException): def __init__(self, code: int): self.code: int = code - msg = _lib.opus_strerror(self.code).decode('utf-8') + msg = _lib.opus_strerror(self.code).decode("utf-8") _log.info('"%s" has happened', msg) super().__init__(msg) + class OpusNotLoaded(DiscordException): """An exception that is thrown for when libopus is not loaded.""" + pass + class _OpusStruct: SAMPLING_RATE = 48000 CHANNELS = 2 FRAME_LENGTH = 20 # in milliseconds - SAMPLE_SIZE = struct.calcsize('h') * CHANNELS + SAMPLE_SIZE = struct.calcsize("h") * CHANNELS SAMPLES_PER_FRAME = int(SAMPLING_RATE / 1000 * FRAME_LENGTH) FRAME_SIZE = SAMPLES_PER_FRAME * SAMPLE_SIZE @@ -304,7 +367,8 @@ class _OpusStruct: if not is_loaded() and not _load_default(): raise OpusNotLoaded() - return _lib.opus_get_version_string().decode('utf-8') + return _lib.opus_get_version_string().decode("utf-8") + class Encoder(_OpusStruct): def __init__(self, application: int = APPLICATION_AUDIO): @@ -315,18 +379,20 @@ class Encoder(_OpusStruct): self.set_bitrate(128) self.set_fec(True) self.set_expected_packet_loss_percent(0.15) - self.set_bandwidth('full') - self.set_signal_type('auto') + self.set_bandwidth("full") + self.set_signal_type("auto") def __del__(self) -> None: - if hasattr(self, '_state'): + if hasattr(self, "_state"): _lib.opus_encoder_destroy(self._state) # This is a destructor, so it's okay to assign None - self._state = None # type: ignore + self._state = None # type: ignore def _create_state(self) -> EncoderStruct: ret = ctypes.c_int() - return _lib.opus_encoder_create(self.SAMPLING_RATE, self.CHANNELS, self.application, ctypes.byref(ret)) + return _lib.opus_encoder_create( + self.SAMPLING_RATE, self.CHANNELS, self.application, ctypes.byref(ret) + ) def set_bitrate(self, kbps: int) -> int: kbps = min(512, max(16, int(kbps))) @@ -336,14 +402,18 @@ class Encoder(_OpusStruct): def set_bandwidth(self, req: BAND_CTL) -> None: if req not in band_ctl: - raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(band_ctl)}') + raise KeyError( + f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(band_ctl)}' + ) k = band_ctl[req] _lib.opus_encoder_ctl(self._state, CTL_SET_BANDWIDTH, k) def set_signal_type(self, req: SIGNAL_CTL) -> None: if req not in signal_ctl: - raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(signal_ctl)}') + raise KeyError( + f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(signal_ctl)}' + ) k = signal_ctl[req] _lib.opus_encoder_ctl(self._state, CTL_SET_SIGNAL, k) @@ -352,18 +422,19 @@ class Encoder(_OpusStruct): _lib.opus_encoder_ctl(self._state, CTL_SET_FEC, 1 if enabled else 0) def set_expected_packet_loss_percent(self, percentage: float) -> None: - _lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore + _lib.opus_encoder_ctl(self._state, CTL_SET_PLP, min(100, max(0, int(percentage * 100)))) # type: ignore def encode(self, pcm: bytes, frame_size: int) -> bytes: max_data_bytes = len(pcm) # bytes can be used to reference pointer - pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore + pcm_ptr = ctypes.cast(pcm, c_int16_ptr) # type: ignore data = (ctypes.c_char * max_data_bytes)() ret = _lib.opus_encode(self._state, pcm_ptr, frame_size, data, max_data_bytes) # array can be initialized with bytes but mypy doesn't know - return array.array('b', data[:ret]).tobytes() # type: ignore + return array.array("b", data[:ret]).tobytes() # type: ignore + class Decoder(_OpusStruct): def __init__(self): @@ -372,14 +443,16 @@ class Decoder(_OpusStruct): self._state: DecoderStruct = self._create_state() def __del__(self) -> None: - if hasattr(self, '_state'): + if hasattr(self, "_state"): _lib.opus_decoder_destroy(self._state) # This is a destructor, so it's okay to assign None - self._state = None # type: ignore + self._state = None # type: ignore def _create_state(self) -> DecoderStruct: ret = ctypes.c_int() - return _lib.opus_decoder_create(self.SAMPLING_RATE, self.CHANNELS, ctypes.byref(ret)) + return _lib.opus_decoder_create( + self.SAMPLING_RATE, self.CHANNELS, ctypes.byref(ret) + ) @staticmethod def packet_get_nb_frames(data: bytes) -> int: @@ -411,12 +484,12 @@ class Decoder(_OpusStruct): def set_gain(self, dB: float) -> int: """Sets the decoder gain in dB, from -128 to 128.""" - dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8) + dB_Q8 = max(-32768, min(32767, round(dB * 256))) # dB * 2^n where n is 8 (Q8) return self._set_gain(dB_Q8) def set_volume(self, mult: float) -> int: """Sets the output volume as a float percent, i.e. 0.5 for 50%, 1.75 for 175%, etc.""" - return self.set_gain(20 * math.log10(mult)) # amplitude ratio + return self.set_gain(20 * math.log10(mult)) # amplitude ratio def _get_last_packet_duration(self) -> int: """Gets the duration (in samples) of the last packet successfully decoded or concealed.""" @@ -428,14 +501,16 @@ class Decoder(_OpusStruct): @overload def decode(self, data: bytes, *, fec: bool) -> bytes: ... - + @overload def decode(self, data: Literal[None], *, fec: Literal[False]) -> bytes: ... def decode(self, data: Optional[bytes], *, fec: bool = False) -> bytes: if data is None and fec: - raise InvalidArgument("Invalid arguments: FEC cannot be used with null data") + raise InvalidArgument( + "Invalid arguments: FEC cannot be used with null data" + ) if data is None: frame_size = self._get_last_packet_duration() or self.SAMPLES_PER_FRAME @@ -449,6 +524,8 @@ class Decoder(_OpusStruct): pcm = (ctypes.c_int16 * (frame_size * channel_count))() pcm_ptr = ctypes.cast(pcm, c_int16_ptr) - ret = _lib.opus_decode(self._state, data, len(data) if data else 0, pcm_ptr, frame_size, fec) + ret = _lib.opus_decode( + self._state, data, len(data) if data else 0, pcm_ptr, frame_size, fec + ) - return array.array('h', pcm[:ret * channel_count]).tobytes() + return array.array("h", pcm[: ret * channel_count]).tobytes() diff --git a/discord/partial_emoji.py b/discord/partial_emoji.py index e2c689e2..eb34be2a 100644 --- a/discord/partial_emoji.py +++ b/discord/partial_emoji.py @@ -31,15 +31,14 @@ from .asset import Asset, AssetMixin from .errors import InvalidArgument from . import utils -__all__ = ( - 'PartialEmoji', -) +__all__ = ("PartialEmoji",) if TYPE_CHECKING: from .state import ConnectionState from datetime import datetime from .types.message import PartialEmoji as PartialEmojiPayload + class _EmojiTag: __slots__ = () @@ -49,7 +48,7 @@ class _EmojiTag: raise NotImplementedError -PE = TypeVar('PE', bound='PartialEmoji') +PE = TypeVar("PE", bound="PartialEmoji") class PartialEmoji(_EmojiTag, AssetMixin): @@ -90,9 +89,11 @@ class PartialEmoji(_EmojiTag, AssetMixin): The ID of the custom emoji, if applicable. """ - __slots__ = ('animated', 'name', 'id', '_state') + __slots__ = ("animated", "name", "id", "_state") - _CUSTOM_EMOJI_RE = re.compile(r'a)?:?(?P[A-Za-z0-9\_]+):(?P[0-9]{13,20})>?') + _CUSTOM_EMOJI_RE = re.compile( + r"a)?:?(?P[A-Za-z0-9\_]+):(?P[0-9]{13,20})>?" + ) if TYPE_CHECKING: id: Optional[int] @@ -104,11 +105,13 @@ class PartialEmoji(_EmojiTag, AssetMixin): self._state: Optional[ConnectionState] = None @classmethod - def from_dict(cls: Type[PE], data: Union[PartialEmojiPayload, Dict[str, Any]]) -> PE: + def from_dict( + cls: Type[PE], data: Union[PartialEmojiPayload, Dict[str, Any]] + ) -> PE: return cls( - animated=data.get('animated', False), - id=utils._get_as_snowflake(data, 'id'), - name=data.get('name') or '', + animated=data.get("animated", False), + id=utils._get_as_snowflake(data, "id"), + name=data.get("name") or "", ) @classmethod @@ -139,19 +142,19 @@ class PartialEmoji(_EmojiTag, AssetMixin): match = cls._CUSTOM_EMOJI_RE.match(value) if match is not None: groups = match.groupdict() - animated = bool(groups['animated']) - emoji_id = int(groups['id']) - name = groups['name'] + animated = bool(groups["animated"]) + emoji_id = int(groups["id"]) + name = groups["name"] return cls(name=name, animated=animated, id=emoji_id) return cls(name=value, id=None, animated=False) def to_dict(self) -> Dict[str, Any]: - o: Dict[str, Any] = {'name': self.name} + o: Dict[str, Any] = {"name": self.name} if self.id: - o['id'] = self.id + o["id"] = self.id if self.animated: - o['animated'] = self.animated + o["animated"] = self.animated return o def _to_partial(self) -> PartialEmoji: @@ -159,7 +162,12 @@ class PartialEmoji(_EmojiTag, AssetMixin): @classmethod def with_state( - cls: Type[PE], state: ConnectionState, *, name: str, animated: bool = False, id: Optional[int] = None + cls: Type[PE], + state: ConnectionState, + *, + name: str, + animated: bool = False, + id: Optional[int] = None, ) -> PE: self = cls(name=name, animated=animated, id=id) self._state = state @@ -169,11 +177,11 @@ class PartialEmoji(_EmojiTag, AssetMixin): if self.id is None: return self.name if self.animated: - return f'' - return f'<:{self.name}:{self.id}>' + return f"" + return f"<:{self.name}:{self.id}>" def __repr__(self): - return f'<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>' + return f"<{self.__class__.__name__} animated={self.animated} name={self.name!r} id={self.id}>" def __eq__(self, other: Any) -> bool: if self.is_unicode_emoji(): @@ -200,7 +208,7 @@ class PartialEmoji(_EmojiTag, AssetMixin): def _as_reaction(self) -> str: if self.id is None: return self.name - return f'{self.name}:{self.id}' + return f"{self.name}:{self.id}" @property def created_at(self) -> Optional[datetime]: @@ -220,13 +228,13 @@ class PartialEmoji(_EmojiTag, AssetMixin): If this isn't a custom emoji then an empty string is returned """ if self.is_unicode_emoji(): - return '' + return "" - fmt = 'gif' if self.animated else 'png' - return f'{Asset.BASE}/emojis/{self.id}.{fmt}' + fmt = "gif" if self.animated else "png" + return f"{Asset.BASE}/emojis/{self.id}.{fmt}" async def read(self) -> bytes: if self.is_unicode_emoji(): - raise InvalidArgument('PartialEmoji is not a custom emoji') + raise InvalidArgument("PartialEmoji is not a custom emoji") return await super().read() diff --git a/discord/permissions.py b/discord/permissions.py index 9d40ca33..a972c5ef 100644 --- a/discord/permissions.py +++ b/discord/permissions.py @@ -24,12 +24,24 @@ DEALINGS IN THE SOFTWARE. from __future__ import annotations -from typing import Callable, Any, ClassVar, Dict, Iterator, Set, TYPE_CHECKING, Tuple, Type, TypeVar, Optional +from typing import ( + Callable, + Any, + ClassVar, + Dict, + Iterator, + Set, + TYPE_CHECKING, + Tuple, + Type, + TypeVar, + Optional, +) from .flags import BaseFlags, flag_value, fill_with_flags, alias_flag_value __all__ = ( - 'Permissions', - 'PermissionOverwrite', + "Permissions", + "PermissionOverwrite", ) # A permission alias works like a regular flag but is marked @@ -38,7 +50,9 @@ class permission_alias(alias_flag_value): alias: str -def make_permission_alias(alias: str) -> Callable[[Callable[[Any], int]], permission_alias]: +def make_permission_alias( + alias: str, +) -> Callable[[Callable[[Any], int]], permission_alias]: def decorator(func: Callable[[Any], int]) -> permission_alias: ret = permission_alias(func) ret.alias = alias @@ -46,7 +60,9 @@ def make_permission_alias(alias: str) -> Callable[[Callable[[Any], int]], permis return decorator -P = TypeVar('P', bound='Permissions') + +P = TypeVar("P", bound="Permissions") + @fill_with_flags() class Permissions(BaseFlags): @@ -101,12 +117,14 @@ class Permissions(BaseFlags): def __init__(self, permissions: int = 0, **kwargs: bool): if not isinstance(permissions, int): - raise TypeError(f'Expected int parameter, received {permissions.__class__.__name__} instead.') + raise TypeError( + f"Expected int parameter, received {permissions.__class__.__name__} instead." + ) self.value = permissions for key, value in kwargs.items(): if key not in self.VALID_FLAGS: - raise TypeError(f'{key!r} is not a valid permission name.') + raise TypeError(f"{key!r} is not a valid permission name.") setattr(self, key, value) def is_subset(self, other: Permissions) -> bool: @@ -114,14 +132,18 @@ class Permissions(BaseFlags): if isinstance(other, Permissions): return (self.value & other.value) == self.value else: - raise TypeError(f"cannot compare {self.__class__.__name__} with {other.__class__.__name__}") + raise TypeError( + f"cannot compare {self.__class__.__name__} with {other.__class__.__name__}" + ) def is_superset(self, other: Permissions) -> bool: """Returns ``True`` if self has the same or more permissions as other.""" if isinstance(other, Permissions): return (self.value | other.value) == self.value else: - raise TypeError(f"cannot compare {self.__class__.__name__} with {other.__class__.__name__}") + raise TypeError( + f"cannot compare {self.__class__.__name__} with {other.__class__.__name__}" + ) def is_strict_subset(self, other: Permissions) -> bool: """Returns ``True`` if the permissions on other are a strict subset of those on self.""" @@ -336,7 +358,7 @@ class Permissions(BaseFlags): """:class:`bool`: Returns ``True`` if a user can read messages from all or specific text channels.""" return 1 << 10 - @make_permission_alias('read_messages') + @make_permission_alias("read_messages") def view_channel(self) -> int: """:class:`bool`: An alias for :attr:`read_messages`. @@ -389,7 +411,7 @@ class Permissions(BaseFlags): """:class:`bool`: Returns ``True`` if a user can use emojis from other guilds.""" return 1 << 18 - @make_permission_alias('external_emojis') + @make_permission_alias("external_emojis") def use_external_emojis(self) -> int: """:class:`bool`: An alias for :attr:`external_emojis`. @@ -453,7 +475,7 @@ class Permissions(BaseFlags): """ return 1 << 28 - @make_permission_alias('manage_roles') + @make_permission_alias("manage_roles") def manage_permissions(self) -> int: """:class:`bool`: An alias for :attr:`manage_roles`. @@ -471,7 +493,7 @@ class Permissions(BaseFlags): """:class:`bool`: Returns ``True`` if a user can create, edit, or delete emojis.""" return 1 << 30 - @make_permission_alias('manage_emojis') + @make_permission_alias("manage_emojis") def manage_emojis_and_stickers(self) -> int: """:class:`bool`: An alias for :attr:`manage_emojis`. @@ -535,7 +557,7 @@ class Permissions(BaseFlags): """ return 1 << 37 - @make_permission_alias('external_stickers') + @make_permission_alias("external_stickers") def use_external_stickers(self) -> int: """:class:`bool`: An alias for :attr:`external_stickers`. @@ -551,7 +573,9 @@ class Permissions(BaseFlags): """ return 1 << 38 -PO = TypeVar('PO', bound='PermissionOverwrite') + +PO = TypeVar("PO", bound="PermissionOverwrite") + def _augment_from_permissions(cls): cls.VALID_NAMES = set(Permissions.VALID_FLAGS) @@ -614,7 +638,7 @@ class PermissionOverwrite: Set the value of permissions by their name. """ - __slots__ = ('_values',) + __slots__ = ("_values",) if TYPE_CHECKING: VALID_NAMES: ClassVar[Set[str]] @@ -670,7 +694,7 @@ class PermissionOverwrite: for key, value in kwargs.items(): if key not in self.VALID_NAMES: - raise ValueError(f'no permission called {key}.') + raise ValueError(f"no permission called {key}.") setattr(self, key, value) @@ -679,7 +703,9 @@ class PermissionOverwrite: def _set(self, key: str, value: Optional[bool]) -> None: if value not in (True, None, False): - raise TypeError(f'Expected bool or NoneType, received {value.__class__.__name__}') + raise TypeError( + f"Expected bool or NoneType, received {value.__class__.__name__}" + ) if value is None: self._values.pop(key, None) diff --git a/discord/player.py b/discord/player.py index 8098d3e3..dad70372 100644 --- a/discord/player.py +++ b/discord/player.py @@ -36,7 +36,18 @@ import sys import re import io -from typing import Any, Callable, Generic, IO, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union +from typing import ( + Any, + Callable, + Generic, + IO, + Optional, + TYPE_CHECKING, + Tuple, + Type, + TypeVar, + Union, +) from .errors import ClientException from .opus import Encoder as OpusEncoder @@ -47,27 +58,28 @@ if TYPE_CHECKING: from .voice_client import VoiceClient -AT = TypeVar('AT', bound='AudioSource') -FT = TypeVar('FT', bound='FFmpegOpusAudio') +AT = TypeVar("AT", bound="AudioSource") +FT = TypeVar("FT", bound="FFmpegOpusAudio") _log = logging.getLogger(__name__) __all__ = ( - 'AudioSource', - 'PCMAudio', - 'FFmpegAudio', - 'FFmpegPCMAudio', - 'FFmpegOpusAudio', - 'PCMVolumeTransformer', + "AudioSource", + "PCMAudio", + "FFmpegAudio", + "FFmpegPCMAudio", + "FFmpegOpusAudio", + "PCMVolumeTransformer", ) CREATE_NO_WINDOW: int -if sys.platform != 'win32': +if sys.platform != "win32": CREATE_NO_WINDOW = 0 else: CREATE_NO_WINDOW = 0x08000000 + class AudioSource: """Represents an audio stream. @@ -114,6 +126,7 @@ class AudioSource: def __del__(self) -> None: self.cleanup() + class PCMAudio(AudioSource): """Represents raw 16-bit 48KHz stereo PCM audio source. @@ -122,15 +135,17 @@ class PCMAudio(AudioSource): stream: :term:`py:file object` A file-like object that reads byte data representing raw PCM. """ + def __init__(self, stream: io.BufferedIOBase) -> None: self.stream: io.BufferedIOBase = stream def read(self) -> bytes: ret = self.stream.read(OpusEncoder.FRAME_SIZE) if len(ret) != OpusEncoder.FRAME_SIZE: - return b'' + return b"" return ret + class FFmpegAudio(AudioSource): """Represents an FFmpeg (or AVConv) based AudioSource. @@ -140,13 +155,22 @@ class FFmpegAudio(AudioSource): .. versionadded:: 1.3 """ - def __init__(self, source: Union[str, io.BufferedIOBase], *, executable: str = 'ffmpeg', args: Any, **subprocess_kwargs: Any): - piping = subprocess_kwargs.get('stdin') == subprocess.PIPE + def __init__( + self, + source: Union[str, io.BufferedIOBase], + *, + executable: str = "ffmpeg", + args: Any, + **subprocess_kwargs: Any, + ): + piping = subprocess_kwargs.get("stdin") == subprocess.PIPE if piping and isinstance(source, str): - raise TypeError("parameter conflict: 'source' parameter cannot be a string when piping to stdin") + raise TypeError( + "parameter conflict: 'source' parameter cannot be a string when piping to stdin" + ) args = [executable, *args] - kwargs = {'stdout': subprocess.PIPE} + kwargs = {"stdout": subprocess.PIPE} kwargs.update(subprocess_kwargs) self._process: subprocess.Popen = self._spawn_process(args, **kwargs) @@ -155,20 +179,26 @@ class FFmpegAudio(AudioSource): self._pipe_thread: Optional[threading.Thread] = None if piping: - n = f'popen-stdin-writer:{id(self):#x}' + n = f"popen-stdin-writer:{id(self):#x}" self._stdin = self._process.stdin - self._pipe_thread = threading.Thread(target=self._pipe_writer, args=(source,), daemon=True, name=n) + self._pipe_thread = threading.Thread( + target=self._pipe_writer, args=(source,), daemon=True, name=n + ) self._pipe_thread.start() def _spawn_process(self, args: Any, **subprocess_kwargs: Any) -> subprocess.Popen: process = None try: - process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs) + process = subprocess.Popen( + args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs + ) except FileNotFoundError: - executable = args.partition(' ')[0] if isinstance(args, str) else args[0] - raise ClientException(executable + ' was not found.') from None + executable = args.partition(" ")[0] if isinstance(args, str) else args[0] + raise ClientException(executable + " was not found.") from None except subprocess.SubprocessError as exc: - raise ClientException(f'Popen failed: {exc.__class__.__name__}: {exc}') from exc + raise ClientException( + f"Popen failed: {exc.__class__.__name__}: {exc}" + ) from exc else: return process @@ -177,20 +207,32 @@ class FFmpegAudio(AudioSource): if proc is MISSING: return - _log.info('Preparing to terminate ffmpeg process %s.', proc.pid) + _log.info("Preparing to terminate ffmpeg process %s.", proc.pid) try: proc.kill() except Exception: - _log.exception('Ignoring error attempting to kill ffmpeg process %s', proc.pid) + _log.exception( + "Ignoring error attempting to kill ffmpeg process %s", proc.pid + ) if proc.poll() is None: - _log.info('ffmpeg process %s has not terminated. Waiting to terminate...', proc.pid) + _log.info( + "ffmpeg process %s has not terminated. Waiting to terminate...", + proc.pid, + ) proc.communicate() - _log.info('ffmpeg process %s should have terminated with a return code of %s.', proc.pid, proc.returncode) + _log.info( + "ffmpeg process %s should have terminated with a return code of %s.", + proc.pid, + proc.returncode, + ) else: - _log.info('ffmpeg process %s successfully terminated with return code of %s.', proc.pid, proc.returncode) - + _log.info( + "ffmpeg process %s successfully terminated with return code of %s.", + proc.pid, + proc.returncode, + ) def _pipe_writer(self, source: io.BufferedIOBase) -> None: while self._process: @@ -202,7 +244,11 @@ class FFmpegAudio(AudioSource): try: self._stdin.write(data) except Exception: - _log.debug('Write error for %s, this is probably not a problem', self, exc_info=True) + _log.debug( + "Write error for %s, this is probably not a problem", + self, + exc_info=True, + ) # at this point the source data is either exhausted or the process is fubar self._process.terminate() return @@ -211,6 +257,7 @@ class FFmpegAudio(AudioSource): self._kill_process() self._process = self._stdout = self._stdin = MISSING + class FFmpegPCMAudio(FFmpegAudio): """An audio source from FFmpeg (or AVConv). @@ -250,38 +297,42 @@ class FFmpegPCMAudio(FFmpegAudio): self, source: Union[str, io.BufferedIOBase], *, - executable: str = 'ffmpeg', + executable: str = "ffmpeg", pipe: bool = False, stderr: Optional[IO[str]] = None, before_options: Optional[str] = None, - options: Optional[str] = None + options: Optional[str] = None, ) -> None: args = [] - subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr} + subprocess_kwargs = { + "stdin": subprocess.PIPE if pipe else subprocess.DEVNULL, + "stderr": stderr, + } if isinstance(before_options, str): args.extend(shlex.split(before_options)) - args.append('-i') - args.append('-' if pipe else source) - args.extend(('-f', 's16le', '-ar', '48000', '-ac', '2', '-loglevel', 'warning')) + args.append("-i") + args.append("-" if pipe else source) + args.extend(("-f", "s16le", "-ar", "48000", "-ac", "2", "-loglevel", "warning")) if isinstance(options, str): args.extend(shlex.split(options)) - args.append('pipe:1') + args.append("pipe:1") super().__init__(source, executable=executable, args=args, **subprocess_kwargs) def read(self) -> bytes: ret = self._stdout.read(OpusEncoder.FRAME_SIZE) if len(ret) != OpusEncoder.FRAME_SIZE: - return b'' + return b"" return ret def is_opus(self) -> bool: return False + class FFmpegOpusAudio(FFmpegAudio): """An audio source from FFmpeg (or AVConv). @@ -349,7 +400,7 @@ class FFmpegOpusAudio(FFmpegAudio): *, bitrate: int = 128, codec: Optional[str] = None, - executable: str = 'ffmpeg', + executable: str = "ffmpeg", pipe=False, stderr=None, before_options=None, @@ -357,28 +408,42 @@ class FFmpegOpusAudio(FFmpegAudio): ) -> None: args = [] - subprocess_kwargs = {'stdin': subprocess.PIPE if pipe else subprocess.DEVNULL, 'stderr': stderr} + subprocess_kwargs = { + "stdin": subprocess.PIPE if pipe else subprocess.DEVNULL, + "stderr": stderr, + } if isinstance(before_options, str): args.extend(shlex.split(before_options)) - args.append('-i') - args.append('-' if pipe else source) + args.append("-i") + args.append("-" if pipe else source) - codec = 'copy' if codec in ('opus', 'libopus') else 'libopus' + codec = "copy" if codec in ("opus", "libopus") else "libopus" - args.extend(('-map_metadata', '-1', - '-f', 'opus', - '-c:a', codec, - '-ar', '48000', - '-ac', '2', - '-b:a', f'{bitrate}k', - '-loglevel', 'warning')) + args.extend( + ( + "-map_metadata", + "-1", + "-f", + "opus", + "-c:a", + codec, + "-ar", + "48000", + "-ac", + "2", + "-b:a", + f"{bitrate}k", + "-loglevel", + "warning", + ) + ) if isinstance(options, str): args.extend(shlex.split(options)) - args.append('pipe:1') + args.append("pipe:1") super().__init__(source, executable=executable, args=args, **subprocess_kwargs) self._packet_iter = OggStream(self._stdout).iter_packets() @@ -388,7 +453,9 @@ class FFmpegOpusAudio(FFmpegAudio): cls: Type[FT], source: str, *, - method: Optional[Union[str, Callable[[str, str], Tuple[Optional[str], Optional[int]]]]] = None, + method: Optional[ + Union[str, Callable[[str, str], Tuple[Optional[str], Optional[int]]]] + ] = None, **kwargs: Any, ) -> FT: """|coro| @@ -446,7 +513,7 @@ class FFmpegOpusAudio(FFmpegAudio): An instance of this class. """ - executable = kwargs.get('executable') + executable = kwargs.get("executable") codec, bitrate = await cls.probe(source, method=method, executable=executable) return cls(source, bitrate=bitrate, codec=codec, **kwargs) # type: ignore @@ -455,7 +522,9 @@ class FFmpegOpusAudio(FFmpegAudio): cls, source: str, *, - method: Optional[Union[str, Callable[[str, str], Tuple[Optional[str], Optional[int]]]]] = None, + method: Optional[ + Union[str, Callable[[str, str], Tuple[Optional[str], Optional[int]]]] + ] = None, executable: Optional[str] = None, ) -> Tuple[Optional[str], Optional[int]]: """|coro| @@ -484,12 +553,12 @@ class FFmpegOpusAudio(FFmpegAudio): A 2-tuple with the codec and bitrate of the input source. """ - method = method or 'native' - executable = executable or 'ffmpeg' + method = method or "native" + executable = executable or "ffmpeg" probefunc = fallback = None if isinstance(method, str): - probefunc = getattr(cls, '_probe_codec_' + method, None) + probefunc = getattr(cls, "_probe_codec_" + method, None) if probefunc is None: raise AttributeError(f"Invalid probe method {method!r}") @@ -500,8 +569,10 @@ class FFmpegOpusAudio(FFmpegAudio): probefunc = method fallback = cls._probe_codec_fallback else: - raise TypeError("Expected str or callable for parameter 'probe', " \ - f"not '{method.__class__.__name__}'") + raise TypeError( + "Expected str or callable for parameter 'probe', " + f"not '{method.__class__.__name__}'" + ) codec = bitrate = None loop = asyncio.get_event_loop() @@ -512,7 +583,9 @@ class FFmpegOpusAudio(FFmpegAudio): _log.exception("Probe '%s' using '%s' failed", method, executable) return # type: ignore - _log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable) + _log.exception( + "Probe '%s' using '%s' failed, trying fallback", method, executable + ) try: codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable)) # type: ignore except Exception: @@ -525,28 +598,51 @@ class FFmpegOpusAudio(FFmpegAudio): return codec, bitrate @staticmethod - def _probe_codec_native(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]: - exe = executable[:2] + 'probe' if executable in ('ffmpeg', 'avconv') else executable - args = [exe, '-v', 'quiet', '-print_format', 'json', '-show_streams', '-select_streams', 'a:0', source] + def _probe_codec_native( + source, executable: str = "ffmpeg" + ) -> Tuple[Optional[str], Optional[int]]: + exe = ( + executable[:2] + "probe" + if executable in ("ffmpeg", "avconv") + else executable + ) + args = [ + exe, + "-v", + "quiet", + "-print_format", + "json", + "-show_streams", + "-select_streams", + "a:0", + source, + ] output = subprocess.check_output(args, timeout=20) codec = bitrate = None if output: data = json.loads(output) - streamdata = data['streams'][0] + streamdata = data["streams"][0] - codec = streamdata.get('codec_name') - bitrate = int(streamdata.get('bit_rate', 0)) - bitrate = max(round(bitrate/1000), 512) + codec = streamdata.get("codec_name") + bitrate = int(streamdata.get("bit_rate", 0)) + bitrate = max(round(bitrate / 1000), 512) return codec, bitrate @staticmethod - def _probe_codec_fallback(source, executable: str = 'ffmpeg') -> Tuple[Optional[str], Optional[int]]: - args = [executable, '-hide_banner', '-i', source] - proc = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdout=subprocess.PIPE, stderr=subprocess.STDOUT) + def _probe_codec_fallback( + source, executable: str = "ffmpeg" + ) -> Tuple[Optional[str], Optional[int]]: + args = [executable, "-hide_banner", "-i", source] + proc = subprocess.Popen( + args, + creationflags=CREATE_NO_WINDOW, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) out, _ = proc.communicate(timeout=20) - output = out.decode('utf8') + output = out.decode("utf8") codec = bitrate = None codec_match = re.search(r"Stream #0.*?Audio: (\w+)", output) @@ -560,11 +656,12 @@ class FFmpegOpusAudio(FFmpegAudio): return codec, bitrate def read(self) -> bytes: - return next(self._packet_iter, b'') + return next(self._packet_iter, b"") def is_opus(self) -> bool: return True + class PCMVolumeTransformer(AudioSource, Generic[AT]): """Transforms a previous :class:`AudioSource` to have volume controls. @@ -589,10 +686,10 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]): def __init__(self, original: AT, volume: float = 1.0): if not isinstance(original, AudioSource): - raise TypeError(f'expected AudioSource not {original.__class__.__name__}.') + raise TypeError(f"expected AudioSource not {original.__class__.__name__}.") if original.is_opus(): - raise ClientException('AudioSource must not be Opus encoded.') + raise ClientException("AudioSource must not be Opus encoded.") self.original: AT = original self.volume = volume @@ -613,6 +710,7 @@ class PCMVolumeTransformer(AudioSource, Generic[AT]): ret = self.original.read() return audioop.mul(ret, 2, min(self._volume, 2.0)) + class AudioPlayer(threading.Thread): DELAY: float = OpusEncoder.FRAME_LENGTH / 1000.0 @@ -625,7 +723,7 @@ class AudioPlayer(threading.Thread): self._end: threading.Event = threading.Event() self._resumed: threading.Event = threading.Event() - self._resumed.set() # we are not paused + self._resumed.set() # we are not paused self._current_error: Optional[Exception] = None self._connected: threading.Event = client._connected self._lock: threading.Lock = threading.Lock() @@ -685,11 +783,11 @@ class AudioPlayer(threading.Thread): try: self.after(error) except Exception as exc: - _log.exception('Calling the after function failed.') + _log.exception("Calling the after function failed.") exc.__context__ = error traceback.print_exception(type(exc), exc, exc.__traceback__) elif error: - msg = f'Exception in voice thread {self.name}' + msg = f"Exception in voice thread {self.name}" _log.exception(msg, exc_info=error) print(msg, file=sys.stderr) traceback.print_exception(type(error), error, error.__traceback__) @@ -725,6 +823,8 @@ class AudioPlayer(threading.Thread): def _speak(self, speaking: bool) -> None: try: - asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.client.loop) + asyncio.run_coroutine_threadsafe( + self.client.ws.speak(speaking), self.client.loop + ) except Exception as e: _log.info("Speaking call in player failed: %s", e) diff --git a/discord/raw_models.py b/discord/raw_models.py index cda754d1..5f212010 100644 --- a/discord/raw_models.py +++ b/discord/raw_models.py @@ -34,7 +34,7 @@ if TYPE_CHECKING: MessageUpdateEvent, ReactionClearEvent, ReactionClearEmojiEvent, - IntegrationDeleteEvent + IntegrationDeleteEvent, ) from .message import Message from .partial_emoji import PartialEmoji @@ -42,20 +42,20 @@ if TYPE_CHECKING: __all__ = ( - 'RawMessageDeleteEvent', - 'RawBulkMessageDeleteEvent', - 'RawMessageUpdateEvent', - 'RawReactionActionEvent', - 'RawReactionClearEvent', - 'RawReactionClearEmojiEvent', - 'RawIntegrationDeleteEvent', + "RawMessageDeleteEvent", + "RawBulkMessageDeleteEvent", + "RawMessageUpdateEvent", + "RawReactionActionEvent", + "RawReactionClearEvent", + "RawReactionClearEmojiEvent", + "RawIntegrationDeleteEvent", ) class _RawReprMixin: def __repr__(self) -> str: - value = ' '.join(f'{attr}={getattr(self, attr)!r}' for attr in self.__slots__) - return f'<{self.__class__.__name__} {value}>' + value = " ".join(f"{attr}={getattr(self, attr)!r}" for attr in self.__slots__) + return f"<{self.__class__.__name__} {value}>" class RawMessageDeleteEvent(_RawReprMixin): @@ -73,14 +73,14 @@ class RawMessageDeleteEvent(_RawReprMixin): The cached message, if found in the internal message cache. """ - __slots__ = ('message_id', 'channel_id', 'guild_id', 'cached_message') + __slots__ = ("message_id", "channel_id", "guild_id", "cached_message") def __init__(self, data: MessageDeleteEvent) -> None: - self.message_id: int = int(data['id']) - self.channel_id: int = int(data['channel_id']) + self.message_id: int = int(data["id"]) + self.channel_id: int = int(data["channel_id"]) self.cached_message: Optional[Message] = None try: - self.guild_id: Optional[int] = int(data['guild_id']) + self.guild_id: Optional[int] = int(data["guild_id"]) except KeyError: self.guild_id: Optional[int] = None @@ -100,15 +100,15 @@ class RawBulkMessageDeleteEvent(_RawReprMixin): The cached messages, if found in the internal message cache. """ - __slots__ = ('message_ids', 'channel_id', 'guild_id', 'cached_messages') + __slots__ = ("message_ids", "channel_id", "guild_id", "cached_messages") def __init__(self, data: BulkMessageDeleteEvent) -> None: - self.message_ids: Set[int] = {int(x) for x in data.get('ids', [])} - self.channel_id: int = int(data['channel_id']) + self.message_ids: Set[int] = {int(x) for x in data.get("ids", [])} + self.channel_id: int = int(data["channel_id"]) self.cached_messages: List[Message] = [] try: - self.guild_id: Optional[int] = int(data['guild_id']) + self.guild_id: Optional[int] = int(data["guild_id"]) except KeyError: self.guild_id: Optional[int] = None @@ -136,16 +136,16 @@ class RawMessageUpdateEvent(_RawReprMixin): it is modified by the data in :attr:`RawMessageUpdateEvent.data`. """ - __slots__ = ('message_id', 'channel_id', 'guild_id', 'data', 'cached_message') + __slots__ = ("message_id", "channel_id", "guild_id", "data", "cached_message") def __init__(self, data: MessageUpdateEvent) -> None: - self.message_id: int = int(data['id']) - self.channel_id: int = int(data['channel_id']) + self.message_id: int = int(data["id"]) + self.channel_id: int = int(data["channel_id"]) self.data: MessageUpdateEvent = data self.cached_message: Optional[Message] = None try: - self.guild_id: Optional[int] = int(data['guild_id']) + self.guild_id: Optional[int] = int(data["guild_id"]) except KeyError: self.guild_id: Optional[int] = None @@ -179,19 +179,28 @@ class RawReactionActionEvent(_RawReprMixin): .. versionadded:: 1.3 """ - __slots__ = ('message_id', 'user_id', 'channel_id', 'guild_id', 'emoji', - 'event_type', 'member') + __slots__ = ( + "message_id", + "user_id", + "channel_id", + "guild_id", + "emoji", + "event_type", + "member", + ) - def __init__(self, data: ReactionActionEvent, emoji: PartialEmoji, event_type: str) -> None: - self.message_id: int = int(data['message_id']) - self.channel_id: int = int(data['channel_id']) - self.user_id: int = int(data['user_id']) + def __init__( + self, data: ReactionActionEvent, emoji: PartialEmoji, event_type: str + ) -> None: + self.message_id: int = int(data["message_id"]) + self.channel_id: int = int(data["channel_id"]) + self.user_id: int = int(data["user_id"]) self.emoji: PartialEmoji = emoji self.event_type: str = event_type self.member: Optional[Member] = None try: - self.guild_id: Optional[int] = int(data['guild_id']) + self.guild_id: Optional[int] = int(data["guild_id"]) except KeyError: self.guild_id: Optional[int] = None @@ -209,14 +218,14 @@ class RawReactionClearEvent(_RawReprMixin): The guild ID where the reactions got cleared. """ - __slots__ = ('message_id', 'channel_id', 'guild_id') + __slots__ = ("message_id", "channel_id", "guild_id") def __init__(self, data: ReactionClearEvent) -> None: - self.message_id: int = int(data['message_id']) - self.channel_id: int = int(data['channel_id']) + self.message_id: int = int(data["message_id"]) + self.channel_id: int = int(data["channel_id"]) try: - self.guild_id: Optional[int] = int(data['guild_id']) + self.guild_id: Optional[int] = int(data["guild_id"]) except KeyError: self.guild_id: Optional[int] = None @@ -238,15 +247,15 @@ class RawReactionClearEmojiEvent(_RawReprMixin): The custom or unicode emoji being removed. """ - __slots__ = ('message_id', 'channel_id', 'guild_id', 'emoji') + __slots__ = ("message_id", "channel_id", "guild_id", "emoji") def __init__(self, data: ReactionClearEmojiEvent, emoji: PartialEmoji) -> None: self.emoji: PartialEmoji = emoji - self.message_id: int = int(data['message_id']) - self.channel_id: int = int(data['channel_id']) + self.message_id: int = int(data["message_id"]) + self.channel_id: int = int(data["channel_id"]) try: - self.guild_id: Optional[int] = int(data['guild_id']) + self.guild_id: Optional[int] = int(data["guild_id"]) except KeyError: self.guild_id: Optional[int] = None @@ -266,13 +275,13 @@ class RawIntegrationDeleteEvent(_RawReprMixin): The guild ID where the integration got deleted. """ - __slots__ = ('integration_id', 'application_id', 'guild_id') + __slots__ = ("integration_id", "application_id", "guild_id") def __init__(self, data: IntegrationDeleteEvent) -> None: - self.integration_id: int = int(data['id']) - self.guild_id: int = int(data['guild_id']) + self.integration_id: int = int(data["id"]) + self.guild_id: int = int(data["guild_id"]) try: - self.application_id: Optional[int] = int(data['application_id']) + self.application_id: Optional[int] = int(data["application_id"]) except KeyError: self.application_id: Optional[int] = None diff --git a/discord/reaction.py b/discord/reaction.py index 04eee342..80fa0c88 100644 --- a/discord/reaction.py +++ b/discord/reaction.py @@ -27,9 +27,7 @@ from typing import Any, TYPE_CHECKING, Union, Optional from .iterators import ReactionIterator -__all__ = ( - 'Reaction', -) +__all__ = ("Reaction",) if TYPE_CHECKING: from .types.message import Reaction as ReactionPayload @@ -38,6 +36,7 @@ if TYPE_CHECKING: from .emoji import Emoji from .abc import Snowflake + class Reaction: """Represents a reaction to a message. @@ -75,13 +74,22 @@ class Reaction: message: :class:`Message` Message this reaction is for. """ - __slots__ = ('message', 'count', 'emoji', 'me') - def __init__(self, *, message: Message, data: ReactionPayload, emoji: Optional[Union[PartialEmoji, Emoji, str]] = None): + __slots__ = ("message", "count", "emoji", "me") + + def __init__( + self, + *, + message: Message, + data: ReactionPayload, + emoji: Optional[Union[PartialEmoji, Emoji, str]] = None, + ): self.message: Message = message - self.emoji: Union[PartialEmoji, Emoji, str] = emoji or message._state.get_reaction_emoji(data['emoji']) - self.count: int = data.get('count', 1) - self.me: bool = data.get('me') + self.emoji: Union[ + PartialEmoji, Emoji, str + ] = emoji or message._state.get_reaction_emoji(data["emoji"]) + self.count: int = data.get("count", 1) + self.me: bool = data.get("me") # TODO: typeguard def is_custom_emoji(self) -> bool: @@ -103,7 +111,7 @@ class Reaction: return str(self.emoji) def __repr__(self) -> str: - return f'' + return f"" async def remove(self, user: Snowflake) -> None: """|coro| @@ -155,7 +163,9 @@ class Reaction: """ await self.message.clear_reaction(self.emoji) - def users(self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None) -> ReactionIterator: + def users( + self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None + ) -> ReactionIterator: """Returns an :class:`AsyncIterator` representing the users that have reacted to the message. The ``after`` parameter must represent a member @@ -201,7 +211,7 @@ class Reaction: """ if not isinstance(self.emoji, str): - emoji = f'{self.emoji.name}:{self.emoji.id}' + emoji = f"{self.emoji.name}:{self.emoji.id}" else: emoji = self.emoji diff --git a/discord/role.py b/discord/role.py index b690cbe4..da45dd58 100644 --- a/discord/role.py +++ b/discord/role.py @@ -32,8 +32,8 @@ from .mixins import Hashable from .utils import snowflake_time, _get_as_snowflake, MISSING __all__ = ( - 'RoleTags', - 'Role', + "RoleTags", + "Role", ) if TYPE_CHECKING: @@ -68,19 +68,21 @@ class RoleTags: """ __slots__ = ( - 'bot_id', - 'integration_id', - '_premium_subscriber', + "bot_id", + "integration_id", + "_premium_subscriber", ) def __init__(self, data: RoleTagPayload): - self.bot_id: Optional[int] = _get_as_snowflake(data, 'bot_id') - self.integration_id: Optional[int] = _get_as_snowflake(data, 'integration_id') + self.bot_id: Optional[int] = _get_as_snowflake(data, "bot_id") + self.integration_id: Optional[int] = _get_as_snowflake(data, "integration_id") # NOTE: The API returns "null" for this if it's valid, which corresponds to None. # This is different from other fields where "null" means "not there". # So in this case, a value of None is the same as True. # Which means we would need a different sentinel. - self._premium_subscriber: Optional[Any] = data.get('premium_subscriber', MISSING) + self._premium_subscriber: Optional[Any] = data.get( + "premium_subscriber", MISSING + ) def is_bot_managed(self) -> bool: """:class:`bool`: Whether the role is associated with a bot.""" @@ -96,12 +98,12 @@ class RoleTags: def __repr__(self) -> str: return ( - f'' + f"" ) -R = TypeVar('R', bound='Role') +R = TypeVar("R", bound="Role") class Role(Hashable): @@ -181,23 +183,23 @@ class Role(Hashable): """ __slots__ = ( - 'id', - 'name', - '_permissions', - '_colour', - 'position', - 'managed', - 'mentionable', - 'hoist', - 'guild', - 'tags', - '_state', + "id", + "name", + "_permissions", + "_colour", + "position", + "managed", + "mentionable", + "hoist", + "guild", + "tags", + "_state", ) def __init__(self, *, guild: Guild, state: ConnectionState, data: RolePayload): self.guild: Guild = guild self._state: ConnectionState = state - self.id: int = int(data['id']) + self.id: int = int(data["id"]) self._update(data) def __str__(self) -> str: @@ -207,14 +209,14 @@ class Role(Hashable): return self.id def __repr__(self) -> str: - return f'' + return f"" def __lt__(self: R, other: R) -> bool: if not isinstance(other, Role) or not isinstance(self, Role): return NotImplemented if self.guild != other.guild: - raise RuntimeError('cannot compare roles from two different guilds.') + raise RuntimeError("cannot compare roles from two different guilds.") # the @everyone role is always the lowest role in hierarchy guild_id = self.guild.id @@ -246,17 +248,17 @@ class Role(Hashable): return not r def _update(self, data: RolePayload): - self.name: str = data['name'] - self._permissions: int = int(data.get('permissions', 0)) - self.position: int = data.get('position', 0) - self._colour: int = data.get('color', 0) - self.hoist: bool = data.get('hoist', False) - self.managed: bool = data.get('managed', False) - self.mentionable: bool = data.get('mentionable', False) + self.name: str = data["name"] + self._permissions: int = int(data.get("permissions", 0)) + self.position: int = data.get("position", 0) + self._colour: int = data.get("color", 0) + self.hoist: bool = data.get("hoist", False) + self.managed: bool = data.get("managed", False) + self.mentionable: bool = data.get("mentionable", False) self.tags: Optional[RoleTags] try: - self.tags = RoleTags(data['tags']) + self.tags = RoleTags(data["tags"]) except KeyError: self.tags = None @@ -291,7 +293,11 @@ class Role(Hashable): .. versionadded:: 2.0 """ me = self.guild.me - return not self.is_default() and not self.managed and (me.top_role > self or me.id == self.guild.owner_id) + return ( + not self.is_default() + and not self.managed + and (me.top_role > self or me.id == self.guild.owner_id) + ) @property def permissions(self) -> Permissions: @@ -316,7 +322,7 @@ class Role(Hashable): @property def mention(self) -> str: """:class:`str`: Returns a string that allows you to mention a role.""" - return f'<@&{self.id}>' + return f"<@&{self.id}>" @property def members(self) -> List[Member]: @@ -340,15 +346,23 @@ class Role(Hashable): http = self._state.http - change_range = range(min(self.position, position), max(self.position, position) + 1) - roles = [r.id for r in self.guild.roles[1:] if r.position in change_range and r.id != self.id] + change_range = range( + min(self.position, position), max(self.position, position) + 1 + ) + roles = [ + r.id + for r in self.guild.roles[1:] + if r.position in change_range and r.id != self.id + ] if self.position > position: roles.insert(0, self.id) else: roles.append(self.id) - payload: List[RolePositionUpdate] = [{"id": z[0], "position": z[1]} for z in zip(roles, change_range)] + payload: List[RolePositionUpdate] = [ + {"id": z[0], "position": z[1]} for z in zip(roles, change_range) + ] await http.move_role_position(self.guild.id, payload, reason=reason) async def edit( @@ -420,23 +434,25 @@ class Role(Hashable): if colour is not MISSING: if isinstance(colour, int): - payload['color'] = colour + payload["color"] = colour else: - payload['color'] = colour.value + payload["color"] = colour.value if name is not MISSING: - payload['name'] = name + payload["name"] = name if permissions is not MISSING: - payload['permissions'] = permissions.value + payload["permissions"] = permissions.value if hoist is not MISSING: - payload['hoist'] = hoist + payload["hoist"] = hoist if mentionable is not MISSING: - payload['mentionable'] = mentionable + payload["mentionable"] = mentionable - data = await self._state.http.edit_role(self.guild.id, self.id, reason=reason, **payload) + data = await self._state.http.edit_role( + self.guild.id, self.id, reason=reason, **payload + ) return Role(guild=self.guild, data=data, state=self._state) async def delete(self, *, reason: Optional[str] = None) -> None: diff --git a/discord/shard.py b/discord/shard.py index edbdebf4..ff13eee7 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -43,18 +43,28 @@ from .errors import ( from .enums import Status -from typing import TYPE_CHECKING, Any, Callable, Tuple, Type, Optional, List, Dict, TypeVar +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Tuple, + Type, + Optional, + List, + Dict, + TypeVar, +) if TYPE_CHECKING: from .gateway import DiscordWebSocket from .activity import BaseActivity from .enums import Status - EI = TypeVar('EI', bound='EventItem') + EI = TypeVar("EI", bound="EventItem") __all__ = ( - 'AutoShardedClient', - 'ShardInfo', + "AutoShardedClient", + "ShardInfo", ) _log = logging.getLogger(__name__) @@ -70,11 +80,13 @@ class EventType: class EventItem: - __slots__ = ('type', 'shard', 'error') + __slots__ = ("type", "shard", "error") - def __init__(self, etype: int, shard: Optional['Shard'], error: Optional[Exception]) -> None: + def __init__( + self, etype: int, shard: Optional["Shard"], error: Optional[Exception] + ) -> None: self.type: int = etype - self.shard: Optional['Shard'] = shard + self.shard: Optional["Shard"] = shard self.error: Optional[Exception] = error def __lt__(self: EI, other: EI) -> bool: @@ -92,7 +104,12 @@ class EventItem: class Shard: - def __init__(self, ws: DiscordWebSocket, client: AutoShardedClient, queue_put: Callable[[EventItem], None]) -> None: + def __init__( + self, + ws: DiscordWebSocket, + client: AutoShardedClient, + queue_put: Callable[[EventItem], None], + ) -> None: self.ws: DiscordWebSocket = ws self._client: Client = client self._dispatch: Callable[..., None] = client.dispatch @@ -129,11 +146,11 @@ class Shard: async def disconnect(self) -> None: await self.close() - self._dispatch('shard_disconnect', self.id) + self._dispatch("shard_disconnect", self.id) async def _handle_disconnect(self, e: Exception) -> None: - self._dispatch('disconnect') - self._dispatch('shard_disconnect', self.id) + self._dispatch("disconnect") + self._dispatch("shard_disconnect", self.id) if not self._reconnect: self._queue_put(EventItem(EventType.close, self, e)) return @@ -149,14 +166,23 @@ class Shard: if isinstance(e, ConnectionClosed): if e.code == 4014: - self._queue_put(EventItem(EventType.terminate, self, PrivilegedIntentsRequired(self.id))) + self._queue_put( + EventItem( + EventType.terminate, self, PrivilegedIntentsRequired(self.id) + ) + ) return if e.code != 1000: self._queue_put(EventItem(EventType.close, self, e)) return retry = self._backoff.delay() - _log.error('Attempting a reconnect for shard ID %s in %.2fs', self.id, retry, exc_info=e) + _log.error( + "Attempting a reconnect for shard ID %s in %.2fs", + self.id, + retry, + exc_info=e, + ) await asyncio.sleep(retry) self._queue_put(EventItem(EventType.reconnect, self, e)) @@ -179,9 +205,9 @@ class Shard: async def reidentify(self, exc: ReconnectWebSocket) -> None: self._cancel_task() - self._dispatch('disconnect') - self._dispatch('shard_disconnect', self.id) - _log.info('Got a request to %s the websocket at Shard ID %s.', exc.op, self.id) + self._dispatch("disconnect") + self._dispatch("shard_disconnect", self.id) + _log.info("Got a request to %s the websocket at Shard ID %s.", exc.op, self.id) try: coro = DiscordWebSocket.from_client( self._client, @@ -231,7 +257,7 @@ class ShardInfo: The shard count for this cluster. If this is ``None`` then the bot has not started yet. """ - __slots__ = ('_parent', 'id', 'shard_count') + __slots__ = ("_parent", "id", "shard_count") def __init__(self, parent: Shard, shard_count: Optional[int]) -> None: self._parent: Shard = parent @@ -320,16 +346,23 @@ class AutoShardedClient(Client): if TYPE_CHECKING: _connection: AutoShardedConnectionState - def __init__(self, *args: Any, loop: Optional[asyncio.AbstractEventLoop] = None, **kwargs: Any) -> None: - kwargs.pop('shard_id', None) - self.shard_ids: Optional[List[int]] = kwargs.pop('shard_ids', None) + def __init__( + self, + *args: Any, + loop: Optional[asyncio.AbstractEventLoop] = None, + **kwargs: Any, + ) -> None: + kwargs.pop("shard_id", None) + self.shard_ids: Optional[List[int]] = kwargs.pop("shard_ids", None) super().__init__(*args, loop=loop, **kwargs) if self.shard_ids is not None: if self.shard_count is None: - raise ClientException('When passing manual shard_ids, you must provide a shard_count.') + raise ClientException( + "When passing manual shard_ids, you must provide a shard_count." + ) elif not isinstance(self.shard_ids, (list, tuple)): - raise ClientException('shard_ids parameter must be a list or a tuple.') + raise ClientException("shard_ids parameter must be a list or a tuple.") # instead of a single websocket, we have multiple # the key is the shard_id @@ -338,7 +371,9 @@ class AutoShardedClient(Client): self._connection._get_client = lambda: self self.__queue = asyncio.PriorityQueue() - def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket: + def _get_websocket( + self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None + ) -> DiscordWebSocket: if shard_id is None: # guild_id won't be None if shard_id is None and shard_count won't be None here shard_id = (guild_id >> 22) % self.shard_count # type: ignore @@ -363,7 +398,7 @@ class AutoShardedClient(Client): :attr:`latencies` property. Returns ``nan`` if there are no shards ready. """ if not self.__shards: - return float('nan') + return float("nan") return sum(latency for _, latency in self.latencies) / len(self.__shards) @property @@ -372,7 +407,9 @@ class AutoShardedClient(Client): This returns a list of tuples with elements ``(shard_id, latency)``. """ - return [(shard_id, shard.ws.latency) for shard_id, shard in self.__shards.items()] + return [ + (shard_id, shard.ws.latency) for shard_id, shard in self.__shards.items() + ] def get_shard(self, shard_id: int) -> Optional[ShardInfo]: """Optional[:class:`ShardInfo`]: Gets the shard information at a given shard ID or ``None`` if not found.""" @@ -386,14 +423,21 @@ class AutoShardedClient(Client): @property def shards(self) -> Dict[int, ShardInfo]: """Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object.""" - return {shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items()} + return { + shard_id: ShardInfo(parent, self.shard_count) + for shard_id, parent in self.__shards.items() + } - async def launch_shard(self, gateway: str, shard_id: int, *, initial: bool = False) -> None: + async def launch_shard( + self, gateway: str, shard_id: int, *, initial: bool = False + ) -> None: try: - coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id) + coro = DiscordWebSocket.from_client( + self, initial=initial, gateway=gateway, shard_id=shard_id + ) ws = await asyncio.wait_for(coro, timeout=180.0) except Exception: - _log.exception('Failed to connect for shard_id: %s. Retrying...', shard_id) + _log.exception("Failed to connect for shard_id: %s. Retrying...", shard_id) await asyncio.sleep(5.0) return await self.launch_shard(gateway, shard_id) @@ -458,7 +502,10 @@ class AutoShardedClient(Client): except Exception: pass - to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.__shards.values()] + to_close = [ + asyncio.ensure_future(shard.close(), loop=self.loop) + for shard in self.__shards.values() + ] if to_close: await asyncio.wait(to_close) @@ -503,10 +550,10 @@ class AutoShardedClient(Client): """ if status is None: - status_value = 'online' + status_value = "online" status_enum = Status.online elif status is Status.offline: - status_value = 'invisible' + status_value = "invisible" status_enum = Status.offline else: status_enum = status diff --git a/discord/stage_instance.py b/discord/stage_instance.py index b538eec3..3711fd6e 100644 --- a/discord/stage_instance.py +++ b/discord/stage_instance.py @@ -31,9 +31,7 @@ from .mixins import Hashable from .errors import InvalidArgument from .enums import StagePrivacyLevel, try_enum -__all__ = ( - 'StageInstance', -) +__all__ = ("StageInstance",) if TYPE_CHECKING: from .types.channel import StageInstance as StageInstancePayload @@ -82,41 +80,51 @@ class StageInstance(Hashable): """ __slots__ = ( - '_state', - 'id', - 'guild', - 'channel_id', - 'topic', - 'privacy_level', - 'discoverable_disabled', - '_cs_channel', + "_state", + "id", + "guild", + "channel_id", + "topic", + "privacy_level", + "discoverable_disabled", + "_cs_channel", ) - def __init__(self, *, state: ConnectionState, guild: Guild, data: StageInstancePayload) -> None: + def __init__( + self, *, state: ConnectionState, guild: Guild, data: StageInstancePayload + ) -> None: self._state = state self.guild = guild self._update(data) def _update(self, data: StageInstancePayload): - self.id: int = int(data['id']) - self.channel_id: int = int(data['channel_id']) - self.topic: str = data['topic'] - self.privacy_level: StagePrivacyLevel = try_enum(StagePrivacyLevel, data['privacy_level']) - self.discoverable_disabled: bool = data.get('discoverable_disabled', False) + self.id: int = int(data["id"]) + self.channel_id: int = int(data["channel_id"]) + self.topic: str = data["topic"] + self.privacy_level: StagePrivacyLevel = try_enum( + StagePrivacyLevel, data["privacy_level"] + ) + self.discoverable_disabled: bool = data.get("discoverable_disabled", False) def __repr__(self) -> str: - return f'' + return f"" - @cached_slot_property('_cs_channel') + @cached_slot_property("_cs_channel") def channel(self) -> Optional[StageChannel]: """Optional[:class:`StageChannel`]: The channel that stage instance is running in.""" # the returned channel will always be a StageChannel or None - return self._state.get_channel(self.channel_id) # type: ignore + return self._state.get_channel(self.channel_id) # type: ignore def is_public(self) -> bool: return self.privacy_level is StagePrivacyLevel.public - async def edit(self, *, topic: str = MISSING, privacy_level: StagePrivacyLevel = MISSING, reason: Optional[str] = None) -> None: + async def edit( + self, + *, + topic: str = MISSING, + privacy_level: StagePrivacyLevel = MISSING, + reason: Optional[str] = None, + ) -> None: """|coro| Edits the stage instance. @@ -146,16 +154,20 @@ class StageInstance(Hashable): payload = {} if topic is not MISSING: - payload['topic'] = topic + payload["topic"] = topic if privacy_level is not MISSING: if not isinstance(privacy_level, StagePrivacyLevel): - raise InvalidArgument('privacy_level field must be of type PrivacyLevel') + raise InvalidArgument( + "privacy_level field must be of type PrivacyLevel" + ) - payload['privacy_level'] = privacy_level.value + payload["privacy_level"] = privacy_level.value if payload: - await self._state.http.edit_stage_instance(self.channel_id, **payload, reason=reason) + await self._state.http.edit_stage_instance( + self.channel_id, **payload, reason=reason + ) async def delete(self, *, reason: Optional[str] = None) -> None: """|coro| diff --git a/discord/state.py b/discord/state.py index 0a9feac1..96a74b3e 100644 --- a/discord/state.py +++ b/discord/state.py @@ -30,7 +30,20 @@ import copy import datetime import itertools import logging -from typing import Dict, Optional, TYPE_CHECKING, Union, Callable, Any, List, TypeVar, Coroutine, Sequence, Tuple, Deque +from typing import ( + Dict, + Optional, + TYPE_CHECKING, + Union, + Callable, + Any, + List, + TypeVar, + Coroutine, + Sequence, + Tuple, + Deque, +) import inspect import os @@ -76,8 +89,8 @@ if TYPE_CHECKING: from .types.guild import Guild as GuildPayload from .types.message import Message as MessagePayload - T = TypeVar('T') - CS = TypeVar('CS', bound='ConnectionState') + T = TypeVar("T") + CS = TypeVar("CS", bound="ConnectionState") Channel = Union[GuildChannel, VocalGuildChannel, PrivateChannel, PartialMessageable] @@ -132,11 +145,13 @@ class ChunkRequest: _log = logging.getLogger(__name__) -async def logging_coroutine(coroutine: Coroutine[Any, Any, T], *, info: str) -> Optional[T]: +async def logging_coroutine( + coroutine: Coroutine[Any, Any, T], *, info: str +) -> Optional[T]: try: await coroutine except Exception: - _log.exception('Exception occurred during %s', info) + _log.exception("Exception occurred during %s", info) class ConnectionState: @@ -158,7 +173,7 @@ class ConnectionState: ) -> None: self.loop: asyncio.AbstractEventLoop = loop self.http: HTTPClient = http - self.max_messages: Optional[int] = options.get('max_messages', 1000) + self.max_messages: Optional[int] = options.get("max_messages", 1000) if self.max_messages is not None and self.max_messages <= 0: self.max_messages = 1000 @@ -167,52 +182,64 @@ class ConnectionState: self.hooks: Dict[str, Callable] = hooks self.shard_count: Optional[int] = None self._ready_task: Optional[asyncio.Task] = None - self.application_id: Optional[int] = utils._get_as_snowflake(options, 'application_id') - self.heartbeat_timeout: float = options.get('heartbeat_timeout', 60.0) - self.guild_ready_timeout: float = options.get('guild_ready_timeout', 2.0) + self.application_id: Optional[int] = utils._get_as_snowflake( + options, "application_id" + ) + self.heartbeat_timeout: float = options.get("heartbeat_timeout", 60.0) + self.guild_ready_timeout: float = options.get("guild_ready_timeout", 2.0) if self.guild_ready_timeout < 0: - raise ValueError('guild_ready_timeout cannot be negative') + raise ValueError("guild_ready_timeout cannot be negative") - allowed_mentions = options.get('allowed_mentions') + allowed_mentions = options.get("allowed_mentions") - if allowed_mentions is not None and not isinstance(allowed_mentions, AllowedMentions): - raise TypeError('allowed_mentions parameter must be AllowedMentions') + if allowed_mentions is not None and not isinstance( + allowed_mentions, AllowedMentions + ): + raise TypeError("allowed_mentions parameter must be AllowedMentions") self.allowed_mentions: Optional[AllowedMentions] = allowed_mentions self._chunk_requests: Dict[Union[int, str], ChunkRequest] = {} - activity = options.get('activity', None) + activity = options.get("activity", None) if activity: if not isinstance(activity, BaseActivity): - raise TypeError('activity parameter must derive from BaseActivity.') + raise TypeError("activity parameter must derive from BaseActivity.") activity = activity.to_dict() - status = options.get('status', None) + status = options.get("status", None) if status: if status is Status.offline: - status = 'invisible' + status = "invisible" else: status = str(status) if not isinstance(intents, Intents): - raise TypeError(f'intents parameter must be Intent not {type(intents)!r}') + raise TypeError(f"intents parameter must be Intent not {type(intents)!r}") if not intents.guilds: - _log.warning('Guilds intent seems to be disabled. This may cause state related issues.') + _log.warning( + "Guilds intent seems to be disabled. This may cause state related issues." + ) - self._chunk_guilds: bool = options.get('chunk_guilds_at_startup', intents.members) + self._chunk_guilds: bool = options.get( + "chunk_guilds_at_startup", intents.members + ) # Ensure these two are set properly if not intents.members and self._chunk_guilds: - raise ValueError('Intents.members must be enabled to chunk guilds at startup.') + raise ValueError( + "Intents.members must be enabled to chunk guilds at startup." + ) - cache_flags = options.get('member_cache_flags', None) + cache_flags = options.get("member_cache_flags", None) if cache_flags is None: cache_flags = MemberCacheFlags.from_intents(intents) else: if not isinstance(cache_flags, MemberCacheFlags): - raise TypeError(f'member_cache_flags parameter must be MemberCacheFlags not {type(cache_flags)!r}') + raise TypeError( + f"member_cache_flags parameter must be MemberCacheFlags not {type(cache_flags)!r}" + ) cache_flags._verify_intents(intents) @@ -227,7 +254,7 @@ class ConnectionState: self.parsers = parsers = {} for attr, func in inspect.getmembers(self): - if attr.startswith('parse_'): + if attr.startswith("parse_"): parsers[attr[6:].upper()] = func self.clear() @@ -264,7 +291,9 @@ class ConnectionState: else: self._messages: Optional[Deque[Message]] = None - def process_chunk_requests(self, guild_id: int, nonce: Optional[str], members: List[Member], complete: bool) -> None: + def process_chunk_requests( + self, guild_id: int, nonce: Optional[str], members: List[Member], complete: bool + ) -> None: removed = [] for key, request in self._chunk_requests.items(): if request.guild_id == guild_id and request.nonce == nonce: @@ -322,12 +351,12 @@ class ConnectionState: vc.main_ws = ws # type: ignore def store_user(self, data: UserPayload) -> User: - user_id = int(data['id']) + user_id = int(data["id"]) try: return self._users[user_id] except KeyError: user = User(state=self, data=data) - if user.discriminator != '0000': + if user.discriminator != "0000": self._users[user_id] = user user._stored = True return user @@ -347,12 +376,12 @@ class ConnectionState: def store_emoji(self, guild: Guild, data: EmojiPayload) -> Emoji: # the id will be present here - emoji_id = int(data['id']) # type: ignore + emoji_id = int(data["id"]) # type: ignore self._emojis[emoji_id] = emoji = Emoji(guild=guild, state=self, data=data) return emoji def store_sticker(self, guild: Guild, data: GuildStickerPayload) -> GuildSticker: - sticker_id = int(data['id']) + sticker_id = int(data["id"]) self._stickers[sticker_id] = sticker = GuildSticker(state=self, data=data) return sticker @@ -408,7 +437,9 @@ class ConnectionState: def private_channels(self) -> List[PrivateChannel]: return list(self._private_channels.values()) - def _get_private_channel(self, channel_id: Optional[int]) -> Optional[PrivateChannel]: + def _get_private_channel( + self, channel_id: Optional[int] + ) -> Optional[PrivateChannel]: try: # the keys of self._private_channels are ints value = self._private_channels[channel_id] # type: ignore @@ -418,7 +449,9 @@ class ConnectionState: self._private_channels.move_to_end(channel_id) # type: ignore return value - def _get_private_channel_by_user(self, user_id: Optional[int]) -> Optional[DMChannel]: + def _get_private_channel_by_user( + self, user_id: Optional[int] + ) -> Optional[DMChannel]: # the keys of self._private_channels are ints return self._private_channels_by_user.get(user_id) # type: ignore @@ -448,7 +481,11 @@ class ConnectionState: self._private_channels_by_user.pop(recipient.id, None) def _get_message(self, msg_id: Optional[int]) -> Optional[Message]: - return utils.find(lambda m: m.id == msg_id, reversed(self._messages)) if self._messages else None + return ( + utils.find(lambda m: m.id == msg_id, reversed(self._messages)) + if self._messages + else None + ) def _add_guild_from_data(self, data: GuildPayload) -> Guild: guild = Guild(data=data, state=self) @@ -457,12 +494,18 @@ class ConnectionState: def _guild_needs_chunking(self, guild: Guild) -> bool: # If presences are enabled then we get back the old guild.large behaviour - return self._chunk_guilds and not guild.chunked and not (self._intents.presences and not guild.large) + return ( + self._chunk_guilds + and not guild.chunked + and not (self._intents.presences and not guild.large) + ) - def _get_guild_channel(self, data: MessagePayload) -> Tuple[Union[Channel, Thread], Optional[Guild]]: - channel_id = int(data['channel_id']) + def _get_guild_channel( + self, data: MessagePayload + ) -> Tuple[Union[Channel, Thread], Optional[Guild]]: + channel_id = int(data["channel_id"]) try: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) except KeyError: channel = DMChannel._from_message(self, channel_id) guild = None @@ -472,16 +515,32 @@ class ConnectionState: return channel or PartialMessageable(state=self, id=channel_id), guild async def chunker( - self, guild_id: int, query: str = '', limit: int = 0, presences: bool = False, *, nonce: Optional[str] = None + self, + guild_id: int, + query: str = "", + limit: int = 0, + presences: bool = False, + *, + nonce: Optional[str] = None, ) -> None: ws = self._get_websocket(guild_id) # This is ignored upstream - await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) + await ws.request_chunks( + guild_id, query=query, limit=limit, presences=presences, nonce=nonce + ) - async def query_members(self, guild: Guild, query: str, limit: int, user_ids: List[int], cache: bool, presences: bool): + async def query_members( + self, + guild: Guild, + query: str, + limit: int, + user_ids: List[int], + cache: bool, + presences: bool, + ): guild_id = guild.id ws = self._get_websocket(guild_id) if ws is None: - raise RuntimeError('Somehow do not have a websocket for this guild_id') + raise RuntimeError("Somehow do not have a websocket for this guild_id") request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache) self._chunk_requests[request.nonce] = request @@ -489,11 +548,21 @@ class ConnectionState: try: # start the query operation await ws.request_chunks( - guild_id, query=query, limit=limit, user_ids=user_ids, presences=presences, nonce=request.nonce + guild_id, + query=query, + limit=limit, + user_ids=user_ids, + presences=presences, + nonce=request.nonce, ) return await asyncio.wait_for(request.wait(), timeout=30.0) except asyncio.TimeoutError: - _log.warning('Timed out waiting for chunks with query %r and limit %d for guild_id %d', query, limit, guild_id) + _log.warning( + "Timed out waiting for chunks with query %r and limit %d for guild_id %d", + query, + limit, + guild_id, + ) raise async def _delay_ready(self) -> None: @@ -503,7 +572,9 @@ class ConnectionState: # this snippet of code is basically waiting N seconds # until the last GUILD_CREATE was sent try: - guild = await asyncio.wait_for(self._ready_state.get(), timeout=self.guild_ready_timeout) + guild = await asyncio.wait_for( + self._ready_state.get(), timeout=self.guild_ready_timeout + ) except asyncio.TimeoutError: break else: @@ -512,20 +583,24 @@ class ConnectionState: states.append((guild, future)) else: if guild.unavailable is False: - self.dispatch('guild_available', guild) + self.dispatch("guild_available", guild) else: - self.dispatch('guild_join', guild) + self.dispatch("guild_join", guild) for guild, future in states: try: await asyncio.wait_for(future, timeout=5.0) except asyncio.TimeoutError: - _log.warning('Shard ID %s timed out waiting for chunks for guild_id %s.', guild.shard_id, guild.id) + _log.warning( + "Shard ID %s timed out waiting for chunks for guild_id %s.", + guild.shard_id, + guild.id, + ) if guild.unavailable is False: - self.dispatch('guild_available', guild) + self.dispatch("guild_available", guild) else: - self.dispatch('guild_join', guild) + self.dispatch("guild_join", guild) # remove the state try: @@ -537,8 +612,8 @@ class ConnectionState: pass else: # dispatch the event - self.call_handlers('ready') - self.dispatch('ready') + self.call_handlers("ready") + self.dispatch("ready") finally: self._ready_task = None @@ -548,33 +623,33 @@ class ConnectionState: self._ready_state = asyncio.Queue() self.clear(views=False) - self.user = ClientUser(state=self, data=data['user']) - self.store_user(data['user']) + self.user = ClientUser(state=self, data=data["user"]) + self.store_user(data["user"]) if self.application_id is None: try: - application = data['application'] + application = data["application"] except KeyError: pass else: - self.application_id = utils._get_as_snowflake(application, 'id') + self.application_id = utils._get_as_snowflake(application, "id") # flags will always be present here - self.application_flags = ApplicationFlags._from_value(application['flags']) # type: ignore + self.application_flags = ApplicationFlags._from_value(application["flags"]) # type: ignore - for guild_data in data['guilds']: + for guild_data in data["guilds"]: self._add_guild_from_data(guild_data) - self.dispatch('connect') + self.dispatch("connect") self._ready_task = asyncio.create_task(self._delay_ready()) def parse_resumed(self, data) -> None: - self.dispatch('resumed') + self.dispatch("resumed") def parse_message_create(self, data) -> None: channel, _ = self._get_guild_channel(data) # channel would be the correct type here message = Message(channel=channel, data=data, state=self) # type: ignore - self.dispatch('message', message) + self.dispatch("message", message) if self._messages is not None: self._messages.append(message) # we ensure that the channel is either a TextChannel or Thread @@ -585,21 +660,23 @@ class ConnectionState: raw = RawMessageDeleteEvent(data) found = self._get_message(raw.message_id) raw.cached_message = found - self.dispatch('raw_message_delete', raw) + self.dispatch("raw_message_delete", raw) if self._messages is not None and found is not None: - self.dispatch('message_delete', found) + self.dispatch("message_delete", found) self._messages.remove(found) def parse_message_delete_bulk(self, data) -> None: raw = RawBulkMessageDeleteEvent(data) if self._messages: - found_messages = [message for message in self._messages if message.id in raw.message_ids] + found_messages = [ + message for message in self._messages if message.id in raw.message_ids + ] else: found_messages = [] raw.cached_messages = found_messages - self.dispatch('raw_bulk_message_delete', raw) + self.dispatch("raw_bulk_message_delete", raw) if found_messages: - self.dispatch('bulk_message_delete', found_messages) + self.dispatch("bulk_message_delete", found_messages) for msg in found_messages: # self._messages won't be None here self._messages.remove(msg) # type: ignore @@ -610,25 +687,27 @@ class ConnectionState: if message is not None: older_message = copy.copy(message) raw.cached_message = older_message - self.dispatch('raw_message_edit', raw) + self.dispatch("raw_message_edit", raw) message._update(data) # Coerce the `after` parameter to take the new updated Member # ref: #5999 older_message.author = message.author - self.dispatch('message_edit', older_message, message) + self.dispatch("message_edit", older_message, message) else: - self.dispatch('raw_message_edit', raw) + self.dispatch("raw_message_edit", raw) - if 'components' in data and self._view_store.is_message_tracked(raw.message_id): - self._view_store.update_from_message(raw.message_id, data['components']) + if "components" in data and self._view_store.is_message_tracked(raw.message_id): + self._view_store.update_from_message(raw.message_id, data["components"]) def parse_message_reaction_add(self, data) -> None: - emoji = data['emoji'] - emoji_id = utils._get_as_snowflake(emoji, 'id') - emoji = PartialEmoji.with_state(self, id=emoji_id, animated=emoji.get('animated', False), name=emoji['name']) - raw = RawReactionActionEvent(data, emoji, 'REACTION_ADD') + emoji = data["emoji"] + emoji_id = utils._get_as_snowflake(emoji, "id") + emoji = PartialEmoji.with_state( + self, id=emoji_id, animated=emoji.get("animated", False), name=emoji["name"] + ) + raw = RawReactionActionEvent(data, emoji, "REACTION_ADD") - member_data = data.get('member') + member_data = data.get("member") if member_data: guild = self._get_guild(raw.guild_id) if guild is not None: @@ -637,7 +716,7 @@ class ConnectionState: raw.member = None else: raw.member = None - self.dispatch('raw_reaction_add', raw) + self.dispatch("raw_reaction_add", raw) # rich interface here message = self._get_message(raw.message_id) @@ -647,24 +726,24 @@ class ConnectionState: user = raw.member or self._get_reaction_user(message.channel, raw.user_id) if user: - self.dispatch('reaction_add', reaction, user) + self.dispatch("reaction_add", reaction, user) def parse_message_reaction_remove_all(self, data) -> None: raw = RawReactionClearEvent(data) - self.dispatch('raw_reaction_clear', raw) + self.dispatch("raw_reaction_clear", raw) message = self._get_message(raw.message_id) if message is not None: old_reactions = message.reactions.copy() message.reactions.clear() - self.dispatch('reaction_clear', message, old_reactions) + self.dispatch("reaction_clear", message, old_reactions) def parse_message_reaction_remove(self, data) -> None: - emoji = data['emoji'] - emoji_id = utils._get_as_snowflake(emoji, 'id') - emoji = PartialEmoji.with_state(self, id=emoji_id, name=emoji['name']) - raw = RawReactionActionEvent(data, emoji, 'REACTION_REMOVE') - self.dispatch('raw_reaction_remove', raw) + emoji = data["emoji"] + emoji_id = utils._get_as_snowflake(emoji, "id") + emoji = PartialEmoji.with_state(self, id=emoji_id, name=emoji["name"]) + raw = RawReactionActionEvent(data, emoji, "REACTION_REMOVE") + self.dispatch("raw_reaction_remove", raw) message = self._get_message(raw.message_id) if message is not None: @@ -676,14 +755,14 @@ class ConnectionState: else: user = self._get_reaction_user(message.channel, raw.user_id) if user: - self.dispatch('reaction_remove', reaction, user) + self.dispatch("reaction_remove", reaction, user) def parse_message_reaction_remove_emoji(self, data) -> None: - emoji = data['emoji'] - emoji_id = utils._get_as_snowflake(emoji, 'id') - emoji = PartialEmoji.with_state(self, id=emoji_id, name=emoji['name']) + emoji = data["emoji"] + emoji_id = utils._get_as_snowflake(emoji, "id") + emoji = PartialEmoji.with_state(self, id=emoji_id, name=emoji["name"]) raw = RawReactionClearEmojiEvent(data, emoji) - self.dispatch('raw_reaction_clear_emoji', raw) + self.dispatch("raw_reaction_clear_emoji", raw) message = self._get_message(raw.message_id) if message is not None: @@ -693,38 +772,44 @@ class ConnectionState: pass else: if reaction: - self.dispatch('reaction_clear_emoji', reaction) + self.dispatch("reaction_clear_emoji", reaction) def parse_interaction_create(self, data) -> None: interaction = Interaction(data=data, state=self) - if data['type'] == 3: # interaction component - custom_id = interaction.data['custom_id'] # type: ignore - component_type = interaction.data['component_type'] # type: ignore + if data["type"] == 3: # interaction component + custom_id = interaction.data["custom_id"] # type: ignore + component_type = interaction.data["component_type"] # type: ignore self._view_store.dispatch(component_type, custom_id, interaction) - self.dispatch('interaction', interaction) + self.dispatch("interaction", interaction) def parse_presence_update(self, data) -> None: - guild_id = utils._get_as_snowflake(data, 'guild_id') + guild_id = utils._get_as_snowflake(data, "guild_id") # guild_id won't be None here guild = self._get_guild(guild_id) if guild is None: - _log.debug('PRESENCE_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) + _log.debug( + "PRESENCE_UPDATE referencing an unknown guild ID: %s. Discarding.", + guild_id, + ) return - user = data['user'] - member_id = int(user['id']) + user = data["user"] + member_id = int(user["id"]) member = guild.get_member(member_id) if member is None: - _log.debug('PRESENCE_UPDATE referencing an unknown member ID: %s. Discarding', member_id) + _log.debug( + "PRESENCE_UPDATE referencing an unknown member ID: %s. Discarding", + member_id, + ) return old_member = Member._copy(member) user_update = member._presence_update(data=data, user=user) if user_update: - self.dispatch('user_update', user_update[0], user_update[1]) + self.dispatch("user_update", user_update[0], user_update[1]) - self.dispatch('presence_update', old_member, member) + self.dispatch("presence_update", old_member, member) def parse_user_update(self, data) -> None: # self.user is *always* cached when this is called @@ -736,66 +821,78 @@ class ConnectionState: def parse_invite_create(self, data) -> None: invite = Invite.from_gateway(state=self, data=data) - self.dispatch('invite_create', invite) + self.dispatch("invite_create", invite) def parse_invite_delete(self, data) -> None: invite = Invite.from_gateway(state=self, data=data) - self.dispatch('invite_delete', invite) + self.dispatch("invite_delete", invite) def parse_channel_delete(self, data) -> None: - guild = self._get_guild(utils._get_as_snowflake(data, 'guild_id')) - channel_id = int(data['id']) + guild = self._get_guild(utils._get_as_snowflake(data, "guild_id")) + channel_id = int(data["id"]) if guild is not None: channel = guild.get_channel(channel_id) if channel is not None: guild._remove_channel(channel) - self.dispatch('guild_channel_delete', channel) + self.dispatch("guild_channel_delete", channel) def parse_channel_update(self, data) -> None: - channel_type = try_enum(ChannelType, data.get('type')) - channel_id = int(data['id']) + channel_type = try_enum(ChannelType, data.get("type")) + channel_id = int(data["id"]) if channel_type is ChannelType.group: channel = self._get_private_channel(channel_id) old_channel = copy.copy(channel) # the channel is a GroupChannel channel._update_group(data) # type: ignore - self.dispatch('private_channel_update', old_channel, channel) + self.dispatch("private_channel_update", old_channel, channel) return - guild_id = utils._get_as_snowflake(data, 'guild_id') + guild_id = utils._get_as_snowflake(data, "guild_id") guild = self._get_guild(guild_id) if guild is not None: channel = guild.get_channel(channel_id) if channel is not None: old_channel = copy.copy(channel) channel._update(guild, data) - self.dispatch('guild_channel_update', old_channel, channel) + self.dispatch("guild_channel_update", old_channel, channel) else: - _log.debug('CHANNEL_UPDATE referencing an unknown channel ID: %s. Discarding.', channel_id) + _log.debug( + "CHANNEL_UPDATE referencing an unknown channel ID: %s. Discarding.", + channel_id, + ) else: - _log.debug('CHANNEL_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) + _log.debug( + "CHANNEL_UPDATE referencing an unknown guild ID: %s. Discarding.", + guild_id, + ) def parse_channel_create(self, data) -> None: - factory, ch_type = _channel_factory(data['type']) + factory, ch_type = _channel_factory(data["type"]) if factory is None: - _log.debug('CHANNEL_CREATE referencing an unknown channel type %s. Discarding.', data['type']) + _log.debug( + "CHANNEL_CREATE referencing an unknown channel type %s. Discarding.", + data["type"], + ) return - guild_id = utils._get_as_snowflake(data, 'guild_id') + guild_id = utils._get_as_snowflake(data, "guild_id") guild = self._get_guild(guild_id) if guild is not None: # the factory can't be a DMChannel or GroupChannel here channel = factory(guild=guild, state=self, data=data) # type: ignore guild._add_channel(channel) # type: ignore - self.dispatch('guild_channel_create', channel) + self.dispatch("guild_channel_create", channel) else: - _log.debug('CHANNEL_CREATE referencing an unknown guild ID: %s. Discarding.', guild_id) + _log.debug( + "CHANNEL_CREATE referencing an unknown guild ID: %s. Discarding.", + guild_id, + ) return def parse_channel_pins_update(self, data) -> None: - channel_id = int(data['channel_id']) + channel_id = int(data["channel_id"]) try: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) except KeyError: guild = None channel = self._get_private_channel(channel_id) @@ -803,69 +900,88 @@ class ConnectionState: channel = guild and guild._resolve_channel(channel_id) if channel is None: - _log.debug('CHANNEL_PINS_UPDATE referencing an unknown channel ID: %s. Discarding.', channel_id) + _log.debug( + "CHANNEL_PINS_UPDATE referencing an unknown channel ID: %s. Discarding.", + channel_id, + ) return - last_pin = utils.parse_time(data['last_pin_timestamp']) if data['last_pin_timestamp'] else None + last_pin = ( + utils.parse_time(data["last_pin_timestamp"]) + if data["last_pin_timestamp"] + else None + ) if guild is None: - self.dispatch('private_channel_pins_update', channel, last_pin) + self.dispatch("private_channel_pins_update", channel, last_pin) else: - self.dispatch('guild_channel_pins_update', channel, last_pin) + self.dispatch("guild_channel_pins_update", channel, last_pin) def parse_thread_create(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: - _log.debug('THREAD_CREATE referencing an unknown guild ID: %s. Discarding', guild_id) + _log.debug( + "THREAD_CREATE referencing an unknown guild ID: %s. Discarding", + guild_id, + ) return thread = Thread(guild=guild, state=guild._state, data=data) has_thread = guild.get_thread(thread.id) guild._add_thread(thread) if not has_thread: - self.dispatch('thread_join', thread) + self.dispatch("thread_join", thread) def parse_thread_update(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild = self._get_guild(guild_id) if guild is None: - _log.debug('THREAD_UPDATE referencing an unknown guild ID: %s. Discarding', guild_id) + _log.debug( + "THREAD_UPDATE referencing an unknown guild ID: %s. Discarding", + guild_id, + ) return - thread_id = int(data['id']) + thread_id = int(data["id"]) thread = guild.get_thread(thread_id) if thread is not None: old = copy.copy(thread) thread._update(data) - self.dispatch('thread_update', old, thread) + self.dispatch("thread_update", old, thread) else: thread = Thread(guild=guild, state=guild._state, data=data) guild._add_thread(thread) - self.dispatch('thread_join', thread) + self.dispatch("thread_join", thread) def parse_thread_delete(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild = self._get_guild(guild_id) if guild is None: - _log.debug('THREAD_DELETE referencing an unknown guild ID: %s. Discarding', guild_id) + _log.debug( + "THREAD_DELETE referencing an unknown guild ID: %s. Discarding", + guild_id, + ) return - thread_id = int(data['id']) + thread_id = int(data["id"]) thread = guild.get_thread(thread_id) if thread is not None: guild._remove_thread(thread) # type: ignore - self.dispatch('thread_delete', thread) + self.dispatch("thread_delete", thread) def parse_thread_list_sync(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: - _log.debug('THREAD_LIST_SYNC referencing an unknown guild ID: %s. Discarding', guild_id) + _log.debug( + "THREAD_LIST_SYNC referencing an unknown guild ID: %s. Discarding", + guild_id, + ) return try: - channel_ids = set(data['channel_ids']) + channel_ids = set(data["channel_ids"]) except KeyError: # If not provided, then the entire guild is being synced # So all previous thread data should be overwritten @@ -874,12 +990,12 @@ class ConnectionState: else: previous_threads = guild._filter_threads(channel_ids) - threads = {d['id']: guild._store_thread(d) for d in data.get('threads', [])} + threads = {d["id"]: guild._store_thread(d) for d in data.get("threads", [])} - for member in data.get('members', []): + for member in data.get("members", []): try: # note: member['id'] is the thread_id - thread = threads[member['id']] + thread = threads[member["id"]] except KeyError: continue else: @@ -888,63 +1004,78 @@ class ConnectionState: for thread in threads.values(): old = previous_threads.pop(thread.id, None) if old is None: - self.dispatch('thread_join', thread) + self.dispatch("thread_join", thread) for thread in previous_threads.values(): - self.dispatch('thread_remove', thread) + self.dispatch("thread_remove", thread) def parse_thread_member_update(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: - _log.debug('THREAD_MEMBER_UPDATE referencing an unknown guild ID: %s. Discarding', guild_id) + _log.debug( + "THREAD_MEMBER_UPDATE referencing an unknown guild ID: %s. Discarding", + guild_id, + ) return - thread_id = int(data['id']) + thread_id = int(data["id"]) thread: Optional[Thread] = guild.get_thread(thread_id) if thread is None: - _log.debug('THREAD_MEMBER_UPDATE referencing an unknown thread ID: %s. Discarding', thread_id) + _log.debug( + "THREAD_MEMBER_UPDATE referencing an unknown thread ID: %s. Discarding", + thread_id, + ) return member = ThreadMember(thread, data) thread.me = member def parse_thread_members_update(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild: Optional[Guild] = self._get_guild(guild_id) if guild is None: - _log.debug('THREAD_MEMBERS_UPDATE referencing an unknown guild ID: %s. Discarding', guild_id) + _log.debug( + "THREAD_MEMBERS_UPDATE referencing an unknown guild ID: %s. Discarding", + guild_id, + ) return - thread_id = int(data['id']) + thread_id = int(data["id"]) thread: Optional[Thread] = guild.get_thread(thread_id) if thread is None: - _log.debug('THREAD_MEMBERS_UPDATE referencing an unknown thread ID: %s. Discarding', thread_id) + _log.debug( + "THREAD_MEMBERS_UPDATE referencing an unknown thread ID: %s. Discarding", + thread_id, + ) return - added_members = [ThreadMember(thread, d) for d in data.get('added_members', [])] - removed_member_ids = [int(x) for x in data.get('removed_member_ids', [])] + added_members = [ThreadMember(thread, d) for d in data.get("added_members", [])] + removed_member_ids = [int(x) for x in data.get("removed_member_ids", [])] self_id = self.self_id for member in added_members: if member.id != self_id: thread._add_member(member) - self.dispatch('thread_member_join', member) + self.dispatch("thread_member_join", member) else: thread.me = member - self.dispatch('thread_join', thread) + self.dispatch("thread_join", thread) for member_id in removed_member_ids: if member_id != self_id: member = thread._pop_member(member_id) if member is not None: - self.dispatch('thread_member_remove', member) + self.dispatch("thread_member_remove", member) else: - self.dispatch('thread_remove', thread) + self.dispatch("thread_remove", thread) def parse_guild_member_add(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is None: - _log.debug('GUILD_MEMBER_ADD referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "GUILD_MEMBER_ADD referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) return member = Member(guild=guild, data=data, state=self) @@ -956,30 +1087,36 @@ class ConnectionState: except AttributeError: pass - self.dispatch('member_join', member) + self.dispatch("member_join", member) def parse_guild_member_remove(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: try: guild._member_count -= 1 except AttributeError: pass - user_id = int(data['user']['id']) + user_id = int(data["user"]["id"]) member = guild.get_member(user_id) if member is not None: guild._remove_member(member) # type: ignore - self.dispatch('member_remove', member) + self.dispatch("member_remove", member) else: - _log.debug('GUILD_MEMBER_REMOVE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "GUILD_MEMBER_REMOVE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) def parse_guild_member_update(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) - user = data['user'] - user_id = int(user['id']) + guild = self._get_guild(int(data["guild_id"])) + user = data["user"] + user_id = int(user["id"]) if guild is None: - _log.debug('GUILD_MEMBER_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "GUILD_MEMBER_UPDATE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) return member = guild.get_member(user_id) @@ -988,9 +1125,9 @@ class ConnectionState: member._update(data) user_update = member._update_inner_user(user) if user_update: - self.dispatch('user_update', user_update[0], user_update[1]) + self.dispatch("user_update", user_update[0], user_update[1]) - self.dispatch('member_update', old_member, member) + self.dispatch("member_update", old_member, member) else: if self.member_cache_flags.joined: member = Member(data=data, guild=guild, state=self) @@ -998,43 +1135,52 @@ class ConnectionState: # Force an update on the inner user if necessary user_update = member._update_inner_user(user) if user_update: - self.dispatch('user_update', user_update[0], user_update[1]) + self.dispatch("user_update", user_update[0], user_update[1]) guild._add_member(member) - _log.debug('GUILD_MEMBER_UPDATE referencing an unknown member ID: %s. Discarding.', user_id) + _log.debug( + "GUILD_MEMBER_UPDATE referencing an unknown member ID: %s. Discarding.", + user_id, + ) def parse_guild_emojis_update(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is None: - _log.debug('GUILD_EMOJIS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "GUILD_EMOJIS_UPDATE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) return before_emojis = guild.emojis for emoji in before_emojis: self._emojis.pop(emoji.id, None) # guild won't be None here - guild.emojis = tuple(map(lambda d: self.store_emoji(guild, d), data['emojis'])) # type: ignore - self.dispatch('guild_emojis_update', guild, before_emojis, guild.emojis) + guild.emojis = tuple(map(lambda d: self.store_emoji(guild, d), data["emojis"])) # type: ignore + self.dispatch("guild_emojis_update", guild, before_emojis, guild.emojis) def parse_guild_stickers_update(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is None: - _log.debug('GUILD_STICKERS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "GUILD_STICKERS_UPDATE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) return before_stickers = guild.stickers for emoji in before_stickers: self._stickers.pop(emoji.id, None) # guild won't be None here - guild.stickers = tuple(map(lambda d: self.store_sticker(guild, d), data['stickers'])) # type: ignore - self.dispatch('guild_stickers_update', guild, before_stickers, guild.stickers) + guild.stickers = tuple(map(lambda d: self.store_sticker(guild, d), data["stickers"])) # type: ignore + self.dispatch("guild_stickers_update", guild, before_stickers, guild.stickers) def _get_create_guild(self, data): - if data.get('unavailable') is False: + if data.get("unavailable") is False: # GUILD_CREATE with unavailable in the response # usually means that the guild has become available # and is therefore in the cache - guild = self._get_guild(int(data['id'])) + guild = self._get_guild(int(data["id"])) if guild is not None: guild.unavailable = False guild._from_data(data) @@ -1049,7 +1195,9 @@ class ConnectionState: cache = cache or self.member_cache_flags.joined request = self._chunk_requests.get(guild.id) if request is None: - self._chunk_requests[guild.id] = request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache) + self._chunk_requests[guild.id] = request = ChunkRequest( + guild.id, self.loop, self._get_guild, cache=cache + ) await self.chunker(guild.id, nonce=request.nonce) if wait: @@ -1060,15 +1208,15 @@ class ConnectionState: try: await asyncio.wait_for(self.chunk_guild(guild), timeout=60.0) except asyncio.TimeoutError: - _log.info('Somehow timed out waiting for chunks.') + _log.info("Somehow timed out waiting for chunks.") if unavailable is False: - self.dispatch('guild_available', guild) + self.dispatch("guild_available", guild) else: - self.dispatch('guild_join', guild) + self.dispatch("guild_join", guild) def parse_guild_create(self, data) -> None: - unavailable = data.get('unavailable') + unavailable = data.get("unavailable") if unavailable is True: # joined a guild with unavailable == True so.. return @@ -1091,40 +1239,47 @@ class ConnectionState: # Dispatch available if newly available if unavailable is False: - self.dispatch('guild_available', guild) + self.dispatch("guild_available", guild) else: - self.dispatch('guild_join', guild) + self.dispatch("guild_join", guild) def parse_guild_update(self, data) -> None: - guild = self._get_guild(int(data['id'])) + guild = self._get_guild(int(data["id"])) if guild is not None: old_guild = copy.copy(guild) guild._from_data(data) - self.dispatch('guild_update', old_guild, guild) + self.dispatch("guild_update", old_guild, guild) else: - _log.debug('GUILD_UPDATE referencing an unknown guild ID: %s. Discarding.', data['id']) + _log.debug( + "GUILD_UPDATE referencing an unknown guild ID: %s. Discarding.", + data["id"], + ) def parse_guild_delete(self, data) -> None: - guild = self._get_guild(int(data['id'])) + guild = self._get_guild(int(data["id"])) if guild is None: - _log.debug('GUILD_DELETE referencing an unknown guild ID: %s. Discarding.', data['id']) + _log.debug( + "GUILD_DELETE referencing an unknown guild ID: %s. Discarding.", + data["id"], + ) return - if data.get('unavailable', False): + if data.get("unavailable", False): # GUILD_DELETE with unavailable being True means that the # guild that was available is now currently unavailable guild.unavailable = True - self.dispatch('guild_unavailable', guild) + self.dispatch("guild_unavailable", guild) return # do a cleanup of the messages cache if self._messages is not None: self._messages: Optional[Deque[Message]] = deque( - (msg for msg in self._messages if msg.guild != guild), maxlen=self.max_messages + (msg for msg in self._messages if msg.guild != guild), + maxlen=self.max_messages, ) self._remove_guild(guild) - self.dispatch('guild_remove', guild) + self.dispatch("guild_remove", guild) def parse_guild_ban_add(self, data) -> None: # we make the assumption that GUILD_BAN_ADD is done @@ -1132,205 +1287,263 @@ class ConnectionState: # hence we don't remove it from cache or do anything # strange with it, the main purpose of this event # is mainly to dispatch to another event worth listening to for logging - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: try: - user = User(data=data['user'], state=self) + user = User(data=data["user"], state=self) except KeyError: pass else: member = guild.get_member(user.id) or user - self.dispatch('member_ban', guild, member) + self.dispatch("member_ban", guild, member) def parse_guild_ban_remove(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) - if guild is not None and 'user' in data: - user = self.store_user(data['user']) - self.dispatch('member_unban', guild, user) + guild = self._get_guild(int(data["guild_id"])) + if guild is not None and "user" in data: + user = self.store_user(data["user"]) + self.dispatch("member_unban", guild, user) def parse_guild_role_create(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is None: - _log.debug('GUILD_ROLE_CREATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "GUILD_ROLE_CREATE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) return - role_data = data['role'] + role_data = data["role"] role = Role(guild=guild, data=role_data, state=self) guild._add_role(role) - self.dispatch('guild_role_create', role) + self.dispatch("guild_role_create", role) def parse_guild_role_delete(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: - role_id = int(data['role_id']) + role_id = int(data["role_id"]) try: role = guild._remove_role(role_id) except KeyError: return else: - self.dispatch('guild_role_delete', role) + self.dispatch("guild_role_delete", role) else: - _log.debug('GUILD_ROLE_DELETE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "GUILD_ROLE_DELETE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) def parse_guild_role_update(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: - role_data = data['role'] - role_id = int(role_data['id']) + role_data = data["role"] + role_id = int(role_data["id"]) role = guild.get_role(role_id) if role is not None: old_role = copy.copy(role) role._update(role_data) - self.dispatch('guild_role_update', old_role, role) + self.dispatch("guild_role_update", old_role, role) else: - _log.debug('GUILD_ROLE_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "GUILD_ROLE_UPDATE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) def parse_guild_members_chunk(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild = self._get_guild(guild_id) - presences = data.get('presences', []) + presences = data.get("presences", []) # the guild won't be None here - members = [Member(guild=guild, data=member, state=self) for member in data.get('members', [])] # type: ignore - _log.debug('Processed a chunk for %s members in guild ID %s.', len(members), guild_id) + members = [Member(guild=guild, data=member, state=self) for member in data.get("members", [])] # type: ignore + _log.debug( + "Processed a chunk for %s members in guild ID %s.", len(members), guild_id + ) if presences: member_dict = {str(member.id): member for member in members} for presence in presences: - user = presence['user'] - member_id = user['id'] + user = presence["user"] + member_id = user["id"] member = member_dict.get(member_id) if member is not None: member._presence_update(presence, user) - complete = data.get('chunk_index', 0) + 1 == data.get('chunk_count') - self.process_chunk_requests(guild_id, data.get('nonce'), members, complete) + complete = data.get("chunk_index", 0) + 1 == data.get("chunk_count") + self.process_chunk_requests(guild_id, data.get("nonce"), members, complete) def parse_guild_integrations_update(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: - self.dispatch('guild_integrations_update', guild) + self.dispatch("guild_integrations_update", guild) else: - _log.debug('GUILD_INTEGRATIONS_UPDATE referencing an unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "GUILD_INTEGRATIONS_UPDATE referencing an unknown guild ID: %s. Discarding.", + data["guild_id"], + ) def parse_integration_create(self, data) -> None: - guild_id = int(data.pop('guild_id')) + guild_id = int(data.pop("guild_id")) guild = self._get_guild(guild_id) if guild is not None: - cls, _ = _integration_factory(data['type']) + cls, _ = _integration_factory(data["type"]) integration = cls(data=data, guild=guild) - self.dispatch('integration_create', integration) + self.dispatch("integration_create", integration) else: - _log.debug('INTEGRATION_CREATE referencing an unknown guild ID: %s. Discarding.', guild_id) + _log.debug( + "INTEGRATION_CREATE referencing an unknown guild ID: %s. Discarding.", + guild_id, + ) def parse_integration_update(self, data) -> None: - guild_id = int(data.pop('guild_id')) + guild_id = int(data.pop("guild_id")) guild = self._get_guild(guild_id) if guild is not None: - cls, _ = _integration_factory(data['type']) + cls, _ = _integration_factory(data["type"]) integration = cls(data=data, guild=guild) - self.dispatch('integration_update', integration) + self.dispatch("integration_update", integration) else: - _log.debug('INTEGRATION_UPDATE referencing an unknown guild ID: %s. Discarding.', guild_id) + _log.debug( + "INTEGRATION_UPDATE referencing an unknown guild ID: %s. Discarding.", + guild_id, + ) def parse_integration_delete(self, data) -> None: - guild_id = int(data['guild_id']) + guild_id = int(data["guild_id"]) guild = self._get_guild(guild_id) if guild is not None: raw = RawIntegrationDeleteEvent(data) - self.dispatch('raw_integration_delete', raw) + self.dispatch("raw_integration_delete", raw) else: - _log.debug('INTEGRATION_DELETE referencing an unknown guild ID: %s. Discarding.', guild_id) + _log.debug( + "INTEGRATION_DELETE referencing an unknown guild ID: %s. Discarding.", + guild_id, + ) def parse_webhooks_update(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is None: - _log.debug('WEBHOOKS_UPDATE referencing an unknown guild ID: %s. Discarding', data['guild_id']) + _log.debug( + "WEBHOOKS_UPDATE referencing an unknown guild ID: %s. Discarding", + data["guild_id"], + ) return - channel = guild.get_channel(int(data['channel_id'])) + channel = guild.get_channel(int(data["channel_id"])) if channel is not None: - self.dispatch('webhooks_update', channel) + self.dispatch("webhooks_update", channel) else: - _log.debug('WEBHOOKS_UPDATE referencing an unknown channel ID: %s. Discarding.', data['channel_id']) + _log.debug( + "WEBHOOKS_UPDATE referencing an unknown channel ID: %s. Discarding.", + data["channel_id"], + ) def parse_stage_instance_create(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: stage_instance = StageInstance(guild=guild, state=self, data=data) guild._stage_instances[stage_instance.id] = stage_instance - self.dispatch('stage_instance_create', stage_instance) + self.dispatch("stage_instance_create", stage_instance) else: - _log.debug('STAGE_INSTANCE_CREATE referencing unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "STAGE_INSTANCE_CREATE referencing unknown guild ID: %s. Discarding.", + data["guild_id"], + ) def parse_stage_instance_update(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: - stage_instance = guild._stage_instances.get(int(data['id'])) + stage_instance = guild._stage_instances.get(int(data["id"])) if stage_instance is not None: old_stage_instance = copy.copy(stage_instance) stage_instance._update(data) - self.dispatch('stage_instance_update', old_stage_instance, stage_instance) + self.dispatch( + "stage_instance_update", old_stage_instance, stage_instance + ) else: - _log.debug('STAGE_INSTANCE_UPDATE referencing unknown stage instance ID: %s. Discarding.', data['id']) + _log.debug( + "STAGE_INSTANCE_UPDATE referencing unknown stage instance ID: %s. Discarding.", + data["id"], + ) else: - _log.debug('STAGE_INSTANCE_UPDATE referencing unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "STAGE_INSTANCE_UPDATE referencing unknown guild ID: %s. Discarding.", + data["guild_id"], + ) def parse_stage_instance_delete(self, data) -> None: - guild = self._get_guild(int(data['guild_id'])) + guild = self._get_guild(int(data["guild_id"])) if guild is not None: try: - stage_instance = guild._stage_instances.pop(int(data['id'])) + stage_instance = guild._stage_instances.pop(int(data["id"])) except KeyError: pass else: - self.dispatch('stage_instance_delete', stage_instance) + self.dispatch("stage_instance_delete", stage_instance) else: - _log.debug('STAGE_INSTANCE_DELETE referencing unknown guild ID: %s. Discarding.', data['guild_id']) + _log.debug( + "STAGE_INSTANCE_DELETE referencing unknown guild ID: %s. Discarding.", + data["guild_id"], + ) def parse_voice_state_update(self, data) -> None: - guild = self._get_guild(utils._get_as_snowflake(data, 'guild_id')) - channel_id = utils._get_as_snowflake(data, 'channel_id') + guild = self._get_guild(utils._get_as_snowflake(data, "guild_id")) + channel_id = utils._get_as_snowflake(data, "channel_id") flags = self.member_cache_flags # self.user is *always* cached when this is called self_id = self.user.id # type: ignore if guild is not None: - if int(data['user_id']) == self_id: + if int(data["user_id"]) == self_id: voice = self._get_voice_client(guild.id) if voice is not None: coro = voice.on_voice_state_update(data) - asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice state update handler')) + asyncio.create_task( + logging_coroutine( + coro, info="Voice Protocol voice state update handler" + ) + ) member, before, after = guild._update_voice_state(data, channel_id) # type: ignore if member is not None: if flags.voice: - if channel_id is None and flags._voice_only and member.id != self_id: + if ( + channel_id is None + and flags._voice_only + and member.id != self_id + ): # Only remove from cache if we only have the voice flag enabled # Member doesn't meet the Snowflake protocol currently guild._remove_member(member) # type: ignore elif channel_id is not None: guild._add_member(member) - self.dispatch('voice_state_update', member, before, after) + self.dispatch("voice_state_update", member, before, after) else: - _log.debug('VOICE_STATE_UPDATE referencing an unknown member ID: %s. Discarding.', data['user_id']) + _log.debug( + "VOICE_STATE_UPDATE referencing an unknown member ID: %s. Discarding.", + data["user_id"], + ) def parse_voice_server_update(self, data) -> None: try: - key_id = int(data['guild_id']) + key_id = int(data["guild_id"]) except KeyError: - key_id = int(data['channel_id']) + key_id = int(data["channel_id"]) vc = self._get_voice_client(key_id) if vc is not None: coro = vc.on_voice_server_update(data) - asyncio.create_task(logging_coroutine(coro, info='Voice Protocol voice server update handler')) + asyncio.create_task( + logging_coroutine( + coro, info="Voice Protocol voice server update handler" + ) + ) def parse_typing_start(self, data) -> None: channel, guild = self._get_guild_channel(data) if channel is not None: member = None - user_id = utils._get_as_snowflake(data, 'user_id') + user_id = utils._get_as_snowflake(data, "user_id") if isinstance(channel, DMChannel): member = channel.recipient @@ -1339,7 +1552,7 @@ class ConnectionState: member = guild.get_member(user_id) # type: ignore if member is None: - member_data = data.get('member') + member_data = data.get("member") if member_data: member = Member(data=member_data, state=self, guild=guild) @@ -1347,26 +1560,37 @@ class ConnectionState: member = utils.find(lambda x: x.id == user_id, channel.recipients) if member is not None: - timestamp = datetime.datetime.fromtimestamp(data.get('timestamp'), tz=datetime.timezone.utc) - self.dispatch('typing', channel, member, timestamp) + timestamp = datetime.datetime.fromtimestamp( + data.get("timestamp"), tz=datetime.timezone.utc + ) + self.dispatch("typing", channel, member, timestamp) - def _get_reaction_user(self, channel: MessageableChannel, user_id: int) -> Optional[Union[User, Member]]: + def _get_reaction_user( + self, channel: MessageableChannel, user_id: int + ) -> Optional[Union[User, Member]]: if isinstance(channel, TextChannel): return channel.guild.get_member(user_id) return self.get_user(user_id) def get_reaction_emoji(self, data) -> Union[Emoji, PartialEmoji]: - emoji_id = utils._get_as_snowflake(data, 'id') + emoji_id = utils._get_as_snowflake(data, "id") if not emoji_id: - return data['name'] + return data["name"] try: return self._emojis[emoji_id] except KeyError: - return PartialEmoji.with_state(self, animated=data.get('animated', False), id=emoji_id, name=data['name']) + return PartialEmoji.with_state( + self, + animated=data.get("animated", False), + id=emoji_id, + name=data["name"], + ) - def _upgrade_partial_emoji(self, emoji: PartialEmoji) -> Union[Emoji, PartialEmoji, str]: + def _upgrade_partial_emoji( + self, emoji: PartialEmoji + ) -> Union[Emoji, PartialEmoji, str]: emoji_id = emoji.id if not emoji_id: return emoji.name @@ -1389,7 +1613,12 @@ class ConnectionState: return channel def create_message( - self, *, channel: Union[TextChannel, Thread, DMChannel, GroupChannel, PartialMessageable], data: MessagePayload + self, + *, + channel: Union[ + TextChannel, Thread, DMChannel, GroupChannel, PartialMessageable + ], + data: MessagePayload, ) -> Message: return Message(state=self, channel=channel, data=data) @@ -1409,14 +1638,16 @@ class AutoShardedConnectionState(ConnectionState): new_guild = self._get_guild(msg.guild.id) if new_guild is not None and new_guild is not msg.guild: channel_id = msg.channel.id - channel = new_guild._resolve_channel(channel_id) or Object(id=channel_id) + channel = new_guild._resolve_channel(channel_id) or Object( + id=channel_id + ) # channel will either be a TextChannel, Thread or Object msg._rebind_cached_references(new_guild, channel) # type: ignore async def chunker( self, guild_id: int, - query: str = '', + query: str = "", limit: int = 0, presences: bool = False, *, @@ -1424,7 +1655,9 @@ class AutoShardedConnectionState(ConnectionState): nonce: Optional[str] = None, ) -> None: ws = self._get_websocket(guild_id, shard_id=shard_id) - await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) + await ws.request_chunks( + guild_id, query=query, limit=limit, presences=presences, nonce=nonce + ) async def _delay_ready(self) -> None: await self.shards_launched.wait() @@ -1435,17 +1668,24 @@ class AutoShardedConnectionState(ConnectionState): # this snippet of code is basically waiting N seconds # until the last GUILD_CREATE was sent try: - guild = await asyncio.wait_for(self._ready_state.get(), timeout=self.guild_ready_timeout) + guild = await asyncio.wait_for( + self._ready_state.get(), timeout=self.guild_ready_timeout + ) except asyncio.TimeoutError: break else: if self._guild_needs_chunking(guild): - _log.debug('Guild ID %d requires chunking, will be done in the background.', guild.id) + _log.debug( + "Guild ID %d requires chunking, will be done in the background.", + guild.id, + ) if len(current_bucket) >= max_concurrency: try: - await utils.sane_wait_for(current_bucket, timeout=max_concurrency * 70.0) + await utils.sane_wait_for( + current_bucket, timeout=max_concurrency * 70.0 + ) except asyncio.TimeoutError: - fmt = 'Shard ID %s failed to wait for chunks from a sub-bucket with length %d' + fmt = "Shard ID %s failed to wait for chunks from a sub-bucket with length %d" _log.warning(fmt, guild.shard_id, len(current_bucket)) finally: current_bucket = [] @@ -1468,15 +1708,18 @@ class AutoShardedConnectionState(ConnectionState): await utils.sane_wait_for(futures, timeout=timeout) except asyncio.TimeoutError: _log.warning( - 'Shard ID %s failed to wait for chunks (timeout=%.2f) for %d guilds', shard_id, timeout, len(guilds) + "Shard ID %s failed to wait for chunks (timeout=%.2f) for %d guilds", + shard_id, + timeout, + len(guilds), ) for guild in children: if guild.unavailable is False: - self.dispatch('guild_available', guild) + self.dispatch("guild_available", guild) else: - self.dispatch('guild_join', guild) + self.dispatch("guild_join", guild) - self.dispatch('shard_ready', shard_id) + self.dispatch("shard_ready", shard_id) # remove the state try: @@ -1490,38 +1733,40 @@ class AutoShardedConnectionState(ConnectionState): self._ready_task = None # dispatch the event - self.call_handlers('ready') - self.dispatch('ready') + self.call_handlers("ready") + self.dispatch("ready") def parse_ready(self, data) -> None: - if not hasattr(self, '_ready_state'): + if not hasattr(self, "_ready_state"): self._ready_state = asyncio.Queue() - self.user = user = ClientUser(state=self, data=data['user']) + self.user = user = ClientUser(state=self, data=data["user"]) # self._users is a list of Users, we're setting a ClientUser self._users[user.id] = user # type: ignore if self.application_id is None: try: - application = data['application'] + application = data["application"] except KeyError: pass else: - self.application_id = utils._get_as_snowflake(application, 'id') - self.application_flags = ApplicationFlags._from_value(application['flags']) + self.application_id = utils._get_as_snowflake(application, "id") + self.application_flags = ApplicationFlags._from_value( + application["flags"] + ) - for guild_data in data['guilds']: + for guild_data in data["guilds"]: self._add_guild_from_data(guild_data) if self._messages: self._update_message_references() - self.dispatch('connect') - self.dispatch('shard_connect', data['__shard_id__']) + self.dispatch("connect") + self.dispatch("shard_connect", data["__shard_id__"]) if self._ready_task is None: self._ready_task = asyncio.create_task(self._delay_ready()) def parse_resumed(self, data) -> None: - self.dispatch('resumed') - self.dispatch('shard_resumed', data['__shard_id__']) + self.dispatch("resumed") + self.dispatch("shard_resumed", data["__shard_id__"]) diff --git a/discord/sticker.py b/discord/sticker.py index 8ec2ad01..f03ddb47 100644 --- a/discord/sticker.py +++ b/discord/sticker.py @@ -33,11 +33,11 @@ from .errors import InvalidData from .enums import StickerType, StickerFormatType, try_enum __all__ = ( - 'StickerPack', - 'StickerItem', - 'Sticker', - 'StandardSticker', - 'GuildSticker', + "StickerPack", + "StickerItem", + "Sticker", + "StandardSticker", + "GuildSticker", ) if TYPE_CHECKING: @@ -102,15 +102,15 @@ class StickerPack(Hashable): """ __slots__ = ( - '_state', - 'id', - 'stickers', - 'name', - 'sku_id', - 'cover_sticker_id', - 'cover_sticker', - 'description', - '_banner', + "_state", + "id", + "stickers", + "name", + "sku_id", + "cover_sticker_id", + "cover_sticker", + "description", + "_banner", ) def __init__(self, *, state: ConnectionState, data: StickerPackPayload) -> None: @@ -118,15 +118,17 @@ class StickerPack(Hashable): self._from_data(data) def _from_data(self, data: StickerPackPayload) -> None: - self.id: int = int(data['id']) - stickers = data['stickers'] - self.stickers: List[StandardSticker] = [StandardSticker(state=self._state, data=sticker) for sticker in stickers] - self.name: str = data['name'] - self.sku_id: int = int(data['sku_id']) - self.cover_sticker_id: int = int(data['cover_sticker_id']) - self.cover_sticker: StandardSticker = get(self.stickers, id=self.cover_sticker_id) # type: ignore - self.description: str = data['description'] - self._banner: int = int(data['banner_asset_id']) + self.id: int = int(data["id"]) + stickers = data["stickers"] + self.stickers: List[StandardSticker] = [ + StandardSticker(state=self._state, data=sticker) for sticker in stickers + ] + self.name: str = data["name"] + self.sku_id: int = int(data["sku_id"]) + self.cover_sticker_id: int = int(data["cover_sticker_id"]) + self.cover_sticker: StandardSticker = get(self.stickers, id=self.cover_sticker_id) # type: ignore + self.description: str = data["description"] + self._banner: int = int(data["banner_asset_id"]) @property def banner(self) -> Asset: @@ -134,7 +136,7 @@ class StickerPack(Hashable): return Asset._from_sticker_banner(self._state, self._banner) def __repr__(self) -> str: - return f'' + return f"" def __str__(self) -> str: return self.name @@ -205,17 +207,19 @@ class StickerItem(_StickerTag): The URL for the sticker's image. """ - __slots__ = ('_state', 'name', 'id', 'format', 'url') + __slots__ = ("_state", "name", "id", "format", "url") def __init__(self, *, state: ConnectionState, data: StickerItemPayload): self._state: ConnectionState = state - self.name: str = data['name'] - self.id: int = int(data['id']) - self.format: StickerFormatType = try_enum(StickerFormatType, data['format_type']) - self.url: str = f'{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}' + self.name: str = data["name"] + self.id: int = int(data["id"]) + self.format: StickerFormatType = try_enum( + StickerFormatType, data["format_type"] + ) + self.url: str = f"{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}" def __repr__(self) -> str: - return f'' + return f"" def __str__(self) -> str: return self.name @@ -236,7 +240,7 @@ class StickerItem(_StickerTag): The retrieved sticker. """ data: StickerPayload = await self._state.http.get_sticker(self.id) - cls, _ = _sticker_factory(data['type']) # type: ignore + cls, _ = _sticker_factory(data["type"]) # type: ignore return cls(state=self._state, data=data) @@ -275,21 +279,23 @@ class Sticker(_StickerTag): The URL for the sticker's image. """ - __slots__ = ('_state', 'id', 'name', 'description', 'format', 'url') + __slots__ = ("_state", "id", "name", "description", "format", "url") def __init__(self, *, state: ConnectionState, data: StickerPayload) -> None: self._state: ConnectionState = state self._from_data(data) def _from_data(self, data: StickerPayload) -> None: - self.id: int = int(data['id']) - self.name: str = data['name'] - self.description: str = data['description'] - self.format: StickerFormatType = try_enum(StickerFormatType, data['format_type']) - self.url: str = f'{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}' + self.id: int = int(data["id"]) + self.name: str = data["name"] + self.description: str = data["description"] + self.format: StickerFormatType = try_enum( + StickerFormatType, data["format_type"] + ) + self.url: str = f"{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}" def __repr__(self) -> str: - return f'' + return f"" def __str__(self) -> str: return self.name @@ -337,21 +343,23 @@ class StandardSticker(Sticker): The sticker's sort order within its pack. """ - __slots__ = ('sort_value', 'pack_id', 'type', 'tags') + __slots__ = ("sort_value", "pack_id", "type", "tags") def _from_data(self, data: StandardStickerPayload) -> None: super()._from_data(data) - self.sort_value: int = data['sort_value'] - self.pack_id: int = int(data['pack_id']) + self.sort_value: int = data["sort_value"] + self.pack_id: int = int(data["pack_id"]) self.type: StickerType = StickerType.standard try: - self.tags: List[str] = [tag.strip() for tag in data['tags'].split(',')] + self.tags: List[str] = [tag.strip() for tag in data["tags"].split(",")] except KeyError: self.tags = [] def __repr__(self) -> str: - return f'' + return ( + f"" + ) async def pack(self) -> StickerPack: """|coro| @@ -370,13 +378,15 @@ class StandardSticker(Sticker): :class:`StickerPack` The retrieved sticker pack. """ - data: ListPremiumStickerPacksPayload = await self._state.http.list_premium_sticker_packs() - packs = data['sticker_packs'] - pack = find(lambda d: int(d['id']) == self.pack_id, packs) + data: ListPremiumStickerPacksPayload = ( + await self._state.http.list_premium_sticker_packs() + ) + packs = data["sticker_packs"] + pack = find(lambda d: int(d["id"]) == self.pack_id, packs) if pack: return StickerPack(state=self._state, data=pack) - raise InvalidData(f'Could not find corresponding sticker pack for {self!r}') + raise InvalidData(f"Could not find corresponding sticker pack for {self!r}") class GuildSticker(Sticker): @@ -419,21 +429,21 @@ class GuildSticker(Sticker): The name of a unicode emoji that represents this sticker. """ - __slots__ = ('available', 'guild_id', 'user', 'emoji', 'type', '_cs_guild') + __slots__ = ("available", "guild_id", "user", "emoji", "type", "_cs_guild") def _from_data(self, data: GuildStickerPayload) -> None: super()._from_data(data) - self.available: bool = data['available'] - self.guild_id: int = int(data['guild_id']) - user = data.get('user') + self.available: bool = data["available"] + self.guild_id: int = int(data["guild_id"]) + user = data.get("user") self.user: Optional[User] = self._state.store_user(user) if user else None - self.emoji: str = data['tags'] + self.emoji: str = data["tags"] self.type: StickerType = StickerType.guild def __repr__(self) -> str: - return f'' + return f"" - @cached_slot_property('_cs_guild') + @cached_slot_property("_cs_guild") def guild(self) -> Optional[Guild]: """Optional[:class:`Guild`]: The guild that this sticker is from. Could be ``None`` if the bot is not in the guild. @@ -480,10 +490,10 @@ class GuildSticker(Sticker): payload: EditGuildSticker = {} if name is not MISSING: - payload['name'] = name + payload["name"] = name if description is not MISSING: - payload['description'] = description + payload["description"] = description if emoji is not MISSING: try: @@ -491,11 +501,13 @@ class GuildSticker(Sticker): except TypeError: pass else: - emoji = emoji.replace(' ', '_') + emoji = emoji.replace(" ", "_") - payload['tags'] = emoji + payload["tags"] = emoji - data: GuildStickerPayload = await self._state.http.modify_guild_sticker(self.guild_id, self.id, payload, reason) + data: GuildStickerPayload = await self._state.http.modify_guild_sticker( + self.guild_id, self.id, payload, reason + ) return GuildSticker(state=self._state, data=data) async def delete(self, *, reason: Optional[str] = None) -> None: @@ -521,7 +533,9 @@ class GuildSticker(Sticker): await self._state.http.delete_guild_sticker(self.guild_id, self.id, reason) -def _sticker_factory(sticker_type: Literal[1, 2]) -> Tuple[Type[Union[StandardSticker, GuildSticker, Sticker]], StickerType]: +def _sticker_factory( + sticker_type: Literal[1, 2] +) -> Tuple[Type[Union[StandardSticker, GuildSticker, Sticker]], StickerType]: value = try_enum(StickerType, sticker_type) if value == StickerType.standard: return StandardSticker, value diff --git a/discord/team.py b/discord/team.py index 538aaba1..bf89c32d 100644 --- a/discord/team.py +++ b/discord/team.py @@ -40,8 +40,8 @@ if TYPE_CHECKING: ) __all__ = ( - 'Team', - 'TeamMember', + "Team", + "TeamMember", ) @@ -62,26 +62,28 @@ class Team: .. versionadded:: 1.3 """ - __slots__ = ('_state', 'id', 'name', '_icon', 'owner_id', 'members') + __slots__ = ("_state", "id", "name", "_icon", "owner_id", "members") def __init__(self, state: ConnectionState, data: TeamPayload): self._state: ConnectionState = state - self.id: int = int(data['id']) - self.name: str = data['name'] - self._icon: Optional[str] = data['icon'] - self.owner_id: Optional[int] = utils._get_as_snowflake(data, 'owner_user_id') - self.members: List[TeamMember] = [TeamMember(self, self._state, member) for member in data['members']] + self.id: int = int(data["id"]) + self.name: str = data["name"] + self._icon: Optional[str] = data["icon"] + self.owner_id: Optional[int] = utils._get_as_snowflake(data, "owner_user_id") + self.members: List[TeamMember] = [ + TeamMember(self, self._state, member) for member in data["members"] + ] def __repr__(self) -> str: - return f'<{self.__class__.__name__} id={self.id} name={self.name}>' + return f"<{self.__class__.__name__} id={self.id} name={self.name}>" @property def icon(self) -> Optional[Asset]: """Optional[:class:`.Asset`]: Retrieves the team's icon asset, if any.""" if self._icon is None: return None - return Asset._from_icon(self._state, self.id, self._icon, path='team') + return Asset._from_icon(self._state, self.id, self._icon, path="team") @property def owner(self) -> Optional[TeamMember]: @@ -130,16 +132,18 @@ class TeamMember(BaseUser): The membership state of the member (e.g. invited or accepted) """ - __slots__ = ('team', 'membership_state', 'permissions') + __slots__ = ("team", "membership_state", "permissions") def __init__(self, team: Team, state: ConnectionState, data: TeamMemberPayload): self.team: Team = team - self.membership_state: TeamMembershipState = try_enum(TeamMembershipState, data['membership_state']) - self.permissions: List[str] = data['permissions'] - super().__init__(state=state, data=data['user']) + self.membership_state: TeamMembershipState = try_enum( + TeamMembershipState, data["membership_state"] + ) + self.permissions: List[str] = data["permissions"] + super().__init__(state=state, data=data["user"]) def __repr__(self) -> str: return ( - f'<{self.__class__.__name__} id={self.id} name={self.name!r} ' - f'discriminator={self.discriminator!r} membership_state={self.membership_state!r}>' + f"<{self.__class__.__name__} id={self.id} name={self.name!r} " + f"discriminator={self.discriminator!r} membership_state={self.membership_state!r}>" ) diff --git a/discord/template.py b/discord/template.py index 30af3a4d..30798936 100644 --- a/discord/template.py +++ b/discord/template.py @@ -29,9 +29,7 @@ from .utils import parse_time, _get_as_snowflake, _bytes_to_base64_data, MISSING from .enums import VoiceRegion from .guild import Guild -__all__ = ( - 'Template', -) +__all__ = ("Template",) if TYPE_CHECKING: import datetime @@ -44,7 +42,7 @@ class _FriendlyHttpAttributeErrorHelper: __slots__ = () def __getattr__(self, attr): - raise AttributeError('PartialTemplateState does not support http methods.') + raise AttributeError("PartialTemplateState does not support http methods.") class _PartialTemplateState: @@ -84,7 +82,7 @@ class _PartialTemplateState: return [] def __getattr__(self, attr): - raise AttributeError(f'PartialTemplateState does not support {attr!r}.') + raise AttributeError(f"PartialTemplateState does not support {attr!r}.") class Template: @@ -118,16 +116,16 @@ class Template: """ __slots__ = ( - 'code', - 'uses', - 'name', - 'description', - 'creator', - 'created_at', - 'updated_at', - 'source_guild', - 'is_dirty', - '_state', + "code", + "uses", + "name", + "description", + "creator", + "created_at", + "updated_at", + "source_guild", + "is_dirty", + "_state", ) def __init__(self, *, state: ConnectionState, data: TemplatePayload) -> None: @@ -135,38 +133,46 @@ class Template: self._store(data) def _store(self, data: TemplatePayload) -> None: - self.code: str = data['code'] - self.uses: int = data['usage_count'] - self.name: str = data['name'] - self.description: Optional[str] = data['description'] - creator_data = data.get('creator') - self.creator: Optional[User] = None if creator_data is None else self._state.create_user(creator_data) + self.code: str = data["code"] + self.uses: int = data["usage_count"] + self.name: str = data["name"] + self.description: Optional[str] = data["description"] + creator_data = data.get("creator") + self.creator: Optional[User] = ( + None if creator_data is None else self._state.create_user(creator_data) + ) - self.created_at: Optional[datetime.datetime] = parse_time(data.get('created_at')) - self.updated_at: Optional[datetime.datetime] = parse_time(data.get('updated_at')) + self.created_at: Optional[datetime.datetime] = parse_time( + data.get("created_at") + ) + self.updated_at: Optional[datetime.datetime] = parse_time( + data.get("updated_at") + ) - guild_id = int(data['source_guild_id']) + guild_id = int(data["source_guild_id"]) guild: Optional[Guild] = self._state._get_guild(guild_id) self.source_guild: Guild if guild is None: - source_serialised = data['serialized_source_guild'] - source_serialised['id'] = guild_id + source_serialised = data["serialized_source_guild"] + source_serialised["id"] = guild_id state = _PartialTemplateState(state=self._state) # Guild expects a ConnectionState, we're passing a _PartialTemplateState self.source_guild = Guild(data=source_serialised, state=state) # type: ignore else: self.source_guild = guild - self.is_dirty: Optional[bool] = data.get('is_dirty', None) + self.is_dirty: Optional[bool] = data.get("is_dirty", None) def __repr__(self) -> str: return ( - f'