mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-04-18 23:15:48 +00:00
[tasks] Refactor tasks to not store a time index state
It's better to recompute it every time rather than suffer from maintaining the extra state.
This commit is contained in:
parent
c02a3c0bb2
commit
6a43d60acf
@ -148,13 +148,17 @@ class Loop(Generic[LF]):
|
||||
self._handle = SleepHandle(dt=dt, loop=asyncio.get_running_loop())
|
||||
return self._handle.wait()
|
||||
|
||||
def _is_relative_time(self) -> bool:
|
||||
return self._time is MISSING
|
||||
|
||||
def _is_explicit_time(self) -> bool:
|
||||
return self._time is not MISSING
|
||||
|
||||
async def _loop(self, *args: Any, **kwargs: Any) -> None:
|
||||
backoff = ExponentialBackoff()
|
||||
await self._call_loop_function('before_loop')
|
||||
self._last_iteration_failed = False
|
||||
if self._time is not MISSING:
|
||||
# the time index should be prepared every time the internal loop is started
|
||||
self._prepare_time_index()
|
||||
if self._is_explicit_time():
|
||||
self._next_iteration = self._get_next_sleep_time()
|
||||
else:
|
||||
self._next_iteration = datetime.datetime.now(datetime.timezone.utc)
|
||||
@ -164,7 +168,7 @@ class Loop(Generic[LF]):
|
||||
return
|
||||
while True:
|
||||
# sleep before the body of the task for explicit time intervals
|
||||
if self._time is not MISSING:
|
||||
if self._is_explicit_time():
|
||||
await self._try_sleep_until(self._next_iteration)
|
||||
if not self._last_iteration_failed:
|
||||
self._last_iteration = self._next_iteration
|
||||
@ -182,7 +186,7 @@ class Loop(Generic[LF]):
|
||||
return
|
||||
|
||||
# sleep after the body of the task for relative time intervals
|
||||
if self._time is MISSING:
|
||||
if self._is_relative_time():
|
||||
await self._try_sleep_until(self._next_iteration)
|
||||
|
||||
self._current_loop += 1
|
||||
@ -553,47 +557,36 @@ class Loop(Generic[LF]):
|
||||
self._error = coro # type: ignore
|
||||
return coro
|
||||
|
||||
def _get_next_sleep_time(self) -> datetime.datetime:
|
||||
def _get_next_sleep_time(self, now: datetime.datetime = MISSING) -> datetime.datetime:
|
||||
if self._sleep is not MISSING:
|
||||
return self._last_iteration + datetime.timedelta(seconds=self._sleep)
|
||||
|
||||
if self._time_index >= len(self._time):
|
||||
self._time_index = 0
|
||||
if self._current_loop == 0:
|
||||
# if we're at the last index on the first iteration, we need to sleep until tomorrow
|
||||
return datetime.datetime.combine(
|
||||
datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0]
|
||||
)
|
||||
if now is MISSING:
|
||||
now = datetime.datetime.now(datetime.timezone.utc)
|
||||
|
||||
next_time = self._time[self._time_index]
|
||||
index = self._start_time_relative_to(now)
|
||||
|
||||
if self._current_loop == 0:
|
||||
self._time_index += 1
|
||||
return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time)
|
||||
if index is None:
|
||||
time = self._time[0]
|
||||
tomorrow = now + datetime.timedelta(days=1)
|
||||
date = tomorrow.date()
|
||||
else:
|
||||
date = now.date()
|
||||
time = self._time[index]
|
||||
|
||||
next_date = self._last_iteration
|
||||
if self._time_index == 0:
|
||||
# we can assume that the earliest time should be scheduled for "tomorrow"
|
||||
next_date += datetime.timedelta(days=1)
|
||||
return datetime.datetime.combine(date, time, tzinfo=time.tzinfo or datetime.timezone.utc)
|
||||
|
||||
self._time_index += 1
|
||||
return datetime.datetime.combine(next_date, next_time)
|
||||
|
||||
def _prepare_time_index(self, now: datetime.datetime = MISSING) -> None:
|
||||
def _start_time_relative_to(self, now: datetime.datetime) -> Optional[int]:
|
||||
# now kwarg should be a datetime.datetime representing the time "now"
|
||||
# to calculate the next time index from
|
||||
|
||||
# pre-condition: self._time is set
|
||||
time_now = (
|
||||
now if now is not MISSING else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0)
|
||||
).timetz()
|
||||
idx = -1
|
||||
time_now = now.timetz()
|
||||
for idx, time in enumerate(self._time):
|
||||
if time >= time_now:
|
||||
self._time_index = idx
|
||||
break
|
||||
return idx
|
||||
else:
|
||||
self._time_index = idx + 1
|
||||
return None
|
||||
|
||||
def _get_time_parameter(
|
||||
self,
|
||||
@ -683,10 +676,6 @@ class Loop(Generic[LF]):
|
||||
self._sleep = self._seconds = self._minutes = self._hours = MISSING
|
||||
|
||||
if self.is_running():
|
||||
if self._time is not MISSING:
|
||||
# prepare the next time index starting from after the last iteration
|
||||
self._prepare_time_index(now=self._last_iteration)
|
||||
|
||||
self._next_iteration = self._get_next_sleep_time()
|
||||
if self._handle and not self._handle.done():
|
||||
# the loop is sleeping, recalculate based on new interval
|
||||
@ -701,7 +690,6 @@ def loop(
|
||||
time: Union[datetime.time, Sequence[datetime.time]] = MISSING,
|
||||
count: Optional[int] = None,
|
||||
reconnect: bool = True,
|
||||
loop: asyncio.AbstractEventLoop = MISSING,
|
||||
) -> Callable[[LF], Loop[LF]]:
|
||||
"""A decorator that schedules a task in the background for you with
|
||||
optional reconnect logic. The decorator returns a :class:`Loop`.
|
||||
@ -745,6 +733,14 @@ def loop(
|
||||
"""
|
||||
|
||||
def decorator(func: LF) -> Loop[LF]:
|
||||
return Loop[LF](func, seconds=seconds, minutes=minutes, hours=hours, count=count, time=time, reconnect=reconnect)
|
||||
return Loop[LF](
|
||||
func,
|
||||
seconds=seconds,
|
||||
minutes=minutes,
|
||||
hours=hours,
|
||||
count=count,
|
||||
time=time,
|
||||
reconnect=reconnect,
|
||||
)
|
||||
|
||||
return decorator
|
||||
|
@ -75,3 +75,23 @@ async def test_explicit_initial_runs_tomorrow_multi():
|
||||
assert not has_run
|
||||
finally:
|
||||
loop.cancel()
|
||||
|
||||
|
||||
def test_task_regression_issue7659():
|
||||
jst = datetime.timezone(datetime.timedelta(hours=9))
|
||||
|
||||
# 00:00, 03:00, 06:00, 09:00, 12:00, 15:00, 18:00, 21:00
|
||||
times = [datetime.time(hour=h, tzinfo=jst) for h in range(0, 24, 3)]
|
||||
|
||||
@tasks.loop(time=times)
|
||||
async def loop():
|
||||
pass
|
||||
|
||||
before_midnight = datetime.datetime(2022, 3, 12, 23, 50, 59, tzinfo=jst)
|
||||
after_midnight = before_midnight + datetime.timedelta(minutes=9, seconds=2)
|
||||
|
||||
expected_before_midnight = datetime.datetime(2022, 3, 13, 0, 0, 0, tzinfo=jst)
|
||||
expected_after_midnight = datetime.datetime(2022, 3, 13, 3, 0, 0, tzinfo=jst)
|
||||
|
||||
assert loop._get_next_sleep_time(before_midnight) == expected_before_midnight
|
||||
assert loop._get_next_sleep_time(after_midnight) == expected_after_midnight
|
||||
|
Loading…
x
Reference in New Issue
Block a user