mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-09-06 09:56:09 +00:00
Rewrite voice connection internals
This commit is contained in:
@ -20,40 +20,24 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
|
||||
|
||||
Some documentation to refer to:
|
||||
|
||||
- Our main web socket (mWS) sends opcode 4 with a guild ID and channel ID.
|
||||
- The mWS receives VOICE_STATE_UPDATE and VOICE_SERVER_UPDATE.
|
||||
- We pull the session_id from VOICE_STATE_UPDATE.
|
||||
- We pull the token, endpoint and server_id from VOICE_SERVER_UPDATE.
|
||||
- Then we initiate the voice web socket (vWS) pointing to the endpoint.
|
||||
- We send opcode 0 with the user_id, server_id, session_id and token using the vWS.
|
||||
- The vWS sends back opcode 2 with an ssrc, port, modes(array) and heartbeat_interval.
|
||||
- We send a UDP discovery packet to endpoint:port and receive our IP and our port in LE.
|
||||
- Then we send our IP and port via vWS with opcode 1.
|
||||
- When that's all done, we receive opcode 4 from the vWS.
|
||||
- Finally we can transmit data to endpoint:port.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import logging
|
||||
import struct
|
||||
import threading
|
||||
from typing import Any, Callable, List, Optional, TYPE_CHECKING, Tuple, Union
|
||||
|
||||
from . import opus, utils
|
||||
from .backoff import ExponentialBackoff
|
||||
from . import opus
|
||||
from .gateway import *
|
||||
from .errors import ClientException, ConnectionClosed
|
||||
from .errors import ClientException
|
||||
from .player import AudioPlayer, AudioSource
|
||||
from .utils import MISSING
|
||||
from .voice_state import VoiceConnectionState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .gateway import DiscordVoiceWebSocket
|
||||
from .client import Client
|
||||
from .guild import Guild
|
||||
from .state import ConnectionState
|
||||
@ -226,12 +210,6 @@ class VoiceClient(VoiceProtocol):
|
||||
"""
|
||||
|
||||
channel: VocalGuildChannel
|
||||
endpoint_ip: str
|
||||
voice_port: int
|
||||
ip: str
|
||||
port: int
|
||||
secret_key: List[int]
|
||||
ssrc: int
|
||||
|
||||
def __init__(self, client: Client, channel: abc.Connectable) -> None:
|
||||
if not has_nacl:
|
||||
@ -239,29 +217,18 @@ class VoiceClient(VoiceProtocol):
|
||||
|
||||
super().__init__(client, channel)
|
||||
state = client._connection
|
||||
self.token: str = MISSING
|
||||
self.server_id: int = MISSING
|
||||
self.socket = MISSING
|
||||
self.loop: asyncio.AbstractEventLoop = state.loop
|
||||
self._state: ConnectionState = state
|
||||
# this will be used in the AudioPlayer thread
|
||||
self._connected: threading.Event = threading.Event()
|
||||
|
||||
self._handshaking: bool = False
|
||||
self._potentially_reconnecting: bool = False
|
||||
self._voice_state_complete: asyncio.Event = asyncio.Event()
|
||||
self._voice_server_complete: asyncio.Event = asyncio.Event()
|
||||
|
||||
self.mode: str = MISSING
|
||||
self._connections: int = 0
|
||||
self.sequence: int = 0
|
||||
self.timestamp: int = 0
|
||||
self.timeout: float = 0
|
||||
self._runner: asyncio.Task = MISSING
|
||||
self._player: Optional[AudioPlayer] = None
|
||||
self.encoder: Encoder = MISSING
|
||||
self._lite_nonce: int = 0
|
||||
self.ws: DiscordVoiceWebSocket = MISSING
|
||||
|
||||
self._connection: VoiceConnectionState = self.create_connection_state()
|
||||
|
||||
warn_nacl: bool = not has_nacl
|
||||
supported_modes: Tuple[SupportedModes, ...] = (
|
||||
@ -280,6 +247,38 @@ class VoiceClient(VoiceProtocol):
|
||||
""":class:`ClientUser`: The user connected to voice (i.e. ourselves)."""
|
||||
return self._state.user # type: ignore
|
||||
|
||||
@property
|
||||
def session_id(self) -> Optional[str]:
|
||||
return self._connection.session_id
|
||||
|
||||
@property
|
||||
def token(self) -> Optional[str]:
|
||||
return self._connection.token
|
||||
|
||||
@property
|
||||
def endpoint(self) -> Optional[str]:
|
||||
return self._connection.endpoint
|
||||
|
||||
@property
|
||||
def ssrc(self) -> int:
|
||||
return self._connection.ssrc
|
||||
|
||||
@property
|
||||
def mode(self) -> SupportedModes:
|
||||
return self._connection.mode
|
||||
|
||||
@property
|
||||
def secret_key(self) -> List[int]:
|
||||
return self._connection.secret_key
|
||||
|
||||
@property
|
||||
def ws(self) -> DiscordVoiceWebSocket:
|
||||
return self._connection.ws
|
||||
|
||||
@property
|
||||
def timeout(self) -> float:
|
||||
return self._connection.timeout
|
||||
|
||||
def checked_add(self, attr: str, value: int, limit: int) -> None:
|
||||
val = getattr(self, attr)
|
||||
if val + value > limit:
|
||||
@ -289,149 +288,23 @@ class VoiceClient(VoiceProtocol):
|
||||
|
||||
# connection related
|
||||
|
||||
async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
|
||||
self.session_id: str = data['session_id']
|
||||
channel_id = data['channel_id']
|
||||
def create_connection_state(self) -> VoiceConnectionState:
|
||||
return VoiceConnectionState(self)
|
||||
|
||||
if not self._handshaking or self._potentially_reconnecting:
|
||||
# If we're done handshaking then we just need to update ourselves
|
||||
# If we're potentially reconnecting due to a 4014, then we need to differentiate
|
||||
# a channel move and an actual force disconnect
|
||||
if channel_id is None:
|
||||
# We're being disconnected so cleanup
|
||||
await self.disconnect()
|
||||
else:
|
||||
self.channel = channel_id and self.guild.get_channel(int(channel_id)) # type: ignore
|
||||
else:
|
||||
self._voice_state_complete.set()
|
||||
async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
|
||||
await self._connection.voice_state_update(data)
|
||||
|
||||
async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
|
||||
if self._voice_server_complete.is_set():
|
||||
_log.warning('Ignoring extraneous voice server update.')
|
||||
return
|
||||
|
||||
self.token = data['token']
|
||||
self.server_id = int(data['guild_id'])
|
||||
endpoint = data.get('endpoint')
|
||||
|
||||
if endpoint is None or self.token is None:
|
||||
_log.warning(
|
||||
'Awaiting endpoint... This requires waiting. '
|
||||
'If timeout occurred considering raising the timeout and reconnecting.'
|
||||
)
|
||||
return
|
||||
|
||||
self.endpoint, _, _ = endpoint.rpartition(':')
|
||||
if self.endpoint.startswith('wss://'):
|
||||
# Just in case, strip it off since we're going to add it later
|
||||
self.endpoint: str = self.endpoint[6:]
|
||||
|
||||
# This gets set later
|
||||
self.endpoint_ip = MISSING
|
||||
|
||||
self.socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.socket.setblocking(False)
|
||||
|
||||
if not self._handshaking:
|
||||
# If we're not handshaking then we need to terminate our previous connection in the websocket
|
||||
await self.ws.close(4000)
|
||||
return
|
||||
|
||||
self._voice_server_complete.set()
|
||||
|
||||
async def voice_connect(self, self_deaf: bool = False, self_mute: bool = False) -> None:
|
||||
await self.channel.guild.change_voice_state(channel=self.channel, self_deaf=self_deaf, self_mute=self_mute)
|
||||
|
||||
async def voice_disconnect(self) -> None:
|
||||
_log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id)
|
||||
await self.channel.guild.change_voice_state(channel=None)
|
||||
|
||||
def prepare_handshake(self) -> None:
|
||||
self._voice_state_complete.clear()
|
||||
self._voice_server_complete.clear()
|
||||
self._handshaking = True
|
||||
_log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1)
|
||||
self._connections += 1
|
||||
|
||||
def finish_handshake(self) -> None:
|
||||
_log.info('Voice handshake complete. Endpoint found %s', self.endpoint)
|
||||
self._handshaking = False
|
||||
self._voice_server_complete.clear()
|
||||
self._voice_state_complete.clear()
|
||||
|
||||
async def connect_websocket(self) -> DiscordVoiceWebSocket:
|
||||
ws = await DiscordVoiceWebSocket.from_client(self)
|
||||
self._connected.clear()
|
||||
while ws.secret_key is None:
|
||||
await ws.poll_event()
|
||||
self._connected.set()
|
||||
return ws
|
||||
await self._connection.voice_server_update(data)
|
||||
|
||||
async def connect(self, *, reconnect: bool, timeout: float, self_deaf: bool = False, self_mute: bool = False) -> None:
|
||||
_log.info('Connecting to voice...')
|
||||
self.timeout = timeout
|
||||
await self._connection.connect(
|
||||
reconnect=reconnect, timeout=timeout, self_deaf=self_deaf, self_mute=self_mute, resume=False
|
||||
)
|
||||
|
||||
for i in range(5):
|
||||
self.prepare_handshake()
|
||||
|
||||
# This has to be created before we start the flow.
|
||||
futures = [
|
||||
self._voice_state_complete.wait(),
|
||||
self._voice_server_complete.wait(),
|
||||
]
|
||||
|
||||
# Start the connection flow
|
||||
await self.voice_connect(self_deaf=self_deaf, self_mute=self_mute)
|
||||
|
||||
try:
|
||||
await utils.sane_wait_for(futures, timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
await self.disconnect(force=True)
|
||||
raise
|
||||
|
||||
self.finish_handshake()
|
||||
|
||||
try:
|
||||
self.ws = await self.connect_websocket()
|
||||
break
|
||||
except (ConnectionClosed, asyncio.TimeoutError):
|
||||
if reconnect:
|
||||
_log.exception('Failed to connect to voice... Retrying...')
|
||||
await asyncio.sleep(1 + i * 2.0)
|
||||
await self.voice_disconnect()
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
|
||||
if self._runner is MISSING:
|
||||
self._runner = self.client.loop.create_task(self.poll_voice_ws(reconnect))
|
||||
|
||||
async def potential_reconnect(self) -> bool:
|
||||
# Attempt to stop the player thread from playing early
|
||||
self._connected.clear()
|
||||
self.prepare_handshake()
|
||||
self._potentially_reconnecting = True
|
||||
try:
|
||||
# We only care about VOICE_SERVER_UPDATE since VOICE_STATE_UPDATE can come before we get disconnected
|
||||
await asyncio.wait_for(self._voice_server_complete.wait(), timeout=self.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
self._potentially_reconnecting = False
|
||||
await self.disconnect(force=True)
|
||||
return False
|
||||
|
||||
self.finish_handshake()
|
||||
self._potentially_reconnecting = False
|
||||
|
||||
if self.ws:
|
||||
_log.debug("Closing existing voice websocket")
|
||||
await self.ws.close()
|
||||
|
||||
try:
|
||||
self.ws = await self.connect_websocket()
|
||||
except (ConnectionClosed, asyncio.TimeoutError):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
def wait_until_connected(self, timeout: Optional[float] = 30.0) -> bool:
|
||||
self._connection.wait(timeout)
|
||||
return self._connection.is_connected()
|
||||
|
||||
@property
|
||||
def latency(self) -> float:
|
||||
@ -442,7 +315,7 @@ class VoiceClient(VoiceProtocol):
|
||||
|
||||
.. versionadded:: 1.4
|
||||
"""
|
||||
ws = self.ws
|
||||
ws = self._connection.ws
|
||||
return float("inf") if not ws else ws.latency
|
||||
|
||||
@property
|
||||
@ -451,72 +324,19 @@ class VoiceClient(VoiceProtocol):
|
||||
|
||||
.. versionadded:: 1.4
|
||||
"""
|
||||
ws = self.ws
|
||||
ws = self._connection.ws
|
||||
return float("inf") if not ws else ws.average_latency
|
||||
|
||||
async def poll_voice_ws(self, reconnect: bool) -> None:
|
||||
backoff = ExponentialBackoff()
|
||||
while True:
|
||||
try:
|
||||
await self.ws.poll_event()
|
||||
except (ConnectionClosed, asyncio.TimeoutError) as exc:
|
||||
if isinstance(exc, ConnectionClosed):
|
||||
# The following close codes are undocumented so I will document them here.
|
||||
# 1000 - normal closure (obviously)
|
||||
# 4014 - voice channel has been deleted.
|
||||
# 4015 - voice server has crashed
|
||||
if exc.code in (1000, 4015):
|
||||
_log.info('Disconnecting from voice normally, close code %d.', exc.code)
|
||||
await self.disconnect()
|
||||
break
|
||||
if exc.code == 4014:
|
||||
_log.info('Disconnected from voice by force... potentially reconnecting.')
|
||||
successful = await self.potential_reconnect()
|
||||
if not successful:
|
||||
_log.info('Reconnect was unsuccessful, disconnecting from voice normally...')
|
||||
await self.disconnect()
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
if not reconnect:
|
||||
await self.disconnect()
|
||||
raise
|
||||
|
||||
retry = backoff.delay()
|
||||
_log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry)
|
||||
self._connected.clear()
|
||||
await asyncio.sleep(retry)
|
||||
await self.voice_disconnect()
|
||||
try:
|
||||
await self.connect(reconnect=True, timeout=self.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
# at this point we've retried 5 times... let's continue the loop.
|
||||
_log.warning('Could not connect to voice... Retrying...')
|
||||
continue
|
||||
|
||||
async def disconnect(self, *, force: bool = False) -> None:
|
||||
"""|coro|
|
||||
|
||||
Disconnects this voice client from voice.
|
||||
"""
|
||||
if not force and not self.is_connected():
|
||||
return
|
||||
|
||||
self.stop()
|
||||
self._connected.clear()
|
||||
await self._connection.disconnect(force=force)
|
||||
self.cleanup()
|
||||
|
||||
try:
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
await self.voice_disconnect()
|
||||
finally:
|
||||
self.cleanup()
|
||||
if self.socket:
|
||||
self.socket.close()
|
||||
|
||||
async def move_to(self, channel: Optional[abc.Snowflake]) -> None:
|
||||
async def move_to(self, channel: Optional[abc.Snowflake], *, timeout: Optional[float] = 30.0) -> None:
|
||||
"""|coro|
|
||||
|
||||
Moves you to a different voice channel.
|
||||
@ -525,12 +345,22 @@ class VoiceClient(VoiceProtocol):
|
||||
-----------
|
||||
channel: Optional[:class:`abc.Snowflake`]
|
||||
The channel to move to. Must be a voice channel.
|
||||
timeout: Optional[:class:`float`]
|
||||
How long to wait for the move to complete.
|
||||
|
||||
.. versionadded:: 2.4
|
||||
|
||||
Raises
|
||||
-------
|
||||
asyncio.TimeoutError
|
||||
The move did not complete in time, but may still be ongoing.
|
||||
"""
|
||||
await self.channel.guild.change_voice_state(channel=channel)
|
||||
await self._connection.move_to(channel)
|
||||
await self._connection.wait_async(timeout)
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""Indicates if the voice client is connected to voice."""
|
||||
return self._connected.is_set()
|
||||
return self._connection.is_connected()
|
||||
|
||||
# audio related
|
||||
|
||||
@ -703,7 +533,7 @@ class VoiceClient(VoiceProtocol):
|
||||
if self._player is None:
|
||||
raise ValueError('Not playing anything.')
|
||||
|
||||
self._player._set_source(value)
|
||||
self._player.set_source(value)
|
||||
|
||||
def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None:
|
||||
"""Sends an audio packet composed of the data.
|
||||
@ -732,8 +562,8 @@ class VoiceClient(VoiceProtocol):
|
||||
encoded_data = data
|
||||
packet = self._get_voice_packet(encoded_data)
|
||||
try:
|
||||
self.socket.sendto(packet, (self.endpoint_ip, self.voice_port))
|
||||
except BlockingIOError:
|
||||
_log.warning('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp)
|
||||
self._connection.send_packet(packet)
|
||||
except OSError:
|
||||
_log.info('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp)
|
||||
|
||||
self.checked_add('timestamp', opus.Encoder.SAMPLES_PER_FRAME, 4294967295)
|
||||
|
Reference in New Issue
Block a user