Add support for annotation transformers

This facilitates the "converter-like" API of the app_commands
submodule. As a consequence of this refactor, more types are supported
like channels and attachment.
This commit is contained in:
Rapptz
2022-02-28 08:06:32 -05:00
parent c10ed93cef
commit 3cf3065c02
4 changed files with 565 additions and 189 deletions

View File

@@ -41,7 +41,6 @@ from typing import (
TypeVar,
Union,
)
from dataclasses import dataclass
from textwrap import TextWrapper
import sys
@@ -51,6 +50,7 @@ from .enums import AppCommandOptionType, AppCommandType
from ..interactions import Interaction
from ..enums import ChannelType, try_enum
from .models import AppCommandChannel, AppCommandThread, Choice
from .transformers import annotation_to_parameter, CommandParameter, NoneType
from .errors import AppCommandError, CommandInvokeError, CommandSignatureMismatch, CommandAlreadyRegistered
from ..utils import resolve_annotation, MISSING, is_inside_class
from ..user import User
@@ -72,7 +72,6 @@ if TYPE_CHECKING:
from .namespace import Namespace
__all__ = (
'CommandParameter',
'Command',
'ContextMenu',
'Group',
@@ -130,158 +129,6 @@ def _to_kebab_case(text: str) -> str:
return CAMEL_CASE_REGEX.sub('-', text).lower()
@dataclass
class CommandParameter:
"""Represents a application command parameter.
Attributes
-----------
name: :class:`str`
The name of the parameter.
description: :class:`str`
The description of the parameter
required: :class:`bool`
Whether the parameter is required
choices: List[:class:`~discord.app_commands.Choice`]
A list of choices this parameter takes
type: :class:`~discord.app_commands.AppCommandOptionType`
The underlying type of this parameter.
channel_types: List[:class:`~discord.ChannelType`]
The channel types that are allowed for this parameter.
min_value: Optional[:class:`int`]
The minimum supported value for this parameter.
max_value: Optional[:class:`int`]
The maximum supported value for this parameter.
autocomplete: :class:`bool`
Whether this parameter enables autocomplete.
"""
name: str = MISSING
description: str = MISSING
required: bool = MISSING
default: Any = MISSING
choices: List[Choice] = MISSING
type: AppCommandOptionType = MISSING
channel_types: List[ChannelType] = MISSING
min_value: Optional[int] = None
max_value: Optional[int] = None
autocomplete: bool = MISSING
_annotation: Any = MISSING
def to_dict(self) -> Dict[str, Any]:
base = {
'type': self.type.value,
'name': self.name,
'description': self.description,
'required': self.required,
}
if self.choices:
base['choices'] = [choice.to_dict() for choice in self.choices]
if self.channel_types:
base['channel_types'] = [t.value for t in self.channel_types]
if self.autocomplete:
base['autocomplete'] = True
if self.min_value is not None:
base['min_value'] = self.min_value
if self.max_value is not None:
base['max_value'] = self.max_value
return base
annotation_to_option_type: Dict[Any, AppCommandOptionType] = {
str: AppCommandOptionType.string,
int: AppCommandOptionType.integer,
float: AppCommandOptionType.number,
bool: AppCommandOptionType.boolean,
User: AppCommandOptionType.user,
Member: AppCommandOptionType.user,
Role: AppCommandOptionType.role,
AppCommandChannel: AppCommandOptionType.channel,
AppCommandThread: AppCommandOptionType.channel,
# StageChannel: AppCommandOptionType.channel,
# StoreChannel: AppCommandOptionType.channel,
# VoiceChannel: AppCommandOptionType.channel,
# TextChannel: AppCommandOptionType.channel,
}
NoneType = type(None)
allowed_default_types: Dict[AppCommandOptionType, Tuple[Type[Any], ...]] = {
AppCommandOptionType.string: (str, NoneType),
AppCommandOptionType.integer: (int, NoneType),
AppCommandOptionType.boolean: (bool, NoneType),
}
# Some sanity checks:
# str => string
# int => int
# User => user
# etc ...
# Optional[str] => string, required: false, default: None
# Optional[int] => integer, required: false, default: None
# Optional[Model] = None => resolved, required: false, default: None
# Optional[Model] can only have (CommandParameter, None) as default
# Optional[int | str | bool] can have (CommandParameter, None, int | str | bool) as a default
# Union[str, Member] => disallowed
# Union[int, str] => disallowed
# Union[Member, User] => user
# Optional[Union[Member, User]] => user, required: false, default: None
# Union[Member, User, Object] => mentionable
# Union[Models] => mentionable
# Optional[Union[Models]] => mentionable, required: false, default: None
def _annotation_to_type(
annotation: Any,
*,
mapping=annotation_to_option_type,
_none=NoneType,
) -> Tuple[AppCommandOptionType, Any]:
# Straight simple case, a regular ol' parameter
try:
option_type = mapping[annotation]
except KeyError:
pass
else:
return (option_type, MISSING)
# Check if there's an origin
origin = getattr(annotation, '__origin__', None)
if origin is not Union:
# Only Union/Optional is supported so bail early
raise TypeError(f'unsupported type annotation {annotation!r}')
default = MISSING
if annotation.__args__[-1] is _none:
if len(annotation.__args__) == 2:
underlying = annotation.__args__[0]
option_type = mapping.get(underlying)
if option_type is None:
raise TypeError(f'unsupported inner optional type {underlying!r}')
return (option_type, None)
else:
args = annotation.__args__[:-1]
default = None
else:
args = annotation.__args__
# At this point only models are allowed
# Since Optional[int | bool | str] will be taken care of above
# The only valid transformations here are:
# [Member, User] => user
# [Member, User, Role] => mentionable
# [Member | User, Role] => mentionable
supported_types: Set[Any] = {Role, Member, User}
if not all(arg in supported_types for arg in args):
raise TypeError(f'unsupported types given inside {annotation!r}')
if args == (User, Member) or args == (Member, User):
return (AppCommandOptionType.user, default)
return (AppCommandOptionType.mentionable, default)
def _context_menu_annotation(annotation: Any, *, _none=NoneType) -> AppCommandType:
if annotation is Message:
return AppCommandType.message
@@ -324,33 +171,6 @@ def _populate_descriptions(params: Dict[str, CommandParameter], descriptions: Di
raise TypeError(f'unknown parameter given: {first}')
def _get_parameter(annotation: Any, parameter: inspect.Parameter) -> CommandParameter:
(type, default) = _annotation_to_type(annotation)
if default is MISSING:
default = parameter.default
if default is parameter.empty:
default = MISSING
result = CommandParameter(
type=type,
default=default,
required=default is MISSING,
name=parameter.name,
)
if parameter.kind in (parameter.POSITIONAL_ONLY, parameter.VAR_KEYWORD, parameter.VAR_POSITIONAL):
raise TypeError(f'unsupported parameter kind in callback: {parameter.kind!s}')
# Verify validity of the default parameter
if result.default is not MISSING:
valid_types: Tuple[Any, ...] = allowed_default_types.get(result.type, (NoneType,))
if not isinstance(result.default, valid_types):
raise TypeError(f'invalid default parameter type given ({result.default.__class__}), expected {valid_types}')
result._annotation = annotation
return result
def _extract_parameters_from_callback(func: Callable[..., Any], globalns: Dict[str, Any]) -> Dict[str, CommandParameter]:
params = inspect.signature(func).parameters
cache = {}
@@ -368,7 +188,7 @@ def _extract_parameters_from_callback(func: Callable[..., Any], globalns: Dict[s
raise TypeError(f'annotation for {parameter.name} must be given')
resolved = resolve_annotation(parameter.annotation, globalns, globalns, cache)
param = _get_parameter(resolved, parameter)
param = annotation_to_parameter(resolved, parameter)
parameters.append(param)
values = sorted(parameters, key=lambda a: a.required, reverse=True)
@@ -377,7 +197,9 @@ def _extract_parameters_from_callback(func: Callable[..., Any], globalns: Dict[s
try:
descriptions = func.__discord_app_commands_param_description__
except AttributeError:
pass
for param in values:
if param.description is MISSING:
param.description = '...'
else:
_populate_descriptions(result, descriptions)
@@ -489,14 +311,24 @@ class Command(Generic[GroupT, P, T]):
await parent.parent.on_error(interaction, self, error)
async def _invoke_with_namespace(self, interaction: Interaction, namespace: Namespace) -> T:
defaults = ((name, param.default) for name, param in self._params.items() if not param.required)
namespace._update_with_defaults(defaults)
values = namespace.__dict__
for name, param in self._params.items():
if not param.required:
values.setdefault(name, param.default)
else:
try:
value = values[name]
except KeyError:
raise CommandSignatureMismatch(self) from None
else:
values[name] = await param.transform(interaction, value)
# These type ignores are because the type checker doesn't quite understand the narrowing here
# Likewise, it thinks we're missing positional arguments when there aren't any.
try:
if self.binding is not None:
return await self._callback(self.binding, interaction, **namespace.__dict__) # type: ignore
return await self._callback(interaction, **namespace.__dict__) # type: ignore
return await self._callback(self.binding, interaction, **values) # type: ignore
return await self._callback(interaction, **values) # type: ignore
except TypeError as e:
# In order to detect mismatch from the provided signature and the Discord data,
# there are many ways it can go wrong yet all of them eventually lead to a TypeError