From 629f36e7d78f034580572dc93659cae9082b48fe Mon Sep 17 00:00:00 2001 From: Rapptz Date: Tue, 5 Apr 2022 06:40:31 -0400 Subject: [PATCH] [commands] Add fallback behaviour to the default parameter instances This allows users to explicitly override the default annotation for CurrentAuthor and CurrentChannel since they're wider than what most users would expect --- discord/ext/commands/core.py | 11 ++++++++++- discord/ext/commands/parameters.py | 18 +++++++++++++++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index cf1d8ff88..3e2eb4688 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -138,7 +138,16 @@ def get_signature_parameters( default = parameter.default if isinstance(default, Parameter): # update from the default if default.annotation is not Parameter.empty: - parameter._annotation = default.annotation + # There are a few cases to care about here. + # x: TextChannel = commands.CurrentChannel + # x = commands.CurrentChannel + # In both of these cases, the default parameter has an explicit annotation + # but in the second case it's only used as the fallback. + if default._fallback: + if parameter.annotation is Parameter.empty: + parameter._annotation = default.annotation + else: + parameter._annotation = default.annotation parameter._default = default.default parameter._displayed_default = default._displayed_default diff --git a/discord/ext/commands/parameters.py b/discord/ext/commands/parameters.py index 1204dd5e3..4bc4bdeee 100644 --- a/discord/ext/commands/parameters.py +++ b/discord/ext/commands/parameters.py @@ -31,6 +31,16 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, OrderedDict, Union from discord.utils import MISSING, maybe_coroutine from .errors import NoPrivateMessage +from .converter import GuildConverter + +from discord import ( + Member, + User, + TextChannel, + VoiceChannel, + DMChannel, + Thread, +) if TYPE_CHECKING: from typing_extensions import Self @@ -77,7 +87,7 @@ class Parameter(inspect.Parameter): .. versionadded:: 2.0 """ - __slots__ = ('_displayed_default',) + __slots__ = ('_displayed_default', '_fallback') def __init__( self, @@ -93,6 +103,7 @@ class Parameter(inspect.Parameter): self._default = default self._annotation = annotation self._displayed_default = displayed_default + self._fallback = False def replace( self, @@ -218,12 +229,16 @@ An alias for :func:`parameter`. Author = parameter( default=attrgetter('author'), displayed_default='', + converter=Union[Member, User], ) +Author._fallback = True CurrentChannel = parameter( default=attrgetter('channel'), displayed_default='', + converter=Union[TextChannel, DMChannel, Thread, VoiceChannel], ) +CurrentChannel._fallback = True def default_guild(ctx: Context) -> Guild: @@ -235,6 +250,7 @@ def default_guild(ctx: Context) -> Guild: CurrentGuild = parameter( default=default_guild, displayed_default='', + converter=GuildConverter, )