[commands] Add fallback behaviour to the default parameter instances

This allows users to explicitly override the default annotation for
CurrentAuthor and CurrentChannel since they're wider than what most
users would expect
This commit is contained in:
Rapptz 2022-04-05 06:40:31 -04:00
parent f15f601779
commit 629f36e7d7
2 changed files with 27 additions and 2 deletions

View File

@ -138,7 +138,16 @@ def get_signature_parameters(
default = parameter.default
if isinstance(default, Parameter): # update from the default
if default.annotation is not Parameter.empty:
parameter._annotation = default.annotation
# There are a few cases to care about here.
# x: TextChannel = commands.CurrentChannel
# x = commands.CurrentChannel
# In both of these cases, the default parameter has an explicit annotation
# but in the second case it's only used as the fallback.
if default._fallback:
if parameter.annotation is Parameter.empty:
parameter._annotation = default.annotation
else:
parameter._annotation = default.annotation
parameter._default = default.default
parameter._displayed_default = default._displayed_default

View File

@ -31,6 +31,16 @@ from typing import TYPE_CHECKING, Any, Literal, Optional, OrderedDict, Union
from discord.utils import MISSING, maybe_coroutine
from .errors import NoPrivateMessage
from .converter import GuildConverter
from discord import (
Member,
User,
TextChannel,
VoiceChannel,
DMChannel,
Thread,
)
if TYPE_CHECKING:
from typing_extensions import Self
@ -77,7 +87,7 @@ class Parameter(inspect.Parameter):
.. versionadded:: 2.0
"""
__slots__ = ('_displayed_default',)
__slots__ = ('_displayed_default', '_fallback')
def __init__(
self,
@ -93,6 +103,7 @@ class Parameter(inspect.Parameter):
self._default = default
self._annotation = annotation
self._displayed_default = displayed_default
self._fallback = False
def replace(
self,
@ -218,12 +229,16 @@ An alias for :func:`parameter`.
Author = parameter(
default=attrgetter('author'),
displayed_default='<you>',
converter=Union[Member, User],
)
Author._fallback = True
CurrentChannel = parameter(
default=attrgetter('channel'),
displayed_default='<this channel>',
converter=Union[TextChannel, DMChannel, Thread, VoiceChannel],
)
CurrentChannel._fallback = True
def default_guild(ctx: Context) -> Guild:
@ -235,6 +250,7 @@ def default_guild(ctx: Context) -> Guild:
CurrentGuild = parameter(
default=default_guild,
displayed_default='<this server>',
converter=GuildConverter,
)