Rewrite voice connection internals

This commit is contained in:
Imayhaveborkedit
2023-09-28 17:51:22 -04:00
committed by GitHub
parent 555940352b
commit 44284ae107
5 changed files with 738 additions and 282 deletions

View File

@ -20,40 +20,24 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
Some documentation to refer to:
- Our main web socket (mWS) sends opcode 4 with a guild ID and channel ID.
- The mWS receives VOICE_STATE_UPDATE and VOICE_SERVER_UPDATE.
- We pull the session_id from VOICE_STATE_UPDATE.
- We pull the token, endpoint and server_id from VOICE_SERVER_UPDATE.
- Then we initiate the voice web socket (vWS) pointing to the endpoint.
- We send opcode 0 with the user_id, server_id, session_id and token using the vWS.
- The vWS sends back opcode 2 with an ssrc, port, modes(array) and heartbeat_interval.
- We send a UDP discovery packet to endpoint:port and receive our IP and our port in LE.
- Then we send our IP and port via vWS with opcode 1.
- When that's all done, we receive opcode 4 from the vWS.
- Finally we can transmit data to endpoint:port.
"""
from __future__ import annotations
import asyncio
import socket
import logging
import struct
import threading
from typing import Any, Callable, List, Optional, TYPE_CHECKING, Tuple, Union
from . import opus, utils
from .backoff import ExponentialBackoff
from . import opus
from .gateway import *
from .errors import ClientException, ConnectionClosed
from .errors import ClientException
from .player import AudioPlayer, AudioSource
from .utils import MISSING
from .voice_state import VoiceConnectionState
if TYPE_CHECKING:
from .gateway import DiscordVoiceWebSocket
from .client import Client
from .guild import Guild
from .state import ConnectionState
@ -226,12 +210,6 @@ class VoiceClient(VoiceProtocol):
"""
channel: VocalGuildChannel
endpoint_ip: str
voice_port: int
ip: str
port: int
secret_key: List[int]
ssrc: int
def __init__(self, client: Client, channel: abc.Connectable) -> None:
if not has_nacl:
@ -239,29 +217,18 @@ class VoiceClient(VoiceProtocol):
super().__init__(client, channel)
state = client._connection
self.token: str = MISSING
self.server_id: int = MISSING
self.socket = MISSING
self.loop: asyncio.AbstractEventLoop = state.loop
self._state: ConnectionState = state
# this will be used in the AudioPlayer thread
self._connected: threading.Event = threading.Event()
self._handshaking: bool = False
self._potentially_reconnecting: bool = False
self._voice_state_complete: asyncio.Event = asyncio.Event()
self._voice_server_complete: asyncio.Event = asyncio.Event()
self.mode: str = MISSING
self._connections: int = 0
self.sequence: int = 0
self.timestamp: int = 0
self.timeout: float = 0
self._runner: asyncio.Task = MISSING
self._player: Optional[AudioPlayer] = None
self.encoder: Encoder = MISSING
self._lite_nonce: int = 0
self.ws: DiscordVoiceWebSocket = MISSING
self._connection: VoiceConnectionState = self.create_connection_state()
warn_nacl: bool = not has_nacl
supported_modes: Tuple[SupportedModes, ...] = (
@ -280,6 +247,38 @@ class VoiceClient(VoiceProtocol):
""":class:`ClientUser`: The user connected to voice (i.e. ourselves)."""
return self._state.user # type: ignore
@property
def session_id(self) -> Optional[str]:
return self._connection.session_id
@property
def token(self) -> Optional[str]:
return self._connection.token
@property
def endpoint(self) -> Optional[str]:
return self._connection.endpoint
@property
def ssrc(self) -> int:
return self._connection.ssrc
@property
def mode(self) -> SupportedModes:
return self._connection.mode
@property
def secret_key(self) -> List[int]:
return self._connection.secret_key
@property
def ws(self) -> DiscordVoiceWebSocket:
return self._connection.ws
@property
def timeout(self) -> float:
return self._connection.timeout
def checked_add(self, attr: str, value: int, limit: int) -> None:
val = getattr(self, attr)
if val + value > limit:
@ -289,149 +288,23 @@ class VoiceClient(VoiceProtocol):
# connection related
async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
self.session_id: str = data['session_id']
channel_id = data['channel_id']
def create_connection_state(self) -> VoiceConnectionState:
return VoiceConnectionState(self)
if not self._handshaking or self._potentially_reconnecting:
# If we're done handshaking then we just need to update ourselves
# If we're potentially reconnecting due to a 4014, then we need to differentiate
# a channel move and an actual force disconnect
if channel_id is None:
# We're being disconnected so cleanup
await self.disconnect()
else:
self.channel = channel_id and self.guild.get_channel(int(channel_id)) # type: ignore
else:
self._voice_state_complete.set()
async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
await self._connection.voice_state_update(data)
async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
if self._voice_server_complete.is_set():
_log.warning('Ignoring extraneous voice server update.')
return
self.token = data['token']
self.server_id = int(data['guild_id'])
endpoint = data.get('endpoint')
if endpoint is None or self.token is None:
_log.warning(
'Awaiting endpoint... This requires waiting. '
'If timeout occurred considering raising the timeout and reconnecting.'
)
return
self.endpoint, _, _ = endpoint.rpartition(':')
if self.endpoint.startswith('wss://'):
# Just in case, strip it off since we're going to add it later
self.endpoint: str = self.endpoint[6:]
# This gets set later
self.endpoint_ip = MISSING
self.socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.setblocking(False)
if not self._handshaking:
# If we're not handshaking then we need to terminate our previous connection in the websocket
await self.ws.close(4000)
return
self._voice_server_complete.set()
async def voice_connect(self, self_deaf: bool = False, self_mute: bool = False) -> None:
await self.channel.guild.change_voice_state(channel=self.channel, self_deaf=self_deaf, self_mute=self_mute)
async def voice_disconnect(self) -> None:
_log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id)
await self.channel.guild.change_voice_state(channel=None)
def prepare_handshake(self) -> None:
self._voice_state_complete.clear()
self._voice_server_complete.clear()
self._handshaking = True
_log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1)
self._connections += 1
def finish_handshake(self) -> None:
_log.info('Voice handshake complete. Endpoint found %s', self.endpoint)
self._handshaking = False
self._voice_server_complete.clear()
self._voice_state_complete.clear()
async def connect_websocket(self) -> DiscordVoiceWebSocket:
ws = await DiscordVoiceWebSocket.from_client(self)
self._connected.clear()
while ws.secret_key is None:
await ws.poll_event()
self._connected.set()
return ws
await self._connection.voice_server_update(data)
async def connect(self, *, reconnect: bool, timeout: float, self_deaf: bool = False, self_mute: bool = False) -> None:
_log.info('Connecting to voice...')
self.timeout = timeout
await self._connection.connect(
reconnect=reconnect, timeout=timeout, self_deaf=self_deaf, self_mute=self_mute, resume=False
)
for i in range(5):
self.prepare_handshake()
# This has to be created before we start the flow.
futures = [
self._voice_state_complete.wait(),
self._voice_server_complete.wait(),
]
# Start the connection flow
await self.voice_connect(self_deaf=self_deaf, self_mute=self_mute)
try:
await utils.sane_wait_for(futures, timeout=timeout)
except asyncio.TimeoutError:
await self.disconnect(force=True)
raise
self.finish_handshake()
try:
self.ws = await self.connect_websocket()
break
except (ConnectionClosed, asyncio.TimeoutError):
if reconnect:
_log.exception('Failed to connect to voice... Retrying...')
await asyncio.sleep(1 + i * 2.0)
await self.voice_disconnect()
continue
else:
raise
if self._runner is MISSING:
self._runner = self.client.loop.create_task(self.poll_voice_ws(reconnect))
async def potential_reconnect(self) -> bool:
# Attempt to stop the player thread from playing early
self._connected.clear()
self.prepare_handshake()
self._potentially_reconnecting = True
try:
# We only care about VOICE_SERVER_UPDATE since VOICE_STATE_UPDATE can come before we get disconnected
await asyncio.wait_for(self._voice_server_complete.wait(), timeout=self.timeout)
except asyncio.TimeoutError:
self._potentially_reconnecting = False
await self.disconnect(force=True)
return False
self.finish_handshake()
self._potentially_reconnecting = False
if self.ws:
_log.debug("Closing existing voice websocket")
await self.ws.close()
try:
self.ws = await self.connect_websocket()
except (ConnectionClosed, asyncio.TimeoutError):
return False
else:
return True
def wait_until_connected(self, timeout: Optional[float] = 30.0) -> bool:
self._connection.wait(timeout)
return self._connection.is_connected()
@property
def latency(self) -> float:
@ -442,7 +315,7 @@ class VoiceClient(VoiceProtocol):
.. versionadded:: 1.4
"""
ws = self.ws
ws = self._connection.ws
return float("inf") if not ws else ws.latency
@property
@ -451,72 +324,19 @@ class VoiceClient(VoiceProtocol):
.. versionadded:: 1.4
"""
ws = self.ws
ws = self._connection.ws
return float("inf") if not ws else ws.average_latency
async def poll_voice_ws(self, reconnect: bool) -> None:
backoff = ExponentialBackoff()
while True:
try:
await self.ws.poll_event()
except (ConnectionClosed, asyncio.TimeoutError) as exc:
if isinstance(exc, ConnectionClosed):
# The following close codes are undocumented so I will document them here.
# 1000 - normal closure (obviously)
# 4014 - voice channel has been deleted.
# 4015 - voice server has crashed
if exc.code in (1000, 4015):
_log.info('Disconnecting from voice normally, close code %d.', exc.code)
await self.disconnect()
break
if exc.code == 4014:
_log.info('Disconnected from voice by force... potentially reconnecting.')
successful = await self.potential_reconnect()
if not successful:
_log.info('Reconnect was unsuccessful, disconnecting from voice normally...')
await self.disconnect()
break
else:
continue
if not reconnect:
await self.disconnect()
raise
retry = backoff.delay()
_log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry)
self._connected.clear()
await asyncio.sleep(retry)
await self.voice_disconnect()
try:
await self.connect(reconnect=True, timeout=self.timeout)
except asyncio.TimeoutError:
# at this point we've retried 5 times... let's continue the loop.
_log.warning('Could not connect to voice... Retrying...')
continue
async def disconnect(self, *, force: bool = False) -> None:
"""|coro|
Disconnects this voice client from voice.
"""
if not force and not self.is_connected():
return
self.stop()
self._connected.clear()
await self._connection.disconnect(force=force)
self.cleanup()
try:
if self.ws:
await self.ws.close()
await self.voice_disconnect()
finally:
self.cleanup()
if self.socket:
self.socket.close()
async def move_to(self, channel: Optional[abc.Snowflake]) -> None:
async def move_to(self, channel: Optional[abc.Snowflake], *, timeout: Optional[float] = 30.0) -> None:
"""|coro|
Moves you to a different voice channel.
@ -525,12 +345,22 @@ class VoiceClient(VoiceProtocol):
-----------
channel: Optional[:class:`abc.Snowflake`]
The channel to move to. Must be a voice channel.
timeout: Optional[:class:`float`]
How long to wait for the move to complete.
.. versionadded:: 2.4
Raises
-------
asyncio.TimeoutError
The move did not complete in time, but may still be ongoing.
"""
await self.channel.guild.change_voice_state(channel=channel)
await self._connection.move_to(channel)
await self._connection.wait_async(timeout)
def is_connected(self) -> bool:
"""Indicates if the voice client is connected to voice."""
return self._connected.is_set()
return self._connection.is_connected()
# audio related
@ -703,7 +533,7 @@ class VoiceClient(VoiceProtocol):
if self._player is None:
raise ValueError('Not playing anything.')
self._player._set_source(value)
self._player.set_source(value)
def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None:
"""Sends an audio packet composed of the data.
@ -732,8 +562,8 @@ class VoiceClient(VoiceProtocol):
encoded_data = data
packet = self._get_voice_packet(encoded_data)
try:
self.socket.sendto(packet, (self.endpoint_ip, self.voice_port))
except BlockingIOError:
_log.warning('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp)
self._connection.send_packet(packet)
except OSError:
_log.info('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp)
self.checked_add('timestamp', opus.Encoder.SAMPLES_PER_FRAME, 4294967295)