mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-07-11 20:35:26 +00:00
[tasks] Refactor tasks to not store a time index state
It's better to recompute it every time rather than suffer from maintaining the extra state.
This commit is contained in:
parent
c02a3c0bb2
commit
6a43d60acf
@ -148,13 +148,17 @@ class Loop(Generic[LF]):
|
|||||||
self._handle = SleepHandle(dt=dt, loop=asyncio.get_running_loop())
|
self._handle = SleepHandle(dt=dt, loop=asyncio.get_running_loop())
|
||||||
return self._handle.wait()
|
return self._handle.wait()
|
||||||
|
|
||||||
|
def _is_relative_time(self) -> bool:
|
||||||
|
return self._time is MISSING
|
||||||
|
|
||||||
|
def _is_explicit_time(self) -> bool:
|
||||||
|
return self._time is not MISSING
|
||||||
|
|
||||||
async def _loop(self, *args: Any, **kwargs: Any) -> None:
|
async def _loop(self, *args: Any, **kwargs: Any) -> None:
|
||||||
backoff = ExponentialBackoff()
|
backoff = ExponentialBackoff()
|
||||||
await self._call_loop_function('before_loop')
|
await self._call_loop_function('before_loop')
|
||||||
self._last_iteration_failed = False
|
self._last_iteration_failed = False
|
||||||
if self._time is not MISSING:
|
if self._is_explicit_time():
|
||||||
# the time index should be prepared every time the internal loop is started
|
|
||||||
self._prepare_time_index()
|
|
||||||
self._next_iteration = self._get_next_sleep_time()
|
self._next_iteration = self._get_next_sleep_time()
|
||||||
else:
|
else:
|
||||||
self._next_iteration = datetime.datetime.now(datetime.timezone.utc)
|
self._next_iteration = datetime.datetime.now(datetime.timezone.utc)
|
||||||
@ -164,7 +168,7 @@ class Loop(Generic[LF]):
|
|||||||
return
|
return
|
||||||
while True:
|
while True:
|
||||||
# sleep before the body of the task for explicit time intervals
|
# sleep before the body of the task for explicit time intervals
|
||||||
if self._time is not MISSING:
|
if self._is_explicit_time():
|
||||||
await self._try_sleep_until(self._next_iteration)
|
await self._try_sleep_until(self._next_iteration)
|
||||||
if not self._last_iteration_failed:
|
if not self._last_iteration_failed:
|
||||||
self._last_iteration = self._next_iteration
|
self._last_iteration = self._next_iteration
|
||||||
@ -182,7 +186,7 @@ class Loop(Generic[LF]):
|
|||||||
return
|
return
|
||||||
|
|
||||||
# sleep after the body of the task for relative time intervals
|
# sleep after the body of the task for relative time intervals
|
||||||
if self._time is MISSING:
|
if self._is_relative_time():
|
||||||
await self._try_sleep_until(self._next_iteration)
|
await self._try_sleep_until(self._next_iteration)
|
||||||
|
|
||||||
self._current_loop += 1
|
self._current_loop += 1
|
||||||
@ -553,47 +557,36 @@ class Loop(Generic[LF]):
|
|||||||
self._error = coro # type: ignore
|
self._error = coro # type: ignore
|
||||||
return coro
|
return coro
|
||||||
|
|
||||||
def _get_next_sleep_time(self) -> datetime.datetime:
|
def _get_next_sleep_time(self, now: datetime.datetime = MISSING) -> datetime.datetime:
|
||||||
if self._sleep is not MISSING:
|
if self._sleep is not MISSING:
|
||||||
return self._last_iteration + datetime.timedelta(seconds=self._sleep)
|
return self._last_iteration + datetime.timedelta(seconds=self._sleep)
|
||||||
|
|
||||||
if self._time_index >= len(self._time):
|
if now is MISSING:
|
||||||
self._time_index = 0
|
now = datetime.datetime.now(datetime.timezone.utc)
|
||||||
if self._current_loop == 0:
|
|
||||||
# if we're at the last index on the first iteration, we need to sleep until tomorrow
|
|
||||||
return datetime.datetime.combine(
|
|
||||||
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
next_time = self._time[self._time_index]
|
index = self._start_time_relative_to(now)
|
||||||
|
|
||||||
if self._current_loop == 0:
|
if index is None:
|
||||||
self._time_index += 1
|
time = self._time[0]
|
||||||
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time)
|
tomorrow = now + datetime.timedelta(days=1)
|
||||||
|
date = tomorrow.date()
|
||||||
|
else:
|
||||||
|
date = now.date()
|
||||||
|
time = self._time[index]
|
||||||
|
|
||||||
next_date = self._last_iteration
|
return datetime.datetime.combine(date, time, tzinfo=time.tzinfo or datetime.timezone.utc)
|
||||||
if self._time_index == 0:
|
|
||||||
# we can assume that the earliest time should be scheduled for "tomorrow"
|
|
||||||
next_date += datetime.timedelta(days=1)
|
|
||||||
|
|
||||||
self._time_index += 1
|
def _start_time_relative_to(self, now: datetime.datetime) -> Optional[int]:
|
||||||
return datetime.datetime.combine(next_date, next_time)
|
|
||||||
|
|
||||||
def _prepare_time_index(self, now: datetime.datetime = MISSING) -> None:
|
|
||||||
# now kwarg should be a datetime.datetime representing the time "now"
|
# now kwarg should be a datetime.datetime representing the time "now"
|
||||||
# to calculate the next time index from
|
# to calculate the next time index from
|
||||||
|
|
||||||
# pre-condition: self._time is set
|
# pre-condition: self._time is set
|
||||||
time_now = (
|
time_now = now.timetz()
|
||||||
now if now is not MISSING else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)
|
|
||||||
).timetz()
|
|
||||||
idx = -1
|
|
||||||
for idx, time in enumerate(self._time):
|
for idx, time in enumerate(self._time):
|
||||||
if time >= time_now:
|
if time >= time_now:
|
||||||
self._time_index = idx
|
return idx
|
||||||
break
|
|
||||||
else:
|
else:
|
||||||
self._time_index = idx + 1
|
return None
|
||||||
|
|
||||||
def _get_time_parameter(
|
def _get_time_parameter(
|
||||||
self,
|
self,
|
||||||
@ -683,10 +676,6 @@ class Loop(Generic[LF]):
|
|||||||
self._sleep = self._seconds = self._minutes = self._hours = MISSING
|
self._sleep = self._seconds = self._minutes = self._hours = MISSING
|
||||||
|
|
||||||
if self.is_running():
|
if self.is_running():
|
||||||
if self._time is not MISSING:
|
|
||||||
# prepare the next time index starting from after the last iteration
|
|
||||||
self._prepare_time_index(now=self._last_iteration)
|
|
||||||
|
|
||||||
self._next_iteration = self._get_next_sleep_time()
|
self._next_iteration = self._get_next_sleep_time()
|
||||||
if self._handle and not self._handle.done():
|
if self._handle and not self._handle.done():
|
||||||
# the loop is sleeping, recalculate based on new interval
|
# the loop is sleeping, recalculate based on new interval
|
||||||
@ -701,7 +690,6 @@ def loop(
|
|||||||
time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
|
time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
|
||||||
count: Optional[int] = None,
|
count: Optional[int] = None,
|
||||||
reconnect: bool = True,
|
reconnect: bool = True,
|
||||||
loop: asyncio.AbstractEventLoop = MISSING,
|
|
||||||
) -> Callable[[LF], Loop[LF]]:
|
) -> Callable[[LF], Loop[LF]]:
|
||||||
"""A decorator that schedules a task in the background for you with
|
"""A decorator that schedules a task in the background for you with
|
||||||
optional reconnect logic. The decorator returns a :class:`Loop`.
|
optional reconnect logic. The decorator returns a :class:`Loop`.
|
||||||
@ -745,6 +733,14 @@ def loop(
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def decorator(func: LF) -> Loop[LF]:
|
def decorator(func: LF) -> Loop[LF]:
|
||||||
return Loop[LF](func, seconds=seconds, minutes=minutes, hours=hours, count=count, time=time, reconnect=reconnect)
|
return Loop[LF](
|
||||||
|
func,
|
||||||
|
seconds=seconds,
|
||||||
|
minutes=minutes,
|
||||||
|
hours=hours,
|
||||||
|
count=count,
|
||||||
|
time=time,
|
||||||
|
reconnect=reconnect,
|
||||||
|
)
|
||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
@ -75,3 +75,23 @@ async def test_explicit_initial_runs_tomorrow_multi():
|
|||||||
assert not has_run
|
assert not has_run
|
||||||
finally:
|
finally:
|
||||||
loop.cancel()
|
loop.cancel()
|
||||||
|
|
||||||
|
|
||||||
|
def test_task_regression_issue7659():
|
||||||
|
jst = datetime.timezone(datetime.timedelta(hours=9))
|
||||||
|
|
||||||
|
# 00:00, 03:00, 06:00, 09:00, 12:00, 15:00, 18:00, 21:00
|
||||||
|
times = [datetime.time(hour=h, tzinfo=jst) for h in range(0, 24, 3)]
|
||||||
|
|
||||||
|
@tasks.loop(time=times)
|
||||||
|
async def loop():
|
||||||
|
pass
|
||||||
|
|
||||||
|
before_midnight = datetime.datetime(2022, 3, 12, 23, 50, 59, tzinfo=jst)
|
||||||
|
after_midnight = before_midnight + datetime.timedelta(minutes=9, seconds=2)
|
||||||
|
|
||||||
|
expected_before_midnight = datetime.datetime(2022, 3, 13, 0, 0, 0, tzinfo=jst)
|
||||||
|
expected_after_midnight = datetime.datetime(2022, 3, 13, 3, 0, 0, tzinfo=jst)
|
||||||
|
|
||||||
|
assert loop._get_next_sleep_time(before_midnight) == expected_before_midnight
|
||||||
|
assert loop._get_next_sleep_time(after_midnight) == expected_after_midnight
|
||||||
|
Loading…
x
Reference in New Issue
Block a user