Rewrite gateway to use aiohttp instead of websockets
This commit is contained in:
parent
45cb231161
commit
b8154e365f
@ -31,7 +31,6 @@ from pathlib import Path
|
|||||||
import discord
|
import discord
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import websockets
|
|
||||||
import platform
|
import platform
|
||||||
|
|
||||||
def show_version():
|
def show_version():
|
||||||
@ -46,7 +45,6 @@ def show_version():
|
|||||||
entries.append(' - discord.py pkg_resources: v{0}'.format(pkg.version))
|
entries.append(' - discord.py pkg_resources: v{0}'.format(pkg.version))
|
||||||
|
|
||||||
entries.append('- aiohttp v{0.__version__}'.format(aiohttp))
|
entries.append('- aiohttp v{0.__version__}'.format(aiohttp))
|
||||||
entries.append('- websockets v{0.__version__}'.format(websockets))
|
|
||||||
uname = platform.uname()
|
uname = platform.uname()
|
||||||
entries.append('- system info: {0.system} {0.release} {0.version}'.format(uname))
|
entries.append('- system info: {0.system} {0.release} {0.version}'.format(uname))
|
||||||
print('\n'.join(entries))
|
print('\n'.join(entries))
|
||||||
|
@ -32,7 +32,6 @@ import sys
|
|||||||
import traceback
|
import traceback
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import websockets
|
|
||||||
|
|
||||||
from .user import User, Profile
|
from .user import User, Profile
|
||||||
from .asset import Asset
|
from .asset import Asset
|
||||||
@ -497,9 +496,7 @@ class Client:
|
|||||||
GatewayNotFound,
|
GatewayNotFound,
|
||||||
ConnectionClosed,
|
ConnectionClosed,
|
||||||
aiohttp.ClientError,
|
aiohttp.ClientError,
|
||||||
asyncio.TimeoutError,
|
asyncio.TimeoutError) as exc:
|
||||||
websockets.InvalidHandshake,
|
|
||||||
websockets.WebSocketProtocolError) as exc:
|
|
||||||
|
|
||||||
self.dispatch('disconnect')
|
self.dispatch('disconnect')
|
||||||
if not reconnect:
|
if not reconnect:
|
||||||
@ -632,7 +629,11 @@ class Client:
|
|||||||
_cleanup_loop(loop)
|
_cleanup_loop(loop)
|
||||||
|
|
||||||
if not future.cancelled():
|
if not future.cancelled():
|
||||||
|
try:
|
||||||
return future.result()
|
return future.result()
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
# I am unsure why this gets raised here but suppress it anyway
|
||||||
|
return None
|
||||||
|
|
||||||
# properties
|
# properties
|
||||||
|
|
||||||
|
@ -159,10 +159,11 @@ class ConnectionClosed(ClientException):
|
|||||||
shard_id: Optional[:class:`int`]
|
shard_id: Optional[:class:`int`]
|
||||||
The shard ID that got closed if applicable.
|
The shard ID that got closed if applicable.
|
||||||
"""
|
"""
|
||||||
def __init__(self, original, *, shard_id):
|
def __init__(self, socket, *, shard_id):
|
||||||
# This exception is just the same exception except
|
# This exception is just the same exception except
|
||||||
# reconfigured to subclass ClientException for users
|
# reconfigured to subclass ClientException for users
|
||||||
self.code = original.code
|
self.code = socket.close_code
|
||||||
self.reason = original.reason
|
# aiohttp doesn't seem to consistently provide close reason
|
||||||
|
self.reason = ''
|
||||||
self.shard_id = shard_id
|
self.shard_id = shard_id
|
||||||
super().__init__(str(original))
|
super().__init__('Shard ID %s WebSocket closed with %s' % (self.shard_id, self.code))
|
||||||
|
@ -27,7 +27,6 @@ DEALINGS IN THE SOFTWARE.
|
|||||||
import asyncio
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import websockets
|
|
||||||
import discord
|
import discord
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
@ -58,8 +57,6 @@ class Loop:
|
|||||||
discord.ConnectionClosed,
|
discord.ConnectionClosed,
|
||||||
aiohttp.ClientError,
|
aiohttp.ClientError,
|
||||||
asyncio.TimeoutError,
|
asyncio.TimeoutError,
|
||||||
websockets.InvalidHandshake,
|
|
||||||
websockets.WebSocketProtocolError,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
self._before_loop = None
|
self._before_loop = None
|
||||||
|
@ -36,7 +36,7 @@ import threading
|
|||||||
import traceback
|
import traceback
|
||||||
import zlib
|
import zlib
|
||||||
|
|
||||||
import websockets
|
import aiohttp
|
||||||
|
|
||||||
from . import utils
|
from . import utils
|
||||||
from .activity import BaseActivity
|
from .activity import BaseActivity
|
||||||
@ -60,6 +60,10 @@ class ReconnectWebSocket(Exception):
|
|||||||
self.resume = resume
|
self.resume = resume
|
||||||
self.op = 'RESUME' if resume else 'IDENTIFY'
|
self.op = 'RESUME' if resume else 'IDENTIFY'
|
||||||
|
|
||||||
|
class WebSocketClosure(Exception):
|
||||||
|
"""An exception to make up for the fact that aiohttp doesn't signal closure."""
|
||||||
|
pass
|
||||||
|
|
||||||
EventListener = namedtuple('EventListener', 'predicate event result future')
|
EventListener = namedtuple('EventListener', 'predicate event result future')
|
||||||
|
|
||||||
class KeepAliveHandler(threading.Thread):
|
class KeepAliveHandler(threading.Thread):
|
||||||
@ -160,11 +164,17 @@ class VoiceKeepAliveHandler(KeepAliveHandler):
|
|||||||
self.latency = ack_time - self._last_send
|
self.latency = ack_time - self._last_send
|
||||||
self.recent_ack_latencies.append(self.latency)
|
self.recent_ack_latencies.append(self.latency)
|
||||||
|
|
||||||
class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
# Monkey patch certain things from the aiohttp websocket code
|
||||||
"""Implements a WebSocket for Discord's gateway v6.
|
# Check this whenever we update dependencies.
|
||||||
|
OLD_CLOSE = aiohttp.ClientWebSocketResponse.close
|
||||||
|
|
||||||
This is created through :func:`create_main_websocket`. Library
|
async def _new_ws_close(self, *, code: int = 4000, message: bytes = b'') -> bool:
|
||||||
users should never create this manually.
|
return await OLD_CLOSE(self, code=code, message=message)
|
||||||
|
|
||||||
|
aiohttp.ClientWebSocketResponse.close = _new_ws_close
|
||||||
|
|
||||||
|
class DiscordWebSocket:
|
||||||
|
"""Implements a WebSocket for Discord's gateway v6.
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
-----------
|
-----------
|
||||||
@ -217,9 +227,10 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
HEARTBEAT_ACK = 11
|
HEARTBEAT_ACK = 11
|
||||||
GUILD_SYNC = 12
|
GUILD_SYNC = 12
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, socket, *, loop):
|
||||||
super().__init__(*args, **kwargs)
|
self.socket = socket
|
||||||
self.max_size = None
|
self.loop = loop
|
||||||
|
|
||||||
# an empty dispatcher to prevent crashes
|
# an empty dispatcher to prevent crashes
|
||||||
self._dispatch = lambda *args: None
|
self._dispatch = lambda *args: None
|
||||||
# generic event listeners
|
# generic event listeners
|
||||||
@ -234,14 +245,19 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
self._zlib = zlib.decompressobj()
|
self._zlib = zlib.decompressobj()
|
||||||
self._buffer = bytearray()
|
self._buffer = bytearray()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def open(self):
|
||||||
|
return not self.socket.closed
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def from_client(cls, client, *, shard_id=None, session=None, sequence=None, resume=False):
|
async def from_client(cls, client, *, gateway=None, shard_id=None, session=None, sequence=None, resume=False):
|
||||||
"""Creates a main websocket for Discord from a :class:`Client`.
|
"""Creates a main websocket for Discord from a :class:`Client`.
|
||||||
|
|
||||||
This is for internal use only.
|
This is for internal use only.
|
||||||
"""
|
"""
|
||||||
gateway = await client.http.get_gateway()
|
gateway = gateway or await client.http.get_gateway()
|
||||||
ws = await websockets.connect(gateway, loop=client.loop, klass=cls, compression=None)
|
socket = await client.http.ws_connect(gateway)
|
||||||
|
ws = cls(socket, loop=client.loop)
|
||||||
|
|
||||||
# dynamically add attributes needed
|
# dynamically add attributes needed
|
||||||
ws.token = client.http.token
|
ws.token = client.http.token
|
||||||
@ -267,13 +283,6 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
return ws
|
return ws
|
||||||
|
|
||||||
await ws.resume()
|
await ws.resume()
|
||||||
try:
|
|
||||||
await ws.ensure_open()
|
|
||||||
except websockets.exceptions.ConnectionClosed:
|
|
||||||
# ws got closed so let's just do a regular IDENTIFY connect.
|
|
||||||
log.warning('RESUME failed (the websocket decided to close) for Shard ID %s. Retrying.', shard_id)
|
|
||||||
return await cls.from_client(client, shard_id=shard_id)
|
|
||||||
else:
|
|
||||||
return ws
|
return ws
|
||||||
|
|
||||||
def wait_for(self, event, predicate, result=None):
|
def wait_for(self, event, predicate, result=None):
|
||||||
@ -472,8 +481,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
heartbeat = self._keep_alive
|
heartbeat = self._keep_alive
|
||||||
return float('inf') if heartbeat is None else heartbeat.latency
|
return float('inf') if heartbeat is None else heartbeat.latency
|
||||||
|
|
||||||
def _can_handle_close(self, code):
|
def _can_handle_close(self):
|
||||||
return code not in (1000, 4004, 4010, 4011)
|
return self.socket.close_code not in (1000, 4004, 4010, 4011)
|
||||||
|
|
||||||
async def poll_event(self):
|
async def poll_event(self):
|
||||||
"""Polls for a DISPATCH event and handles the general gateway loop.
|
"""Polls for a DISPATCH event and handles the general gateway loop.
|
||||||
@ -484,26 +493,35 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
The websocket connection was terminated for unhandled reasons.
|
The websocket connection was terminated for unhandled reasons.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
msg = await self.recv()
|
msg = await self.socket.receive()
|
||||||
await self.received_message(msg)
|
if msg.type is aiohttp.WSMsgType.TEXT:
|
||||||
except websockets.exceptions.ConnectionClosed as exc:
|
await self.received_message(msg.data)
|
||||||
if self._can_handle_close(exc.code):
|
elif msg.type is aiohttp.WSMsgType.BINARY:
|
||||||
log.info('Websocket closed with %s (%s), attempting a reconnect.', exc.code, exc.reason)
|
await self.received_message(msg.data)
|
||||||
raise ReconnectWebSocket(self.shard_id) from exc
|
elif msg.type is aiohttp.WSMsgType.ERROR:
|
||||||
else:
|
log.debug('Received %s', msg)
|
||||||
log.info('Websocket closed with %s (%s), cannot reconnect.', exc.code, exc.reason)
|
raise msg.data
|
||||||
raise ConnectionClosed(exc, shard_id=self.shard_id) from exc
|
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE):
|
||||||
|
log.debug('Received %s', msg)
|
||||||
|
raise WebSocketClosure('Unexpected WebSocket closure.')
|
||||||
|
except WebSocketClosure as e:
|
||||||
|
if self._can_handle_close():
|
||||||
|
log.info('Websocket closed with %s, attempting a reconnect.', self.socket.close_code)
|
||||||
|
raise ReconnectWebSocket(self.shard_id) from e
|
||||||
|
elif self.socket.close_code is not None:
|
||||||
|
log.info('Websocket closed with %s, cannot reconnect.', self.socket.close_code)
|
||||||
|
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from e
|
||||||
|
|
||||||
async def send(self, data):
|
async def send(self, data):
|
||||||
self._dispatch('socket_raw_send', data)
|
self._dispatch('socket_raw_send', data)
|
||||||
await super().send(data)
|
await self.socket.send_str(data)
|
||||||
|
|
||||||
async def send_as_json(self, data):
|
async def send_as_json(self, data):
|
||||||
try:
|
try:
|
||||||
await self.send(utils.to_json(data))
|
await self.send(utils.to_json(data))
|
||||||
except websockets.exceptions.ConnectionClosed as exc:
|
except RuntimeError as exc:
|
||||||
if not self._can_handle_close(exc.code):
|
if not self._can_handle_close():
|
||||||
raise ConnectionClosed(exc, shard_id=self.shard_id) from exc
|
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
|
||||||
|
|
||||||
async def change_presence(self, *, activity=None, status=None, afk=False, since=0.0):
|
async def change_presence(self, *, activity=None, status=None, afk=False, since=0.0):
|
||||||
if activity is not None:
|
if activity is not None:
|
||||||
@ -570,19 +588,13 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
log.debug('Updating our voice state to %s.', payload)
|
log.debug('Updating our voice state to %s.', payload)
|
||||||
await self.send_as_json(payload)
|
await self.send_as_json(payload)
|
||||||
|
|
||||||
async def close(self, code=4000, reason=''):
|
async def close(self, code=4000):
|
||||||
if self._keep_alive:
|
if self._keep_alive:
|
||||||
self._keep_alive.stop()
|
self._keep_alive.stop()
|
||||||
|
|
||||||
await super().close(code, reason)
|
await self.socket.close(code=code)
|
||||||
|
|
||||||
async def close_connection(self, *args, **kwargs):
|
class DiscordVoiceWebSocket:
|
||||||
if self._keep_alive:
|
|
||||||
self._keep_alive.stop()
|
|
||||||
|
|
||||||
await super().close_connection(*args, **kwargs)
|
|
||||||
|
|
||||||
class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
|
|
||||||
"""Implements the websocket protocol for handling voice connections.
|
"""Implements the websocket protocol for handling voice connections.
|
||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
@ -626,14 +638,13 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
CLIENT_CONNECT = 12
|
CLIENT_CONNECT = 12
|
||||||
CLIENT_DISCONNECT = 13
|
CLIENT_DISCONNECT = 13
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, socket):
|
||||||
super().__init__(*args, **kwargs)
|
self.ws = socket
|
||||||
self.max_size = None
|
|
||||||
self._keep_alive = None
|
self._keep_alive = None
|
||||||
|
|
||||||
async def send_as_json(self, data):
|
async def send_as_json(self, data):
|
||||||
log.debug('Sending voice websocket frame: %s.', data)
|
log.debug('Sending voice websocket frame: %s.', data)
|
||||||
await self.send(utils.to_json(data))
|
await self.ws.send_str(utils.to_json(data))
|
||||||
|
|
||||||
async def resume(self):
|
async def resume(self):
|
||||||
state = self._connection
|
state = self._connection
|
||||||
@ -664,7 +675,9 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
async def from_client(cls, client, *, resume=False):
|
async def from_client(cls, client, *, resume=False):
|
||||||
"""Creates a voice websocket for the :class:`VoiceClient`."""
|
"""Creates a voice websocket for the :class:`VoiceClient`."""
|
||||||
gateway = 'wss://' + client.endpoint + '/?v=4'
|
gateway = 'wss://' + client.endpoint + '/?v=4'
|
||||||
ws = await websockets.connect(gateway, loop=client.loop, klass=cls, compression=None)
|
http = client._state.http
|
||||||
|
socket = await http.ws_connect(gateway)
|
||||||
|
ws = cls(socket)
|
||||||
ws.gateway = gateway
|
ws.gateway = gateway
|
||||||
ws._connection = client
|
ws._connection = client
|
||||||
ws._max_heartbeat_timeout = 60.0
|
ws._max_heartbeat_timeout = 60.0
|
||||||
@ -785,14 +798,19 @@ class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
await self.speak(False)
|
await self.speak(False)
|
||||||
|
|
||||||
async def poll_event(self):
|
async def poll_event(self):
|
||||||
try:
|
# This exception is handled up the chain
|
||||||
msg = await asyncio.wait_for(self.recv(), timeout=30.0)
|
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
|
||||||
await self.received_message(json.loads(msg))
|
if msg.type is aiohttp.WSMsgType.TEXT:
|
||||||
except websockets.exceptions.ConnectionClosed as exc:
|
await self.received_message(json.loads(msg.data))
|
||||||
raise ConnectionClosed(exc, shard_id=None) from exc
|
elif msg.type is aiohttp.WSMsgType.ERROR:
|
||||||
|
log.debug('Received %s', msg)
|
||||||
|
raise ConnectionClosed(self.ws, shard_id=None) from msg.data
|
||||||
|
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE):
|
||||||
|
log.debug('Received %s', msg)
|
||||||
|
raise ConnectionClosed(self.ws, shard_id=None)
|
||||||
|
|
||||||
async def close_connection(self, *args, **kwargs):
|
async def close(self, code=1000):
|
||||||
if self._keep_alive:
|
if self._keep_alive is not None:
|
||||||
self._keep_alive.stop()
|
self._keep_alive.stop()
|
||||||
|
|
||||||
await super().close_connection(*args, **kwargs)
|
await self.ws.close(code=code)
|
||||||
|
@ -111,6 +111,17 @@ class HTTPClient:
|
|||||||
if self.__session.closed:
|
if self.__session.closed:
|
||||||
self.__session = aiohttp.ClientSession(connector=self.connector)
|
self.__session = aiohttp.ClientSession(connector=self.connector)
|
||||||
|
|
||||||
|
async def ws_connect(self, url):
|
||||||
|
kwargs = {
|
||||||
|
'proxy_auth': self.proxy_auth,
|
||||||
|
'proxy': self.proxy,
|
||||||
|
'max_msg_size': 0,
|
||||||
|
'timeout': 30.0,
|
||||||
|
'autoclose': False,
|
||||||
|
}
|
||||||
|
|
||||||
|
return await self.__session.ws_connect(url, **kwargs)
|
||||||
|
|
||||||
async def request(self, route, *, files=None, **kwargs):
|
async def request(self, route, *, files=None, **kwargs):
|
||||||
bucket = route.bucket
|
bucket = route.bucket
|
||||||
method = route.method
|
method = route.method
|
||||||
|
@ -28,8 +28,6 @@ import asyncio
|
|||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
import websockets
|
|
||||||
|
|
||||||
from .state import AutoShardedConnectionState
|
from .state import AutoShardedConnectionState
|
||||||
from .client import Client
|
from .client import Client
|
||||||
from .gateway import *
|
from .gateway import *
|
||||||
@ -191,31 +189,13 @@ class AutoShardedClient(Client):
|
|||||||
|
|
||||||
async def launch_shard(self, gateway, shard_id):
|
async def launch_shard(self, gateway, shard_id):
|
||||||
try:
|
try:
|
||||||
coro = websockets.connect(gateway, loop=self.loop, klass=DiscordWebSocket, compression=None)
|
coro = DiscordWebSocket.from_client(self, gateway=gateway, shard_id=shard_id)
|
||||||
ws = await asyncio.wait_for(coro, timeout=180.0)
|
ws = await asyncio.wait_for(coro, timeout=180.0)
|
||||||
except Exception:
|
except Exception:
|
||||||
log.info('Failed to connect for shard_id: %s. Retrying...', shard_id)
|
log.info('Failed to connect for shard_id: %s. Retrying...', shard_id)
|
||||||
await asyncio.sleep(5.0)
|
await asyncio.sleep(5.0)
|
||||||
return await self.launch_shard(gateway, shard_id)
|
return await self.launch_shard(gateway, shard_id)
|
||||||
|
|
||||||
ws.token = self.http.token
|
|
||||||
ws._connection = self._connection
|
|
||||||
ws._discord_parsers = self._connection.parsers
|
|
||||||
ws._dispatch = self.dispatch
|
|
||||||
ws.gateway = gateway
|
|
||||||
ws.shard_id = shard_id
|
|
||||||
ws.shard_count = self.shard_count
|
|
||||||
ws._max_heartbeat_timeout = self._connection.heartbeat_timeout
|
|
||||||
|
|
||||||
try:
|
|
||||||
# OP HELLO
|
|
||||||
await asyncio.wait_for(ws.poll_event(), timeout=180.0)
|
|
||||||
await asyncio.wait_for(ws.identify(), timeout=180.0)
|
|
||||||
except asyncio.TimeoutError:
|
|
||||||
log.info('Timed out when connecting for shard_id: %s. Retrying...', shard_id)
|
|
||||||
await asyncio.sleep(5.0)
|
|
||||||
return await self.launch_shard(gateway, shard_id)
|
|
||||||
|
|
||||||
# keep reading the shard while others connect
|
# keep reading the shard while others connect
|
||||||
self.shards[shard_id] = ret = Shard(ws, self)
|
self.shards[shard_id] = ret = Shard(ws, self)
|
||||||
ret.launch()
|
ret.launch()
|
||||||
|
@ -1,2 +1 @@
|
|||||||
aiohttp>=3.6.0,<3.7.0
|
aiohttp>=3.6.0,<3.7.0
|
||||||
websockets>=6.0,!=7.0,!=8.0,!=8.0.1,<9.0
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user