From 42463bae672f1356e4a7cb6076cadd1597c086af Mon Sep 17 00:00:00 2001 From: Josh Date: Wed, 21 Apr 2021 14:31:01 +1000 Subject: [PATCH] [commands] Add support for aliasing to FlagConverter --- discord/ext/commands/flags.py | 51 ++++++++++++++++++++++++++++++++--- 1 file changed, 47 insertions(+), 4 deletions(-) diff --git a/discord/ext/commands/flags.py b/discord/ext/commands/flags.py index 460774db6..8f5d34d39 100644 --- a/discord/ext/commands/flags.py +++ b/discord/ext/commands/flags.py @@ -37,7 +37,7 @@ from .view import StringView from .converter import run_converters from discord.utils import maybe_coroutine -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import ( Dict, Optional, @@ -87,6 +87,8 @@ class Flag: ------------ name: :class:`str` The name of the flag. + aliases: List[:class:`str`] + The aliases of the flag name. attribute: :class:`str` The attribute in the class that corresponds to this flag. default: Any @@ -101,6 +103,7 @@ class Flag: """ name: str = MISSING + aliases: List[str] = field(default_factory=list) attribute: str = MISSING annotation: Any = MISSING default: Any = MISSING @@ -120,6 +123,7 @@ class Flag: def flag( *, name: str = MISSING, + aliases: List[str] = MISSING, default: Any = MISSING, max_args: int = MISSING, override: bool = MISSING, @@ -131,6 +135,8 @@ def flag( ------------ name: :class:`str` The flag name. If not given, defaults to the attribute name. + aliases: List[:class:`str`] + Aliases to the flag name. If not given no aliases are set. default: Any The default parameter. This could be either a value or a callable that takes :class:`Context` as its sole parameter. If not given then it defaults to @@ -143,7 +149,7 @@ def flag( Whether multiple given values overrides the previous value. The default value depends on the annotation given. """ - return Flag(name=name, default=default, max_args=max_args, override=override) + return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override) def validate_flag_name(name: str, forbidden: Set[str]): @@ -161,8 +167,10 @@ def validate_flag_name(name: str, forbidden: Set[str]): def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]: annotations = namespace.get('__annotations__', {}) + case_insensitive = namespace['__commands_flag_case_insensitive__'] flags: Dict[str, Flag] = {} cache: Dict[str, Any] = {} + names: Set[str] = set() for name, annotation in annotations.items(): flag = namespace.pop(name, MISSING) if isinstance(flag, Flag): @@ -176,6 +184,9 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache) + if flag.aliases is MISSING: + flag.aliases = [] + # Add sensible defaults based off of the type annotation # -> (max_args=1) # List[str] -> (max_args=-1) @@ -221,6 +232,21 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s if flag.override is MISSING: flag.override = False + # Validate flag names are unique + name = flag.name.casefold() if case_insensitive else flag.name + if name in names: + raise TypeError(f'{flag.name!r} flag conflicts with previous flag or alias.') + else: + names.add(name) + + for alias in flag.aliases: + # Validate alias is unique + alias = alias.casefold() if case_insensitive else alias + if alias in names: + raise TypeError(f'{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.') + else: + names.add(alias) + flags[flag.name] = flag return flags @@ -230,6 +256,7 @@ class FlagsMeta(type): if TYPE_CHECKING: __commands_is_flag__: bool __commands_flags__: Dict[str, Flag] + __commands_flag_aliases__: Dict[str, str] __commands_flag_regex__: Pattern[str] __commands_flag_case_insensitive__: bool __commands_flag_delimiter__: str @@ -271,25 +298,37 @@ class FlagsMeta(type): del frame flags: Dict[str, Flag] = {} + aliases: Dict[str, str] = {} for base in reversed(bases): if base.__dict__.get('__commands_is_flag__', False): flags.update(base.__dict__['__commands_flags__']) + aliases.update(base.__dict__['__commands_flag_aliases__']) + + for flag_name, flag in get_flags(attrs, global_ns, local_ns).items(): + flags[flag_name] = flag + aliases.update({alias_name: flag_name for alias_name in flag.aliases}) - flags.update(get_flags(attrs, global_ns, local_ns)) forbidden = set(delimiter).union(prefix) for flag_name in flags: validate_flag_name(flag_name, forbidden) + for alias_name in aliases: + validate_flag_name(alias_name, forbidden) regex_flags = 0 if case_insensitive: flags = {key.casefold(): value for key, value in flags.items()} + aliases = {key.casefold(): value.casefold() for key, value in aliases.items()} regex_flags = re.IGNORECASE - keys = sorted((re.escape(k) for k in flags), key=lambda t: len(t), reverse=True) + keys = list(re.escape(k) for k in flags) + keys.extend(re.escape(a) for a in aliases) + keys = sorted(keys, key=lambda t: len(t), reverse=True) + joined = '|'.join(keys) pattern = re.compile(f'(({re.escape(prefix)})(?P{joined}){re.escape(delimiter)})', regex_flags) attrs['__commands_flag_regex__'] = pattern attrs['__commands_flags__'] = flags + attrs['__commands_flag_aliases__'] = aliases return type.__new__(cls, name, bases, attrs) @@ -432,6 +471,7 @@ class FlagConverter(metaclass=FlagsMeta): def parse_flags(cls, argument: str) -> Dict[str, List[str]]: result: Dict[str, List[str]] = {} flags = cls.__commands_flags__ + aliases = cls.__commands_flag_aliases__ last_position = 0 last_flag: Optional[Flag] = None @@ -442,6 +482,9 @@ class FlagConverter(metaclass=FlagsMeta): if case_insensitive: key = key.casefold() + if key in aliases: + key = aliases[key] + flag = flags.get(key) if last_position and last_flag is not None: value = argument[last_position : begin - 1].lstrip()