[tasks] Use ParamSpec in ext-tasks #45

Open
paris-ci wants to merge 3 commits from paris-ci/pr7044 into 2.0

View File

@ -29,10 +29,12 @@ import datetime
from typing import (
Any,
Awaitable,
Callable,
Callable,
Coroutine,
Generic,
List,
Optional,
Optional,
TYPE_CHECKING,
Type,
TypeVar,
Union,
@ -50,15 +52,23 @@ from collections.abc import Sequence
from discord.backoff import ExponentialBackoff
from discord.utils import MISSING
if TYPE_CHECKING:
from typing_extensions import Concatenate, ParamSpec
P = ParamSpec("P")
else:
P = TypeVar("P") # hacky runtime fix
log = logging.getLogger(__name__)
__all__ = (
'loop',
)
C = TypeVar('C')
T = TypeVar('T')
_coro = Coroutine[Any, Any, T]
_func = Callable[..., Awaitable[Any]]
LF = TypeVar('LF', bound=_func)
FT = TypeVar('FT', bound=_func)
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
LT = TypeVar('LT', bound='Loop')
@ -89,13 +99,13 @@ class SleepHandle:
self.future.cancel()
class Loop(Generic[LF]):
class Loop(Generic[C, P, T]):
"""A background task helper that abstracts the loop and reconnection logic for you.
The main interface to create this is through :func:`loop`.
"""
def __init__(self,
coro: LF,
coro: Callable[P, _coro[T]],
seconds: float,
hours: float,
minutes: float,
@ -104,14 +114,14 @@ class Loop(Generic[LF]):
reconnect: bool,
loop: Optional[asyncio.AbstractEventLoop],
) -> None:
self.coro: LF = coro
self.coro: Callable[P, _coro[T]] = coro
self.reconnect: bool = reconnect
self.loop: Optional[asyncio.AbstractEventLoop] = loop
self.count: Optional[int] = count
self._current_loop = 0
self._handle = None
self._task = None
self._injected = None
self._injected: Optional[C] = None
self._valid_exception = (
OSError,
discord.GatewayNotFound,
@ -207,11 +217,11 @@ class Loop(Generic[LF]):
self._stop_next_iteration = False
self._has_failed = False
def __get__(self, obj: T, objtype: Type[T]) -> Loop[LF]:
def __get__(self, obj: C, objtype: Type[C]) -> Loop[C, P, T]:
if obj is None:
return self
copy = Loop(
copy = Loop[C, P, T](
self.coro,
seconds=self._seconds,
hours=self._hours,
@ -285,7 +295,7 @@ class Loop(Generic[LF]):
return None
return self._next_iteration
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
r"""|coro|
Calls the internal callback that the task holds.
@ -305,7 +315,7 @@ class Loop(Generic[LF]):
return await self.coro(*args, **kwargs)
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task:
def start(self, *args: P.args, **kwargs: P.kwargs) -> asyncio.Task:
r"""Starts the internal task in the event loop.
Parameters
@ -367,7 +377,7 @@ class Loop(Generic[LF]):
if self._can_be_cancelled():
self._task.cancel()
def restart(self, *args: Any, **kwargs: Any) -> None:
def restart(self, *args: P.args, **kwargs: P.kwargs) -> None:
r"""A convenience method to restart the internal task.
.. note::
@ -692,7 +702,7 @@ def loop(
count: Optional[int] = None,
reconnect: bool = True,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> Callable[[LF], Loop[LF]]:
) -> Callable[[Union[Callable[Concatenate[Type[C], P], _coro[T]], Callable[P, _coro[T]]]], Loop[C, P, T]]:
"""A decorator that schedules a task in the background for you with
optional reconnect logic. The decorator returns a :class:`Loop`.
@ -736,7 +746,7 @@ def loop(
The function was not a coroutine, an invalid value for the ``time`` parameter was passed,
or ``time`` parameter was passed in conjunction with relative time parameters.
"""
def decorator(func: LF) -> Loop[LF]:
def decorator(func: Union[Callable[Concatenate[Type[C], P], _coro[T]], Callable[P, _coro[T]]]) -> Loop[C, P, T]:
kwargs = {
'seconds': seconds,
'minutes': minutes,
@ -746,5 +756,5 @@ def loop(
'reconnect': reconnect,
'loop': loop,
}
return Loop(func, **kwargs)
return Loop[C, P, T](func, **kwargs)
return decorator