[commands] Use typing.get_type_hints to resolve ForwardRefs
This commit is contained in:
parent
72275a73fa
commit
7a34de1570
@ -27,6 +27,7 @@ import functools
|
|||||||
import inspect
|
import inspect
|
||||||
import typing
|
import typing
|
||||||
import datetime
|
import datetime
|
||||||
|
import sys
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
|
||||||
@ -299,17 +300,35 @@ class Command(_BaseCommand):
|
|||||||
signature = inspect.signature(function)
|
signature = inspect.signature(function)
|
||||||
self.params = signature.parameters.copy()
|
self.params = signature.parameters.copy()
|
||||||
|
|
||||||
# PEP-563 allows postponing evaluation of annotations with a __future__
|
# see: https://bugs.python.org/issue41341
|
||||||
# import. When postponed, Parameter.annotation will be a string and must
|
resolve = self._recursive_resolve if sys.version_info < (3, 9) else self._return_resolved
|
||||||
# be replaced with the real value for the converters to work later on
|
|
||||||
for key, value in self.params.items():
|
|
||||||
if isinstance(value.annotation, str):
|
|
||||||
self.params[key] = value = value.replace(annotation=eval(value.annotation, function.__globals__))
|
|
||||||
|
|
||||||
|
try:
|
||||||
|
type_hints = {k: resolve(v) for k, v in typing.get_type_hints(function).items()}
|
||||||
|
except NameError as e:
|
||||||
|
raise NameError(f'unresolved forward reference: {e.args[0]}') from None
|
||||||
|
|
||||||
|
for key, value in self.params.items():
|
||||||
|
# coalesce the forward references
|
||||||
|
self.params[key] = value = value.replace(annotation=type_hints.get(key))
|
||||||
# fail early for when someone passes an unparameterized Greedy type
|
# fail early for when someone passes an unparameterized Greedy type
|
||||||
if value.annotation is converters.Greedy:
|
if value.annotation is converters.Greedy:
|
||||||
raise TypeError('Unparameterized Greedy[...] is disallowed in signature.')
|
raise TypeError('Unparameterized Greedy[...] is disallowed in signature.')
|
||||||
|
|
||||||
|
def _return_resolved(self, type, **kwargs):
|
||||||
|
return type
|
||||||
|
|
||||||
|
def _recursive_resolve(self, type, *, globals=None):
|
||||||
|
if not isinstance(type, typing.ForwardRef):
|
||||||
|
return type
|
||||||
|
|
||||||
|
resolved = eval(type.__forward_arg__, globals)
|
||||||
|
args = typing.get_args(resolved)
|
||||||
|
for index, arg in enumerate(args):
|
||||||
|
inner_resolve_result = self._recursive_resolve(arg, globals=globals)
|
||||||
|
resolved[index] = inner_resolve_result
|
||||||
|
return resolved
|
||||||
|
|
||||||
def add_check(self, func):
|
def add_check(self, func):
|
||||||
"""Adds a check to the command.
|
"""Adds a check to the command.
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user