Refactor voice websocket into gateway.py

This commit is contained in:
Rapptz
2016-04-27 17:37:25 -04:00
parent 1c623ccf11
commit c1b5a52823
3 changed files with 211 additions and 211 deletions

View File

@ -36,11 +36,13 @@ import logging
import zlib, time, json
from collections import namedtuple
import threading
import struct
log = logging.getLogger(__name__)
__all__ = [ 'ReconnectWebSocket', 'get_gateway', 'DiscordWebSocket',
'KeepAliveHandler' ]
'KeepAliveHandler', 'VoiceKeepAliveHandler',
'DiscordVoiceWebSocket' ]
class ReconnectWebSocket(Exception):
"""Signals to handle the RECONNECT opcode."""
@ -56,13 +58,13 @@ class KeepAliveHandler(threading.Thread):
self.ws = ws
self.interval = interval
self.daemon = True
self.msg = 'Keeping websocket alive with sequence {0[d]}'
self._stop = threading.Event()
def run(self):
while not self._stop.wait(self.interval):
data = self.get_payload()
msg = 'Keeping websocket alive with sequence {0[d]}'.format(data)
log.debug(msg)
log.debug(self.msg.format(data))
coro = self.ws.send_as_json(data)
f = compat.run_coroutine_threadsafe(coro, loop=self.ws.loop)
try:
@ -80,6 +82,17 @@ class KeepAliveHandler(threading.Thread):
def stop(self):
self._stop.set()
class VoiceKeepAliveHandler(KeepAliveHandler):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.msg = 'Keeping voice websocket alive with timestamp {0[d]}'
def get_payload(self):
return {
'op': self.ws.HEARTBEAT,
'd': int(time.time() * 1000)
}
@asyncio.coroutine
def get_gateway(token, *, loop=None):
@ -212,7 +225,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
connection=client.connection,
loop=client.loop)
def wait_for(self, event, predicate, result):
def wait_for(self, event, predicate, result=None):
"""Waits for a DISPATCH'd event that meets the predicate.
Parameters
@ -224,7 +237,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
properties. The data parameter is the 'd' key in the JSON message.
result
A function that takes the same data parameter and executes to send
the result to the future.
the result to the future. If None, returns the data.
Returns
--------
@ -281,6 +294,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
# "reconnect" can only be handled by the Client
# so we terminate our connection and raise an
# internal exception signalling to reconnect.
log.info('Receivede RECONNECT opcode.')
yield from self.close()
raise ReconnectWebSocket()
@ -332,7 +346,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
removed.append(index)
else:
if valid:
future.set_result(entry.result)
ret = data if entry.result is None else entry.result(data)
future.set_result(ret)
removed.append(index)
for index in reversed(removed):
@ -352,6 +367,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
yield from self.received_message(msg)
except websockets.exceptions.ConnectionClosed as e:
if e.code in (4008, 4009) or e.code in range(1001, 1015):
log.info('Websocket closed with {0.code}, attempting a reconnect.'.format(e))
raise ReconnectWebSocket() from e
else:
raise ConnectionClosed(e) from e
@ -394,9 +410,171 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
status = Status.idle if idle_since else Status.online
me.status = status
@asyncio.coroutine
def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False):
payload = {
'op': self.VOICE_STATE,
'd': {
'guild_id': guild_id,
'channel_id': channel_id,
'self_mute': self_mute,
'self_deaf': self_deaf
}
}
yield from self.send_as_json(payload)
@asyncio.coroutine
def close(self, code=1000, reason=''):
if self._keep_alive:
self._keep_alive.stop()
yield from super().close(code, reason)
class DiscordVoiceWebSocket(websockets.client.WebSocketClientProtocol):
"""Implements the websocket protocol for handling voice connections.
Attributes
-----------
IDENTIFY
Send only. Starts a new voice session.
SELECT_PROTOCOL
Send only. Tells discord what encryption mode and how to connect for voice.
READY
Receive only. Tells the websocket that the initial connection has completed.
HEARTBEAT
Send only. Keeps your websocket connection alive.
SESSION_DESCRIPTION
Receive only. Gives you the secret key required for voice.
SPEAKING
Send only. Notifies the client if you are currently speaking.
"""
IDENTIFY = 0
SELECT_PROTOCOL = 1
READY = 2
HEARTBEAT = 3
SESSION_DESCRIPTION = 4
SPEAKING = 5
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.max_size = None
self._keep_alive = None
@asyncio.coroutine
def send_as_json(self, data):
yield from self.send(utils.to_json(data))
@classmethod
@asyncio.coroutine
def from_client(cls, client):
"""Creates a voice websocket for the :class:`VoiceClient`."""
gateway = 'wss://' + client.endpoint
ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls)
ws.gateway = gateway
ws._connection = client
identify = {
'op': cls.IDENTIFY,
'd': {
'server_id': client.guild_id,
'user_id': client.user.id,
'session_id': client.session_id,
'token': client.token
}
}
yield from ws.send_as_json(identify)
return ws
@asyncio.coroutine
def select_protocol(self, ip, port):
payload = {
'op': self.SELECT_PROTOCOL,
'd': {
'protocol': 'udp',
'data': {
'address': ip,
'port': port,
'mode': 'xsalsa20_poly1305'
}
}
}
yield from self.send_as_json(payload)
log.debug('Selected protocol as {}'.format(payload))
@asyncio.coroutine
def speak(self, is_speaking=True):
payload = {
'op': self.SPEAKING,
'd': {
'speaking': is_speaking,
'delay': 0
}
}
yield from self.send_as_json(payload)
log.debug('Voice speaking now set to {}'.format(is_speaking))
@asyncio.coroutine
def received_message(self, msg):
log.debug('Voice websocket frame received: {}'.format(msg))
op = msg.get('op')
data = msg.get('d')
if op == self.READY:
interval = (data['heartbeat_interval'] / 100.0) - 5
self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=interval)
self._keep_alive.start()
yield from self.initial_connection(data)
elif op == self.SESSION_DESCRIPTION:
yield from self.load_secret_key(data)
@asyncio.coroutine
def initial_connection(self, data):
state = self._connection
state.ssrc = data.get('ssrc')
state.voice_port = data.get('port')
packet = bytearray(70)
struct.pack_into('>I', packet, 0, state.ssrc)
state.socket.sendto(packet, (state.endpoint_ip, state.voice_port))
recv = yield from self.loop.sock_recv(state.socket, 70)
log.debug('received packet in initial_connection: {}'.format(recv))
# the ip is ascii starting at the 4th byte and ending at the first null
ip_start = 4
ip_end = recv.index(0, ip_start)
state.ip = recv[ip_start:ip_end].decode('ascii')
# the port is a little endian unsigned short in the last two bytes
# yes, this is different endianness from everything else
state.port = struct.unpack_from('<H', recv, len(recv) - 2)[0]
log.debug('detected ip: {0.ip} port: {0.port}'.format(state))
yield from self.select_protocol(state.ip, state.port)
log.info('selected the voice protocol for use')
@asyncio.coroutine
def load_secret_key(self, data):
log.info('received secret key for voice connection')
self._connection.secret_key = data.get('secret_key')
yield from self.speak()
@asyncio.coroutine
def poll_event(self):
try:
msg = yield from self.recv()
yield from self.received_message(json.loads(msg))
except websockets.exceptions.ConnectionClosed as e:
raise ConnectionClosed(e) from e
@asyncio.coroutine
def close(self, code=1000, reason=''):
if self._keep_alive:
self._keep_alive.stop()
yield from super().close(code, reason)