[tasks] Type hint the tasks extension
This commit is contained in:
parent
f5727ff0d0
commit
ef22178dee
@ -22,8 +22,23 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
Awaitable,
|
||||
Callable,
|
||||
Generic,
|
||||
List,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
)
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
import inspect
|
||||
@ -33,6 +48,7 @@ import traceback
|
||||
|
||||
from collections.abc import Sequence
|
||||
from discord.backoff import ExponentialBackoff
|
||||
from discord.utils import MISSING
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@ -40,41 +56,58 @@ __all__ = (
|
||||
'loop',
|
||||
)
|
||||
|
||||
T = TypeVar('T')
|
||||
_func = Callable[..., Awaitable[Any]]
|
||||
LF = TypeVar('LF', bound=_func)
|
||||
FT = TypeVar('FT', bound=_func)
|
||||
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
|
||||
LT = TypeVar('LT', bound='Loop')
|
||||
|
||||
|
||||
class SleepHandle:
|
||||
__slots__ = ('future', 'loop', 'handle')
|
||||
|
||||
def __init__(self, dt, *, loop):
|
||||
def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self.loop = loop
|
||||
self.future = future = loop.create_future()
|
||||
relative_delta = discord.utils.compute_timedelta(dt)
|
||||
self.handle = loop.call_later(relative_delta, future.set_result, True)
|
||||
|
||||
def recalculate(self, dt):
|
||||
def recalculate(self, dt: datetime.datetime) -> None:
|
||||
self.handle.cancel()
|
||||
relative_delta = discord.utils.compute_timedelta(dt)
|
||||
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True)
|
||||
|
||||
def wait(self):
|
||||
def wait(self) -> asyncio.Future:
|
||||
return self.future
|
||||
|
||||
def done(self):
|
||||
def done(self) -> bool:
|
||||
return self.future.done()
|
||||
|
||||
def cancel(self):
|
||||
def cancel(self) -> None:
|
||||
self.handle.cancel()
|
||||
self.future.cancel()
|
||||
|
||||
|
||||
class Loop:
|
||||
class Loop(Generic[LF]):
|
||||
"""A background task helper that abstracts the loop and reconnection logic for you.
|
||||
|
||||
The main interface to create this is through :func:`loop`.
|
||||
"""
|
||||
def __init__(self, coro, seconds, hours, minutes, time, count, reconnect, loop):
|
||||
self.coro = coro
|
||||
self.reconnect = reconnect
|
||||
self.loop = loop
|
||||
self.count = count
|
||||
def __init__(self,
|
||||
coro: LF,
|
||||
seconds: float,
|
||||
hours: float,
|
||||
minutes: float,
|
||||
time: Union[datetime.time, Sequence[datetime.time]],
|
||||
count: Optional[int],
|
||||
reconnect: bool,
|
||||
loop: Optional[asyncio.AbstractEventLoop],
|
||||
) -> None:
|
||||
self.coro: LF = coro
|
||||
self.reconnect: bool = reconnect
|
||||
self.loop: Optional[asyncio.AbstractEventLoop] = loop
|
||||
self.count: Optional[int] = count
|
||||
self._current_loop = 0
|
||||
self._handle = None
|
||||
self._task = None
|
||||
@ -104,7 +137,7 @@ class Loop:
|
||||
if not inspect.iscoroutinefunction(self.coro):
|
||||
raise TypeError(f'Expected coroutine function, not {type(self.coro).__name__!r}.')
|
||||
|
||||
async def _call_loop_function(self, name, *args, **kwargs):
|
||||
async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None:
|
||||
coro = getattr(self, '_' + name)
|
||||
if coro is None:
|
||||
return
|
||||
@ -114,16 +147,16 @@ class Loop:
|
||||
else:
|
||||
await coro(*args, **kwargs)
|
||||
|
||||
def _try_sleep_until(self, dt):
|
||||
self._handle = SleepHandle(dt=dt, loop=self.loop)
|
||||
|
||||
def _try_sleep_until(self, dt: datetime.datetime):
|
||||
self._handle = SleepHandle(dt=dt, loop=self.loop) # type: ignore
|
||||
return self._handle.wait()
|
||||
|
||||
async def _loop(self, *args, **kwargs):
|
||||
async def _loop(self, *args: Any, **kwargs: Any) -> None:
|
||||
backoff = ExponentialBackoff()
|
||||
await self._call_loop_function('before_loop')
|
||||
sleep_until = discord.utils.sleep_until
|
||||
self._last_iteration_failed = False
|
||||
if self._time is not None:
|
||||
if self._time is not MISSING:
|
||||
# the time index should be prepared every time the internal loop is started
|
||||
self._prepare_time_index()
|
||||
self._next_iteration = self._get_next_sleep_time()
|
||||
@ -174,7 +207,7 @@ class Loop:
|
||||
self._stop_next_iteration = False
|
||||
self._has_failed = False
|
||||
|
||||
def __get__(self, obj, objtype):
|
||||
def __get__(self, obj: T, objtype: Type[T]) -> Loop[LF]:
|
||||
if obj is None:
|
||||
return self
|
||||
|
||||
@ -183,8 +216,8 @@ class Loop:
|
||||
seconds=self._seconds,
|
||||
hours=self._hours,
|
||||
minutes=self._minutes,
|
||||
count=self.count,
|
||||
time=self._time,
|
||||
count=self.count,
|
||||
reconnect=self.reconnect,
|
||||
loop=self.loop,
|
||||
)
|
||||
@ -196,49 +229,52 @@ class Loop:
|
||||
return copy
|
||||
|
||||
@property
|
||||
def seconds(self):
|
||||
def seconds(self) -> Optional[float]:
|
||||
"""Optional[:class:`float`]: Read-only value for the number of seconds
|
||||
between each iteration. ``None`` if an explicit ``time`` value was passed instead.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return self._seconds
|
||||
if self._seconds is not MISSING:
|
||||
return self._seconds
|
||||
|
||||
@property
|
||||
def minutes(self):
|
||||
def minutes(self) -> Optional[float]:
|
||||
"""Optional[:class:`float`]: Read-only value for the number of minutes
|
||||
between each iteration. ``None`` if an explicit ``time`` value was passed instead.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return self._minutes
|
||||
if self._minutes is not MISSING:
|
||||
return self._minutes
|
||||
|
||||
@property
|
||||
def hours(self):
|
||||
def hours(self) -> Optional[float]:
|
||||
"""Optional[:class:`float`]: Read-only value for the number of hours
|
||||
between each iteration. ``None`` if an explicit ``time`` value was passed instead.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
return self._hours
|
||||
if self._hours is not MISSING:
|
||||
return self._hours
|
||||
|
||||
@property
|
||||
def time(self):
|
||||
def time(self) -> Optional[List[datetime.time]]:
|
||||
"""Optional[List[:class:`datetime.time`]]: Read-only list for the exact times this loop runs at.
|
||||
``None`` if relative times were passed instead.
|
||||
|
||||
.. versionadded:: 2.0
|
||||
"""
|
||||
if self._time is not None:
|
||||
if self._time is not MISSING:
|
||||
return self._time.copy()
|
||||
|
||||
@property
|
||||
def current_loop(self):
|
||||
def current_loop(self) -> int:
|
||||
""":class:`int`: The current iteration of the loop."""
|
||||
return self._current_loop
|
||||
|
||||
@property
|
||||
def next_iteration(self):
|
||||
def next_iteration(self) -> Optional[datetime.datetime]:
|
||||
"""Optional[:class:`datetime.datetime`]: When the next iteration of the loop will occur.
|
||||
|
||||
.. versionadded:: 1.3
|
||||
@ -249,7 +285,7 @@ class Loop:
|
||||
return None
|
||||
return self._next_iteration
|
||||
|
||||
async def __call__(self, *args, **kwargs):
|
||||
async def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
r"""|coro|
|
||||
|
||||
Calls the internal callback that the task holds.
|
||||
@ -269,7 +305,7 @@ class Loop:
|
||||
|
||||
return await self.coro(*args, **kwargs)
|
||||
|
||||
def start(self, *args, **kwargs):
|
||||
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task:
|
||||
r"""Starts the internal task in the event loop.
|
||||
|
||||
Parameters
|
||||
@ -302,7 +338,7 @@ class Loop:
|
||||
self._task = self.loop.create_task(self._loop(*args, **kwargs))
|
||||
return self._task
|
||||
|
||||
def stop(self):
|
||||
def stop(self) -> None:
|
||||
r"""Gracefully stops the task from running.
|
||||
|
||||
Unlike :meth:`cancel`\, this allows the task to finish its
|
||||
@ -323,15 +359,15 @@ class Loop:
|
||||
if self._task and not self._task.done():
|
||||
self._stop_next_iteration = True
|
||||
|
||||
def _can_be_cancelled(self):
|
||||
return not self._is_being_cancelled and self._task and not self._task.done()
|
||||
def _can_be_cancelled(self) -> bool:
|
||||
return bool(not self._is_being_cancelled and self._task and not self._task.done())
|
||||
|
||||
def cancel(self):
|
||||
def cancel(self) -> None:
|
||||
"""Cancels the internal task, if it is running."""
|
||||
if self._can_be_cancelled():
|
||||
self._task.cancel()
|
||||
|
||||
def restart(self, *args, **kwargs):
|
||||
def restart(self, *args: Any, **kwargs: Any) -> None:
|
||||
r"""A convenience method to restart the internal task.
|
||||
|
||||
.. note::
|
||||
@ -355,7 +391,7 @@ class Loop:
|
||||
self._task.add_done_callback(restart_when_over)
|
||||
self._task.cancel()
|
||||
|
||||
def add_exception_type(self, *exceptions):
|
||||
def add_exception_type(self, *exceptions: Type[BaseException]) -> None:
|
||||
r"""Adds exception types to be handled during the reconnect logic.
|
||||
|
||||
By default the exception types handled are those handled by
|
||||
@ -384,7 +420,7 @@ class Loop:
|
||||
|
||||
self._valid_exception = (*self._valid_exception, *exceptions)
|
||||
|
||||
def clear_exception_types(self):
|
||||
def clear_exception_types(self) -> None:
|
||||
"""Removes all exception types that are handled.
|
||||
|
||||
.. note::
|
||||
@ -393,7 +429,7 @@ class Loop:
|
||||
"""
|
||||
self._valid_exception = tuple()
|
||||
|
||||
def remove_exception_type(self, *exceptions):
|
||||
def remove_exception_type(self, *exceptions: Type[BaseException]) -> bool:
|
||||
r"""Removes exception types from being handled during the reconnect logic.
|
||||
|
||||
Parameters
|
||||
@ -410,34 +446,34 @@ class Loop:
|
||||
self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions)
|
||||
return len(self._valid_exception) == old_length - len(exceptions)
|
||||
|
||||
def get_task(self):
|
||||
def get_task(self) -> Optional[asyncio.Task]:
|
||||
"""Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running."""
|
||||
return self._task
|
||||
|
||||
def is_being_cancelled(self):
|
||||
def is_being_cancelled(self) -> bool:
|
||||
"""Whether the task is being cancelled."""
|
||||
return self._is_being_cancelled
|
||||
|
||||
def failed(self):
|
||||
def failed(self) -> bool:
|
||||
""":class:`bool`: Whether the internal task has failed.
|
||||
|
||||
.. versionadded:: 1.2
|
||||
"""
|
||||
return self._has_failed
|
||||
|
||||
def is_running(self):
|
||||
def is_running(self) -> bool:
|
||||
""":class:`bool`: Check if the task is currently running.
|
||||
|
||||
.. versionadded:: 1.4
|
||||
"""
|
||||
return not bool(self._task.done()) if self._task else False
|
||||
|
||||
async def _error(self, *args):
|
||||
exception = args[-1]
|
||||
async def _error(self, *args: Any) -> None:
|
||||
exception: Exception = args[-1]
|
||||
print(f'Unhandled exception in internal background task {self.coro.__name__!r}.', file=sys.stderr)
|
||||
traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr)
|
||||
|
||||
def before_loop(self, coro):
|
||||
def before_loop(self, coro: FT) -> FT:
|
||||
"""A decorator that registers a coroutine to be called before the loop starts running.
|
||||
|
||||
This is useful if you want to wait for some bot state before the loop starts,
|
||||
@ -462,7 +498,7 @@ class Loop:
|
||||
self._before_loop = coro
|
||||
return coro
|
||||
|
||||
def after_loop(self, coro):
|
||||
def after_loop(self, coro: FT) -> FT:
|
||||
"""A decorator that register a coroutine to be called after the loop finished running.
|
||||
|
||||
The coroutine must take no arguments (except ``self`` in a class context).
|
||||
@ -490,7 +526,7 @@ class Loop:
|
||||
self._after_loop = coro
|
||||
return coro
|
||||
|
||||
def error(self, coro):
|
||||
def error(self, coro: ET) -> ET:
|
||||
"""A decorator that registers a coroutine to be called if the task encounters an unhandled exception.
|
||||
|
||||
The coroutine must take only one argument the exception raised (except ``self`` in a class context).
|
||||
@ -513,11 +549,11 @@ class Loop:
|
||||
if not inspect.iscoroutinefunction(coro):
|
||||
raise TypeError(f'Expected coroutine function, received {coro.__class__.__name__!r}.')
|
||||
|
||||
self._error = coro
|
||||
self._error = coro # type: ignore
|
||||
return coro
|
||||
|
||||
def _get_next_sleep_time(self):
|
||||
if self._sleep is not None:
|
||||
def _get_next_sleep_time(self) -> datetime.datetime:
|
||||
if self._sleep is not MISSING:
|
||||
return self._last_iteration + datetime.timedelta(seconds=self._sleep)
|
||||
|
||||
if self._time_index >= len(self._time):
|
||||
@ -532,7 +568,7 @@ class Loop:
|
||||
self._time_index += 1
|
||||
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time)
|
||||
|
||||
next_date = self._last_iteration
|
||||
next_date = cast(datetime.datetime, self._last_iteration)
|
||||
if self._time_index == 0:
|
||||
# we can assume that the earliest time should be scheduled for "tomorrow"
|
||||
next_date += datetime.timedelta(days=1)
|
||||
@ -540,7 +576,7 @@ class Loop:
|
||||
self._time_index += 1
|
||||
return datetime.datetime.combine(next_date, next_time)
|
||||
|
||||
def _prepare_time_index(self, now=None):
|
||||
def _prepare_time_index(self, now: Optional[datetime.datetime] = None) -> None:
|
||||
# now kwarg should be a datetime.datetime representing the time "now"
|
||||
# to calculate the next time index from
|
||||
|
||||
@ -553,25 +589,38 @@ class Loop:
|
||||
else:
|
||||
self._time_index = 0
|
||||
|
||||
def _get_time_parameter(self, time, *, inst=isinstance, dt=datetime.time, utc=datetime.timezone.utc):
|
||||
if inst(time, dt):
|
||||
def _get_time_parameter(
|
||||
self,
|
||||
time: Union[datetime.time, Sequence[datetime.time]],
|
||||
*,
|
||||
dt: Type[datetime.time] = datetime.time,
|
||||
utc: datetime.timezone = datetime.timezone.utc,
|
||||
) -> List[datetime.time]:
|
||||
if isinstance(time, dt):
|
||||
ret = time if time.tzinfo is not None else time.replace(tzinfo=utc)
|
||||
return [ret]
|
||||
if not inst(time, Sequence):
|
||||
if not isinstance(time, Sequence):
|
||||
raise TypeError(f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.')
|
||||
if not time:
|
||||
raise ValueError('time parameter must not be an empty sequence.')
|
||||
|
||||
ret = []
|
||||
for index, t in enumerate(time):
|
||||
if not inst(t, dt):
|
||||
if not isinstance(t, dt):
|
||||
raise TypeError(f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.')
|
||||
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
|
||||
|
||||
ret = sorted(set(ret)) # de-dupe and sort times
|
||||
return ret
|
||||
|
||||
def change_interval(self, *, seconds=0, minutes=0, hours=0, time=None):
|
||||
def change_interval(
|
||||
self,
|
||||
*,
|
||||
seconds: float = 0,
|
||||
minutes: float = 0,
|
||||
hours: float = 0,
|
||||
time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
|
||||
) -> None:
|
||||
"""Changes the interval for the sleep time.
|
||||
|
||||
.. versionadded:: 1.2
|
||||
@ -604,7 +653,10 @@ class Loop:
|
||||
``time`` parameter was passed in conjunction with relative time parameters.
|
||||
"""
|
||||
|
||||
if time is None:
|
||||
if time is MISSING:
|
||||
seconds = seconds or 0
|
||||
minutes = minutes or 0
|
||||
hours = hours or 0
|
||||
sleep = seconds + (minutes * 60.0) + (hours * 3600.0)
|
||||
if sleep < 0:
|
||||
raise ValueError('Total number of seconds cannot be less than zero.')
|
||||
@ -613,12 +665,12 @@ class Loop:
|
||||
self._seconds = float(seconds)
|
||||
self._hours = float(hours)
|
||||
self._minutes = float(minutes)
|
||||
self._time = None
|
||||
self._time: List[datetime.time] = MISSING
|
||||
else:
|
||||
if any((seconds, minutes, hours)):
|
||||
raise TypeError('Cannot mix explicit time with relative time')
|
||||
self._time = self._get_time_parameter(time)
|
||||
self._sleep = self._seconds = self._minutes = self._hours = None
|
||||
self._sleep = self._seconds = self._minutes = self._hours = MISSING
|
||||
|
||||
if self.is_running():
|
||||
if self._time is not None:
|
||||
@ -631,7 +683,16 @@ class Loop:
|
||||
self._handle.recalculate(self._next_iteration)
|
||||
|
||||
|
||||
def loop(*, seconds=0, minutes=0, hours=0, count=None, time=None, reconnect=True, loop=None):
|
||||
def loop(
|
||||
*,
|
||||
seconds: float = MISSING,
|
||||
minutes: float = MISSING,
|
||||
hours: float = MISSING,
|
||||
time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
|
||||
count: Optional[int] = None,
|
||||
reconnect: bool = True,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
||||
) -> Callable[[LF], Loop[LF]]:
|
||||
"""A decorator that schedules a task in the background for you with
|
||||
optional reconnect logic. The decorator returns a :class:`Loop`.
|
||||
|
||||
@ -663,7 +724,7 @@ def loop(*, seconds=0, minutes=0, hours=0, count=None, time=None, reconnect=True
|
||||
Whether to handle errors and restart the task
|
||||
using an exponential back-off algorithm similar to the
|
||||
one used in :meth:`discord.Client.connect`.
|
||||
loop: :class:`asyncio.AbstractEventLoop`
|
||||
loop: Optional[:class:`asyncio.AbstractEventLoop`]
|
||||
The loop to use to register the task, if not given
|
||||
defaults to :func:`asyncio.get_event_loop`.
|
||||
|
||||
@ -675,7 +736,7 @@ def loop(*, seconds=0, minutes=0, hours=0, count=None, time=None, reconnect=True
|
||||
The function was not a coroutine, an invalid value for the ``time`` parameter was passed,
|
||||
or ``time`` parameter was passed in conjunction with relative time parameters.
|
||||
"""
|
||||
def decorator(func):
|
||||
def decorator(func: LF) -> Loop[LF]:
|
||||
kwargs = {
|
||||
'seconds': seconds,
|
||||
'minutes': minutes,
|
||||
@ -683,7 +744,7 @@ def loop(*, seconds=0, minutes=0, hours=0, count=None, time=None, reconnect=True
|
||||
'count': count,
|
||||
'time': time,
|
||||
'reconnect': reconnect,
|
||||
'loop': loop
|
||||
'loop': loop,
|
||||
}
|
||||
return Loop(func, **kwargs)
|
||||
return decorator
|
||||
|
Loading…
x
Reference in New Issue
Block a user