Compare commits
3 Commits
wasi-maste
...
paris-ci/p
Author | SHA1 | Date | |
---|---|---|---|
|
a7c95966ac | ||
|
faa3e84cfb | ||
|
3a71d3be5f |
@@ -30,9 +30,11 @@ from typing import (
|
|||||||
Any,
|
Any,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
|
Coroutine,
|
||||||
Generic,
|
Generic,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
TYPE_CHECKING,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
@@ -50,15 +52,23 @@ from collections.abc import Sequence
|
|||||||
from discord.backoff import ExponentialBackoff
|
from discord.backoff import ExponentialBackoff
|
||||||
from discord.utils import MISSING
|
from discord.utils import MISSING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from typing_extensions import Concatenate, ParamSpec
|
||||||
|
|
||||||
|
P = ParamSpec("P")
|
||||||
|
else:
|
||||||
|
P = TypeVar("P") # hacky runtime fix
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
'loop',
|
'loop',
|
||||||
)
|
)
|
||||||
|
|
||||||
|
C = TypeVar('C')
|
||||||
T = TypeVar('T')
|
T = TypeVar('T')
|
||||||
|
_coro = Coroutine[Any, Any, T]
|
||||||
_func = Callable[..., Awaitable[Any]]
|
_func = Callable[..., Awaitable[Any]]
|
||||||
LF = TypeVar('LF', bound=_func)
|
|
||||||
FT = TypeVar('FT', bound=_func)
|
FT = TypeVar('FT', bound=_func)
|
||||||
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
|
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
|
||||||
LT = TypeVar('LT', bound='Loop')
|
LT = TypeVar('LT', bound='Loop')
|
||||||
@@ -89,13 +99,13 @@ class SleepHandle:
|
|||||||
self.future.cancel()
|
self.future.cancel()
|
||||||
|
|
||||||
|
|
||||||
class Loop(Generic[LF]):
|
class Loop(Generic[C, P, T]):
|
||||||
"""A background task helper that abstracts the loop and reconnection logic for you.
|
"""A background task helper that abstracts the loop and reconnection logic for you.
|
||||||
|
|
||||||
The main interface to create this is through :func:`loop`.
|
The main interface to create this is through :func:`loop`.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
coro: LF,
|
coro: Callable[P, _coro[T]],
|
||||||
seconds: float,
|
seconds: float,
|
||||||
hours: float,
|
hours: float,
|
||||||
minutes: float,
|
minutes: float,
|
||||||
@@ -104,14 +114,14 @@ class Loop(Generic[LF]):
|
|||||||
reconnect: bool,
|
reconnect: bool,
|
||||||
loop: Optional[asyncio.AbstractEventLoop],
|
loop: Optional[asyncio.AbstractEventLoop],
|
||||||
) -> None:
|
) -> None:
|
||||||
self.coro: LF = coro
|
self.coro: Callable[P, _coro[T]] = coro
|
||||||
self.reconnect: bool = reconnect
|
self.reconnect: bool = reconnect
|
||||||
self.loop: Optional[asyncio.AbstractEventLoop] = loop
|
self.loop: Optional[asyncio.AbstractEventLoop] = loop
|
||||||
self.count: Optional[int] = count
|
self.count: Optional[int] = count
|
||||||
self._current_loop = 0
|
self._current_loop = 0
|
||||||
self._handle = None
|
self._handle = None
|
||||||
self._task = None
|
self._task = None
|
||||||
self._injected = None
|
self._injected: Optional[C] = None
|
||||||
self._valid_exception = (
|
self._valid_exception = (
|
||||||
OSError,
|
OSError,
|
||||||
discord.GatewayNotFound,
|
discord.GatewayNotFound,
|
||||||
@@ -207,11 +217,11 @@ class Loop(Generic[LF]):
|
|||||||
self._stop_next_iteration = False
|
self._stop_next_iteration = False
|
||||||
self._has_failed = False
|
self._has_failed = False
|
||||||
|
|
||||||
def __get__(self, obj: T, objtype: Type[T]) -> Loop[LF]:
|
def __get__(self, obj: C, objtype: Type[C]) -> Loop[C, P, T]:
|
||||||
if obj is None:
|
if obj is None:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
copy = Loop(
|
copy = Loop[C, P, T](
|
||||||
self.coro,
|
self.coro,
|
||||||
seconds=self._seconds,
|
seconds=self._seconds,
|
||||||
hours=self._hours,
|
hours=self._hours,
|
||||||
@@ -285,7 +295,7 @@ class Loop(Generic[LF]):
|
|||||||
return None
|
return None
|
||||||
return self._next_iteration
|
return self._next_iteration
|
||||||
|
|
||||||
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
async def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T:
|
||||||
r"""|coro|
|
r"""|coro|
|
||||||
|
|
||||||
Calls the internal callback that the task holds.
|
Calls the internal callback that the task holds.
|
||||||
@@ -305,7 +315,7 @@ class Loop(Generic[LF]):
|
|||||||
|
|
||||||
return await self.coro(*args, **kwargs)
|
return await self.coro(*args, **kwargs)
|
||||||
|
|
||||||
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task:
|
def start(self, *args: P.args, **kwargs: P.kwargs) -> asyncio.Task:
|
||||||
r"""Starts the internal task in the event loop.
|
r"""Starts the internal task in the event loop.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@@ -367,7 +377,7 @@ class Loop(Generic[LF]):
|
|||||||
if self._can_be_cancelled():
|
if self._can_be_cancelled():
|
||||||
self._task.cancel()
|
self._task.cancel()
|
||||||
|
|
||||||
def restart(self, *args: Any, **kwargs: Any) -> None:
|
def restart(self, *args: P.args, **kwargs: P.kwargs) -> None:
|
||||||
r"""A convenience method to restart the internal task.
|
r"""A convenience method to restart the internal task.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
@@ -692,7 +702,7 @@ def loop(
|
|||||||
count: Optional[int] = None,
|
count: Optional[int] = None,
|
||||||
reconnect: bool = True,
|
reconnect: bool = True,
|
||||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||||
) -> Callable[[LF], Loop[LF]]:
|
) -> Callable[[Union[Callable[Concatenate[Type[C], P], _coro[T]], Callable[P, _coro[T]]]], Loop[C, P, T]]:
|
||||||
"""A decorator that schedules a task in the background for you with
|
"""A decorator that schedules a task in the background for you with
|
||||||
optional reconnect logic. The decorator returns a :class:`Loop`.
|
optional reconnect logic. The decorator returns a :class:`Loop`.
|
||||||
|
|
||||||
@@ -736,7 +746,7 @@ def loop(
|
|||||||
The function was not a coroutine, an invalid value for the ``time`` parameter was passed,
|
The function was not a coroutine, an invalid value for the ``time`` parameter was passed,
|
||||||
or ``time`` parameter was passed in conjunction with relative time parameters.
|
or ``time`` parameter was passed in conjunction with relative time parameters.
|
||||||
"""
|
"""
|
||||||
def decorator(func: LF) -> Loop[LF]:
|
def decorator(func: Union[Callable[Concatenate[Type[C], P], _coro[T]], Callable[P, _coro[T]]]) -> Loop[C, P, T]:
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'seconds': seconds,
|
'seconds': seconds,
|
||||||
'minutes': minutes,
|
'minutes': minutes,
|
||||||
@@ -746,5 +756,5 @@ def loop(
|
|||||||
'reconnect': reconnect,
|
'reconnect': reconnect,
|
||||||
'loop': loop,
|
'loop': loop,
|
||||||
}
|
}
|
||||||
return Loop(func, **kwargs)
|
return Loop[C, P, T](func, **kwargs)
|
||||||
return decorator
|
return decorator
|
||||||
|
Reference in New Issue
Block a user