From e649f356bea9e6dfe12d5da05e2b81c00ef894c3 Mon Sep 17 00:00:00 2001 From: Gnome Date: Thu, 7 Oct 2021 17:15:19 +0100 Subject: [PATCH] Turn KeepAliveHandler into an asyncio Task --- discord/gateway.py | 54 ++++++++++++++++++++++------------------------ 1 file changed, 26 insertions(+), 28 deletions(-) diff --git a/discord/gateway.py b/discord/gateway.py index 5ef651f1..bf0a5624 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -153,38 +153,38 @@ class GatewayRatelimiter: await asyncio.sleep(delta) -class KeepAliveHandler(threading.Thread): - def __init__(self, *args: Any, **kwargs: Any) -> None: - ws = kwargs.pop("ws") - interval = kwargs.pop("interval", None) - shard_id = kwargs.pop("shard_id", None) - threading.Thread.__init__(self, *args, **kwargs) +class KeepAliveHandler: + def __init__(self, *, ws: DiscordWebSocket, shard_id: int = None, interval: float = None) -> None: self.ws: DiscordWebSocket = ws - self._main_thread_id: int = ws.thread_id - self.interval: Optional[float] = interval - self.daemon: bool = True self.shard_id: Optional[int] = shard_id + self.interval: Optional[float] = interval + self.heartbeat_timeout: float = self.ws._max_heartbeat_timeout + self.msg: str = "Keeping shard ID %s websocket alive with sequence %s." self.block_msg: str = "Shard ID %s heartbeat blocked for more than %s seconds." self.behind_msg: str = "Can't keep up, shard ID %s websocket is %.1fs behind." - self._stop_ev: threading.Event = threading.Event() - self._last_ack: float = time.perf_counter() + self._stop_ev: asyncio.Event = asyncio.Event() self._last_send: float = time.perf_counter() self._last_recv: float = time.perf_counter() + self._last_ack: float = time.perf_counter() self.latency: float = float("inf") - self.heartbeat_timeout: float = ws._max_heartbeat_timeout - def run(self) -> None: - while not self._stop_ev.wait(self.interval): + async def run(self) -> None: + while True: + try: + await asyncio.wait_for(self._stop_ev.wait(), timeout=self.interval) + except asyncio.TimeoutError: + pass + else: + return + if self._last_recv + self.heartbeat_timeout < time.perf_counter(): _log.warning( "Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id ) - coro = self.ws.close(4000) - f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop) try: - f.result() + await self.ws.close(4000) except Exception: _log.exception("An error occurred while stopping the gateway. Ignoring.") finally: @@ -193,24 +193,18 @@ class KeepAliveHandler(threading.Thread): data = self.get_payload() _log.debug(self.msg, self.shard_id, data["d"]) - coro = self.ws.send_heartbeat(data) - f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop) try: # block until sending is complete total = 0 while True: try: - f.result(10) + await asyncio.wait_for(self.ws.send_heartbeat(data), timeout=10) break - except concurrent.futures.TimeoutError: + except asyncio.TimeoutError: total += 10 - try: - frame = sys._current_frames()[self._main_thread_id] - except KeyError: - msg = self.block_msg - else: - stack = "".join(traceback.format_stack(frame)) - msg = f"{self.block_msg}\nLoop thread traceback (most recent call last):\n{stack}" + + stack = "".join(traceback.format_stack()) + msg = f"{self.block_msg}\nLoop traceback (most recent call last):\n{stack}" _log.warning(msg, self.shard_id, total) except Exception: @@ -225,6 +219,10 @@ class KeepAliveHandler(threading.Thread): "d": self.ws.sequence, # type: ignore } + + def start(self) -> None: + self.ws.loop.create_task(self.run()) + def stop(self) -> None: self._stop_ev.set()