Refactor voice websocket into gateway.py
This commit is contained in:
@ -36,11 +36,13 @@ import logging
|
||||
import zlib, time, json
|
||||
from collections import namedtuple
|
||||
import threading
|
||||
import struct
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
__all__ = [ 'ReconnectWebSocket', 'get_gateway', 'DiscordWebSocket',
|
||||
'KeepAliveHandler' ]
|
||||
'KeepAliveHandler', 'VoiceKeepAliveHandler',
|
||||
'DiscordVoiceWebSocket' ]
|
||||
|
||||
class ReconnectWebSocket(Exception):
|
||||
"""Signals to handle the RECONNECT opcode."""
|
||||
@ -56,13 +58,13 @@ class KeepAliveHandler(threading.Thread):
|
||||
self.ws = ws
|
||||
self.interval = interval
|
||||
self.daemon = True
|
||||
self.msg = 'Keeping websocket alive with sequence {0[d]}'
|
||||
self._stop = threading.Event()
|
||||
|
||||
def run(self):
|
||||
while not self._stop.wait(self.interval):
|
||||
data = self.get_payload()
|
||||
msg = 'Keeping websocket alive with sequence {0[d]}'.format(data)
|
||||
log.debug(msg)
|
||||
log.debug(self.msg.format(data))
|
||||
coro = self.ws.send_as_json(data)
|
||||
f = compat.run_coroutine_threadsafe(coro, loop=self.ws.loop)
|
||||
try:
|
||||
@ -80,6 +82,17 @@ class KeepAliveHandler(threading.Thread):
|
||||
def stop(self):
|
||||
self._stop.set()
|
||||
|
||||
class VoiceKeepAliveHandler(KeepAliveHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.msg = 'Keeping voice websocket alive with timestamp {0[d]}'
|
||||
|
||||
def get_payload(self):
|
||||
return {
|
||||
'op': self.ws.HEARTBEAT,
|
||||
'd': int(time.time() * 1000)
|
||||
}
|
||||
|
||||
|
||||
@asyncio.coroutine
|
||||
def get_gateway(token, *, loop=None):
|
||||
@ -212,7 +225,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
connection=client.connection,
|
||||
loop=client.loop)
|
||||
|
||||
def wait_for(self, event, predicate, result):
|
||||
def wait_for(self, event, predicate, result=None):
|
||||
"""Waits for a DISPATCH'd event that meets the predicate.
|
||||
|
||||
Parameters
|
||||
@ -224,7 +237,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
properties. The data parameter is the 'd' key in the JSON message.
|
||||
result
|
||||
A function that takes the same data parameter and executes to send
|
||||
the result to the future.
|
||||
the result to the future. If None, returns the data.
|
||||
|
||||
Returns
|
||||
--------
|
||||
@ -281,6 +294,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
# "reconnect" can only be handled by the Client
|
||||
# so we terminate our connection and raise an
|
||||
# internal exception signalling to reconnect.
|
||||
log.info('Receivede RECONNECT opcode.')
|
||||
yield from self.close()
|
||||
raise ReconnectWebSocket()
|
||||
|
||||
@ -332,7 +346,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
removed.append(index)
|
||||
else:
|
||||
if valid:
|
||||
future.set_result(entry.result)
|
||||
ret = data if entry.result is None else entry.result(data)
|
||||
future.set_result(ret)
|
||||
removed.append(index)
|
||||
|
||||
for index in reversed(removed):
|
||||
@ -352,6 +367,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
yield from self.received_message(msg)
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
if e.code in (4008, 4009) or e.code in range(1001, 1015):
|
||||
log.info('Websocket closed with {0.code}, attempting a reconnect.'.format(e))
|
||||
raise ReconnectWebSocket() from e
|
||||
else:
|
||||
raise ConnectionClosed(e) from e
|
||||
@ -394,9 +410,171 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
status = Status.idle if idle_since else Status.online
|
||||
me.status = status
|
||||
|
||||
@asyncio.coroutine
|
||||
def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False):
|
||||
payload = {
|
||||
'op': self.VOICE_STATE,
|
||||
'd': {
|
||||
'guild_id': guild_id,
|
||||
'channel_id': channel_id,
|
||||
'self_mute': self_mute,
|
||||
'self_deaf': self_deaf
|
||||
}
|
||||
}
|
||||
|
||||
yield from self.send_as_json(payload)
|
||||
|
||||
@asyncio.coroutine
|
||||
def close(self, code=1000, reason=''):
|
||||
if self._keep_alive:
|
||||
self._keep_alive.stop()
|
||||
|
||||
yield from super().close(code, reason)
|
||||
|
||||
class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
|
||||
"""Implements the websocket protocol for handling voice connections.
|
||||
|
||||
Attributes
|
||||
-----------
|
||||
IDENTIFY
|
||||
Send only. Starts a new voice session.
|
||||
SELECT_PROTOCOL
|
||||
Send only. Tells discord what encryption mode and how to connect for voice.
|
||||
READY
|
||||
Receive only. Tells the websocket that the initial connection has completed.
|
||||
HEARTBEAT
|
||||
Send only. Keeps your websocket connection alive.
|
||||
SESSION_DESCRIPTION
|
||||
Receive only. Gives you the secret key required for voice.
|
||||
SPEAKING
|
||||
Send only. Notifies the client if you are currently speaking.
|
||||
"""
|
||||
|
||||
IDENTIFY = 0
|
||||
SELECT_PROTOCOL = 1
|
||||
READY = 2
|
||||
HEARTBEAT = 3
|
||||
SESSION_DESCRIPTION = 4
|
||||
SPEAKING = 5
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.max_size = None
|
||||
self._keep_alive = None
|
||||
|
||||
@asyncio.coroutine
|
||||
def send_as_json(self, data):
|
||||
yield from self.send(utils.to_json(data))
|
||||
|
||||
@classmethod
|
||||
@asyncio.coroutine
|
||||
def from_client(cls, client):
|
||||
"""Creates a voice websocket for the :class:`VoiceClient`."""
|
||||
gateway = 'wss://' + client.endpoint
|
||||
ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls)
|
||||
ws.gateway = gateway
|
||||
ws._connection = client
|
||||
|
||||
identify = {
|
||||
'op': cls.IDENTIFY,
|
||||
'd': {
|
||||
'server_id': client.guild_id,
|
||||
'user_id': client.user.id,
|
||||
'session_id': client.session_id,
|
||||
'token': client.token
|
||||
}
|
||||
}
|
||||
|
||||
yield from ws.send_as_json(identify)
|
||||
return ws
|
||||
|
||||
@asyncio.coroutine
|
||||
def select_protocol(self, ip, port):
|
||||
payload = {
|
||||
'op': self.SELECT_PROTOCOL,
|
||||
'd': {
|
||||
'protocol': 'udp',
|
||||
'data': {
|
||||
'address': ip,
|
||||
'port': port,
|
||||
'mode': 'xsalsa20_poly1305'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
yield from self.send_as_json(payload)
|
||||
log.debug('Selected protocol as {}'.format(payload))
|
||||
|
||||
@asyncio.coroutine
|
||||
def speak(self, is_speaking=True):
|
||||
payload = {
|
||||
'op': self.SPEAKING,
|
||||
'd': {
|
||||
'speaking': is_speaking,
|
||||
'delay': 0
|
||||
}
|
||||
}
|
||||
|
||||
yield from self.send_as_json(payload)
|
||||
log.debug('Voice speaking now set to {}'.format(is_speaking))
|
||||
|
||||
@asyncio.coroutine
|
||||
def received_message(self, msg):
|
||||
log.debug('Voice websocket frame received: {}'.format(msg))
|
||||
op = msg.get('op')
|
||||
data = msg.get('d')
|
||||
|
||||
if op == self.READY:
|
||||
interval = (data['heartbeat_interval'] / 100.0) - 5
|
||||
self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=interval)
|
||||
self._keep_alive.start()
|
||||
yield from self.initial_connection(data)
|
||||
elif op == self.SESSION_DESCRIPTION:
|
||||
yield from self.load_secret_key(data)
|
||||
|
||||
@asyncio.coroutine
|
||||
def initial_connection(self, data):
|
||||
state = self._connection
|
||||
state.ssrc = data.get('ssrc')
|
||||
state.voice_port = data.get('port')
|
||||
packet = bytearray(70)
|
||||
struct.pack_into('>I', packet, 0, state.ssrc)
|
||||
state.socket.sendto(packet, (state.endpoint_ip, state.voice_port))
|
||||
recv = yield from self.loop.sock_recv(state.socket, 70)
|
||||
log.debug('received packet in initial_connection: {}'.format(recv))
|
||||
|
||||
# the ip is ascii starting at the 4th byte and ending at the first null
|
||||
ip_start = 4
|
||||
ip_end = recv.index(0, ip_start)
|
||||
state.ip = recv[ip_start:ip_end].decode('ascii')
|
||||
|
||||
# the port is a little endian unsigned short in the last two bytes
|
||||
# yes, this is different endianness from everything else
|
||||
state.port = struct.unpack_from('<H', recv, len(recv) - 2)[0]
|
||||
|
||||
log.debug('detected ip: {0.ip} port: {0.port}'.format(state))
|
||||
yield from self.select_protocol(state.ip, state.port)
|
||||
log.info('selected the voice protocol for use')
|
||||
|
||||
@asyncio.coroutine
|
||||
def load_secret_key(self, data):
|
||||
log.info('received secret key for voice connection')
|
||||
self._connection.secret_key = data.get('secret_key')
|
||||
yield from self.speak()
|
||||
|
||||
@asyncio.coroutine
|
||||
def poll_event(self):
|
||||
try:
|
||||
msg = yield from self.recv()
|
||||
yield from self.received_message(json.loads(msg))
|
||||
except websockets.exceptions.ConnectionClosed as e:
|
||||
raise ConnectionClosed(e) from e
|
||||
|
||||
@asyncio.coroutine
|
||||
def close(self, code=1000, reason=''):
|
||||
if self._keep_alive:
|
||||
self._keep_alive.stop()
|
||||
|
||||
yield from super().close(code, reason)
|
||||
|
||||
|
||||
|
Reference in New Issue
Block a user