diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index 36946c85..46ccdee4 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -1105,7 +1105,7 @@ class BotBase(GroupMixin): option = next((o for o in command_options if o['name'] == name), None) # type: ignore if option is None: - if not command._is_typing_optional(param.annotation): + if param.default is param.empty and not command._is_typing_optional(param.annotation): raise errors.MissingRequiredArgument(param) elif ( option["type"] == 3 diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index 5740a188..bf552a59 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -77,6 +77,7 @@ __all__ = ( 'GuildStickerConverter', 'clean_content', 'Greedy', + 'Option', 'run_converters', ) @@ -96,6 +97,9 @@ T_co = TypeVar('T_co', covariant=True) CT = TypeVar('CT', bound=discord.abc.GuildChannel) TT = TypeVar('TT', bound=discord.Thread) +NT = TypeVar('NT', bound=str) +DT = TypeVar('DT', bound=str) + @runtime_checkable class Converter(Protocol[T_co]): @@ -1004,6 +1008,20 @@ class Greedy(List[T]): return cls(converter=converter) +if TYPE_CHECKING: + def Option(default: T = inspect.Parameter.empty, *, name: str = None, description: str) -> T: ... +else: + class Option(Generic[T, DT, NT]): + description: DT + name: Optional[NT] + default: Union[T, inspect.Parameter.empty] + __slots__ = ('name', 'default', 'description',) + + def __init__(self, default: T = inspect.Parameter.empty, *, name: NT = None, description: DT) -> None: + self.description = description + self.default = default + self.name = name + def _convert_to_bool(argument: str) -> bool: lowered = argument.lower() diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index 123a76d6..83730243 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -39,19 +39,21 @@ from typing import ( TypeVar, Type, TYPE_CHECKING, + cast, overload, ) import asyncio import functools import inspect import datetime +from collections import defaultdict from operator import itemgetter import discord from .errors import * from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency, DynamicCooldownMapping -from .converter import run_converters, get_converter, Greedy +from .converter import run_converters, get_converter, Greedy, Option from ._types import _BaseCommand from .cog import Cog from .context import Context @@ -136,13 +138,19 @@ def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: return function -def get_signature_parameters(function: Callable[..., Any], globalns: Dict[str, Any]) -> Dict[str, inspect.Parameter]: +def get_signature_parameters(function: Callable[..., Any], globalns: Dict[str, Any]) -> Tuple[Dict[str, inspect.Parameter], Dict[str, str]]: signature = inspect.signature(function) params = {} cache: Dict[str, Any] = {} + descriptions = defaultdict(lambda: 'no description') eval_annotation = discord.utils.evaluate_annotation for name, parameter in signature.parameters.items(): annotation = parameter.annotation + if isinstance(parameter.default, Option): # type: ignore + option = parameter.default + descriptions[name] = option.description + parameter = parameter.replace(default=option.default) + if annotation is parameter.empty: params[name] = parameter continue @@ -156,7 +164,7 @@ def get_signature_parameters(function: Callable[..., Any], globalns: Dict[str, A params[name] = parameter.replace(annotation=annotation) - return params + return params, descriptions def wrap_callback(coro): @@ -421,7 +429,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): except AttributeError: globalns = {} - self.params = get_signature_parameters(function, globalns) + self.params, self.option_descriptions = get_signature_parameters(function, globalns) def add_check(self, func: Check) -> None: """Adds a check to the command. @@ -1160,7 +1168,6 @@ class Command(_BaseCommand, Generic[CogT, P, T]): if nested != 0: payload["type"] = 1 - option_descriptions = self.extras.get("option_descriptions", {}) for name, param in self.clean_params.items(): annotation: Type[Any] = param.annotation if param.annotation is not param.empty else str origin = getattr(param.annotation, "__origin__", None) @@ -1171,10 +1178,11 @@ class Command(_BaseCommand, Generic[CogT, P, T]): option: Dict[str, Any] = { "name": name, - "required": not self._is_typing_optional(annotation), - "description": option_descriptions.get(name, "no description"), + "description": self.option_descriptions[name], + "required": param.default is param.empty and not self._is_typing_optional(annotation), } + annotation = cast(Any, annotation) if not option["required"] and origin is not None and len(annotation.__args__) == 2: # Unpack Optional[T] (Union[T, None]) into just T annotation, origin = annotation.__args__[0], None @@ -1182,7 +1190,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]): if origin is None: option["type"] = next( (num for t, num in application_option_type_lookup.items() - if issubclass(annotation, t)), str + if issubclass(annotation, t)), 3 ) elif origin is Literal and len(origin.__args__) <= 25: # type: ignore option["choices"] = [{ diff --git a/discord/ext/commands/help.py b/discord/ext/commands/help.py index afaacbfb..722ccc05 100644 --- a/discord/ext/commands/help.py +++ b/discord/ext/commands/help.py @@ -615,7 +615,7 @@ class HelpCommand: :class:`.abc.Messageable` The destination where the help command will be output. """ - return self.context.channel + return self.context async def send_error_message(self, error): """|coro| @@ -977,6 +977,14 @@ class DefaultHelpCommand(HelpCommand): for page in self.paginator.pages: await destination.send(page) + interaction = self.context.interaction + if ( + interaction is not None + and destination == self.context.author + and not interaction.response.is_done() + ): + await interaction.response.send_message("Sent help to your DMs!", ephemeral=True) + def add_command_formatting(self, command): """A utility function to format the non-indented block of commands and groups. @@ -1007,7 +1015,7 @@ class DefaultHelpCommand(HelpCommand): elif self.dm_help is None and len(self.paginator) > self.dm_help_threshold: return ctx.author else: - return ctx.channel + return ctx async def prepare_help_command(self, ctx, command): self.paginator.clear()