diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 178b252e..c135a9ba 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -53,7 +53,13 @@ from operator import itemgetter import discord from .errors import * -from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency, DynamicCooldownMapping +from .cooldowns import ( + Cooldown, + BucketType, + CooldownMapping, + MaxConcurrency, + DynamicCooldownMapping, +) from .converter import ( CONVERTER_MAPPING, Converter, @@ -74,7 +80,10 @@ if TYPE_CHECKING: from typing_extensions import Concatenate, ParamSpec, TypeGuard from discord.message import Message - from discord.types.interactions import EditApplicationCommand, ApplicationCommandInteractionDataOption + from discord.types.interactions import ( + EditApplicationCommand, + ApplicationCommandInteractionDataOption, + ) from ._types import ( Coro, @@ -330,10 +339,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]): This overwrites the global ``slash_commands`` parameter of :class:`.Bot`. .. versionadded:: 2.0 - slash_command_guilds: Optional[List[:class:`int`]] + guilds: Optional[List[:class:`int`]] If this is set, only upload this slash command to these guild IDs. - This overwrites the global ``slash_command_guilds`` parameter of :class:`.Bot`. + This overwrites the global ``guilds`` parameter of :class:`.Bot`. .. versionadded:: 2.0 @@ -382,7 +391,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.slash_command: Optional[bool] = kwargs.get("slash_command", None) self.message_command: Optional[bool] = kwargs.get("message_command", None) - self.slash_command_guilds: Optional[Iterable[int]] = kwargs.get("slash_command_guilds", None) + self.slash_command_guilds: Optional[Iterable[int]] = kwargs.get("guilds", None) help_doc = kwargs.get("help") if help_doc is not None: @@ -401,7 +410,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): self.extras: Dict[str, Any] = kwargs.get("extras", {}) if not isinstance(self.aliases, (list, tuple)): - raise TypeError("Aliases of a command must be a list or a tuple of strings.") + raise TypeError( + "Aliases of a command must be a list or a tuple of strings." + ) self.description: str = inspect.cleandoc(kwargs.get("description", "")) self.hidden: bool = kwargs.get("hidden", False) @@ -427,7 +438,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): elif isinstance(cooldown, CooldownMapping): buckets = cooldown else: - raise TypeError("Cooldown must be a an instance of CooldownMapping or None.") + raise TypeError( + "Cooldown must be a an instance of CooldownMapping or None." + ) self.checks: List[Check] = checks self._buckets: CooldownMapping = buckets @@ -468,7 +481,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]): @property def callback( self, - ) -> Union[Callable[Concatenate[CogT, Context, P], Coro[T]], Callable[Concatenate[Context, P], Coro[T]],]: + ) -> Union[ + Callable[Concatenate[CogT, Context, P], Coro[T]], + Callable[Concatenate[Context, P], Coro[T]], + ]: return self._callback @callback.setter @@ -488,7 +504,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): except AttributeError: globalns = {} - self.params, self.option_descriptions = get_signature_parameters(function, globalns) + self.params, self.option_descriptions = get_signature_parameters( + function, globalns + ) def _update_attrs(self, **command_attrs: Any): for key, value in command_attrs.items(): @@ -622,7 +640,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): required = param.default is param.empty converter = get_converter(param) - consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw + consume_rest_is_special = ( + param.kind == param.KEYWORD_ONLY and not self.rest_is_raw + ) view = ctx.view view.skip_ws() @@ -630,9 +650,13 @@ class Command(_BaseCommand, Generic[CogT, P, T]): # it undos the view ready for the next parameter to use instead if isinstance(converter, Greedy): if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY): - return await self._transform_greedy_pos(ctx, param, required, converter.converter) + return await self._transform_greedy_pos( + ctx, param, required, converter.converter + ) elif param.kind == param.VAR_POSITIONAL: - return await self._transform_greedy_var_pos(ctx, param, converter.converter) + return await self._transform_greedy_var_pos( + ctx, param, converter.converter + ) else: # if we're here, then it's a KEYWORD_ONLY param type # since this is mostly useless, we'll helpfully transform Greedy[X] @@ -645,7 +669,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]): if required: if self._is_typing_optional(param.annotation): return None - if hasattr(converter, "__commands_is_flag__") and converter._can_be_constructible(): + if ( + hasattr(converter, "__commands_is_flag__") + and converter._can_be_constructible() + ): return await converter._construct_default(ctx) raise MissingRequiredArgument(param) return param.default @@ -690,7 +717,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): return param.default return result - async def _transform_greedy_var_pos(self, ctx: Context, param: inspect.Parameter, converter: Any) -> Any: + async def _transform_greedy_var_pos( + self, ctx: Context, param: inspect.Parameter, converter: Any + ) -> Any: view = ctx.view previous = view.index try: @@ -804,13 +833,17 @@ class Command(_BaseCommand, Generic[CogT, P, T]): try: next(iterator) except StopIteration: - raise discord.ClientException(f'Callback for {self.name} command is missing "self" parameter.') + raise discord.ClientException( + f'Callback for {self.name} command is missing "self" parameter.' + ) # next we have the 'ctx' as the next parameter try: next(iterator) except StopIteration: - raise discord.ClientException(f'Callback for {self.name} command is missing "ctx" parameter.') + raise discord.ClientException( + f'Callback for {self.name} command is missing "ctx" parameter.' + ) for name, param in iterator: ctx.current_parameter = param @@ -837,7 +870,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): break if not self.ignore_extra and not view.eof: - raise TooManyArguments("Too many arguments passed to " + self.qualified_name) + raise TooManyArguments( + "Too many arguments passed to " + self.qualified_name + ) async def call_before_hooks(self, ctx: Context) -> None: # now that we're done preparing we can call the pre-command hooks @@ -897,7 +932,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): ctx.command = self if not await self.can_run(ctx): - raise CheckFailure(f"The check functions for command {self.qualified_name} failed.") + raise CheckFailure( + f"The check functions for command {self.qualified_name} failed." + ) if self._max_concurrency is not None: # For this application, context can be duck-typed as a Message @@ -1106,7 +1143,9 @@ class Command(_BaseCommand, Generic[CogT, P, T]): return self.help.split("\n", 1)[0] return "" - def _is_typing_optional(self, annotation: Union[T, Optional[T]]) -> TypeGuard[Optional[T]]: + def _is_typing_optional( + self, annotation: Union[T, Optional[T]] + ) -> TypeGuard[Optional[T]]: return getattr(annotation, "__origin__", None) is Union and type(None) in annotation.__args__ # type: ignore @property @@ -1137,13 +1176,24 @@ class Command(_BaseCommand, Generic[CogT, P, T]): origin = getattr(annotation, "__origin__", None) if origin is Literal: - name = "|".join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__) + name = "|".join( + f'"{v}"' if isinstance(v, str) else str(v) + for v in annotation.__args__ + ) if param.default is not param.empty: # We don't want None or '' to trigger the [name=value] case and instead it should # do [name] since [name=None] or [name=] are not exactly useful for the user. - should_print = param.default if isinstance(param.default, str) else param.default is not None + should_print = ( + param.default + if isinstance(param.default, str) + else param.default is not None + ) if should_print: - result.append(f"[{name}={param.default}]" if not greedy else f"[{name}={param.default}]...") + result.append( + f"[{name}={param.default}]" + if not greedy + else f"[{name}={param.default}]..." + ) continue else: result.append(f"[{name}]") @@ -1192,21 +1242,29 @@ class Command(_BaseCommand, Generic[CogT, P, T]): raise DisabledCommand(f"{self.name} command is disabled") if ctx.interaction is None and ( - self.message_command is False or (self.message_command is None and not ctx.bot.message_commands) + self.message_command is False + or (self.message_command is None and not ctx.bot.message_commands) ): - raise DisabledCommand(f"{self.name} command cannot be run as a message command") + raise DisabledCommand( + f"{self.name} command cannot be run as a message command" + ) if ctx.interaction is not None and ( - self.slash_command is False or (self.slash_command is None and not ctx.bot.slash_commands) + self.slash_command is False + or (self.slash_command is None and not ctx.bot.slash_commands) ): - raise DisabledCommand(f"{self.name} command cannot be run as a slash command") + raise DisabledCommand( + f"{self.name} command cannot be run as a slash command" + ) original = ctx.command ctx.command = self try: if not await ctx.bot.can_run(ctx): - raise CheckFailure(f"The global check functions for command {self.qualified_name} failed.") + raise CheckFailure( + f"The global check functions for command {self.qualified_name} failed." + ) cog = self.cog if cog is not None: @@ -1234,7 +1292,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]): param for name, flag in annotation.get_flags().items() for param in self._param_to_options( - name, flag.annotation, required=flag.required, varadic=flag.annotation is tuple + name, + flag.annotation, + required=flag.required, + varadic=flag.annotation is tuple, ) ] @@ -1268,7 +1329,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]): break elif origin is Union: - if annotation in {Union[discord.Member, discord.Role], Union[MemberConverter, RoleConverter]}: + if annotation in { + Union[discord.Member, discord.Role], + Union[MemberConverter, RoleConverter], + }: option["type"] = 9 elif origin is Literal: @@ -1281,18 +1345,27 @@ class Command(_BaseCommand, Generic[CogT, P, T]): option["type"] = application_option_type_lookup[python_type] option["choices"] = [ - {"name": literal_value, "value": literal_value} for literal_value in annotation.__args__ + {"name": literal_value, "value": literal_value} + for literal_value in annotation.__args__ ] return [option] # type: ignore - def to_application_command(self, nested: int = 0) -> Optional[EditApplicationCommand]: + def to_application_command( + self, nested: int = 0 + ) -> Optional[EditApplicationCommand]: if self.slash_command is False: return elif nested == 3: - raise ApplicationCommandRegistrationError(self, f"{self.qualified_name} is too deeply nested!") + raise ApplicationCommandRegistrationError( + self, f"{self.qualified_name} is too deeply nested!" + ) - payload = {"name": self.name, "description": self.short_doc or "no description", "options": []} + payload = { + "name": self.name, + "description": self.short_doc or "no description", + "options": [], + } if nested != 0: payload["type"] = 1 @@ -1300,15 +1373,23 @@ class Command(_BaseCommand, Generic[CogT, P, T]): options = self._param_to_options( name, param.annotation if param.annotation is not param.empty else str, - varadic=param.kind == param.KEYWORD_ONLY or isinstance(param.annotation, Greedy), - required=(param.default is param.empty and not self._is_typing_optional(param.annotation)) + varadic=param.kind == param.KEYWORD_ONLY + or isinstance(param.annotation, Greedy), + required=( + param.default is param.empty + and not self._is_typing_optional(param.annotation) + ) or param.kind == param.VAR_POSITIONAL, ) if options is not None: - payload["options"].extend(option for option in options if option is not None) + payload["options"].extend( + option for option in options if option is not None + ) # Now we have all options, make sure required is before optional. - payload["options"] = sorted(payload["options"], key=itemgetter("required"), reverse=True) + payload["options"] = sorted( + payload["options"], key=itemgetter("required"), reverse=True + ) return payload # type: ignore @@ -1327,7 +1408,9 @@ class GroupMixin(Generic[CogT]): def __init__(self, *args: Any, **kwargs: Any) -> None: case_insensitive = kwargs.get("case_insensitive", True) - self.all_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {} + self.all_commands: Dict[str, Command[CogT, Any, Any]] = ( + _CaseInsensitiveDict() if case_insensitive else {} + ) self.case_insensitive: bool = case_insensitive super().__init__(*args, **kwargs) @@ -1533,7 +1616,12 @@ class GroupMixin(Generic[CogT]): *args: Any, **kwargs: Any, ) -> Callable[ - [Union[Callable[Concatenate[CogT, ContextT, P], Coro[T]], Callable[Concatenate[ContextT, P], Coro[T]]]], + [ + Union[ + Callable[Concatenate[CogT, ContextT, P], Coro[T]], + Callable[Concatenate[ContextT, P], Coro[T]], + ] + ], Group[CogT, P, T], ]: ... @@ -1684,17 +1772,24 @@ class Group(GroupMixin[CogT], Command[CogT, P, T]): view.previous = previous await super().reinvoke(ctx, call_hooks=call_hooks) - def to_application_command(self, nested: int = 0) -> Optional[EditApplicationCommand]: + def to_application_command( + self, nested: int = 0 + ) -> Optional[EditApplicationCommand]: if self.slash_command is False: return elif nested == 2: - raise ApplicationCommandRegistrationError(self, f"{self.qualified_name} is too deeply nested!") + raise ApplicationCommandRegistrationError( + self, f"{self.qualified_name} is too deeply nested!" + ) return { # type: ignore "name": self.name, "type": int(not (nested - 1)) + 1, "description": self.short_doc or "no description", - "options": [cmd.to_application_command(nested=nested + 1) for cmd in sorted(self.commands, key=lambda x: x.name)], + "options": [ + cmd.to_application_command(nested=nested + 1) + for cmd in sorted(self.commands, key=lambda x: x.name) + ], } @@ -2001,7 +2096,9 @@ def check_any(*checks: Check) -> Callable[[T], T]: try: pred = wrapped.predicate except AttributeError: - raise TypeError(f"{wrapped!r} must be wrapped by commands.check decorator") from None + raise TypeError( + f"{wrapped!r} must be wrapped by commands.check decorator" + ) from None else: unwrapped.append(pred) @@ -2103,7 +2200,10 @@ def has_any_role(*items: Union[int, str]) -> Callable[[T], T]: # ctx.guild is None doesn't narrow ctx.author to Member getter = functools.partial(discord.utils.get, ctx.author.roles) # type: ignore if any( - getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items + getter(id=item) is not None + if isinstance(item, int) + else getter(name=item) is not None + for item in items ): return True raise MissingAnyRole(list(items)) @@ -2162,7 +2262,10 @@ def bot_has_any_role(*items: int) -> Callable[[T], T]: me = ctx.me getter = functools.partial(discord.utils.get, me.roles) if any( - getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items + getter(id=item) is not None + if isinstance(item, int) + else getter(name=item) is not None + for item in items ): return True raise BotMissingAnyRole(list(items)) @@ -2208,7 +2311,9 @@ def has_permissions(**perms: bool) -> Callable[[T], T]: ch = ctx.channel permissions = ch.permissions_for(ctx.author) # type: ignore - missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] + missing = [ + perm for perm, value in perms.items() if getattr(permissions, perm) != value + ] if not missing: return True @@ -2235,7 +2340,9 @@ def bot_has_permissions(**perms: bool) -> Callable[[T], T]: me = guild.me if guild is not None else ctx.bot.user permissions = ctx.channel.permissions_for(me) # type: ignore - missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] + missing = [ + perm for perm, value in perms.items() if getattr(permissions, perm) != value + ] if not missing: return True @@ -2264,7 +2371,9 @@ def has_guild_permissions(**perms: bool) -> Callable[[T], T]: raise NoPrivateMessage permissions = ctx.author.guild_permissions # type: ignore - missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] + missing = [ + perm for perm, value in perms.items() if getattr(permissions, perm) != value + ] if not missing: return True @@ -2290,7 +2399,9 @@ def bot_has_guild_permissions(**perms: bool) -> Callable[[T], T]: raise NoPrivateMessage permissions = ctx.me.guild_permissions # type: ignore - missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] + missing = [ + perm for perm, value in perms.items() if getattr(permissions, perm) != value + ] if not missing: return True @@ -2368,7 +2479,9 @@ def is_nsfw() -> Callable[[T], T]: def pred(ctx: Context) -> bool: ch = ctx.channel - if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()): + if ctx.guild is None or ( + isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw() + ): return True raise NSFWChannelRequired(ch) # type: ignore @@ -2376,7 +2489,9 @@ def is_nsfw() -> Callable[[T], T]: def cooldown( - rate: int, per: float, type: Union[BucketType, Callable[[Message], Any]] = BucketType.default + rate: int, + per: float, + type: Union[BucketType, Callable[[Message], Any]] = BucketType.default, ) -> Callable[[T], T]: """A decorator that adds a cooldown to a :class:`.Command` @@ -2411,14 +2526,17 @@ def cooldown( if not hasattr(func, "__command_attrs__"): func.__command_attrs__ = {} - func.__command_attrs__["cooldown"] = CooldownMapping(Cooldown(rate, per), type) + func.__command_attrs__["cooldown"] = CooldownMapping( + Cooldown(rate, per), type + ) return func return decorator # type: ignore def dynamic_cooldown( - cooldown: Union[BucketType, Callable[[Message], Any]], type: BucketType = BucketType.default + cooldown: Union[BucketType, Callable[[Message], Any]], + type: BucketType = BucketType.default, ) -> Callable[[T], T]: """A decorator that adds a dynamic cooldown to a :class:`.Command` @@ -2464,7 +2582,9 @@ def dynamic_cooldown( return decorator # type: ignore -def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait: bool = False) -> Callable[[T], T]: +def max_concurrency( + number: int, per: BucketType = BucketType.default, *, wait: bool = False +) -> Callable[[T], T]: """A decorator that adds a maximum concurrency to a :class:`.Command` or its subclasses. This enables you to only allow a certain number of command invocations at the same time,