Implement VoiceProtocol lower level hooks.

This allows changing the connect flow and taking control of it without
relying on internal events or tricks.
This commit is contained in:
Rapptz
2020-08-10 06:28:36 -04:00
parent 93fa46713a
commit 0b93fa3a82
9 changed files with 230 additions and 106 deletions

View File

@ -45,7 +45,7 @@ import logging
import struct
import threading
from . import opus
from . import opus, utils
from .backoff import ExponentialBackoff
from .gateway import *
from .errors import ClientException, ConnectionClosed
@ -59,7 +59,110 @@ except ImportError:
log = logging.getLogger(__name__)
class VoiceClient:
class VoiceProtocol:
"""A class that represents the Discord voice protocol.
This is an abstract class. The library provides a concrete implementation
under :class:`VoiceClient`.
This class allows you to implement a protocol to allow for an external
method of sending voice, such as Lavalink_ or a native library implementation.
These classes are passed to :meth:`abc.Connectable.connect`.
.. _Lavalink: https://github.com/Frederikam/Lavalink
Parameters
------------
client: :class:`Client`
The client (or its subclasses) that started the connection request.
channel: :class:`abc.Connectable`
The voice channel that is being connected to.
"""
def __init__(self, client, channel):
self.client = client
self.channel = channel
async def on_voice_state_update(self, data):
"""|coro|
An abstract method that is called when the client's voice state
has changed. This corresponds to ``VOICE_STATE_UPDATE``.
Parameters
------------
data: :class:`dict`
The raw `voice state payload`_.
.. _voice state payload: https://discord.com/developers/docs/resources/voice#voice-state-object
"""
raise NotImplementedError
async def on_voice_server_update(self, data):
"""|coro|
An abstract method that is called when initially connecting to voice.
This corresponds to ``VOICE_SERVER_UPDATE``.
Parameters
------------
data: :class:`dict`
The raw `voice server update payload`__.
.. _VSU: https://discord.com/developers/docs/topics/gateway#voice-server-update-voice-server-update-event-fields
__ VSU_
"""
raise NotImplementedError
async def connect(self, *, timeout, reconnect):
"""|coro|
An abstract method called when the client initiates the connection request.
When a connection is requested initially, the library calls the following functions
in order:
- ``__init__``
Parameters
------------
timeout: :class:`float`
The timeout for the connection.
reconnect: :class:`bool`
Whether reconnection is expected.
"""
raise NotImplementedError
async def disconnect(self, *, force):
"""|coro|
An abstract method called when the client terminates the connection.
See :meth:`cleanup`.
Parameters
------------
force: :class:`bool`
Whether the disconnection was forced.
"""
raise NotImplementedError
def cleanup(self):
"""This method *must* be called to ensure proper clean-up during a disconnect.
It is advisable to call this from within :meth:`disconnect` when you are
completely done with the voice protocol instance.
This method removes it from the internal state cache that keeps track of
currently alive voice clients. Failure to clean-up will cause subsequent
connections to report that it's still connected.
"""
key_id, _ = self.channel._get_voice_client_key()
self.client._connection._remove_voice_client(key_id)
class VoiceClient(VoiceProtocol):
"""Represents a Discord voice connection.
You do not create these, you typically get them from
@ -85,14 +188,13 @@ class VoiceClient:
loop: :class:`asyncio.AbstractEventLoop`
The event loop that the voice client is running on.
"""
def __init__(self, state, timeout, channel):
def __init__(self, client, channel):
if not has_nacl:
raise RuntimeError("PyNaCl library needed in order to use voice")
self.channel = channel
self.main_ws = None
self.timeout = timeout
self.ws = None
super().__init__(client, channel)
state = client._connection
self.token = None
self.socket = None
self.loop = state.loop
self._state = state
@ -100,8 +202,8 @@ class VoiceClient:
self._connected = threading.Event()
self._handshaking = False
self._handshake_check = asyncio.Lock()
self._handshake_complete = asyncio.Event()
self._voice_state_complete = asyncio.Event()
self._voice_server_complete = asyncio.Event()
self.mode = None
self._connections = 0
@ -138,48 +240,24 @@ class VoiceClient:
# connection related
async def start_handshake(self):
log.info('Starting voice handshake...')
async def on_voice_state_update(self, data):
self.session_id = data['session_id']
channel_id = data['channel_id']
guild_id, channel_id = self.channel._get_voice_state_pair()
state = self._state
self.main_ws = ws = state._get_websocket(guild_id)
self._connections += 1
if not self._handshaking:
# If we're done handshaking then we just need to update ourselves
guild = self.guild
self.channel = channel_id and guild and guild.get_channel(int(channel_id))
else:
self._voice_state_complete.set()
# request joining
await ws.voice_state(guild_id, channel_id)
async def on_voice_server_update(self, data):
if self._voice_server_complete.is_set():
log.info('Ignoring extraneous voice server update.')
return
try:
await asyncio.wait_for(self._handshake_complete.wait(), timeout=self.timeout)
except asyncio.TimeoutError:
await self.terminate_handshake(remove=True)
raise
log.info('Voice handshake complete. Endpoint found %s (IP: %s)', self.endpoint, self.endpoint_ip)
async def terminate_handshake(self, *, remove=False):
guild_id, channel_id = self.channel._get_voice_state_pair()
self._handshake_complete.clear()
await self.main_ws.voice_state(guild_id, None, self_mute=True)
self._handshaking = False
log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', channel_id, guild_id)
if remove:
log.info('The voice client has been removed for Channel ID %s (Guild ID %s)', channel_id, guild_id)
key_id, _ = self.channel._get_voice_client_key()
self._state._remove_voice_client(key_id)
async def _create_socket(self, server_id, data):
async with self._handshake_check:
if self._handshaking:
log.info("Ignoring voice server update while handshake is in progress")
return
self._handshaking = True
self._connected.clear()
self.session_id = self.main_ws.session_id
self.server_id = server_id
self.token = data.get('token')
self.server_id = int(data['guild_id'])
endpoint = data.get('endpoint')
if endpoint is None or self.token is None:
@ -195,23 +273,77 @@ class VoiceClient:
# This gets set later
self.endpoint_ip = None
if self.socket:
try:
self.socket.close()
except Exception:
pass
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.setblocking(False)
if self._handshake_complete.is_set():
# terminate the websocket and handle the reconnect loop if necessary.
self._handshake_complete.clear()
self._handshaking = False
if not self._handshaking:
# If we're not handshaking then we need to terminate our previous connection in the websocket
await self.ws.close(4000)
return
self._handshake_complete.set()
self._voice_server_complete.set()
async def voice_connect(self):
self._connections += 1
await self.channel.guild.change_voice_state(channel=self.channel)
async def voice_disconnect(self):
log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id)
await self.channel.guild.change_voice_state(channel=None)
async def connect(self, *, reconnect, timeout):
log.info('Connecting to voice...')
self.timeout = timeout
try:
del self.secret_key
except AttributeError:
pass
for i in range(5):
self._voice_state_complete.clear()
self._voice_server_complete.clear()
self._handshaking = True
# This has to be created before we start the flow.
futures = [
self._voice_state_complete.wait(),
self._voice_server_complete.wait(),
]
# Start the connection flow
log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1)
await self.voice_connect()
try:
await utils.sane_wait_for(futures, timeout=timeout)
except asyncio.TimeoutError:
await self.disconnect(force=True)
raise
log.info('Voice handshake complete. Endpoint found %s', self.endpoint)
self._handshaking = False
self._voice_server_complete.clear()
self._voice_state_complete.clear()
try:
self.ws = await DiscordVoiceWebSocket.from_client(self)
self._connected.clear()
while not hasattr(self, 'secret_key'):
await self.ws.poll_event()
self._connected.set()
break
except (ConnectionClosed, asyncio.TimeoutError):
if reconnect:
log.exception('Failed to connect to voice... Retrying...')
await asyncio.sleep(1 + i * 2.0)
await self.voice_disconnect()
continue
else:
raise
if self._runner is None:
self._runner = self.loop.create_task(self.poll_voice_ws(reconnect))
@property
def latency(self):
@ -234,35 +366,6 @@ class VoiceClient:
ws = self.ws
return float("inf") if not ws else ws.average_latency
async def connect(self, *, reconnect=True, _tries=0, do_handshake=True):
log.info('Connecting to voice...')
try:
del self.secret_key
except AttributeError:
pass
if do_handshake:
await self.start_handshake()
try:
self.ws = await DiscordVoiceWebSocket.from_client(self)
self._handshaking = False
self._connected.clear()
while not hasattr(self, 'secret_key'):
await self.ws.poll_event()
self._connected.set()
except (ConnectionClosed, asyncio.TimeoutError):
if reconnect and _tries < 5:
log.exception('Failed to connect to voice... Retrying...')
await asyncio.sleep(1 + _tries * 2.0)
await self.terminate_handshake()
await self.connect(reconnect=reconnect, _tries=_tries + 1)
else:
raise
if self._runner is None:
self._runner = self.loop.create_task(self.poll_voice_ws(reconnect))
async def poll_voice_ws(self, reconnect):
backoff = ExponentialBackoff()
while True:
@ -287,9 +390,9 @@ class VoiceClient:
log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry)
self._connected.clear()
await asyncio.sleep(retry)
await self.terminate_handshake()
await self.voice_disconnect()
try:
await self.connect(reconnect=True)
await self.connect(reconnect=True, timeout=self.timeout)
except asyncio.TimeoutError:
# at this point we've retried 5 times... let's continue the loop.
log.warning('Could not connect to voice... Retrying...')
@ -310,8 +413,9 @@ class VoiceClient:
if self.ws:
await self.ws.close()
await self.terminate_handshake(remove=True)
await self.voice_disconnect()
finally:
self.cleanup()
if self.socket:
self.socket.close()
@ -325,8 +429,7 @@ class VoiceClient:
channel: :class:`abc.Snowflake`
The channel to move to. Must be a voice channel.
"""
guild_id, _ = self.channel._get_voice_state_pair()
await self.main_ws.voice_state(guild_id, channel.id)
await self.channel.guild.change_voice_state(channel=channel)
def is_connected(self):
"""Indicates if the voice client is connected to voice."""