[commands] Use typing.get_type_hints to resolve ForwardRefs
This commit is contained in:
parent
72275a73fa
commit
7a34de1570
@ -27,6 +27,7 @@ import functools
|
||||
import inspect
|
||||
import typing
|
||||
import datetime
|
||||
import sys
|
||||
|
||||
import discord
|
||||
|
||||
@ -299,17 +300,35 @@ class Command(_BaseCommand):
|
||||
signature = inspect.signature(function)
|
||||
self.params = signature.parameters.copy()
|
||||
|
||||
# PEP-563 allows postponing evaluation of annotations with a __future__
|
||||
# import. When postponed, Parameter.annotation will be a string and must
|
||||
# be replaced with the real value for the converters to work later on
|
||||
for key, value in self.params.items():
|
||||
if isinstance(value.annotation, str):
|
||||
self.params[key] = value = value.replace(annotation=eval(value.annotation, function.__globals__))
|
||||
# see: https://bugs.python.org/issue41341
|
||||
resolve = self._recursive_resolve if sys.version_info < (3, 9) else self._return_resolved
|
||||
|
||||
try:
|
||||
type_hints = {k: resolve(v) for k, v in typing.get_type_hints(function).items()}
|
||||
except NameError as e:
|
||||
raise NameError(f'unresolved forward reference: {e.args[0]}') from None
|
||||
|
||||
for key, value in self.params.items():
|
||||
# coalesce the forward references
|
||||
self.params[key] = value = value.replace(annotation=type_hints.get(key))
|
||||
# fail early for when someone passes an unparameterized Greedy type
|
||||
if value.annotation is converters.Greedy:
|
||||
raise TypeError('Unparameterized Greedy[...] is disallowed in signature.')
|
||||
|
||||
def _return_resolved(self, type, **kwargs):
|
||||
return type
|
||||
|
||||
def _recursive_resolve(self, type, *, globals=None):
|
||||
if not isinstance(type, typing.ForwardRef):
|
||||
return type
|
||||
|
||||
resolved = eval(type.__forward_arg__, globals)
|
||||
args = typing.get_args(resolved)
|
||||
for index, arg in enumerate(args):
|
||||
inner_resolve_result = self._recursive_resolve(arg, globals=globals)
|
||||
resolved[index] = inner_resolve_result
|
||||
return resolved
|
||||
|
||||
def add_check(self, func):
|
||||
"""Adds a check to the command.
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user