Rewrite gateway to use aiohttp instead of websockets
This commit is contained in:
		| @@ -31,7 +31,6 @@ from pathlib import Path | ||||
| import discord | ||||
| import pkg_resources | ||||
| import aiohttp | ||||
| import websockets | ||||
| import platform | ||||
|  | ||||
| def show_version(): | ||||
| @@ -46,7 +45,6 @@ def show_version(): | ||||
|             entries.append('    - discord.py pkg_resources: v{0}'.format(pkg.version)) | ||||
|  | ||||
|     entries.append('- aiohttp v{0.__version__}'.format(aiohttp)) | ||||
|     entries.append('- websockets v{0.__version__}'.format(websockets)) | ||||
|     uname = platform.uname() | ||||
|     entries.append('- system info: {0.system} {0.release} {0.version}'.format(uname)) | ||||
|     print('\n'.join(entries)) | ||||
|   | ||||
| @@ -32,7 +32,6 @@ import sys | ||||
| import traceback | ||||
|  | ||||
| import aiohttp | ||||
| import websockets | ||||
|  | ||||
| from .user import User, Profile | ||||
| from .asset import Asset | ||||
| @@ -497,9 +496,7 @@ class Client: | ||||
|                     GatewayNotFound, | ||||
|                     ConnectionClosed, | ||||
|                     aiohttp.ClientError, | ||||
|                     asyncio.TimeoutError, | ||||
|                     websockets.InvalidHandshake, | ||||
|                     websockets.WebSocketProtocolError) as exc: | ||||
|                     asyncio.TimeoutError) as exc: | ||||
|  | ||||
|                 self.dispatch('disconnect') | ||||
|                 if not reconnect: | ||||
| @@ -632,7 +629,11 @@ class Client: | ||||
|             _cleanup_loop(loop) | ||||
|  | ||||
|         if not future.cancelled(): | ||||
|             return future.result() | ||||
|             try: | ||||
|                 return future.result() | ||||
|             except KeyboardInterrupt: | ||||
|                 # I am unsure why this gets raised here but suppress it anyway | ||||
|                 return None | ||||
|  | ||||
|     # properties | ||||
|  | ||||
|   | ||||
| @@ -159,10 +159,11 @@ class ConnectionClosed(ClientException): | ||||
|     shard_id: Optional[:class:`int`] | ||||
|         The shard ID that got closed if applicable. | ||||
|     """ | ||||
|     def __init__(self, original, *, shard_id): | ||||
|     def __init__(self, socket, *, shard_id): | ||||
|         # This exception is just the same exception except | ||||
|         # reconfigured to subclass ClientException for users | ||||
|         self.code = original.code | ||||
|         self.reason = original.reason | ||||
|         self.code = socket.close_code | ||||
|         # aiohttp doesn't seem to consistently provide close reason | ||||
|         self.reason = '' | ||||
|         self.shard_id = shard_id | ||||
|         super().__init__(str(original)) | ||||
|         super().__init__('Shard ID %s WebSocket closed with %s' % (self.shard_id, self.code)) | ||||
|   | ||||
| @@ -27,7 +27,6 @@ DEALINGS IN THE SOFTWARE. | ||||
| import asyncio | ||||
| import datetime | ||||
| import aiohttp | ||||
| import websockets | ||||
| import discord | ||||
| import inspect | ||||
| import logging | ||||
| @@ -58,8 +57,6 @@ class Loop: | ||||
|             discord.ConnectionClosed, | ||||
|             aiohttp.ClientError, | ||||
|             asyncio.TimeoutError, | ||||
|             websockets.InvalidHandshake, | ||||
|             websockets.WebSocketProtocolError, | ||||
|         ) | ||||
|  | ||||
|         self._before_loop = None | ||||
|   | ||||
| @@ -36,7 +36,7 @@ import threading | ||||
| import traceback | ||||
| import zlib | ||||
|  | ||||
| import websockets | ||||
| import aiohttp | ||||
|  | ||||
| from . import utils | ||||
| from .activity import BaseActivity | ||||
| @@ -60,6 +60,10 @@ class ReconnectWebSocket(Exception): | ||||
|         self.resume = resume | ||||
|         self.op = 'RESUME' if resume else 'IDENTIFY' | ||||
|  | ||||
| class WebSocketClosure(Exception): | ||||
|     """An exception to make up for the fact that aiohttp doesn't signal closure.""" | ||||
|     pass | ||||
|  | ||||
| EventListener = namedtuple('EventListener', 'predicate event result future') | ||||
|  | ||||
| class KeepAliveHandler(threading.Thread): | ||||
| @@ -160,11 +164,17 @@ class VoiceKeepAliveHandler(KeepAliveHandler): | ||||
|         self.latency = ack_time - self._last_send | ||||
|         self.recent_ack_latencies.append(self.latency) | ||||
|  | ||||
| class DiscordWebSocket(websockets.client.WebSocketClientProtocol): | ||||
|     """Implements a WebSocket for Discord's gateway v6. | ||||
| # Monkey patch certain things from the aiohttp websocket code | ||||
| # Check this whenever we update dependencies. | ||||
| OLD_CLOSE = aiohttp.ClientWebSocketResponse.close | ||||
|  | ||||
|     This is created through :func:`create_main_websocket`. Library | ||||
|     users should never create this manually. | ||||
| async def _new_ws_close(self, *, code: int = 4000, message: bytes = b'') -> bool: | ||||
|     return await OLD_CLOSE(self, code=code, message=message) | ||||
|  | ||||
| aiohttp.ClientWebSocketResponse.close = _new_ws_close | ||||
|  | ||||
| class DiscordWebSocket: | ||||
|     """Implements a WebSocket for Discord's gateway v6. | ||||
|  | ||||
|     Attributes | ||||
|     ----------- | ||||
| @@ -217,9 +227,10 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): | ||||
|     HEARTBEAT_ACK      = 11 | ||||
|     GUILD_SYNC         = 12 | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.max_size = None | ||||
|     def __init__(self, socket, *, loop): | ||||
|         self.socket = socket | ||||
|         self.loop = loop | ||||
|  | ||||
|         # an empty dispatcher to prevent crashes | ||||
|         self._dispatch = lambda *args: None | ||||
|         # generic event listeners | ||||
| @@ -234,14 +245,19 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): | ||||
|         self._zlib = zlib.decompressobj() | ||||
|         self._buffer = bytearray() | ||||
|  | ||||
|     @property | ||||
|     def open(self): | ||||
|         return not self.socket.closed | ||||
|  | ||||
|     @classmethod | ||||
|     async def from_client(cls, client, *, shard_id=None, session=None, sequence=None, resume=False): | ||||
|     async def from_client(cls, client, *, gateway=None, shard_id=None, session=None, sequence=None, resume=False): | ||||
|         """Creates a main websocket for Discord from a :class:`Client`. | ||||
|  | ||||
|         This is for internal use only. | ||||
|         """ | ||||
|         gateway = await client.http.get_gateway() | ||||
|         ws = await websockets.connect(gateway, loop=client.loop, klass=cls, compression=None) | ||||
|         gateway = gateway or await client.http.get_gateway() | ||||
|         socket = await client.http.ws_connect(gateway) | ||||
|         ws = cls(socket, loop=client.loop) | ||||
|  | ||||
|         # dynamically add attributes needed | ||||
|         ws.token = client.http.token | ||||
| @@ -267,14 +283,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): | ||||
|             return ws | ||||
|  | ||||
|         await ws.resume() | ||||
|         try: | ||||
|             await ws.ensure_open() | ||||
|         except websockets.exceptions.ConnectionClosed: | ||||
|             # ws got closed so let's just do a regular IDENTIFY connect. | ||||
|             log.warning('RESUME failed (the websocket decided to close) for Shard ID %s. Retrying.', shard_id) | ||||
|             return await cls.from_client(client, shard_id=shard_id) | ||||
|         else: | ||||
|             return ws | ||||
|         return ws | ||||
|  | ||||
|     def wait_for(self, event, predicate, result=None): | ||||
|         """Waits for a DISPATCH'd event that meets the predicate. | ||||
| @@ -472,8 +481,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): | ||||
|         heartbeat = self._keep_alive | ||||
|         return float('inf') if heartbeat is None else heartbeat.latency | ||||
|  | ||||
|     def _can_handle_close(self, code): | ||||
|         return code not in (1000, 4004, 4010, 4011) | ||||
|     def _can_handle_close(self): | ||||
|         return self.socket.close_code not in (1000, 4004, 4010, 4011) | ||||
|  | ||||
|     async def poll_event(self): | ||||
|         """Polls for a DISPATCH event and handles the general gateway loop. | ||||
| @@ -484,26 +493,35 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): | ||||
|             The websocket connection was terminated for unhandled reasons. | ||||
|         """ | ||||
|         try: | ||||
|             msg = await self.recv() | ||||
|             await self.received_message(msg) | ||||
|         except websockets.exceptions.ConnectionClosed as exc: | ||||
|             if self._can_handle_close(exc.code): | ||||
|                 log.info('Websocket closed with %s (%s), attempting a reconnect.', exc.code, exc.reason) | ||||
|                 raise ReconnectWebSocket(self.shard_id) from exc | ||||
|             else: | ||||
|                 log.info('Websocket closed with %s (%s), cannot reconnect.', exc.code, exc.reason) | ||||
|                 raise ConnectionClosed(exc, shard_id=self.shard_id) from exc | ||||
|             msg = await self.socket.receive() | ||||
|             if msg.type is aiohttp.WSMsgType.TEXT: | ||||
|                 await self.received_message(msg.data) | ||||
|             elif msg.type is aiohttp.WSMsgType.BINARY: | ||||
|                 await self.received_message(msg.data) | ||||
|             elif msg.type is aiohttp.WSMsgType.ERROR: | ||||
|                 log.debug('Received %s', msg) | ||||
|                 raise msg.data | ||||
|             elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE): | ||||
|                 log.debug('Received %s', msg) | ||||
|                 raise WebSocketClosure('Unexpected WebSocket closure.') | ||||
|         except WebSocketClosure as e: | ||||
|             if self._can_handle_close(): | ||||
|                 log.info('Websocket closed with %s, attempting a reconnect.', self.socket.close_code) | ||||
|                 raise ReconnectWebSocket(self.shard_id) from e | ||||
|             elif self.socket.close_code is not None: | ||||
|                 log.info('Websocket closed with %s, cannot reconnect.', self.socket.close_code) | ||||
|                 raise ConnectionClosed(self.socket, shard_id=self.shard_id) from e | ||||
|  | ||||
|     async def send(self, data): | ||||
|         self._dispatch('socket_raw_send', data) | ||||
|         await super().send(data) | ||||
|         await self.socket.send_str(data) | ||||
|  | ||||
|     async def send_as_json(self, data): | ||||
|         try: | ||||
|             await self.send(utils.to_json(data)) | ||||
|         except websockets.exceptions.ConnectionClosed as exc: | ||||
|             if not self._can_handle_close(exc.code): | ||||
|                 raise ConnectionClosed(exc, shard_id=self.shard_id) from exc | ||||
|         except RuntimeError as exc: | ||||
|             if not self._can_handle_close(): | ||||
|                 raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc | ||||
|  | ||||
|     async def change_presence(self, *, activity=None, status=None, afk=False, since=0.0): | ||||
|         if activity is not None: | ||||
| @@ -570,19 +588,13 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): | ||||
|         log.debug('Updating our voice state to %s.', payload) | ||||
|         await self.send_as_json(payload) | ||||
|  | ||||
|     async def close(self, code=4000, reason=''): | ||||
|     async def close(self, code=4000): | ||||
|         if self._keep_alive: | ||||
|             self._keep_alive.stop() | ||||
|  | ||||
|         await super().close(code, reason) | ||||
|         await self.socket.close(code=code) | ||||
|  | ||||
|     async def close_connection(self, *args, **kwargs): | ||||
|         if self._keep_alive: | ||||
|             self._keep_alive.stop() | ||||
|  | ||||
|         await super().close_connection(*args, **kwargs) | ||||
|  | ||||
| class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol): | ||||
| class DiscordVoiceWebSocket: | ||||
|     """Implements the websocket protocol for handling voice connections. | ||||
|  | ||||
|     Attributes | ||||
| @@ -626,14 +638,13 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol): | ||||
|     CLIENT_CONNECT      = 12 | ||||
|     CLIENT_DISCONNECT   = 13 | ||||
|  | ||||
|     def __init__(self, *args, **kwargs): | ||||
|         super().__init__(*args, **kwargs) | ||||
|         self.max_size = None | ||||
|     def __init__(self, socket): | ||||
|         self.ws = socket | ||||
|         self._keep_alive = None | ||||
|  | ||||
|     async def send_as_json(self, data): | ||||
|         log.debug('Sending voice websocket frame: %s.', data) | ||||
|         await self.send(utils.to_json(data)) | ||||
|         await self.ws.send_str(utils.to_json(data)) | ||||
|  | ||||
|     async def resume(self): | ||||
|         state = self._connection | ||||
| @@ -664,7 +675,9 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol): | ||||
|     async def from_client(cls, client, *, resume=False): | ||||
|         """Creates a voice websocket for the :class:`VoiceClient`.""" | ||||
|         gateway = 'wss://' + client.endpoint + '/?v=4' | ||||
|         ws = await websockets.connect(gateway, loop=client.loop, klass=cls, compression=None) | ||||
|         http = client._state.http | ||||
|         socket = await http.ws_connect(gateway) | ||||
|         ws = cls(socket) | ||||
|         ws.gateway = gateway | ||||
|         ws._connection = client | ||||
|         ws._max_heartbeat_timeout = 60.0 | ||||
| @@ -785,14 +798,19 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol): | ||||
|         await self.speak(False) | ||||
|  | ||||
|     async def poll_event(self): | ||||
|         try: | ||||
|             msg = await asyncio.wait_for(self.recv(), timeout=30.0) | ||||
|             await self.received_message(json.loads(msg)) | ||||
|         except websockets.exceptions.ConnectionClosed as exc: | ||||
|             raise ConnectionClosed(exc, shard_id=None) from exc | ||||
|         # This exception is handled up the chain | ||||
|         msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0) | ||||
|         if msg.type is aiohttp.WSMsgType.TEXT: | ||||
|             await self.received_message(json.loads(msg.data)) | ||||
|         elif msg.type is aiohttp.WSMsgType.ERROR: | ||||
|             log.debug('Received %s', msg) | ||||
|             raise ConnectionClosed(self.ws, shard_id=None) from msg.data | ||||
|         elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE): | ||||
|             log.debug('Received %s', msg) | ||||
|             raise ConnectionClosed(self.ws, shard_id=None) | ||||
|  | ||||
|     async def close_connection(self, *args, **kwargs): | ||||
|         if self._keep_alive: | ||||
|     async def close(self, code=1000): | ||||
|         if self._keep_alive is not None: | ||||
|             self._keep_alive.stop() | ||||
|  | ||||
|         await super().close_connection(*args, **kwargs) | ||||
|         await self.ws.close(code=code) | ||||
|   | ||||
| @@ -111,6 +111,17 @@ class HTTPClient: | ||||
|         if self.__session.closed: | ||||
|             self.__session = aiohttp.ClientSession(connector=self.connector) | ||||
|  | ||||
|     async def ws_connect(self, url): | ||||
|         kwargs = { | ||||
|             'proxy_auth': self.proxy_auth, | ||||
|             'proxy': self.proxy, | ||||
|             'max_msg_size': 0, | ||||
|             'timeout': 30.0, | ||||
|             'autoclose': False, | ||||
|         } | ||||
|  | ||||
|         return await self.__session.ws_connect(url, **kwargs) | ||||
|  | ||||
|     async def request(self, route, *, files=None, **kwargs): | ||||
|         bucket = route.bucket | ||||
|         method = route.method | ||||
|   | ||||
| @@ -28,8 +28,6 @@ import asyncio | ||||
| import itertools | ||||
| import logging | ||||
|  | ||||
| import websockets | ||||
|  | ||||
| from .state import AutoShardedConnectionState | ||||
| from .client import Client | ||||
| from .gateway import * | ||||
| @@ -191,31 +189,13 @@ class AutoShardedClient(Client): | ||||
|  | ||||
|     async def launch_shard(self, gateway, shard_id): | ||||
|         try: | ||||
|             coro = websockets.connect(gateway, loop=self.loop, klass=DiscordWebSocket, compression=None) | ||||
|             coro = DiscordWebSocket.from_client(self, gateway=gateway, shard_id=shard_id) | ||||
|             ws = await asyncio.wait_for(coro, timeout=180.0) | ||||
|         except Exception: | ||||
|             log.info('Failed to connect for shard_id: %s. Retrying...', shard_id) | ||||
|             await asyncio.sleep(5.0) | ||||
|             return await self.launch_shard(gateway, shard_id) | ||||
|  | ||||
|         ws.token = self.http.token | ||||
|         ws._connection = self._connection | ||||
|         ws._discord_parsers = self._connection.parsers | ||||
|         ws._dispatch = self.dispatch | ||||
|         ws.gateway = gateway | ||||
|         ws.shard_id = shard_id | ||||
|         ws.shard_count = self.shard_count | ||||
|         ws._max_heartbeat_timeout = self._connection.heartbeat_timeout | ||||
|  | ||||
|         try: | ||||
|             # OP HELLO | ||||
|             await asyncio.wait_for(ws.poll_event(), timeout=180.0) | ||||
|             await asyncio.wait_for(ws.identify(), timeout=180.0) | ||||
|         except asyncio.TimeoutError: | ||||
|             log.info('Timed out when connecting for shard_id: %s. Retrying...', shard_id) | ||||
|             await asyncio.sleep(5.0) | ||||
|             return await self.launch_shard(gateway, shard_id) | ||||
|  | ||||
|         # keep reading the shard while others connect | ||||
|         self.shards[shard_id] = ret = Shard(ws, self) | ||||
|         ret.launch() | ||||
|   | ||||
		Reference in New Issue
	
	Block a user