Rearrange some stuff and add flag support

This commit is contained in:
Gnome 2021-09-09 20:49:03 +01:00
parent 17096629cd
commit 2f3d59e625
2 changed files with 87 additions and 54 deletions

View File

@ -64,6 +64,7 @@ from .core import GroupMixin
from .converter import Greedy
from .view import StringView, supported_quotes
from .context import Context
from .flags import FlagConverter
from . import errors
from .help import HelpCommand, DefaultHelpCommand
from .cog import Cog
@ -272,9 +273,9 @@ class BotBase(GroupMixin):
for guild in guilds:
commands[guild].append(payload)
http: HTTPClient = self.http # type: ignore
http: HTTPClient = self.http # type: ignore
global_commands = commands.pop(None, None)
application_id = self.application_id or (await self.application_info()).id # type: ignore
application_id = self.application_id or (await self.application_info()).id # type: ignore
if global_commands is not None:
if self.slash_command_guilds is None:
await http.bulk_upsert_global_commands(
@ -1271,7 +1272,20 @@ class BotBase(GroupMixin):
ignore_params: List[inspect.Parameter] = []
message.content = f"{prefix}{command_name} "
for name, param in command.clean_params.items():
option = next((o for o in command_options if o["name"] == name), None) # type: ignore
if inspect.isclass(param.annotation) and issubclass(param.annotation, FlagConverter):
for name, flag in param.annotation.get_flags().items():
option = next((o for o in command_options if o["name"] == name), None)
if option is None:
if flag.required:
raise errors.MissingRequiredFlag(flag)
else:
prefix = param.annotation.__commands_flag_prefix__
delimiter = param.annotation.__commands_flag_delimiter__
message.content += f"{prefix}{name} {option['value']}{delimiter}" # type: ignore
continue
option = next((o for o in command_options if o["name"] == name), None)
if option is None:
if param.default is param.empty and not command._is_typing_optional(param.annotation):
raise errors.MissingRequiredArgument(param)
@ -1280,8 +1294,7 @@ class BotBase(GroupMixin):
elif option["type"] == 3 and param.kind != param.KEYWORD_ONLY and not isinstance(param.annotation, Greedy):
# String with space in without "consume rest"
option = cast(_ApplicationCommandInteractionDataOptionString, option)
quoted_string = _quote_string_safe(option["value"])
message.content += f"{quoted_string} "
message.content += f"{_quote_string_safe(option['value'])} "
else:
message.content += f'{option.get("value", "")} '

View File

@ -58,13 +58,14 @@ from .converter import CONVERTER_MAPPING, Converter, run_converters, get_convert
from ._types import _BaseCommand
from .cog import Cog
from .context import Context
from .flags import FlagConverter
if TYPE_CHECKING:
from typing_extensions import Concatenate, ParamSpec, TypeGuard
from discord.message import Message
from discord.types.interactions import EditApplicationCommand
from discord.types.interactions import EditApplicationCommand, ApplicationCommandInteractionDataOption
from ._types import (
Coro,
@ -1202,6 +1203,65 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
finally:
ctx.command = original
def _param_to_options(
self, name: str, annotation: Any, required: bool, varadic: bool
) -> List[Optional[ApplicationCommandInteractionDataOption]]:
origin = getattr(annotation, "__origin__", None)
if inspect.isclass(annotation) and issubclass(annotation, FlagConverter):
return [
param
for name, flag in annotation.get_flags().items()
for param in self._param_to_options(
name, flag.annotation, required=flag.required, varadic=flag.annotation is tuple
)
]
if varadic:
annotation = str
origin = None
if not required and origin is not None and len(annotation.__args__) == 2:
# Unpack Optional[T] (Union[T, None]) into just T
annotation, origin = annotation.__args__[0], None
option: Dict[str, Any] = {
"name": name,
"required": required,
"description": self.option_descriptions[name],
}
if origin is None:
if not inspect.isclass(annotation):
annotation = type(annotation)
if issubclass(annotation, Converter):
# If this is a converter, we want to check if it is a native
# one, in which we can get the original type, eg, (MemberConverter -> Member)
annotation = REVERSED_CONVERTER_MAPPING.get(annotation, annotation)
option["type"] = 3
for python_type, discord_type in application_option_type_lookup.items():
if issubclass(annotation, python_type):
option["type"] = discord_type
break
elif origin is Literal:
literal_values = annotation.__args__
python_type = type(literal_values[0])
if (
all(type(value) == python_type for value in literal_values)
and python_type in application_option_type_lookup.keys()
):
option["type"] = application_option_type_lookup[python_type]
option["choices"] = [
{"name": literal_value, "value": literal_value} for literal_value in annotation.__args__
]
option.setdefault("type", 3) # STRING
return [option] # type: ignore
def to_application_command(self, nested: int = 0) -> Optional[EditApplicationCommand]:
if self.slash_command is False:
return
@ -1213,55 +1273,15 @@ class Command(_BaseCommand, Generic[CogT, P, T]):
payload["type"] = 1
for name, param in self.clean_params.items():
annotation: Type[Any] = param.annotation if param.annotation is not param.empty else str
origin = getattr(param.annotation, "__origin__", None)
if origin is None and isinstance(annotation, Greedy):
annotation = annotation.converter
origin = Greedy
option: Dict[str, Any] = {
"name": name,
"description": self.option_descriptions[name],
"required": (param.default is param.empty and not self._is_typing_optional(annotation))
options = self._param_to_options(
name,
param.annotation if param.annotation is not param.empty else str,
varadic=param.kind == param.KEYWORD_ONLY or isinstance(param.annotation, Greedy),
required=(param.default is param.empty and not self._is_typing_optional(param.annotation))
or param.kind == param.VAR_POSITIONAL,
}
annotation = cast(Any, annotation)
if not option["required"] and origin is not None and len(annotation.__args__) == 2:
# Unpack Optional[T] (Union[T, None]) into just T
annotation, origin = annotation.__args__[0], None
if origin is None:
if not inspect.isclass(annotation):
annotation = type(annotation)
if issubclass(annotation, Converter):
# If this is a converter, we want to check if it is a native
# one, in which we can get the original type, eg, (MemberConverter -> Member)
annotation = REVERSED_CONVERTER_MAPPING.get(annotation, annotation)
option["type"] = 3
for python_type, discord_type in application_option_type_lookup.items():
if issubclass(annotation, python_type):
option["type"] = discord_type
break
elif origin is Literal:
literal_values = annotation.__args__
python_type = type(literal_values[0])
if (
all(type(value) == python_type for value in literal_values)
and python_type in application_option_type_lookup.keys()
):
option["type"] = application_option_type_lookup[python_type]
option["choices"] = [
{"name": literal_value, "value": literal_value} for literal_value in annotation.__args__
]
option.setdefault("type", 3) # STRING
payload["options"].append(option)
)
if options is not None:
payload["options"].extend(option for option in options if option is not None)
# Now we have all options, make sure required is before optional.
payload["options"] = sorted(payload["options"], key=itemgetter("required"), reverse=True)