Turn KeepAliveHandler into an asyncio Task
This commit is contained in:
		| @@ -153,38 +153,38 @@ class GatewayRatelimiter: | |||||||
|                 await asyncio.sleep(delta) |                 await asyncio.sleep(delta) | ||||||
|  |  | ||||||
|  |  | ||||||
| class KeepAliveHandler(threading.Thread): | class KeepAliveHandler: | ||||||
|     def __init__(self, *args: Any, **kwargs: Any) -> None: |     def __init__(self, *, ws: DiscordWebSocket, shard_id: int = None, interval: float = None) -> None: | ||||||
|         ws = kwargs.pop("ws") |  | ||||||
|         interval = kwargs.pop("interval", None) |  | ||||||
|         shard_id = kwargs.pop("shard_id", None) |  | ||||||
|         threading.Thread.__init__(self, *args, **kwargs) |  | ||||||
|         self.ws: DiscordWebSocket = ws |         self.ws: DiscordWebSocket = ws | ||||||
|         self._main_thread_id: int = ws.thread_id |  | ||||||
|         self.interval: Optional[float] = interval |  | ||||||
|         self.daemon: bool = True |  | ||||||
|         self.shard_id: Optional[int] = shard_id |         self.shard_id: Optional[int] = shard_id | ||||||
|  |         self.interval: Optional[float] = interval | ||||||
|  |         self.heartbeat_timeout: float = self.ws._max_heartbeat_timeout | ||||||
|  |  | ||||||
|         self.msg: str = "Keeping shard ID %s websocket alive with sequence %s." |         self.msg: str = "Keeping shard ID %s websocket alive with sequence %s." | ||||||
|         self.block_msg: str = "Shard ID %s heartbeat blocked for more than %s seconds." |         self.block_msg: str = "Shard ID %s heartbeat blocked for more than %s seconds." | ||||||
|         self.behind_msg: str = "Can't keep up, shard ID %s websocket is %.1fs behind." |         self.behind_msg: str = "Can't keep up, shard ID %s websocket is %.1fs behind." | ||||||
|         self._stop_ev: threading.Event = threading.Event() |         self._stop_ev: asyncio.Event = asyncio.Event() | ||||||
|         self._last_ack: float = time.perf_counter() |  | ||||||
|         self._last_send: float = time.perf_counter() |         self._last_send: float = time.perf_counter() | ||||||
|         self._last_recv: float = time.perf_counter() |         self._last_recv: float = time.perf_counter() | ||||||
|  |         self._last_ack: float = time.perf_counter() | ||||||
|         self.latency: float = float("inf") |         self.latency: float = float("inf") | ||||||
|         self.heartbeat_timeout: float = ws._max_heartbeat_timeout |  | ||||||
|  |  | ||||||
|     def run(self) -> None: |     async def run(self) -> None: | ||||||
|         while not self._stop_ev.wait(self.interval): |         while True: | ||||||
|  |             try: | ||||||
|  |                 await asyncio.wait_for(self._stop_ev.wait(), timeout=self.interval) | ||||||
|  |             except asyncio.TimeoutError: | ||||||
|  |                 pass | ||||||
|  |             else: | ||||||
|  |                 return | ||||||
|  |  | ||||||
|             if self._last_recv + self.heartbeat_timeout < time.perf_counter(): |             if self._last_recv + self.heartbeat_timeout < time.perf_counter(): | ||||||
|                 _log.warning( |                 _log.warning( | ||||||
|                     "Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id |                     "Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id | ||||||
|                 ) |                 ) | ||||||
|                 coro = self.ws.close(4000) |  | ||||||
|                 f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop) |  | ||||||
|  |  | ||||||
|                 try: |                 try: | ||||||
|                     f.result() |                     await self.ws.close(4000) | ||||||
|                 except Exception: |                 except Exception: | ||||||
|                     _log.exception("An error occurred while stopping the gateway. Ignoring.") |                     _log.exception("An error occurred while stopping the gateway. Ignoring.") | ||||||
|                 finally: |                 finally: | ||||||
| @@ -193,24 +193,18 @@ class KeepAliveHandler(threading.Thread): | |||||||
|  |  | ||||||
|             data = self.get_payload() |             data = self.get_payload() | ||||||
|             _log.debug(self.msg, self.shard_id, data["d"]) |             _log.debug(self.msg, self.shard_id, data["d"]) | ||||||
|             coro = self.ws.send_heartbeat(data) |  | ||||||
|             f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop) |  | ||||||
|             try: |             try: | ||||||
|                 # block until sending is complete |                 # block until sending is complete | ||||||
|                 total = 0 |                 total = 0 | ||||||
|                 while True: |                 while True: | ||||||
|                     try: |                     try: | ||||||
|                         f.result(10) |                         await asyncio.wait_for(self.ws.send_heartbeat(data), timeout=10) | ||||||
|                         break |                         break | ||||||
|                     except concurrent.futures.TimeoutError: |                     except asyncio.TimeoutError: | ||||||
|                         total += 10 |                         total += 10 | ||||||
|                         try: |  | ||||||
|                             frame = sys._current_frames()[self._main_thread_id] |                         stack = "".join(traceback.format_stack()) | ||||||
|                         except KeyError: |                         msg = f"{self.block_msg}\nLoop traceback (most recent call last):\n{stack}" | ||||||
|                             msg = self.block_msg |  | ||||||
|                         else: |  | ||||||
|                             stack = "".join(traceback.format_stack(frame)) |  | ||||||
|                             msg = f"{self.block_msg}\nLoop thread traceback (most recent call last):\n{stack}" |  | ||||||
|                         _log.warning(msg, self.shard_id, total) |                         _log.warning(msg, self.shard_id, total) | ||||||
|  |  | ||||||
|             except Exception: |             except Exception: | ||||||
| @@ -225,6 +219,10 @@ class KeepAliveHandler(threading.Thread): | |||||||
|             "d": self.ws.sequence,  # type: ignore |             "d": self.ws.sequence,  # type: ignore | ||||||
|         } |         } | ||||||
|  |  | ||||||
|  |  | ||||||
|  |     def start(self) -> None: | ||||||
|  |         self.ws.loop.create_task(self.run()) | ||||||
|  |  | ||||||
|     def stop(self) -> None: |     def stop(self) -> None: | ||||||
|         self._stop_ev.set() |         self._stop_ev.set() | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user