mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-04-19 15:36:02 +00:00
[commands] Add support for aliasing to FlagConverter
This commit is contained in:
parent
0c1c9284f6
commit
42463bae67
@ -37,7 +37,7 @@ from .view import StringView
|
||||
from .converter import run_converters
|
||||
|
||||
from discord.utils import maybe_coroutine
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from typing import (
|
||||
Dict,
|
||||
Optional,
|
||||
@ -87,6 +87,8 @@ class Flag:
|
||||
------------
|
||||
name: :class:`str`
|
||||
The name of the flag.
|
||||
aliases: List[:class:`str`]
|
||||
The aliases of the flag name.
|
||||
attribute: :class:`str`
|
||||
The attribute in the class that corresponds to this flag.
|
||||
default: Any
|
||||
@ -101,6 +103,7 @@ class Flag:
|
||||
"""
|
||||
|
||||
name: str = MISSING
|
||||
aliases: List[str] = field(default_factory=list)
|
||||
attribute: str = MISSING
|
||||
annotation: Any = MISSING
|
||||
default: Any = MISSING
|
||||
@ -120,6 +123,7 @@ class Flag:
|
||||
def flag(
|
||||
*,
|
||||
name: str = MISSING,
|
||||
aliases: List[str] = MISSING,
|
||||
default: Any = MISSING,
|
||||
max_args: int = MISSING,
|
||||
override: bool = MISSING,
|
||||
@ -131,6 +135,8 @@ def flag(
|
||||
------------
|
||||
name: :class:`str`
|
||||
The flag name. If not given, defaults to the attribute name.
|
||||
aliases: List[:class:`str`]
|
||||
Aliases to the flag name. If not given no aliases are set.
|
||||
default: Any
|
||||
The default parameter. This could be either a value or a callable that takes
|
||||
:class:`Context` as its sole parameter. If not given then it defaults to
|
||||
@ -143,7 +149,7 @@ def flag(
|
||||
Whether multiple given values overrides the previous value. The default
|
||||
value depends on the annotation given.
|
||||
"""
|
||||
return Flag(name=name, default=default, max_args=max_args, override=override)
|
||||
return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override)
|
||||
|
||||
|
||||
def validate_flag_name(name: str, forbidden: Set[str]):
|
||||
@ -161,8 +167,10 @@ def validate_flag_name(name: str, forbidden: Set[str]):
|
||||
|
||||
def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]:
|
||||
annotations = namespace.get('__annotations__', {})
|
||||
case_insensitive = namespace['__commands_flag_case_insensitive__']
|
||||
flags: Dict[str, Flag] = {}
|
||||
cache: Dict[str, Any] = {}
|
||||
names: Set[str] = set()
|
||||
for name, annotation in annotations.items():
|
||||
flag = namespace.pop(name, MISSING)
|
||||
if isinstance(flag, Flag):
|
||||
@ -176,6 +184,9 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
|
||||
|
||||
annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache)
|
||||
|
||||
if flag.aliases is MISSING:
|
||||
flag.aliases = []
|
||||
|
||||
# Add sensible defaults based off of the type annotation
|
||||
# <type> -> (max_args=1)
|
||||
# List[str] -> (max_args=-1)
|
||||
@ -221,6 +232,21 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
|
||||
if flag.override is MISSING:
|
||||
flag.override = False
|
||||
|
||||
# Validate flag names are unique
|
||||
name = flag.name.casefold() if case_insensitive else flag.name
|
||||
if name in names:
|
||||
raise TypeError(f'{flag.name!r} flag conflicts with previous flag or alias.')
|
||||
else:
|
||||
names.add(name)
|
||||
|
||||
for alias in flag.aliases:
|
||||
# Validate alias is unique
|
||||
alias = alias.casefold() if case_insensitive else alias
|
||||
if alias in names:
|
||||
raise TypeError(f'{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.')
|
||||
else:
|
||||
names.add(alias)
|
||||
|
||||
flags[flag.name] = flag
|
||||
|
||||
return flags
|
||||
@ -230,6 +256,7 @@ class FlagsMeta(type):
|
||||
if TYPE_CHECKING:
|
||||
__commands_is_flag__: bool
|
||||
__commands_flags__: Dict[str, Flag]
|
||||
__commands_flag_aliases__: Dict[str, str]
|
||||
__commands_flag_regex__: Pattern[str]
|
||||
__commands_flag_case_insensitive__: bool
|
||||
__commands_flag_delimiter__: str
|
||||
@ -271,25 +298,37 @@ class FlagsMeta(type):
|
||||
del frame
|
||||
|
||||
flags: Dict[str, Flag] = {}
|
||||
aliases: Dict[str, str] = {}
|
||||
for base in reversed(bases):
|
||||
if base.__dict__.get('__commands_is_flag__', False):
|
||||
flags.update(base.__dict__['__commands_flags__'])
|
||||
aliases.update(base.__dict__['__commands_flag_aliases__'])
|
||||
|
||||
for flag_name, flag in get_flags(attrs, global_ns, local_ns).items():
|
||||
flags[flag_name] = flag
|
||||
aliases.update({alias_name: flag_name for alias_name in flag.aliases})
|
||||
|
||||
flags.update(get_flags(attrs, global_ns, local_ns))
|
||||
forbidden = set(delimiter).union(prefix)
|
||||
for flag_name in flags:
|
||||
validate_flag_name(flag_name, forbidden)
|
||||
for alias_name in aliases:
|
||||
validate_flag_name(alias_name, forbidden)
|
||||
|
||||
regex_flags = 0
|
||||
if case_insensitive:
|
||||
flags = {key.casefold(): value for key, value in flags.items()}
|
||||
aliases = {key.casefold(): value.casefold() for key, value in aliases.items()}
|
||||
regex_flags = re.IGNORECASE
|
||||
|
||||
keys = sorted((re.escape(k) for k in flags), key=lambda t: len(t), reverse=True)
|
||||
keys = list(re.escape(k) for k in flags)
|
||||
keys.extend(re.escape(a) for a in aliases)
|
||||
keys = sorted(keys, key=lambda t: len(t), reverse=True)
|
||||
|
||||
joined = '|'.join(keys)
|
||||
pattern = re.compile(f'(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})', regex_flags)
|
||||
attrs['__commands_flag_regex__'] = pattern
|
||||
attrs['__commands_flags__'] = flags
|
||||
attrs['__commands_flag_aliases__'] = aliases
|
||||
|
||||
return type.__new__(cls, name, bases, attrs)
|
||||
|
||||
@ -432,6 +471,7 @@ class FlagConverter(metaclass=FlagsMeta):
|
||||
def parse_flags(cls, argument: str) -> Dict[str, List[str]]:
|
||||
result: Dict[str, List[str]] = {}
|
||||
flags = cls.__commands_flags__
|
||||
aliases = cls.__commands_flag_aliases__
|
||||
last_position = 0
|
||||
last_flag: Optional[Flag] = None
|
||||
|
||||
@ -442,6 +482,9 @@ class FlagConverter(metaclass=FlagsMeta):
|
||||
if case_insensitive:
|
||||
key = key.casefold()
|
||||
|
||||
if key in aliases:
|
||||
key = aliases[key]
|
||||
|
||||
flag = flags.get(key)
|
||||
if last_position and last_flag is not None:
|
||||
value = argument[last_position : begin - 1].lstrip()
|
||||
|
Loading…
x
Reference in New Issue
Block a user