Turn KeepAliveHandler into an asyncio Task
This commit is contained in:
		| @@ -153,38 +153,38 @@ class GatewayRatelimiter: | ||||
|                 await asyncio.sleep(delta) | ||||
|  | ||||
|  | ||||
| class KeepAliveHandler(threading.Thread): | ||||
|     def __init__(self, *args: Any, **kwargs: Any) -> None: | ||||
|         ws = kwargs.pop("ws") | ||||
|         interval = kwargs.pop("interval", None) | ||||
|         shard_id = kwargs.pop("shard_id", None) | ||||
|         threading.Thread.__init__(self, *args, **kwargs) | ||||
| class KeepAliveHandler: | ||||
|     def __init__(self, *, ws: DiscordWebSocket, shard_id: int = None, interval: float = None) -> None: | ||||
|         self.ws: DiscordWebSocket = ws | ||||
|         self._main_thread_id: int = ws.thread_id | ||||
|         self.interval: Optional[float] = interval | ||||
|         self.daemon: bool = True | ||||
|         self.shard_id: Optional[int] = shard_id | ||||
|         self.interval: Optional[float] = interval | ||||
|         self.heartbeat_timeout: float = self.ws._max_heartbeat_timeout | ||||
|  | ||||
|         self.msg: str = "Keeping shard ID %s websocket alive with sequence %s." | ||||
|         self.block_msg: str = "Shard ID %s heartbeat blocked for more than %s seconds." | ||||
|         self.behind_msg: str = "Can't keep up, shard ID %s websocket is %.1fs behind." | ||||
|         self._stop_ev: threading.Event = threading.Event() | ||||
|         self._last_ack: float = time.perf_counter() | ||||
|         self._stop_ev: asyncio.Event = asyncio.Event() | ||||
|         self._last_send: float = time.perf_counter() | ||||
|         self._last_recv: float = time.perf_counter() | ||||
|         self._last_ack: float = time.perf_counter() | ||||
|         self.latency: float = float("inf") | ||||
|         self.heartbeat_timeout: float = ws._max_heartbeat_timeout | ||||
|  | ||||
|     def run(self) -> None: | ||||
|         while not self._stop_ev.wait(self.interval): | ||||
|     async def run(self) -> None: | ||||
|         while True: | ||||
|             try: | ||||
|                 await asyncio.wait_for(self._stop_ev.wait(), timeout=self.interval) | ||||
|             except asyncio.TimeoutError: | ||||
|                 pass | ||||
|             else: | ||||
|                 return | ||||
|  | ||||
|             if self._last_recv + self.heartbeat_timeout < time.perf_counter(): | ||||
|                 _log.warning( | ||||
|                     "Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id | ||||
|                 ) | ||||
|                 coro = self.ws.close(4000) | ||||
|                 f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop) | ||||
|  | ||||
|                 try: | ||||
|                     f.result() | ||||
|                     await self.ws.close(4000) | ||||
|                 except Exception: | ||||
|                     _log.exception("An error occurred while stopping the gateway. Ignoring.") | ||||
|                 finally: | ||||
| @@ -193,24 +193,18 @@ class KeepAliveHandler(threading.Thread): | ||||
|  | ||||
|             data = self.get_payload() | ||||
|             _log.debug(self.msg, self.shard_id, data["d"]) | ||||
|             coro = self.ws.send_heartbeat(data) | ||||
|             f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop) | ||||
|             try: | ||||
|                 # block until sending is complete | ||||
|                 total = 0 | ||||
|                 while True: | ||||
|                     try: | ||||
|                         f.result(10) | ||||
|                         await asyncio.wait_for(self.ws.send_heartbeat(data), timeout=10) | ||||
|                         break | ||||
|                     except concurrent.futures.TimeoutError: | ||||
|                     except asyncio.TimeoutError: | ||||
|                         total += 10 | ||||
|                         try: | ||||
|                             frame = sys._current_frames()[self._main_thread_id] | ||||
|                         except KeyError: | ||||
|                             msg = self.block_msg | ||||
|                         else: | ||||
|                             stack = "".join(traceback.format_stack(frame)) | ||||
|                             msg = f"{self.block_msg}\nLoop thread traceback (most recent call last):\n{stack}" | ||||
|  | ||||
|                         stack = "".join(traceback.format_stack()) | ||||
|                         msg = f"{self.block_msg}\nLoop traceback (most recent call last):\n{stack}" | ||||
|                         _log.warning(msg, self.shard_id, total) | ||||
|  | ||||
|             except Exception: | ||||
| @@ -225,6 +219,10 @@ class KeepAliveHandler(threading.Thread): | ||||
|             "d": self.ws.sequence,  # type: ignore | ||||
|         } | ||||
|  | ||||
|  | ||||
|     def start(self) -> None: | ||||
|         self.ws.loop.create_task(self.run()) | ||||
|  | ||||
|     def stop(self) -> None: | ||||
|         self._stop_ev.set() | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user