mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-09-06 09:56:09 +00:00
Rewrite voice connection internals
This commit is contained in:
@ -34,7 +34,7 @@ import threading
|
||||
import traceback
|
||||
import zlib
|
||||
|
||||
from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar
|
||||
from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Tuple
|
||||
|
||||
import aiohttp
|
||||
import yarl
|
||||
@ -59,7 +59,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from .client import Client
|
||||
from .state import ConnectionState
|
||||
from .voice_client import VoiceClient
|
||||
from .voice_state import VoiceConnectionState
|
||||
|
||||
|
||||
class ReconnectWebSocket(Exception):
|
||||
@ -797,7 +797,7 @@ class DiscordVoiceWebSocket:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
thread_id: int
|
||||
_connection: VoiceClient
|
||||
_connection: VoiceConnectionState
|
||||
gateway: str
|
||||
_max_heartbeat_timeout: float
|
||||
|
||||
@ -866,16 +866,21 @@ class DiscordVoiceWebSocket:
|
||||
await self.send_as_json(payload)
|
||||
|
||||
@classmethod
|
||||
async def from_client(
|
||||
cls, client: VoiceClient, *, resume: bool = False, hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None
|
||||
async def from_connection_state(
|
||||
cls,
|
||||
state: VoiceConnectionState,
|
||||
*,
|
||||
resume: bool = False,
|
||||
hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None,
|
||||
) -> Self:
|
||||
"""Creates a voice websocket for the :class:`VoiceClient`."""
|
||||
gateway = 'wss://' + client.endpoint + '/?v=4'
|
||||
gateway = f'wss://{state.endpoint}/?v=4'
|
||||
client = state.voice_client
|
||||
http = client._state.http
|
||||
socket = await http.ws_connect(gateway, compress=15)
|
||||
ws = cls(socket, loop=client.loop, hook=hook)
|
||||
ws.gateway = gateway
|
||||
ws._connection = client
|
||||
ws._connection = state
|
||||
ws._max_heartbeat_timeout = 60.0
|
||||
ws.thread_id = threading.get_ident()
|
||||
|
||||
@ -951,30 +956,50 @@ class DiscordVoiceWebSocket:
|
||||
state.voice_port = data['port']
|
||||
state.endpoint_ip = data['ip']
|
||||
|
||||
packet = bytearray(74)
|
||||
struct.pack_into('>H', packet, 0, 1) # 1 = Send
|
||||
struct.pack_into('>H', packet, 2, 70) # 70 = Length
|
||||
struct.pack_into('>I', packet, 4, state.ssrc)
|
||||
state.socket.sendto(packet, (state.endpoint_ip, state.voice_port))
|
||||
recv = await self.loop.sock_recv(state.socket, 74)
|
||||
_log.debug('received packet in initial_connection: %s', recv)
|
||||
|
||||
# the ip is ascii starting at the 8th byte and ending at the first null
|
||||
ip_start = 8
|
||||
ip_end = recv.index(0, ip_start)
|
||||
state.ip = recv[ip_start:ip_end].decode('ascii')
|
||||
|
||||
state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
|
||||
_log.debug('detected ip: %s port: %s', state.ip, state.port)
|
||||
_log.debug('Connecting to voice socket')
|
||||
await self.loop.sock_connect(state.socket, (state.endpoint_ip, state.voice_port))
|
||||
|
||||
state.ip, state.port = await self.discover_ip()
|
||||
# there *should* always be at least one supported mode (xsalsa20_poly1305)
|
||||
modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes]
|
||||
_log.debug('received supported encryption modes: %s', ", ".join(modes))
|
||||
_log.debug('received supported encryption modes: %s', ', '.join(modes))
|
||||
|
||||
mode = modes[0]
|
||||
await self.select_protocol(state.ip, state.port, mode)
|
||||
_log.debug('selected the voice protocol for use (%s)', mode)
|
||||
|
||||
async def discover_ip(self) -> Tuple[str, int]:
|
||||
state = self._connection
|
||||
packet = bytearray(74)
|
||||
struct.pack_into('>H', packet, 0, 1) # 1 = Send
|
||||
struct.pack_into('>H', packet, 2, 70) # 70 = Length
|
||||
struct.pack_into('>I', packet, 4, state.ssrc)
|
||||
|
||||
_log.debug('Sending ip discovery packet')
|
||||
await self.loop.sock_sendall(state.socket, packet)
|
||||
|
||||
fut: asyncio.Future[bytes] = self.loop.create_future()
|
||||
|
||||
def get_ip_packet(data: bytes):
|
||||
if data[1] == 0x02 and len(data) == 74:
|
||||
self.loop.call_soon_threadsafe(fut.set_result, data)
|
||||
|
||||
fut.add_done_callback(lambda f: state.remove_socket_listener(get_ip_packet))
|
||||
state.add_socket_listener(get_ip_packet)
|
||||
recv = await fut
|
||||
|
||||
_log.debug('Received ip discovery packet: %s', recv)
|
||||
|
||||
# the ip is ascii starting at the 8th byte and ending at the first null
|
||||
ip_start = 8
|
||||
ip_end = recv.index(0, ip_start)
|
||||
ip = recv[ip_start:ip_end].decode('ascii')
|
||||
|
||||
port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
|
||||
_log.debug('detected ip: %s port: %s', ip, port)
|
||||
|
||||
return ip, port
|
||||
|
||||
@property
|
||||
def latency(self) -> float:
|
||||
""":class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds."""
|
||||
@ -995,9 +1020,8 @@ class DiscordVoiceWebSocket:
|
||||
self.secret_key = self._connection.secret_key = data['secret_key']
|
||||
|
||||
# Send a speak command with the "not speaking" state.
|
||||
# This also tells Discord our SSRC value, which Discord requires
|
||||
# before sending any voice data (and is the real reason why we
|
||||
# call this here).
|
||||
# This also tells Discord our SSRC value, which Discord requires before
|
||||
# sending any voice data (and is the real reason why we call this here).
|
||||
await self.speak(SpeakingState.none)
|
||||
|
||||
async def poll_event(self) -> None:
|
||||
@ -1006,10 +1030,10 @@ class DiscordVoiceWebSocket:
|
||||
if msg.type is aiohttp.WSMsgType.TEXT:
|
||||
await self.received_message(utils._from_json(msg.data))
|
||||
elif msg.type is aiohttp.WSMsgType.ERROR:
|
||||
_log.debug('Received %s', msg)
|
||||
_log.debug('Received voice %s', msg)
|
||||
raise ConnectionClosed(self.ws, shard_id=None) from msg.data
|
||||
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
|
||||
_log.debug('Received %s', msg)
|
||||
_log.debug('Received voice %s', msg)
|
||||
raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code)
|
||||
|
||||
async def close(self, code: int = 1000) -> None:
|
||||
|
Reference in New Issue
Block a user