mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-10-24 10:02:56 +00:00
[commands] Refactor typing evaluation to not use get_type_hints
get_type_hints had a few issues: 1. It would convert = None default parameters to Optional 2. It would not allow values as type annotations 3. It would not implicitly convert some string literals as ForwardRef In Python 3.9 `list['Foo']` does not convert into `list[ForwardRef('Foo')]` even though `typing.List` does this behaviour. In order to streamline it, evaluation had to be rewritten manually to support our usecases. This patch also flattens nested typing.Literal which was not done until Python 3.9.2.
This commit is contained in:
@@ -22,10 +22,20 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
ForwardRef,
|
||||
Iterable,
|
||||
Literal,
|
||||
Tuple,
|
||||
Union,
|
||||
get_args as get_typing_args,
|
||||
get_origin as get_typing_origin,
|
||||
)
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import typing
|
||||
import datetime
|
||||
import sys
|
||||
|
||||
@@ -64,6 +74,83 @@ __all__ = (
|
||||
'bot_has_guild_permissions'
|
||||
)
|
||||
|
||||
PY_310 = sys.version_info >= (3, 10)
|
||||
|
||||
def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
|
||||
params = []
|
||||
literal_cls = type(Literal[0])
|
||||
for p in parameters:
|
||||
if isinstance(p, literal_cls):
|
||||
params.extend(p.__args__)
|
||||
else:
|
||||
params.append(p)
|
||||
return tuple(params)
|
||||
|
||||
def _evaluate_annotation(tp: Any, globals: Dict[str, Any], cache: Dict[str, Any] = {}, *, implicit_str=True):
|
||||
if isinstance(tp, ForwardRef):
|
||||
tp = tp.__forward_arg__
|
||||
# ForwardRefs always evaluate their internals
|
||||
implicit_str = True
|
||||
|
||||
if implicit_str and isinstance(tp, str):
|
||||
if tp in cache:
|
||||
return cache[tp]
|
||||
evaluated = eval(tp, globals)
|
||||
cache[tp] = evaluated
|
||||
return _evaluate_annotation(evaluated, globals, cache)
|
||||
|
||||
if hasattr(tp, '__args__'):
|
||||
implicit_str = True
|
||||
args = tp.__args__
|
||||
if tp.__origin__ is Literal:
|
||||
if not PY_310:
|
||||
args = flatten_literal_params(tp.__args__)
|
||||
implicit_str = False
|
||||
|
||||
evaluated_args = tuple(
|
||||
_evaluate_annotation(arg, globals, cache, implicit_str=implicit_str) for arg in args
|
||||
)
|
||||
|
||||
if evaluated_args == args:
|
||||
return tp
|
||||
|
||||
try:
|
||||
return tp.copy_with(evaluated_args)
|
||||
except AttributeError:
|
||||
return tp.__origin__[evaluated_args]
|
||||
|
||||
return tp
|
||||
|
||||
def resolve_annotation(annotation: Any, globalns: Dict[str, Any], cache: Dict[str, Any] = {}) -> Any:
|
||||
if annotation is None:
|
||||
return type(None)
|
||||
if isinstance(annotation, str):
|
||||
annotation = ForwardRef(annotation)
|
||||
return _evaluate_annotation(annotation, globalns, cache)
|
||||
|
||||
def get_signature_parameters(function) -> Dict[str, inspect.Parameter]:
|
||||
globalns = function.__globals__
|
||||
signature = inspect.signature(function)
|
||||
params = {}
|
||||
cache: Dict[str, Any] = {}
|
||||
for name, parameter in signature.parameters.items():
|
||||
annotation = parameter.annotation
|
||||
if annotation is parameter.empty:
|
||||
params[name] = parameter
|
||||
continue
|
||||
if annotation is None:
|
||||
params[name] = parameter.replace(annotation=type(None))
|
||||
continue
|
||||
|
||||
annotation = _evaluate_annotation(annotation, globalns, cache)
|
||||
if annotation is converters.Greedy:
|
||||
raise TypeError('Unparameterized Greedy[...] is disallowed in signature.')
|
||||
|
||||
params[name] = parameter.replace(annotation=annotation)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
def wrap_callback(coro):
|
||||
@functools.wraps(coro)
|
||||
async def wrapped(*args, **kwargs):
|
||||
@@ -300,40 +387,7 @@ class Command(_BaseCommand):
|
||||
def callback(self, function):
|
||||
self._callback = function
|
||||
self.module = function.__module__
|
||||
|
||||
signature = inspect.signature(function)
|
||||
self.params = signature.parameters.copy()
|
||||
|
||||
# see: https://bugs.python.org/issue41341
|
||||
resolve = self._recursive_resolve if sys.version_info < (3, 9) else self._return_resolved
|
||||
|
||||
try:
|
||||
type_hints = {k: resolve(v) for k, v in typing.get_type_hints(function).items()}
|
||||
except NameError as e:
|
||||
raise NameError(f'unresolved forward reference: {e.args[0]}') from None
|
||||
|
||||
for key, value in self.params.items():
|
||||
# coalesce the forward references
|
||||
if key in type_hints:
|
||||
self.params[key] = value = value.replace(annotation=type_hints[key])
|
||||
|
||||
# fail early for when someone passes an unparameterized Greedy type
|
||||
if value.annotation is converters.Greedy:
|
||||
raise TypeError('Unparameterized Greedy[...] is disallowed in signature.')
|
||||
|
||||
def _return_resolved(self, type, **kwargs):
|
||||
return type
|
||||
|
||||
def _recursive_resolve(self, type, *, globals=None):
|
||||
if not isinstance(type, typing.ForwardRef):
|
||||
return type
|
||||
|
||||
resolved = eval(type.__forward_arg__, globals)
|
||||
args = typing.get_args(resolved)
|
||||
for index, arg in enumerate(args):
|
||||
inner_resolve_result = self._recursive_resolve(arg, globals=globals)
|
||||
resolved[index] = inner_resolve_result
|
||||
return resolved
|
||||
self.params = get_signature_parameters(function)
|
||||
|
||||
def add_check(self, func):
|
||||
"""Adds a check to the command.
|
||||
@@ -493,12 +547,12 @@ class Command(_BaseCommand):
|
||||
raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc
|
||||
|
||||
async def do_conversion(self, ctx, converter, argument, param):
|
||||
origin = typing.get_origin(converter)
|
||||
origin = get_typing_origin(converter)
|
||||
|
||||
if origin is typing.Union:
|
||||
if origin is Union:
|
||||
errors = []
|
||||
_NoneType = type(None)
|
||||
for conv in typing.get_args(converter):
|
||||
for conv in get_typing_args(converter):
|
||||
# if we got to this part in the code, then the previous conversions have failed
|
||||
# so we should just undo the view, return the default, and allow parsing to continue
|
||||
# with the other parameters
|
||||
@@ -514,13 +568,12 @@ class Command(_BaseCommand):
|
||||
return value
|
||||
|
||||
# if we're here, then we failed all the converters
|
||||
raise BadUnionArgument(param, typing.get_args(converter), errors)
|
||||
raise BadUnionArgument(param, get_typing_args(converter), errors)
|
||||
|
||||
if origin is typing.Literal:
|
||||
if origin is Literal:
|
||||
errors = []
|
||||
conversions = {}
|
||||
literal_args = tuple(self._flattened_typing_literal_args(converter))
|
||||
for literal in literal_args:
|
||||
for literal in converter.__args__:
|
||||
literal_type = type(literal)
|
||||
try:
|
||||
value = conversions[literal_type]
|
||||
@@ -538,7 +591,7 @@ class Command(_BaseCommand):
|
||||
return value
|
||||
|
||||
# if we're here, then we failed to match all the literals
|
||||
raise BadLiteralArgument(param, literal_args, errors)
|
||||
raise BadLiteralArgument(param, converter.__args__, errors)
|
||||
|
||||
return await self._actual_conversion(ctx, converter, argument, param)
|
||||
|
||||
@@ -1021,14 +1074,7 @@ class Command(_BaseCommand):
|
||||
return ''
|
||||
|
||||
def _is_typing_optional(self, annotation):
|
||||
return typing.get_origin(annotation) is typing.Union and typing.get_args(annotation)[-1] is type(None)
|
||||
|
||||
def _flattened_typing_literal_args(self, annotation):
|
||||
for literal in typing.get_args(annotation):
|
||||
if typing.get_origin(literal) is typing.Literal:
|
||||
yield from self._flattened_typing_literal_args(literal)
|
||||
else:
|
||||
yield literal
|
||||
return get_typing_origin(annotation) is Union and get_typing_args(annotation)[-1] is type(None)
|
||||
|
||||
@property
|
||||
def signature(self):
|
||||
@@ -1048,17 +1094,16 @@ class Command(_BaseCommand):
|
||||
# for typing.Literal[...], typing.Optional[typing.Literal[...]], and Greedy[typing.Literal[...]], the
|
||||
# parameter signature is a literal list of it's values
|
||||
annotation = param.annotation.converter if greedy else param.annotation
|
||||
origin = typing.get_origin(annotation)
|
||||
if not greedy and origin is typing.Union:
|
||||
union_args = typing.get_args(annotation)
|
||||
origin = get_typing_origin(annotation)
|
||||
if not greedy and origin is Union:
|
||||
union_args = get_typing_args(annotation)
|
||||
optional = union_args[-1] is type(None)
|
||||
if optional:
|
||||
annotation = union_args[0]
|
||||
origin = typing.get_origin(annotation)
|
||||
origin = get_typing_origin(annotation)
|
||||
|
||||
if origin is typing.Literal:
|
||||
name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v)
|
||||
for v in self._flattened_typing_literal_args(annotation))
|
||||
if origin is Literal:
|
||||
name = '|'.join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__)
|
||||
if param.default is not param.empty:
|
||||
# We don't want None or '' to trigger the [name=value] case and instead it should
|
||||
# do [name] since [name=None] or [name=] are not exactly useful for the user.
|
||||
|
Reference in New Issue
Block a user