Rewrite gateway to use aiohttp instead of websockets

This commit is contained in:
Rapptz
2020-04-07 21:53:55 -04:00
parent 45cb231161
commit b8154e365f
8 changed files with 97 additions and 92 deletions

View File

@ -36,7 +36,7 @@ import threading
import traceback
import zlib
import websockets
import aiohttp
from . import utils
from .activity import BaseActivity
@ -60,6 +60,10 @@ class ReconnectWebSocket(Exception):
self.resume = resume
self.op = 'RESUME' if resume else 'IDENTIFY'
class WebSocketClosure(Exception):
"""An exception to make up for the fact that aiohttp doesn't signal closure."""
pass
EventListener = namedtuple('EventListener', 'predicate event result future')
class KeepAliveHandler(threading.Thread):
@ -160,11 +164,17 @@ class VoiceKeepAliveHandler(KeepAliveHandler):
self.latency = ack_time - self._last_send
self.recent_ack_latencies.append(self.latency)
class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
"""Implements a WebSocket for Discord's gateway v6.
# Monkey patch certain things from the aiohttp websocket code
# Check this whenever we update dependencies.
OLD_CLOSE = aiohttp.ClientWebSocketResponse.close
This is created through :func:`create_main_websocket`. Library
users should never create this manually.
async def _new_ws_close(self, *, code: int = 4000, message: bytes = b'') -> bool:
return await OLD_CLOSE(self, code=code, message=message)
aiohttp.ClientWebSocketResponse.close = _new_ws_close
class DiscordWebSocket:
"""Implements a WebSocket for Discord's gateway v6.
Attributes
-----------
@ -217,9 +227,10 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
HEARTBEAT_ACK = 11
GUILD_SYNC = 12
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.max_size = None
def __init__(self, socket, *, loop):
self.socket = socket
self.loop = loop
# an empty dispatcher to prevent crashes
self._dispatch = lambda *args: None
# generic event listeners
@ -234,14 +245,19 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
self._zlib = zlib.decompressobj()
self._buffer = bytearray()
@property
def open(self):
return not self.socket.closed
@classmethod
async def from_client(cls, client, *, shard_id=None, session=None, sequence=None, resume=False):
async def from_client(cls, client, *, gateway=None, shard_id=None, session=None, sequence=None, resume=False):
"""Creates a main websocket for Discord from a :class:`Client`.
This is for internal use only.
"""
gateway = await client.http.get_gateway()
ws = await websockets.connect(gateway, loop=client.loop, klass=cls, compression=None)
gateway = gateway or await client.http.get_gateway()
socket = await client.http.ws_connect(gateway)
ws = cls(socket, loop=client.loop)
# dynamically add attributes needed
ws.token = client.http.token
@ -267,14 +283,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
return ws
await ws.resume()
try:
await ws.ensure_open()
except websockets.exceptions.ConnectionClosed:
# ws got closed so let's just do a regular IDENTIFY connect.
log.warning('RESUME failed (the websocket decided to close) for Shard ID %s. Retrying.', shard_id)
return await cls.from_client(client, shard_id=shard_id)
else:
return ws
return ws
def wait_for(self, event, predicate, result=None):
"""Waits for a DISPATCH'd event that meets the predicate.
@ -472,8 +481,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
heartbeat = self._keep_alive
return float('inf') if heartbeat is None else heartbeat.latency
def _can_handle_close(self, code):
return code not in (1000, 4004, 4010, 4011)
def _can_handle_close(self):
return self.socket.close_code not in (1000, 4004, 4010, 4011)
async def poll_event(self):
"""Polls for a DISPATCH event and handles the general gateway loop.
@ -484,26 +493,35 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
The websocket connection was terminated for unhandled reasons.
"""
try:
msg = await self.recv()
await self.received_message(msg)
except websockets.exceptions.ConnectionClosed as exc:
if self._can_handle_close(exc.code):
log.info('Websocket closed with %s (%s), attempting a reconnect.', exc.code, exc.reason)
raise ReconnectWebSocket(self.shard_id) from exc
else:
log.info('Websocket closed with %s (%s), cannot reconnect.', exc.code, exc.reason)
raise ConnectionClosed(exc, shard_id=self.shard_id) from exc
msg = await self.socket.receive()
if msg.type is aiohttp.WSMsgType.TEXT:
await self.received_message(msg.data)
elif msg.type is aiohttp.WSMsgType.BINARY:
await self.received_message(msg.data)
elif msg.type is aiohttp.WSMsgType.ERROR:
log.debug('Received %s', msg)
raise msg.data
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE):
log.debug('Received %s', msg)
raise WebSocketClosure('Unexpected WebSocket closure.')
except WebSocketClosure as e:
if self._can_handle_close():
log.info('Websocket closed with %s, attempting a reconnect.', self.socket.close_code)
raise ReconnectWebSocket(self.shard_id) from e
elif self.socket.close_code is not None:
log.info('Websocket closed with %s, cannot reconnect.', self.socket.close_code)
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from e
async def send(self, data):
self._dispatch('socket_raw_send', data)
await super().send(data)
await self.socket.send_str(data)
async def send_as_json(self, data):
try:
await self.send(utils.to_json(data))
except websockets.exceptions.ConnectionClosed as exc:
if not self._can_handle_close(exc.code):
raise ConnectionClosed(exc, shard_id=self.shard_id) from exc
except RuntimeError as exc:
if not self._can_handle_close():
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
async def change_presence(self, *, activity=None, status=None, afk=False, since=0.0):
if activity is not None:
@ -570,19 +588,13 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
log.debug('Updating our voice state to %s.', payload)
await self.send_as_json(payload)
async def close(self, code=4000, reason=''):
async def close(self, code=4000):
if self._keep_alive:
self._keep_alive.stop()
await super().close(code, reason)
await self.socket.close(code=code)
async def close_connection(self, *args, **kwargs):
if self._keep_alive:
self._keep_alive.stop()
await super().close_connection(*args, **kwargs)
class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
class DiscordVoiceWebSocket:
"""Implements the websocket protocol for handling voice connections.
Attributes
@ -626,14 +638,13 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
CLIENT_CONNECT = 12
CLIENT_DISCONNECT = 13
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.max_size = None
def __init__(self, socket):
self.ws = socket
self._keep_alive = None
async def send_as_json(self, data):
log.debug('Sending voice websocket frame: %s.', data)
await self.send(utils.to_json(data))
await self.ws.send_str(utils.to_json(data))
async def resume(self):
state = self._connection
@ -664,7 +675,9 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
async def from_client(cls, client, *, resume=False):
"""Creates a voice websocket for the :class:`VoiceClient`."""
gateway = 'wss://' + client.endpoint + '/?v=4'
ws = await websockets.connect(gateway, loop=client.loop, klass=cls, compression=None)
http = client._state.http
socket = await http.ws_connect(gateway)
ws = cls(socket)
ws.gateway = gateway
ws._connection = client
ws._max_heartbeat_timeout = 60.0
@ -785,14 +798,19 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
await self.speak(False)
async def poll_event(self):
try:
msg = await asyncio.wait_for(self.recv(), timeout=30.0)
await self.received_message(json.loads(msg))
except websockets.exceptions.ConnectionClosed as exc:
raise ConnectionClosed(exc, shard_id=None) from exc
# This exception is handled up the chain
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
if msg.type is aiohttp.WSMsgType.TEXT:
await self.received_message(json.loads(msg.data))
elif msg.type is aiohttp.WSMsgType.ERROR:
log.debug('Received %s', msg)
raise ConnectionClosed(self.ws, shard_id=None) from msg.data
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE):
log.debug('Received %s', msg)
raise ConnectionClosed(self.ws, shard_id=None)
async def close_connection(self, *args, **kwargs):
if self._keep_alive:
async def close(self, code=1000):
if self._keep_alive is not None:
self._keep_alive.stop()
await super().close_connection(*args, **kwargs)
await self.ws.close(code=code)