[tasks] Improve typing parity
This commit is contained in:
		| @@ -36,13 +36,11 @@ from typing import ( | |||||||
|     Type, |     Type, | ||||||
|     TypeVar, |     TypeVar, | ||||||
|     Union, |     Union, | ||||||
|     cast, |  | ||||||
| ) | ) | ||||||
|  |  | ||||||
| import aiohttp | import aiohttp | ||||||
| import discord | import discord | ||||||
| import inspect | import inspect | ||||||
| import logging |  | ||||||
| import sys | import sys | ||||||
| import traceback | import traceback | ||||||
|  |  | ||||||
| @@ -50,8 +48,6 @@ from collections.abc import Sequence | |||||||
| from discord.backoff import ExponentialBackoff | from discord.backoff import ExponentialBackoff | ||||||
| from discord.utils import MISSING | from discord.utils import MISSING | ||||||
|  |  | ||||||
| _log = logging.getLogger(__name__) |  | ||||||
|  |  | ||||||
| __all__ = ( | __all__ = ( | ||||||
|     'loop', |     'loop', | ||||||
| ) | ) | ||||||
| @@ -61,7 +57,6 @@ _func = Callable[..., Awaitable[Any]] | |||||||
| LF = TypeVar('LF', bound=_func) | LF = TypeVar('LF', bound=_func) | ||||||
| FT = TypeVar('FT', bound=_func) | FT = TypeVar('FT', bound=_func) | ||||||
| ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]]) | ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]]) | ||||||
| LT = TypeVar('LT', bound='Loop') |  | ||||||
|  |  | ||||||
|  |  | ||||||
| class SleepHandle: | class SleepHandle: | ||||||
| @@ -78,7 +73,7 @@ class SleepHandle: | |||||||
|         relative_delta = discord.utils.compute_timedelta(dt) |         relative_delta = discord.utils.compute_timedelta(dt) | ||||||
|         self.handle = self.loop.call_later(relative_delta, self.future.set_result, True) |         self.handle = self.loop.call_later(relative_delta, self.future.set_result, True) | ||||||
|  |  | ||||||
|     def wait(self) -> asyncio.Future: |     def wait(self) -> asyncio.Future[Any]: | ||||||
|         return self.future |         return self.future | ||||||
|  |  | ||||||
|     def done(self) -> bool: |     def done(self) -> bool: | ||||||
| @@ -94,7 +89,9 @@ class Loop(Generic[LF]): | |||||||
|  |  | ||||||
|     The main interface to create this is through :func:`loop`. |     The main interface to create this is through :func:`loop`. | ||||||
|     """ |     """ | ||||||
|     def __init__(self, |  | ||||||
|  |     def __init__( | ||||||
|  |         self, | ||||||
|         coro: LF, |         coro: LF, | ||||||
|         seconds: float, |         seconds: float, | ||||||
|         hours: float, |         hours: float, | ||||||
| @@ -102,15 +99,15 @@ class Loop(Generic[LF]): | |||||||
|         time: Union[datetime.time, Sequence[datetime.time]], |         time: Union[datetime.time, Sequence[datetime.time]], | ||||||
|         count: Optional[int], |         count: Optional[int], | ||||||
|         reconnect: bool, |         reconnect: bool, | ||||||
|         loop: Optional[asyncio.AbstractEventLoop], |         loop: asyncio.AbstractEventLoop, | ||||||
|     ) -> None: |     ) -> None: | ||||||
|         self.coro: LF = coro |         self.coro: LF = coro | ||||||
|         self.reconnect: bool = reconnect |         self.reconnect: bool = reconnect | ||||||
|         self.loop: Optional[asyncio.AbstractEventLoop] = loop |         self.loop: asyncio.AbstractEventLoop = loop | ||||||
|         self.count: Optional[int] = count |         self.count: Optional[int] = count | ||||||
|         self._current_loop = 0 |         self._current_loop = 0 | ||||||
|         self._handle = None |         self._handle: SleepHandle = MISSING | ||||||
|         self._task = None |         self._task: asyncio.Task[None] = MISSING | ||||||
|         self._injected = None |         self._injected = None | ||||||
|         self._valid_exception = ( |         self._valid_exception = ( | ||||||
|             OSError, |             OSError, | ||||||
| @@ -131,7 +128,7 @@ class Loop(Generic[LF]): | |||||||
|  |  | ||||||
|         self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time) |         self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time) | ||||||
|         self._last_iteration_failed = False |         self._last_iteration_failed = False | ||||||
|         self._last_iteration = None |         self._last_iteration: datetime.datetime = MISSING | ||||||
|         self._next_iteration = None |         self._next_iteration = None | ||||||
|  |  | ||||||
|         if not inspect.iscoroutinefunction(self.coro): |         if not inspect.iscoroutinefunction(self.coro): | ||||||
| @@ -147,9 +144,8 @@ class Loop(Generic[LF]): | |||||||
|         else: |         else: | ||||||
|             await coro(*args, **kwargs) |             await coro(*args, **kwargs) | ||||||
|  |  | ||||||
|      |  | ||||||
|     def _try_sleep_until(self, dt: datetime.datetime): |     def _try_sleep_until(self, dt: datetime.datetime): | ||||||
|         self._handle = SleepHandle(dt=dt, loop=self.loop)  # type: ignore |         self._handle = SleepHandle(dt=dt, loop=self.loop) | ||||||
|         return self._handle.wait() |         return self._handle.wait() | ||||||
|  |  | ||||||
|     async def _loop(self, *args: Any, **kwargs: Any) -> None: |     async def _loop(self, *args: Any, **kwargs: Any) -> None: | ||||||
| @@ -211,7 +207,7 @@ class Loop(Generic[LF]): | |||||||
|         if obj is None: |         if obj is None: | ||||||
|             return self |             return self | ||||||
|  |  | ||||||
|         copy = Loop( |         copy: Loop[LF] = Loop( | ||||||
|             self.coro, |             self.coro, | ||||||
|             seconds=self._seconds, |             seconds=self._seconds, | ||||||
|             hours=self._hours, |             hours=self._hours, | ||||||
| @@ -279,7 +275,7 @@ class Loop(Generic[LF]): | |||||||
|  |  | ||||||
|         .. versionadded:: 1.3 |         .. versionadded:: 1.3 | ||||||
|         """ |         """ | ||||||
|         if self._task is None: |         if self._task is MISSING: | ||||||
|             return None |             return None | ||||||
|         elif self._task and self._task.done() or self._stop_next_iteration: |         elif self._task and self._task.done() or self._stop_next_iteration: | ||||||
|             return None |             return None | ||||||
| @@ -305,7 +301,7 @@ class Loop(Generic[LF]): | |||||||
|  |  | ||||||
|         return await self.coro(*args, **kwargs) |         return await self.coro(*args, **kwargs) | ||||||
|  |  | ||||||
|     def start(self, *args: Any, **kwargs: Any) -> asyncio.Task: |     def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]: | ||||||
|         r"""Starts the internal task in the event loop. |         r"""Starts the internal task in the event loop. | ||||||
|  |  | ||||||
|         Parameters |         Parameters | ||||||
| @@ -326,13 +322,13 @@ class Loop(Generic[LF]): | |||||||
|             The task that has been created. |             The task that has been created. | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         if self._task is not None and not self._task.done(): |         if self._task is not MISSING and not self._task.done(): | ||||||
|             raise RuntimeError('Task is already launched and is not completed.') |             raise RuntimeError('Task is already launched and is not completed.') | ||||||
|  |  | ||||||
|         if self._injected is not None: |         if self._injected is not None: | ||||||
|             args = (self._injected, *args) |             args = (self._injected, *args) | ||||||
|  |  | ||||||
|         if self.loop is None: |         if self.loop is MISSING: | ||||||
|             self.loop = asyncio.get_event_loop() |             self.loop = asyncio.get_event_loop() | ||||||
|  |  | ||||||
|         self._task = self.loop.create_task(self._loop(*args, **kwargs)) |         self._task = self.loop.create_task(self._loop(*args, **kwargs)) | ||||||
| @@ -356,7 +352,7 @@ class Loop(Generic[LF]): | |||||||
|  |  | ||||||
|         .. versionadded:: 1.2 |         .. versionadded:: 1.2 | ||||||
|         """ |         """ | ||||||
|         if self._task and not self._task.done(): |         if self._task is not MISSING and not self._task.done(): | ||||||
|             self._stop_next_iteration = True |             self._stop_next_iteration = True | ||||||
|  |  | ||||||
|     def _can_be_cancelled(self) -> bool: |     def _can_be_cancelled(self) -> bool: | ||||||
| @@ -383,7 +379,7 @@ class Loop(Generic[LF]): | |||||||
|             The keyword arguments to use. |             The keyword arguments to use. | ||||||
|         """ |         """ | ||||||
|  |  | ||||||
|         def restart_when_over(fut, *, args=args, kwargs=kwargs): |         def restart_when_over(fut: Any, *, args: Any = args, kwargs: Any = kwargs) -> None: | ||||||
|             self._task.remove_done_callback(restart_when_over) |             self._task.remove_done_callback(restart_when_over) | ||||||
|             self.start(*args, **kwargs) |             self.start(*args, **kwargs) | ||||||
|  |  | ||||||
| @@ -446,9 +442,9 @@ class Loop(Generic[LF]): | |||||||
|         self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions) |         self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions) | ||||||
|         return len(self._valid_exception) == old_length - len(exceptions) |         return len(self._valid_exception) == old_length - len(exceptions) | ||||||
|  |  | ||||||
|     def get_task(self) -> Optional[asyncio.Task]: |     def get_task(self) -> Optional[asyncio.Task[None]]: | ||||||
|         """Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running.""" |         """Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running.""" | ||||||
|         return self._task |         return self._task if self._task is not MISSING else None | ||||||
|  |  | ||||||
|     def is_being_cancelled(self) -> bool: |     def is_being_cancelled(self) -> bool: | ||||||
|         """Whether the task is being cancelled.""" |         """Whether the task is being cancelled.""" | ||||||
| @@ -466,7 +462,7 @@ class Loop(Generic[LF]): | |||||||
|  |  | ||||||
|         .. versionadded:: 1.4 |         .. versionadded:: 1.4 | ||||||
|         """ |         """ | ||||||
|         return not bool(self._task.done()) if self._task else False |         return not bool(self._task.done()) if self._task is not MISSING else False | ||||||
|  |  | ||||||
|     async def _error(self, *args: Any) -> None: |     async def _error(self, *args: Any) -> None: | ||||||
|         exception: Exception = args[-1] |         exception: Exception = args[-1] | ||||||
| @@ -560,7 +556,9 @@ class Loop(Generic[LF]): | |||||||
|             self._time_index = 0 |             self._time_index = 0 | ||||||
|             if self._current_loop == 0: |             if self._current_loop == 0: | ||||||
|                 # if we're at the last index on the first iteration, we need to sleep until tomorrow |                 # if we're at the last index on the first iteration, we need to sleep until tomorrow | ||||||
|                 return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0]) |                 return datetime.datetime.combine( | ||||||
|  |                     datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0] | ||||||
|  |                 ) | ||||||
|  |  | ||||||
|         next_time = self._time[self._time_index] |         next_time = self._time[self._time_index] | ||||||
|  |  | ||||||
| @@ -568,7 +566,7 @@ class Loop(Generic[LF]): | |||||||
|             self._time_index += 1 |             self._time_index += 1 | ||||||
|             return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time) |             return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time) | ||||||
|  |  | ||||||
|         next_date = cast(datetime.datetime, self._last_iteration) |         next_date = self._last_iteration | ||||||
|         if self._time_index == 0: |         if self._time_index == 0: | ||||||
|             # we can assume that the earliest time should be scheduled for "tomorrow" |             # we can assume that the earliest time should be scheduled for "tomorrow" | ||||||
|             next_date += datetime.timedelta(days=1) |             next_date += datetime.timedelta(days=1) | ||||||
| @@ -576,12 +574,14 @@ class Loop(Generic[LF]): | |||||||
|         self._time_index += 1 |         self._time_index += 1 | ||||||
|         return datetime.datetime.combine(next_date, next_time) |         return datetime.datetime.combine(next_date, next_time) | ||||||
|  |  | ||||||
|     def _prepare_time_index(self, now: Optional[datetime.datetime] = None) -> None: |     def _prepare_time_index(self, now: datetime.datetime = MISSING) -> None: | ||||||
|         # now kwarg should be a datetime.datetime representing the time "now" |         # now kwarg should be a datetime.datetime representing the time "now" | ||||||
|         # to calculate the next time index from |         # to calculate the next time index from | ||||||
|  |  | ||||||
|         # pre-condition: self._time is set |         # pre-condition: self._time is set | ||||||
|         time_now = (now or datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)).timetz() |         time_now = ( | ||||||
|  |             now if now is not MISSING else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0) | ||||||
|  |         ).timetz() | ||||||
|         for idx, time in enumerate(self._time): |         for idx, time in enumerate(self._time): | ||||||
|             if time >= time_now: |             if time >= time_now: | ||||||
|                 self._time_index = idx |                 self._time_index = idx | ||||||
| @@ -597,17 +597,21 @@ class Loop(Generic[LF]): | |||||||
|         utc: datetime.timezone = datetime.timezone.utc, |         utc: datetime.timezone = datetime.timezone.utc, | ||||||
|     ) -> List[datetime.time]: |     ) -> List[datetime.time]: | ||||||
|         if isinstance(time, dt): |         if isinstance(time, dt): | ||||||
|             ret = time if time.tzinfo is not None else time.replace(tzinfo=utc) |             inner = time if time.tzinfo is not None else time.replace(tzinfo=utc) | ||||||
|             return [ret] |             return [inner] | ||||||
|         if not isinstance(time, Sequence): |         if not isinstance(time, Sequence): | ||||||
|             raise TypeError(f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.') |             raise TypeError( | ||||||
|  |                 f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.' | ||||||
|  |             ) | ||||||
|         if not time: |         if not time: | ||||||
|             raise ValueError('time parameter must not be an empty sequence.') |             raise ValueError('time parameter must not be an empty sequence.') | ||||||
|  |  | ||||||
|         ret = [] |         ret: List[datetime.time] = [] | ||||||
|         for index, t in enumerate(time): |         for index, t in enumerate(time): | ||||||
|             if not isinstance(t, dt): |             if not isinstance(t, dt): | ||||||
|                 raise TypeError(f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.') |                 raise TypeError( | ||||||
|  |                     f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.' | ||||||
|  |                 ) | ||||||
|             ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc)) |             ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc)) | ||||||
|  |  | ||||||
|         ret = sorted(set(ret))  # de-dupe and sort times |         ret = sorted(set(ret))  # de-dupe and sort times | ||||||
| @@ -691,7 +695,7 @@ def loop( | |||||||
|     time: Union[datetime.time, Sequence[datetime.time]] = MISSING, |     time: Union[datetime.time, Sequence[datetime.time]] = MISSING, | ||||||
|     count: Optional[int] = None, |     count: Optional[int] = None, | ||||||
|     reconnect: bool = True, |     reconnect: bool = True, | ||||||
|     loop: Optional[asyncio.AbstractEventLoop] = None, |     loop: asyncio.AbstractEventLoop = MISSING, | ||||||
| ) -> Callable[[LF], Loop[LF]]: | ) -> Callable[[LF], Loop[LF]]: | ||||||
|     """A decorator that schedules a task in the background for you with |     """A decorator that schedules a task in the background for you with | ||||||
|     optional reconnect logic. The decorator returns a :class:`Loop`. |     optional reconnect logic. The decorator returns a :class:`Loop`. | ||||||
| @@ -724,7 +728,7 @@ def loop( | |||||||
|         Whether to handle errors and restart the task |         Whether to handle errors and restart the task | ||||||
|         using an exponential back-off algorithm similar to the |         using an exponential back-off algorithm similar to the | ||||||
|         one used in :meth:`discord.Client.connect`. |         one used in :meth:`discord.Client.connect`. | ||||||
|     loop: Optional[:class:`asyncio.AbstractEventLoop`] |     loop: :class:`asyncio.AbstractEventLoop` | ||||||
|         The loop to use to register the task, if not given |         The loop to use to register the task, if not given | ||||||
|         defaults to :func:`asyncio.get_event_loop`. |         defaults to :func:`asyncio.get_event_loop`. | ||||||
|  |  | ||||||
| @@ -736,15 +740,17 @@ def loop( | |||||||
|         The function was not a coroutine, an invalid value for the ``time`` parameter was passed, |         The function was not a coroutine, an invalid value for the ``time`` parameter was passed, | ||||||
|         or ``time`` parameter was passed in conjunction with relative time parameters. |         or ``time`` parameter was passed in conjunction with relative time parameters. | ||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     def decorator(func: LF) -> Loop[LF]: |     def decorator(func: LF) -> Loop[LF]: | ||||||
|         kwargs = { |         return Loop[LF]( | ||||||
|             'seconds': seconds, |             func, | ||||||
|             'minutes': minutes, |             seconds=seconds, | ||||||
|             'hours': hours, |             minutes=minutes, | ||||||
|             'count': count, |             hours=hours, | ||||||
|             'time': time, |             count=count, | ||||||
|             'reconnect': reconnect, |             time=time, | ||||||
|             'loop': loop, |             reconnect=reconnect, | ||||||
|         } |             loop=loop, | ||||||
|         return Loop(func, **kwargs) |         ) | ||||||
|  |  | ||||||
|     return decorator |     return decorator | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user