Fix issue when tasks loop used in classes.

This commit is contained in:
Josh 2021-06-28 15:45:24 +10:00
parent faa3e84cfb
commit a7c95966ac

View File

@ -53,7 +53,7 @@ from discord.backoff import ExponentialBackoff
from discord.utils import MISSING from discord.utils import MISSING
if TYPE_CHECKING: if TYPE_CHECKING:
from typing_extensions import ParamSpec from typing_extensions import Concatenate, ParamSpec
P = ParamSpec("P") P = ParamSpec("P")
else: else:
@ -65,8 +65,8 @@ __all__ = (
'loop', 'loop',
) )
C = TypeVar('C')
T = TypeVar('T') T = TypeVar('T')
OT = TypeVar('OT')
_coro = Coroutine[Any, Any, T] _coro = Coroutine[Any, Any, T]
_func = Callable[..., Awaitable[Any]] _func = Callable[..., Awaitable[Any]]
FT = TypeVar('FT', bound=_func) FT = TypeVar('FT', bound=_func)
@ -99,7 +99,7 @@ class SleepHandle:
self.future.cancel() self.future.cancel()
class Loop(Generic[P, T]): class Loop(Generic[C, P, T]):
"""A background task helper that abstracts the loop and reconnection logic for you. """A background task helper that abstracts the loop and reconnection logic for you.
The main interface to create this is through :func:`loop`. The main interface to create this is through :func:`loop`.
@ -121,7 +121,7 @@ class Loop(Generic[P, T]):
self._current_loop = 0 self._current_loop = 0
self._handle = None self._handle = None
self._task = None self._task = None
self._injected = None self._injected: Optional[C] = None
self._valid_exception = ( self._valid_exception = (
OSError, OSError,
discord.GatewayNotFound, discord.GatewayNotFound,
@ -217,11 +217,11 @@ class Loop(Generic[P, T]):
self._stop_next_iteration = False self._stop_next_iteration = False
self._has_failed = False self._has_failed = False
def __get__(self, obj: OT, objtype: Type[OT]) -> Loop[P, T]: def __get__(self, obj: C, objtype: Type[C]) -> Loop[C, P, T]:
if obj is None: if obj is None:
return self return self
copy = Loop[P, T]( copy = Loop[C, P, T](
self.coro, self.coro,
seconds=self._seconds, seconds=self._seconds,
hours=self._hours, hours=self._hours,
@ -702,7 +702,7 @@ def loop(
count: Optional[int] = None, count: Optional[int] = None,
reconnect: bool = True, reconnect: bool = True,
loop: Optional[asyncio.AbstractEventLoop] = None, loop: Optional[asyncio.AbstractEventLoop] = None,
) -> Callable[[Callable[P, _coro[T]]], Loop[P, T]]: ) -> Callable[[Union[Callable[Concatenate[Type[C], P], _coro[T]], Callable[P, _coro[T]]]], Loop[C, P, T]]:
"""A decorator that schedules a task in the background for you with """A decorator that schedules a task in the background for you with
optional reconnect logic. The decorator returns a :class:`Loop`. optional reconnect logic. The decorator returns a :class:`Loop`.
@ -746,7 +746,7 @@ def loop(
The function was not a coroutine, an invalid value for the ``time`` parameter was passed, The function was not a coroutine, an invalid value for the ``time`` parameter was passed,
or ``time`` parameter was passed in conjunction with relative time parameters. or ``time`` parameter was passed in conjunction with relative time parameters.
""" """
def decorator(func: Callable[P, _coro[T]]) -> Loop[P, T]: def decorator(func: Union[Callable[Concatenate[Type[C], P], _coro[T]], Callable[P, _coro[T]]]) -> Loop[C, P, T]:
kwargs = { kwargs = {
'seconds': seconds, 'seconds': seconds,
'minutes': minutes, 'minutes': minutes,
@ -756,5 +756,5 @@ def loop(
'reconnect': reconnect, 'reconnect': reconnect,
'loop': loop, 'loop': loop,
} }
return Loop[P, T](func, **kwargs) return Loop[C, P, T](func, **kwargs)
return decorator return decorator