Rewrite voice connection internals

This commit is contained in:
Imayhaveborkedit
2023-09-28 17:51:22 -04:00
committed by GitHub
parent 555940352b
commit 44284ae107
5 changed files with 738 additions and 282 deletions

View File

@ -34,7 +34,7 @@ import threading
import traceback
import zlib
from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar
from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Tuple
import aiohttp
import yarl
@ -59,7 +59,7 @@ if TYPE_CHECKING:
from .client import Client
from .state import ConnectionState
from .voice_client import VoiceClient
from .voice_state import VoiceConnectionState
class ReconnectWebSocket(Exception):
@ -797,7 +797,7 @@ class DiscordVoiceWebSocket:
if TYPE_CHECKING:
thread_id: int
_connection: VoiceClient
_connection: VoiceConnectionState
gateway: str
_max_heartbeat_timeout: float
@ -866,16 +866,21 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload)
@classmethod
async def from_client(
cls, client: VoiceClient, *, resume: bool = False, hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None
async def from_connection_state(
cls,
state: VoiceConnectionState,
*,
resume: bool = False,
hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None,
) -> Self:
"""Creates a voice websocket for the :class:`VoiceClient`."""
gateway = 'wss://' + client.endpoint + '/?v=4'
gateway = f'wss://{state.endpoint}/?v=4'
client = state.voice_client
http = client._state.http
socket = await http.ws_connect(gateway, compress=15)
ws = cls(socket, loop=client.loop, hook=hook)
ws.gateway = gateway
ws._connection = client
ws._connection = state
ws._max_heartbeat_timeout = 60.0
ws.thread_id = threading.get_ident()
@ -951,30 +956,50 @@ class DiscordVoiceWebSocket:
state.voice_port = data['port']
state.endpoint_ip = data['ip']
packet = bytearray(74)
struct.pack_into('>H', packet, 0, 1) # 1 = Send
struct.pack_into('>H', packet, 2, 70) # 70 = Length
struct.pack_into('>I', packet, 4, state.ssrc)
state.socket.sendto(packet, (state.endpoint_ip, state.voice_port))
recv = await self.loop.sock_recv(state.socket, 74)
_log.debug('received packet in initial_connection: %s', recv)
# the ip is ascii starting at the 8th byte and ending at the first null
ip_start = 8
ip_end = recv.index(0, ip_start)
state.ip = recv[ip_start:ip_end].decode('ascii')
state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
_log.debug('detected ip: %s port: %s', state.ip, state.port)
_log.debug('Connecting to voice socket')
await self.loop.sock_connect(state.socket, (state.endpoint_ip, state.voice_port))
state.ip, state.port = await self.discover_ip()
# there *should* always be at least one supported mode (xsalsa20_poly1305)
modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes]
_log.debug('received supported encryption modes: %s', ", ".join(modes))
_log.debug('received supported encryption modes: %s', ', '.join(modes))
mode = modes[0]
await self.select_protocol(state.ip, state.port, mode)
_log.debug('selected the voice protocol for use (%s)', mode)
async def discover_ip(self) -> Tuple[str, int]:
state = self._connection
packet = bytearray(74)
struct.pack_into('>H', packet, 0, 1) # 1 = Send
struct.pack_into('>H', packet, 2, 70) # 70 = Length
struct.pack_into('>I', packet, 4, state.ssrc)
_log.debug('Sending ip discovery packet')
await self.loop.sock_sendall(state.socket, packet)
fut: asyncio.Future[bytes] = self.loop.create_future()
def get_ip_packet(data: bytes):
if data[1] == 0x02 and len(data) == 74:
self.loop.call_soon_threadsafe(fut.set_result, data)
fut.add_done_callback(lambda f: state.remove_socket_listener(get_ip_packet))
state.add_socket_listener(get_ip_packet)
recv = await fut
_log.debug('Received ip discovery packet: %s', recv)
# the ip is ascii starting at the 8th byte and ending at the first null
ip_start = 8
ip_end = recv.index(0, ip_start)
ip = recv[ip_start:ip_end].decode('ascii')
port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
_log.debug('detected ip: %s port: %s', ip, port)
return ip, port
@property
def latency(self) -> float:
""":class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds."""
@ -995,9 +1020,8 @@ class DiscordVoiceWebSocket:
self.secret_key = self._connection.secret_key = data['secret_key']
# Send a speak command with the "not speaking" state.
# This also tells Discord our SSRC value, which Discord requires
# before sending any voice data (and is the real reason why we
# call this here).
# This also tells Discord our SSRC value, which Discord requires before
# sending any voice data (and is the real reason why we call this here).
await self.speak(SpeakingState.none)
async def poll_event(self) -> None:
@ -1006,10 +1030,10 @@ class DiscordVoiceWebSocket:
if msg.type is aiohttp.WSMsgType.TEXT:
await self.received_message(utils._from_json(msg.data))
elif msg.type is aiohttp.WSMsgType.ERROR:
_log.debug('Received %s', msg)
_log.debug('Received voice %s', msg)
raise ConnectionClosed(self.ws, shard_id=None) from msg.data
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
_log.debug('Received %s', msg)
_log.debug('Received voice %s', msg)
raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code)
async def close(self, code: int = 1000) -> None: