This only applies if and only if the flag can be default constructible. Ergo, all the flags are optional or not required.
547 lines
18 KiB
Python
547 lines
18 KiB
Python
"""
|
|
The MIT License (MIT)
|
|
|
|
Copyright (c) 2015-present Rapptz
|
|
|
|
Permission is hereby granted, free of charge, to any person obtaining a
|
|
copy of this software and associated documentation files (the "Software"),
|
|
to deal in the Software without restriction, including without limitation
|
|
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
|
and/or sell copies of the Software, and to permit persons to whom the
|
|
Software is furnished to do so, subject to the following conditions:
|
|
|
|
The above copyright notice and this permission notice shall be included in
|
|
all copies or substantial portions of the Software.
|
|
|
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
|
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
|
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
|
DEALINGS IN THE SOFTWARE.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from .errors import (
|
|
BadFlagArgument,
|
|
CommandError,
|
|
MissingFlagArgument,
|
|
TooManyFlags,
|
|
MissingRequiredFlag,
|
|
)
|
|
|
|
from .core import resolve_annotation
|
|
from .view import StringView
|
|
from .converter import run_converters
|
|
|
|
from discord.utils import maybe_coroutine
|
|
from dataclasses import dataclass
|
|
from typing import (
|
|
Dict,
|
|
Optional,
|
|
Pattern,
|
|
Set,
|
|
TYPE_CHECKING,
|
|
Tuple,
|
|
List,
|
|
Any,
|
|
Type,
|
|
TypeVar,
|
|
Union,
|
|
)
|
|
|
|
import inspect
|
|
import sys
|
|
import re
|
|
|
|
__all__ = (
|
|
'Flag',
|
|
'flag',
|
|
'FlagConverter',
|
|
)
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
from .context import Context
|
|
|
|
|
|
class _MissingSentinel:
|
|
def __repr__(self):
|
|
return 'MISSING'
|
|
|
|
|
|
MISSING: Any = _MissingSentinel()
|
|
|
|
|
|
@dataclass
|
|
class Flag:
|
|
"""Represents a flag parameter for :class:`FlagConverter`.
|
|
|
|
The :func:`~discord.ext.commands.flag` function helps
|
|
create these flag objects, but it is not necessary to
|
|
do so. These cannot be constructed manually.
|
|
|
|
Attributes
|
|
------------
|
|
name: :class:`str`
|
|
The name of the flag.
|
|
attribute: :class:`str`
|
|
The attribute in the class that corresponds to this flag.
|
|
default: Any
|
|
The default value of the flag, if available.
|
|
annotation: Any
|
|
The underlying evaluated annotation of the flag.
|
|
max_args: :class:`int`
|
|
The maximum number of arguments the flag can accept.
|
|
A negative value indicates an unlimited amount of arguments.
|
|
override: :class:`bool`
|
|
Whether multiple given values overrides the previous value.
|
|
"""
|
|
|
|
name: str = MISSING
|
|
attribute: str = MISSING
|
|
annotation: Any = MISSING
|
|
default: Any = MISSING
|
|
max_args: int = MISSING
|
|
override: bool = MISSING
|
|
cast_to_dict: bool = False
|
|
|
|
@property
|
|
def required(self) -> bool:
|
|
""":class:`bool`: Whether the flag is required.
|
|
|
|
A required flag has no default value.
|
|
"""
|
|
return self.default is MISSING
|
|
|
|
|
|
def flag(
|
|
*,
|
|
name: str = MISSING,
|
|
default: Any = MISSING,
|
|
max_args: int = MISSING,
|
|
override: bool = MISSING,
|
|
) -> Any:
|
|
"""Override default functionality and parameters of the underlying :class:`FlagConverter`
|
|
class attributes.
|
|
|
|
Parameters
|
|
------------
|
|
name: :class:`str`
|
|
The flag name. If not given, defaults to the attribute name.
|
|
default: Any
|
|
The default parameter. This could be either a value or a callable that takes
|
|
:class:`Context` as its sole parameter. If not given then it defaults to
|
|
the default value given to the attribute.
|
|
max_args: :class:`int`
|
|
The maximum number of arguments the flag can accept.
|
|
A negative value indicates an unlimited amount of arguments.
|
|
The default value depends on the annotation given.
|
|
override: :class:`bool`
|
|
Whether multiple given values overrides the previous value. The default
|
|
value depends on the annotation given.
|
|
"""
|
|
return Flag(name=name, default=default, max_args=max_args, override=override)
|
|
|
|
|
|
def validate_flag_name(name: str, forbidden: Set[str]):
|
|
if not name:
|
|
raise ValueError('flag names should not be empty')
|
|
|
|
for ch in name:
|
|
if ch.isspace():
|
|
raise ValueError(f'flag name {name!r} cannot have spaces')
|
|
if ch == '\\':
|
|
raise ValueError(f'flag name {name!r} cannot have backslashes')
|
|
if ch in forbidden:
|
|
raise ValueError(f'flag name {name!r} cannot have any of {forbidden!r} within them')
|
|
|
|
|
|
def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]:
|
|
annotations = namespace.get('__annotations__', {})
|
|
flags: Dict[str, Flag] = {}
|
|
cache: Dict[str, Any] = {}
|
|
for name, annotation in annotations.items():
|
|
flag = namespace.pop(name, MISSING)
|
|
if isinstance(flag, Flag):
|
|
flag.annotation = annotation
|
|
else:
|
|
flag = Flag(name=name, annotation=annotation, default=flag)
|
|
|
|
flag.attribute = name
|
|
if flag.name is MISSING:
|
|
flag.name = name
|
|
|
|
annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache)
|
|
|
|
# Add sensible defaults based off of the type annotation
|
|
# <type> -> (max_args=1)
|
|
# List[str] -> (max_args=-1)
|
|
# Tuple[int, ...] -> (max_args=1)
|
|
# Dict[K, V] -> (max_args=-1, override=True)
|
|
# Optional[str] -> (default=None, max_args=1)
|
|
|
|
try:
|
|
origin = annotation.__origin__
|
|
except AttributeError:
|
|
# A regular type hint
|
|
if flag.max_args is MISSING:
|
|
flag.max_args = 1
|
|
else:
|
|
if origin is Union and annotation.__args__[-1] is type(None):
|
|
# typing.Optional
|
|
if flag.max_args is MISSING:
|
|
flag.max_args = 1
|
|
if flag.default is MISSING:
|
|
flag.default = None
|
|
elif origin is tuple:
|
|
# typing.Tuple
|
|
# tuple parsing is e.g. `flag: peter 20`
|
|
# for Tuple[str, int] would give you flag: ('peter', 20)
|
|
if flag.max_args is MISSING:
|
|
flag.max_args = 1
|
|
elif origin is list:
|
|
# typing.List
|
|
if flag.max_args is MISSING:
|
|
flag.max_args = -1
|
|
elif origin is dict:
|
|
# typing.Dict[K, V]
|
|
# Equivalent to:
|
|
# typing.List[typing.Tuple[K, V]]
|
|
flag.cast_to_dict = True
|
|
if flag.max_args is MISSING:
|
|
flag.max_args = -1
|
|
if flag.override is MISSING:
|
|
flag.override = True
|
|
else:
|
|
raise TypeError(f'Unsupported typing annotation {annotation!r} for {flag.name!r} flag')
|
|
|
|
if flag.override is MISSING:
|
|
flag.override = False
|
|
|
|
flags[flag.name] = flag
|
|
|
|
return flags
|
|
|
|
|
|
class FlagsMeta(type):
|
|
if TYPE_CHECKING:
|
|
__commands_is_flag__: bool
|
|
__commands_flags__: Dict[str, Flag]
|
|
__commands_flag_regex__: Pattern[str]
|
|
__commands_flag_case_insensitive__: bool
|
|
__commands_flag_delimiter__: str
|
|
__commands_flag_prefix__: str
|
|
|
|
def __new__(
|
|
cls: Type[type],
|
|
name: str,
|
|
bases: Tuple[type, ...],
|
|
attrs: Dict[str, Any],
|
|
*,
|
|
case_insensitive: bool = False,
|
|
delimiter: str = ':',
|
|
prefix: str = '',
|
|
):
|
|
attrs['__commands_is_flag__'] = True
|
|
attrs['__commands_flag_case_insensitive__'] = case_insensitive
|
|
attrs['__commands_flag_delimiter__'] = delimiter
|
|
attrs['__commands_flag_prefix__'] = prefix
|
|
|
|
if not prefix and not delimiter:
|
|
raise TypeError('Must have either a delimiter or a prefix set')
|
|
|
|
try:
|
|
global_ns = sys.modules[attrs['__module__']].__dict__
|
|
except KeyError:
|
|
global_ns = {}
|
|
|
|
frame = inspect.currentframe()
|
|
try:
|
|
if frame is None:
|
|
local_ns = {}
|
|
else:
|
|
if frame.f_back is None:
|
|
local_ns = frame.f_locals
|
|
else:
|
|
local_ns = frame.f_back.f_locals
|
|
finally:
|
|
del frame
|
|
|
|
flags: Dict[str, Flag] = {}
|
|
for base in reversed(bases):
|
|
if base.__dict__.get('__commands_is_flag__', False):
|
|
flags.update(base.__dict__['__commands_flags__'])
|
|
|
|
flags.update(get_flags(attrs, global_ns, local_ns))
|
|
forbidden = set(delimiter).union(prefix)
|
|
for flag_name in flags:
|
|
validate_flag_name(flag_name, forbidden)
|
|
|
|
regex_flags = 0
|
|
if case_insensitive:
|
|
flags = {key.casefold(): value for key, value in flags.items()}
|
|
regex_flags = re.IGNORECASE
|
|
|
|
keys = sorted((re.escape(k) for k in flags), key=lambda t: len(t), reverse=True)
|
|
joined = '|'.join(keys)
|
|
pattern = re.compile(f'(({re.escape(prefix)})(?P<flag>{joined}){re.escape(delimiter)})', regex_flags)
|
|
attrs['__commands_flag_regex__'] = pattern
|
|
attrs['__commands_flags__'] = flags
|
|
|
|
return type.__new__(cls, name, bases, attrs)
|
|
|
|
|
|
async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]:
|
|
view = StringView(argument)
|
|
results = []
|
|
param: inspect.Parameter = ctx.current_parameter # type: ignore
|
|
while not view.eof:
|
|
view.skip_ws()
|
|
if view.eof:
|
|
break
|
|
|
|
word = view.get_quoted_word()
|
|
if word is None:
|
|
break
|
|
|
|
try:
|
|
converted = await run_converters(ctx, converter, word, param)
|
|
except CommandError:
|
|
raise
|
|
except Exception as e:
|
|
raise BadFlagArgument(flag) from e
|
|
else:
|
|
results.append(converted)
|
|
|
|
return tuple(results)
|
|
|
|
|
|
async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]:
|
|
view = StringView(argument)
|
|
results = []
|
|
param: inspect.Parameter = ctx.current_parameter # type: ignore
|
|
for converter in converters:
|
|
view.skip_ws()
|
|
if view.eof:
|
|
break
|
|
|
|
word = view.get_quoted_word()
|
|
if word is None:
|
|
break
|
|
|
|
try:
|
|
converted = await run_converters(ctx, converter, word, param)
|
|
except CommandError:
|
|
raise
|
|
except Exception as e:
|
|
raise BadFlagArgument(flag) from e
|
|
else:
|
|
results.append(converted)
|
|
|
|
if len(results) != len(converters):
|
|
raise BadFlagArgument(flag)
|
|
|
|
return tuple(results)
|
|
|
|
|
|
async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) -> Any:
|
|
param: inspect.Parameter = ctx.current_parameter # type: ignore
|
|
annotation = annotation or flag.annotation
|
|
try:
|
|
origin = annotation.__origin__
|
|
except AttributeError:
|
|
pass
|
|
else:
|
|
if origin is tuple:
|
|
if annotation.__args__[-1] is Ellipsis:
|
|
return await tuple_convert_all(ctx, argument, flag, annotation.__args__[0])
|
|
else:
|
|
return await tuple_convert_flag(ctx, argument, flag, annotation.__args__)
|
|
elif origin is list or origin is Union and annotation.__args__[-1] is type(None):
|
|
# typing.List[x] or typing.Optional[x]
|
|
annotation = annotation.__args__[0]
|
|
return await convert_flag(ctx, argument, flag, annotation)
|
|
elif origin is dict:
|
|
# typing.Dict[K, V] -> typing.Tuple[K, V]
|
|
return await tuple_convert_flag(ctx, argument, flag, annotation.__args__)
|
|
|
|
try:
|
|
return await run_converters(ctx, annotation, argument, param)
|
|
except CommandError:
|
|
raise
|
|
except Exception as e:
|
|
raise BadFlagArgument(flag) from e
|
|
|
|
|
|
F = TypeVar('F', bound='FlagConverter')
|
|
|
|
|
|
class FlagConverter(metaclass=FlagsMeta):
|
|
"""A converter that allows for a user-friendly flag syntax.
|
|
|
|
The flags are defined using :pep:`526` type annotations similar
|
|
to the :mod:`dataclasses` Python module. For more information on
|
|
how this converter works, check the appropriate
|
|
:ref:`documentation <ext_commands_flag_converter>`.
|
|
|
|
.. versionadded:: 2.0
|
|
|
|
Parameters
|
|
-----------
|
|
case_insensitive: :class:`bool`
|
|
A class parameter to toggle case insensitivity of the flag parsing.
|
|
If ``True`` then flags are parsed in a case insensitive manner.
|
|
Defaults to ``False``.
|
|
prefix: :class:`str`
|
|
The prefix that all flags must be prefixed with. By default
|
|
there is no prefix.
|
|
delimiter: :class:`str`
|
|
The delimiter that separates a flag's argument from the flag's name.
|
|
By default this is ``:``.
|
|
"""
|
|
|
|
@classmethod
|
|
def get_flags(cls) -> Dict[str, Flag]:
|
|
"""Dict[:class:`str`, :class:`Flag`]: A mapping of flag name to flag object this converter has."""
|
|
return cls.__commands_flags__.copy()
|
|
|
|
@classmethod
|
|
def _can_be_constructible(cls) -> bool:
|
|
return all(not flag.required for flag in cls.__commands_flags__.values())
|
|
|
|
@classmethod
|
|
async def _construct_default(cls: Type[F], ctx: Context) -> F:
|
|
self: F = cls.__new__(cls)
|
|
flags = cls.__commands_flags__
|
|
for flag in flags.values():
|
|
if callable(flag.default):
|
|
default = await maybe_coroutine(flag.default, ctx)
|
|
setattr(self, flag.attribute, default)
|
|
else:
|
|
setattr(self, flag.attribute, flag.default)
|
|
return self
|
|
|
|
def __repr__(self) -> str:
|
|
pairs = ' '.join([f'{flag.attribute}={getattr(self, flag.attribute)!r}' for flag in self.get_flags().values()])
|
|
return f'<{self.__class__.__name__} {pairs}>'
|
|
|
|
@classmethod
|
|
def parse_flags(cls, argument: str) -> Dict[str, List[str]]:
|
|
result: Dict[str, List[str]] = {}
|
|
flags = cls.get_flags()
|
|
last_position = 0
|
|
last_flag: Optional[Flag] = None
|
|
|
|
case_insensitive = cls.__commands_flag_case_insensitive__
|
|
for match in cls.__commands_flag_regex__.finditer(argument):
|
|
begin, end = match.span(0)
|
|
key = match.group('flag')
|
|
if case_insensitive:
|
|
key = key.casefold()
|
|
|
|
flag = flags.get(key)
|
|
if last_position and last_flag is not None:
|
|
value = argument[last_position : begin - 1].lstrip()
|
|
if not value:
|
|
raise MissingFlagArgument(last_flag)
|
|
|
|
try:
|
|
values = result[last_flag.name]
|
|
except KeyError:
|
|
result[last_flag.name] = [value]
|
|
else:
|
|
values.append(value)
|
|
|
|
last_position = end
|
|
last_flag = flag
|
|
|
|
# Add the remaining string to the last available flag
|
|
if last_position and last_flag is not None:
|
|
value = argument[last_position:].strip()
|
|
if not value:
|
|
raise MissingFlagArgument(last_flag)
|
|
|
|
try:
|
|
values = result[last_flag.name]
|
|
except KeyError:
|
|
result[last_flag.name] = [value]
|
|
else:
|
|
values.append(value)
|
|
|
|
# Verification of values will come at a later stage
|
|
return result
|
|
|
|
@classmethod
|
|
async def convert(cls: Type[F], ctx: Context, argument: str) -> F:
|
|
"""|coro|
|
|
|
|
The method that actually converters an argument to the flag mapping.
|
|
|
|
Parameters
|
|
----------
|
|
cls: Type[:class:`FlagConverter`]
|
|
The flag converter class.
|
|
ctx: :class:`Context`
|
|
The invocation context.
|
|
argument: :class:`str`
|
|
The argument to convert from.
|
|
|
|
Raises
|
|
--------
|
|
FlagError
|
|
A flag related parsing error.
|
|
CommandError
|
|
A command related error.
|
|
|
|
Returns
|
|
--------
|
|
:class:`FlagConverter`
|
|
The flag converter instance with all flags parsed.
|
|
"""
|
|
arguments = cls.parse_flags(argument)
|
|
flags = cls.get_flags()
|
|
|
|
self: F = cls.__new__(cls)
|
|
for name, flag in flags.items():
|
|
try:
|
|
values = arguments[name]
|
|
except KeyError:
|
|
if flag.required:
|
|
raise MissingRequiredFlag(flag)
|
|
else:
|
|
if callable(flag.default):
|
|
default = await maybe_coroutine(flag.default, ctx)
|
|
setattr(self, flag.attribute, default)
|
|
else:
|
|
setattr(self, flag.attribute, flag.default)
|
|
continue
|
|
|
|
if flag.max_args > 0 and len(values) > flag.max_args:
|
|
if flag.override:
|
|
values = values[-flag.max_args :]
|
|
else:
|
|
raise TooManyFlags(flag, values)
|
|
|
|
# Special case:
|
|
if flag.max_args == 1:
|
|
value = await convert_flag(ctx, values[0], flag)
|
|
setattr(self, flag.attribute, value)
|
|
continue
|
|
|
|
# Another special case, tuple parsing.
|
|
# Tuple parsing is basically converting arguments within the flag
|
|
# So, given flag: hello 20 as the input and Tuple[str, int] as the type hint
|
|
# We would receive ('hello', 20) as the resulting value
|
|
# This uses the same whitespace and quoting rules as regular parameters.
|
|
values = [await convert_flag(ctx, value, flag) for value in values]
|
|
|
|
if flag.cast_to_dict:
|
|
values = dict(values) # type: ignore
|
|
|
|
setattr(self, flag.attribute, values)
|
|
|
|
return self
|