diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index 47c58c7d..81ca0b7d 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -64,6 +64,7 @@ from .core import GroupMixin from .converter import Greedy from .view import StringView, supported_quotes from .context import Context +from .flags import FlagConverter from . import errors from .help import HelpCommand, DefaultHelpCommand from .cog import Cog @@ -272,9 +273,9 @@ class BotBase(GroupMixin): for guild in guilds: commands[guild].append(payload) - http: HTTPClient = self.http # type: ignore + http: HTTPClient = self.http # type: ignore global_commands = commands.pop(None, None) - application_id = self.application_id or (await self.application_info()).id # type: ignore + application_id = self.application_id or (await self.application_info()).id # type: ignore if global_commands is not None: if self.slash_command_guilds is None: await http.bulk_upsert_global_commands( @@ -1271,7 +1272,20 @@ class BotBase(GroupMixin): ignore_params: List[inspect.Parameter] = [] message.content = f"{prefix}{command_name} " for name, param in command.clean_params.items(): - option = next((o for o in command_options if o["name"] == name), None) # type: ignore + if inspect.isclass(param.annotation) and issubclass(param.annotation, FlagConverter): + for name, flag in param.annotation.get_flags().items(): + option = next((o for o in command_options if o["name"] == name), None) + + if option is None: + if flag.required: + raise errors.MissingRequiredFlag(flag) + else: + prefix = param.annotation.__commands_flag_prefix__ + delimiter = param.annotation.__commands_flag_delimiter__ + message.content += f"{prefix}{name} {option['value']}{delimiter}" # type: ignore + continue + + option = next((o for o in command_options if o["name"] == name), None) if option is None: if param.default is param.empty and not command._is_typing_optional(param.annotation): raise errors.MissingRequiredArgument(param) @@ -1280,8 +1294,7 @@ class BotBase(GroupMixin): elif option["type"] == 3 and param.kind != param.KEYWORD_ONLY and not isinstance(param.annotation, Greedy): # String with space in without "consume rest" option = cast(_ApplicationCommandInteractionDataOptionString, option) - quoted_string = _quote_string_safe(option["value"]) - message.content += f"{quoted_string} " + message.content += f"{_quote_string_safe(option['value'])} " else: message.content += f'{option.get("value", "")} ' diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 4c6d2c5a..ed1b8bfe 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -58,13 +58,14 @@ from .converter import CONVERTER_MAPPING, Converter, run_converters, get_convert from ._types import _BaseCommand from .cog import Cog from .context import Context +from .flags import FlagConverter if TYPE_CHECKING: from typing_extensions import Concatenate, ParamSpec, TypeGuard from discord.message import Message - from discord.types.interactions import EditApplicationCommand + from discord.types.interactions import EditApplicationCommand, ApplicationCommandInteractionDataOption from ._types import ( Coro, @@ -1202,6 +1203,65 @@ class Command(_BaseCommand, Generic[CogT, P, T]): finally: ctx.command = original + def _param_to_options( + self, name: str, annotation: Any, required: bool, varadic: bool + ) -> List[Optional[ApplicationCommandInteractionDataOption]]: + + origin = getattr(annotation, "__origin__", None) + if inspect.isclass(annotation) and issubclass(annotation, FlagConverter): + return [ + param + for name, flag in annotation.get_flags().items() + for param in self._param_to_options( + name, flag.annotation, required=flag.required, varadic=flag.annotation is tuple + ) + ] + + if varadic: + annotation = str + origin = None + + if not required and origin is not None and len(annotation.__args__) == 2: + # Unpack Optional[T] (Union[T, None]) into just T + annotation, origin = annotation.__args__[0], None + + option: Dict[str, Any] = { + "name": name, + "required": required, + "description": self.option_descriptions[name], + } + + if origin is None: + if not inspect.isclass(annotation): + annotation = type(annotation) + + if issubclass(annotation, Converter): + # If this is a converter, we want to check if it is a native + # one, in which we can get the original type, eg, (MemberConverter -> Member) + annotation = REVERSED_CONVERTER_MAPPING.get(annotation, annotation) + + option["type"] = 3 + for python_type, discord_type in application_option_type_lookup.items(): + if issubclass(annotation, python_type): + option["type"] = discord_type + break + + elif origin is Literal: + literal_values = annotation.__args__ + python_type = type(literal_values[0]) + if ( + all(type(value) == python_type for value in literal_values) + and python_type in application_option_type_lookup.keys() + ): + + option["type"] = application_option_type_lookup[python_type] + option["choices"] = [ + {"name": literal_value, "value": literal_value} for literal_value in annotation.__args__ + ] + + option.setdefault("type", 3) # STRING + return [option] # type: ignore + def to_application_command(self, nested: int = 0) -> Optional[EditApplicationCommand]: if self.slash_command is False: return @@ -1213,55 +1273,15 @@ class Command(_BaseCommand, Generic[CogT, P, T]): payload["type"] = 1 for name, param in self.clean_params.items(): - annotation: Type[Any] = param.annotation if param.annotation is not param.empty else str - origin = getattr(param.annotation, "__origin__", None) - - if origin is None and isinstance(annotation, Greedy): - annotation = annotation.converter - origin = Greedy - - option: Dict[str, Any] = { - "name": name, - "description": self.option_descriptions[name], - "required": (param.default is param.empty and not self._is_typing_optional(annotation)) + options = self._param_to_options( + name, + param.annotation if param.annotation is not param.empty else str, + varadic=param.kind == param.KEYWORD_ONLY or isinstance(param.annotation, Greedy), + required=(param.default is param.empty and not self._is_typing_optional(param.annotation)) or param.kind == param.VAR_POSITIONAL, - } - - annotation = cast(Any, annotation) - if not option["required"] and origin is not None and len(annotation.__args__) == 2: - # Unpack Optional[T] (Union[T, None]) into just T - annotation, origin = annotation.__args__[0], None - - if origin is None: - if not inspect.isclass(annotation): - annotation = type(annotation) - - if issubclass(annotation, Converter): - # If this is a converter, we want to check if it is a native - # one, in which we can get the original type, eg, (MemberConverter -> Member) - annotation = REVERSED_CONVERTER_MAPPING.get(annotation, annotation) - - option["type"] = 3 - for python_type, discord_type in application_option_type_lookup.items(): - if issubclass(annotation, python_type): - option["type"] = discord_type - break - - elif origin is Literal: - literal_values = annotation.__args__ - python_type = type(literal_values[0]) - if ( - all(type(value) == python_type for value in literal_values) - and python_type in application_option_type_lookup.keys() - ): - - option["type"] = application_option_type_lookup[python_type] - option["choices"] = [ - {"name": literal_value, "value": literal_value} for literal_value in annotation.__args__ - ] - - option.setdefault("type", 3) # STRING - payload["options"].append(option) + ) + if options is not None: + payload["options"].extend(option for option in options if option is not None) # Now we have all options, make sure required is before optional. payload["options"] = sorted(payload["options"], key=itemgetter("required"), reverse=True)