mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-10-24 01:53:01 +00:00
[commands] Add support for discord.Attachment converters
This commit is contained in:
@@ -56,7 +56,7 @@ from .errors import *
|
||||
from .parameters import Parameter, Signature
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Concatenate, ParamSpec, Self, TypeGuard
|
||||
from typing_extensions import Concatenate, ParamSpec, Self
|
||||
|
||||
from discord.message import Message
|
||||
|
||||
@@ -237,6 +237,27 @@ class _CaseInsensitiveDict(dict):
|
||||
super().__setitem__(k.casefold(), v)
|
||||
|
||||
|
||||
class _AttachmentIterator:
|
||||
def __init__(self, data: List[discord.Attachment]):
|
||||
self.data: List[discord.Attachment] = data
|
||||
self.index: int = 0
|
||||
|
||||
def __iter__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __next__(self) -> discord.Attachment:
|
||||
try:
|
||||
value = self.data[self.index]
|
||||
except IndexError:
|
||||
raise StopIteration
|
||||
else:
|
||||
self.index += 1
|
||||
return value
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
return self.index >= len(self.data)
|
||||
|
||||
|
||||
class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
r"""A class that implements the protocol for a bot text command.
|
||||
|
||||
@@ -592,7 +613,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
finally:
|
||||
ctx.bot.dispatch('command_error', ctx, error)
|
||||
|
||||
async def transform(self, ctx: Context[BotT], param: Parameter, /) -> Any:
|
||||
async def transform(self, ctx: Context[BotT], param: Parameter, attachments: _AttachmentIterator, /) -> Any:
|
||||
converter = param.converter
|
||||
consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw
|
||||
view = ctx.view
|
||||
@@ -601,6 +622,10 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
# The greedy converter is simple -- it keeps going until it fails in which case,
|
||||
# it undos the view ready for the next parameter to use instead
|
||||
if isinstance(converter, Greedy):
|
||||
# Special case for Greedy[discord.Attachment] to consume the attachments iterator
|
||||
if converter.converter is discord.Attachment:
|
||||
return list(attachments)
|
||||
|
||||
if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY):
|
||||
return await self._transform_greedy_pos(ctx, param, param.required, converter.converter)
|
||||
elif param.kind == param.VAR_POSITIONAL:
|
||||
@@ -611,6 +636,20 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
# into just X and do the parsing that way.
|
||||
converter = converter.converter
|
||||
|
||||
# Try to detect Optional[discord.Attachment] or discord.Attachment special converter
|
||||
if converter is discord.Attachment:
|
||||
try:
|
||||
return next(attachments)
|
||||
except StopIteration:
|
||||
raise MissingRequiredAttachment(param)
|
||||
|
||||
if self._is_typing_optional(param.annotation) and param.annotation.__args__[0] is discord.Attachment:
|
||||
if attachments.is_empty():
|
||||
# I have no idea who would be doing Optional[discord.Attachment] = 1
|
||||
# but for those cases then 1 should be returned instead of None
|
||||
return None if param.default is param.empty else param.default
|
||||
return next(attachments)
|
||||
|
||||
if view.eof:
|
||||
if param.kind == param.VAR_POSITIONAL:
|
||||
raise RuntimeError() # break the loop
|
||||
@@ -759,6 +798,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
ctx.kwargs = {}
|
||||
args = ctx.args
|
||||
kwargs = ctx.kwargs
|
||||
attachments = _AttachmentIterator(ctx.message.attachments)
|
||||
|
||||
view = ctx.view
|
||||
iterator = iter(self.params.items())
|
||||
@@ -766,7 +806,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
for name, param in iterator:
|
||||
ctx.current_parameter = param
|
||||
if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY):
|
||||
transformed = await self.transform(ctx, param)
|
||||
transformed = await self.transform(ctx, param, attachments)
|
||||
args.append(transformed)
|
||||
elif param.kind == param.KEYWORD_ONLY:
|
||||
# kwarg only param denotes "consume rest" semantics
|
||||
@@ -774,14 +814,14 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
ctx.current_argument = argument = view.read_rest()
|
||||
kwargs[name] = await run_converters(ctx, param.converter, argument, param)
|
||||
else:
|
||||
kwargs[name] = await self.transform(ctx, param)
|
||||
kwargs[name] = await self.transform(ctx, param, attachments)
|
||||
break
|
||||
elif param.kind == param.VAR_POSITIONAL:
|
||||
if view.eof and self.require_var_positional:
|
||||
raise MissingRequiredArgument(param)
|
||||
while not view.eof:
|
||||
try:
|
||||
transformed = await self.transform(ctx, param)
|
||||
transformed = await self.transform(ctx, param, attachments)
|
||||
args.append(transformed)
|
||||
except RuntimeError:
|
||||
break
|
||||
@@ -1080,7 +1120,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
return self.help.split('\n', 1)[0]
|
||||
return ''
|
||||
|
||||
def _is_typing_optional(self, annotation: Union[T, Optional[T]]) -> TypeGuard[Optional[T]]:
|
||||
def _is_typing_optional(self, annotation: Union[T, Optional[T]]) -> bool:
|
||||
return getattr(annotation, '__origin__', None) is Union and type(None) in annotation.__args__ # type: ignore
|
||||
|
||||
@property
|
||||
@@ -1108,6 +1148,17 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
||||
annotation = union_args[0]
|
||||
origin = getattr(annotation, '__origin__', None)
|
||||
|
||||
if annotation is discord.Attachment:
|
||||
# For discord.Attachment we need to signal to the user that it's an attachment
|
||||
# It's not exactly pretty but it's enough to differentiate
|
||||
if optional:
|
||||
result.append(f'[{name} (upload a file)]')
|
||||
elif greedy:
|
||||
result.append(f'[{name} (upload files)]...')
|
||||
else:
|
||||
result.append(f'<{name} (upload a file)>')
|
||||
continue
|
||||
|
||||
# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
|
||||
# parameter signature is a literal list of it's values
|
||||
if origin is Literal:
|
||||
|
@@ -45,6 +45,7 @@ if TYPE_CHECKING:
|
||||
__all__ = (
|
||||
'CommandError',
|
||||
'MissingRequiredArgument',
|
||||
'MissingRequiredAttachment',
|
||||
'BadArgument',
|
||||
'PrivateMessageOnly',
|
||||
'NoPrivateMessage',
|
||||
@@ -184,6 +185,25 @@ class MissingRequiredArgument(UserInputError):
|
||||
super().__init__(f'{param.name} is a required argument that is missing.')
|
||||
|
||||
|
||||
class MissingRequiredAttachment(UserInputError):
|
||||
"""Exception raised when parsing a command and a parameter
|
||||
that requires an attachment is not given.
|
||||
|
||||
This inherits from :exc:`UserInputError`
|
||||
|
||||
.. versionadded:: 2.0
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
param: :class:`Parameter`
|
||||
The argument that is missing an attachment.
|
||||
"""
|
||||
|
||||
def __init__(self, param: Parameter) -> None:
|
||||
self.param: Parameter = param
|
||||
super().__init__(f'{param.name} is a required argument that is missing an attachment.')
|
||||
|
||||
|
||||
class TooManyArguments(UserInputError):
|
||||
"""Exception raised when the command was passed too many arguments and its
|
||||
:attr:`.Command.ignore_extra` attribute was not set to ``True``.
|
||||
|
@@ -186,6 +186,9 @@ def replace_parameters(parameters: Dict[str, Parameter], signature: inspect.Sign
|
||||
# However, in here, it probably makes sense to make it required.
|
||||
# I'm unsure how to allow the user to choose right now.
|
||||
inner = converter.converter
|
||||
if inner is discord.Attachment:
|
||||
raise TypeError('discord.Attachment with Greedy is not supported in hybrid commands')
|
||||
|
||||
param = param.replace(annotation=make_greedy_transformer(inner, parameter))
|
||||
elif is_converter(converter):
|
||||
param = param.replace(annotation=make_converter_transformer(converter))
|
||||
|
Reference in New Issue
Block a user