Add Option converter, fix default optional, fix help command
This commit is contained in:
parent
1a22df6228
commit
7c83c335d1
@ -1105,7 +1105,7 @@ class BotBase(GroupMixin):
|
|||||||
option = next((o for o in command_options if o['name'] == name), None) # type: ignore
|
option = next((o for o in command_options if o['name'] == name), None) # type: ignore
|
||||||
|
|
||||||
if option is None:
|
if option is None:
|
||||||
if not command._is_typing_optional(param.annotation):
|
if param.default is param.empty and not command._is_typing_optional(param.annotation):
|
||||||
raise errors.MissingRequiredArgument(param)
|
raise errors.MissingRequiredArgument(param)
|
||||||
elif (
|
elif (
|
||||||
option["type"] == 3
|
option["type"] == 3
|
||||||
|
@ -77,6 +77,7 @@ __all__ = (
|
|||||||
'GuildStickerConverter',
|
'GuildStickerConverter',
|
||||||
'clean_content',
|
'clean_content',
|
||||||
'Greedy',
|
'Greedy',
|
||||||
|
'Option',
|
||||||
'run_converters',
|
'run_converters',
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -96,6 +97,9 @@ T_co = TypeVar('T_co', covariant=True)
|
|||||||
CT = TypeVar('CT', bound=discord.abc.GuildChannel)
|
CT = TypeVar('CT', bound=discord.abc.GuildChannel)
|
||||||
TT = TypeVar('TT', bound=discord.Thread)
|
TT = TypeVar('TT', bound=discord.Thread)
|
||||||
|
|
||||||
|
NT = TypeVar('NT', bound=str)
|
||||||
|
DT = TypeVar('DT', bound=str)
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class Converter(Protocol[T_co]):
|
class Converter(Protocol[T_co]):
|
||||||
@ -1004,6 +1008,20 @@ class Greedy(List[T]):
|
|||||||
|
|
||||||
return cls(converter=converter)
|
return cls(converter=converter)
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
def Option(default: T = inspect.Parameter.empty, *, name: str = None, description: str) -> T: ...
|
||||||
|
else:
|
||||||
|
class Option(Generic[T, DT, NT]):
|
||||||
|
description: DT
|
||||||
|
name: Optional[NT]
|
||||||
|
default: Union[T, inspect.Parameter.empty]
|
||||||
|
__slots__ = ('name', 'default', 'description',)
|
||||||
|
|
||||||
|
def __init__(self, default: T = inspect.Parameter.empty, *, name: NT = None, description: DT) -> None:
|
||||||
|
self.description = description
|
||||||
|
self.default = default
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_bool(argument: str) -> bool:
|
def _convert_to_bool(argument: str) -> bool:
|
||||||
lowered = argument.lower()
|
lowered = argument.lower()
|
||||||
|
@ -39,19 +39,21 @@ from typing import (
|
|||||||
TypeVar,
|
TypeVar,
|
||||||
Type,
|
Type,
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
|
cast,
|
||||||
overload,
|
overload,
|
||||||
)
|
)
|
||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
import inspect
|
import inspect
|
||||||
import datetime
|
import datetime
|
||||||
|
from collections import defaultdict
|
||||||
from operator import itemgetter
|
from operator import itemgetter
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
|
||||||
from .errors import *
|
from .errors import *
|
||||||
from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency, DynamicCooldownMapping
|
from .cooldowns import Cooldown, BucketType, CooldownMapping, MaxConcurrency, DynamicCooldownMapping
|
||||||
from .converter import run_converters, get_converter, Greedy
|
from .converter import run_converters, get_converter, Greedy, Option
|
||||||
from ._types import _BaseCommand
|
from ._types import _BaseCommand
|
||||||
from .cog import Cog
|
from .cog import Cog
|
||||||
from .context import Context
|
from .context import Context
|
||||||
@ -136,13 +138,19 @@ def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]:
|
|||||||
return function
|
return function
|
||||||
|
|
||||||
|
|
||||||
def get_signature_parameters(function: Callable[..., Any], globalns: Dict[str, Any]) -> Dict[str, inspect.Parameter]:
|
def get_signature_parameters(function: Callable[..., Any], globalns: Dict[str, Any]) -> Tuple[Dict[str, inspect.Parameter], Dict[str, str]]:
|
||||||
signature = inspect.signature(function)
|
signature = inspect.signature(function)
|
||||||
params = {}
|
params = {}
|
||||||
cache: Dict[str, Any] = {}
|
cache: Dict[str, Any] = {}
|
||||||
|
descriptions = defaultdict(lambda: 'no description')
|
||||||
eval_annotation = discord.utils.evaluate_annotation
|
eval_annotation = discord.utils.evaluate_annotation
|
||||||
for name, parameter in signature.parameters.items():
|
for name, parameter in signature.parameters.items():
|
||||||
annotation = parameter.annotation
|
annotation = parameter.annotation
|
||||||
|
if isinstance(parameter.default, Option): # type: ignore
|
||||||
|
option = parameter.default
|
||||||
|
descriptions[name] = option.description
|
||||||
|
parameter = parameter.replace(default=option.default)
|
||||||
|
|
||||||
if annotation is parameter.empty:
|
if annotation is parameter.empty:
|
||||||
params[name] = parameter
|
params[name] = parameter
|
||||||
continue
|
continue
|
||||||
@ -156,7 +164,7 @@ def get_signature_parameters(function: Callable[..., Any], globalns: Dict[str, A
|
|||||||
|
|
||||||
params[name] = parameter.replace(annotation=annotation)
|
params[name] = parameter.replace(annotation=annotation)
|
||||||
|
|
||||||
return params
|
return params, descriptions
|
||||||
|
|
||||||
|
|
||||||
def wrap_callback(coro):
|
def wrap_callback(coro):
|
||||||
@ -421,7 +429,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
|||||||
except AttributeError:
|
except AttributeError:
|
||||||
globalns = {}
|
globalns = {}
|
||||||
|
|
||||||
self.params = get_signature_parameters(function, globalns)
|
self.params, self.option_descriptions = get_signature_parameters(function, globalns)
|
||||||
|
|
||||||
def add_check(self, func: Check) -> None:
|
def add_check(self, func: Check) -> None:
|
||||||
"""Adds a check to the command.
|
"""Adds a check to the command.
|
||||||
@ -1160,7 +1168,6 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
|||||||
if nested != 0:
|
if nested != 0:
|
||||||
payload["type"] = 1
|
payload["type"] = 1
|
||||||
|
|
||||||
option_descriptions = self.extras.get("option_descriptions", {})
|
|
||||||
for name, param in self.clean_params.items():
|
for name, param in self.clean_params.items():
|
||||||
annotation: Type[Any] = param.annotation if param.annotation is not param.empty else str
|
annotation: Type[Any] = param.annotation if param.annotation is not param.empty else str
|
||||||
origin = getattr(param.annotation, "__origin__", None)
|
origin = getattr(param.annotation, "__origin__", None)
|
||||||
@ -1171,10 +1178,11 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
|||||||
|
|
||||||
option: Dict[str, Any] = {
|
option: Dict[str, Any] = {
|
||||||
"name": name,
|
"name": name,
|
||||||
"required": not self._is_typing_optional(annotation),
|
"description": self.option_descriptions[name],
|
||||||
"description": option_descriptions.get(name, "no description"),
|
"required": param.default is param.empty and not self._is_typing_optional(annotation),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
annotation = cast(Any, annotation)
|
||||||
if not option["required"] and origin is not None and len(annotation.__args__) == 2:
|
if not option["required"] and origin is not None and len(annotation.__args__) == 2:
|
||||||
# Unpack Optional[T] (Union[T, None]) into just T
|
# Unpack Optional[T] (Union[T, None]) into just T
|
||||||
annotation, origin = annotation.__args__[0], None
|
annotation, origin = annotation.__args__[0], None
|
||||||
@ -1182,7 +1190,7 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
|
|||||||
if origin is None:
|
if origin is None:
|
||||||
option["type"] = next(
|
option["type"] = next(
|
||||||
(num for t, num in application_option_type_lookup.items()
|
(num for t, num in application_option_type_lookup.items()
|
||||||
if issubclass(annotation, t)), str
|
if issubclass(annotation, t)), 3
|
||||||
)
|
)
|
||||||
elif origin is Literal and len(origin.__args__) <= 25: # type: ignore
|
elif origin is Literal and len(origin.__args__) <= 25: # type: ignore
|
||||||
option["choices"] = [{
|
option["choices"] = [{
|
||||||
|
@ -615,7 +615,7 @@ class HelpCommand:
|
|||||||
:class:`.abc.Messageable`
|
:class:`.abc.Messageable`
|
||||||
The destination where the help command will be output.
|
The destination where the help command will be output.
|
||||||
"""
|
"""
|
||||||
return self.context.channel
|
return self.context
|
||||||
|
|
||||||
async def send_error_message(self, error):
|
async def send_error_message(self, error):
|
||||||
"""|coro|
|
"""|coro|
|
||||||
@ -977,6 +977,14 @@ class DefaultHelpCommand(HelpCommand):
|
|||||||
for page in self.paginator.pages:
|
for page in self.paginator.pages:
|
||||||
await destination.send(page)
|
await destination.send(page)
|
||||||
|
|
||||||
|
interaction = self.context.interaction
|
||||||
|
if (
|
||||||
|
interaction is not None
|
||||||
|
and destination == self.context.author
|
||||||
|
and not interaction.response.is_done()
|
||||||
|
):
|
||||||
|
await interaction.response.send_message("Sent help to your DMs!", ephemeral=True)
|
||||||
|
|
||||||
def add_command_formatting(self, command):
|
def add_command_formatting(self, command):
|
||||||
"""A utility function to format the non-indented block of commands and groups.
|
"""A utility function to format the non-indented block of commands and groups.
|
||||||
|
|
||||||
@ -1007,7 +1015,7 @@ class DefaultHelpCommand(HelpCommand):
|
|||||||
elif self.dm_help is None and len(self.paginator) > self.dm_help_threshold:
|
elif self.dm_help is None and len(self.paginator) > self.dm_help_threshold:
|
||||||
return ctx.author
|
return ctx.author
|
||||||
else:
|
else:
|
||||||
return ctx.channel
|
return ctx
|
||||||
|
|
||||||
async def prepare_help_command(self, ctx, command):
|
async def prepare_help_command(self, ctx, command):
|
||||||
self.paginator.clear()
|
self.paginator.clear()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user