mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-06-07 04:17:16 +00:00
[commands] use __args__ and __origin__ where applicable
This commit is contained in:
parent
c54c4cb215
commit
7f91ae8b67
@ -30,8 +30,6 @@ from typing import (
|
|||||||
Literal,
|
Literal,
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
get_args as get_typing_args,
|
|
||||||
get_origin as get_typing_origin,
|
|
||||||
)
|
)
|
||||||
import asyncio
|
import asyncio
|
||||||
import functools
|
import functools
|
||||||
@ -86,6 +84,10 @@ def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
|
|||||||
params.append(p)
|
params.append(p)
|
||||||
return tuple(params)
|
return tuple(params)
|
||||||
|
|
||||||
|
def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
|
||||||
|
none_cls = type(None)
|
||||||
|
return tuple(p for p in parameters if p is not none_cls) + (none_cls,)
|
||||||
|
|
||||||
def _evaluate_annotation(tp: Any, globals: Dict[str, Any], cache: Dict[str, Any] = {}, *, implicit_str=True):
|
def _evaluate_annotation(tp: Any, globals: Dict[str, Any], cache: Dict[str, Any] = {}, *, implicit_str=True):
|
||||||
if isinstance(tp, ForwardRef):
|
if isinstance(tp, ForwardRef):
|
||||||
tp = tp.__forward_arg__
|
tp = tp.__forward_arg__
|
||||||
@ -102,6 +104,12 @@ def _evaluate_annotation(tp: Any, globals: Dict[str, Any], cache: Dict[str, Any]
|
|||||||
if hasattr(tp, '__args__'):
|
if hasattr(tp, '__args__'):
|
||||||
implicit_str = True
|
implicit_str = True
|
||||||
args = tp.__args__
|
args = tp.__args__
|
||||||
|
if tp.__origin__ is Union:
|
||||||
|
try:
|
||||||
|
if args.index(type(None)) != len(args) - 1:
|
||||||
|
args = normalise_optional_params(tp.__args__)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
if tp.__origin__ is Literal:
|
if tp.__origin__ is Literal:
|
||||||
if not PY_310:
|
if not PY_310:
|
||||||
args = flatten_literal_params(tp.__args__)
|
args = flatten_literal_params(tp.__args__)
|
||||||
@ -547,12 +555,13 @@ class Command(_BaseCommand):
|
|||||||
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
|
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
|
||||||
|
|
||||||
async def do_conversion(self, ctx, converter, argument, param):
|
async def do_conversion(self, ctx, converter, argument, param):
|
||||||
origin = get_typing_origin(converter)
|
origin = getattr(converter, '__origin__', None)
|
||||||
|
|
||||||
if origin is Union:
|
if origin is Union:
|
||||||
errors = []
|
errors = []
|
||||||
_NoneType = type(None)
|
_NoneType = type(None)
|
||||||
for conv in get_typing_args(converter):
|
union_args = converter.__args__
|
||||||
|
for conv in union_args:
|
||||||
# if we got to this part in the code, then the previous conversions have failed
|
# if we got to this part in the code, then the previous conversions have failed
|
||||||
# so we should just undo the view, return the default, and allow parsing to continue
|
# so we should just undo the view, return the default, and allow parsing to continue
|
||||||
# with the other parameters
|
# with the other parameters
|
||||||
@ -568,12 +577,13 @@ class Command(_BaseCommand):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
# if we're here, then we failed all the converters
|
# if we're here, then we failed all the converters
|
||||||
raise BadUnionArgument(param, get_typing_args(converter), errors)
|
raise BadUnionArgument(param, union_args, errors)
|
||||||
|
|
||||||
if origin is Literal:
|
if origin is Literal:
|
||||||
errors = []
|
errors = []
|
||||||
conversions = {}
|
conversions = {}
|
||||||
for literal in converter.__args__:
|
literal_args = converter.__args__
|
||||||
|
for literal in literal_args:
|
||||||
literal_type = type(literal)
|
literal_type = type(literal)
|
||||||
try:
|
try:
|
||||||
value = conversions[literal_type]
|
value = conversions[literal_type]
|
||||||
@ -591,7 +601,7 @@ class Command(_BaseCommand):
|
|||||||
return value
|
return value
|
||||||
|
|
||||||
# if we're here, then we failed to match all the literals
|
# if we're here, then we failed to match all the literals
|
||||||
raise BadLiteralArgument(param, converter.__args__, errors)
|
raise BadLiteralArgument(param, literal_args, errors)
|
||||||
|
|
||||||
return await self._actual_conversion(ctx, converter, argument, param)
|
return await self._actual_conversion(ctx, converter, argument, param)
|
||||||
|
|
||||||
@ -614,7 +624,7 @@ class Command(_BaseCommand):
|
|||||||
# The greedy converter is simple -- it keeps going until it fails in which case,
|
# The greedy converter is simple -- it keeps going until it fails in which case,
|
||||||
# it undos the view ready for the next parameter to use instead
|
# it undos the view ready for the next parameter to use instead
|
||||||
if isinstance(converter, converters.Greedy):
|
if isinstance(converter, converters.Greedy):
|
||||||
if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.POSITIONAL_ONLY:
|
if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY):
|
||||||
return await self._transform_greedy_pos(ctx, param, required, converter.converter)
|
return await self._transform_greedy_pos(ctx, param, required, converter.converter)
|
||||||
elif param.kind == param.VAR_POSITIONAL:
|
elif param.kind == param.VAR_POSITIONAL:
|
||||||
return await self._transform_greedy_var_pos(ctx, param, converter.converter)
|
return await self._transform_greedy_var_pos(ctx, param, converter.converter)
|
||||||
@ -782,7 +792,7 @@ class Command(_BaseCommand):
|
|||||||
raise discord.ClientException(f'Callback for {self.name} command is missing "ctx" parameter.')
|
raise discord.ClientException(f'Callback for {self.name} command is missing "ctx" parameter.')
|
||||||
|
|
||||||
for name, param in iterator:
|
for name, param in iterator:
|
||||||
if param.kind == param.POSITIONAL_OR_KEYWORD or param.kind == param.POSITIONAL_ONLY:
|
if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY):
|
||||||
transformed = await self.transform(ctx, param)
|
transformed = await self.transform(ctx, param)
|
||||||
args.append(transformed)
|
args.append(transformed)
|
||||||
elif param.kind == param.KEYWORD_ONLY:
|
elif param.kind == param.KEYWORD_ONLY:
|
||||||
@ -1074,7 +1084,7 @@ class Command(_BaseCommand):
|
|||||||
return ''
|
return ''
|
||||||
|
|
||||||
def _is_typing_optional(self, annotation):
|
def _is_typing_optional(self, annotation):
|
||||||
return get_typing_origin(annotation) is Union and get_typing_args(annotation)[-1] is type(None)
|
return getattr(annotation, '__origin__', None) is Union and type(None) in annotation.__args__
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def signature(self):
|
def signature(self):
|
||||||
@ -1094,13 +1104,14 @@ class Command(_BaseCommand):
|
|||||||
# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
|
# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
|
||||||
# parameter signature is a literal list of it's values
|
# parameter signature is a literal list of it's values
|
||||||
annotation = param.annotation.converter if greedy else param.annotation
|
annotation = param.annotation.converter if greedy else param.annotation
|
||||||
origin = get_typing_origin(annotation)
|
origin = getattr(annotation, '__origin__', None)
|
||||||
if not greedy and origin is Union:
|
if not greedy and origin is Union:
|
||||||
union_args = get_typing_args(annotation)
|
none_cls = type(None)
|
||||||
optional = union_args[-1] is type(None)
|
union_args = annotation.__args__
|
||||||
if optional:
|
optional = union_args[-1] is none_cls
|
||||||
|
if len(union_args) == 2 and optional:
|
||||||
annotation = union_args[0]
|
annotation = union_args[0]
|
||||||
origin = get_typing_origin(annotation)
|
origin = getattr(annotation, '__origin__', None)
|
||||||
|
|
||||||
if origin is Literal:
|
if origin is Literal:
|
||||||
name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__)
|
name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__)
|
||||||
|
@ -23,7 +23,6 @@ DEALINGS IN THE SOFTWARE.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from discord.errors import ClientException, DiscordException
|
from discord.errors import ClientException, DiscordException
|
||||||
import typing
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
@ -646,7 +645,7 @@ class BadUnionArgument(UserInputError):
|
|||||||
try:
|
try:
|
||||||
return x.__name__
|
return x.__name__
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
if typing.get_origin(x) is not None:
|
if hasattr(x, '__origin__'):
|
||||||
return repr(x)
|
return repr(x)
|
||||||
return x.__class__.__name__
|
return x.__class__.__name__
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user