Turn KeepAliveHandler into an asyncio Task
This commit is contained in:
parent
3260ec6643
commit
e649f356be
@ -153,38 +153,38 @@ class GatewayRatelimiter:
|
||||
await asyncio.sleep(delta)
|
||||
|
||||
|
||||
class KeepAliveHandler(threading.Thread):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
ws = kwargs.pop("ws")
|
||||
interval = kwargs.pop("interval", None)
|
||||
shard_id = kwargs.pop("shard_id", None)
|
||||
threading.Thread.__init__(self, *args, **kwargs)
|
||||
class KeepAliveHandler:
|
||||
def __init__(self, *, ws: DiscordWebSocket, shard_id: int = None, interval: float = None) -> None:
|
||||
self.ws: DiscordWebSocket = ws
|
||||
self._main_thread_id: int = ws.thread_id
|
||||
self.interval: Optional[float] = interval
|
||||
self.daemon: bool = True
|
||||
self.shard_id: Optional[int] = shard_id
|
||||
self.interval: Optional[float] = interval
|
||||
self.heartbeat_timeout: float = self.ws._max_heartbeat_timeout
|
||||
|
||||
self.msg: str = "Keeping shard ID %s websocket alive with sequence %s."
|
||||
self.block_msg: str = "Shard ID %s heartbeat blocked for more than %s seconds."
|
||||
self.behind_msg: str = "Can't keep up, shard ID %s websocket is %.1fs behind."
|
||||
self._stop_ev: threading.Event = threading.Event()
|
||||
self._last_ack: float = time.perf_counter()
|
||||
self._stop_ev: asyncio.Event = asyncio.Event()
|
||||
self._last_send: float = time.perf_counter()
|
||||
self._last_recv: float = time.perf_counter()
|
||||
self._last_ack: float = time.perf_counter()
|
||||
self.latency: float = float("inf")
|
||||
self.heartbeat_timeout: float = ws._max_heartbeat_timeout
|
||||
|
||||
def run(self) -> None:
|
||||
while not self._stop_ev.wait(self.interval):
|
||||
async def run(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
await asyncio.wait_for(self._stop_ev.wait(), timeout=self.interval)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
else:
|
||||
return
|
||||
|
||||
if self._last_recv + self.heartbeat_timeout < time.perf_counter():
|
||||
_log.warning(
|
||||
"Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id
|
||||
)
|
||||
coro = self.ws.close(4000)
|
||||
f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop)
|
||||
|
||||
try:
|
||||
f.result()
|
||||
await self.ws.close(4000)
|
||||
except Exception:
|
||||
_log.exception("An error occurred while stopping the gateway. Ignoring.")
|
||||
finally:
|
||||
@ -193,24 +193,18 @@ class KeepAliveHandler(threading.Thread):
|
||||
|
||||
data = self.get_payload()
|
||||
_log.debug(self.msg, self.shard_id, data["d"])
|
||||
coro = self.ws.send_heartbeat(data)
|
||||
f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop)
|
||||
try:
|
||||
# block until sending is complete
|
||||
total = 0
|
||||
while True:
|
||||
try:
|
||||
f.result(10)
|
||||
await asyncio.wait_for(self.ws.send_heartbeat(data), timeout=10)
|
||||
break
|
||||
except concurrent.futures.TimeoutError:
|
||||
except asyncio.TimeoutError:
|
||||
total += 10
|
||||
try:
|
||||
frame = sys._current_frames()[self._main_thread_id]
|
||||
except KeyError:
|
||||
msg = self.block_msg
|
||||
else:
|
||||
stack = "".join(traceback.format_stack(frame))
|
||||
msg = f"{self.block_msg}\nLoop thread traceback (most recent call last):\n{stack}"
|
||||
|
||||
stack = "".join(traceback.format_stack())
|
||||
msg = f"{self.block_msg}\nLoop traceback (most recent call last):\n{stack}"
|
||||
_log.warning(msg, self.shard_id, total)
|
||||
|
||||
except Exception:
|
||||
@ -225,6 +219,10 @@ class KeepAliveHandler(threading.Thread):
|
||||
"d": self.ws.sequence, # type: ignore
|
||||
}
|
||||
|
||||
|
||||
def start(self) -> None:
|
||||
self.ws.loop.create_task(self.run())
|
||||
|
||||
def stop(self) -> None:
|
||||
self._stop_ev.set()
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user