mirror of
				https://github.com/Rapptz/discord.py.git
				synced 2025-11-03 23:12:56 +00:00 
			
		
		
		
	Add support for choice option parameters
This implements it in three different ways: * The first is using typing.Literal for quick and easy ones * The second is using enum.Enum for slightly more complex ones * The last is using a Choice type hint with a decorator to pass a list of choices. This should hopefully cover most use cases.
This commit is contained in:
		@@ -77,6 +77,7 @@ __all__ = (
 | 
			
		||||
    'Group',
 | 
			
		||||
    'command',
 | 
			
		||||
    'describe',
 | 
			
		||||
    'choices',
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
@@ -171,6 +172,31 @@ def _populate_descriptions(params: Dict[str, CommandParameter], descriptions: Di
 | 
			
		||||
        raise TypeError(f'unknown parameter given: {first}')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _populate_choices(params: Dict[str, CommandParameter], all_choices: Dict[str, List[Choice]]) -> None:
 | 
			
		||||
    for name, param in params.items():
 | 
			
		||||
        choices = all_choices.pop(name, MISSING)
 | 
			
		||||
        if choices is MISSING:
 | 
			
		||||
            continue
 | 
			
		||||
 | 
			
		||||
        if not isinstance(choices, list):
 | 
			
		||||
            raise TypeError('choices must be a list of Choice')
 | 
			
		||||
 | 
			
		||||
        if not all(isinstance(choice, Choice) for choice in choices):
 | 
			
		||||
            raise TypeError('choices must be a list of Choice')
 | 
			
		||||
 | 
			
		||||
        if param.type not in (AppCommandOptionType.string, AppCommandOptionType.number, AppCommandOptionType.integer):
 | 
			
		||||
            raise TypeError('choices are only supported for integer, string, or number option types')
 | 
			
		||||
 | 
			
		||||
        # There's a type safety hole if someone does Choice[float] as an annotation
 | 
			
		||||
        # but the values are actually Choice[int]. Since the input-output is the same this feels
 | 
			
		||||
        # safe enough to ignore.
 | 
			
		||||
        param.choices = choices
 | 
			
		||||
 | 
			
		||||
    if all_choices:
 | 
			
		||||
        first = next(iter(all_choices))
 | 
			
		||||
        raise TypeError(f'unknown parameter given: {first}')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _extract_parameters_from_callback(func: Callable[..., Any], globalns: Dict[str, Any]) -> Dict[str, CommandParameter]:
 | 
			
		||||
    params = inspect.signature(func).parameters
 | 
			
		||||
    cache = {}
 | 
			
		||||
@@ -203,6 +229,13 @@ def _extract_parameters_from_callback(func: Callable[..., Any], globalns: Dict[s
 | 
			
		||||
    else:
 | 
			
		||||
        _populate_descriptions(result, descriptions)
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        choices = func.__discord_app_commands_param_choices__
 | 
			
		||||
    except AttributeError:
 | 
			
		||||
        pass
 | 
			
		||||
    else:
 | 
			
		||||
        _populate_choices(result, choices)
 | 
			
		||||
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@@ -313,15 +346,15 @@ class Command(Generic[GroupT, P, T]):
 | 
			
		||||
    async def _invoke_with_namespace(self, interaction: Interaction, namespace: Namespace) -> T:
 | 
			
		||||
        values = namespace.__dict__
 | 
			
		||||
        for name, param in self._params.items():
 | 
			
		||||
            if not param.required:
 | 
			
		||||
                values.setdefault(name, param.default)
 | 
			
		||||
            else:
 | 
			
		||||
                try:
 | 
			
		||||
                    value = values[name]
 | 
			
		||||
                except KeyError:
 | 
			
		||||
                    raise CommandSignatureMismatch(self) from None
 | 
			
		||||
            try:
 | 
			
		||||
                value = values[name]
 | 
			
		||||
            except KeyError:
 | 
			
		||||
                if not param.required:
 | 
			
		||||
                    values[name] = param.default
 | 
			
		||||
                else:
 | 
			
		||||
                    values[name] = await param.transform(interaction, value)
 | 
			
		||||
                    raise CommandSignatureMismatch(self) from None
 | 
			
		||||
            else:
 | 
			
		||||
                values[name] = await param.transform(interaction, value)
 | 
			
		||||
 | 
			
		||||
        # These type ignores are because the type checker doesn't quite understand the narrowing here
 | 
			
		||||
        # Likewise, it thinks we're missing positional arguments when there aren't any.
 | 
			
		||||
@@ -768,7 +801,7 @@ def describe(**parameters: str) -> Callable[[T], T]:
 | 
			
		||||
    .. code-block:: python3
 | 
			
		||||
 | 
			
		||||
        @app_commands.command()
 | 
			
		||||
        @app_commads.describe(member='the member to ban')
 | 
			
		||||
        @app_commands.describe(member='the member to ban')
 | 
			
		||||
        async def ban(interaction: discord.Interaction, member: discord.Member):
 | 
			
		||||
            await interaction.response.send_message(f'Banned {member}')
 | 
			
		||||
 | 
			
		||||
@@ -787,7 +820,79 @@ def describe(**parameters: str) -> Callable[[T], T]:
 | 
			
		||||
        if isinstance(inner, Command):
 | 
			
		||||
            _populate_descriptions(inner._params, parameters)
 | 
			
		||||
        else:
 | 
			
		||||
            inner.__discord_app_commands_param_description__ = parameters  # type: ignore - Runtime attribute assignment
 | 
			
		||||
            try:
 | 
			
		||||
                inner.__discord_app_commands_param_description__.update(parameters)  # type: ignore - Runtime attribute access
 | 
			
		||||
            except AttributeError:
 | 
			
		||||
                inner.__discord_app_commands_param_description__ = parameters  # type: ignore - Runtime attribute assignment
 | 
			
		||||
 | 
			
		||||
        return inner
 | 
			
		||||
 | 
			
		||||
    return decorator
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def choices(**parameters: List[Choice]) -> Callable[[T], T]:
 | 
			
		||||
    r"""Instructs the given parameters by their name to use the given choices for their choices.
 | 
			
		||||
 | 
			
		||||
    Example:
 | 
			
		||||
 | 
			
		||||
    .. code-block:: python3
 | 
			
		||||
 | 
			
		||||
        @app_commands.command()
 | 
			
		||||
        @app_commands.describe(fruits='fruits to choose from')
 | 
			
		||||
        @app_commands.choices(fruits=[
 | 
			
		||||
            Choice(name='apple', value=1),
 | 
			
		||||
            Choice(name='banana', value=2),
 | 
			
		||||
            Choice(name='cherry', value=3),
 | 
			
		||||
        ])
 | 
			
		||||
        async def fruit(interaction: discord.Interaction, fruits: Choice[int]):
 | 
			
		||||
            await interaction.response.send_message(f'Your favourite fruit is {fruits.name}.')
 | 
			
		||||
 | 
			
		||||
    .. note::
 | 
			
		||||
 | 
			
		||||
        This is not the only way to provide choices to a command. There are two more ergonomic ways
 | 
			
		||||
        of doing this. The first one is to use a :obj:`typing.Literal` annotation:
 | 
			
		||||
 | 
			
		||||
        .. code-block:: python3
 | 
			
		||||
 | 
			
		||||
            @app_commands.command()
 | 
			
		||||
            @app_commands.describe(fruits='fruits to choose from')
 | 
			
		||||
            async def fruit(interaction: discord.Interaction, fruits: Literal['apple', 'banana', 'cherry']):
 | 
			
		||||
                await interaction.response.send_message(f'Your favourite fruit is {fruits}.')
 | 
			
		||||
 | 
			
		||||
        The second way is to use an :class:`enum.Enum`:
 | 
			
		||||
 | 
			
		||||
        .. code-block:: python3
 | 
			
		||||
 | 
			
		||||
            class Fruits(enum.Enum):
 | 
			
		||||
                apple = 1
 | 
			
		||||
                banana = 2
 | 
			
		||||
                cherry = 3
 | 
			
		||||
 | 
			
		||||
            @app_commands.command()
 | 
			
		||||
            @app_commands.describe(fruits='fruits to choose from')
 | 
			
		||||
            async def fruit(interaction: discord.Interaction, fruits: Fruits):
 | 
			
		||||
                await interaction.response.send_message(f'Your favourite fruit is {fruits}.')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    Parameters
 | 
			
		||||
    -----------
 | 
			
		||||
    \*\*parameters
 | 
			
		||||
        The choices of the parameters.
 | 
			
		||||
 | 
			
		||||
    Raises
 | 
			
		||||
    --------
 | 
			
		||||
    TypeError
 | 
			
		||||
        The parameter name is not found.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def decorator(inner: T) -> T:
 | 
			
		||||
        if isinstance(inner, Command):
 | 
			
		||||
            _populate_choices(inner._params, parameters)
 | 
			
		||||
        else:
 | 
			
		||||
            try:
 | 
			
		||||
                inner.__discord_app_commands_param_choices__.update(parameters)  # type: ignore - Runtime attribute access
 | 
			
		||||
            except AttributeError:
 | 
			
		||||
                inner.__discord_app_commands_param_choices__ = parameters  # type: ignore - Runtime attribute assignment
 | 
			
		||||
 | 
			
		||||
        return inner
 | 
			
		||||
 | 
			
		||||
 
 | 
			
		||||
@@ -31,7 +31,7 @@ from ..enums import ChannelType, try_enum
 | 
			
		||||
from ..mixins import Hashable
 | 
			
		||||
from ..utils import _get_as_snowflake, parse_time, snowflake_time
 | 
			
		||||
from .enums import AppCommandOptionType, AppCommandType
 | 
			
		||||
from typing import List, NamedTuple, TYPE_CHECKING, Optional, Union
 | 
			
		||||
from typing import Generic, List, NamedTuple, TYPE_CHECKING, Optional, TypeVar, Union
 | 
			
		||||
 | 
			
		||||
__all__ = (
 | 
			
		||||
    'AppCommand',
 | 
			
		||||
@@ -42,6 +42,8 @@ __all__ = (
 | 
			
		||||
    'Choice',
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
ChoiceT = TypeVar('ChoiceT', str, int, float, Union[str, int, float])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def is_app_command_argument_type(value: int) -> bool:
 | 
			
		||||
    return 11 >= value >= 3
 | 
			
		||||
@@ -145,7 +147,7 @@ class AppCommand(Hashable):
 | 
			
		||||
        return f'<{self.__class__.__name__} id={self.id!r} name={self.name!r} type={self.type!r}>'
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Choice(NamedTuple):
 | 
			
		||||
class Choice(Generic[ChoiceT]):
 | 
			
		||||
    """Represents an application command argument choice.
 | 
			
		||||
 | 
			
		||||
    .. versionadded:: 2.0
 | 
			
		||||
@@ -160,6 +162,10 @@ class Choice(NamedTuple):
 | 
			
		||||
 | 
			
		||||
            Checks if two choices are not equal.
 | 
			
		||||
 | 
			
		||||
        .. describe:: hash(x)
 | 
			
		||||
 | 
			
		||||
            Returns the choice's hash.
 | 
			
		||||
 | 
			
		||||
    Parameters
 | 
			
		||||
    -----------
 | 
			
		||||
    name: :class:`str`
 | 
			
		||||
@@ -168,8 +174,20 @@ class Choice(NamedTuple):
 | 
			
		||||
        The value of the choice.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    name: str
 | 
			
		||||
    value: Union[int, str, float]
 | 
			
		||||
    __slots__ = ('name', 'value')
 | 
			
		||||
 | 
			
		||||
    def __init__(self, *, name: str, value: ChoiceT):
 | 
			
		||||
        self.name: str = name
 | 
			
		||||
        self.value: ChoiceT = value
 | 
			
		||||
 | 
			
		||||
    def __eq__(self, o: object) -> bool:
 | 
			
		||||
        return isinstance(o, Choice) and self.name == o.name and self.value == o.value
 | 
			
		||||
 | 
			
		||||
    def __hash__(self) -> int:
 | 
			
		||||
        return hash((self.name, self.value))
 | 
			
		||||
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        return f'{self.__class__.__name__}(name={self.name!r}, value={self.value!r})'
 | 
			
		||||
 | 
			
		||||
    def to_dict(self) -> ApplicationCommandOptionChoice:
 | 
			
		||||
        return {
 | 
			
		||||
 
 | 
			
		||||
@@ -26,7 +26,8 @@ from __future__ import annotations
 | 
			
		||||
import inspect
 | 
			
		||||
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Set, Tuple, Type, TypeVar, Union
 | 
			
		||||
from enum import Enum
 | 
			
		||||
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Literal, Optional, Set, Tuple, Type, TypeVar, Union
 | 
			
		||||
 | 
			
		||||
from .enums import AppCommandOptionType
 | 
			
		||||
from .errors import TransformerError
 | 
			
		||||
@@ -113,6 +114,13 @@ class CommandParameter:
 | 
			
		||||
 | 
			
		||||
    async def transform(self, interaction: Interaction, value: Any) -> Any:
 | 
			
		||||
        if hasattr(self._annotation, '__discord_app_commands_transformer__'):
 | 
			
		||||
            # This one needs special handling for type safety reasons
 | 
			
		||||
            if self._annotation.__discord_app_commands_is_choice__:
 | 
			
		||||
                choice = next((c for c in self.choices if c.value == value), None)
 | 
			
		||||
                if choice is None:
 | 
			
		||||
                    raise TransformerError(value, self.type, self._annotation)
 | 
			
		||||
                return choice
 | 
			
		||||
 | 
			
		||||
            return await self._annotation.transform(interaction, value)
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
@@ -149,6 +157,7 @@ class Transformer:
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    __discord_app_commands_transformer__: ClassVar[bool] = True
 | 
			
		||||
    __discord_app_commands_is_choice__: ClassVar[bool] = False
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
    def type(cls) -> AppCommandOptionType:
 | 
			
		||||
@@ -221,24 +230,93 @@ class _TransformMetadata:
 | 
			
		||||
        self.metadata: Type[Transformer] = metadata
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def _identity_transform(cls, interaction: Interaction, value: Any) -> Any:
 | 
			
		||||
    return value
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _make_range_transformer(
 | 
			
		||||
    opt_type: AppCommandOptionType,
 | 
			
		||||
    *,
 | 
			
		||||
    min: Optional[Union[int, float]] = None,
 | 
			
		||||
    max: Optional[Union[int, float]] = None,
 | 
			
		||||
) -> Type[Transformer]:
 | 
			
		||||
    async def transform(cls, interaction: Interaction, value: Any) -> Any:
 | 
			
		||||
        return value
 | 
			
		||||
 | 
			
		||||
    ns = {
 | 
			
		||||
        'type': classmethod(lambda _: opt_type),
 | 
			
		||||
        'min_value': classmethod(lambda _: min),
 | 
			
		||||
        'max_value': classmethod(lambda _: max),
 | 
			
		||||
        'transform': classmethod(transform),
 | 
			
		||||
        'transform': classmethod(_identity_transform),
 | 
			
		||||
    }
 | 
			
		||||
    return type('RangeTransformer', (Transformer,), ns)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _make_literal_transformer(values: Tuple[Any, ...]) -> Type[Transformer]:
 | 
			
		||||
    if len(values) < 2:
 | 
			
		||||
        raise TypeError(f'typing.Literal requires at least two values.')
 | 
			
		||||
 | 
			
		||||
    first = type(values[0])
 | 
			
		||||
    if first is int:
 | 
			
		||||
        opt_type = AppCommandOptionType.integer
 | 
			
		||||
    elif first is float:
 | 
			
		||||
        opt_type = AppCommandOptionType.number
 | 
			
		||||
    elif first is str:
 | 
			
		||||
        opt_type = AppCommandOptionType.string
 | 
			
		||||
    else:
 | 
			
		||||
        raise TypeError(f'expected int, str, or float values not {first!r}')
 | 
			
		||||
 | 
			
		||||
    ns = {
 | 
			
		||||
        'type': classmethod(lambda _: opt_type),
 | 
			
		||||
        'transform': classmethod(_identity_transform),
 | 
			
		||||
        '__discord_app_commands_transformer_choices__': [Choice(name=str(v), value=v) for v in values],
 | 
			
		||||
    }
 | 
			
		||||
    return type('LiteralTransformer', (Transformer,), ns)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _make_choice_transformer(inner_type: Any) -> Type[Transformer]:
 | 
			
		||||
    if inner_type is int:
 | 
			
		||||
        opt_type = AppCommandOptionType.integer
 | 
			
		||||
    elif inner_type is float:
 | 
			
		||||
        opt_type = AppCommandOptionType.number
 | 
			
		||||
    elif inner_type is str:
 | 
			
		||||
        opt_type = AppCommandOptionType.string
 | 
			
		||||
    else:
 | 
			
		||||
        raise TypeError(f'expected int, str, or float values not {inner_type!r}')
 | 
			
		||||
 | 
			
		||||
    ns = {
 | 
			
		||||
        'type': classmethod(lambda _: opt_type),
 | 
			
		||||
        'transform': classmethod(_identity_transform),
 | 
			
		||||
        '__discord_app_commands_is_choice__': True,
 | 
			
		||||
    }
 | 
			
		||||
    return type('ChoiceTransformer', (Transformer,), ns)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _make_enum_transformer(enum) -> Type[Transformer]:
 | 
			
		||||
    values = list(enum)
 | 
			
		||||
    if len(values) < 2:
 | 
			
		||||
        raise TypeError(f'enum.Enum requires at least two values.')
 | 
			
		||||
 | 
			
		||||
    first = type(values[0].value)
 | 
			
		||||
    if first is int:
 | 
			
		||||
        opt_type = AppCommandOptionType.integer
 | 
			
		||||
    elif first is float:
 | 
			
		||||
        opt_type = AppCommandOptionType.number
 | 
			
		||||
    elif first is str:
 | 
			
		||||
        opt_type = AppCommandOptionType.string
 | 
			
		||||
    else:
 | 
			
		||||
        raise TypeError(f'expected int, str, or float values not {first!r}')
 | 
			
		||||
 | 
			
		||||
    async def transform(cls, interaction: Interaction, value: Any) -> Any:
 | 
			
		||||
        return enum(value)
 | 
			
		||||
 | 
			
		||||
    ns = {
 | 
			
		||||
        'type': classmethod(lambda _: opt_type),
 | 
			
		||||
        'transform': classmethod(transform),
 | 
			
		||||
        '__discord_app_commands_transformer_enum__': enum,
 | 
			
		||||
        '__discord_app_commands_transformer_choices__': [Choice(name=v.name, value=v.value) for v in values],
 | 
			
		||||
    }
 | 
			
		||||
 | 
			
		||||
    return type(f'{enum.__name__}EnumTransformer', (Transformer,), ns)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from typing_extensions import Annotated as Transform
 | 
			
		||||
    from typing_extensions import Annotated as Range
 | 
			
		||||
@@ -465,11 +543,24 @@ def get_supported_annotation(
 | 
			
		||||
    if hasattr(annotation, '__discord_app_commands_transform__'):
 | 
			
		||||
        return (annotation.metadata, MISSING)
 | 
			
		||||
 | 
			
		||||
    if inspect.isclass(annotation) and issubclass(annotation, Transformer):
 | 
			
		||||
        return (annotation, MISSING)
 | 
			
		||||
    if inspect.isclass(annotation):
 | 
			
		||||
        if issubclass(annotation, Transformer):
 | 
			
		||||
            return (annotation, MISSING)
 | 
			
		||||
        if issubclass(annotation, Enum):
 | 
			
		||||
            return (_make_enum_transformer(annotation), MISSING)
 | 
			
		||||
        if annotation is Choice:
 | 
			
		||||
            raise TypeError(f'Choice requires a type argument of int, str, or float')
 | 
			
		||||
 | 
			
		||||
    # Check if there's an origin
 | 
			
		||||
    origin = getattr(annotation, '__origin__', None)
 | 
			
		||||
    if origin is Literal:
 | 
			
		||||
        args = annotation.__args__  # type: ignore
 | 
			
		||||
        return (_make_literal_transformer(args), MISSING)
 | 
			
		||||
 | 
			
		||||
    if origin is Choice:
 | 
			
		||||
        arg = annotation.__args__[0]  # type: ignore
 | 
			
		||||
        return (_make_choice_transformer(arg), MISSING)
 | 
			
		||||
 | 
			
		||||
    if origin is not Union:
 | 
			
		||||
        # Only Union/Optional is supported right now so bail early
 | 
			
		||||
        raise TypeError(f'unsupported type annotation {annotation!r}')
 | 
			
		||||
@@ -522,9 +613,11 @@ def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> Co
 | 
			
		||||
 | 
			
		||||
    # Verify validity of the default parameter
 | 
			
		||||
    if default is not MISSING:
 | 
			
		||||
        valid_types: Tuple[Any, ...] = ALLOWED_DEFAULTS.get(type, (NoneType,))
 | 
			
		||||
        if not isinstance(default, valid_types):
 | 
			
		||||
            raise TypeError(f'invalid default parameter type given ({default.__class__}), expected {valid_types}')
 | 
			
		||||
        enum_type = getattr(inner, '__discord_app_commands_transformer_enum__', None)
 | 
			
		||||
        if default.__class__ is not enum_type:
 | 
			
		||||
            valid_types: Tuple[Any, ...] = ALLOWED_DEFAULTS.get(type, (NoneType,))
 | 
			
		||||
            if not isinstance(default, valid_types):
 | 
			
		||||
                raise TypeError(f'invalid default parameter type given ({default.__class__}), expected {valid_types}')
 | 
			
		||||
 | 
			
		||||
    result = CommandParameter(
 | 
			
		||||
        type=type,
 | 
			
		||||
@@ -534,6 +627,13 @@ def annotation_to_parameter(annotation: Any, parameter: inspect.Parameter) -> Co
 | 
			
		||||
        name=parameter.name,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    try:
 | 
			
		||||
        choices = inner.__discord_app_commands_transformer_choices__
 | 
			
		||||
    except AttributeError:
 | 
			
		||||
        pass
 | 
			
		||||
    else:
 | 
			
		||||
        result.choices = choices
 | 
			
		||||
 | 
			
		||||
    # These methods should be duck typed
 | 
			
		||||
    if type in (AppCommandOptionType.number, AppCommandOptionType.integer):
 | 
			
		||||
        result.min_value = inner.min_value()
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user