[tasks] Improve typing parity
This commit is contained in:
parent
b2ac327bd8
commit
a2a7b0f076
@ -27,22 +27,20 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
from typing import (
|
from typing import (
|
||||||
Any,
|
Any,
|
||||||
Awaitable,
|
Awaitable,
|
||||||
Callable,
|
Callable,
|
||||||
Generic,
|
Generic,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
Type,
|
Type,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import discord
|
import discord
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
@ -50,8 +48,6 @@ from collections.abc import Sequence
|
|||||||
from discord.backoff import ExponentialBackoff
|
from discord.backoff import ExponentialBackoff
|
||||||
from discord.utils import MISSING
|
from discord.utils import MISSING
|
||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
'loop',
|
'loop',
|
||||||
)
|
)
|
||||||
@ -61,7 +57,6 @@ _func = Callable[..., Awaitable[Any]]
|
|||||||
LF = TypeVar('LF', bound=_func)
|
LF = TypeVar('LF', bound=_func)
|
||||||
FT = TypeVar('FT', bound=_func)
|
FT = TypeVar('FT', bound=_func)
|
||||||
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
|
ET = TypeVar('ET', bound=Callable[[Any, BaseException], Awaitable[Any]])
|
||||||
LT = TypeVar('LT', bound='Loop')
|
|
||||||
|
|
||||||
|
|
||||||
class SleepHandle:
|
class SleepHandle:
|
||||||
@ -78,7 +73,7 @@ class SleepHandle:
|
|||||||
relative_delta = discord.utils.compute_timedelta(dt)
|
relative_delta = discord.utils.compute_timedelta(dt)
|
||||||
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True)
|
self.handle = self.loop.call_later(relative_delta, self.future.set_result, True)
|
||||||
|
|
||||||
def wait(self) -> asyncio.Future:
|
def wait(self) -> asyncio.Future[Any]:
|
||||||
return self.future
|
return self.future
|
||||||
|
|
||||||
def done(self) -> bool:
|
def done(self) -> bool:
|
||||||
@ -94,7 +89,9 @@ class Loop(Generic[LF]):
|
|||||||
|
|
||||||
The main interface to create this is through :func:`loop`.
|
The main interface to create this is through :func:`loop`.
|
||||||
"""
|
"""
|
||||||
def __init__(self,
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
coro: LF,
|
coro: LF,
|
||||||
seconds: float,
|
seconds: float,
|
||||||
hours: float,
|
hours: float,
|
||||||
@ -102,15 +99,15 @@ class Loop(Generic[LF]):
|
|||||||
time: Union[datetime.time, Sequence[datetime.time]],
|
time: Union[datetime.time, Sequence[datetime.time]],
|
||||||
count: Optional[int],
|
count: Optional[int],
|
||||||
reconnect: bool,
|
reconnect: bool,
|
||||||
loop: Optional[asyncio.AbstractEventLoop],
|
loop: asyncio.AbstractEventLoop,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.coro: LF = coro
|
self.coro: LF = coro
|
||||||
self.reconnect: bool = reconnect
|
self.reconnect: bool = reconnect
|
||||||
self.loop: Optional[asyncio.AbstractEventLoop] = loop
|
self.loop: asyncio.AbstractEventLoop = loop
|
||||||
self.count: Optional[int] = count
|
self.count: Optional[int] = count
|
||||||
self._current_loop = 0
|
self._current_loop = 0
|
||||||
self._handle = None
|
self._handle: SleepHandle = MISSING
|
||||||
self._task = None
|
self._task: asyncio.Task[None] = MISSING
|
||||||
self._injected = None
|
self._injected = None
|
||||||
self._valid_exception = (
|
self._valid_exception = (
|
||||||
OSError,
|
OSError,
|
||||||
@ -131,7 +128,7 @@ class Loop(Generic[LF]):
|
|||||||
|
|
||||||
self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time)
|
self.change_interval(seconds=seconds, minutes=minutes, hours=hours, time=time)
|
||||||
self._last_iteration_failed = False
|
self._last_iteration_failed = False
|
||||||
self._last_iteration = None
|
self._last_iteration: datetime.datetime = MISSING
|
||||||
self._next_iteration = None
|
self._next_iteration = None
|
||||||
|
|
||||||
if not inspect.iscoroutinefunction(self.coro):
|
if not inspect.iscoroutinefunction(self.coro):
|
||||||
@ -147,9 +144,8 @@ class Loop(Generic[LF]):
|
|||||||
else:
|
else:
|
||||||
await coro(*args, **kwargs)
|
await coro(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _try_sleep_until(self, dt: datetime.datetime):
|
def _try_sleep_until(self, dt: datetime.datetime):
|
||||||
self._handle = SleepHandle(dt=dt, loop=self.loop) # type: ignore
|
self._handle = SleepHandle(dt=dt, loop=self.loop)
|
||||||
return self._handle.wait()
|
return self._handle.wait()
|
||||||
|
|
||||||
async def _loop(self, *args: Any, **kwargs: Any) -> None:
|
async def _loop(self, *args: Any, **kwargs: Any) -> None:
|
||||||
@ -178,7 +174,7 @@ class Loop(Generic[LF]):
|
|||||||
await asyncio.sleep(backoff.delay())
|
await asyncio.sleep(backoff.delay())
|
||||||
else:
|
else:
|
||||||
await self._try_sleep_until(self._next_iteration)
|
await self._try_sleep_until(self._next_iteration)
|
||||||
|
|
||||||
if self._stop_next_iteration:
|
if self._stop_next_iteration:
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -211,14 +207,14 @@ class Loop(Generic[LF]):
|
|||||||
if obj is None:
|
if obj is None:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
copy = Loop(
|
copy: Loop[LF] = Loop(
|
||||||
self.coro,
|
self.coro,
|
||||||
seconds=self._seconds,
|
seconds=self._seconds,
|
||||||
hours=self._hours,
|
hours=self._hours,
|
||||||
minutes=self._minutes,
|
minutes=self._minutes,
|
||||||
time=self._time,
|
time=self._time,
|
||||||
count=self.count,
|
count=self.count,
|
||||||
reconnect=self.reconnect,
|
reconnect=self.reconnect,
|
||||||
loop=self.loop,
|
loop=self.loop,
|
||||||
)
|
)
|
||||||
copy._injected = obj
|
copy._injected = obj
|
||||||
@ -237,7 +233,7 @@ class Loop(Generic[LF]):
|
|||||||
"""
|
"""
|
||||||
if self._seconds is not MISSING:
|
if self._seconds is not MISSING:
|
||||||
return self._seconds
|
return self._seconds
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def minutes(self) -> Optional[float]:
|
def minutes(self) -> Optional[float]:
|
||||||
"""Optional[:class:`float`]: Read-only value for the number of minutes
|
"""Optional[:class:`float`]: Read-only value for the number of minutes
|
||||||
@ -247,7 +243,7 @@ class Loop(Generic[LF]):
|
|||||||
"""
|
"""
|
||||||
if self._minutes is not MISSING:
|
if self._minutes is not MISSING:
|
||||||
return self._minutes
|
return self._minutes
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def hours(self) -> Optional[float]:
|
def hours(self) -> Optional[float]:
|
||||||
"""Optional[:class:`float`]: Read-only value for the number of hours
|
"""Optional[:class:`float`]: Read-only value for the number of hours
|
||||||
@ -279,7 +275,7 @@ class Loop(Generic[LF]):
|
|||||||
|
|
||||||
.. versionadded:: 1.3
|
.. versionadded:: 1.3
|
||||||
"""
|
"""
|
||||||
if self._task is None:
|
if self._task is MISSING:
|
||||||
return None
|
return None
|
||||||
elif self._task and self._task.done() or self._stop_next_iteration:
|
elif self._task and self._task.done() or self._stop_next_iteration:
|
||||||
return None
|
return None
|
||||||
@ -305,7 +301,7 @@ class Loop(Generic[LF]):
|
|||||||
|
|
||||||
return await self.coro(*args, **kwargs)
|
return await self.coro(*args, **kwargs)
|
||||||
|
|
||||||
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task:
|
def start(self, *args: Any, **kwargs: Any) -> asyncio.Task[None]:
|
||||||
r"""Starts the internal task in the event loop.
|
r"""Starts the internal task in the event loop.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -326,13 +322,13 @@ class Loop(Generic[LF]):
|
|||||||
The task that has been created.
|
The task that has been created.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if self._task is not None and not self._task.done():
|
if self._task is not MISSING and not self._task.done():
|
||||||
raise RuntimeError('Task is already launched and is not completed.')
|
raise RuntimeError('Task is already launched and is not completed.')
|
||||||
|
|
||||||
if self._injected is not None:
|
if self._injected is not None:
|
||||||
args = (self._injected, *args)
|
args = (self._injected, *args)
|
||||||
|
|
||||||
if self.loop is None:
|
if self.loop is MISSING:
|
||||||
self.loop = asyncio.get_event_loop()
|
self.loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
self._task = self.loop.create_task(self._loop(*args, **kwargs))
|
self._task = self.loop.create_task(self._loop(*args, **kwargs))
|
||||||
@ -356,7 +352,7 @@ class Loop(Generic[LF]):
|
|||||||
|
|
||||||
.. versionadded:: 1.2
|
.. versionadded:: 1.2
|
||||||
"""
|
"""
|
||||||
if self._task and not self._task.done():
|
if self._task is not MISSING and not self._task.done():
|
||||||
self._stop_next_iteration = True
|
self._stop_next_iteration = True
|
||||||
|
|
||||||
def _can_be_cancelled(self) -> bool:
|
def _can_be_cancelled(self) -> bool:
|
||||||
@ -383,7 +379,7 @@ class Loop(Generic[LF]):
|
|||||||
The keyword arguments to use.
|
The keyword arguments to use.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def restart_when_over(fut, *, args=args, kwargs=kwargs):
|
def restart_when_over(fut: Any, *, args: Any = args, kwargs: Any = kwargs) -> None:
|
||||||
self._task.remove_done_callback(restart_when_over)
|
self._task.remove_done_callback(restart_when_over)
|
||||||
self.start(*args, **kwargs)
|
self.start(*args, **kwargs)
|
||||||
|
|
||||||
@ -446,9 +442,9 @@ class Loop(Generic[LF]):
|
|||||||
self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions)
|
self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions)
|
||||||
return len(self._valid_exception) == old_length - len(exceptions)
|
return len(self._valid_exception) == old_length - len(exceptions)
|
||||||
|
|
||||||
def get_task(self) -> Optional[asyncio.Task]:
|
def get_task(self) -> Optional[asyncio.Task[None]]:
|
||||||
"""Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running."""
|
"""Optional[:class:`asyncio.Task`]: Fetches the internal task or ``None`` if there isn't one running."""
|
||||||
return self._task
|
return self._task if self._task is not MISSING else None
|
||||||
|
|
||||||
def is_being_cancelled(self) -> bool:
|
def is_being_cancelled(self) -> bool:
|
||||||
"""Whether the task is being cancelled."""
|
"""Whether the task is being cancelled."""
|
||||||
@ -466,7 +462,7 @@ class Loop(Generic[LF]):
|
|||||||
|
|
||||||
.. versionadded:: 1.4
|
.. versionadded:: 1.4
|
||||||
"""
|
"""
|
||||||
return not bool(self._task.done()) if self._task else False
|
return not bool(self._task.done()) if self._task is not MISSING else False
|
||||||
|
|
||||||
async def _error(self, *args: Any) -> None:
|
async def _error(self, *args: Any) -> None:
|
||||||
exception: Exception = args[-1]
|
exception: Exception = args[-1]
|
||||||
@ -560,7 +556,9 @@ class Loop(Generic[LF]):
|
|||||||
self._time_index = 0
|
self._time_index = 0
|
||||||
if self._current_loop == 0:
|
if self._current_loop == 0:
|
||||||
# if we're at the last index on the first iteration, we need to sleep until tomorrow
|
# if we're at the last index on the first iteration, we need to sleep until tomorrow
|
||||||
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0])
|
return datetime.datetime.combine(
|
||||||
|
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0]
|
||||||
|
)
|
||||||
|
|
||||||
next_time = self._time[self._time_index]
|
next_time = self._time[self._time_index]
|
||||||
|
|
||||||
@ -568,7 +566,7 @@ class Loop(Generic[LF]):
|
|||||||
self._time_index += 1
|
self._time_index += 1
|
||||||
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time)
|
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time)
|
||||||
|
|
||||||
next_date = cast(datetime.datetime, self._last_iteration)
|
next_date = self._last_iteration
|
||||||
if self._time_index == 0:
|
if self._time_index == 0:
|
||||||
# we can assume that the earliest time should be scheduled for "tomorrow"
|
# we can assume that the earliest time should be scheduled for "tomorrow"
|
||||||
next_date += datetime.timedelta(days=1)
|
next_date += datetime.timedelta(days=1)
|
||||||
@ -576,12 +574,14 @@ class Loop(Generic[LF]):
|
|||||||
self._time_index += 1
|
self._time_index += 1
|
||||||
return datetime.datetime.combine(next_date, next_time)
|
return datetime.datetime.combine(next_date, next_time)
|
||||||
|
|
||||||
def _prepare_time_index(self, now: Optional[datetime.datetime] = None) -> None:
|
def _prepare_time_index(self, now: datetime.datetime = MISSING) -> None:
|
||||||
# now kwarg should be a datetime.datetime representing the time "now"
|
# now kwarg should be a datetime.datetime representing the time "now"
|
||||||
# to calculate the next time index from
|
# to calculate the next time index from
|
||||||
|
|
||||||
# pre-condition: self._time is set
|
# pre-condition: self._time is set
|
||||||
time_now = (now or datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)).timetz()
|
time_now = (
|
||||||
|
now if now is not MISSING else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)
|
||||||
|
).timetz()
|
||||||
for idx, time in enumerate(self._time):
|
for idx, time in enumerate(self._time):
|
||||||
if time >= time_now:
|
if time >= time_now:
|
||||||
self._time_index = idx
|
self._time_index = idx
|
||||||
@ -597,20 +597,24 @@ class Loop(Generic[LF]):
|
|||||||
utc: datetime.timezone = datetime.timezone.utc,
|
utc: datetime.timezone = datetime.timezone.utc,
|
||||||
) -> List[datetime.time]:
|
) -> List[datetime.time]:
|
||||||
if isinstance(time, dt):
|
if isinstance(time, dt):
|
||||||
ret = time if time.tzinfo is not None else time.replace(tzinfo=utc)
|
inner = time if time.tzinfo is not None else time.replace(tzinfo=utc)
|
||||||
return [ret]
|
return [inner]
|
||||||
if not isinstance(time, Sequence):
|
if not isinstance(time, Sequence):
|
||||||
raise TypeError(f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.')
|
raise TypeError(
|
||||||
|
f'Expected datetime.time or a sequence of datetime.time for ``time``, received {type(time)!r} instead.'
|
||||||
|
)
|
||||||
if not time:
|
if not time:
|
||||||
raise ValueError('time parameter must not be an empty sequence.')
|
raise ValueError('time parameter must not be an empty sequence.')
|
||||||
|
|
||||||
ret = []
|
ret: List[datetime.time] = []
|
||||||
for index, t in enumerate(time):
|
for index, t in enumerate(time):
|
||||||
if not isinstance(t, dt):
|
if not isinstance(t, dt):
|
||||||
raise TypeError(f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.')
|
raise TypeError(
|
||||||
|
f'Expected a sequence of {dt!r} for ``time``, received {type(t).__name__!r} at index {index} instead.'
|
||||||
|
)
|
||||||
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
|
ret.append(t if t.tzinfo is not None else t.replace(tzinfo=utc))
|
||||||
|
|
||||||
ret = sorted(set(ret)) # de-dupe and sort times
|
ret = sorted(set(ret)) # de-dupe and sort times
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def change_interval(
|
def change_interval(
|
||||||
@ -691,7 +695,7 @@ def loop(
|
|||||||
time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
|
time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
|
||||||
count: Optional[int] = None,
|
count: Optional[int] = None,
|
||||||
reconnect: bool = True,
|
reconnect: bool = True,
|
||||||
loop: Optional[asyncio.AbstractEventLoop] = None,
|
loop: asyncio.AbstractEventLoop = MISSING,
|
||||||
) -> Callable[[LF], Loop[LF]]:
|
) -> Callable[[LF], Loop[LF]]:
|
||||||
"""A decorator that schedules a task in the background for you with
|
"""A decorator that schedules a task in the background for you with
|
||||||
optional reconnect logic. The decorator returns a :class:`Loop`.
|
optional reconnect logic. The decorator returns a :class:`Loop`.
|
||||||
@ -707,7 +711,7 @@ def loop(
|
|||||||
time: Union[:class:`datetime.time`, Sequence[:class:`datetime.time`]]
|
time: Union[:class:`datetime.time`, Sequence[:class:`datetime.time`]]
|
||||||
The exact times to run this loop at. Either a non-empty list or a single
|
The exact times to run this loop at. Either a non-empty list or a single
|
||||||
value of :class:`datetime.time` should be passed. Timezones are supported.
|
value of :class:`datetime.time` should be passed. Timezones are supported.
|
||||||
If no timezone is given for the times, it is assumed to represent UTC time.
|
If no timezone is given for the times, it is assumed to represent UTC time.
|
||||||
|
|
||||||
This cannot be used in conjunction with the relative time parameters.
|
This cannot be used in conjunction with the relative time parameters.
|
||||||
|
|
||||||
@ -724,7 +728,7 @@ def loop(
|
|||||||
Whether to handle errors and restart the task
|
Whether to handle errors and restart the task
|
||||||
using an exponential back-off algorithm similar to the
|
using an exponential back-off algorithm similar to the
|
||||||
one used in :meth:`discord.Client.connect`.
|
one used in :meth:`discord.Client.connect`.
|
||||||
loop: Optional[:class:`asyncio.AbstractEventLoop`]
|
loop: :class:`asyncio.AbstractEventLoop`
|
||||||
The loop to use to register the task, if not given
|
The loop to use to register the task, if not given
|
||||||
defaults to :func:`asyncio.get_event_loop`.
|
defaults to :func:`asyncio.get_event_loop`.
|
||||||
|
|
||||||
@ -736,15 +740,17 @@ def loop(
|
|||||||
The function was not a coroutine, an invalid value for the ``time`` parameter was passed,
|
The function was not a coroutine, an invalid value for the ``time`` parameter was passed,
|
||||||
or ``time`` parameter was passed in conjunction with relative time parameters.
|
or ``time`` parameter was passed in conjunction with relative time parameters.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func: LF) -> Loop[LF]:
|
def decorator(func: LF) -> Loop[LF]:
|
||||||
kwargs = {
|
return Loop[LF](
|
||||||
'seconds': seconds,
|
func,
|
||||||
'minutes': minutes,
|
seconds=seconds,
|
||||||
'hours': hours,
|
minutes=minutes,
|
||||||
'count': count,
|
hours=hours,
|
||||||
'time': time,
|
count=count,
|
||||||
'reconnect': reconnect,
|
time=time,
|
||||||
'loop': loop,
|
reconnect=reconnect,
|
||||||
}
|
loop=loop,
|
||||||
return Loop(func, **kwargs)
|
)
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
Loading…
x
Reference in New Issue
Block a user