mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-04-18 23:15:48 +00:00
[commands] Properly support commands.param in hybrid commands
This commit is contained in:
parent
fa3a4c109b
commit
f072edfdfc
@ -42,8 +42,9 @@ import inspect
|
||||
from discord import app_commands
|
||||
from discord.utils import MISSING, maybe_coroutine, async_all
|
||||
from .core import Command, Group
|
||||
from .errors import CommandRegistrationError, CommandError, HybridCommandError, ConversionError
|
||||
from .errors import BadArgument, CommandRegistrationError, CommandError, HybridCommandError, ConversionError
|
||||
from .converter import Converter
|
||||
from .parameters import Parameter
|
||||
from .cog import Cog
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -51,7 +52,6 @@ if TYPE_CHECKING:
|
||||
from ._types import ContextT, Coro, BotT
|
||||
from .bot import Bot
|
||||
from .context import Context
|
||||
from .parameters import Parameter
|
||||
from discord.app_commands.commands import (
|
||||
Check as AppCommandCheck,
|
||||
AutocompleteCallback,
|
||||
@ -71,6 +71,7 @@ CogT = TypeVar('CogT', bound='Cog')
|
||||
CommandT = TypeVar('CommandT', bound='Command')
|
||||
# CHT = TypeVar('CHT', bound='Check')
|
||||
GroupT = TypeVar('GroupT', bound='Group')
|
||||
_NoneType = type(None)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
P = ParamSpec('P')
|
||||
@ -85,6 +86,17 @@ else:
|
||||
P2 = TypeVar('P2')
|
||||
|
||||
|
||||
class _CallableDefault:
|
||||
__slots__ = ('func',)
|
||||
|
||||
def __init__(self, func: Callable[[Context], Any]) -> None:
|
||||
self.func: Callable[[Context], Any] = func
|
||||
|
||||
@property
|
||||
def __class__(self) -> Any:
|
||||
return _NoneType
|
||||
|
||||
|
||||
def is_converter(converter: Any) -> bool:
|
||||
return (inspect.isclass(converter) and issubclass(converter, Converter)) or isinstance(converter, Converter)
|
||||
|
||||
@ -107,12 +119,33 @@ def make_converter_transformer(converter: Any) -> Type[app_commands.Transformer]
|
||||
return type('ConverterTransformer', (app_commands.Transformer,), {'transform': classmethod(transform)})
|
||||
|
||||
|
||||
def make_callable_transformer(func: Callable[[str], Any]) -> Type[app_commands.Transformer]:
|
||||
async def transform(cls, interaction: discord.Interaction, value: str) -> Any:
|
||||
try:
|
||||
return func(value)
|
||||
except CommandError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
raise BadArgument(f'Converting to "{func.__name__}" failed') from exc
|
||||
|
||||
return type('CallableTransformer', (app_commands.Transformer,), {'transform': classmethod(transform)})
|
||||
|
||||
|
||||
def replace_parameters(parameters: Dict[str, Parameter], signature: inspect.Signature) -> List[inspect.Parameter]:
|
||||
# Need to convert commands.Parameter back to inspect.Parameter so this will be a bit ugly
|
||||
params = signature.parameters.copy()
|
||||
for name, parameter in parameters.items():
|
||||
if is_converter(parameter.converter) and not hasattr(parameter.converter, '__discord_app_commands_transformer__'):
|
||||
is_transformer = hasattr(parameter.converter, '__discord_app_commands_transformer__')
|
||||
if is_converter(parameter.converter) and not is_transformer:
|
||||
params[name] = params[name].replace(annotation=make_converter_transformer(parameter.converter))
|
||||
if callable(parameter.converter) and not inspect.isclass(parameter.converter) and not is_transformer:
|
||||
params[name] = params[name].replace(annotation=make_callable_transformer(parameter.converter))
|
||||
if callable(parameter.default):
|
||||
params[name] = params[name].replace(default=_CallableDefault(parameter.default))
|
||||
|
||||
if isinstance(params[name].default, Parameter):
|
||||
# If we're here, then then it hasn't been handled yet so it should be removed completely
|
||||
params[name] = params[name].replace(default=parameter.empty)
|
||||
|
||||
return list(params.values())
|
||||
|
||||
@ -146,6 +179,28 @@ class HybridAppCommand(discord.app_commands.Command[CogT, P, T]):
|
||||
}
|
||||
return self._copy_with(parent=self.parent, binding=self.binding, bindings=bindings)
|
||||
|
||||
async def _transform_arguments(
|
||||
self, interaction: discord.Interaction, namespace: app_commands.Namespace
|
||||
) -> Dict[str, Any]:
|
||||
values = namespace.__dict__
|
||||
transformed_values = {}
|
||||
|
||||
for param in self._params.values():
|
||||
try:
|
||||
value = values[param.display_name]
|
||||
except KeyError:
|
||||
if not param.required:
|
||||
if isinstance(param.default, _CallableDefault):
|
||||
transformed_values[param.name] = await maybe_coroutine(param.default.func, interaction._baton)
|
||||
else:
|
||||
transformed_values[param.name] = param.default
|
||||
else:
|
||||
raise app_commands.CommandSignatureMismatch(self) from None
|
||||
else:
|
||||
transformed_values[param.name] = await param.transform(interaction, value)
|
||||
|
||||
return transformed_values
|
||||
|
||||
async def _check_can_run(self, interaction: discord.Interaction) -> bool:
|
||||
# Hybrid checks must run like so:
|
||||
# - Bot global check once
|
||||
|
Loading…
x
Reference in New Issue
Block a user