Use ParamSpec in ext-tasks

This commit is contained in:
Josh 2021-06-10 22:54:59 +10:00
parent 04788d0a06
commit 3a71d3be5f

View File

@ -29,10 +29,12 @@ import datetime
from typing import ( from typing import (
Any, Any,
Awaitable, Awaitable,
Callable, Callable,
Coroutine,
Generic, Generic,
List, List,
Optional, Optional,
TYPE_CHECKING,
Type, Type,
TypeVar, TypeVar,
Union, Union,
@ -50,6 +52,13 @@ from collections.abc import Sequence
from discord.backoff import ExponentialBackoff from discord.backoff import ExponentialBackoff
from discord.utils import MISSING from discord.utils import MISSING
if TYPE_CHECKING:
from typing_extensions import ParamSpec
P = ParamSpec("P")
else:
P = TypeVar("P") # hacky runtime fix
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
__all__ = ( __all__ = (
@ -57,8 +66,9 @@ __all__ = (
) )
T = TypeVar('T') T = TypeVar('T')
OT = TypeVar('OT')
_coro = Coroutine[Any, Any, T]
_func = Callable[..., Awaitable[Any]] _func = Callable[..., Awaitable[Any]]
LF = TypeVar('LF', bound=_func)
FT = TypeVar('FT', bound=_func) FT = TypeVar('FT', bound=_func)
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]]) ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
LT = TypeVar('LT', bound='Loop') LT = TypeVar('LT', bound='Loop')
@ -89,13 +99,13 @@ class SleepHandle:
self.future.cancel() self.future.cancel()
class Loop(Generic[LF]): class Loop(Generic[P, T]):
"""A background task helper that abstracts the loop and reconnection logic for you. """A background task helper that abstracts the loop and reconnection logic for you.
The main interface to create this is through :func:`loop`. The main interface to create this is through :func:`loop`.
""" """
def __init__(self, def __init__(self,
coro: LF, coro: Callable[P, _coro[T]],
seconds: float, seconds: float,
hours: float, hours: float,
minutes: float, minutes: float,
@ -104,7 +114,7 @@ class Loop(Generic[LF]):
reconnect: bool, reconnect: bool,
loop: Optional[asyncio.AbstractEventLoop], loop: Optional[asyncio.AbstractEventLoop],
) -> None: ) -> None:
self.coro: LF = coro self.coro: Callable[P, _coro[T]] = coro
self.reconnect: bool = reconnect self.reconnect: bool = reconnect
self.loop: Optional[asyncio.AbstractEventLoop] = loop self.loop: Optional[asyncio.AbstractEventLoop] = loop
self.count: Optional[int] = count self.count: Optional[int] = count
@ -207,11 +217,11 @@ class Loop(Generic[LF]):
self._stop_next_iteration = False self._stop_next_iteration = False
self._has_failed = False self._has_failed = False
def __get__(self, obj: T, objtype: Type[T]) -> Loop[LF]: def __get__(self, obj: OT, objtype: Type[OT]) -> Loop[P, T]:
if obj is None: if obj is None:
return self return self
copy = Loop( copy = Loop[P, T](
self.coro, self.coro,
seconds=self._seconds, seconds=self._seconds,
hours=self._hours, hours=self._hours,
@ -285,7 +295,7 @@ class Loop(Generic[LF]):
return None return None
return self._next_iteration return self._next_iteration
async def __call__(self, *args: Any, **kwargs: Any) -> Any: async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
r"""|coro| r"""|coro|
Calls the internal callback that the task holds. Calls the internal callback that the task holds.
@ -305,7 +315,7 @@ class Loop(Generic[LF]):
return await self.coro(*args, **kwargs) return await self.coro(*args, **kwargs)
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task: def start(self, *args: P.args, **kwargs: P.kwargs) -> asyncio.Task:
r"""Starts the internal task in the event loop. r"""Starts the internal task in the event loop.
Parameters Parameters
@ -367,7 +377,7 @@ class Loop(Generic[LF]):
if self._can_be_cancelled(): if self._can_be_cancelled():
self._task.cancel() self._task.cancel()
def restart(self, *args: Any, **kwargs: Any) -> None: def restart(self, *args: P.args, **kwargs: P.kwargs) -> None:
r"""A convenience method to restart the internal task. r"""A convenience method to restart the internal task.
.. note:: .. note::
@ -692,7 +702,7 @@ def loop(
count: Optional[int] = None, count: Optional[int] = None,
reconnect: bool = True, reconnect: bool = True,
loop: Optional[asyncio.AbstractEventLoop] = None, loop: Optional[asyncio.AbstractEventLoop] = None,
) -> Callable[[LF], Loop[LF]]: ) -> Callable[[Callable[P, _coro[T]]], Loop[P, T]]:
"""A decorator that schedules a task in the background for you with """A decorator that schedules a task in the background for you with
optional reconnect logic. The decorator returns a :class:`Loop`. optional reconnect logic. The decorator returns a :class:`Loop`.
@ -736,7 +746,7 @@ def loop(
The function was not a coroutine, an invalid value for the ``time`` parameter was passed, The function was not a coroutine, an invalid value for the ``time`` parameter was passed,
or ``time`` parameter was passed in conjunction with relative time parameters. or ``time`` parameter was passed in conjunction with relative time parameters.
""" """
def decorator(func: LF) -> Loop[LF]: def decorator(func: Callable[P, _coro[T]]) -> Loop[P, T]:
kwargs = { kwargs = {
'seconds': seconds, 'seconds': seconds,
'minutes': minutes, 'minutes': minutes,
@ -746,5 +756,5 @@ def loop(
'reconnect': reconnect, 'reconnect': reconnect,
'loop': loop, 'loop': loop,
} }
return Loop(func, **kwargs) return Loop[P, T](func, **kwargs)
return decorator return decorator # type: ignore - pyright bug