mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-09-05 09:26:10 +00:00
Fix typing issues and improve typing completeness across the library
Co-authored-by: Danny <Rapptz@users.noreply.github.com> Co-authored-by: Josh <josh.ja.butt@gmail.com>
This commit is contained in:
@ -29,6 +29,7 @@ from typing import (
|
||||
Any,
|
||||
AsyncIterable,
|
||||
AsyncIterator,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
Dict,
|
||||
@ -42,6 +43,7 @@ from typing import (
|
||||
NamedTuple,
|
||||
Optional,
|
||||
Protocol,
|
||||
Set,
|
||||
Sequence,
|
||||
Tuple,
|
||||
Type,
|
||||
@ -66,7 +68,7 @@ import warnings
|
||||
import yarl
|
||||
|
||||
try:
|
||||
import orjson
|
||||
import orjson # type: ignore
|
||||
except ModuleNotFoundError:
|
||||
HAS_ORJSON = False
|
||||
else:
|
||||
@ -123,7 +125,7 @@ class _cached_property:
|
||||
if TYPE_CHECKING:
|
||||
from functools import cached_property as cached_property
|
||||
|
||||
from typing_extensions import ParamSpec
|
||||
from typing_extensions import ParamSpec, Self
|
||||
|
||||
from .permissions import Permissions
|
||||
from .abc import Snowflake
|
||||
@ -135,8 +137,16 @@ if TYPE_CHECKING:
|
||||
|
||||
P = ParamSpec('P')
|
||||
|
||||
MaybeCoroFunc = Union[
|
||||
Callable[P, Coroutine[Any, Any, 'T']],
|
||||
Callable[P, 'T'],
|
||||
]
|
||||
|
||||
_SnowflakeListBase = array.array[int]
|
||||
|
||||
else:
|
||||
cached_property = _cached_property
|
||||
_SnowflakeListBase = array.array
|
||||
|
||||
|
||||
T = TypeVar('T')
|
||||
@ -178,7 +188,7 @@ class classproperty(Generic[T_co]):
|
||||
def __get__(self, instance: Optional[Any], owner: Type[Any]) -> T_co:
|
||||
return self.fget(owner)
|
||||
|
||||
def __set__(self, instance, value) -> None:
|
||||
def __set__(self, instance: Optional[Any], value: Any) -> None:
|
||||
raise AttributeError('cannot set attribute')
|
||||
|
||||
|
||||
@ -210,7 +220,7 @@ class SequenceProxy(Sequence[T_co]):
|
||||
def __reversed__(self) -> Iterator[T_co]:
|
||||
return reversed(self.__proxied)
|
||||
|
||||
def index(self, value: Any, *args, **kwargs) -> int:
|
||||
def index(self, value: Any, *args: Any, **kwargs: Any) -> int:
|
||||
return self.__proxied.index(value, *args, **kwargs)
|
||||
|
||||
def count(self, value: Any) -> int:
|
||||
@ -578,7 +588,7 @@ def _is_submodule(parent: str, child: str) -> bool:
|
||||
|
||||
if HAS_ORJSON:
|
||||
|
||||
def _to_json(obj: Any) -> str: # type: ignore
|
||||
def _to_json(obj: Any) -> str:
|
||||
return orjson.dumps(obj).decode('utf-8')
|
||||
|
||||
_from_json = orjson.loads # type: ignore
|
||||
@ -602,15 +612,15 @@ def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float:
|
||||
return float(reset_after)
|
||||
|
||||
|
||||
async def maybe_coroutine(f, *args, **kwargs):
|
||||
async def maybe_coroutine(f: MaybeCoroFunc[P, T], *args: P.args, **kwargs: P.kwargs) -> T:
|
||||
value = f(*args, **kwargs)
|
||||
if _isawaitable(value):
|
||||
return await value
|
||||
else:
|
||||
return value
|
||||
return value # type: ignore
|
||||
|
||||
|
||||
async def async_all(gen, *, check=_isawaitable):
|
||||
async def async_all(gen: Iterable[Awaitable[T]], *, check: Callable[[T], bool] = _isawaitable) -> bool:
|
||||
for elem in gen:
|
||||
if check(elem):
|
||||
elem = await elem
|
||||
@ -619,7 +629,7 @@ async def async_all(gen, *, check=_isawaitable):
|
||||
return True
|
||||
|
||||
|
||||
async def sane_wait_for(futures, *, timeout):
|
||||
async def sane_wait_for(futures: Iterable[Awaitable[T]], *, timeout: Optional[float]) -> Set[asyncio.Task[T]]:
|
||||
ensured = [asyncio.ensure_future(fut) for fut in futures]
|
||||
done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED)
|
||||
|
||||
@ -637,7 +647,7 @@ def get_slots(cls: Type[Any]) -> Iterator[str]:
|
||||
continue
|
||||
|
||||
|
||||
def compute_timedelta(dt: datetime.datetime):
|
||||
def compute_timedelta(dt: datetime.datetime) -> float:
|
||||
if dt.tzinfo is None:
|
||||
dt = dt.astimezone()
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
@ -686,7 +696,7 @@ def valid_icon_size(size: int) -> bool:
|
||||
return not size & (size - 1) and 4096 >= size >= 16
|
||||
|
||||
|
||||
class SnowflakeList(array.array):
|
||||
class SnowflakeList(_SnowflakeListBase):
|
||||
"""Internal data storage class to efficiently store a list of snowflakes.
|
||||
|
||||
This should have the following characteristics:
|
||||
@ -705,7 +715,7 @@ class SnowflakeList(array.array):
|
||||
def __init__(self, data: Iterable[int], *, is_sorted: bool = False):
|
||||
...
|
||||
|
||||
def __new__(cls, data: Iterable[int], *, is_sorted: bool = False):
|
||||
def __new__(cls, data: Iterable[int], *, is_sorted: bool = False) -> Self:
|
||||
return array.array.__new__(cls, 'Q', data if is_sorted else sorted(data)) # type: ignore
|
||||
|
||||
def add(self, element: int) -> None:
|
||||
@ -1010,7 +1020,7 @@ def evaluate_annotation(
|
||||
cache: Dict[str, Any],
|
||||
*,
|
||||
implicit_str: bool = True,
|
||||
):
|
||||
) -> Any:
|
||||
if isinstance(tp, ForwardRef):
|
||||
tp = tp.__forward_arg__
|
||||
# ForwardRefs always evaluate their internals
|
||||
|
Reference in New Issue
Block a user