mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-06-07 12:18:59 +00:00
Split annotation resolution to discord.utils
This commit is contained in:
parent
69da87f455
commit
9f3551926a
@ -25,11 +25,7 @@ DEALINGS IN THE SOFTWARE.
|
|||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Dict,
|
Dict,
|
||||||
ForwardRef,
|
|
||||||
Iterable,
|
|
||||||
Literal,
|
Literal,
|
||||||
Optional,
|
|
||||||
Tuple,
|
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
import asyncio
|
import asyncio
|
||||||
@ -37,7 +33,6 @@ import functools
|
|||||||
import inspect
|
import inspect
|
||||||
import datetime
|
import datetime
|
||||||
import types
|
import types
|
||||||
import sys
|
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
|
||||||
@ -74,102 +69,12 @@ __all__ = (
|
|||||||
'bot_has_guild_permissions'
|
'bot_has_guild_permissions'
|
||||||
)
|
)
|
||||||
|
|
||||||
PY_310 = sys.version_info >= (3, 10)
|
|
||||||
|
|
||||||
def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
|
|
||||||
params = []
|
|
||||||
literal_cls = type(Literal[0])
|
|
||||||
for p in parameters:
|
|
||||||
if isinstance(p, literal_cls):
|
|
||||||
params.extend(p.__args__)
|
|
||||||
else:
|
|
||||||
params.append(p)
|
|
||||||
return tuple(params)
|
|
||||||
|
|
||||||
def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
|
|
||||||
none_cls = type(None)
|
|
||||||
return tuple(p for p in parameters if p is not none_cls) + (none_cls,)
|
|
||||||
|
|
||||||
def _evaluate_annotation(
|
|
||||||
tp: Any,
|
|
||||||
globals: Dict[str, Any],
|
|
||||||
locals: Dict[str, Any],
|
|
||||||
cache: Dict[str, Any],
|
|
||||||
*,
|
|
||||||
implicit_str: bool = True,
|
|
||||||
):
|
|
||||||
if isinstance(tp, ForwardRef):
|
|
||||||
tp = tp.__forward_arg__
|
|
||||||
# ForwardRefs always evaluate their internals
|
|
||||||
implicit_str = True
|
|
||||||
|
|
||||||
if implicit_str and isinstance(tp, str):
|
|
||||||
if tp in cache:
|
|
||||||
return cache[tp]
|
|
||||||
evaluated = eval(tp, globals, locals)
|
|
||||||
cache[tp] = evaluated
|
|
||||||
return _evaluate_annotation(evaluated, globals, locals, cache)
|
|
||||||
|
|
||||||
if hasattr(tp, '__args__'):
|
|
||||||
implicit_str = True
|
|
||||||
is_literal = False
|
|
||||||
args = tp.__args__
|
|
||||||
if not hasattr(tp, '__origin__'):
|
|
||||||
if PY_310 and tp.__class__ is types.Union:
|
|
||||||
converted = Union[args] # type: ignore
|
|
||||||
return _evaluate_annotation(converted, globals, locals, cache)
|
|
||||||
|
|
||||||
return tp
|
|
||||||
if tp.__origin__ is Union:
|
|
||||||
try:
|
|
||||||
if args.index(type(None)) != len(args) - 1:
|
|
||||||
args = normalise_optional_params(tp.__args__)
|
|
||||||
except ValueError:
|
|
||||||
pass
|
|
||||||
if tp.__origin__ is Literal:
|
|
||||||
if not PY_310:
|
|
||||||
args = flatten_literal_params(tp.__args__)
|
|
||||||
implicit_str = False
|
|
||||||
is_literal = True
|
|
||||||
|
|
||||||
evaluated_args = tuple(
|
|
||||||
_evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args):
|
|
||||||
raise TypeError('Literal arguments must be of type str, int, bool, float or complex.')
|
|
||||||
|
|
||||||
if evaluated_args == args:
|
|
||||||
return tp
|
|
||||||
|
|
||||||
try:
|
|
||||||
return tp.copy_with(evaluated_args)
|
|
||||||
except AttributeError:
|
|
||||||
return tp.__origin__[evaluated_args]
|
|
||||||
|
|
||||||
return tp
|
|
||||||
|
|
||||||
def resolve_annotation(
|
|
||||||
annotation: Any,
|
|
||||||
globalns: Dict[str, Any],
|
|
||||||
localns: Optional[Dict[str, Any]],
|
|
||||||
cache: Optional[Dict[str, Any]],
|
|
||||||
) -> Any:
|
|
||||||
if annotation is None:
|
|
||||||
return type(None)
|
|
||||||
if isinstance(annotation, str):
|
|
||||||
annotation = ForwardRef(annotation)
|
|
||||||
|
|
||||||
locals = globalns if localns is None else localns
|
|
||||||
if cache is None:
|
|
||||||
cache = {}
|
|
||||||
return _evaluate_annotation(annotation, globalns, locals, cache)
|
|
||||||
|
|
||||||
def get_signature_parameters(function: types.FunctionType) -> Dict[str, inspect.Parameter]:
|
def get_signature_parameters(function: types.FunctionType) -> Dict[str, inspect.Parameter]:
|
||||||
globalns = function.__globals__
|
globalns = function.__globals__
|
||||||
signature = inspect.signature(function)
|
signature = inspect.signature(function)
|
||||||
params = {}
|
params = {}
|
||||||
cache: Dict[str, Any] = {}
|
cache: Dict[str, Any] = {}
|
||||||
|
eval_annotation = discord.utils.evaluate_annotation
|
||||||
for name, parameter in signature.parameters.items():
|
for name, parameter in signature.parameters.items():
|
||||||
annotation = parameter.annotation
|
annotation = parameter.annotation
|
||||||
if annotation is parameter.empty:
|
if annotation is parameter.empty:
|
||||||
@ -179,7 +84,7 @@ def get_signature_parameters(function: types.FunctionType) -> Dict[str, inspect.
|
|||||||
params[name] = parameter.replace(annotation=type(None))
|
params[name] = parameter.replace(annotation=type(None))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
annotation = _evaluate_annotation(annotation, globalns, globalns, cache)
|
annotation = eval_annotation(annotation, globalns, globalns, cache)
|
||||||
if annotation is Greedy:
|
if annotation is Greedy:
|
||||||
raise TypeError('Unparameterized Greedy[...] is disallowed in signature.')
|
raise TypeError('Unparameterized Greedy[...] is disallowed in signature.')
|
||||||
|
|
||||||
|
@ -32,7 +32,7 @@ from .errors import (
|
|||||||
MissingRequiredFlag,
|
MissingRequiredFlag,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .core import resolve_annotation
|
from discord.utils import resolve_annotation
|
||||||
from .view import StringView
|
from .view import StringView
|
||||||
from .converter import run_converters
|
from .converter import run_converters
|
||||||
|
|
||||||
|
101
discord/utils.py
101
discord/utils.py
@ -31,13 +31,16 @@ from typing import (
|
|||||||
AsyncIterator,
|
AsyncIterator,
|
||||||
Callable,
|
Callable,
|
||||||
Dict,
|
Dict,
|
||||||
|
ForwardRef,
|
||||||
Generic,
|
Generic,
|
||||||
Iterable,
|
Iterable,
|
||||||
Iterator,
|
Iterator,
|
||||||
List,
|
List,
|
||||||
|
Literal,
|
||||||
Optional,
|
Optional,
|
||||||
Protocol,
|
Protocol,
|
||||||
Sequence,
|
Sequence,
|
||||||
|
Tuple,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
@ -53,6 +56,8 @@ from inspect import isawaitable as _isawaitable, signature as _signature
|
|||||||
from operator import attrgetter
|
from operator import attrgetter
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import sys
|
||||||
|
import types
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from .errors import InvalidArgument
|
from .errors import InvalidArgument
|
||||||
@ -99,6 +104,7 @@ if TYPE_CHECKING:
|
|||||||
class _RequestLike(Protocol):
|
class _RequestLike(Protocol):
|
||||||
headers: Dict[str, Any]
|
headers: Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
cached_property = _cached_property
|
cached_property = _cached_property
|
||||||
|
|
||||||
@ -741,6 +747,7 @@ def _chunk(iterator: Iterator[T], max_size: int) -> Iterator[List[T]]:
|
|||||||
if ret:
|
if ret:
|
||||||
yield ret
|
yield ret
|
||||||
|
|
||||||
|
|
||||||
async def _achunk(iterator: AsyncIterator[T], max_size: int) -> AsyncIterator[List[T]]:
|
async def _achunk(iterator: AsyncIterator[T], max_size: int) -> AsyncIterator[List[T]]:
|
||||||
ret = []
|
ret = []
|
||||||
n = 0
|
n = 0
|
||||||
@ -793,3 +800,97 @@ def as_chunks(iterator: _Iter[T], max_size: int) -> _Iter[List[T]]:
|
|||||||
if isinstance(iterator, AsyncIterator):
|
if isinstance(iterator, AsyncIterator):
|
||||||
return _achunk(iterator, max_size)
|
return _achunk(iterator, max_size)
|
||||||
return _chunk(iterator, max_size)
|
return _chunk(iterator, max_size)
|
||||||
|
|
||||||
|
|
||||||
|
PY_310 = sys.version_info >= (3, 10)
|
||||||
|
|
||||||
|
|
||||||
|
def flatten_literal_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
|
||||||
|
params = []
|
||||||
|
literal_cls = type(Literal[0])
|
||||||
|
for p in parameters:
|
||||||
|
if isinstance(p, literal_cls):
|
||||||
|
params.extend(p.__args__)
|
||||||
|
else:
|
||||||
|
params.append(p)
|
||||||
|
return tuple(params)
|
||||||
|
|
||||||
|
|
||||||
|
def normalise_optional_params(parameters: Iterable[Any]) -> Tuple[Any, ...]:
|
||||||
|
none_cls = type(None)
|
||||||
|
return tuple(p for p in parameters if p is not none_cls) + (none_cls,)
|
||||||
|
|
||||||
|
|
||||||
|
def evaluate_annotation(
|
||||||
|
tp: Any,
|
||||||
|
globals: Dict[str, Any],
|
||||||
|
locals: Dict[str, Any],
|
||||||
|
cache: Dict[str, Any],
|
||||||
|
*,
|
||||||
|
implicit_str: bool = True,
|
||||||
|
):
|
||||||
|
if isinstance(tp, ForwardRef):
|
||||||
|
tp = tp.__forward_arg__
|
||||||
|
# ForwardRefs always evaluate their internals
|
||||||
|
implicit_str = True
|
||||||
|
|
||||||
|
if implicit_str and isinstance(tp, str):
|
||||||
|
if tp in cache:
|
||||||
|
return cache[tp]
|
||||||
|
evaluated = eval(tp, globals, locals)
|
||||||
|
cache[tp] = evaluated
|
||||||
|
return evaluate_annotation(evaluated, globals, locals, cache)
|
||||||
|
|
||||||
|
if hasattr(tp, '__args__'):
|
||||||
|
implicit_str = True
|
||||||
|
is_literal = False
|
||||||
|
args = tp.__args__
|
||||||
|
if not hasattr(tp, '__origin__'):
|
||||||
|
if PY_310 and tp.__class__ is types.Union:
|
||||||
|
converted = Union[args] # type: ignore
|
||||||
|
return evaluate_annotation(converted, globals, locals, cache)
|
||||||
|
|
||||||
|
return tp
|
||||||
|
if tp.__origin__ is Union:
|
||||||
|
try:
|
||||||
|
if args.index(type(None)) != len(args) - 1:
|
||||||
|
args = normalise_optional_params(tp.__args__)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
if tp.__origin__ is Literal:
|
||||||
|
if not PY_310:
|
||||||
|
args = flatten_literal_params(tp.__args__)
|
||||||
|
implicit_str = False
|
||||||
|
is_literal = True
|
||||||
|
|
||||||
|
evaluated_args = tuple(evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args)
|
||||||
|
|
||||||
|
if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args):
|
||||||
|
raise TypeError('Literal arguments must be of type str, int, bool, float or complex.')
|
||||||
|
|
||||||
|
if evaluated_args == args:
|
||||||
|
return tp
|
||||||
|
|
||||||
|
try:
|
||||||
|
return tp.copy_with(evaluated_args)
|
||||||
|
except AttributeError:
|
||||||
|
return tp.__origin__[evaluated_args]
|
||||||
|
|
||||||
|
return tp
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_annotation(
|
||||||
|
annotation: Any,
|
||||||
|
globalns: Dict[str, Any],
|
||||||
|
localns: Optional[Dict[str, Any]],
|
||||||
|
cache: Optional[Dict[str, Any]],
|
||||||
|
) -> Any:
|
||||||
|
if annotation is None:
|
||||||
|
return type(None)
|
||||||
|
if isinstance(annotation, str):
|
||||||
|
annotation = ForwardRef(annotation)
|
||||||
|
|
||||||
|
locals = globalns if localns is None else localns
|
||||||
|
if cache is None:
|
||||||
|
cache = {}
|
||||||
|
return evaluate_annotation(annotation, globalns, locals, cache)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user