Add Option converter, fix default optional, fix help command

This commit is contained in:
Gnome 2021-08-31 18:44:32 +01:00
parent 1a22df6228
commit 7c83c335d1
4 changed files with 45 additions and 11 deletions

View File

@ -1105,7 +1105,7 @@ class BotBase(GroupMixin):
option = next((o for o in command_options if o['name'] == name), None) # type: ignore option = next((o for o in command_options if o['name'] == name), None) # type: ignore
if option is None: if option is None:
if not command._is_typing_optional(param.annotation): if param.default is param.empty and not command._is_typing_optional(param.annotation):
raise errors.MissingRequiredArgument(param) raise errors.MissingRequiredArgument(param)
elif ( elif (
option["type"] == 3 option["type"] == 3

View File

@ -77,6 +77,7 @@ __all__ = (
'GuildStickerConverter', 'GuildStickerConverter',
'clean_content', 'clean_content',
'Greedy', 'Greedy',
'Option',
'run_converters', 'run_converters',
) )
@ -96,6 +97,9 @@ T_co = TypeVar('T_co', covariant=True)
CT = TypeVar('CT', bound=discord.abc.GuildChannel) CT = TypeVar('CT', bound=discord.abc.GuildChannel)
TT = TypeVar('TT', bound=discord.Thread) TT = TypeVar('TT', bound=discord.Thread)
NT = TypeVar('NT', bound=str)
DT = TypeVar('DT', bound=str)
@runtime_checkable @runtime_checkable
class Converter(Protocol[T_co]): class Converter(Protocol[T_co]):
@ -1004,6 +1008,20 @@ class Greedy(List[T]):
return cls(converter=converter) return cls(converter=converter)
if TYPE_CHECKING:
def Option(default: T = inspect.Parameter.empty, *, name: str = None, description: str) -> T: ...
else:
class Option(Generic[T, DT, NT]):
description: DT
name: Optional[NT]
default: Union[T, inspect.Parameter.empty]
__slots__ = ('name', 'default', 'description',)
def __init__(self, default: T = inspect.Parameter.empty, *, name: NT = None, description: DT) -> None:
self.description = description
self.default = default
self.name = name
def _convert_to_bool(argument: str) -> bool: def _convert_to_bool(argument: str) -> bool:
lowered = argument.lower() lowered = argument.lower()

View File

@ -39,19 +39,21 @@ from typing import (
TypeVar, TypeVar,
Type, Type,
TYPE_CHECKING, TYPE_CHECKING,
cast,
overload, overload,
) )
import asyncio import asyncio
import functools import functools
import inspect import inspect
import datetime import datetime
from collections import defaultdict
from operator import itemgetter from operator import itemgetter
import discord import discord
from .errors import * from .errors import *
from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency, DynamicCooldownMapping from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency, DynamicCooldownMapping
from .converter import run_converters, get_converter, Greedy from .converter import run_converters, get_converter, Greedy, Option
from ._types import _BaseCommand from ._types import _BaseCommand
from .cog import Cog from .cog import Cog
from .context import Context from .context import Context
@ -136,13 +138,19 @@ def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]:
return function return function
def get_signature_parameters(function: Callable[..., Any], globalns: Dict[str, Any]) -> Dict[str, inspect.Parameter]: def get_signature_parameters(function: Callable[..., Any], globalns: Dict[str, Any]) -> Tuple[Dict[str, inspect.Parameter], Dict[str, str]]:
signature = inspect.signature(function) signature = inspect.signature(function)
params = {} params = {}
cache: Dict[str, Any] = {} cache: Dict[str, Any] = {}
descriptions = defaultdict(lambda: 'no description')
eval_annotation = discord.utils.evaluate_annotation eval_annotation = discord.utils.evaluate_annotation
for name, parameter in signature.parameters.items(): for name, parameter in signature.parameters.items():
annotation = parameter.annotation annotation = parameter.annotation
if isinstance(parameter.default, Option): # type: ignore
option = parameter.default
descriptions[name] = option.description
parameter = parameter.replace(default=option.default)
if annotation is parameter.empty: if annotation is parameter.empty:
params[name] = parameter params[name] = parameter
continue continue
@ -156,7 +164,7 @@ def get_signature_parameters(function: Callable[..., Any], globalns: Dict[str, A
params[name] = parameter.replace(annotation=annotation) params[name] = parameter.replace(annotation=annotation)
return params return params, descriptions
def wrap_callback(coro): def wrap_callback(coro):
@ -421,7 +429,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
except AttributeError: except AttributeError:
globalns = {} globalns = {}
self.params = get_signature_parameters(function, globalns) self.params, self.option_descriptions = get_signature_parameters(function, globalns)
def add_check(self, func: Check) -> None: def add_check(self, func: Check) -> None:
"""Adds a check to the command. """Adds a check to the command.
@ -1160,7 +1168,6 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if nested != 0: if nested != 0:
payload["type"] = 1 payload["type"] = 1
option_descriptions = self.extras.get("option_descriptions", {})
for name, param in self.clean_params.items(): for name, param in self.clean_params.items():
annotation: Type[Any] = param.annotation if param.annotation is not param.empty else str annotation: Type[Any] = param.annotation if param.annotation is not param.empty else str
origin = getattr(param.annotation, "__origin__", None) origin = getattr(param.annotation, "__origin__", None)
@ -1171,10 +1178,11 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
option: Dict[str, Any] = { option: Dict[str, Any] = {
"name": name, "name": name,
"required": not self._is_typing_optional(annotation), "description": self.option_descriptions[name],
"description": option_descriptions.get(name, "no description"), "required": param.default is param.empty and not self._is_typing_optional(annotation),
} }
annotation = cast(Any, annotation)
if not option["required"] and origin is not None and len(annotation.__args__) == 2: if not option["required"] and origin is not None and len(annotation.__args__) == 2:
# Unpack Optional[T] (Union[T, None]) into just T # Unpack Optional[T] (Union[T, None]) into just T
annotation, origin = annotation.__args__[0], None annotation, origin = annotation.__args__[0], None
@ -1182,7 +1190,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
if origin is None: if origin is None:
option["type"] = next( option["type"] = next(
(num for t, num in application_option_type_lookup.items() (num for t, num in application_option_type_lookup.items()
if issubclass(annotation, t)), str if issubclass(annotation, t)), 3
) )
elif origin is Literal and len(origin.__args__) <= 25: # type: ignore elif origin is Literal and len(origin.__args__) <= 25: # type: ignore
option["choices"] = [{ option["choices"] = [{

View File

@ -615,7 +615,7 @@ class HelpCommand:
:class:`.abc.Messageable` :class:`.abc.Messageable`
The destination where the help command will be output. The destination where the help command will be output.
""" """
return self.context.channel return self.context
async def send_error_message(self, error): async def send_error_message(self, error):
"""|coro| """|coro|
@ -977,6 +977,14 @@ class DefaultHelpCommand(HelpCommand):
for page in self.paginator.pages: for page in self.paginator.pages:
await destination.send(page) await destination.send(page)
interaction = self.context.interaction
if (
interaction is not None
and destination == self.context.author
and not interaction.response.is_done()
):
await interaction.response.send_message("Sent help to your DMs!", ephemeral=True)
def add_command_formatting(self, command): def add_command_formatting(self, command):
"""A utility function to format the non-indented block of commands and groups. """A utility function to format the non-indented block of commands and groups.
@ -1007,7 +1015,7 @@ class DefaultHelpCommand(HelpCommand):
elif self.dm_help is None and len(self.paginator) > self.dm_help_threshold: elif self.dm_help is None and len(self.paginator) > self.dm_help_threshold:
return ctx.author return ctx.author
else: else:
return ctx.channel return ctx
async def prepare_help_command(self, ctx, command): async def prepare_help_command(self, ctx, command):
self.paginator.clear() self.paginator.clear()