mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-11-02 22:42:54 +00:00
Add support for annotation transformers
This facilitates the "converter-like" API of the app_commands submodule. As a consequence of this refactor, more types are supported like channels and attachment.
This commit is contained in:
@@ -41,7 +41,6 @@ from typing import (
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
from dataclasses import dataclass
|
||||
from textwrap import TextWrapper
|
||||
|
||||
import sys
|
||||
@@ -51,6 +50,7 @@ from .enums import AppCommandOptionType, AppCommandType
|
||||
from ..interactions import Interaction
|
||||
from ..enums import ChannelType, try_enum
|
||||
from .models import AppCommandChannel, AppCommandThread, Choice
|
||||
from .transformers import annotation_to_parameter, CommandParameter, NoneType
|
||||
from .errors import AppCommandError, CommandInvokeError, CommandSignatureMismatch, CommandAlreadyRegistered
|
||||
from ..utils import resolve_annotation, MISSING, is_inside_class
|
||||
from ..user import User
|
||||
@@ -72,7 +72,6 @@ if TYPE_CHECKING:
|
||||
from .namespace import Namespace
|
||||
|
||||
__all__ = (
|
||||
'CommandParameter',
|
||||
'Command',
|
||||
'ContextMenu',
|
||||
'Group',
|
||||
@@ -130,158 +129,6 @@ def _to_kebab_case(text: str) -> str:
|
||||
return CAMEL_CASE_REGEX.sub('-', text).lower()
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommandParameter:
|
||||
"""Represents a application command parameter.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
name: :class:`str`
|
||||
The name of the parameter.
|
||||
description: :class:`str`
|
||||
The description of the parameter
|
||||
required: :class:`bool`
|
||||
Whether the parameter is required
|
||||
choices: List[:class:`~discord.app_commands.Choice`]
|
||||
A list of choices this parameter takes
|
||||
type: :class:`~discord.app_commands.AppCommandOptionType`
|
||||
The underlying type of this parameter.
|
||||
channel_types: List[:class:`~discord.ChannelType`]
|
||||
The channel types that are allowed for this parameter.
|
||||
min_value: Optional[:class:`int`]
|
||||
The minimum supported value for this parameter.
|
||||
max_value: Optional[:class:`int`]
|
||||
The maximum supported value for this parameter.
|
||||
autocomplete: :class:`bool`
|
||||
Whether this parameter enables autocomplete.
|
||||
"""
|
||||
|
||||
name: str = MISSING
|
||||
description: str = MISSING
|
||||
required: bool = MISSING
|
||||
default: Any = MISSING
|
||||
choices: List[Choice] = MISSING
|
||||
type: AppCommandOptionType = MISSING
|
||||
channel_types: List[ChannelType] = MISSING
|
||||
min_value: Optional[int] = None
|
||||
max_value: Optional[int] = None
|
||||
autocomplete: bool = MISSING
|
||||
_annotation: Any = MISSING
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
base = {
|
||||
'type': self.type.value,
|
||||
'name': self.name,
|
||||
'description': self.description,
|
||||
'required': self.required,
|
||||
}
|
||||
|
||||
if self.choices:
|
||||
base['choices'] = [choice.to_dict() for choice in self.choices]
|
||||
if self.channel_types:
|
||||
base['channel_types'] = [t.value for t in self.channel_types]
|
||||
if self.autocomplete:
|
||||
base['autocomplete'] = True
|
||||
if self.min_value is not None:
|
||||
base['min_value'] = self.min_value
|
||||
if self.max_value is not None:
|
||||
base['max_value'] = self.max_value
|
||||
|
||||
return base
|
||||
|
||||
|
||||
annotation_to_option_type: Dict[Any, AppCommandOptionType] = {
|
||||
str: AppCommandOptionType.string,
|
||||
int: AppCommandOptionType.integer,
|
||||
float: AppCommandOptionType.number,
|
||||
bool: AppCommandOptionType.boolean,
|
||||
User: AppCommandOptionType.user,
|
||||
Member: AppCommandOptionType.user,
|
||||
Role: AppCommandOptionType.role,
|
||||
AppCommandChannel: AppCommandOptionType.channel,
|
||||
AppCommandThread: AppCommandOptionType.channel,
|
||||
# StageChannel: AppCommandOptionType.channel,
|
||||
# StoreChannel: AppCommandOptionType.channel,
|
||||
# VoiceChannel: AppCommandOptionType.channel,
|
||||
# TextChannel: AppCommandOptionType.channel,
|
||||
}
|
||||
|
||||
NoneType = type(None)
|
||||
allowed_default_types: Dict[AppCommandOptionType, Tuple[Type[Any], ...]] = {
|
||||
AppCommandOptionType.string: (str, NoneType),
|
||||
AppCommandOptionType.integer: (int, NoneType),
|
||||
AppCommandOptionType.boolean: (bool, NoneType),
|
||||
}
|
||||
|
||||
|
||||
# Some sanity checks:
|
||||
# str => string
|
||||
# int => int
|
||||
# User => user
|
||||
# etc ...
|
||||
# Optional[str] => string, required: false, default: None
|
||||
# Optional[int] => integer, required: false, default: None
|
||||
# Optional[Model] = None => resolved, required: false, default: None
|
||||
# Optional[Model] can only have (CommandParameter, None) as default
|
||||
# Optional[int | str | bool] can have (CommandParameter, None, int | str | bool) as a default
|
||||
# Union[str, Member] => disallowed
|
||||
# Union[int, str] => disallowed
|
||||
# Union[Member, User] => user
|
||||
# Optional[Union[Member, User]] => user, required: false, default: None
|
||||
# Union[Member, User, Object] => mentionable
|
||||
# Union[Models] => mentionable
|
||||
# Optional[Union[Models]] => mentionable, required: false, default: None
|
||||
|
||||
|
||||
def _annotation_to_type(
|
||||
annotation: Any,
|
||||
*,
|
||||
mapping=annotation_to_option_type,
|
||||
_none=NoneType,
|
||||
) -> Tuple[AppCommandOptionType, Any]:
|
||||
# Straight simple case, a regular ol' parameter
|
||||
try:
|
||||
option_type = mapping[annotation]
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
return (option_type, MISSING)
|
||||
|
||||
# Check if there's an origin
|
||||
origin = getattr(annotation, '__origin__', None)
|
||||
if origin is not Union:
|
||||
# Only Union/Optional is supported so bail early
|
||||
raise TypeError(f'unsupported type annotation {annotation!r}')
|
||||
|
||||
default = MISSING
|
||||
if annotation.__args__[-1] is _none:
|
||||
if len(annotation.__args__) == 2:
|
||||
underlying = annotation.__args__[0]
|
||||
option_type = mapping.get(underlying)
|
||||
if option_type is None:
|
||||
raise TypeError(f'unsupported inner optional type {underlying!r}')
|
||||
return (option_type, None)
|
||||
else:
|
||||
args = annotation.__args__[:-1]
|
||||
default = None
|
||||
else:
|
||||
args = annotation.__args__
|
||||
|
||||
# At this point only models are allowed
|
||||
# Since Optional[int | bool | str] will be taken care of above
|
||||
# The only valid transformations here are:
|
||||
# [Member, User] => user
|
||||
# [Member, User, Role] => mentionable
|
||||
# [Member | User, Role] => mentionable
|
||||
supported_types: Set[Any] = {Role, Member, User}
|
||||
if not all(arg in supported_types for arg in args):
|
||||
raise TypeError(f'unsupported types given inside {annotation!r}')
|
||||
if args == (User, Member) or args == (Member, User):
|
||||
return (AppCommandOptionType.user, default)
|
||||
|
||||
return (AppCommandOptionType.mentionable, default)
|
||||
|
||||
|
||||
def _context_menu_annotation(annotation: Any, *, _none=NoneType) -> AppCommandType:
|
||||
if annotation is Message:
|
||||
return AppCommandType.message
|
||||
@@ -324,33 +171,6 @@ def _populate_descriptions(params: Dict[str, CommandParameter], descriptions: Di
|
||||
raise TypeError(f'unknown parameter given: {first}')
|
||||
|
||||
|
||||
def _get_parameter(annotation: Any, parameter: inspect.Parameter) -> CommandParameter:
|
||||
(type, default) = _annotation_to_type(annotation)
|
||||
if default is MISSING:
|
||||
default = parameter.default
|
||||
if default is parameter.empty:
|
||||
default = MISSING
|
||||
|
||||
result = CommandParameter(
|
||||
type=type,
|
||||
default=default,
|
||||
required=default is MISSING,
|
||||
name=parameter.name,
|
||||
)
|
||||
|
||||
if parameter.kind in (parameter.POSITIONAL_ONLY, parameter.VAR_KEYWORD, parameter.VAR_POSITIONAL):
|
||||
raise TypeError(f'unsupported parameter kind in callback: {parameter.kind!s}')
|
||||
|
||||
# Verify validity of the default parameter
|
||||
if result.default is not MISSING:
|
||||
valid_types: Tuple[Any, ...] = allowed_default_types.get(result.type, (NoneType,))
|
||||
if not isinstance(result.default, valid_types):
|
||||
raise TypeError(f'invalid default parameter type given ({result.default.__class__}), expected {valid_types}')
|
||||
|
||||
result._annotation = annotation
|
||||
return result
|
||||
|
||||
|
||||
def _extract_parameters_from_callback(func: Callable[..., Any], globalns: Dict[str, Any]) -> Dict[str, CommandParameter]:
|
||||
params = inspect.signature(func).parameters
|
||||
cache = {}
|
||||
@@ -368,7 +188,7 @@ def _extract_parameters_from_callback(func: Callable[..., Any], globalns: Dict[s
|
||||
raise TypeError(f'annotation for {parameter.name} must be given')
|
||||
|
||||
resolved = resolve_annotation(parameter.annotation, globalns, globalns, cache)
|
||||
param = _get_parameter(resolved, parameter)
|
||||
param = annotation_to_parameter(resolved, parameter)
|
||||
parameters.append(param)
|
||||
|
||||
values = sorted(parameters, key=lambda a: a.required, reverse=True)
|
||||
@@ -377,7 +197,9 @@ def _extract_parameters_from_callback(func: Callable[..., Any], globalns: Dict[s
|
||||
try:
|
||||
descriptions = func.__discord_app_commands_param_description__
|
||||
except AttributeError:
|
||||
pass
|
||||
for param in values:
|
||||
if param.description is MISSING:
|
||||
param.description = '...'
|
||||
else:
|
||||
_populate_descriptions(result, descriptions)
|
||||
|
||||
@@ -489,14 +311,24 @@ class Command(Generic[GroupT, P, T]):
|
||||
await parent.parent.on_error(interaction, self, error)
|
||||
|
||||
async def _invoke_with_namespace(self, interaction: Interaction, namespace: Namespace) -> T:
|
||||
defaults = ((name, param.default) for name, param in self._params.items() if not param.required)
|
||||
namespace._update_with_defaults(defaults)
|
||||
values = namespace.__dict__
|
||||
for name, param in self._params.items():
|
||||
if not param.required:
|
||||
values.setdefault(name, param.default)
|
||||
else:
|
||||
try:
|
||||
value = values[name]
|
||||
except KeyError:
|
||||
raise CommandSignatureMismatch(self) from None
|
||||
else:
|
||||
values[name] = await param.transform(interaction, value)
|
||||
|
||||
# These type ignores are because the type checker doesn't quite understand the narrowing here
|
||||
# Likewise, it thinks we're missing positional arguments when there aren't any.
|
||||
try:
|
||||
if self.binding is not None:
|
||||
return await self._callback(self.binding, interaction, **namespace.__dict__) # type: ignore
|
||||
return await self._callback(interaction, **namespace.__dict__) # type: ignore
|
||||
return await self._callback(self.binding, interaction, **values) # type: ignore
|
||||
return await self._callback(interaction, **values) # type: ignore
|
||||
except TypeError as e:
|
||||
# In order to detect mismatch from the provided signature and the Discord data,
|
||||
# there are many ways it can go wrong yet all of them eventually lead to a TypeError
|
||||
|
||||
Reference in New Issue
Block a user