mirror of
				https://github.com/Rapptz/discord.py.git
				synced 2025-10-24 18:13:00 +00:00 
			
		
		
		
	Add support for choice option parameters
This implements it in three different ways: * The first is using typing.Literal for quick and easy ones * The second is using enum.Enum for slightly more complex ones * The last is using a Choice type hint with a decorator to pass a list of choices. This should hopefully cover most use cases.
This commit is contained in:
		| @@ -77,6 +77,7 @@ __all__ = ( | ||||
|     'Group', | ||||
|     'command', | ||||
|     'describe', | ||||
|     'choices', | ||||
| ) | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
| @@ -171,6 +172,31 @@ def _populate_descriptions(params: Dict[str, CommandParameter], descriptions: Di | ||||
|         raise TypeError(f'unknown parameter given: {first}') | ||||
|  | ||||
|  | ||||
| def _populate_choices(params: Dict[str, CommandParameter], all_choices: Dict[str, List[Choice]]) -> None: | ||||
|     for name, param in params.items(): | ||||
|         choices = all_choices.pop(name, MISSING) | ||||
|         if choices is MISSING: | ||||
|             continue | ||||
|  | ||||
|         if not isinstance(choices, list): | ||||
|             raise TypeError('choices must be a list of Choice') | ||||
|  | ||||
|         if not all(isinstance(choice, Choice) for choice in choices): | ||||
|             raise TypeError('choices must be a list of Choice') | ||||
|  | ||||
|         if param.type not in (AppCommandOptionType.string, AppCommandOptionType.number, AppCommandOptionType.integer): | ||||
|             raise TypeError('choices are only supported for integer, string, or number option types') | ||||
|  | ||||
|         # There's a type safety hole if someone does Choice[float] as an annotation | ||||
|         # but the values are actually Choice[int]. Since the input-output is the same this feels | ||||
|         # safe enough to ignore. | ||||
|         param.choices = choices | ||||
|  | ||||
|     if all_choices: | ||||
|         first = next(iter(all_choices)) | ||||
|         raise TypeError(f'unknown parameter given: {first}') | ||||
|  | ||||
|  | ||||
| def _extract_parameters_from_callback(func: Callable[..., Any], globalns: Dict[str, Any]) -> Dict[str, CommandParameter]: | ||||
|     params = inspect.signature(func).parameters | ||||
|     cache = {} | ||||
| @@ -203,6 +229,13 @@ def _extract_parameters_from_callback(func: Callable[..., Any], globalns: Dict[s | ||||
|     else: | ||||
|         _populate_descriptions(result, descriptions) | ||||
|  | ||||
|     try: | ||||
|         choices = func.__discord_app_commands_param_choices__ | ||||
|     except AttributeError: | ||||
|         pass | ||||
|     else: | ||||
|         _populate_choices(result, choices) | ||||
|  | ||||
|     return result | ||||
|  | ||||
|  | ||||
| @@ -313,15 +346,15 @@ class Command(Generic[GroupT, P, T]): | ||||
|     async def _invoke_with_namespace(self, interaction: Interaction, namespace: Namespace) -> T: | ||||
|         values = namespace.__dict__ | ||||
|         for name, param in self._params.items(): | ||||
|             if not param.required: | ||||
|                 values.setdefault(name, param.default) | ||||
|             else: | ||||
|                 try: | ||||
|                     value = values[name] | ||||
|                 except KeyError: | ||||
|                     raise CommandSignatureMismatch(self) from None | ||||
|             try: | ||||
|                 value = values[name] | ||||
|             except KeyError: | ||||
|                 if not param.required: | ||||
|                     values[name] = param.default | ||||
|                 else: | ||||
|                     values[name] = await param.transform(interaction, value) | ||||
|                     raise CommandSignatureMismatch(self) from None | ||||
|             else: | ||||
|                 values[name] = await param.transform(interaction, value) | ||||
|  | ||||
|         # These type ignores are because the type checker doesn't quite understand the narrowing here | ||||
|         # Likewise, it thinks we're missing positional arguments when there aren't any. | ||||
| @@ -768,7 +801,7 @@ def describe(**parameters: str) -> Callable[[T], T]: | ||||
|     .. code-block:: python3 | ||||
|  | ||||
|         @app_commands.command() | ||||
|         @app_commads.describe(member='the member to ban') | ||||
|         @app_commands.describe(member='the member to ban') | ||||
|         async def ban(interaction: discord.Interaction, member: discord.Member): | ||||
|             await interaction.response.send_message(f'Banned {member}') | ||||
|  | ||||
| @@ -787,7 +820,79 @@ def describe(**parameters: str) -> Callable[[T], T]: | ||||
|         if isinstance(inner, Command): | ||||
|             _populate_descriptions(inner._params, parameters) | ||||
|         else: | ||||
|             inner.__discord_app_commands_param_description__ = parameters  # type: ignore - Runtime attribute assignment | ||||
|             try: | ||||
|                 inner.__discord_app_commands_param_description__.update(parameters)  # type: ignore - Runtime attribute access | ||||
|             except AttributeError: | ||||
|                 inner.__discord_app_commands_param_description__ = parameters  # type: ignore - Runtime attribute assignment | ||||
|  | ||||
|         return inner | ||||
|  | ||||
|     return decorator | ||||
|  | ||||
|  | ||||
| def choices(**parameters: List[Choice]) -> Callable[[T], T]: | ||||
|     r"""Instructs the given parameters by their name to use the given choices for their choices. | ||||
|  | ||||
|     Example: | ||||
|  | ||||
|     .. code-block:: python3 | ||||
|  | ||||
|         @app_commands.command() | ||||
|         @app_commands.describe(fruits='fruits to choose from') | ||||
|         @app_commands.choices(fruits=[ | ||||
|             Choice(name='apple', value=1), | ||||
|             Choice(name='banana', value=2), | ||||
|             Choice(name='cherry', value=3), | ||||
|         ]) | ||||
|         async def fruit(interaction: discord.Interaction, fruits: Choice[int]): | ||||
|             await interaction.response.send_message(f'Your favourite fruit is {fruits.name}.') | ||||
|  | ||||
|     .. note:: | ||||
|  | ||||
|         This is not the only way to provide choices to a command. There are two more ergonomic ways | ||||
|         of doing this. The first one is to use a :obj:`typing.Literal` annotation: | ||||
|  | ||||
|         .. code-block:: python3 | ||||
|  | ||||
|             @app_commands.command() | ||||
|             @app_commands.describe(fruits='fruits to choose from') | ||||
|             async def fruit(interaction: discord.Interaction, fruits: Literal['apple', 'banana', 'cherry']): | ||||
|                 await interaction.response.send_message(f'Your favourite fruit is {fruits}.') | ||||
|  | ||||
|         The second way is to use an :class:`enum.Enum`: | ||||
|  | ||||
|         .. code-block:: python3 | ||||
|  | ||||
|             class Fruits(enum.Enum): | ||||
|                 apple = 1 | ||||
|                 banana = 2 | ||||
|                 cherry = 3 | ||||
|  | ||||
|             @app_commands.command() | ||||
|             @app_commands.describe(fruits='fruits to choose from') | ||||
|             async def fruit(interaction: discord.Interaction, fruits: Fruits): | ||||
|                 await interaction.response.send_message(f'Your favourite fruit is {fruits}.') | ||||
|  | ||||
|  | ||||
|     Parameters | ||||
|     ----------- | ||||
|     \*\*parameters | ||||
|         The choices of the parameters. | ||||
|  | ||||
|     Raises | ||||
|     -------- | ||||
|     TypeError | ||||
|         The parameter name is not found. | ||||
|     """ | ||||
|  | ||||
|     def decorator(inner: T) -> T: | ||||
|         if isinstance(inner, Command): | ||||
|             _populate_choices(inner._params, parameters) | ||||
|         else: | ||||
|             try: | ||||
|                 inner.__discord_app_commands_param_choices__.update(parameters)  # type: ignore - Runtime attribute access | ||||
|             except AttributeError: | ||||
|                 inner.__discord_app_commands_param_choices__ = parameters  # type: ignore - Runtime attribute assignment | ||||
|  | ||||
|         return inner | ||||
|  | ||||
|   | ||||
| @@ -31,7 +31,7 @@ from ..enums import ChannelType, try_enum | ||||
| from ..mixins import Hashable | ||||
| from ..utils import _get_as_snowflake, parse_time, snowflake_time | ||||
| from .enums import AppCommandOptionType, AppCommandType | ||||
| from typing import List, NamedTuple, TYPE_CHECKING, Optional, Union | ||||
| from typing import Generic, List, NamedTuple, TYPE_CHECKING, Optional, TypeVar, Union | ||||
|  | ||||
| __all__ = ( | ||||
|     'AppCommand', | ||||
| @@ -42,6 +42,8 @@ __all__ = ( | ||||
|     'Choice', | ||||
| ) | ||||
|  | ||||
| ChoiceT = TypeVar('ChoiceT', str, int, float, Union[str, int, float]) | ||||
|  | ||||
|  | ||||
| def is_app_command_argument_type(value: int) -> bool: | ||||
|     return 11 >= value >= 3 | ||||
| @@ -145,7 +147,7 @@ class AppCommand(Hashable): | ||||
|         return f'<{self.__class__.__name__} id={self.id!r} name={self.name!r} type={self.type!r}>' | ||||
|  | ||||
|  | ||||
| class Choice(NamedTuple): | ||||
| class Choice(Generic[ChoiceT]): | ||||
|     """Represents an application command argument choice. | ||||
|  | ||||
|     .. versionadded:: 2.0 | ||||
| @@ -160,6 +162,10 @@ class Choice(NamedTuple): | ||||
|  | ||||
|             Checks if two choices are not equal. | ||||
|  | ||||
|         .. describe:: hash(x) | ||||
|  | ||||
|             Returns the choice's hash. | ||||
|  | ||||
|     Parameters | ||||
|     ----------- | ||||
|     name: :class:`str` | ||||
| @@ -168,8 +174,20 @@ class Choice(NamedTuple): | ||||
|         The value of the choice. | ||||
|     """ | ||||
|  | ||||
|     name: str | ||||
|     value: Union[int, str, float] | ||||
|     __slots__ = ('name', 'value') | ||||
|  | ||||
|     def __init__(self, *, name: str, value: ChoiceT): | ||||
|         self.name: str = name | ||||
|         self.value: ChoiceT = value | ||||
|  | ||||
|     def __eq__(self, o: object) -> bool: | ||||
|         return isinstance(o, Choice) and self.name == o.name and self.value == o.value | ||||
|  | ||||
|     def __hash__(self) -> int: | ||||
|         return hash((self.name, self.value)) | ||||
|  | ||||
|     def __repr__(self) -> str: | ||||
|         return f'{self.__class__.__name__}(name={self.name!r}, value={self.value!r})' | ||||
|  | ||||
|     def to_dict(self) -> ApplicationCommandOptionChoice: | ||||
|         return { | ||||
|   | ||||
| @@ -26,7 +26,8 @@ from __future__ import annotations | ||||
| import inspect | ||||
|  | ||||
| from dataclasses import dataclass | ||||
| from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union | ||||
| from enum import Enum | ||||
| from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional, Set, Tuple, Type, TypeVar, Union | ||||
|  | ||||
| from .enums import AppCommandOptionType | ||||
| from .errors import TransformerError | ||||
| @@ -113,6 +114,13 @@ class CommandParameter: | ||||
|  | ||||
|     async def transform(self, interaction: Interaction, value: Any) -> Any: | ||||
|         if hasattr(self._annotation, '__discord_app_commands_transformer__'): | ||||
|             # This one needs special handling for type safety reasons | ||||
|             if self._annotation.__discord_app_commands_is_choice__: | ||||
|                 choice = next((c for c in self.choices if c.value == value), None) | ||||
|                 if choice is None: | ||||
|                     raise TransformerError(value, self.type, self._annotation) | ||||
|                 return choice | ||||
|  | ||||
|             return await self._annotation.transform(interaction, value) | ||||
|         return value | ||||
|  | ||||
| @@ -149,6 +157,7 @@ class Transformer: | ||||
|     """ | ||||
|  | ||||
|     __discord_app_commands_transformer__: ClassVar[bool] = True | ||||
|     __discord_app_commands_is_choice__: ClassVar[bool] = False | ||||
|  | ||||
|     @classmethod | ||||
|     def type(cls) -> AppCommandOptionType: | ||||
| @@ -221,24 +230,93 @@ class _TransformMetadata: | ||||
|         self.metadata: Type[Transformer] = metadata | ||||
|  | ||||
|  | ||||
| async def _identity_transform(cls, interaction: Interaction, value: Any) -> Any: | ||||
|     return value | ||||
|  | ||||
|  | ||||
| def _make_range_transformer( | ||||
|     opt_type: AppCommandOptionType, | ||||
|     *, | ||||
|     min: Optional[Union[int, float]] = None, | ||||
|     max: Optional[Union[int, float]] = None, | ||||
| ) -> Type[Transformer]: | ||||
|     async def transform(cls, interaction: Interaction, value: Any) -> Any: | ||||
|         return value | ||||
|  | ||||
|     ns = { | ||||
|         'type': classmethod(lambda _: opt_type), | ||||
|         'min_value': classmethod(lambda _: min), | ||||
|         'max_value': classmethod(lambda _: max), | ||||
|         'transform': classmethod(transform), | ||||
|         'transform': classmethod(_identity_transform), | ||||
|     } | ||||
|     return type('RangeTransformer', (Transformer,), ns) | ||||
|  | ||||
|  | ||||
| def _make_literal_transformer(values: Tuple[Any, ...]) -> Type[Transformer]: | ||||
|     if len(values) < 2: | ||||
|         raise TypeError(f'typing.Literal requires at least two values.') | ||||
|  | ||||
|     first = type(values[0]) | ||||
|     if first is int: | ||||
|         opt_type = AppCommandOptionType.integer | ||||
|     elif first is float: | ||||
|         opt_type = AppCommandOptionType.number | ||||
|     elif first is str: | ||||
|         opt_type = AppCommandOptionType.string | ||||
|     else: | ||||
|         raise TypeError(f'expected int, str, or float values not {first!r}') | ||||
|  | ||||
|     ns = { | ||||
|         'type': classmethod(lambda _: opt_type), | ||||
|         'transform': classmethod(_identity_transform), | ||||
|         '__discord_app_commands_transformer_choices__': [Choice(name=str(v), value=v) for v in values], | ||||
|     } | ||||
|     return type('LiteralTransformer', (Transformer,), ns) | ||||
|  | ||||
|  | ||||
| def _make_choice_transformer(inner_type: Any) -> Type[Transformer]: | ||||
|     if inner_type is int: | ||||
|         opt_type = AppCommandOptionType.integer | ||||
|     elif inner_type is float: | ||||
|         opt_type = AppCommandOptionType.number | ||||
|     elif inner_type is str: | ||||
|         opt_type = AppCommandOptionType.string | ||||
|     else: | ||||
|         raise TypeError(f'expected int, str, or float values not {inner_type!r}') | ||||
|  | ||||
|     ns = { | ||||
|         'type': classmethod(lambda _: opt_type), | ||||
|         'transform': classmethod(_identity_transform), | ||||
|         '__discord_app_commands_is_choice__': True, | ||||
|     } | ||||
|     return type('ChoiceTransformer', (Transformer,), ns) | ||||
|  | ||||
|  | ||||
| def _make_enum_transformer(enum) -> Type[Transformer]: | ||||
|     values = list(enum) | ||||
|     if len(values) < 2: | ||||
|         raise TypeError(f'enum.Enum requires at least two values.') | ||||
|  | ||||
|     first = type(values[0].value) | ||||
|     if first is int: | ||||
|         opt_type = AppCommandOptionType.integer | ||||
|     elif first is float: | ||||
|         opt_type = AppCommandOptionType.number | ||||
|     elif first is str: | ||||
|         opt_type = AppCommandOptionType.string | ||||
|     else: | ||||
|         raise TypeError(f'expected int, str, or float values not {first!r}') | ||||
|  | ||||
|     async def transform(cls, interaction: Interaction, value: Any) -> Any: | ||||
|         return enum(value) | ||||
|  | ||||
|     ns = { | ||||
|         'type': classmethod(lambda _: opt_type), | ||||
|         'transform': classmethod(transform), | ||||
|         '__discord_app_commands_transformer_enum__': enum, | ||||
|         '__discord_app_commands_transformer_choices__': [Choice(name=v.name, value=v.value) for v in values], | ||||
|     } | ||||
|  | ||||
|     return type(f'{enum.__name__}EnumTransformer', (Transformer,), ns) | ||||
|  | ||||
|  | ||||
| if TYPE_CHECKING: | ||||
|     from typing_extensions import Annotated as Transform | ||||
|     from typing_extensions import Annotated as Range | ||||
| @@ -465,11 +543,24 @@ def get_supported_annotation( | ||||
|     if hasattr(annotation, '__discord_app_commands_transform__'): | ||||
|         return (annotation.metadata, MISSING) | ||||
|  | ||||
|     if inspect.isclass(annotation) and issubclass(annotation, Transformer): | ||||
|         return (annotation, MISSING) | ||||
|     if inspect.isclass(annotation): | ||||
|         if issubclass(annotation, Transformer): | ||||
|             return (annotation, MISSING) | ||||
|         if issubclass(annotation, Enum): | ||||
|             return (_make_enum_transformer(annotation), MISSING) | ||||
|         if annotation is Choice: | ||||
|             raise TypeError(f'Choice requires a type argument of int, str, or float') | ||||
|  | ||||
|     # Check if there's an origin | ||||
|     origin = getattr(annotation, '__origin__', None) | ||||
|     if origin is Literal: | ||||
|         args = annotation.__args__  # type: ignore | ||||
|         return (_make_literal_transformer(args), MISSING) | ||||
|  | ||||
|     if origin is Choice: | ||||
|         arg = annotation.__args__[0]  # type: ignore | ||||
|         return (_make_choice_transformer(arg), MISSING) | ||||
|  | ||||
|     if origin is not Union: | ||||
|         # Only Union/Optional is supported right now so bail early | ||||
|         raise TypeError(f'unsupported type annotation {annotation!r}') | ||||
| @@ -522,9 +613,11 @@ def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> Co | ||||
|  | ||||
|     # Verify validity of the default parameter | ||||
|     if default is not MISSING: | ||||
|         valid_types: Tuple[Any, ...] = ALLOWED_DEFAULTS.get(type, (NoneType,)) | ||||
|         if not isinstance(default, valid_types): | ||||
|             raise TypeError(f'invalid default parameter type given ({default.__class__}), expected {valid_types}') | ||||
|         enum_type = getattr(inner, '__discord_app_commands_transformer_enum__', None) | ||||
|         if default.__class__ is not enum_type: | ||||
|             valid_types: Tuple[Any, ...] = ALLOWED_DEFAULTS.get(type, (NoneType,)) | ||||
|             if not isinstance(default, valid_types): | ||||
|                 raise TypeError(f'invalid default parameter type given ({default.__class__}), expected {valid_types}') | ||||
|  | ||||
|     result = CommandParameter( | ||||
|         type=type, | ||||
| @@ -534,6 +627,13 @@ def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> Co | ||||
|         name=parameter.name, | ||||
|     ) | ||||
|  | ||||
|     try: | ||||
|         choices = inner.__discord_app_commands_transformer_choices__ | ||||
|     except AttributeError: | ||||
|         pass | ||||
|     else: | ||||
|         result.choices = choices | ||||
|  | ||||
|     # These methods should be duck typed | ||||
|     if type in (AppCommandOptionType.number, AppCommandOptionType.integer): | ||||
|         result.min_value = inner.min_value() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user