mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-10-24 01:53:01 +00:00
[commands] use __args__ and __origin__ where applicable
This commit is contained in:
@@ -30,8 +30,6 @@ from typing import (
|
||||
Literal,
|
||||
Tuple,
|
||||
Union,
|
||||
get_args as get_typing_args,
|
||||
get_origin as get_typing_origin,
|
||||
)
|
||||
import asyncio
|
||||
import functools
|
||||
@@ -86,6 +84,10 @@ def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
|
||||
params.append(p)
|
||||
return tuple(params)
|
||||
|
||||
def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
|
||||
none_cls = type(None)
|
||||
return tuple(p for p in parameters if p is not none_cls) + (none_cls,)
|
||||
|
||||
def _evaluate_annotation(tp: Any, globals: Dict[str, Any], cache: Dict[str, Any] = {}, *, implicit_str=True):
|
||||
if isinstance(tp, ForwardRef):
|
||||
tp = tp.__forward_arg__
|
||||
@@ -102,6 +104,12 @@ def _evaluate_annotation(tp: Any, globals: Dict[str, Any], cache: Dict[str, Any]
|
||||
if hasattr(tp, '__args__'):
|
||||
implicit_str = True
|
||||
args = tp.__args__
|
||||
if tp.__origin__ is Union:
|
||||
try:
|
||||
if args.index(type(None)) != len(args) - 1:
|
||||
args = normalise_optional_params(tp.__args__)
|
||||
except ValueError:
|
||||
pass
|
||||
if tp.__origin__ is Literal:
|
||||
if not PY_310:
|
||||
args = flatten_literal_params(tp.__args__)
|
||||
@@ -547,12 +555,13 @@ class Command(_BaseCommand):
|
||||
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
|
||||
|
||||
async def do_conversion(self, ctx, converter, argument, param):
|
||||
origin = get_typing_origin(converter)
|
||||
origin = getattr(converter, '__origin__', None)
|
||||
|
||||
if origin is Union:
|
||||
errors = []
|
||||
_NoneType = type(None)
|
||||
for conv in get_typing_args(converter):
|
||||
union_args = converter.__args__
|
||||
for conv in union_args:
|
||||
# if we got to this part in the code, then the previous conversions have failed
|
||||
# so we should just undo the view, return the default, and allow parsing to continue
|
||||
# with the other parameters
|
||||
@@ -568,12 +577,13 @@ class Command(_BaseCommand):
|
||||
return value
|
||||
|
||||
# if we're here, then we failed all the converters
|
||||
raise BadUnionArgument(param, get_typing_args(converter), errors)
|
||||
raise BadUnionArgument(param, union_args, errors)
|
||||
|
||||
if origin is Literal:
|
||||
errors = []
|
||||
conversions = {}
|
||||
for literal in converter.__args__:
|
||||
literal_args = converter.__args__
|
||||
for literal in literal_args:
|
||||
literal_type = type(literal)
|
||||
try:
|
||||
value = conversions[literal_type]
|
||||
@@ -591,7 +601,7 @@ class Command(_BaseCommand):
|
||||
return value
|
||||
|
||||
# if we're here, then we failed to match all the literals
|
||||
raise BadLiteralArgument(param, converter.__args__, errors)
|
||||
raise BadLiteralArgument(param, literal_args, errors)
|
||||
|
||||
return await self._actual_conversion(ctx, converter, argument, param)
|
||||
|
||||
@@ -614,7 +624,7 @@ class Command(_BaseCommand):
|
||||
# The greedy converter is simple -- it keeps going until it fails in which case,
|
||||
# it undos the view ready for the next parameter to use instead
|
||||
if isinstance(converter, converters.Greedy):
|
||||
if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.POSITIONAL_ONLY:
|
||||
if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY):
|
||||
return await self._transform_greedy_pos(ctx, param, required, converter.converter)
|
||||
elif param.kind == param.VAR_POSITIONAL:
|
||||
return await self._transform_greedy_var_pos(ctx, param, converter.converter)
|
||||
@@ -782,7 +792,7 @@ class Command(_BaseCommand):
|
||||
raise discord.ClientException(f'Callback for {self.name} command is missing "ctx" parameter.')
|
||||
|
||||
for name, param in iterator:
|
||||
if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.POSITIONAL_ONLY:
|
||||
if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY):
|
||||
transformed = await self.transform(ctx, param)
|
||||
args.append(transformed)
|
||||
elif param.kind == param.KEYWORD_ONLY:
|
||||
@@ -1074,7 +1084,7 @@ class Command(_BaseCommand):
|
||||
return ''
|
||||
|
||||
def _is_typing_optional(self, annotation):
|
||||
return get_typing_origin(annotation) is Union and get_typing_args(annotation)[-1] is type(None)
|
||||
return getattr(annotation, '__origin__', None) is Union and type(None) in annotation.__args__
|
||||
|
||||
@property
|
||||
def signature(self):
|
||||
@@ -1094,13 +1104,14 @@ class Command(_BaseCommand):
|
||||
# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
|
||||
# parameter signature is a literal list of it's values
|
||||
annotation = param.annotation.converter if greedy else param.annotation
|
||||
origin = get_typing_origin(annotation)
|
||||
origin = getattr(annotation, '__origin__', None)
|
||||
if not greedy and origin is Union:
|
||||
union_args = get_typing_args(annotation)
|
||||
optional = union_args[-1] is type(None)
|
||||
if optional:
|
||||
none_cls = type(None)
|
||||
union_args = annotation.__args__
|
||||
optional = union_args[-1] is none_cls
|
||||
if len(union_args) == 2 and optional:
|
||||
annotation = union_args[0]
|
||||
origin = get_typing_origin(annotation)
|
||||
origin = getattr(annotation, '__origin__', None)
|
||||
|
||||
if origin is Literal:
|
||||
name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__)
|
||||
|
@@ -23,7 +23,6 @@ DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from discord.errors import ClientException, DiscordException
|
||||
import typing
|
||||
|
||||
|
||||
__all__ = (
|
||||
@@ -646,7 +645,7 @@ class BadUnionArgument(UserInputError):
|
||||
try:
|
||||
return x.__name__
|
||||
except AttributeError:
|
||||
if typing.get_origin(x) is not None:
|
||||
if hasattr(x, '__origin__'):
|
||||
return repr(x)
|
||||
return x.__class__.__name__
|
||||
|
||||
|
Reference in New Issue
Block a user