mirror of
				https://github.com/Rapptz/discord.py.git
				synced 2025-10-25 02:23:04 +00:00 
			
		
		
		
	Rewrite voice connection internals
This commit is contained in:
		| @@ -1842,7 +1842,7 @@ class Connectable(Protocol): | |||||||
|     async def connect( |     async def connect( | ||||||
|         self, |         self, | ||||||
|         *, |         *, | ||||||
|         timeout: float = 60.0, |         timeout: float = 30.0, | ||||||
|         reconnect: bool = True, |         reconnect: bool = True, | ||||||
|         cls: Callable[[Client, Connectable], T] = VoiceClient, |         cls: Callable[[Client, Connectable], T] = VoiceClient, | ||||||
|         self_deaf: bool = False, |         self_deaf: bool = False, | ||||||
| @@ -1858,7 +1858,7 @@ class Connectable(Protocol): | |||||||
|         Parameters |         Parameters | ||||||
|         ----------- |         ----------- | ||||||
|         timeout: :class:`float` |         timeout: :class:`float` | ||||||
|             The timeout in seconds to wait for the voice endpoint. |             The timeout in seconds to wait the connection to complete. | ||||||
|         reconnect: :class:`bool` |         reconnect: :class:`bool` | ||||||
|             Whether the bot should automatically attempt |             Whether the bot should automatically attempt | ||||||
|             a reconnect if a part of the handshake fails |             a reconnect if a part of the handshake fails | ||||||
|   | |||||||
| @@ -34,7 +34,7 @@ import threading | |||||||
| import traceback | import traceback | ||||||
| import zlib | import zlib | ||||||
|  |  | ||||||
| from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar | from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Tuple | ||||||
|  |  | ||||||
| import aiohttp | import aiohttp | ||||||
| import yarl | import yarl | ||||||
| @@ -59,7 +59,7 @@ if TYPE_CHECKING: | |||||||
|  |  | ||||||
|     from .client import Client |     from .client import Client | ||||||
|     from .state import ConnectionState |     from .state import ConnectionState | ||||||
|     from .voice_client import VoiceClient |     from .voice_state import VoiceConnectionState | ||||||
|  |  | ||||||
|  |  | ||||||
| class ReconnectWebSocket(Exception): | class ReconnectWebSocket(Exception): | ||||||
| @@ -797,7 +797,7 @@ class DiscordVoiceWebSocket: | |||||||
|  |  | ||||||
|     if TYPE_CHECKING: |     if TYPE_CHECKING: | ||||||
|         thread_id: int |         thread_id: int | ||||||
|         _connection: VoiceClient |         _connection: VoiceConnectionState | ||||||
|         gateway: str |         gateway: str | ||||||
|         _max_heartbeat_timeout: float |         _max_heartbeat_timeout: float | ||||||
|  |  | ||||||
| @@ -866,16 +866,21 @@ class DiscordVoiceWebSocket: | |||||||
|         await self.send_as_json(payload) |         await self.send_as_json(payload) | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     async def from_client( |     async def from_connection_state( | ||||||
|         cls, client: VoiceClient, *, resume: bool = False, hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None |         cls, | ||||||
|  |         state: VoiceConnectionState, | ||||||
|  |         *, | ||||||
|  |         resume: bool = False, | ||||||
|  |         hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None, | ||||||
|     ) -> Self: |     ) -> Self: | ||||||
|         """Creates a voice websocket for the :class:`VoiceClient`.""" |         """Creates a voice websocket for the :class:`VoiceClient`.""" | ||||||
|         gateway = 'wss://' + client.endpoint + '/?v=4' |         gateway = f'wss://{state.endpoint}/?v=4' | ||||||
|  |         client = state.voice_client | ||||||
|         http = client._state.http |         http = client._state.http | ||||||
|         socket = await http.ws_connect(gateway, compress=15) |         socket = await http.ws_connect(gateway, compress=15) | ||||||
|         ws = cls(socket, loop=client.loop, hook=hook) |         ws = cls(socket, loop=client.loop, hook=hook) | ||||||
|         ws.gateway = gateway |         ws.gateway = gateway | ||||||
|         ws._connection = client |         ws._connection = state | ||||||
|         ws._max_heartbeat_timeout = 60.0 |         ws._max_heartbeat_timeout = 60.0 | ||||||
|         ws.thread_id = threading.get_ident() |         ws.thread_id = threading.get_ident() | ||||||
|  |  | ||||||
| @@ -951,30 +956,50 @@ class DiscordVoiceWebSocket: | |||||||
|         state.voice_port = data['port'] |         state.voice_port = data['port'] | ||||||
|         state.endpoint_ip = data['ip'] |         state.endpoint_ip = data['ip'] | ||||||
|  |  | ||||||
|         packet = bytearray(74) |         _log.debug('Connecting to voice socket') | ||||||
|         struct.pack_into('>H', packet, 0, 1)  # 1 = Send |         await self.loop.sock_connect(state.socket, (state.endpoint_ip, state.voice_port)) | ||||||
|         struct.pack_into('>H', packet, 2, 70)  # 70 = Length |  | ||||||
|         struct.pack_into('>I', packet, 4, state.ssrc) |  | ||||||
|         state.socket.sendto(packet, (state.endpoint_ip, state.voice_port)) |  | ||||||
|         recv = await self.loop.sock_recv(state.socket, 74) |  | ||||||
|         _log.debug('received packet in initial_connection: %s', recv) |  | ||||||
|  |  | ||||||
|         # the ip is ascii starting at the 8th byte and ending at the first null |  | ||||||
|         ip_start = 8 |  | ||||||
|         ip_end = recv.index(0, ip_start) |  | ||||||
|         state.ip = recv[ip_start:ip_end].decode('ascii') |  | ||||||
|  |  | ||||||
|         state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0] |  | ||||||
|         _log.debug('detected ip: %s port: %s', state.ip, state.port) |  | ||||||
|  |  | ||||||
|  |         state.ip, state.port = await self.discover_ip() | ||||||
|         # there *should* always be at least one supported mode (xsalsa20_poly1305) |         # there *should* always be at least one supported mode (xsalsa20_poly1305) | ||||||
|         modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes] |         modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes] | ||||||
|         _log.debug('received supported encryption modes: %s', ", ".join(modes)) |         _log.debug('received supported encryption modes: %s', ', '.join(modes)) | ||||||
|  |  | ||||||
|         mode = modes[0] |         mode = modes[0] | ||||||
|         await self.select_protocol(state.ip, state.port, mode) |         await self.select_protocol(state.ip, state.port, mode) | ||||||
|         _log.debug('selected the voice protocol for use (%s)', mode) |         _log.debug('selected the voice protocol for use (%s)', mode) | ||||||
|  |  | ||||||
|  |     async def discover_ip(self) -> Tuple[str, int]: | ||||||
|  |         state = self._connection | ||||||
|  |         packet = bytearray(74) | ||||||
|  |         struct.pack_into('>H', packet, 0, 1)  # 1 = Send | ||||||
|  |         struct.pack_into('>H', packet, 2, 70)  # 70 = Length | ||||||
|  |         struct.pack_into('>I', packet, 4, state.ssrc) | ||||||
|  |  | ||||||
|  |         _log.debug('Sending ip discovery packet') | ||||||
|  |         await self.loop.sock_sendall(state.socket, packet) | ||||||
|  |  | ||||||
|  |         fut: asyncio.Future[bytes] = self.loop.create_future() | ||||||
|  |  | ||||||
|  |         def get_ip_packet(data: bytes): | ||||||
|  |             if data[1] == 0x02 and len(data) == 74: | ||||||
|  |                 self.loop.call_soon_threadsafe(fut.set_result, data) | ||||||
|  |  | ||||||
|  |         fut.add_done_callback(lambda f: state.remove_socket_listener(get_ip_packet)) | ||||||
|  |         state.add_socket_listener(get_ip_packet) | ||||||
|  |         recv = await fut | ||||||
|  |  | ||||||
|  |         _log.debug('Received ip discovery packet: %s', recv) | ||||||
|  |  | ||||||
|  |         # the ip is ascii starting at the 8th byte and ending at the first null | ||||||
|  |         ip_start = 8 | ||||||
|  |         ip_end = recv.index(0, ip_start) | ||||||
|  |         ip = recv[ip_start:ip_end].decode('ascii') | ||||||
|  |  | ||||||
|  |         port = struct.unpack_from('>H', recv, len(recv) - 2)[0] | ||||||
|  |         _log.debug('detected ip: %s port: %s', ip, port) | ||||||
|  |  | ||||||
|  |         return ip, port | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def latency(self) -> float: |     def latency(self) -> float: | ||||||
|         """:class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds.""" |         """:class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds.""" | ||||||
| @@ -995,9 +1020,8 @@ class DiscordVoiceWebSocket: | |||||||
|         self.secret_key = self._connection.secret_key = data['secret_key'] |         self.secret_key = self._connection.secret_key = data['secret_key'] | ||||||
|  |  | ||||||
|         # Send a speak command with the "not speaking" state. |         # Send a speak command with the "not speaking" state. | ||||||
|         # This also tells Discord our SSRC value, which Discord requires |         # This also tells Discord our SSRC value, which Discord requires before | ||||||
|         # before sending any voice data (and is the real reason why we |         # sending any voice data (and is the real reason why we call this here). | ||||||
|         # call this here). |  | ||||||
|         await self.speak(SpeakingState.none) |         await self.speak(SpeakingState.none) | ||||||
|  |  | ||||||
|     async def poll_event(self) -> None: |     async def poll_event(self) -> None: | ||||||
| @@ -1006,10 +1030,10 @@ class DiscordVoiceWebSocket: | |||||||
|         if msg.type is aiohttp.WSMsgType.TEXT: |         if msg.type is aiohttp.WSMsgType.TEXT: | ||||||
|             await self.received_message(utils._from_json(msg.data)) |             await self.received_message(utils._from_json(msg.data)) | ||||||
|         elif msg.type is aiohttp.WSMsgType.ERROR: |         elif msg.type is aiohttp.WSMsgType.ERROR: | ||||||
|             _log.debug('Received %s', msg) |             _log.debug('Received voice %s', msg) | ||||||
|             raise ConnectionClosed(self.ws, shard_id=None) from msg.data |             raise ConnectionClosed(self.ws, shard_id=None) from msg.data | ||||||
|         elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING): |         elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING): | ||||||
|             _log.debug('Received %s', msg) |             _log.debug('Received voice %s', msg) | ||||||
|             raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code) |             raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code) | ||||||
|  |  | ||||||
|     async def close(self, code: int = 1000) -> None: |     async def close(self, code: int = 1000) -> None: | ||||||
|   | |||||||
| @@ -703,7 +703,6 @@ class AudioPlayer(threading.Thread): | |||||||
|         self._resumed: threading.Event = threading.Event() |         self._resumed: threading.Event = threading.Event() | ||||||
|         self._resumed.set()  # we are not paused |         self._resumed.set()  # we are not paused | ||||||
|         self._current_error: Optional[Exception] = None |         self._current_error: Optional[Exception] = None | ||||||
|         self._connected: threading.Event = client._connected |  | ||||||
|         self._lock: threading.Lock = threading.Lock() |         self._lock: threading.Lock = threading.Lock() | ||||||
|  |  | ||||||
|         if after is not None and not callable(after): |         if after is not None and not callable(after): | ||||||
| @@ -714,7 +713,8 @@ class AudioPlayer(threading.Thread): | |||||||
|         self._start = time.perf_counter() |         self._start = time.perf_counter() | ||||||
|  |  | ||||||
|         # getattr lookup speed ups |         # getattr lookup speed ups | ||||||
|         play_audio = self.client.send_audio_packet |         client = self.client | ||||||
|  |         play_audio = client.send_audio_packet | ||||||
|         self._speak(SpeakingState.voice) |         self._speak(SpeakingState.voice) | ||||||
|  |  | ||||||
|         while not self._end.is_set(): |         while not self._end.is_set(): | ||||||
| @@ -725,22 +725,28 @@ class AudioPlayer(threading.Thread): | |||||||
|                 self._resumed.wait() |                 self._resumed.wait() | ||||||
|                 continue |                 continue | ||||||
|  |  | ||||||
|             # are we disconnected from voice? |  | ||||||
|             if not self._connected.is_set(): |  | ||||||
|                 # wait until we are connected |  | ||||||
|                 self._connected.wait() |  | ||||||
|                 # reset our internal data |  | ||||||
|                 self.loops = 0 |  | ||||||
|                 self._start = time.perf_counter() |  | ||||||
|  |  | ||||||
|             self.loops += 1 |  | ||||||
|             data = self.source.read() |             data = self.source.read() | ||||||
|  |  | ||||||
|             if not data: |             if not data: | ||||||
|                 self.stop() |                 self.stop() | ||||||
|                 break |                 break | ||||||
|  |  | ||||||
|  |             # are we disconnected from voice? | ||||||
|  |             if not client.is_connected(): | ||||||
|  |                 _log.debug('Not connected, waiting for %ss...', client.timeout) | ||||||
|  |                 # wait until we are connected, but not forever | ||||||
|  |                 connected = client.wait_until_connected(client.timeout) | ||||||
|  |                 if self._end.is_set() or not connected: | ||||||
|  |                     _log.debug('Aborting playback') | ||||||
|  |                     return | ||||||
|  |                 _log.debug('Reconnected, resuming playback') | ||||||
|  |                 self._speak(SpeakingState.voice) | ||||||
|  |                 # reset our internal data | ||||||
|  |                 self.loops = 0 | ||||||
|  |                 self._start = time.perf_counter() | ||||||
|  |  | ||||||
|             play_audio(data, encode=not self.source.is_opus()) |             play_audio(data, encode=not self.source.is_opus()) | ||||||
|  |             self.loops += 1 | ||||||
|             next_time = self._start + self.DELAY * self.loops |             next_time = self._start + self.DELAY * self.loops | ||||||
|             delay = max(0, self.DELAY + (next_time - time.perf_counter())) |             delay = max(0, self.DELAY + (next_time - time.perf_counter())) | ||||||
|             time.sleep(delay) |             time.sleep(delay) | ||||||
| @@ -792,7 +798,7 @@ class AudioPlayer(threading.Thread): | |||||||
|     def is_paused(self) -> bool: |     def is_paused(self) -> bool: | ||||||
|         return not self._end.is_set() and not self._resumed.is_set() |         return not self._end.is_set() and not self._resumed.is_set() | ||||||
|  |  | ||||||
|     def _set_source(self, source: AudioSource) -> None: |     def set_source(self, source: AudioSource) -> None: | ||||||
|         with self._lock: |         with self._lock: | ||||||
|             self.pause(update_speaking=False) |             self.pause(update_speaking=False) | ||||||
|             self.source = source |             self.source = source | ||||||
|   | |||||||
| @@ -20,40 +20,24 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | |||||||
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING | ||||||
| FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER | ||||||
| DEALINGS IN THE SOFTWARE. | DEALINGS IN THE SOFTWARE. | ||||||
|  |  | ||||||
|  |  | ||||||
| Some documentation to refer to: |  | ||||||
|  |  | ||||||
| - Our main web socket (mWS) sends opcode 4 with a guild ID and channel ID. |  | ||||||
| - The mWS receives VOICE_STATE_UPDATE and VOICE_SERVER_UPDATE. |  | ||||||
| - We pull the session_id from VOICE_STATE_UPDATE. |  | ||||||
| - We pull the token, endpoint and server_id from VOICE_SERVER_UPDATE. |  | ||||||
| - Then we initiate the voice web socket (vWS) pointing to the endpoint. |  | ||||||
| - We send opcode 0 with the user_id, server_id, session_id and token using the vWS. |  | ||||||
| - The vWS sends back opcode 2 with an ssrc, port, modes(array) and heartbeat_interval. |  | ||||||
| - We send a UDP discovery packet to endpoint:port and receive our IP and our port in LE. |  | ||||||
| - Then we send our IP and port via vWS with opcode 1. |  | ||||||
| - When that's all done, we receive opcode 4 from the vWS. |  | ||||||
| - Finally we can transmit data to endpoint:port. |  | ||||||
| """ | """ | ||||||
|  |  | ||||||
| from __future__ import annotations | from __future__ import annotations | ||||||
|  |  | ||||||
| import asyncio | import asyncio | ||||||
| import socket |  | ||||||
| import logging | import logging | ||||||
| import struct | import struct | ||||||
| import threading |  | ||||||
| from typing import Any, Callable, List, Optional, TYPE_CHECKING, Tuple, Union | from typing import Any, Callable, List, Optional, TYPE_CHECKING, Tuple, Union | ||||||
|  |  | ||||||
| from . import opus, utils | from . import opus | ||||||
| from .backoff import ExponentialBackoff |  | ||||||
| from .gateway import * | from .gateway import * | ||||||
| from .errors import ClientException, ConnectionClosed | from .errors import ClientException | ||||||
| from .player import AudioPlayer, AudioSource | from .player import AudioPlayer, AudioSource | ||||||
| from .utils import MISSING | from .utils import MISSING | ||||||
|  | from .voice_state import VoiceConnectionState | ||||||
|  |  | ||||||
| if TYPE_CHECKING: | if TYPE_CHECKING: | ||||||
|  |     from .gateway import DiscordVoiceWebSocket | ||||||
|     from .client import Client |     from .client import Client | ||||||
|     from .guild import Guild |     from .guild import Guild | ||||||
|     from .state import ConnectionState |     from .state import ConnectionState | ||||||
| @@ -226,12 +210,6 @@ class VoiceClient(VoiceProtocol): | |||||||
|     """ |     """ | ||||||
|  |  | ||||||
|     channel: VocalGuildChannel |     channel: VocalGuildChannel | ||||||
|     endpoint_ip: str |  | ||||||
|     voice_port: int |  | ||||||
|     ip: str |  | ||||||
|     port: int |  | ||||||
|     secret_key: List[int] |  | ||||||
|     ssrc: int |  | ||||||
|  |  | ||||||
|     def __init__(self, client: Client, channel: abc.Connectable) -> None: |     def __init__(self, client: Client, channel: abc.Connectable) -> None: | ||||||
|         if not has_nacl: |         if not has_nacl: | ||||||
| @@ -239,29 +217,18 @@ class VoiceClient(VoiceProtocol): | |||||||
|  |  | ||||||
|         super().__init__(client, channel) |         super().__init__(client, channel) | ||||||
|         state = client._connection |         state = client._connection | ||||||
|         self.token: str = MISSING |  | ||||||
|         self.server_id: int = MISSING |         self.server_id: int = MISSING | ||||||
|         self.socket = MISSING |         self.socket = MISSING | ||||||
|         self.loop: asyncio.AbstractEventLoop = state.loop |         self.loop: asyncio.AbstractEventLoop = state.loop | ||||||
|         self._state: ConnectionState = state |         self._state: ConnectionState = state | ||||||
|         # this will be used in the AudioPlayer thread |  | ||||||
|         self._connected: threading.Event = threading.Event() |  | ||||||
|  |  | ||||||
|         self._handshaking: bool = False |  | ||||||
|         self._potentially_reconnecting: bool = False |  | ||||||
|         self._voice_state_complete: asyncio.Event = asyncio.Event() |  | ||||||
|         self._voice_server_complete: asyncio.Event = asyncio.Event() |  | ||||||
|  |  | ||||||
|         self.mode: str = MISSING |  | ||||||
|         self._connections: int = 0 |  | ||||||
|         self.sequence: int = 0 |         self.sequence: int = 0 | ||||||
|         self.timestamp: int = 0 |         self.timestamp: int = 0 | ||||||
|         self.timeout: float = 0 |  | ||||||
|         self._runner: asyncio.Task = MISSING |  | ||||||
|         self._player: Optional[AudioPlayer] = None |         self._player: Optional[AudioPlayer] = None | ||||||
|         self.encoder: Encoder = MISSING |         self.encoder: Encoder = MISSING | ||||||
|         self._lite_nonce: int = 0 |         self._lite_nonce: int = 0 | ||||||
|         self.ws: DiscordVoiceWebSocket = MISSING |  | ||||||
|  |         self._connection: VoiceConnectionState = self.create_connection_state() | ||||||
|  |  | ||||||
|     warn_nacl: bool = not has_nacl |     warn_nacl: bool = not has_nacl | ||||||
|     supported_modes: Tuple[SupportedModes, ...] = ( |     supported_modes: Tuple[SupportedModes, ...] = ( | ||||||
| @@ -280,6 +247,38 @@ class VoiceClient(VoiceProtocol): | |||||||
|         """:class:`ClientUser`: The user connected to voice (i.e. ourselves).""" |         """:class:`ClientUser`: The user connected to voice (i.e. ourselves).""" | ||||||
|         return self._state.user  # type: ignore |         return self._state.user  # type: ignore | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def session_id(self) -> Optional[str]: | ||||||
|  |         return self._connection.session_id | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def token(self) -> Optional[str]: | ||||||
|  |         return self._connection.token | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def endpoint(self) -> Optional[str]: | ||||||
|  |         return self._connection.endpoint | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def ssrc(self) -> int: | ||||||
|  |         return self._connection.ssrc | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def mode(self) -> SupportedModes: | ||||||
|  |         return self._connection.mode | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def secret_key(self) -> List[int]: | ||||||
|  |         return self._connection.secret_key | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def ws(self) -> DiscordVoiceWebSocket: | ||||||
|  |         return self._connection.ws | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def timeout(self) -> float: | ||||||
|  |         return self._connection.timeout | ||||||
|  |  | ||||||
|     def checked_add(self, attr: str, value: int, limit: int) -> None: |     def checked_add(self, attr: str, value: int, limit: int) -> None: | ||||||
|         val = getattr(self, attr) |         val = getattr(self, attr) | ||||||
|         if val + value > limit: |         if val + value > limit: | ||||||
| @@ -289,149 +288,23 @@ class VoiceClient(VoiceProtocol): | |||||||
|  |  | ||||||
|     # connection related |     # connection related | ||||||
|  |  | ||||||
|     async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None: |     def create_connection_state(self) -> VoiceConnectionState: | ||||||
|         self.session_id: str = data['session_id'] |         return VoiceConnectionState(self) | ||||||
|         channel_id = data['channel_id'] |  | ||||||
|  |  | ||||||
|         if not self._handshaking or self._potentially_reconnecting: |     async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None: | ||||||
|             # If we're done handshaking then we just need to update ourselves |         await self._connection.voice_state_update(data) | ||||||
|             # If we're potentially reconnecting due to a 4014, then we need to differentiate |  | ||||||
|             # a channel move and an actual force disconnect |  | ||||||
|             if channel_id is None: |  | ||||||
|                 # We're being disconnected so cleanup |  | ||||||
|                 await self.disconnect() |  | ||||||
|             else: |  | ||||||
|                 self.channel = channel_id and self.guild.get_channel(int(channel_id))  # type: ignore |  | ||||||
|         else: |  | ||||||
|             self._voice_state_complete.set() |  | ||||||
|  |  | ||||||
|     async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None: |     async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None: | ||||||
|         if self._voice_server_complete.is_set(): |         await self._connection.voice_server_update(data) | ||||||
|             _log.warning('Ignoring extraneous voice server update.') |  | ||||||
|             return |  | ||||||
|  |  | ||||||
|         self.token = data['token'] |  | ||||||
|         self.server_id = int(data['guild_id']) |  | ||||||
|         endpoint = data.get('endpoint') |  | ||||||
|  |  | ||||||
|         if endpoint is None or self.token is None: |  | ||||||
|             _log.warning( |  | ||||||
|                 'Awaiting endpoint... This requires waiting. ' |  | ||||||
|                 'If timeout occurred considering raising the timeout and reconnecting.' |  | ||||||
|             ) |  | ||||||
|             return |  | ||||||
|  |  | ||||||
|         self.endpoint, _, _ = endpoint.rpartition(':') |  | ||||||
|         if self.endpoint.startswith('wss://'): |  | ||||||
|             # Just in case, strip it off since we're going to add it later |  | ||||||
|             self.endpoint: str = self.endpoint[6:] |  | ||||||
|  |  | ||||||
|         # This gets set later |  | ||||||
|         self.endpoint_ip = MISSING |  | ||||||
|  |  | ||||||
|         self.socket: socket.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) |  | ||||||
|         self.socket.setblocking(False) |  | ||||||
|  |  | ||||||
|         if not self._handshaking: |  | ||||||
|             # If we're not handshaking then we need to terminate our previous connection in the websocket |  | ||||||
|             await self.ws.close(4000) |  | ||||||
|             return |  | ||||||
|  |  | ||||||
|         self._voice_server_complete.set() |  | ||||||
|  |  | ||||||
|     async def voice_connect(self, self_deaf: bool = False, self_mute: bool = False) -> None: |  | ||||||
|         await self.channel.guild.change_voice_state(channel=self.channel, self_deaf=self_deaf, self_mute=self_mute) |  | ||||||
|  |  | ||||||
|     async def voice_disconnect(self) -> None: |  | ||||||
|         _log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id) |  | ||||||
|         await self.channel.guild.change_voice_state(channel=None) |  | ||||||
|  |  | ||||||
|     def prepare_handshake(self) -> None: |  | ||||||
|         self._voice_state_complete.clear() |  | ||||||
|         self._voice_server_complete.clear() |  | ||||||
|         self._handshaking = True |  | ||||||
|         _log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1) |  | ||||||
|         self._connections += 1 |  | ||||||
|  |  | ||||||
|     def finish_handshake(self) -> None: |  | ||||||
|         _log.info('Voice handshake complete. Endpoint found %s', self.endpoint) |  | ||||||
|         self._handshaking = False |  | ||||||
|         self._voice_server_complete.clear() |  | ||||||
|         self._voice_state_complete.clear() |  | ||||||
|  |  | ||||||
|     async def connect_websocket(self) -> DiscordVoiceWebSocket: |  | ||||||
|         ws = await DiscordVoiceWebSocket.from_client(self) |  | ||||||
|         self._connected.clear() |  | ||||||
|         while ws.secret_key is None: |  | ||||||
|             await ws.poll_event() |  | ||||||
|         self._connected.set() |  | ||||||
|         return ws |  | ||||||
|  |  | ||||||
|     async def connect(self, *, reconnect: bool, timeout: float, self_deaf: bool = False, self_mute: bool = False) -> None: |     async def connect(self, *, reconnect: bool, timeout: float, self_deaf: bool = False, self_mute: bool = False) -> None: | ||||||
|         _log.info('Connecting to voice...') |         await self._connection.connect( | ||||||
|         self.timeout = timeout |             reconnect=reconnect, timeout=timeout, self_deaf=self_deaf, self_mute=self_mute, resume=False | ||||||
|  |         ) | ||||||
|  |  | ||||||
|         for i in range(5): |     def wait_until_connected(self, timeout: Optional[float] = 30.0) -> bool: | ||||||
|             self.prepare_handshake() |         self._connection.wait(timeout) | ||||||
|  |         return self._connection.is_connected() | ||||||
|             # This has to be created before we start the flow. |  | ||||||
|             futures = [ |  | ||||||
|                 self._voice_state_complete.wait(), |  | ||||||
|                 self._voice_server_complete.wait(), |  | ||||||
|             ] |  | ||||||
|  |  | ||||||
|             # Start the connection flow |  | ||||||
|             await self.voice_connect(self_deaf=self_deaf, self_mute=self_mute) |  | ||||||
|  |  | ||||||
|             try: |  | ||||||
|                 await utils.sane_wait_for(futures, timeout=timeout) |  | ||||||
|             except asyncio.TimeoutError: |  | ||||||
|                 await self.disconnect(force=True) |  | ||||||
|                 raise |  | ||||||
|  |  | ||||||
|             self.finish_handshake() |  | ||||||
|  |  | ||||||
|             try: |  | ||||||
|                 self.ws = await self.connect_websocket() |  | ||||||
|                 break |  | ||||||
|             except (ConnectionClosed, asyncio.TimeoutError): |  | ||||||
|                 if reconnect: |  | ||||||
|                     _log.exception('Failed to connect to voice... Retrying...') |  | ||||||
|                     await asyncio.sleep(1 + i * 2.0) |  | ||||||
|                     await self.voice_disconnect() |  | ||||||
|                     continue |  | ||||||
|                 else: |  | ||||||
|                     raise |  | ||||||
|  |  | ||||||
|         if self._runner is MISSING: |  | ||||||
|             self._runner = self.client.loop.create_task(self.poll_voice_ws(reconnect)) |  | ||||||
|  |  | ||||||
|     async def potential_reconnect(self) -> bool: |  | ||||||
|         # Attempt to stop the player thread from playing early |  | ||||||
|         self._connected.clear() |  | ||||||
|         self.prepare_handshake() |  | ||||||
|         self._potentially_reconnecting = True |  | ||||||
|         try: |  | ||||||
|             # We only care about VOICE_SERVER_UPDATE since VOICE_STATE_UPDATE can come before we get disconnected |  | ||||||
|             await asyncio.wait_for(self._voice_server_complete.wait(), timeout=self.timeout) |  | ||||||
|         except asyncio.TimeoutError: |  | ||||||
|             self._potentially_reconnecting = False |  | ||||||
|             await self.disconnect(force=True) |  | ||||||
|             return False |  | ||||||
|  |  | ||||||
|         self.finish_handshake() |  | ||||||
|         self._potentially_reconnecting = False |  | ||||||
|  |  | ||||||
|         if self.ws: |  | ||||||
|             _log.debug("Closing existing voice websocket") |  | ||||||
|             await self.ws.close() |  | ||||||
|  |  | ||||||
|         try: |  | ||||||
|             self.ws = await self.connect_websocket() |  | ||||||
|         except (ConnectionClosed, asyncio.TimeoutError): |  | ||||||
|             return False |  | ||||||
|         else: |  | ||||||
|             return True |  | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def latency(self) -> float: |     def latency(self) -> float: | ||||||
| @@ -442,7 +315,7 @@ class VoiceClient(VoiceProtocol): | |||||||
|  |  | ||||||
|         .. versionadded:: 1.4 |         .. versionadded:: 1.4 | ||||||
|         """ |         """ | ||||||
|         ws = self.ws |         ws = self._connection.ws | ||||||
|         return float("inf") if not ws else ws.latency |         return float("inf") if not ws else ws.latency | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
| @@ -451,72 +324,19 @@ class VoiceClient(VoiceProtocol): | |||||||
|  |  | ||||||
|         .. versionadded:: 1.4 |         .. versionadded:: 1.4 | ||||||
|         """ |         """ | ||||||
|         ws = self.ws |         ws = self._connection.ws | ||||||
|         return float("inf") if not ws else ws.average_latency |         return float("inf") if not ws else ws.average_latency | ||||||
|  |  | ||||||
|     async def poll_voice_ws(self, reconnect: bool) -> None: |  | ||||||
|         backoff = ExponentialBackoff() |  | ||||||
|         while True: |  | ||||||
|             try: |  | ||||||
|                 await self.ws.poll_event() |  | ||||||
|             except (ConnectionClosed, asyncio.TimeoutError) as exc: |  | ||||||
|                 if isinstance(exc, ConnectionClosed): |  | ||||||
|                     # The following close codes are undocumented so I will document them here. |  | ||||||
|                     # 1000 - normal closure (obviously) |  | ||||||
|                     # 4014 - voice channel has been deleted. |  | ||||||
|                     # 4015 - voice server has crashed |  | ||||||
|                     if exc.code in (1000, 4015): |  | ||||||
|                         _log.info('Disconnecting from voice normally, close code %d.', exc.code) |  | ||||||
|                         await self.disconnect() |  | ||||||
|                         break |  | ||||||
|                     if exc.code == 4014: |  | ||||||
|                         _log.info('Disconnected from voice by force... potentially reconnecting.') |  | ||||||
|                         successful = await self.potential_reconnect() |  | ||||||
|                         if not successful: |  | ||||||
|                             _log.info('Reconnect was unsuccessful, disconnecting from voice normally...') |  | ||||||
|                             await self.disconnect() |  | ||||||
|                             break |  | ||||||
|                         else: |  | ||||||
|                             continue |  | ||||||
|  |  | ||||||
|                 if not reconnect: |  | ||||||
|                     await self.disconnect() |  | ||||||
|                     raise |  | ||||||
|  |  | ||||||
|                 retry = backoff.delay() |  | ||||||
|                 _log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry) |  | ||||||
|                 self._connected.clear() |  | ||||||
|                 await asyncio.sleep(retry) |  | ||||||
|                 await self.voice_disconnect() |  | ||||||
|                 try: |  | ||||||
|                     await self.connect(reconnect=True, timeout=self.timeout) |  | ||||||
|                 except asyncio.TimeoutError: |  | ||||||
|                     # at this point we've retried 5 times... let's continue the loop. |  | ||||||
|                     _log.warning('Could not connect to voice... Retrying...') |  | ||||||
|                     continue |  | ||||||
|  |  | ||||||
|     async def disconnect(self, *, force: bool = False) -> None: |     async def disconnect(self, *, force: bool = False) -> None: | ||||||
|         """|coro| |         """|coro| | ||||||
|  |  | ||||||
|         Disconnects this voice client from voice. |         Disconnects this voice client from voice. | ||||||
|         """ |         """ | ||||||
|         if not force and not self.is_connected(): |  | ||||||
|             return |  | ||||||
|  |  | ||||||
|         self.stop() |         self.stop() | ||||||
|         self._connected.clear() |         await self._connection.disconnect(force=force) | ||||||
|  |  | ||||||
|         try: |  | ||||||
|             if self.ws: |  | ||||||
|                 await self.ws.close() |  | ||||||
|  |  | ||||||
|             await self.voice_disconnect() |  | ||||||
|         finally: |  | ||||||
|         self.cleanup() |         self.cleanup() | ||||||
|             if self.socket: |  | ||||||
|                 self.socket.close() |  | ||||||
|  |  | ||||||
|     async def move_to(self, channel: Optional[abc.Snowflake]) -> None: |     async def move_to(self, channel: Optional[abc.Snowflake], *, timeout: Optional[float] = 30.0) -> None: | ||||||
|         """|coro| |         """|coro| | ||||||
|  |  | ||||||
|         Moves you to a different voice channel. |         Moves you to a different voice channel. | ||||||
| @@ -525,12 +345,22 @@ class VoiceClient(VoiceProtocol): | |||||||
|         ----------- |         ----------- | ||||||
|         channel: Optional[:class:`abc.Snowflake`] |         channel: Optional[:class:`abc.Snowflake`] | ||||||
|             The channel to move to. Must be a voice channel. |             The channel to move to. Must be a voice channel. | ||||||
|  |         timeout: Optional[:class:`float`] | ||||||
|  |             How long to wait for the move to complete. | ||||||
|  |  | ||||||
|  |             .. versionadded:: 2.4 | ||||||
|  |  | ||||||
|  |         Raises | ||||||
|  |         ------- | ||||||
|  |         asyncio.TimeoutError | ||||||
|  |             The move did not complete in time, but may still be ongoing. | ||||||
|         """ |         """ | ||||||
|         await self.channel.guild.change_voice_state(channel=channel) |         await self._connection.move_to(channel) | ||||||
|  |         await self._connection.wait_async(timeout) | ||||||
|  |  | ||||||
|     def is_connected(self) -> bool: |     def is_connected(self) -> bool: | ||||||
|         """Indicates if the voice client is connected to voice.""" |         """Indicates if the voice client is connected to voice.""" | ||||||
|         return self._connected.is_set() |         return self._connection.is_connected() | ||||||
|  |  | ||||||
|     # audio related |     # audio related | ||||||
|  |  | ||||||
| @@ -703,7 +533,7 @@ class VoiceClient(VoiceProtocol): | |||||||
|         if self._player is None: |         if self._player is None: | ||||||
|             raise ValueError('Not playing anything.') |             raise ValueError('Not playing anything.') | ||||||
|  |  | ||||||
|         self._player._set_source(value) |         self._player.set_source(value) | ||||||
|  |  | ||||||
|     def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None: |     def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None: | ||||||
|         """Sends an audio packet composed of the data. |         """Sends an audio packet composed of the data. | ||||||
| @@ -732,8 +562,8 @@ class VoiceClient(VoiceProtocol): | |||||||
|             encoded_data = data |             encoded_data = data | ||||||
|         packet = self._get_voice_packet(encoded_data) |         packet = self._get_voice_packet(encoded_data) | ||||||
|         try: |         try: | ||||||
|             self.socket.sendto(packet, (self.endpoint_ip, self.voice_port)) |             self._connection.send_packet(packet) | ||||||
|         except BlockingIOError: |         except OSError: | ||||||
|             _log.warning('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp) |             _log.info('A packet has been dropped (seq: %s, timestamp: %s)', self.sequence, self.timestamp) | ||||||
|  |  | ||||||
|         self.checked_add('timestamp', opus.Encoder.SAMPLES_PER_FRAME, 4294967295) |         self.checked_add('timestamp', opus.Encoder.SAMPLES_PER_FRAME, 4294967295) | ||||||
|   | |||||||
							
								
								
									
										596
									
								
								discord/voice_state.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										596
									
								
								discord/voice_state.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,596 @@ | |||||||
|  | """ | ||||||
|  | The MIT License (MIT) | ||||||
|  |  | ||||||
|  | Copyright (c) 2015-present Rapptz | ||||||
|  |  | ||||||
|  | Permission is hereby granted, free of charge, to any person obtaining a | ||||||
|  | copy of this software and associated documentation files (the "Software"), | ||||||
|  | to deal in the Software without restriction, including without limitation | ||||||
|  | the rights to use, copy, modify, merge, publish, distribute, sublicense, | ||||||
|  | and/or sell copies of the Software, and to permit persons to whom the | ||||||
|  | Software is furnished to do so, subject to the following conditions: | ||||||
|  |  | ||||||
|  | The above copyright notice and this permission notice shall be included in | ||||||
|  | all copies or substantial portions of the Software. | ||||||
|  |  | ||||||
|  | THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS | ||||||
|  | OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||||||
|  | FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||||||
|  | AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||||||
|  | LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING | ||||||
|  | FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER | ||||||
|  | DEALINGS IN THE SOFTWARE. | ||||||
|  |  | ||||||
|  |  | ||||||
|  | Some documentation to refer to: | ||||||
|  |  | ||||||
|  | - Our main web socket (mWS) sends opcode 4 with a guild ID and channel ID. | ||||||
|  | - The mWS receives VOICE_STATE_UPDATE and VOICE_SERVER_UPDATE. | ||||||
|  | - We pull the session_id from VOICE_STATE_UPDATE. | ||||||
|  | - We pull the token, endpoint and server_id from VOICE_SERVER_UPDATE. | ||||||
|  | - Then we initiate the voice web socket (vWS) pointing to the endpoint. | ||||||
|  | - We send opcode 0 with the user_id, server_id, session_id and token using the vWS. | ||||||
|  | - The vWS sends back opcode 2 with an ssrc, port, modes(array) and heartbeat_interval. | ||||||
|  | - We send a UDP discovery packet to endpoint:port and receive our IP and our port in LE. | ||||||
|  | - Then we send our IP and port via vWS with opcode 1. | ||||||
|  | - When that's all done, we receive opcode 4 from the vWS. | ||||||
|  | - Finally we can transmit data to endpoint:port. | ||||||
|  | """ | ||||||
|  |  | ||||||
|  | from __future__ import annotations | ||||||
|  |  | ||||||
|  | import select | ||||||
|  | import socket | ||||||
|  | import asyncio | ||||||
|  | import logging | ||||||
|  | import threading | ||||||
|  |  | ||||||
|  | import async_timeout | ||||||
|  |  | ||||||
|  | from typing import TYPE_CHECKING, Optional, Dict, List, Callable, Coroutine, Any, Tuple | ||||||
|  |  | ||||||
|  | from .enums import Enum | ||||||
|  | from .utils import MISSING, sane_wait_for | ||||||
|  | from .errors import ConnectionClosed | ||||||
|  | from .backoff import ExponentialBackoff | ||||||
|  | from .gateway import DiscordVoiceWebSocket | ||||||
|  |  | ||||||
|  | if TYPE_CHECKING: | ||||||
|  |     from . import abc | ||||||
|  |     from .guild import Guild | ||||||
|  |     from .user import ClientUser | ||||||
|  |     from .member import VoiceState | ||||||
|  |     from .voice_client import VoiceClient | ||||||
|  |  | ||||||
|  |     from .types.voice import ( | ||||||
|  |         GuildVoiceState as GuildVoiceStatePayload, | ||||||
|  |         VoiceServerUpdate as VoiceServerUpdatePayload, | ||||||
|  |         SupportedModes, | ||||||
|  |     ) | ||||||
|  |  | ||||||
|  |     WebsocketHook = Optional[Callable[['VoiceConnectionState', Dict[str, Any]], Coroutine[Any, Any, Any]]] | ||||||
|  |     SocketReaderCallback = Callable[[bytes], Any] | ||||||
|  |  | ||||||
|  |  | ||||||
|  | __all__ = ('VoiceConnectionState',) | ||||||
|  |  | ||||||
|  | _log = logging.getLogger(__name__) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class SocketReader(threading.Thread): | ||||||
|  |     def __init__(self, state: VoiceConnectionState) -> None: | ||||||
|  |         super().__init__(daemon=True, name=f'voice-socket-reader:{id(self):#x}') | ||||||
|  |         self.state: VoiceConnectionState = state | ||||||
|  |         self._callbacks: List[SocketReaderCallback] = [] | ||||||
|  |         self._running = threading.Event() | ||||||
|  |         self._end = threading.Event() | ||||||
|  |         # If we have paused reading due to having no callbacks | ||||||
|  |         self._idle_paused: bool = True | ||||||
|  |  | ||||||
|  |     def register(self, callback: SocketReaderCallback) -> None: | ||||||
|  |         self._callbacks.append(callback) | ||||||
|  |         if self._idle_paused: | ||||||
|  |             self._idle_paused = False | ||||||
|  |             self._running.set() | ||||||
|  |  | ||||||
|  |     def unregister(self, callback: SocketReaderCallback) -> None: | ||||||
|  |         try: | ||||||
|  |             self._callbacks.remove(callback) | ||||||
|  |         except ValueError: | ||||||
|  |             pass | ||||||
|  |         else: | ||||||
|  |             if not self._callbacks and self._running.is_set(): | ||||||
|  |                 # If running is not set, we are either explicitly paused and | ||||||
|  |                 # should be explicitly resumed, or we are already idle paused | ||||||
|  |                 self._idle_paused = True | ||||||
|  |                 self._running.clear() | ||||||
|  |  | ||||||
|  |     def pause(self) -> None: | ||||||
|  |         self._idle_paused = False | ||||||
|  |         self._running.clear() | ||||||
|  |  | ||||||
|  |     def resume(self, *, force: bool = False) -> None: | ||||||
|  |         if self._running.is_set(): | ||||||
|  |             return | ||||||
|  |         # Don't resume if there are no callbacks registered | ||||||
|  |         if not force and not self._callbacks: | ||||||
|  |             # We tried to resume but there was nothing to do, so resume when ready | ||||||
|  |             self._idle_paused = True | ||||||
|  |             return | ||||||
|  |         self._idle_paused = False | ||||||
|  |         self._running.set() | ||||||
|  |  | ||||||
|  |     def stop(self) -> None: | ||||||
|  |         self._end.set() | ||||||
|  |         self._running.set() | ||||||
|  |  | ||||||
|  |     def run(self) -> None: | ||||||
|  |         self._end.clear() | ||||||
|  |         self._running.set() | ||||||
|  |         try: | ||||||
|  |             self._do_run() | ||||||
|  |         except Exception: | ||||||
|  |             _log.exception('Error in %s', self) | ||||||
|  |         finally: | ||||||
|  |             self.stop() | ||||||
|  |             self._running.clear() | ||||||
|  |             self._callbacks.clear() | ||||||
|  |  | ||||||
|  |     def _do_run(self) -> None: | ||||||
|  |         while not self._end.is_set(): | ||||||
|  |             if not self._running.is_set(): | ||||||
|  |                 self._running.wait() | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |             # Since this socket is a non blocking socket, select has to be used to wait on it for reading. | ||||||
|  |             try: | ||||||
|  |                 readable, _, _ = select.select([self.state.socket], [], [], 30) | ||||||
|  |             except (ValueError, TypeError): | ||||||
|  |                 # The socket is either closed or doesn't exist at the moment | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |             if not readable: | ||||||
|  |                 continue | ||||||
|  |  | ||||||
|  |             try: | ||||||
|  |                 data = self.state.socket.recv(2048) | ||||||
|  |             except OSError: | ||||||
|  |                 _log.debug('Error reading from socket in %s, this should be safe to ignore', self, exc_info=True) | ||||||
|  |             else: | ||||||
|  |                 for cb in self._callbacks: | ||||||
|  |                     try: | ||||||
|  |                         cb(data) | ||||||
|  |                     except Exception: | ||||||
|  |                         _log.exception('Error calling %s in %s', cb, self) | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class ConnectionFlowState(Enum): | ||||||
|  |     """Enum representing voice connection flow state.""" | ||||||
|  |  | ||||||
|  |     # fmt: off | ||||||
|  |     disconnected            = 0 | ||||||
|  |     set_guild_voice_state   = 1 | ||||||
|  |     got_voice_state_update  = 2 | ||||||
|  |     got_voice_server_update = 3 | ||||||
|  |     got_both_voice_updates  = 4 | ||||||
|  |     websocket_connected     = 5 | ||||||
|  |     got_websocket_ready     = 6 | ||||||
|  |     got_ip_discovery        = 7 | ||||||
|  |     connected               = 8 | ||||||
|  |     # fmt: on | ||||||
|  |  | ||||||
|  |  | ||||||
|  | class VoiceConnectionState: | ||||||
|  |     """Represents the internal state of a voice connection.""" | ||||||
|  |  | ||||||
|  |     def __init__(self, voice_client: VoiceClient, *, hook: Optional[WebsocketHook] = None) -> None: | ||||||
|  |         self.voice_client = voice_client | ||||||
|  |         self.hook = hook | ||||||
|  |  | ||||||
|  |         self.timeout: float = 30.0 | ||||||
|  |         self.reconnect: bool = True | ||||||
|  |         self.self_deaf: bool = False | ||||||
|  |         self.self_mute: bool = False | ||||||
|  |         self.token: Optional[str] = None | ||||||
|  |         self.session_id: Optional[str] = None | ||||||
|  |         self.endpoint: Optional[str] = None | ||||||
|  |         self.endpoint_ip: Optional[str] = None | ||||||
|  |         self.server_id: Optional[int] = None | ||||||
|  |         self.ip: Optional[str] = None | ||||||
|  |         self.port: Optional[int] = None | ||||||
|  |         self.voice_port: Optional[int] = None | ||||||
|  |         self.secret_key: List[int] = MISSING | ||||||
|  |         self.ssrc: int = MISSING | ||||||
|  |         self.mode: SupportedModes = MISSING | ||||||
|  |         self.socket: socket.socket = MISSING | ||||||
|  |         self.ws: DiscordVoiceWebSocket = MISSING | ||||||
|  |  | ||||||
|  |         self._state: ConnectionFlowState = ConnectionFlowState.disconnected | ||||||
|  |         self._expecting_disconnect: bool = False | ||||||
|  |         self._connected = threading.Event() | ||||||
|  |         self._state_event = asyncio.Event() | ||||||
|  |         self._runner: Optional[asyncio.Task] = None | ||||||
|  |         self._connector: Optional[asyncio.Task] = None | ||||||
|  |         self._socket_reader = SocketReader(self) | ||||||
|  |         self._socket_reader.start() | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def state(self) -> ConnectionFlowState: | ||||||
|  |         return self._state | ||||||
|  |  | ||||||
|  |     @state.setter | ||||||
|  |     def state(self, state: ConnectionFlowState) -> None: | ||||||
|  |         if state is not self._state: | ||||||
|  |             _log.debug('Connection state changed to %s', state.name) | ||||||
|  |         self._state = state | ||||||
|  |         self._state_event.set() | ||||||
|  |         self._state_event.clear() | ||||||
|  |  | ||||||
|  |         if state is ConnectionFlowState.connected: | ||||||
|  |             self._connected.set() | ||||||
|  |         else: | ||||||
|  |             self._connected.clear() | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def guild(self) -> Guild: | ||||||
|  |         return self.voice_client.guild | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def user(self) -> ClientUser: | ||||||
|  |         return self.voice_client.user | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def supported_modes(self) -> Tuple[SupportedModes, ...]: | ||||||
|  |         return self.voice_client.supported_modes | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def self_voice_state(self) -> Optional[VoiceState]: | ||||||
|  |         return self.guild.me.voice | ||||||
|  |  | ||||||
|  |     async def voice_state_update(self, data: GuildVoiceStatePayload) -> None: | ||||||
|  |         channel_id = data['channel_id'] | ||||||
|  |  | ||||||
|  |         if channel_id is None: | ||||||
|  |             # If we know we're going to get a voice_state_update where we have no channel due to | ||||||
|  |             # being in the reconnect flow, we ignore it.  Otherwise, it probably wasn't from us. | ||||||
|  |             if self._expecting_disconnect: | ||||||
|  |                 self._expecting_disconnect = False | ||||||
|  |             else: | ||||||
|  |                 _log.debug('We were externally disconnected from voice.') | ||||||
|  |                 await self.disconnect() | ||||||
|  |  | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         self.session_id = data['session_id'] | ||||||
|  |  | ||||||
|  |         # we got the event while connecting | ||||||
|  |         if self.state in (ConnectionFlowState.set_guild_voice_state, ConnectionFlowState.got_voice_server_update): | ||||||
|  |             if self.state is ConnectionFlowState.set_guild_voice_state: | ||||||
|  |                 self.state = ConnectionFlowState.got_voice_state_update | ||||||
|  |             else: | ||||||
|  |                 self.state = ConnectionFlowState.got_both_voice_updates | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         if self.state is ConnectionFlowState.connected: | ||||||
|  |             self.voice_client.channel = channel_id and self.guild.get_channel(int(channel_id))  # type: ignore | ||||||
|  |  | ||||||
|  |         elif self.state is not ConnectionFlowState.disconnected: | ||||||
|  |             if channel_id != self.voice_client.channel.id: | ||||||
|  |                 # For some unfortunate reason we were moved during the connection flow | ||||||
|  |                 _log.info('Handling channel move while connecting...') | ||||||
|  |  | ||||||
|  |                 self.voice_client.channel = channel_id and self.guild.get_channel(int(channel_id))  # type: ignore | ||||||
|  |  | ||||||
|  |                 await self.soft_disconnect(with_state=ConnectionFlowState.got_voice_state_update) | ||||||
|  |                 await self.connect( | ||||||
|  |                     reconnect=self.reconnect, | ||||||
|  |                     timeout=self.timeout, | ||||||
|  |                     self_deaf=(self.self_voice_state or self).self_deaf, | ||||||
|  |                     self_mute=(self.self_voice_state or self).self_mute, | ||||||
|  |                     resume=False, | ||||||
|  |                     wait=False, | ||||||
|  |                 ) | ||||||
|  |             else: | ||||||
|  |                 _log.debug('Ignoring unexpected voice_state_update event') | ||||||
|  |  | ||||||
|  |     async def voice_server_update(self, data: VoiceServerUpdatePayload) -> None: | ||||||
|  |         self.token = data['token'] | ||||||
|  |         self.server_id = int(data['guild_id']) | ||||||
|  |         endpoint = data.get('endpoint') | ||||||
|  |  | ||||||
|  |         if self.token is None or endpoint is None: | ||||||
|  |             _log.warning( | ||||||
|  |                 'Awaiting endpoint... This requires waiting. ' | ||||||
|  |                 'If timeout occurred considering raising the timeout and reconnecting.' | ||||||
|  |             ) | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         self.endpoint, _, _ = endpoint.rpartition(':') | ||||||
|  |         if self.endpoint.startswith('wss://'): | ||||||
|  |             # Just in case, strip it off since we're going to add it later | ||||||
|  |             self.endpoint = self.endpoint[6:] | ||||||
|  |  | ||||||
|  |         # we got the event while connecting | ||||||
|  |         if self.state in (ConnectionFlowState.set_guild_voice_state, ConnectionFlowState.got_voice_state_update): | ||||||
|  |             # This gets set after READY is received | ||||||
|  |             self.endpoint_ip = MISSING | ||||||
|  |             self._create_socket() | ||||||
|  |  | ||||||
|  |             if self.state is ConnectionFlowState.set_guild_voice_state: | ||||||
|  |                 self.state = ConnectionFlowState.got_voice_server_update | ||||||
|  |             else: | ||||||
|  |                 self.state = ConnectionFlowState.got_both_voice_updates | ||||||
|  |  | ||||||
|  |         elif self.state is ConnectionFlowState.connected: | ||||||
|  |             _log.debug('Voice server update, closing old voice websocket') | ||||||
|  |             await self.ws.close(4014) | ||||||
|  |             self.state = ConnectionFlowState.got_voice_server_update | ||||||
|  |  | ||||||
|  |         elif self.state is not ConnectionFlowState.disconnected: | ||||||
|  |             _log.debug('Unexpected server update event, attempting to handle') | ||||||
|  |  | ||||||
|  |             await self.soft_disconnect(with_state=ConnectionFlowState.got_voice_server_update) | ||||||
|  |             await self.connect( | ||||||
|  |                 reconnect=self.reconnect, | ||||||
|  |                 timeout=self.timeout, | ||||||
|  |                 self_deaf=(self.self_voice_state or self).self_deaf, | ||||||
|  |                 self_mute=(self.self_voice_state or self).self_mute, | ||||||
|  |                 resume=False, | ||||||
|  |                 wait=False, | ||||||
|  |             ) | ||||||
|  |             self._create_socket() | ||||||
|  |  | ||||||
|  |     async def connect( | ||||||
|  |         self, *, reconnect: bool, timeout: float, self_deaf: bool, self_mute: bool, resume: bool, wait: bool = True | ||||||
|  |     ) -> None: | ||||||
|  |         if self._connector: | ||||||
|  |             self._connector.cancel() | ||||||
|  |             self._connector = None | ||||||
|  |  | ||||||
|  |         if self._runner: | ||||||
|  |             self._runner.cancel() | ||||||
|  |             self._runner = None | ||||||
|  |  | ||||||
|  |         self.timeout = timeout | ||||||
|  |         self.reconnect = reconnect | ||||||
|  |         self._connector = self.voice_client.loop.create_task( | ||||||
|  |             self._wrap_connect(reconnect, timeout, self_deaf, self_mute, resume), name='Voice connector' | ||||||
|  |         ) | ||||||
|  |         if wait: | ||||||
|  |             await self._connector | ||||||
|  |  | ||||||
|  |     async def _wrap_connect(self, *args: Any) -> None: | ||||||
|  |         try: | ||||||
|  |             await self._connect(*args) | ||||||
|  |         except asyncio.CancelledError: | ||||||
|  |             _log.debug('Cancelling voice connection') | ||||||
|  |             await self.soft_disconnect() | ||||||
|  |             raise | ||||||
|  |         except asyncio.TimeoutError: | ||||||
|  |             _log.info('Timed out connecting to voice') | ||||||
|  |             await self.disconnect() | ||||||
|  |             raise | ||||||
|  |         except Exception: | ||||||
|  |             _log.exception('Error connecting to voice... disconnecting') | ||||||
|  |             await self.disconnect() | ||||||
|  |             raise | ||||||
|  |  | ||||||
|  |     async def _connect(self, reconnect: bool, timeout: float, self_deaf: bool, self_mute: bool, resume: bool) -> None: | ||||||
|  |         _log.info('Connecting to voice...') | ||||||
|  |  | ||||||
|  |         async with async_timeout.timeout(timeout): | ||||||
|  |             for i in range(5): | ||||||
|  |                 _log.info('Starting voice handshake... (connection attempt %d)', i + 1) | ||||||
|  |  | ||||||
|  |                 await self._voice_connect(self_deaf=self_deaf, self_mute=self_mute) | ||||||
|  |                 # Setting this unnecessarily will break reconnecting | ||||||
|  |                 if self.state is ConnectionFlowState.disconnected: | ||||||
|  |                     self.state = ConnectionFlowState.set_guild_voice_state | ||||||
|  |  | ||||||
|  |                 await self._wait_for_state(ConnectionFlowState.got_both_voice_updates) | ||||||
|  |  | ||||||
|  |                 _log.info('Voice handshake complete. Endpoint found: %s', self.endpoint) | ||||||
|  |  | ||||||
|  |                 try: | ||||||
|  |                     self.ws = await self._connect_websocket(resume) | ||||||
|  |                     await self._handshake_websocket() | ||||||
|  |                     break | ||||||
|  |                 except ConnectionClosed: | ||||||
|  |                     if reconnect: | ||||||
|  |                         wait = 1 + i * 2.0 | ||||||
|  |                         _log.exception('Failed to connect to voice... Retrying in %ss...', wait) | ||||||
|  |                         await self.disconnect(cleanup=False) | ||||||
|  |                         await asyncio.sleep(wait) | ||||||
|  |                         continue | ||||||
|  |                     else: | ||||||
|  |                         await self.disconnect() | ||||||
|  |                         raise | ||||||
|  |  | ||||||
|  |         _log.info('Voice connection complete.') | ||||||
|  |  | ||||||
|  |         if not self._runner: | ||||||
|  |             self._runner = self.voice_client.loop.create_task(self._poll_voice_ws(reconnect), name='Voice websocket poller') | ||||||
|  |  | ||||||
|  |     async def disconnect(self, *, force: bool = True, cleanup: bool = True) -> None: | ||||||
|  |         if not force and not self.is_connected(): | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             if self.ws: | ||||||
|  |                 await self.ws.close() | ||||||
|  |             await self._voice_disconnect() | ||||||
|  |         except Exception: | ||||||
|  |             _log.debug('Ignoring exception disconnecting from voice', exc_info=True) | ||||||
|  |         finally: | ||||||
|  |             self.ip = MISSING | ||||||
|  |             self.port = MISSING | ||||||
|  |             self.state = ConnectionFlowState.disconnected | ||||||
|  |             self._socket_reader.pause() | ||||||
|  |  | ||||||
|  |             # Flip the connected event to unlock any waiters | ||||||
|  |             self._connected.set() | ||||||
|  |             self._connected.clear() | ||||||
|  |  | ||||||
|  |             if cleanup: | ||||||
|  |                 self._socket_reader.stop() | ||||||
|  |                 self.voice_client.cleanup() | ||||||
|  |  | ||||||
|  |             if self.socket: | ||||||
|  |                 self.socket.close() | ||||||
|  |  | ||||||
|  |     async def soft_disconnect(self, *, with_state: ConnectionFlowState = ConnectionFlowState.got_both_voice_updates) -> None: | ||||||
|  |         _log.debug('Soft disconnecting from voice') | ||||||
|  |         # Stop the websocket reader because closing the websocket will trigger an unwanted reconnect | ||||||
|  |         if self._runner: | ||||||
|  |             self._runner.cancel() | ||||||
|  |             self._runner = None | ||||||
|  |  | ||||||
|  |         try: | ||||||
|  |             if self.ws: | ||||||
|  |                 await self.ws.close() | ||||||
|  |         except Exception: | ||||||
|  |             _log.debug('Ignoring exception soft disconnecting from voice', exc_info=True) | ||||||
|  |         finally: | ||||||
|  |             self.ip = MISSING | ||||||
|  |             self.port = MISSING | ||||||
|  |             self.state = with_state | ||||||
|  |             self._socket_reader.pause() | ||||||
|  |  | ||||||
|  |             if self.socket: | ||||||
|  |                 self.socket.close() | ||||||
|  |  | ||||||
|  |     async def move_to(self, channel: Optional[abc.Snowflake]) -> None: | ||||||
|  |         if channel is None: | ||||||
|  |             await self.disconnect() | ||||||
|  |             return | ||||||
|  |  | ||||||
|  |         await self.voice_client.channel.guild.change_voice_state(channel=channel) | ||||||
|  |         self.state = ConnectionFlowState.set_guild_voice_state | ||||||
|  |  | ||||||
|  |     def wait(self, timeout: Optional[float] = None) -> bool: | ||||||
|  |         return self._connected.wait(timeout) | ||||||
|  |  | ||||||
|  |     async def wait_async(self, timeout: Optional[float] = None) -> None: | ||||||
|  |         await self._wait_for_state(ConnectionFlowState.connected, timeout=timeout) | ||||||
|  |  | ||||||
|  |     def is_connected(self) -> bool: | ||||||
|  |         return self.state is ConnectionFlowState.connected | ||||||
|  |  | ||||||
|  |     def send_packet(self, packet: bytes) -> None: | ||||||
|  |         self.socket.sendall(packet) | ||||||
|  |  | ||||||
|  |     def add_socket_listener(self, callback: SocketReaderCallback) -> None: | ||||||
|  |         _log.debug('Registering socket listener callback %s', callback) | ||||||
|  |         self._socket_reader.register(callback) | ||||||
|  |  | ||||||
|  |     def remove_socket_listener(self, callback: SocketReaderCallback) -> None: | ||||||
|  |         _log.debug('Unregistering socket listener callback %s', callback) | ||||||
|  |         self._socket_reader.unregister(callback) | ||||||
|  |  | ||||||
|  |     async def _wait_for_state( | ||||||
|  |         self, state: ConnectionFlowState, *other_states: ConnectionFlowState, timeout: Optional[float] = None | ||||||
|  |     ) -> None: | ||||||
|  |         states = (state, *other_states) | ||||||
|  |         while True: | ||||||
|  |             if self.state in states: | ||||||
|  |                 return | ||||||
|  |             await sane_wait_for([self._state_event.wait()], timeout=timeout) | ||||||
|  |  | ||||||
|  |     async def _voice_connect(self, *, self_deaf: bool = False, self_mute: bool = False) -> None: | ||||||
|  |         channel = self.voice_client.channel | ||||||
|  |         await channel.guild.change_voice_state(channel=channel, self_deaf=self_deaf, self_mute=self_mute) | ||||||
|  |  | ||||||
|  |     async def _voice_disconnect(self) -> None: | ||||||
|  |         _log.info( | ||||||
|  |             'The voice handshake is being terminated for Channel ID %s (Guild ID %s)', | ||||||
|  |             self.voice_client.channel.id, | ||||||
|  |             self.voice_client.guild.id, | ||||||
|  |         ) | ||||||
|  |         self.state = ConnectionFlowState.disconnected | ||||||
|  |         await self.voice_client.channel.guild.change_voice_state(channel=None) | ||||||
|  |         self._expecting_disconnect = True | ||||||
|  |  | ||||||
|  |     async def _connect_websocket(self, resume: bool) -> DiscordVoiceWebSocket: | ||||||
|  |         ws = await DiscordVoiceWebSocket.from_connection_state(self, resume=resume, hook=self.hook) | ||||||
|  |         self.state = ConnectionFlowState.websocket_connected | ||||||
|  |         return ws | ||||||
|  |  | ||||||
|  |     async def _handshake_websocket(self) -> None: | ||||||
|  |         while not self.ip: | ||||||
|  |             await self.ws.poll_event() | ||||||
|  |         self.state = ConnectionFlowState.got_ip_discovery | ||||||
|  |         while self.ws.secret_key is None: | ||||||
|  |             await self.ws.poll_event() | ||||||
|  |         self.state = ConnectionFlowState.connected | ||||||
|  |  | ||||||
|  |     def _create_socket(self) -> None: | ||||||
|  |         self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | ||||||
|  |         self.socket.setblocking(False) | ||||||
|  |         self._socket_reader.resume() | ||||||
|  |  | ||||||
|  |     async def _poll_voice_ws(self, reconnect: bool) -> None: | ||||||
|  |         backoff = ExponentialBackoff() | ||||||
|  |         while True: | ||||||
|  |             try: | ||||||
|  |                 await self.ws.poll_event() | ||||||
|  |             except asyncio.CancelledError: | ||||||
|  |                 return | ||||||
|  |             except (ConnectionClosed, asyncio.TimeoutError) as exc: | ||||||
|  |                 if isinstance(exc, ConnectionClosed): | ||||||
|  |                     # The following close codes are undocumented so I will document them here. | ||||||
|  |                     # 1000 - normal closure (obviously) | ||||||
|  |                     # 4014 - we were externally disconnected (voice channel deleted, we were moved, etc) | ||||||
|  |                     # 4015 - voice server has crashed | ||||||
|  |                     if exc.code in (1000, 4015): | ||||||
|  |                         _log.info('Disconnecting from voice normally, close code %d.', exc.code) | ||||||
|  |                         await self.disconnect() | ||||||
|  |                         break | ||||||
|  |  | ||||||
|  |                     if exc.code == 4014: | ||||||
|  |                         _log.info('Disconnected from voice by force... potentially reconnecting.') | ||||||
|  |                         successful = await self._potential_reconnect() | ||||||
|  |                         if not successful: | ||||||
|  |                             _log.info('Reconnect was unsuccessful, disconnecting from voice normally...') | ||||||
|  |                             await self.disconnect() | ||||||
|  |                             break | ||||||
|  |                         else: | ||||||
|  |                             continue | ||||||
|  |  | ||||||
|  |                     _log.debug('Not handling close code %s (%s)', exc.code, exc.reason or 'no reason') | ||||||
|  |  | ||||||
|  |                 if not reconnect: | ||||||
|  |                     await self.disconnect() | ||||||
|  |                     raise | ||||||
|  |  | ||||||
|  |                 retry = backoff.delay() | ||||||
|  |                 _log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry) | ||||||
|  |                 await asyncio.sleep(retry) | ||||||
|  |                 await self.disconnect(cleanup=False) | ||||||
|  |  | ||||||
|  |                 try: | ||||||
|  |                     await self._connect( | ||||||
|  |                         reconnect=reconnect, | ||||||
|  |                         timeout=self.timeout, | ||||||
|  |                         self_deaf=(self.self_voice_state or self).self_deaf, | ||||||
|  |                         self_mute=(self.self_voice_state or self).self_mute, | ||||||
|  |                         resume=False, | ||||||
|  |                     ) | ||||||
|  |                 except asyncio.TimeoutError: | ||||||
|  |                     # at this point we've retried 5 times... let's continue the loop. | ||||||
|  |                     _log.warning('Could not connect to voice... Retrying...') | ||||||
|  |                     continue | ||||||
|  |  | ||||||
|  |     async def _potential_reconnect(self) -> bool: | ||||||
|  |         try: | ||||||
|  |             await self._wait_for_state( | ||||||
|  |                 ConnectionFlowState.got_voice_server_update, ConnectionFlowState.got_both_voice_updates, timeout=self.timeout | ||||||
|  |             ) | ||||||
|  |         except asyncio.TimeoutError: | ||||||
|  |             return False | ||||||
|  |         try: | ||||||
|  |             self.ws = await self._connect_websocket(False) | ||||||
|  |             await self._handshake_websocket() | ||||||
|  |         except (ConnectionClosed, asyncio.TimeoutError): | ||||||
|  |             return False | ||||||
|  |         else: | ||||||
|  |             return True | ||||||
		Reference in New Issue
	
	Block a user