From 8f846ba2f5e4412e9fb163af3bc9093502c383bd Mon Sep 17 00:00:00 2001 From: StockerMC <44980366+StockerMC@users.noreply.github.com> Date: Fri, 15 Oct 2021 13:32:55 -0400 Subject: [PATCH 1/5] Add the ability to set the option name with commands.Option --- discord/ext/commands/converter.py | 6 ++++-- discord/ext/commands/core.py | 6 +++++- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index feff650d..0e9a3daf 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -1039,16 +1039,18 @@ class Option(Generic[T, DT]): # type: ignore __slots__ = ( "default", "description", + "name", ) - def __init__(self, default: T = inspect.Parameter.empty, *, description: DT) -> None: + def __init__(self, default: T = inspect.Parameter.empty, *, description: DT, name: str = discord.utils.MISSING) -> None: self.description = description self.default = default + self.name: str = name if TYPE_CHECKING: # Terrible workaround for type checking reasons - def Option(default: T = inspect.Parameter.empty, *, description: str) -> T: + def Option(default: T = inspect.Parameter.empty, *, description: str, name: str = discord.utils.MISSING) -> T: ... diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 42ef5ff9..a23ae878 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -174,8 +174,12 @@ def get_signature_parameters( annotation = parameter.annotation if isinstance(parameter.default, Option): # type: ignore option = parameter.default - descriptions[name] = option.description parameter = parameter.replace(default=option.default) + if option.name is not MISSING: + name = option.name + parameter.replace(name=name) + + descriptions[name] = option.description if annotation is parameter.empty: params[name] = parameter -- 2.47.2 From 23b390971f694613e604618fdd3c47c2c0a5426d Mon Sep 17 00:00:00 2001 From: StockerMC <44980366+StockerMC@users.noreply.github.com> Date: Fri, 15 Oct 2021 13:35:57 -0400 Subject: [PATCH 2/5] Document commands.Option.name --- discord/ext/commands/converter.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index 0e9a3daf..4406490b 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -1032,6 +1032,8 @@ class Option(Generic[T, DT]): # type: ignore The default for this option, overwrites Option during parsing. description: :class:`str` The description for this option, is unpacked to :attr:`.Command.option_descriptions` + name: :class:`str` + The name of the option. This defaults to the parameter name. """ description: DT -- 2.47.2 From 161affa246288b4c4c5033e975736c1220ed90cc Mon Sep 17 00:00:00 2001 From: StockerMC <44980366+StockerMC@users.noreply.github.com> Date: Fri, 15 Oct 2021 13:58:15 -0400 Subject: [PATCH 3/5] Format with black --- discord/ext/commands/converter.py | 144 +++++++++++++----- discord/ext/commands/core.py | 238 +++++++++++++++++++++++------- 2 files changed, 294 insertions(+), 88 deletions(-) diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index 4406490b..62f93a6f 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -161,7 +161,9 @@ class ObjectConverter(IDConverter[discord.Object]): """ async def convert(self, ctx: Context, argument: str) -> discord.Object: - match = self._get_id_match(argument) or re.match(r"<(?:@(?:!|&)?|#)([0-9]{15,20})>$", argument) + match = self._get_id_match(argument) or re.match( + r"<(?:@(?:!|&)?|#)([0-9]{15,20})>$", argument + ) if match is None: raise ObjectNotFound(argument) @@ -198,10 +200,14 @@ class MemberConverter(IDConverter[discord.Member]): if len(argument) > 5 and argument[-5] == "#": username, _, discriminator = argument.rpartition("#") members = await guild.query_members(username, limit=100, cache=cache) - return discord.utils.get(members, name=username, discriminator=discriminator) + return discord.utils.get( + members, name=username, discriminator=discriminator + ) else: members = await guild.query_members(argument, limit=100, cache=cache) - return discord.utils.find(lambda m: m.name == argument or m.nick == argument, members) + return discord.utils.find( + lambda m: m.name == argument or m.nick == argument, members + ) async def query_member_by_id(self, bot, guild, user_id): ws = bot._get_websocket(shard_id=guild.shard_id) @@ -226,7 +232,9 @@ class MemberConverter(IDConverter[discord.Member]): async def convert(self, ctx: Context, argument: str) -> discord.Member: bot = ctx.bot - match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument) + match = self._get_id_match(argument) or re.match( + r"<@!?([0-9]{15,20})>$", argument + ) guild = ctx.guild result = None user_id = None @@ -239,7 +247,9 @@ class MemberConverter(IDConverter[discord.Member]): else: user_id = int(match.group(1)) if guild: - result = guild.get_member(user_id) or _utils_get(ctx.message.mentions, id=user_id) + result = guild.get_member(user_id) or _utils_get( + ctx.message.mentions, id=user_id + ) else: result = _get_from_guilds(bot, "get_member", user_id) @@ -279,13 +289,17 @@ class UserConverter(IDConverter[discord.User]): """ async def convert(self, ctx: Context, argument: str) -> discord.User: - match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument) + match = self._get_id_match(argument) or re.match( + r"<@!?([0-9]{15,20})>$", argument + ) result = None state = ctx._state if match is not None: user_id = int(match.group(1)) - result = ctx.bot.get_user(user_id) or _utils_get(ctx.message.mentions, id=user_id) + result = ctx.bot.get_user(user_id) or _utils_get( + ctx.message.mentions, id=user_id + ) if result is None: try: result = await ctx.bot.fetch_user(user_id) @@ -333,7 +347,9 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): @staticmethod def _get_id_matches(ctx, argument): - id_regex = re.compile(r"(?:(?P[0-9]{15,20})-)?(?P[0-9]{15,20})$") + id_regex = re.compile( + r"(?:(?P[0-9]{15,20})-)?(?P[0-9]{15,20})$" + ) link_regex = re.compile( r"https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/" r"(?P[0-9]{15,20}|@me)" @@ -355,7 +371,9 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): return guild_id, message_id, channel_id @staticmethod - def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]: + def _resolve_channel( + ctx, guild_id, channel_id + ) -> Optional[PartialMessageableChannel]: if guild_id is not None: guild = ctx.bot.get_guild(guild_id) if guild is not None and channel_id is not None: @@ -389,7 +407,9 @@ class MessageConverter(IDConverter[discord.Message]): """ async def convert(self, ctx: Context, argument: str) -> discord.Message: - guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument) + guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches( + ctx, argument + ) message = ctx.bot._connection._get_message(message_id) if message: return message @@ -420,13 +440,19 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.abc.GuildChannel: - return self._resolve_channel(ctx, argument, "channels", discord.abc.GuildChannel) + return self._resolve_channel( + ctx, argument, "channels", discord.abc.GuildChannel + ) @staticmethod - def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT: + def _resolve_channel( + ctx: Context, argument: str, attribute: str, type: Type[CT] + ) -> CT: bot = ctx.bot - match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument) + match = IDConverter._get_id_match(argument) or re.match( + r"<#([0-9]{15,20})>$", argument + ) result = None guild = ctx.guild @@ -454,10 +480,14 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): return result @staticmethod - def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT: + def _resolve_thread( + ctx: Context, argument: str, attribute: str, type: Type[TT] + ) -> TT: bot = ctx.bot - match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument) + match = IDConverter._get_id_match(argument) or re.match( + r"<#([0-9]{15,20})>$", argument + ) result = None guild = ctx.guild @@ -494,7 +524,9 @@ class TextChannelConverter(IDConverter[discord.TextChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.TextChannel: - return GuildChannelConverter._resolve_channel(ctx, argument, "text_channels", discord.TextChannel) + return GuildChannelConverter._resolve_channel( + ctx, argument, "text_channels", discord.TextChannel + ) class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): @@ -514,7 +546,9 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel: - return GuildChannelConverter._resolve_channel(ctx, argument, "voice_channels", discord.VoiceChannel) + return GuildChannelConverter._resolve_channel( + ctx, argument, "voice_channels", discord.VoiceChannel + ) class StageChannelConverter(IDConverter[discord.StageChannel]): @@ -533,7 +567,9 @@ class StageChannelConverter(IDConverter[discord.StageChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.StageChannel: - return GuildChannelConverter._resolve_channel(ctx, argument, "stage_channels", discord.StageChannel) + return GuildChannelConverter._resolve_channel( + ctx, argument, "stage_channels", discord.StageChannel + ) class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): @@ -553,7 +589,9 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel: - return GuildChannelConverter._resolve_channel(ctx, argument, "categories", discord.CategoryChannel) + return GuildChannelConverter._resolve_channel( + ctx, argument, "categories", discord.CategoryChannel + ) class StoreChannelConverter(IDConverter[discord.StoreChannel]): @@ -572,7 +610,9 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel: - return GuildChannelConverter._resolve_channel(ctx, argument, "channels", discord.StoreChannel) + return GuildChannelConverter._resolve_channel( + ctx, argument, "channels", discord.StoreChannel + ) class ThreadConverter(IDConverter[discord.Thread]): @@ -590,7 +630,9 @@ class ThreadConverter(IDConverter[discord.Thread]): """ async def convert(self, ctx: Context, argument: str) -> discord.Thread: - return GuildChannelConverter._resolve_thread(ctx, argument, "threads", discord.Thread) + return GuildChannelConverter._resolve_thread( + ctx, argument, "threads", discord.Thread + ) class ColourConverter(Converter[discord.Colour]): @@ -619,7 +661,9 @@ class ColourConverter(Converter[discord.Colour]): Added support for ``rgb`` function and 3-digit hex shortcuts """ - RGB_REGEX = re.compile(r"rgb\s*\((?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*\)") + RGB_REGEX = re.compile( + r"rgb\s*\((?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*\)" + ) def parse_hex_number(self, argument): arg = "".join(i * 2 for i in argument) if len(argument) == 3 else argument @@ -700,7 +744,9 @@ class RoleConverter(IDConverter[discord.Role]): if not guild: raise NoPrivateMessage() - match = self._get_id_match(argument) or re.match(r"<@&([0-9]{15,20})>$", argument) + match = self._get_id_match(argument) or re.match( + r"<@&([0-9]{15,20})>$", argument + ) if match: result = guild.get_role(int(match.group(1))) else: @@ -779,7 +825,9 @@ class EmojiConverter(IDConverter[discord.Emoji]): """ async def convert(self, ctx: Context, argument: str) -> discord.Emoji: - match = self._get_id_match(argument) or re.match(r"$", argument) + match = self._get_id_match(argument) or re.match( + r"$", argument + ) result = None bot = ctx.bot guild = ctx.guild @@ -821,7 +869,10 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]): emoji_id = int(match.group(3)) return discord.PartialEmoji.with_state( - ctx.bot._connection, animated=emoji_animated, name=emoji_name, id=emoji_id + ctx.bot._connection, + animated=emoji_animated, + name=emoji_name, + id=emoji_id, ) raise PartialEmojiConversionFailure(argument) @@ -906,7 +957,11 @@ class clean_content(Converter[str]): def resolve_member(id: int) -> str: m = _utils_get(msg.mentions, id=id) or ctx.guild.get_member(id) - return f"@{m.display_name if self.use_nicknames else m.name}" if m else "@deleted-user" + return ( + f"@{m.display_name if self.use_nicknames else m.name}" + if m + else "@deleted-user" + ) def resolve_role(id: int) -> str: r = _utils_get(msg.role_mentions, id=id) or ctx.guild.get_role(id) @@ -996,7 +1051,11 @@ class Greedy(List[T]): origin = getattr(converter, "__origin__", None) args = getattr(converter, "__args__", ()) - if not (callable(converter) or isinstance(converter, Converter) or origin is not None): + if not ( + callable(converter) + or isinstance(converter, Converter) + or origin is not None + ): raise TypeError("Greedy[...] expects a type or a Converter instance.") if converter in (str, type(None)) or origin is Greedy: @@ -1044,7 +1103,13 @@ class Option(Generic[T, DT]): # type: ignore "name", ) - def __init__(self, default: T = inspect.Parameter.empty, *, description: DT, name: str = discord.utils.MISSING) -> None: + def __init__( + self, + default: T = inspect.Parameter.empty, + *, + description: DT, + name: str = discord.utils.MISSING, + ) -> None: self.description = description self.default = default self.name: str = name @@ -1052,7 +1117,12 @@ class Option(Generic[T, DT]): # type: ignore if TYPE_CHECKING: # Terrible workaround for type checking reasons - def Option(default: T = inspect.Parameter.empty, *, description: str, name: str = discord.utils.MISSING) -> T: + def Option( + default: T = inspect.Parameter.empty, + *, + description: str, + name: str = discord.utils.MISSING, + ) -> T: ... @@ -1107,7 +1177,9 @@ CONVERTER_MAPPING: Dict[Type[Any], Any] = { } -async def _actual_conversion(ctx: Context, converter, argument: str, param: inspect.Parameter): +async def _actual_conversion( + ctx: Context, converter, argument: str, param: inspect.Parameter +): if converter is bool: return _convert_to_bool(argument) @@ -1116,7 +1188,9 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp except AttributeError: pass else: - if module is not None and (module.startswith("discord.") and not module.endswith("converter")): + if module is not None and ( + module.startswith("discord.") and not module.endswith("converter") + ): converter = CONVERTER_MAPPING.get(converter, converter) try: @@ -1142,10 +1216,14 @@ async def _actual_conversion(ctx: Context, converter, argument: str, param: insp except AttributeError: name = converter.__class__.__name__ - raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc + raise BadArgument( + f'Converting to "{name}" failed for parameter "{param.name}".' + ) from exc -async def run_converters(ctx: Context, converter, argument: str, param: inspect.Parameter): +async def run_converters( + ctx: Context, converter, argument: str, param: inspect.Parameter +): """|coro| Runs converters for a given converter, argument, and parameter. diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index a23ae878..edcc4ee7 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -53,7 +53,13 @@ from operator import itemgetter import discord from .errors import * -from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency, DynamicCooldownMapping +from .cooldowns import ( + Cooldown, + BucketType, + CooldownMapping, + MaxConcurrency, + DynamicCooldownMapping, +) from .converter import ( CONVERTER_MAPPING, Converter, @@ -74,7 +80,10 @@ if TYPE_CHECKING: from typing_extensions import Concatenate, ParamSpec, TypeGuard from discord.message import Message - from discord.types.interactions import EditApplicationCommand, ApplicationCommandInteractionDataOption + from discord.types.interactions import ( + EditApplicationCommand, + ApplicationCommandInteractionDataOption, + ) from ._types import ( Coro, @@ -394,7 +403,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.slash_command: Optional[bool] = kwargs.get("slash_command", None) self.message_command: Optional[bool] = kwargs.get("message_command", None) - self.slash_command_guilds: Optional[Iterable[int]] = kwargs.get("slash_command_guilds", None) + self.slash_command_guilds: Optional[Iterable[int]] = kwargs.get( + "slash_command_guilds", None + ) help_doc = kwargs.get("help") if help_doc is not None: @@ -413,7 +424,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.extras: Dict[str, Any] = kwargs.get("extras", {}) if not isinstance(self.aliases, (list, tuple)): - raise TypeError("Aliases of a command must be a list or a tuple of strings.") + raise TypeError( + "Aliases of a command must be a list or a tuple of strings." + ) self.description: str = inspect.cleandoc(kwargs.get("description", "")) self.hidden: bool = kwargs.get("hidden", False) @@ -439,7 +452,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): elif isinstance(cooldown, CooldownMapping): buckets = cooldown else: - raise TypeError("Cooldown must be a an instance of CooldownMapping or None.") + raise TypeError( + "Cooldown must be a an instance of CooldownMapping or None." + ) self.checks: List[Check] = checks self._buckets: CooldownMapping = buckets @@ -480,7 +495,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]): @property def callback( self, - ) -> Union[Callable[Concatenate[CogT, Context, P], Coro[T]], Callable[Concatenate[Context, P], Coro[T]],]: + ) -> Union[ + Callable[Concatenate[CogT, Context, P], Coro[T]], + Callable[Concatenate[Context, P], Coro[T]], + ]: return self._callback @callback.setter @@ -500,7 +518,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): except AttributeError: globalns = {} - self.params, self.option_descriptions = get_signature_parameters(function, globalns) + self.params, self.option_descriptions = get_signature_parameters( + function, globalns + ) def _update_attrs(self, **command_attrs: Any): for key, value in command_attrs.items(): @@ -634,7 +654,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): required = param.default is param.empty converter = get_converter(param) - consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw + consume_rest_is_special = ( + param.kind == param.KEYWORD_ONLY and not self.rest_is_raw + ) view = ctx.view view.skip_ws() @@ -642,9 +664,13 @@ class Command(_BaseCommand, Generic[CogT, P, T]): # it undos the view ready for the next parameter to use instead if isinstance(converter, Greedy): if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY): - return await self._transform_greedy_pos(ctx, param, required, converter.converter) + return await self._transform_greedy_pos( + ctx, param, required, converter.converter + ) elif param.kind == param.VAR_POSITIONAL: - return await self._transform_greedy_var_pos(ctx, param, converter.converter) + return await self._transform_greedy_var_pos( + ctx, param, converter.converter + ) else: # if we're here, then it's a KEYWORD_ONLY param type # since this is mostly useless, we'll helpfully transform Greedy[X] @@ -657,7 +683,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]): if required: if self._is_typing_optional(param.annotation): return None - if hasattr(converter, "__commands_is_flag__") and converter._can_be_constructible(): + if ( + hasattr(converter, "__commands_is_flag__") + and converter._can_be_constructible() + ): return await converter._construct_default(ctx) raise MissingRequiredArgument(param) return param.default @@ -702,7 +731,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): return param.default return result - async def _transform_greedy_var_pos(self, ctx: Context, param: inspect.Parameter, converter: Any) -> Any: + async def _transform_greedy_var_pos( + self, ctx: Context, param: inspect.Parameter, converter: Any + ) -> Any: view = ctx.view previous = view.index try: @@ -816,13 +847,17 @@ class Command(_BaseCommand, Generic[CogT, P, T]): try: next(iterator) except StopIteration: - raise discord.ClientException(f'Callback for {self.name} command is missing "self" parameter.') + raise discord.ClientException( + f'Callback for {self.name} command is missing "self" parameter.' + ) # next we have the 'ctx' as the next parameter try: next(iterator) except StopIteration: - raise discord.ClientException(f'Callback for {self.name} command is missing "ctx" parameter.') + raise discord.ClientException( + f'Callback for {self.name} command is missing "ctx" parameter.' + ) for name, param in iterator: ctx.current_parameter = param @@ -849,7 +884,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): break if not self.ignore_extra and not view.eof: - raise TooManyArguments("Too many arguments passed to " + self.qualified_name) + raise TooManyArguments( + "Too many arguments passed to " + self.qualified_name + ) async def call_before_hooks(self, ctx: Context) -> None: # now that we're done preparing we can call the pre-command hooks @@ -909,7 +946,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): ctx.command = self if not await self.can_run(ctx): - raise CheckFailure(f"The check functions for command {self.qualified_name} failed.") + raise CheckFailure( + f"The check functions for command {self.qualified_name} failed." + ) if self._max_concurrency is not None: # For this application, context can be duck-typed as a Message @@ -1118,7 +1157,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): return self.help.split("\n", 1)[0] return "" - def _is_typing_optional(self, annotation: Union[T, Optional[T]]) -> TypeGuard[Optional[T]]: + def _is_typing_optional( + self, annotation: Union[T, Optional[T]] + ) -> TypeGuard[Optional[T]]: return getattr(annotation, "__origin__", None) is Union and type(None) in annotation.__args__ # type: ignore @property @@ -1149,13 +1190,24 @@ class Command(_BaseCommand, Generic[CogT, P, T]): origin = getattr(annotation, "__origin__", None) if origin is Literal: - name = "|".join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__) + name = "|".join( + f'"{v}"' if isinstance(v, str) else str(v) + for v in annotation.__args__ + ) if param.default is not param.empty: # We don't want None or '' to trigger the [name=value] case and instead it should # do [name] since [name=None] or [name=] are not exactly useful for the user. - should_print = param.default if isinstance(param.default, str) else param.default is not None + should_print = ( + param.default + if isinstance(param.default, str) + else param.default is not None + ) if should_print: - result.append(f"[{name}={param.default}]" if not greedy else f"[{name}={param.default}]...") + result.append( + f"[{name}={param.default}]" + if not greedy + else f"[{name}={param.default}]..." + ) continue else: result.append(f"[{name}]") @@ -1204,21 +1256,29 @@ class Command(_BaseCommand, Generic[CogT, P, T]): raise DisabledCommand(f"{self.name} command is disabled") if ctx.interaction is None and ( - self.message_command is False or (self.message_command is None and not ctx.bot.message_commands) + self.message_command is False + or (self.message_command is None and not ctx.bot.message_commands) ): - raise DisabledCommand(f"{self.name} command cannot be run as a message command") + raise DisabledCommand( + f"{self.name} command cannot be run as a message command" + ) if ctx.interaction is not None and ( - self.slash_command is False or (self.slash_command is None and not ctx.bot.slash_commands) + self.slash_command is False + or (self.slash_command is None and not ctx.bot.slash_commands) ): - raise DisabledCommand(f"{self.name} command cannot be run as a slash command") + raise DisabledCommand( + f"{self.name} command cannot be run as a slash command" + ) original = ctx.command ctx.command = self try: if not await ctx.bot.can_run(ctx): - raise CheckFailure(f"The global check functions for command {self.qualified_name} failed.") + raise CheckFailure( + f"The global check functions for command {self.qualified_name} failed." + ) cog = self.cog if cog is not None: @@ -1246,7 +1306,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]): param for name, flag in annotation.get_flags().items() for param in self._param_to_options( - name, flag.annotation, required=flag.required, varadic=flag.annotation is tuple + name, + flag.annotation, + required=flag.required, + varadic=flag.annotation is tuple, ) ] @@ -1279,16 +1342,27 @@ class Command(_BaseCommand, Generic[CogT, P, T]): option["type"] = discord_type # Set channel types if discord_type == 7: - option["channel_types"] = application_option_channel_types[annotation] + option["channel_types"] = application_option_channel_types[ + annotation + ] break elif origin is Union: - if annotation in {Union[discord.Member, discord.Role], Union[MemberConverter, RoleConverter]}: + if annotation in { + Union[discord.Member, discord.Role], + Union[MemberConverter, RoleConverter], + }: option["type"] = 9 - elif all([arg in application_option_channel_types for arg in annotation.__args__]): + elif all( + [arg in application_option_channel_types for arg in annotation.__args__] + ): option["type"] = 7 - option["channel_types"] = [discord_value for arg in annotation.__args__ for discord_value in application_option_channel_types[arg]] + option["channel_types"] = [ + discord_value + for arg in annotation.__args__ + for discord_value in application_option_channel_types[arg] + ] elif origin is Literal: literal_values = annotation.__args__ @@ -1300,18 +1374,27 @@ class Command(_BaseCommand, Generic[CogT, P, T]): option["type"] = application_option_type_lookup[python_type] option["choices"] = [ - {"name": literal_value, "value": literal_value} for literal_value in annotation.__args__ + {"name": literal_value, "value": literal_value} + for literal_value in annotation.__args__ ] return [option] # type: ignore - def to_application_command(self, nested: int = 0) -> Optional[EditApplicationCommand]: + def to_application_command( + self, nested: int = 0 + ) -> Optional[EditApplicationCommand]: if self.slash_command is False: return elif nested == 3: - raise ApplicationCommandRegistrationError(self, f"{self.qualified_name} is too deeply nested!") + raise ApplicationCommandRegistrationError( + self, f"{self.qualified_name} is too deeply nested!" + ) - payload = {"name": self.name, "description": self.short_doc or "no description", "options": []} + payload = { + "name": self.name, + "description": self.short_doc or "no description", + "options": [], + } if nested != 0: payload["type"] = 1 @@ -1319,15 +1402,23 @@ class Command(_BaseCommand, Generic[CogT, P, T]): options = self._param_to_options( name, param.annotation if param.annotation is not param.empty else str, - varadic=param.kind == param.KEYWORD_ONLY or isinstance(param.annotation, Greedy), - required=(param.default is param.empty and not self._is_typing_optional(param.annotation)) + varadic=param.kind == param.KEYWORD_ONLY + or isinstance(param.annotation, Greedy), + required=( + param.default is param.empty + and not self._is_typing_optional(param.annotation) + ) or param.kind == param.VAR_POSITIONAL, ) if options is not None: - payload["options"].extend(option for option in options if option is not None) + payload["options"].extend( + option for option in options if option is not None + ) # Now we have all options, make sure required is before optional. - payload["options"] = sorted(payload["options"], key=itemgetter("required"), reverse=True) + payload["options"] = sorted( + payload["options"], key=itemgetter("required"), reverse=True + ) return payload # type: ignore @@ -1346,7 +1437,9 @@ class GroupMixin(Generic[CogT]): def __init__(self, *args: Any, **kwargs: Any) -> None: case_insensitive = kwargs.get("case_insensitive", True) - self.all_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {} + self.all_commands: Dict[str, Command[CogT, Any, Any]] = ( + _CaseInsensitiveDict() if case_insensitive else {} + ) self.case_insensitive: bool = case_insensitive super().__init__(*args, **kwargs) @@ -1552,7 +1645,12 @@ class GroupMixin(Generic[CogT]): *args: Any, **kwargs: Any, ) -> Callable[ - [Union[Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]]]], + [ + Union[ + Callable[Concatenate[CogT, ContextT, P], Coro[T]], + Callable[Concatenate[ContextT, P], Coro[T]], + ] + ], Group[CogT, P, T], ]: ... @@ -1703,18 +1801,23 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]): view.previous = previous await super().reinvoke(ctx, call_hooks=call_hooks) - def to_application_command(self, nested: int = 0) -> Optional[EditApplicationCommand]: + def to_application_command( + self, nested: int = 0 + ) -> Optional[EditApplicationCommand]: if self.slash_command is False: return elif nested == 2: - raise ApplicationCommandRegistrationError(self, f"{self.qualified_name} is too deeply nested!") + raise ApplicationCommandRegistrationError( + self, f"{self.qualified_name} is too deeply nested!" + ) return { # type: ignore "name": self.name, "type": int(not (nested - 1)) + 1, "description": self.short_doc or "no description", "options": [ - cmd.to_application_command(nested=nested + 1) for cmd in sorted(self.commands, key=lambda x: x.name) + cmd.to_application_command(nested=nested + 1) + for cmd in sorted(self.commands, key=lambda x: x.name) ], } @@ -2022,7 +2125,9 @@ def check_any(*checks: Check) -> Callable[[T], T]: try: pred = wrapped.predicate except AttributeError: - raise TypeError(f"{wrapped!r} must be wrapped by commands.check decorator") from None + raise TypeError( + f"{wrapped!r} must be wrapped by commands.check decorator" + ) from None else: unwrapped.append(pred) @@ -2124,7 +2229,10 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]: # ctx.guild is None doesn't narrow ctx.author to Member getter = functools.partial(discord.utils.get, ctx.author.roles) # type: ignore if any( - getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items + getter(id=item) is not None + if isinstance(item, int) + else getter(name=item) is not None + for item in items ): return True raise MissingAnyRole(list(items)) @@ -2183,7 +2291,10 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]: me = ctx.me getter = functools.partial(discord.utils.get, me.roles) if any( - getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items + getter(id=item) is not None + if isinstance(item, int) + else getter(name=item) is not None + for item in items ): return True raise BotMissingAnyRole(list(items)) @@ -2229,7 +2340,9 @@ def has_permissions(**perms: bool) -> Callable[[T], T]: ch = ctx.channel permissions = ch.permissions_for(ctx.author) # type: ignore - missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] + missing = [ + perm for perm, value in perms.items() if getattr(permissions, perm) != value + ] if not missing: return True @@ -2256,7 +2369,9 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]: me = guild.me if guild is not None else ctx.bot.user permissions = ctx.channel.permissions_for(me) # type: ignore - missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] + missing = [ + perm for perm, value in perms.items() if getattr(permissions, perm) != value + ] if not missing: return True @@ -2285,7 +2400,9 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]: raise NoPrivateMessage permissions = ctx.author.guild_permissions # type: ignore - missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] + missing = [ + perm for perm, value in perms.items() if getattr(permissions, perm) != value + ] if not missing: return True @@ -2311,7 +2428,9 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]: raise NoPrivateMessage permissions = ctx.me.guild_permissions # type: ignore - missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] + missing = [ + perm for perm, value in perms.items() if getattr(permissions, perm) != value + ] if not missing: return True @@ -2389,7 +2508,9 @@ def is_nsfw() -> Callable[[T], T]: def pred(ctx: Context) -> bool: ch = ctx.channel - if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()): + if ctx.guild is None or ( + isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw() + ): return True raise NSFWChannelRequired(ch) # type: ignore @@ -2397,7 +2518,9 @@ def is_nsfw() -> Callable[[T], T]: def cooldown( - rate: int, per: float, type: Union[BucketType, Callable[[Message], Any]] = BucketType.default + rate: int, + per: float, + type: Union[BucketType, Callable[[Message], Any]] = BucketType.default, ) -> Callable[[T], T]: """A decorator that adds a cooldown to a :class:`.Command` @@ -2432,14 +2555,17 @@ def cooldown( if not hasattr(func, "__command_attrs__"): func.__command_attrs__ = {} - func.__command_attrs__["cooldown"] = CooldownMapping(Cooldown(rate, per), type) + func.__command_attrs__["cooldown"] = CooldownMapping( + Cooldown(rate, per), type + ) return func return decorator # type: ignore def dynamic_cooldown( - cooldown: Union[BucketType, Callable[[Message], Any]], type: BucketType = BucketType.default + cooldown: Union[BucketType, Callable[[Message], Any]], + type: BucketType = BucketType.default, ) -> Callable[[T], T]: """A decorator that adds a dynamic cooldown to a :class:`.Command` @@ -2485,7 +2611,9 @@ def dynamic_cooldown( return decorator # type: ignore -def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait: bool = False) -> Callable[[T], T]: +def max_concurrency( + number: int, per: BucketType = BucketType.default, *, wait: bool = False +) -> Callable[[T], T]: """A decorator that adds a maximum concurrency to a :class:`.Command` or its subclasses. This enables you to only allow a certain number of command invocations at the same time, -- 2.47.2 From 022bbd3d5147c961fd6cbe2604e5d310316e2858 Mon Sep 17 00:00:00 2001 From: StockerMC <44980366+StockerMC@users.noreply.github.com> Date: Fri, 15 Oct 2021 13:35:57 -0400 Subject: [PATCH 4/5] Revert "Format with black" This reverts commit 161affa246288b4c4c5033e975736c1220ed90cc. --- discord/ext/commands/converter.py | 144 +++++------------- discord/ext/commands/core.py | 238 +++++++----------------------- 2 files changed, 88 insertions(+), 294 deletions(-) diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index 62f93a6f..4406490b 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -161,9 +161,7 @@ class ObjectConverter(IDConverter[discord.Object]): """ async def convert(self, ctx: Context, argument: str) -> discord.Object: - match = self._get_id_match(argument) or re.match( - r"<(?:@(?:!|&)?|#)([0-9]{15,20})>$", argument - ) + match = self._get_id_match(argument) or re.match(r"<(?:@(?:!|&)?|#)([0-9]{15,20})>$", argument) if match is None: raise ObjectNotFound(argument) @@ -200,14 +198,10 @@ class MemberConverter(IDConverter[discord.Member]): if len(argument) > 5 and argument[-5] == "#": username, _, discriminator = argument.rpartition("#") members = await guild.query_members(username, limit=100, cache=cache) - return discord.utils.get( - members, name=username, discriminator=discriminator - ) + return discord.utils.get(members, name=username, discriminator=discriminator) else: members = await guild.query_members(argument, limit=100, cache=cache) - return discord.utils.find( - lambda m: m.name == argument or m.nick == argument, members - ) + return discord.utils.find(lambda m: m.name == argument or m.nick == argument, members) async def query_member_by_id(self, bot, guild, user_id): ws = bot._get_websocket(shard_id=guild.shard_id) @@ -232,9 +226,7 @@ class MemberConverter(IDConverter[discord.Member]): async def convert(self, ctx: Context, argument: str) -> discord.Member: bot = ctx.bot - match = self._get_id_match(argument) or re.match( - r"<@!?([0-9]{15,20})>$", argument - ) + match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument) guild = ctx.guild result = None user_id = None @@ -247,9 +239,7 @@ class MemberConverter(IDConverter[discord.Member]): else: user_id = int(match.group(1)) if guild: - result = guild.get_member(user_id) or _utils_get( - ctx.message.mentions, id=user_id - ) + result = guild.get_member(user_id) or _utils_get(ctx.message.mentions, id=user_id) else: result = _get_from_guilds(bot, "get_member", user_id) @@ -289,17 +279,13 @@ class UserConverter(IDConverter[discord.User]): """ async def convert(self, ctx: Context, argument: str) -> discord.User: - match = self._get_id_match(argument) or re.match( - r"<@!?([0-9]{15,20})>$", argument - ) + match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument) result = None state = ctx._state if match is not None: user_id = int(match.group(1)) - result = ctx.bot.get_user(user_id) or _utils_get( - ctx.message.mentions, id=user_id - ) + result = ctx.bot.get_user(user_id) or _utils_get(ctx.message.mentions, id=user_id) if result is None: try: result = await ctx.bot.fetch_user(user_id) @@ -347,9 +333,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): @staticmethod def _get_id_matches(ctx, argument): - id_regex = re.compile( - r"(?:(?P[0-9]{15,20})-)?(?P[0-9]{15,20})$" - ) + id_regex = re.compile(r"(?:(?P[0-9]{15,20})-)?(?P[0-9]{15,20})$") link_regex = re.compile( r"https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/" r"(?P[0-9]{15,20}|@me)" @@ -371,9 +355,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): return guild_id, message_id, channel_id @staticmethod - def _resolve_channel( - ctx, guild_id, channel_id - ) -> Optional[PartialMessageableChannel]: + def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]: if guild_id is not None: guild = ctx.bot.get_guild(guild_id) if guild is not None and channel_id is not None: @@ -407,9 +389,7 @@ class MessageConverter(IDConverter[discord.Message]): """ async def convert(self, ctx: Context, argument: str) -> discord.Message: - guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches( - ctx, argument - ) + guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument) message = ctx.bot._connection._get_message(message_id) if message: return message @@ -440,19 +420,13 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.abc.GuildChannel: - return self._resolve_channel( - ctx, argument, "channels", discord.abc.GuildChannel - ) + return self._resolve_channel(ctx, argument, "channels", discord.abc.GuildChannel) @staticmethod - def _resolve_channel( - ctx: Context, argument: str, attribute: str, type: Type[CT] - ) -> CT: + def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT: bot = ctx.bot - match = IDConverter._get_id_match(argument) or re.match( - r"<#([0-9]{15,20})>$", argument - ) + match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument) result = None guild = ctx.guild @@ -480,14 +454,10 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): return result @staticmethod - def _resolve_thread( - ctx: Context, argument: str, attribute: str, type: Type[TT] - ) -> TT: + def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT: bot = ctx.bot - match = IDConverter._get_id_match(argument) or re.match( - r"<#([0-9]{15,20})>$", argument - ) + match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument) result = None guild = ctx.guild @@ -524,9 +494,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.TextChannel: - return GuildChannelConverter._resolve_channel( - ctx, argument, "text_channels", discord.TextChannel - ) + return GuildChannelConverter._resolve_channel(ctx, argument, "text_channels", discord.TextChannel) class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): @@ -546,9 +514,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel: - return GuildChannelConverter._resolve_channel( - ctx, argument, "voice_channels", discord.VoiceChannel - ) + return GuildChannelConverter._resolve_channel(ctx, argument, "voice_channels", discord.VoiceChannel) class StageChannelConverter(IDConverter[discord.StageChannel]): @@ -567,9 +533,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.StageChannel: - return GuildChannelConverter._resolve_channel( - ctx, argument, "stage_channels", discord.StageChannel - ) + return GuildChannelConverter._resolve_channel(ctx, argument, "stage_channels", discord.StageChannel) class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): @@ -589,9 +553,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel: - return GuildChannelConverter._resolve_channel( - ctx, argument, "categories", discord.CategoryChannel - ) + return GuildChannelConverter._resolve_channel(ctx, argument, "categories", discord.CategoryChannel) class StoreChannelConverter(IDConverter[discord.StoreChannel]): @@ -610,9 +572,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel: - return GuildChannelConverter._resolve_channel( - ctx, argument, "channels", discord.StoreChannel - ) + return GuildChannelConverter._resolve_channel(ctx, argument, "channels", discord.StoreChannel) class ThreadConverter(IDConverter[discord.Thread]): @@ -630,9 +590,7 @@ class ThreadConverter(IDConverter[discord.Thread]): """ async def convert(self, ctx: Context, argument: str) -> discord.Thread: - return GuildChannelConverter._resolve_thread( - ctx, argument, "threads", discord.Thread - ) + return GuildChannelConverter._resolve_thread(ctx, argument, "threads", discord.Thread) class ColourConverter(Converter[discord.Colour]): @@ -661,9 +619,7 @@ class ColourConverter(Converter[discord.Colour]): Added support for ``rgb`` function and 3-digit hex shortcuts """ - RGB_REGEX = re.compile( - r"rgb\s*\((?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*\)" - ) + RGB_REGEX = re.compile(r"rgb\s*\((?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*\)") def parse_hex_number(self, argument): arg = "".join(i * 2 for i in argument) if len(argument) == 3 else argument @@ -744,9 +700,7 @@ class RoleConverter(IDConverter[discord.Role]): if not guild: raise NoPrivateMessage() - match = self._get_id_match(argument) or re.match( - r"<@&([0-9]{15,20})>$", argument - ) + match = self._get_id_match(argument) or re.match(r"<@&([0-9]{15,20})>$", argument) if match: result = guild.get_role(int(match.group(1))) else: @@ -825,9 +779,7 @@ class EmojiConverter(IDConverter[discord.Emoji]): """ async def convert(self, ctx: Context, argument: str) -> discord.Emoji: - match = self._get_id_match(argument) or re.match( - r"$", argument - ) + match = self._get_id_match(argument) or re.match(r"$", argument) result = None bot = ctx.bot guild = ctx.guild @@ -869,10 +821,7 @@ class PartialEmojiConverter(Converter[discord.PartialEmoji]): emoji_id = int(match.group(3)) return discord.PartialEmoji.with_state( - ctx.bot._connection, - animated=emoji_animated, - name=emoji_name, - id=emoji_id, + ctx.bot._connection, animated=emoji_animated, name=emoji_name, id=emoji_id ) raise PartialEmojiConversionFailure(argument) @@ -957,11 +906,7 @@ class clean_content(Converter[str]): def resolve_member(id: int) -> str: m = _utils_get(msg.mentions, id=id) or ctx.guild.get_member(id) - return ( - f"@{m.display_name if self.use_nicknames else m.name}" - if m - else "@deleted-user" - ) + return f"@{m.display_name if self.use_nicknames else m.name}" if m else "@deleted-user" def resolve_role(id: int) -> str: r = _utils_get(msg.role_mentions, id=id) or ctx.guild.get_role(id) @@ -1051,11 +996,7 @@ class Greedy(List[T]): origin = getattr(converter, "__origin__", None) args = getattr(converter, "__args__", ()) - if not ( - callable(converter) - or isinstance(converter, Converter) - or origin is not None - ): + if not (callable(converter) or isinstance(converter, Converter) or origin is not None): raise TypeError("Greedy[...] expects a type or a Converter instance.") if converter in (str, type(None)) or origin is Greedy: @@ -1103,13 +1044,7 @@ class Option(Generic[T, DT]): # type: ignore "name", ) - def __init__( - self, - default: T = inspect.Parameter.empty, - *, - description: DT, - name: str = discord.utils.MISSING, - ) -> None: + def __init__(self, default: T = inspect.Parameter.empty, *, description: DT, name: str = discord.utils.MISSING) -> None: self.description = description self.default = default self.name: str = name @@ -1117,12 +1052,7 @@ class Option(Generic[T, DT]): # type: ignore if TYPE_CHECKING: # Terrible workaround for type checking reasons - def Option( - default: T = inspect.Parameter.empty, - *, - description: str, - name: str = discord.utils.MISSING, - ) -> T: + def Option(default: T = inspect.Parameter.empty, *, description: str, name: str = discord.utils.MISSING) -> T: ... @@ -1177,9 +1107,7 @@ CONVERTER_MAPPING: Dict[Type[Any], Any] = { } -async def _actual_conversion( - ctx: Context, converter, argument: str, param: inspect.Parameter -): +async def _actual_conversion(ctx: Context, converter, argument: str, param: inspect.Parameter): if converter is bool: return _convert_to_bool(argument) @@ -1188,9 +1116,7 @@ async def _actual_conversion( except AttributeError: pass else: - if module is not None and ( - module.startswith("discord.") and not module.endswith("converter") - ): + if module is not None and (module.startswith("discord.") and not module.endswith("converter")): converter = CONVERTER_MAPPING.get(converter, converter) try: @@ -1216,14 +1142,10 @@ async def _actual_conversion( except AttributeError: name = converter.__class__.__name__ - raise BadArgument( - f'Converting to "{name}" failed for parameter "{param.name}".' - ) from exc + raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc -async def run_converters( - ctx: Context, converter, argument: str, param: inspect.Parameter -): +async def run_converters(ctx: Context, converter, argument: str, param: inspect.Parameter): """|coro| Runs converters for a given converter, argument, and parameter. diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index edcc4ee7..a23ae878 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -53,13 +53,7 @@ from operator import itemgetter import discord from .errors import * -from .cooldowns import ( - Cooldown, - BucketType, - CooldownMapping, - MaxConcurrency, - DynamicCooldownMapping, -) +from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency, DynamicCooldownMapping from .converter import ( CONVERTER_MAPPING, Converter, @@ -80,10 +74,7 @@ if TYPE_CHECKING: from typing_extensions import Concatenate, ParamSpec, TypeGuard from discord.message import Message - from discord.types.interactions import ( - EditApplicationCommand, - ApplicationCommandInteractionDataOption, - ) + from discord.types.interactions import EditApplicationCommand, ApplicationCommandInteractionDataOption from ._types import ( Coro, @@ -403,9 +394,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.slash_command: Optional[bool] = kwargs.get("slash_command", None) self.message_command: Optional[bool] = kwargs.get("message_command", None) - self.slash_command_guilds: Optional[Iterable[int]] = kwargs.get( - "slash_command_guilds", None - ) + self.slash_command_guilds: Optional[Iterable[int]] = kwargs.get("slash_command_guilds", None) help_doc = kwargs.get("help") if help_doc is not None: @@ -424,9 +413,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.extras: Dict[str, Any] = kwargs.get("extras", {}) if not isinstance(self.aliases, (list, tuple)): - raise TypeError( - "Aliases of a command must be a list or a tuple of strings." - ) + raise TypeError("Aliases of a command must be a list or a tuple of strings.") self.description: str = inspect.cleandoc(kwargs.get("description", "")) self.hidden: bool = kwargs.get("hidden", False) @@ -452,9 +439,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): elif isinstance(cooldown, CooldownMapping): buckets = cooldown else: - raise TypeError( - "Cooldown must be a an instance of CooldownMapping or None." - ) + raise TypeError("Cooldown must be a an instance of CooldownMapping or None.") self.checks: List[Check] = checks self._buckets: CooldownMapping = buckets @@ -495,10 +480,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): @property def callback( self, - ) -> Union[ - Callable[Concatenate[CogT, Context, P], Coro[T]], - Callable[Concatenate[Context, P], Coro[T]], - ]: + ) -> Union[Callable[Concatenate[CogT, Context, P], Coro[T]], Callable[Concatenate[Context, P], Coro[T]],]: return self._callback @callback.setter @@ -518,9 +500,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): except AttributeError: globalns = {} - self.params, self.option_descriptions = get_signature_parameters( - function, globalns - ) + self.params, self.option_descriptions = get_signature_parameters(function, globalns) def _update_attrs(self, **command_attrs: Any): for key, value in command_attrs.items(): @@ -654,9 +634,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): required = param.default is param.empty converter = get_converter(param) - consume_rest_is_special = ( - param.kind == param.KEYWORD_ONLY and not self.rest_is_raw - ) + consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw view = ctx.view view.skip_ws() @@ -664,13 +642,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): # it undos the view ready for the next parameter to use instead if isinstance(converter, Greedy): if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY): - return await self._transform_greedy_pos( - ctx, param, required, converter.converter - ) + return await self._transform_greedy_pos(ctx, param, required, converter.converter) elif param.kind == param.VAR_POSITIONAL: - return await self._transform_greedy_var_pos( - ctx, param, converter.converter - ) + return await self._transform_greedy_var_pos(ctx, param, converter.converter) else: # if we're here, then it's a KEYWORD_ONLY param type # since this is mostly useless, we'll helpfully transform Greedy[X] @@ -683,10 +657,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): if required: if self._is_typing_optional(param.annotation): return None - if ( - hasattr(converter, "__commands_is_flag__") - and converter._can_be_constructible() - ): + if hasattr(converter, "__commands_is_flag__") and converter._can_be_constructible(): return await converter._construct_default(ctx) raise MissingRequiredArgument(param) return param.default @@ -731,9 +702,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): return param.default return result - async def _transform_greedy_var_pos( - self, ctx: Context, param: inspect.Parameter, converter: Any - ) -> Any: + async def _transform_greedy_var_pos(self, ctx: Context, param: inspect.Parameter, converter: Any) -> Any: view = ctx.view previous = view.index try: @@ -847,17 +816,13 @@ class Command(_BaseCommand, Generic[CogT, P, T]): try: next(iterator) except StopIteration: - raise discord.ClientException( - f'Callback for {self.name} command is missing "self" parameter.' - ) + raise discord.ClientException(f'Callback for {self.name} command is missing "self" parameter.') # next we have the 'ctx' as the next parameter try: next(iterator) except StopIteration: - raise discord.ClientException( - f'Callback for {self.name} command is missing "ctx" parameter.' - ) + raise discord.ClientException(f'Callback for {self.name} command is missing "ctx" parameter.') for name, param in iterator: ctx.current_parameter = param @@ -884,9 +849,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): break if not self.ignore_extra and not view.eof: - raise TooManyArguments( - "Too many arguments passed to " + self.qualified_name - ) + raise TooManyArguments("Too many arguments passed to " + self.qualified_name) async def call_before_hooks(self, ctx: Context) -> None: # now that we're done preparing we can call the pre-command hooks @@ -946,9 +909,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): ctx.command = self if not await self.can_run(ctx): - raise CheckFailure( - f"The check functions for command {self.qualified_name} failed." - ) + raise CheckFailure(f"The check functions for command {self.qualified_name} failed.") if self._max_concurrency is not None: # For this application, context can be duck-typed as a Message @@ -1157,9 +1118,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): return self.help.split("\n", 1)[0] return "" - def _is_typing_optional( - self, annotation: Union[T, Optional[T]] - ) -> TypeGuard[Optional[T]]: + def _is_typing_optional(self, annotation: Union[T, Optional[T]]) -> TypeGuard[Optional[T]]: return getattr(annotation, "__origin__", None) is Union and type(None) in annotation.__args__ # type: ignore @property @@ -1190,24 +1149,13 @@ class Command(_BaseCommand, Generic[CogT, P, T]): origin = getattr(annotation, "__origin__", None) if origin is Literal: - name = "|".join( - f'"{v}"' if isinstance(v, str) else str(v) - for v in annotation.__args__ - ) + name = "|".join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__) if param.default is not param.empty: # We don't want None or '' to trigger the [name=value] case and instead it should # do [name] since [name=None] or [name=] are not exactly useful for the user. - should_print = ( - param.default - if isinstance(param.default, str) - else param.default is not None - ) + should_print = param.default if isinstance(param.default, str) else param.default is not None if should_print: - result.append( - f"[{name}={param.default}]" - if not greedy - else f"[{name}={param.default}]..." - ) + result.append(f"[{name}={param.default}]" if not greedy else f"[{name}={param.default}]...") continue else: result.append(f"[{name}]") @@ -1256,29 +1204,21 @@ class Command(_BaseCommand, Generic[CogT, P, T]): raise DisabledCommand(f"{self.name} command is disabled") if ctx.interaction is None and ( - self.message_command is False - or (self.message_command is None and not ctx.bot.message_commands) + self.message_command is False or (self.message_command is None and not ctx.bot.message_commands) ): - raise DisabledCommand( - f"{self.name} command cannot be run as a message command" - ) + raise DisabledCommand(f"{self.name} command cannot be run as a message command") if ctx.interaction is not None and ( - self.slash_command is False - or (self.slash_command is None and not ctx.bot.slash_commands) + self.slash_command is False or (self.slash_command is None and not ctx.bot.slash_commands) ): - raise DisabledCommand( - f"{self.name} command cannot be run as a slash command" - ) + raise DisabledCommand(f"{self.name} command cannot be run as a slash command") original = ctx.command ctx.command = self try: if not await ctx.bot.can_run(ctx): - raise CheckFailure( - f"The global check functions for command {self.qualified_name} failed." - ) + raise CheckFailure(f"The global check functions for command {self.qualified_name} failed.") cog = self.cog if cog is not None: @@ -1306,10 +1246,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): param for name, flag in annotation.get_flags().items() for param in self._param_to_options( - name, - flag.annotation, - required=flag.required, - varadic=flag.annotation is tuple, + name, flag.annotation, required=flag.required, varadic=flag.annotation is tuple ) ] @@ -1342,27 +1279,16 @@ class Command(_BaseCommand, Generic[CogT, P, T]): option["type"] = discord_type # Set channel types if discord_type == 7: - option["channel_types"] = application_option_channel_types[ - annotation - ] + option["channel_types"] = application_option_channel_types[annotation] break elif origin is Union: - if annotation in { - Union[discord.Member, discord.Role], - Union[MemberConverter, RoleConverter], - }: + if annotation in {Union[discord.Member, discord.Role], Union[MemberConverter, RoleConverter]}: option["type"] = 9 - elif all( - [arg in application_option_channel_types for arg in annotation.__args__] - ): + elif all([arg in application_option_channel_types for arg in annotation.__args__]): option["type"] = 7 - option["channel_types"] = [ - discord_value - for arg in annotation.__args__ - for discord_value in application_option_channel_types[arg] - ] + option["channel_types"] = [discord_value for arg in annotation.__args__ for discord_value in application_option_channel_types[arg]] elif origin is Literal: literal_values = annotation.__args__ @@ -1374,27 +1300,18 @@ class Command(_BaseCommand, Generic[CogT, P, T]): option["type"] = application_option_type_lookup[python_type] option["choices"] = [ - {"name": literal_value, "value": literal_value} - for literal_value in annotation.__args__ + {"name": literal_value, "value": literal_value} for literal_value in annotation.__args__ ] return [option] # type: ignore - def to_application_command( - self, nested: int = 0 - ) -> Optional[EditApplicationCommand]: + def to_application_command(self, nested: int = 0) -> Optional[EditApplicationCommand]: if self.slash_command is False: return elif nested == 3: - raise ApplicationCommandRegistrationError( - self, f"{self.qualified_name} is too deeply nested!" - ) + raise ApplicationCommandRegistrationError(self, f"{self.qualified_name} is too deeply nested!") - payload = { - "name": self.name, - "description": self.short_doc or "no description", - "options": [], - } + payload = {"name": self.name, "description": self.short_doc or "no description", "options": []} if nested != 0: payload["type"] = 1 @@ -1402,23 +1319,15 @@ class Command(_BaseCommand, Generic[CogT, P, T]): options = self._param_to_options( name, param.annotation if param.annotation is not param.empty else str, - varadic=param.kind == param.KEYWORD_ONLY - or isinstance(param.annotation, Greedy), - required=( - param.default is param.empty - and not self._is_typing_optional(param.annotation) - ) + varadic=param.kind == param.KEYWORD_ONLY or isinstance(param.annotation, Greedy), + required=(param.default is param.empty and not self._is_typing_optional(param.annotation)) or param.kind == param.VAR_POSITIONAL, ) if options is not None: - payload["options"].extend( - option for option in options if option is not None - ) + payload["options"].extend(option for option in options if option is not None) # Now we have all options, make sure required is before optional. - payload["options"] = sorted( - payload["options"], key=itemgetter("required"), reverse=True - ) + payload["options"] = sorted(payload["options"], key=itemgetter("required"), reverse=True) return payload # type: ignore @@ -1437,9 +1346,7 @@ class GroupMixin(Generic[CogT]): def __init__(self, *args: Any, **kwargs: Any) -> None: case_insensitive = kwargs.get("case_insensitive", True) - self.all_commands: Dict[str, Command[CogT, Any, Any]] = ( - _CaseInsensitiveDict() if case_insensitive else {} - ) + self.all_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {} self.case_insensitive: bool = case_insensitive super().__init__(*args, **kwargs) @@ -1645,12 +1552,7 @@ class GroupMixin(Generic[CogT]): *args: Any, **kwargs: Any, ) -> Callable[ - [ - Union[ - Callable[Concatenate[CogT, ContextT, P], Coro[T]], - Callable[Concatenate[ContextT, P], Coro[T]], - ] - ], + [Union[Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]]]], Group[CogT, P, T], ]: ... @@ -1801,23 +1703,18 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]): view.previous = previous await super().reinvoke(ctx, call_hooks=call_hooks) - def to_application_command( - self, nested: int = 0 - ) -> Optional[EditApplicationCommand]: + def to_application_command(self, nested: int = 0) -> Optional[EditApplicationCommand]: if self.slash_command is False: return elif nested == 2: - raise ApplicationCommandRegistrationError( - self, f"{self.qualified_name} is too deeply nested!" - ) + raise ApplicationCommandRegistrationError(self, f"{self.qualified_name} is too deeply nested!") return { # type: ignore "name": self.name, "type": int(not (nested - 1)) + 1, "description": self.short_doc or "no description", "options": [ - cmd.to_application_command(nested=nested + 1) - for cmd in sorted(self.commands, key=lambda x: x.name) + cmd.to_application_command(nested=nested + 1) for cmd in sorted(self.commands, key=lambda x: x.name) ], } @@ -2125,9 +2022,7 @@ def check_any(*checks: Check) -> Callable[[T], T]: try: pred = wrapped.predicate except AttributeError: - raise TypeError( - f"{wrapped!r} must be wrapped by commands.check decorator" - ) from None + raise TypeError(f"{wrapped!r} must be wrapped by commands.check decorator") from None else: unwrapped.append(pred) @@ -2229,10 +2124,7 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]: # ctx.guild is None doesn't narrow ctx.author to Member getter = functools.partial(discord.utils.get, ctx.author.roles) # type: ignore if any( - getter(id=item) is not None - if isinstance(item, int) - else getter(name=item) is not None - for item in items + getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items ): return True raise MissingAnyRole(list(items)) @@ -2291,10 +2183,7 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]: me = ctx.me getter = functools.partial(discord.utils.get, me.roles) if any( - getter(id=item) is not None - if isinstance(item, int) - else getter(name=item) is not None - for item in items + getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items ): return True raise BotMissingAnyRole(list(items)) @@ -2340,9 +2229,7 @@ def has_permissions(**perms: bool) -> Callable[[T], T]: ch = ctx.channel permissions = ch.permissions_for(ctx.author) # type: ignore - missing = [ - perm for perm, value in perms.items() if getattr(permissions, perm) != value - ] + missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] if not missing: return True @@ -2369,9 +2256,7 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]: me = guild.me if guild is not None else ctx.bot.user permissions = ctx.channel.permissions_for(me) # type: ignore - missing = [ - perm for perm, value in perms.items() if getattr(permissions, perm) != value - ] + missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] if not missing: return True @@ -2400,9 +2285,7 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]: raise NoPrivateMessage permissions = ctx.author.guild_permissions # type: ignore - missing = [ - perm for perm, value in perms.items() if getattr(permissions, perm) != value - ] + missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] if not missing: return True @@ -2428,9 +2311,7 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]: raise NoPrivateMessage permissions = ctx.me.guild_permissions # type: ignore - missing = [ - perm for perm, value in perms.items() if getattr(permissions, perm) != value - ] + missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] if not missing: return True @@ -2508,9 +2389,7 @@ def is_nsfw() -> Callable[[T], T]: def pred(ctx: Context) -> bool: ch = ctx.channel - if ctx.guild is None or ( - isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw() - ): + if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()): return True raise NSFWChannelRequired(ch) # type: ignore @@ -2518,9 +2397,7 @@ def is_nsfw() -> Callable[[T], T]: def cooldown( - rate: int, - per: float, - type: Union[BucketType, Callable[[Message], Any]] = BucketType.default, + rate: int, per: float, type: Union[BucketType, Callable[[Message], Any]] = BucketType.default ) -> Callable[[T], T]: """A decorator that adds a cooldown to a :class:`.Command` @@ -2555,17 +2432,14 @@ def cooldown( if not hasattr(func, "__command_attrs__"): func.__command_attrs__ = {} - func.__command_attrs__["cooldown"] = CooldownMapping( - Cooldown(rate, per), type - ) + func.__command_attrs__["cooldown"] = CooldownMapping(Cooldown(rate, per), type) return func return decorator # type: ignore def dynamic_cooldown( - cooldown: Union[BucketType, Callable[[Message], Any]], - type: BucketType = BucketType.default, + cooldown: Union[BucketType, Callable[[Message], Any]], type: BucketType = BucketType.default ) -> Callable[[T], T]: """A decorator that adds a dynamic cooldown to a :class:`.Command` @@ -2611,9 +2485,7 @@ def dynamic_cooldown( return decorator # type: ignore -def max_concurrency( - number: int, per: BucketType = BucketType.default, *, wait: bool = False -) -> Callable[[T], T]: +def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait: bool = False) -> Callable[[T], T]: """A decorator that adds a maximum concurrency to a :class:`.Command` or its subclasses. This enables you to only allow a certain number of command invocations at the same time, -- 2.47.2 From 8dba9897d4a9be2c278f812bbbf4e0ee96e69076 Mon Sep 17 00:00:00 2001 From: StockerMC <44980366+StockerMC@users.noreply.github.com> Date: Sat, 16 Oct 2021 09:55:50 -0400 Subject: [PATCH 5/5] Format with black --- discord/ext/commands/converter.py | 4 +++- discord/ext/commands/core.py | 6 +++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index 4406490b..bbef6719 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -1044,7 +1044,9 @@ class Option(Generic[T, DT]): # type: ignore "name", ) - def __init__(self, default: T = inspect.Parameter.empty, *, description: DT, name: str = discord.utils.MISSING) -> None: + def __init__( + self, default: T = inspect.Parameter.empty, *, description: DT, name: str = discord.utils.MISSING + ) -> None: self.description = description self.default = default self.name: str = name diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index a23ae878..3d6a9631 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -1288,7 +1288,11 @@ class Command(_BaseCommand, Generic[CogT, P, T]): elif all([arg in application_option_channel_types for arg in annotation.__args__]): option["type"] = 7 - option["channel_types"] = [discord_value for arg in annotation.__args__ for discord_value in application_option_channel_types[arg]] + option["channel_types"] = [ + discord_value + for arg in annotation.__args__ + for discord_value in application_option_channel_types[arg] + ] elif origin is Literal: literal_values = annotation.__args__ -- 2.47.2