mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-06-03 18:42:43 +00:00
Rewrite voice connection internals
This commit is contained in:
parent
555940352b
commit
44284ae107
@ -1842,7 +1842,7 @@ class Connectable(Protocol):
|
||||
async def connect(
|
||||
self,
|
||||
*,
|
||||
timeout: float = 60.0,
|
||||
timeout: float = 30.0,
|
||||
reconnect: bool = True,
|
||||
cls: Callable[[Client, Connectable], T] = VoiceClient,
|
||||
self_deaf: bool = False,
|
||||
@ -1858,7 +1858,7 @@ class Connectable(Protocol):
|
||||
Parameters
|
||||
-----------
|
||||
timeout: :class:`float`
|
||||
The timeout in seconds to wait for the voice endpoint.
|
||||
The timeout in seconds to wait the connection to complete.
|
||||
reconnect: :class:`bool`
|
||||
Whether the bot should automatically attempt
|
||||
a reconnect if a part of the handshake fails
|
||||
|
@ -34,7 +34,7 @@ import threading
|
||||
import traceback
|
||||
import zlib
|
||||
|
||||
from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar
|
||||
from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Tuple
|
||||
|
||||
import aiohttp
|
||||
import yarl
|
||||
@ -59,7 +59,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from .client import Client
|
||||
from .state import ConnectionState
|
||||
from .voice_client import VoiceClient
|
||||
from .voice_state import VoiceConnectionState
|
||||
|
||||
|
||||
class ReconnectWebSocket(Exception):
|
||||
@ -797,7 +797,7 @@ class DiscordVoiceWebSocket:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
thread_id: int
|
||||
_connection: VoiceClient
|
||||
_connection: VoiceConnectionState
|
||||
gateway: str
|
||||
_max_heartbeat_timeout: float
|
||||
|
||||
@ -866,16 +866,21 @@ class DiscordVoiceWebSocket:
|
||||
await self.send_as_json(payload)
|
||||
|
||||
@classmethod
|
||||
async def from_client(
|
||||
cls, client: VoiceClient, *, resume: bool = False, hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None
|
||||
async def from_connection_state(
|
||||
cls,
|
||||
state: VoiceConnectionState,
|
||||
*,
|
||||
resume: bool = False,
|
||||
hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None,
|
||||
) -> Self:
|
||||
"""Creates a voice websocket for the :class:`VoiceClient`."""
|
||||
gateway = 'wss://' + client.endpoint + '/?v=4'
|
||||
gateway = f'wss://{state.endpoint}/?v=4'
|
||||
client = state.voice_client
|
||||
http = client._state.http
|
||||
socket = await http.ws_connect(gateway, compress=15)
|
||||
ws = cls(socket, loop=client.loop, hook=hook)
|
||||
ws.gateway = gateway
|
||||
ws._connection = client
|
||||
ws._connection = state
|
||||
ws._max_heartbeat_timeout = 60.0
|
||||
ws.thread_id = threading.get_ident()
|
||||
|
||||
@ -951,30 +956,50 @@ class DiscordVoiceWebSocket:
|
||||
state.voice_port = data['port']
|
||||
state.endpoint_ip = data['ip']
|
||||
|
||||
packet = bytearray(74)
|
||||
struct.pack_into('>H', packet, 0, 1) # 1 = Send
|
||||
struct.pack_into('>H', packet, 2, 70) # 70 = Length
|
||||
struct.pack_into('>I', packet, 4, state.ssrc)
|
||||
state.socket.sendto(packet, (state.endpoint_ip, state.voice_port))
|
||||
recv = await self.loop.sock_recv(state.socket, 74)
|
||||
_log.debug('received packet in initial_connection: %s', recv)
|
||||
|
||||
# the ip is ascii starting at the 8th byte and ending at the first null
|
||||
ip_start = 8
|
||||
ip_end = recv.index(0, ip_start)
|
||||
state.ip = recv[ip_start:ip_end].decode('ascii')
|
||||
|
||||
state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
|
||||
_log.debug('detected ip: %s port: %s', state.ip, state.port)
|
||||
_log.debug('Connecting to voice socket')
|
||||
await self.loop.sock_connect(state.socket, (state.endpoint_ip, state.voice_port))
|
||||
|
||||
state.ip, state.port = await self.discover_ip()
|
||||
# there *should* always be at least one supported mode (xsalsa20_poly1305)
|
||||
modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes]
|
||||
_log.debug('received supported encryption modes: %s', ", ".join(modes))
|
||||
_log.debug('received supported encryption modes: %s', ', '.join(modes))
|
||||
|
||||
mode = modes[0]
|
||||
await self.select_protocol(state.ip, state.port, mode)
|
||||
_log.debug('selected the voice protocol for use (%s)', mode)
|
||||
|
||||
async def discover_ip(self) -> Tuple[str, int]:
|
||||
state = self._connection
|
||||
packet = bytearray(74)
|
||||
struct.pack_into('>H', packet, 0, 1) # 1 = Send
|
||||
struct.pack_into('>H', packet, 2, 70) # 70 = Length
|
||||
struct.pack_into('>I', packet, 4, state.ssrc)
|
||||
|
||||
_log.debug('Sending ip discovery packet')
|
||||
await self.loop.sock_sendall(state.socket, packet)
|
||||
|
||||
fut: asyncio.Future[bytes] = self.loop.create_future()
|
||||
|
||||
def get_ip_packet(data: bytes):
|
||||
if data[1] == 0x02 and len(data) == 74:
|
||||
self.loop.call_soon_threadsafe(fut.set_result, data)
|
||||
|
||||
fut.add_done_callback(lambda f: state.remove_socket_listener(get_ip_packet))
|
||||
state.add_socket_listener(get_ip_packet)
|
||||
recv = await fut
|
||||
|
||||
_log.debug('Received ip discovery packet: %s', recv)
|
||||
|
||||
# the ip is ascii starting at the 8th byte and ending at the first null
|
||||
ip_start = 8
|
||||
ip_end = recv.index(0, ip_start)
|
||||
ip = recv[ip_start:ip_end].decode('ascii')
|
||||
|
||||
port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
|
||||
_log.debug('detected ip: %s port: %s', ip, port)
|
||||
|
||||
return ip, port
|
||||
|
||||
@property
|
||||
def latency(self) -> float:
|
||||
""":class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds."""
|
||||
@ -995,9 +1020,8 @@ class DiscordVoiceWebSocket:
|
||||
self.secret_key = self._connection.secret_key = data['secret_key']
|
||||
|
||||
# Send a speak command with the "not speaking" state.
|
||||
# This also tells Discord our SSRC value, which Discord requires
|
||||
# before sending any voice data (and is the real reason why we
|
||||
# call this here).
|
||||
# This also tells Discord our SSRC value, which Discord requires before
|
||||
# sending any voice data (and is the real reason why we call this here).
|
||||
await self.speak(SpeakingState.none)
|
||||
|
||||
async def poll_event(self) -> None:
|
||||
@ -1006,10 +1030,10 @@ class DiscordVoiceWebSocket:
|
||||
if msg.type is aiohttp.WSMsgType.TEXT:
|
||||
await self.received_message(utils._from_json(msg.data))
|
||||
elif msg.type is aiohttp.WSMsgType.ERROR:
|
||||
_log.debug('Received %s', msg)
|
||||
_log.debug('Received voice %s', msg)
|
||||
raise ConnectionClosed(self.ws, shard_id=None) from msg.data
|
||||
elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
|
||||
_log.debug('Received %s', msg)
|
||||
_log.debug('Received voice %s', msg)
|
||||
raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code)
|
||||
|
||||
async def close(self, code: int = 1000) -> None:
|
||||
|
@ -703,7 +703,6 @@ class AudioPlayer(threading.Thread):
|
||||
self._resumed: threading.Event = threading.Event()
|
||||
self._resumed.set() # we are not paused
|
||||
self._current_error: Optional[Exception] = None
|
||||
self._connected: threading.Event = client._connected
|
||||
self._lock: threading.Lock = threading.Lock()
|
||||
|
||||
if after is not None and not callable(after):
|
||||
@ -714,7 +713,8 @@ class AudioPlayer(threading.Thread):
|
||||
self._start = time.perf_counter()
|
||||
|
||||
# getattr lookup speed ups
|
||||
play_audio = self.client.send_audio_packet
|
||||
client = self.client
|
||||
play_audio = client.send_audio_packet
|
||||
self._speak(SpeakingState.voice)
|
||||
|
||||
while not self._end.is_set():
|
||||
@ -725,22 +725,28 @@ class AudioPlayer(threading.Thread):
|
||||
self._resumed.wait()
|
||||
continue
|
||||
|
||||
# are we disconnected from voice?
|
||||
if not self._connected.is_set():
|
||||
# wait until we are connected
|
||||
self._connected.wait()
|
||||
# reset our internal data
|
||||
self.loops = 0
|
||||
self._start = time.perf_counter()
|
||||
|
||||
self.loops += 1
|
||||
data = self.source.read()
|
||||
|
||||
if not data:
|
||||
self.stop()
|
||||
break
|
||||
|
||||
# are we disconnected from voice?
|
||||
if not client.is_connected():
|
||||
_log.debug('Not connected, waiting for %ss...', client.timeout)
|
||||
# wait until we are connected, but not forever
|
||||
connected = client.wait_until_connected(client.timeout)
|
||||
if self._end.is_set() or not connected:
|
||||
_log.debug('Aborting playback')
|
||||
return
|
||||
_log.debug('Reconnected, resuming playback')
|
||||
self._speak(SpeakingState.voice)
|
||||
# reset our internal data
|
||||
self.loops = 0
|
||||
self._start = time.perf_counter()
|
||||
|
||||
play_audio(data, encode=not self.source.is_opus())
|
||||
self.loops += 1
|
||||
next_time = self._start + self.DELAY * self.loops
|
||||
delay = max(0, self.DELAY + (next_time - time.perf_counter()))
|
||||
time.sleep(delay)
|
||||
@ -792,7 +798,7 @@ class AudioPlayer(threading.Thread):
|
||||
def is_paused(self) -> bool:
|
||||
return not self._end.is_set() and not self._resumed.is_set()
|
||||
|
||||
def _set_source(self, source: AudioSource) -> None:
|
||||
def set_source(self, source: AudioSource) -> None:
|
||||
with self._lock:
|
||||
self.pause(update_speaking=False)
|
||||
self.source = source
|
||||
|
@ -20,40 +20,24 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
|
||||
|
||||
Some documentation to refer to:
|
||||
|
||||
- Our main web socket (mWS) sends opcode 4 with a guild ID and channel ID.
|
||||
- The mWS receives VOICE_STATE_UPDATE and VOICE_SERVER_UPDATE.
|
||||
- We pull the session_id from VOICE_STATE_UPDATE.
|
||||
- We pull the token, endpoint and server_id from VOICE_SERVER_UPDATE.
|
||||
- Then we initiate the voice web socket (vWS) pointing to the endpoint.
|
||||
- We send opcode 0 with the user_id, server_id, session_id and token using the vWS.
|
||||
- The vWS sends back opcode 2 with an ssrc, port, modes(array) and heartbeat_interval.
|
||||
- We send a UDP discovery packet to endpoint:port and receive our IP and our port in LE.
|
||||
- Then we send our IP and port via vWS with opcode 1.
|
||||
- When that's all done, we receive opcode 4 from the vWS.
|
||||
- Finally we can transmit data to endpoint:port.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import socket
|
||||
import logging
|
||||
import struct
|
||||
import threading
|
||||
from typing import Any, Callable, List, Optional, TYPE_CHECKING, Tuple, Union
|
||||
|
||||
from . import opus, utils
|
||||
from .backoff import ExponentialBackoff
|
||||
from . import opus
|
||||
from .gateway import *
|
||||
from .errors import ClientException, ConnectionClosed
|
||||
from .errors import ClientException
|
||||
from .player import AudioPlayer, AudioSource
|
||||
from .utils import MISSING
|
||||
from .voice_state import VoiceConnectionState
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .gateway import DiscordVoiceWebSocket
|
||||
from .client import Client
|
||||
from .guild import Guild
|
||||
from .state import ConnectionState
|
||||
@ -226,12 +210,6 @@ class VoiceClient(VoiceProtocol):
|
||||
"""
|
||||
|
||||
channel: VocalGuildChannel
|
||||
endpoint_ip: str
|
||||
voice_port: int
|
||||
ip: str
|
||||
port: int
|
||||
secret_key: List[int]
|
||||
ssrc: int
|
||||
|
||||
def __init__(self, client: Client, channel: abc.Connectable) -> None:
|
||||
if not has_nacl:
|
||||
@ -239,29 +217,18 @@ class VoiceClient(VoiceProtocol):
|
||||
|
||||
super().__init__(client, channel)
|
||||
state = client._connection
|
||||
self.token: str = MISSING
|
||||
self.server_id: int = MISSING
|
||||
self.socket = MISSING
|
||||
self.loop: asyncio.AbstractEventLoop = state.loop
|
||||
self._state: ConnectionState = state
|
||||
# this will be used in the AudioPlayer thread
|
||||
self._connected: threading.Event = threading.Event()
|
||||
|
||||
self._handshaking: bool = False
|
||||
self._potentially_reconnecting: bool = False
|
||||
self._voice_state_complete: asyncio.Event = asyncio.Event()
|
||||
self._voice_server_complete: asyncio.Event = asyncio.Event()
|
||||
|
||||
self.mode: str = MISSING
|
||||
self._connections: int = 0
|
||||
self.sequence: int = 0
|
||||
self.timestamp: int = 0
|
||||
self.timeout: float = 0
|
||||
self._runner: asyncio.Task = MISSING
|
||||
self._player: Optional[AudioPlayer] = None
|
||||
self.encoder: Encoder = MISSING
|
||||
self._lite_nonce: int = 0
|
||||
self.ws: DiscordVoiceWebSocket = MISSING
|
||||
|
||||
self._connection: VoiceConnectionState = self.create_connection_state()
|
||||
|
||||
warn_nacl: bool = not has_nacl
|
||||
supported_modes: Tuple[SupportedModes, ...] = (
|
||||
@ -280,6 +247,38 @@ class VoiceClient(VoiceProtocol):
|
||||
""":class:`ClientUser`: The user connected to voice (i.e. ourselves)."""
|
||||
return self._state.user # type: ignore
|
||||
|
||||
@property
|
||||
def session_id(self) -> Optional[str]:
|
||||
return self._connection.session_id
|
||||
|
||||
@property
|
||||
def token(self) -> Optional[str]:
|
||||
return self._connection.token
|
||||
|
||||
@property
|
||||
def endpoint(self) -> Optional[str]:
|
||||
return self._connection.endpoint
|
||||
|
||||
@property
|
||||
def ssrc(self) -> int:
|
||||
return self._connection.ssrc
|
||||
|
||||
@property
|
||||
def mode(self) -> SupportedModes:
|
||||
return self._connection.mode
|
||||
|
||||
@property
|
||||
def secret_key(self) -> List[int]:
|
||||
return self._connection.secret_key
|
||||
|
||||
@property
|
||||
def ws(self) -> DiscordVoiceWebSocket:
|
||||
return self._connection.ws
|
||||
|
||||
@property
|
||||
def timeout(self) -> float:
|
||||
return self._connection.timeout
|
||||
|
||||
def checked_add(self, attr: str, value: int, limit: int) -> None:
|
||||
val = getattr(self, attr)
|
||||
if val + value > limit:
|
||||
@ -289,149 +288,23 @@ class VoiceClient(VoiceProtocol):
|
||||
|
||||
# connection related
|
||||
|
||||
async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
|
||||
self.session_id: str = data['session_id']
|
||||
channel_id = data['channel_id']
|
||||
def create_connection_state(self) -> VoiceConnectionState:
|
||||
return VoiceConnectionState(self)
|
||||
|
||||
if not self._handshaking or self._potentially_reconnecting:
|
||||
# If we're done handshaking then we just need to update ourselves
|
||||
# If we're potentially reconnecting due to a 4014, then we need to differentiate
|
||||
# a channel move and an actual force disconnect
|
||||
if channel_id is None:
|
||||
# We're being disconnected so cleanup
|
||||
await self.disconnect()
|
||||
else:
|
||||
self.channel = channel_id and self.guild.get_channel(int(channel_id)) # type: ignore
|
||||
else:
|
||||
self._voice_state_complete.set()
|
||||
async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
|
||||
await self._connection.voice_state_update(data)
|
||||
|
||||
async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
|
||||
if self._voice_server_complete.is_set():
|
||||
_log.warning('Ignoring extraneous voice server update.')
|
||||
return
|
||||
|
||||
self.token = data['token']
|
||||
self.server_id = int(data['guild_id'])
|
||||
endpoint = data.get('endpoint')
|
||||
|
||||
if endpoint is None or self.token is None:
|
||||
_log.warning(
|
||||
'Awaiting endpoint... This requires waiting. '
|
||||
'If timeout occurred considering raising the timeout and reconnecting.'
|
||||
)
|
||||
return
|
||||
|
||||
self.endpoint, _, _ = endpoint.rpartition(':')
|
||||
if self.endpoint.startswith('wss://'):
|
||||
# Just in case, strip it off since we're going to add it later
|
||||
self.endpoint: str = self.endpoint[6:]
|
||||
|
||||
# This gets set later
|
||||
self.endpoint_ip = MISSING
|
||||
|
||||
self.socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.socket.setblocking(False)
|
||||
|
||||
if not self._handshaking:
|
||||
# If we're not handshaking then we need to terminate our previous connection in the websocket
|
||||
await self.ws.close(4000)
|
||||
return
|
||||
|
||||
self._voice_server_complete.set()
|
||||
|
||||
async def voice_connect(self, self_deaf: bool = False, self_mute: bool = False) -> None:
|
||||
await self.channel.guild.change_voice_state(channel=self.channel, self_deaf=self_deaf, self_mute=self_mute)
|
||||
|
||||
async def voice_disconnect(self) -> None:
|
||||
_log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id)
|
||||
await self.channel.guild.change_voice_state(channel=None)
|
||||
|
||||
def prepare_handshake(self) -> None:
|
||||
self._voice_state_complete.clear()
|
||||
self._voice_server_complete.clear()
|
||||
self._handshaking = True
|
||||
_log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1)
|
||||
self._connections += 1
|
||||
|
||||
def finish_handshake(self) -> None:
|
||||
_log.info('Voice handshake complete. Endpoint found %s', self.endpoint)
|
||||
self._handshaking = False
|
||||
self._voice_server_complete.clear()
|
||||
self._voice_state_complete.clear()
|
||||
|
||||
async def connect_websocket(self) -> DiscordVoiceWebSocket:
|
||||
ws = await DiscordVoiceWebSocket.from_client(self)
|
||||
self._connected.clear()
|
||||
while ws.secret_key is None:
|
||||
await ws.poll_event()
|
||||
self._connected.set()
|
||||
return ws
|
||||
await self._connection.voice_server_update(data)
|
||||
|
||||
async def connect(self, *, reconnect: bool, timeout: float, self_deaf: bool = False, self_mute: bool = False) -> None:
|
||||
_log.info('Connecting to voice...')
|
||||
self.timeout = timeout
|
||||
await self._connection.connect(
|
||||
reconnect=reconnect, timeout=timeout, self_deaf=self_deaf, self_mute=self_mute, resume=False
|
||||
)
|
||||
|
||||
for i in range(5):
|
||||
self.prepare_handshake()
|
||||
|
||||
# This has to be created before we start the flow.
|
||||
futures = [
|
||||
self._voice_state_complete.wait(),
|
||||
self._voice_server_complete.wait(),
|
||||
]
|
||||
|
||||
# Start the connection flow
|
||||
await self.voice_connect(self_deaf=self_deaf, self_mute=self_mute)
|
||||
|
||||
try:
|
||||
await utils.sane_wait_for(futures, timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
await self.disconnect(force=True)
|
||||
raise
|
||||
|
||||
self.finish_handshake()
|
||||
|
||||
try:
|
||||
self.ws = await self.connect_websocket()
|
||||
break
|
||||
except (ConnectionClosed, asyncio.TimeoutError):
|
||||
if reconnect:
|
||||
_log.exception('Failed to connect to voice... Retrying...')
|
||||
await asyncio.sleep(1 + i * 2.0)
|
||||
await self.voice_disconnect()
|
||||
continue
|
||||
else:
|
||||
raise
|
||||
|
||||
if self._runner is MISSING:
|
||||
self._runner = self.client.loop.create_task(self.poll_voice_ws(reconnect))
|
||||
|
||||
async def potential_reconnect(self) -> bool:
|
||||
# Attempt to stop the player thread from playing early
|
||||
self._connected.clear()
|
||||
self.prepare_handshake()
|
||||
self._potentially_reconnecting = True
|
||||
try:
|
||||
# We only care about VOICE_SERVER_UPDATE since VOICE_STATE_UPDATE can come before we get disconnected
|
||||
await asyncio.wait_for(self._voice_server_complete.wait(), timeout=self.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
self._potentially_reconnecting = False
|
||||
await self.disconnect(force=True)
|
||||
return False
|
||||
|
||||
self.finish_handshake()
|
||||
self._potentially_reconnecting = False
|
||||
|
||||
if self.ws:
|
||||
_log.debug("Closing existing voice websocket")
|
||||
await self.ws.close()
|
||||
|
||||
try:
|
||||
self.ws = await self.connect_websocket()
|
||||
except (ConnectionClosed, asyncio.TimeoutError):
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
def wait_until_connected(self, timeout: Optional[float] = 30.0) -> bool:
|
||||
self._connection.wait(timeout)
|
||||
return self._connection.is_connected()
|
||||
|
||||
@property
|
||||
def latency(self) -> float:
|
||||
@ -442,7 +315,7 @@ class VoiceClient(VoiceProtocol):
|
||||
|
||||
.. versionadded:: 1.4
|
||||
"""
|
||||
ws = self.ws
|
||||
ws = self._connection.ws
|
||||
return float("inf") if not ws else ws.latency
|
||||
|
||||
@property
|
||||
@ -451,72 +324,19 @@ class VoiceClient(VoiceProtocol):
|
||||
|
||||
.. versionadded:: 1.4
|
||||
"""
|
||||
ws = self.ws
|
||||
ws = self._connection.ws
|
||||
return float("inf") if not ws else ws.average_latency
|
||||
|
||||
async def poll_voice_ws(self, reconnect: bool) -> None:
|
||||
backoff = ExponentialBackoff()
|
||||
while True:
|
||||
try:
|
||||
await self.ws.poll_event()
|
||||
except (ConnectionClosed, asyncio.TimeoutError) as exc:
|
||||
if isinstance(exc, ConnectionClosed):
|
||||
# The following close codes are undocumented so I will document them here.
|
||||
# 1000 - normal closure (obviously)
|
||||
# 4014 - voice channel has been deleted.
|
||||
# 4015 - voice server has crashed
|
||||
if exc.code in (1000, 4015):
|
||||
_log.info('Disconnecting from voice normally, close code %d.', exc.code)
|
||||
await self.disconnect()
|
||||
break
|
||||
if exc.code == 4014:
|
||||
_log.info('Disconnected from voice by force... potentially reconnecting.')
|
||||
successful = await self.potential_reconnect()
|
||||
if not successful:
|
||||
_log.info('Reconnect was unsuccessful, disconnecting from voice normally...')
|
||||
await self.disconnect()
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
if not reconnect:
|
||||
await self.disconnect()
|
||||
raise
|
||||
|
||||
retry = backoff.delay()
|
||||
_log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry)
|
||||
self._connected.clear()
|
||||
await asyncio.sleep(retry)
|
||||
await self.voice_disconnect()
|
||||
try:
|
||||
await self.connect(reconnect=True, timeout=self.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
# at this point we've retried 5 times... let's continue the loop.
|
||||
_log.warning('Could not connect to voice... Retrying...')
|
||||
continue
|
||||
|
||||
async def disconnect(self, *, force: bool = False) -> None:
|
||||
"""|coro|
|
||||
|
||||
Disconnects this voice client from voice.
|
||||
"""
|
||||
if not force and not self.is_connected():
|
||||
return
|
||||
|
||||
self.stop()
|
||||
self._connected.clear()
|
||||
await self._connection.disconnect(force=force)
|
||||
self.cleanup()
|
||||
|
||||
try:
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
|
||||
await self.voice_disconnect()
|
||||
finally:
|
||||
self.cleanup()
|
||||
if self.socket:
|
||||
self.socket.close()
|
||||
|
||||
async def move_to(self, channel: Optional[abc.Snowflake]) -> None:
|
||||
async def move_to(self, channel: Optional[abc.Snowflake], *, timeout: Optional[float] = 30.0) -> None:
|
||||
"""|coro|
|
||||
|
||||
Moves you to a different voice channel.
|
||||
@ -525,12 +345,22 @@ class VoiceClient(VoiceProtocol):
|
||||
-----------
|
||||
channel: Optional[:class:`abc.Snowflake`]
|
||||
The channel to move to. Must be a voice channel.
|
||||
timeout: Optional[:class:`float`]
|
||||
How long to wait for the move to complete.
|
||||
|
||||
.. versionadded:: 2.4
|
||||
|
||||
Raises
|
||||
-------
|
||||
asyncio.TimeoutError
|
||||
The move did not complete in time, but may still be ongoing.
|
||||
"""
|
||||
await self.channel.guild.change_voice_state(channel=channel)
|
||||
await self._connection.move_to(channel)
|
||||
await self._connection.wait_async(timeout)
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
"""Indicates if the voice client is connected to voice."""
|
||||
return self._connected.is_set()
|
||||
return self._connection.is_connected()
|
||||
|
||||
# audio related
|
||||
|
||||
@ -703,7 +533,7 @@ class VoiceClient(VoiceProtocol):
|
||||
if self._player is None:
|
||||
raise ValueError('Not playing anything.')
|
||||
|
||||
self._player._set_source(value)
|
||||
self._player.set_source(value)
|
||||
|
||||
def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None:
|
||||
"""Sends an audio packet composed of the data.
|
||||
@ -732,8 +562,8 @@ class VoiceClient(VoiceProtocol):
|
||||
encoded_data = data
|
||||
packet = self._get_voice_packet(encoded_data)
|
||||
try:
|
||||
self.socket.sendto(packet, (self.endpoint_ip, self.voice_port))
|
||||
except BlockingIOError:
|
||||
_log.warning('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp)
|
||||
self._connection.send_packet(packet)
|
||||
except OSError:
|
||||
_log.info('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp)
|
||||
|
||||
self.checked_add('timestamp', opus.Encoder.SAMPLES_PER_FRAME, 4294967295)
|
||||
|
596
discord/voice_state.py
Normal file
596
discord/voice_state.py
Normal file
@ -0,0 +1,596 @@
|
||||
"""
|
||||
The MIT License (MIT)
|
||||
|
||||
Copyright (c) 2015-present Rapptz
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a
|
||||
copy of this software and associated documentation files (the "Software"),
|
||||
to deal in the Software without restriction, including without limitation
|
||||
the rights to use, copy, modify, merge, publish, distribute, sublicense,
|
||||
and/or sell copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in
|
||||
all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
|
||||
OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
|
||||
|
||||
Some documentation to refer to:
|
||||
|
||||
- Our main web socket (mWS) sends opcode 4 with a guild ID and channel ID.
|
||||
- The mWS receives VOICE_STATE_UPDATE and VOICE_SERVER_UPDATE.
|
||||
- We pull the session_id from VOICE_STATE_UPDATE.
|
||||
- We pull the token, endpoint and server_id from VOICE_SERVER_UPDATE.
|
||||
- Then we initiate the voice web socket (vWS) pointing to the endpoint.
|
||||
- We send opcode 0 with the user_id, server_id, session_id and token using the vWS.
|
||||
- The vWS sends back opcode 2 with an ssrc, port, modes(array) and heartbeat_interval.
|
||||
- We send a UDP discovery packet to endpoint:port and receive our IP and our port in LE.
|
||||
- Then we send our IP and port via vWS with opcode 1.
|
||||
- When that's all done, we receive opcode 4 from the vWS.
|
||||
- Finally we can transmit data to endpoint:port.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import select
|
||||
import socket
|
||||
import asyncio
|
||||
import logging
|
||||
import threading
|
||||
|
||||
import async_timeout
|
||||
|
||||
from typing import TYPE_CHECKING, Optional, Dict, List, Callable, Coroutine, Any, Tuple
|
||||
|
||||
from .enums import Enum
|
||||
from .utils import MISSING, sane_wait_for
|
||||
from .errors import ConnectionClosed
|
||||
from .backoff import ExponentialBackoff
|
||||
from .gateway import DiscordVoiceWebSocket
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from . import abc
|
||||
from .guild import Guild
|
||||
from .user import ClientUser
|
||||
from .member import VoiceState
|
||||
from .voice_client import VoiceClient
|
||||
|
||||
from .types.voice import (
|
||||
GuildVoiceState as GuildVoiceStatePayload,
|
||||
VoiceServerUpdate as VoiceServerUpdatePayload,
|
||||
SupportedModes,
|
||||
)
|
||||
|
||||
WebsocketHook = Optional[Callable[['VoiceConnectionState', Dict[str, Any]], Coroutine[Any, Any, Any]]]
|
||||
SocketReaderCallback = Callable[[bytes], Any]
|
||||
|
||||
|
||||
__all__ = ('VoiceConnectionState',)
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SocketReader(threading.Thread):
|
||||
def __init__(self, state: VoiceConnectionState) -> None:
|
||||
super().__init__(daemon=True, name=f'voice-socket-reader:{id(self):#x}')
|
||||
self.state: VoiceConnectionState = state
|
||||
self._callbacks: List[SocketReaderCallback] = []
|
||||
self._running = threading.Event()
|
||||
self._end = threading.Event()
|
||||
# If we have paused reading due to having no callbacks
|
||||
self._idle_paused: bool = True
|
||||
|
||||
def register(self, callback: SocketReaderCallback) -> None:
|
||||
self._callbacks.append(callback)
|
||||
if self._idle_paused:
|
||||
self._idle_paused = False
|
||||
self._running.set()
|
||||
|
||||
def unregister(self, callback: SocketReaderCallback) -> None:
|
||||
try:
|
||||
self._callbacks.remove(callback)
|
||||
except ValueError:
|
||||
pass
|
||||
else:
|
||||
if not self._callbacks and self._running.is_set():
|
||||
# If running is not set, we are either explicitly paused and
|
||||
# should be explicitly resumed, or we are already idle paused
|
||||
self._idle_paused = True
|
||||
self._running.clear()
|
||||
|
||||
def pause(self) -> None:
|
||||
self._idle_paused = False
|
||||
self._running.clear()
|
||||
|
||||
def resume(self, *, force: bool = False) -> None:
|
||||
if self._running.is_set():
|
||||
return
|
||||
# Don't resume if there are no callbacks registered
|
||||
if not force and not self._callbacks:
|
||||
# We tried to resume but there was nothing to do, so resume when ready
|
||||
self._idle_paused = True
|
||||
return
|
||||
self._idle_paused = False
|
||||
self._running.set()
|
||||
|
||||
def stop(self) -> None:
|
||||
self._end.set()
|
||||
self._running.set()
|
||||
|
||||
def run(self) -> None:
|
||||
self._end.clear()
|
||||
self._running.set()
|
||||
try:
|
||||
self._do_run()
|
||||
except Exception:
|
||||
_log.exception('Error in %s', self)
|
||||
finally:
|
||||
self.stop()
|
||||
self._running.clear()
|
||||
self._callbacks.clear()
|
||||
|
||||
def _do_run(self) -> None:
|
||||
while not self._end.is_set():
|
||||
if not self._running.is_set():
|
||||
self._running.wait()
|
||||
continue
|
||||
|
||||
# Since this socket is a non blocking socket, select has to be used to wait on it for reading.
|
||||
try:
|
||||
readable, _, _ = select.select([self.state.socket], [], [], 30)
|
||||
except (ValueError, TypeError):
|
||||
# The socket is either closed or doesn't exist at the moment
|
||||
continue
|
||||
|
||||
if not readable:
|
||||
continue
|
||||
|
||||
try:
|
||||
data = self.state.socket.recv(2048)
|
||||
except OSError:
|
||||
_log.debug('Error reading from socket in %s, this should be safe to ignore', self, exc_info=True)
|
||||
else:
|
||||
for cb in self._callbacks:
|
||||
try:
|
||||
cb(data)
|
||||
except Exception:
|
||||
_log.exception('Error calling %s in %s', cb, self)
|
||||
|
||||
|
||||
class ConnectionFlowState(Enum):
|
||||
"""Enum representing voice connection flow state."""
|
||||
|
||||
# fmt: off
|
||||
disconnected = 0
|
||||
set_guild_voice_state = 1
|
||||
got_voice_state_update = 2
|
||||
got_voice_server_update = 3
|
||||
got_both_voice_updates = 4
|
||||
websocket_connected = 5
|
||||
got_websocket_ready = 6
|
||||
got_ip_discovery = 7
|
||||
connected = 8
|
||||
# fmt: on
|
||||
|
||||
|
||||
class VoiceConnectionState:
|
||||
"""Represents the internal state of a voice connection."""
|
||||
|
||||
def __init__(self, voice_client: VoiceClient, *, hook: Optional[WebsocketHook] = None) -> None:
|
||||
self.voice_client = voice_client
|
||||
self.hook = hook
|
||||
|
||||
self.timeout: float = 30.0
|
||||
self.reconnect: bool = True
|
||||
self.self_deaf: bool = False
|
||||
self.self_mute: bool = False
|
||||
self.token: Optional[str] = None
|
||||
self.session_id: Optional[str] = None
|
||||
self.endpoint: Optional[str] = None
|
||||
self.endpoint_ip: Optional[str] = None
|
||||
self.server_id: Optional[int] = None
|
||||
self.ip: Optional[str] = None
|
||||
self.port: Optional[int] = None
|
||||
self.voice_port: Optional[int] = None
|
||||
self.secret_key: List[int] = MISSING
|
||||
self.ssrc: int = MISSING
|
||||
self.mode: SupportedModes = MISSING
|
||||
self.socket: socket.socket = MISSING
|
||||
self.ws: DiscordVoiceWebSocket = MISSING
|
||||
|
||||
self._state: ConnectionFlowState = ConnectionFlowState.disconnected
|
||||
self._expecting_disconnect: bool = False
|
||||
self._connected = threading.Event()
|
||||
self._state_event = asyncio.Event()
|
||||
self._runner: Optional[asyncio.Task] = None
|
||||
self._connector: Optional[asyncio.Task] = None
|
||||
self._socket_reader = SocketReader(self)
|
||||
self._socket_reader.start()
|
||||
|
||||
@property
|
||||
def state(self) -> ConnectionFlowState:
|
||||
return self._state
|
||||
|
||||
@state.setter
|
||||
def state(self, state: ConnectionFlowState) -> None:
|
||||
if state is not self._state:
|
||||
_log.debug('Connection state changed to %s', state.name)
|
||||
self._state = state
|
||||
self._state_event.set()
|
||||
self._state_event.clear()
|
||||
|
||||
if state is ConnectionFlowState.connected:
|
||||
self._connected.set()
|
||||
else:
|
||||
self._connected.clear()
|
||||
|
||||
@property
|
||||
def guild(self) -> Guild:
|
||||
return self.voice_client.guild
|
||||
|
||||
@property
|
||||
def user(self) -> ClientUser:
|
||||
return self.voice_client.user
|
||||
|
||||
@property
|
||||
def supported_modes(self) -> Tuple[SupportedModes, ...]:
|
||||
return self.voice_client.supported_modes
|
||||
|
||||
@property
|
||||
def self_voice_state(self) -> Optional[VoiceState]:
|
||||
return self.guild.me.voice
|
||||
|
||||
async def voice_state_update(self, data: GuildVoiceStatePayload) -> None:
|
||||
channel_id = data['channel_id']
|
||||
|
||||
if channel_id is None:
|
||||
# If we know we're going to get a voice_state_update where we have no channel due to
|
||||
# being in the reconnect flow, we ignore it. Otherwise, it probably wasn't from us.
|
||||
if self._expecting_disconnect:
|
||||
self._expecting_disconnect = False
|
||||
else:
|
||||
_log.debug('We were externally disconnected from voice.')
|
||||
await self.disconnect()
|
||||
|
||||
return
|
||||
|
||||
self.session_id = data['session_id']
|
||||
|
||||
# we got the event while connecting
|
||||
if self.state in (ConnectionFlowState.set_guild_voice_state, ConnectionFlowState.got_voice_server_update):
|
||||
if self.state is ConnectionFlowState.set_guild_voice_state:
|
||||
self.state = ConnectionFlowState.got_voice_state_update
|
||||
else:
|
||||
self.state = ConnectionFlowState.got_both_voice_updates
|
||||
return
|
||||
|
||||
if self.state is ConnectionFlowState.connected:
|
||||
self.voice_client.channel = channel_id and self.guild.get_channel(int(channel_id)) # type: ignore
|
||||
|
||||
elif self.state is not ConnectionFlowState.disconnected:
|
||||
if channel_id != self.voice_client.channel.id:
|
||||
# For some unfortunate reason we were moved during the connection flow
|
||||
_log.info('Handling channel move while connecting...')
|
||||
|
||||
self.voice_client.channel = channel_id and self.guild.get_channel(int(channel_id)) # type: ignore
|
||||
|
||||
await self.soft_disconnect(with_state=ConnectionFlowState.got_voice_state_update)
|
||||
await self.connect(
|
||||
reconnect=self.reconnect,
|
||||
timeout=self.timeout,
|
||||
self_deaf=(self.self_voice_state or self).self_deaf,
|
||||
self_mute=(self.self_voice_state or self).self_mute,
|
||||
resume=False,
|
||||
wait=False,
|
||||
)
|
||||
else:
|
||||
_log.debug('Ignoring unexpected voice_state_update event')
|
||||
|
||||
async def voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
|
||||
self.token = data['token']
|
||||
self.server_id = int(data['guild_id'])
|
||||
endpoint = data.get('endpoint')
|
||||
|
||||
if self.token is None or endpoint is None:
|
||||
_log.warning(
|
||||
'Awaiting endpoint... This requires waiting. '
|
||||
'If timeout occurred considering raising the timeout and reconnecting.'
|
||||
)
|
||||
return
|
||||
|
||||
self.endpoint, _, _ = endpoint.rpartition(':')
|
||||
if self.endpoint.startswith('wss://'):
|
||||
# Just in case, strip it off since we're going to add it later
|
||||
self.endpoint = self.endpoint[6:]
|
||||
|
||||
# we got the event while connecting
|
||||
if self.state in (ConnectionFlowState.set_guild_voice_state, ConnectionFlowState.got_voice_state_update):
|
||||
# This gets set after READY is received
|
||||
self.endpoint_ip = MISSING
|
||||
self._create_socket()
|
||||
|
||||
if self.state is ConnectionFlowState.set_guild_voice_state:
|
||||
self.state = ConnectionFlowState.got_voice_server_update
|
||||
else:
|
||||
self.state = ConnectionFlowState.got_both_voice_updates
|
||||
|
||||
elif self.state is ConnectionFlowState.connected:
|
||||
_log.debug('Voice server update, closing old voice websocket')
|
||||
await self.ws.close(4014)
|
||||
self.state = ConnectionFlowState.got_voice_server_update
|
||||
|
||||
elif self.state is not ConnectionFlowState.disconnected:
|
||||
_log.debug('Unexpected server update event, attempting to handle')
|
||||
|
||||
await self.soft_disconnect(with_state=ConnectionFlowState.got_voice_server_update)
|
||||
await self.connect(
|
||||
reconnect=self.reconnect,
|
||||
timeout=self.timeout,
|
||||
self_deaf=(self.self_voice_state or self).self_deaf,
|
||||
self_mute=(self.self_voice_state or self).self_mute,
|
||||
resume=False,
|
||||
wait=False,
|
||||
)
|
||||
self._create_socket()
|
||||
|
||||
async def connect(
|
||||
self, *, reconnect: bool, timeout: float, self_deaf: bool, self_mute: bool, resume: bool, wait: bool = True
|
||||
) -> None:
|
||||
if self._connector:
|
||||
self._connector.cancel()
|
||||
self._connector = None
|
||||
|
||||
if self._runner:
|
||||
self._runner.cancel()
|
||||
self._runner = None
|
||||
|
||||
self.timeout = timeout
|
||||
self.reconnect = reconnect
|
||||
self._connector = self.voice_client.loop.create_task(
|
||||
self._wrap_connect(reconnect, timeout, self_deaf, self_mute, resume), name='Voice connector'
|
||||
)
|
||||
if wait:
|
||||
await self._connector
|
||||
|
||||
async def _wrap_connect(self, *args: Any) -> None:
|
||||
try:
|
||||
await self._connect(*args)
|
||||
except asyncio.CancelledError:
|
||||
_log.debug('Cancelling voice connection')
|
||||
await self.soft_disconnect()
|
||||
raise
|
||||
except asyncio.TimeoutError:
|
||||
_log.info('Timed out connecting to voice')
|
||||
await self.disconnect()
|
||||
raise
|
||||
except Exception:
|
||||
_log.exception('Error connecting to voice... disconnecting')
|
||||
await self.disconnect()
|
||||
raise
|
||||
|
||||
async def _connect(self, reconnect: bool, timeout: float, self_deaf: bool, self_mute: bool, resume: bool) -> None:
|
||||
_log.info('Connecting to voice...')
|
||||
|
||||
async with async_timeout.timeout(timeout):
|
||||
for i in range(5):
|
||||
_log.info('Starting voice handshake... (connection attempt %d)', i + 1)
|
||||
|
||||
await self._voice_connect(self_deaf=self_deaf, self_mute=self_mute)
|
||||
# Setting this unnecessarily will break reconnecting
|
||||
if self.state is ConnectionFlowState.disconnected:
|
||||
self.state = ConnectionFlowState.set_guild_voice_state
|
||||
|
||||
await self._wait_for_state(ConnectionFlowState.got_both_voice_updates)
|
||||
|
||||
_log.info('Voice handshake complete. Endpoint found: %s', self.endpoint)
|
||||
|
||||
try:
|
||||
self.ws = await self._connect_websocket(resume)
|
||||
await self._handshake_websocket()
|
||||
break
|
||||
except ConnectionClosed:
|
||||
if reconnect:
|
||||
wait = 1 + i * 2.0
|
||||
_log.exception('Failed to connect to voice... Retrying in %ss...', wait)
|
||||
await self.disconnect(cleanup=False)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
else:
|
||||
await self.disconnect()
|
||||
raise
|
||||
|
||||
_log.info('Voice connection complete.')
|
||||
|
||||
if not self._runner:
|
||||
self._runner = self.voice_client.loop.create_task(self._poll_voice_ws(reconnect), name='Voice websocket poller')
|
||||
|
||||
async def disconnect(self, *, force: bool = True, cleanup: bool = True) -> None:
|
||||
if not force and not self.is_connected():
|
||||
return
|
||||
|
||||
try:
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
await self._voice_disconnect()
|
||||
except Exception:
|
||||
_log.debug('Ignoring exception disconnecting from voice', exc_info=True)
|
||||
finally:
|
||||
self.ip = MISSING
|
||||
self.port = MISSING
|
||||
self.state = ConnectionFlowState.disconnected
|
||||
self._socket_reader.pause()
|
||||
|
||||
# Flip the connected event to unlock any waiters
|
||||
self._connected.set()
|
||||
self._connected.clear()
|
||||
|
||||
if cleanup:
|
||||
self._socket_reader.stop()
|
||||
self.voice_client.cleanup()
|
||||
|
||||
if self.socket:
|
||||
self.socket.close()
|
||||
|
||||
async def soft_disconnect(self, *, with_state: ConnectionFlowState = ConnectionFlowState.got_both_voice_updates) -> None:
|
||||
_log.debug('Soft disconnecting from voice')
|
||||
# Stop the websocket reader because closing the websocket will trigger an unwanted reconnect
|
||||
if self._runner:
|
||||
self._runner.cancel()
|
||||
self._runner = None
|
||||
|
||||
try:
|
||||
if self.ws:
|
||||
await self.ws.close()
|
||||
except Exception:
|
||||
_log.debug('Ignoring exception soft disconnecting from voice', exc_info=True)
|
||||
finally:
|
||||
self.ip = MISSING
|
||||
self.port = MISSING
|
||||
self.state = with_state
|
||||
self._socket_reader.pause()
|
||||
|
||||
if self.socket:
|
||||
self.socket.close()
|
||||
|
||||
async def move_to(self, channel: Optional[abc.Snowflake]) -> None:
|
||||
if channel is None:
|
||||
await self.disconnect()
|
||||
return
|
||||
|
||||
await self.voice_client.channel.guild.change_voice_state(channel=channel)
|
||||
self.state = ConnectionFlowState.set_guild_voice_state
|
||||
|
||||
def wait(self, timeout: Optional[float] = None) -> bool:
|
||||
return self._connected.wait(timeout)
|
||||
|
||||
async def wait_async(self, timeout: Optional[float] = None) -> None:
|
||||
await self._wait_for_state(ConnectionFlowState.connected, timeout=timeout)
|
||||
|
||||
def is_connected(self) -> bool:
|
||||
return self.state is ConnectionFlowState.connected
|
||||
|
||||
def send_packet(self, packet: bytes) -> None:
|
||||
self.socket.sendall(packet)
|
||||
|
||||
def add_socket_listener(self, callback: SocketReaderCallback) -> None:
|
||||
_log.debug('Registering socket listener callback %s', callback)
|
||||
self._socket_reader.register(callback)
|
||||
|
||||
def remove_socket_listener(self, callback: SocketReaderCallback) -> None:
|
||||
_log.debug('Unregistering socket listener callback %s', callback)
|
||||
self._socket_reader.unregister(callback)
|
||||
|
||||
async def _wait_for_state(
|
||||
self, state: ConnectionFlowState, *other_states: ConnectionFlowState, timeout: Optional[float] = None
|
||||
) -> None:
|
||||
states = (state, *other_states)
|
||||
while True:
|
||||
if self.state in states:
|
||||
return
|
||||
await sane_wait_for([self._state_event.wait()], timeout=timeout)
|
||||
|
||||
async def _voice_connect(self, *, self_deaf: bool = False, self_mute: bool = False) -> None:
|
||||
channel = self.voice_client.channel
|
||||
await channel.guild.change_voice_state(channel=channel, self_deaf=self_deaf, self_mute=self_mute)
|
||||
|
||||
async def _voice_disconnect(self) -> None:
|
||||
_log.info(
|
||||
'The voice handshake is being terminated for Channel ID %s (Guild ID %s)',
|
||||
self.voice_client.channel.id,
|
||||
self.voice_client.guild.id,
|
||||
)
|
||||
self.state = ConnectionFlowState.disconnected
|
||||
await self.voice_client.channel.guild.change_voice_state(channel=None)
|
||||
self._expecting_disconnect = True
|
||||
|
||||
async def _connect_websocket(self, resume: bool) -> DiscordVoiceWebSocket:
|
||||
ws = await DiscordVoiceWebSocket.from_connection_state(self, resume=resume, hook=self.hook)
|
||||
self.state = ConnectionFlowState.websocket_connected
|
||||
return ws
|
||||
|
||||
async def _handshake_websocket(self) -> None:
|
||||
while not self.ip:
|
||||
await self.ws.poll_event()
|
||||
self.state = ConnectionFlowState.got_ip_discovery
|
||||
while self.ws.secret_key is None:
|
||||
await self.ws.poll_event()
|
||||
self.state = ConnectionFlowState.connected
|
||||
|
||||
def _create_socket(self) -> None:
|
||||
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
self.socket.setblocking(False)
|
||||
self._socket_reader.resume()
|
||||
|
||||
async def _poll_voice_ws(self, reconnect: bool) -> None:
|
||||
backoff = ExponentialBackoff()
|
||||
while True:
|
||||
try:
|
||||
await self.ws.poll_event()
|
||||
except asyncio.CancelledError:
|
||||
return
|
||||
except (ConnectionClosed, asyncio.TimeoutError) as exc:
|
||||
if isinstance(exc, ConnectionClosed):
|
||||
# The following close codes are undocumented so I will document them here.
|
||||
# 1000 - normal closure (obviously)
|
||||
# 4014 - we were externally disconnected (voice channel deleted, we were moved, etc)
|
||||
# 4015 - voice server has crashed
|
||||
if exc.code in (1000, 4015):
|
||||
_log.info('Disconnecting from voice normally, close code %d.', exc.code)
|
||||
await self.disconnect()
|
||||
break
|
||||
|
||||
if exc.code == 4014:
|
||||
_log.info('Disconnected from voice by force... potentially reconnecting.')
|
||||
successful = await self._potential_reconnect()
|
||||
if not successful:
|
||||
_log.info('Reconnect was unsuccessful, disconnecting from voice normally...')
|
||||
await self.disconnect()
|
||||
break
|
||||
else:
|
||||
continue
|
||||
|
||||
_log.debug('Not handling close code %s (%s)', exc.code, exc.reason or 'no reason')
|
||||
|
||||
if not reconnect:
|
||||
await self.disconnect()
|
||||
raise
|
||||
|
||||
retry = backoff.delay()
|
||||
_log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry)
|
||||
await asyncio.sleep(retry)
|
||||
await self.disconnect(cleanup=False)
|
||||
|
||||
try:
|
||||
await self._connect(
|
||||
reconnect=reconnect,
|
||||
timeout=self.timeout,
|
||||
self_deaf=(self.self_voice_state or self).self_deaf,
|
||||
self_mute=(self.self_voice_state or self).self_mute,
|
||||
resume=False,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
# at this point we've retried 5 times... let's continue the loop.
|
||||
_log.warning('Could not connect to voice... Retrying...')
|
||||
continue
|
||||
|
||||
async def _potential_reconnect(self) -> bool:
|
||||
try:
|
||||
await self._wait_for_state(
|
||||
ConnectionFlowState.got_voice_server_update, ConnectionFlowState.got_both_voice_updates, timeout=self.timeout
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
return False
|
||||
try:
|
||||
self.ws = await self._connect_websocket(False)
|
||||
await self._handshake_websocket()
|
||||
except (ConnectionClosed, asyncio.TimeoutError):
|
||||
return False
|
||||
else:
|
||||
return True
|
Loading…
x
Reference in New Issue
Block a user