[commands] Add support for aliasing to FlagConverter
This commit is contained in:
parent
0c1c9284f6
commit
42463bae67
@ -37,7 +37,7 @@ from .view import StringView
|
|||||||
from .converter import run_converters
|
from .converter import run_converters
|
||||||
|
|
||||||
from discord.utils import maybe_coroutine
|
from discord.utils import maybe_coroutine
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass, field
|
||||||
from typing import (
|
from typing import (
|
||||||
Dict,
|
Dict,
|
||||||
Optional,
|
Optional,
|
||||||
@ -87,6 +87,8 @@ class Flag:
|
|||||||
------------
|
------------
|
||||||
name: :class:`str`
|
name: :class:`str`
|
||||||
The name of the flag.
|
The name of the flag.
|
||||||
|
aliases: List[:class:`str`]
|
||||||
|
The aliases of the flag name.
|
||||||
attribute: :class:`str`
|
attribute: :class:`str`
|
||||||
The attribute in the class that corresponds to this flag.
|
The attribute in the class that corresponds to this flag.
|
||||||
default: Any
|
default: Any
|
||||||
@ -101,6 +103,7 @@ class Flag:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
name: str = MISSING
|
name: str = MISSING
|
||||||
|
aliases: List[str] = field(default_factory=list)
|
||||||
attribute: str = MISSING
|
attribute: str = MISSING
|
||||||
annotation: Any = MISSING
|
annotation: Any = MISSING
|
||||||
default: Any = MISSING
|
default: Any = MISSING
|
||||||
@ -120,6 +123,7 @@ class Flag:
|
|||||||
def flag(
|
def flag(
|
||||||
*,
|
*,
|
||||||
name: str = MISSING,
|
name: str = MISSING,
|
||||||
|
aliases: List[str] = MISSING,
|
||||||
default: Any = MISSING,
|
default: Any = MISSING,
|
||||||
max_args: int = MISSING,
|
max_args: int = MISSING,
|
||||||
override: bool = MISSING,
|
override: bool = MISSING,
|
||||||
@ -131,6 +135,8 @@ def flag(
|
|||||||
------------
|
------------
|
||||||
name: :class:`str`
|
name: :class:`str`
|
||||||
The flag name. If not given, defaults to the attribute name.
|
The flag name. If not given, defaults to the attribute name.
|
||||||
|
aliases: List[:class:`str`]
|
||||||
|
Aliases to the flag name. If not given no aliases are set.
|
||||||
default: Any
|
default: Any
|
||||||
The default parameter. This could be either a value or a callable that takes
|
The default parameter. This could be either a value or a callable that takes
|
||||||
:class:`Context` as its sole parameter. If not given then it defaults to
|
:class:`Context` as its sole parameter. If not given then it defaults to
|
||||||
@ -143,7 +149,7 @@ def flag(
|
|||||||
Whether multiple given values overrides the previous value. The default
|
Whether multiple given values overrides the previous value. The default
|
||||||
value depends on the annotation given.
|
value depends on the annotation given.
|
||||||
"""
|
"""
|
||||||
return Flag(name=name, default=default, max_args=max_args, override=override)
|
return Flag(name=name, aliases=aliases, default=default, max_args=max_args, override=override)
|
||||||
|
|
||||||
|
|
||||||
def validate_flag_name(name: str, forbidden: Set[str]):
|
def validate_flag_name(name: str, forbidden: Set[str]):
|
||||||
@ -161,8 +167,10 @@ def validate_flag_name(name: str, forbidden: Set[str]):
|
|||||||
|
|
||||||
def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]:
|
def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]:
|
||||||
annotations = namespace.get('__annotations__', {})
|
annotations = namespace.get('__annotations__', {})
|
||||||
|
case_insensitive = namespace['__commands_flag_case_insensitive__']
|
||||||
flags: Dict[str, Flag] = {}
|
flags: Dict[str, Flag] = {}
|
||||||
cache: Dict[str, Any] = {}
|
cache: Dict[str, Any] = {}
|
||||||
|
names: Set[str] = set()
|
||||||
for name, annotation in annotations.items():
|
for name, annotation in annotations.items():
|
||||||
flag = namespace.pop(name, MISSING)
|
flag = namespace.pop(name, MISSING)
|
||||||
if isinstance(flag, Flag):
|
if isinstance(flag, Flag):
|
||||||
@ -176,6 +184,9 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
|
|||||||
|
|
||||||
annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache)
|
annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache)
|
||||||
|
|
||||||
|
if flag.aliases is MISSING:
|
||||||
|
flag.aliases = []
|
||||||
|
|
||||||
# Add sensible defaults based off of the type annotation
|
# Add sensible defaults based off of the type annotation
|
||||||
# <type> -> (max_args=1)
|
# <type> -> (max_args=1)
|
||||||
# List[str] -> (max_args=-1)
|
# List[str] -> (max_args=-1)
|
||||||
@ -221,6 +232,21 @@ def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[s
|
|||||||
if flag.override is MISSING:
|
if flag.override is MISSING:
|
||||||
flag.override = False
|
flag.override = False
|
||||||
|
|
||||||
|
# Validate flag names are unique
|
||||||
|
name = flag.name.casefold() if case_insensitive else flag.name
|
||||||
|
if name in names:
|
||||||
|
raise TypeError(f'{flag.name!r} flag conflicts with previous flag or alias.')
|
||||||
|
else:
|
||||||
|
names.add(name)
|
||||||
|
|
||||||
|
for alias in flag.aliases:
|
||||||
|
# Validate alias is unique
|
||||||
|
alias = alias.casefold() if case_insensitive else alias
|
||||||
|
if alias in names:
|
||||||
|
raise TypeError(f'{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.')
|
||||||
|
else:
|
||||||
|
names.add(alias)
|
||||||
|
|
||||||
flags[flag.name] = flag
|
flags[flag.name] = flag
|
||||||
|
|
||||||
return flags
|
return flags
|
||||||
@ -230,6 +256,7 @@ class FlagsMeta(type):
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
__commands_is_flag__: bool
|
__commands_is_flag__: bool
|
||||||
__commands_flags__: Dict[str, Flag]
|
__commands_flags__: Dict[str, Flag]
|
||||||
|
__commands_flag_aliases__: Dict[str, str]
|
||||||
__commands_flag_regex__: Pattern[str]
|
__commands_flag_regex__: Pattern[str]
|
||||||
__commands_flag_case_insensitive__: bool
|
__commands_flag_case_insensitive__: bool
|
||||||
__commands_flag_delimiter__: str
|
__commands_flag_delimiter__: str
|
||||||
@ -271,25 +298,37 @@ class FlagsMeta(type):
|
|||||||
del frame
|
del frame
|
||||||
|
|
||||||
flags: Dict[str, Flag] = {}
|
flags: Dict[str, Flag] = {}
|
||||||
|
aliases: Dict[str, str] = {}
|
||||||
for base in reversed(bases):
|
for base in reversed(bases):
|
||||||
if base.__dict__.get('__commands_is_flag__', False):
|
if base.__dict__.get('__commands_is_flag__', False):
|
||||||
flags.update(base.__dict__['__commands_flags__'])
|
flags.update(base.__dict__['__commands_flags__'])
|
||||||
|
aliases.update(base.__dict__['__commands_flag_aliases__'])
|
||||||
|
|
||||||
|
for flag_name, flag in get_flags(attrs, global_ns, local_ns).items():
|
||||||
|
flags[flag_name] = flag
|
||||||
|
aliases.update({alias_name: flag_name for alias_name in flag.aliases})
|
||||||
|
|
||||||
flags.update(get_flags(attrs, global_ns, local_ns))
|
|
||||||
forbidden = set(delimiter).union(prefix)
|
forbidden = set(delimiter).union(prefix)
|
||||||
for flag_name in flags:
|
for flag_name in flags:
|
||||||
validate_flag_name(flag_name, forbidden)
|
validate_flag_name(flag_name, forbidden)
|
||||||
|
for alias_name in aliases:
|
||||||
|
validate_flag_name(alias_name, forbidden)
|
||||||
|
|
||||||
regex_flags = 0
|
regex_flags = 0
|
||||||
if case_insensitive:
|
if case_insensitive:
|
||||||
flags = {key.casefold(): value for key, value in flags.items()}
|
flags = {key.casefold(): value for key, value in flags.items()}
|
||||||
|
aliases = {key.casefold(): value.casefold() for key, value in aliases.items()}
|
||||||
regex_flags = re.IGNORECASE
|
regex_flags = re.IGNORECASE
|
||||||
|
|
||||||
keys = sorted((re.escape(k) for k in flags), key=lambda t: len(t), reverse=True)
|
keys = list(re.escape(k) for k in flags)
|
||||||
|
keys.extend(re.escape(a) for a in aliases)
|
||||||
|
keys = sorted(keys, key=lambda t: len(t), reverse=True)
|
||||||
|
|
||||||
joined = '|'.join(keys)
|
joined = '|'.join(keys)
|
||||||
pattern = re.compile(f'(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})', regex_flags)
|
pattern = re.compile(f'(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})', regex_flags)
|
||||||
attrs['__commands_flag_regex__'] = pattern
|
attrs['__commands_flag_regex__'] = pattern
|
||||||
attrs['__commands_flags__'] = flags
|
attrs['__commands_flags__'] = flags
|
||||||
|
attrs['__commands_flag_aliases__'] = aliases
|
||||||
|
|
||||||
return type.__new__(cls, name, bases, attrs)
|
return type.__new__(cls, name, bases, attrs)
|
||||||
|
|
||||||
@ -432,6 +471,7 @@ class FlagConverter(metaclass=FlagsMeta):
|
|||||||
def parse_flags(cls, argument: str) -> Dict[str, List[str]]:
|
def parse_flags(cls, argument: str) -> Dict[str, List[str]]:
|
||||||
result: Dict[str, List[str]] = {}
|
result: Dict[str, List[str]] = {}
|
||||||
flags = cls.__commands_flags__
|
flags = cls.__commands_flags__
|
||||||
|
aliases = cls.__commands_flag_aliases__
|
||||||
last_position = 0
|
last_position = 0
|
||||||
last_flag: Optional[Flag] = None
|
last_flag: Optional[Flag] = None
|
||||||
|
|
||||||
@ -442,6 +482,9 @@ class FlagConverter(metaclass=FlagsMeta):
|
|||||||
if case_insensitive:
|
if case_insensitive:
|
||||||
key = key.casefold()
|
key = key.casefold()
|
||||||
|
|
||||||
|
if key in aliases:
|
||||||
|
key = aliases[key]
|
||||||
|
|
||||||
flag = flags.get(key)
|
flag = flags.get(key)
|
||||||
if last_position and last_flag is not None:
|
if last_position and last_flag is not None:
|
||||||
value = argument[last_position : begin - 1].lstrip()
|
value = argument[last_position : begin - 1].lstrip()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user