Type-hint voice_client / player

This commit is contained in:
Josh
2021-06-28 14:59:14 +10:00
committed by GitHub
parent cd6b453cb3
commit 5acea453cc
3 changed files with 225 additions and 128 deletions

View File

@ -20,9 +20,9 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE.
"""
"""Some documentation to refer to:
Some documentation to refer to:
- Our main web socket (mWS) sends opcode 4 with a guild ID and channel ID.
- The mWS receives VOICE_STATE_UPDATE and VOICE_SERVER_UPDATE.
@ -37,21 +37,41 @@ DEALINGS IN THE SOFTWARE.
- Finally we can transmit data to endpoint:port.
"""
from __future__ import annotations
import asyncio
import socket
import logging
import struct
import threading
from typing import Any, Callable
from typing import Any, Callable, List, Optional, TYPE_CHECKING, Tuple
from . import opus, utils
from .backoff import ExponentialBackoff
from .gateway import *
from .errors import ClientException, ConnectionClosed
from .player import AudioPlayer, AudioSource
from .utils import MISSING
if TYPE_CHECKING:
from .client import Client
from .guild import Guild
from .state import ConnectionState
from .user import ClientUser
from .opus import Encoder
from . import abc
from .types.voice import (
GuildVoiceState as GuildVoiceStatePayload,
VoiceServerUpdate as VoiceServerUpdatePayload,
SupportedModes,
)
has_nacl: bool
try:
import nacl.secret
import nacl.secret # type: ignore
has_nacl = True
except ImportError:
has_nacl = False
@ -61,7 +81,10 @@ __all__ = (
'VoiceClient',
)
log = logging.getLogger(__name__)
log: logging.Logger = logging.getLogger(__name__)
class VoiceProtocol:
"""A class that represents the Discord voice protocol.
@ -84,11 +107,11 @@ class VoiceProtocol:
The voice channel that is being connected to.
"""
def __init__(self, client, channel):
self.client = client
self.channel = channel
def __init__(self, client: Client, channel: abc.Connectable) -> None:
self.client: Client = client
self.channel: abc.Connectable = channel
async def on_voice_state_update(self, data):
async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
"""|coro|
An abstract method that is called when the client's voice state
@ -105,7 +128,7 @@ class VoiceProtocol:
"""
raise NotImplementedError
async def on_voice_server_update(self, data):
async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
"""|coro|
An abstract method that is called when initially connecting to voice.
@ -122,7 +145,7 @@ class VoiceProtocol:
"""
raise NotImplementedError
async def connect(self, *, timeout: float, reconnect: bool):
async def connect(self, *, timeout: float, reconnect: bool) -> None:
"""|coro|
An abstract method called when the client initiates the connection request.
@ -145,7 +168,7 @@ class VoiceProtocol:
"""
raise NotImplementedError
async def disconnect(self, *, force: bool):
async def disconnect(self, *, force: bool) -> None:
"""|coro|
An abstract method called when the client terminates the connection.
@ -159,7 +182,7 @@ class VoiceProtocol:
"""
raise NotImplementedError
def cleanup(self):
def cleanup(self) -> None:
"""This method *must* be called to ensure proper clean-up during a disconnect.
It is advisable to call this from within :meth:`disconnect` when you are
@ -198,48 +221,55 @@ class VoiceClient(VoiceProtocol):
loop: :class:`asyncio.AbstractEventLoop`
The event loop that the voice client is running on.
"""
def __init__(self, client, channel):
endpoint_ip: str
voice_port: int
secret_key: List[int]
ssrc: int
def __init__(self, client: Client, channel: abc.Connectable):
if not has_nacl:
raise RuntimeError("PyNaCl library needed in order to use voice")
super().__init__(client, channel)
state = client._connection
self.token = None
self.socket = None
self.loop = state.loop
self._state = state
self.token: str = MISSING
self.socket = MISSING
self.loop: asyncio.AbstractEventLoop = state.loop
self._state: ConnectionState = state
# this will be used in the AudioPlayer thread
self._connected = threading.Event()
self._connected: threading.Event = threading.Event()
self._handshaking = False
self._potentially_reconnecting = False
self._voice_state_complete = asyncio.Event()
self._voice_server_complete = asyncio.Event()
self._handshaking: bool = False
self._potentially_reconnecting: bool = False
self._voice_state_complete: asyncio.Event = asyncio.Event()
self._voice_server_complete: asyncio.Event = asyncio.Event()
self.mode = None
self._connections = 0
self.sequence = 0
self.timestamp = 0
self._runner = None
self._player = None
self.encoder = None
self._lite_nonce = 0
self.ws = None
self.mode: str = MISSING
self._connections: int = 0
self.sequence: int = 0
self.timestamp: int = 0
self.timeout: float = 0
self._runner: asyncio.Task = MISSING
self._player: Optional[AudioPlayer] = None
self.encoder: Encoder = MISSING
self._lite_nonce: int = 0
self.ws: DiscordVoiceWebSocket = MISSING
warn_nacl = not has_nacl
supported_modes = (
supported_modes: Tuple[SupportedModes, ...] = (
'xsalsa20_poly1305_lite',
'xsalsa20_poly1305_suffix',
'xsalsa20_poly1305',
)
@property
def guild(self):
def guild(self) -> Optional[Guild]:
"""Optional[:class:`Guild`]: The guild we're connected to, if applicable."""
return getattr(self.channel, 'guild', None)
@property
def user(self):
def user(self) -> ClientUser:
""":class:`ClientUser`: The user connected to voice (i.e. ourselves)."""
return self._state.user
@ -252,7 +282,7 @@ class VoiceClient(VoiceProtocol):
# connection related
async def on_voice_state_update(self, data):
async def on_voice_state_update(self, data: GuildVoiceStatePayload) -> None:
self.session_id = data['session_id']
channel_id = data['channel_id']
@ -265,11 +295,11 @@ class VoiceClient(VoiceProtocol):
await self.disconnect()
else:
guild = self.guild
self.channel = channel_id and guild and guild.get_channel(int(channel_id))
self.channel = channel_id and guild and guild.get_channel(int(channel_id)) # type: ignore
else:
self._voice_state_complete.set()
async def on_voice_server_update(self, data):
async def on_voice_server_update(self, data: VoiceServerUpdatePayload) -> None:
if self._voice_server_complete.is_set():
log.info('Ignoring extraneous voice server update.')
return
@ -289,7 +319,7 @@ class VoiceClient(VoiceProtocol):
self.endpoint = self.endpoint[6:]
# This gets set later
self.endpoint_ip = None
self.endpoint_ip = MISSING
self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
self.socket.setblocking(False)
@ -301,27 +331,27 @@ class VoiceClient(VoiceProtocol):
self._voice_server_complete.set()
async def voice_connect(self):
async def voice_connect(self) -> None:
await self.channel.guild.change_voice_state(channel=self.channel)
async def voice_disconnect(self):
async def voice_disconnect(self) -> None:
log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id)
await self.channel.guild.change_voice_state(channel=None)
def prepare_handshake(self):
def prepare_handshake(self) -> None:
self._voice_state_complete.clear()
self._voice_server_complete.clear()
self._handshaking = True
log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1)
self._connections += 1
def finish_handshake(self):
def finish_handshake(self) -> None:
log.info('Voice handshake complete. Endpoint found %s', self.endpoint)
self._handshaking = False
self._voice_server_complete.clear()
self._voice_state_complete.clear()
async def connect_websocket(self):
async def connect_websocket(self) -> DiscordVoiceWebSocket:
ws = await DiscordVoiceWebSocket.from_client(self)
self._connected.clear()
while ws.secret_key is None:
@ -329,7 +359,7 @@ class VoiceClient(VoiceProtocol):
self._connected.set()
return ws
async def connect(self, *, reconnect: bool, timeout: bool):
async def connect(self, *, reconnect: bool, timeout: float) ->None:
log.info('Connecting to voice...')
self.timeout = timeout
@ -365,10 +395,10 @@ class VoiceClient(VoiceProtocol):
else:
raise
if self._runner is None:
if self._runner is MISSING:
self._runner = self.loop.create_task(self.poll_voice_ws(reconnect))
async def potential_reconnect(self):
async def potential_reconnect(self) -> bool:
# Attempt to stop the player thread from playing early
self._connected.clear()
self.prepare_handshake()
@ -391,7 +421,7 @@ class VoiceClient(VoiceProtocol):
return True
@property
def latency(self):
def latency(self) -> float:
""":class:`float`: Latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.
This could be referred to as the Discord Voice WebSocket latency and is
@ -403,7 +433,7 @@ class VoiceClient(VoiceProtocol):
return float("inf") if not ws else ws.latency
@property
def average_latency(self):
def average_latency(self) -> float:
""":class:`float`: Average of most recent 20 HEARTBEAT latencies in seconds.
.. versionadded:: 1.4
@ -411,7 +441,7 @@ class VoiceClient(VoiceProtocol):
ws = self.ws
return float("inf") if not ws else ws.average_latency
async def poll_voice_ws(self, reconnect):
async def poll_voice_ws(self, reconnect: bool) -> None:
backoff = ExponentialBackoff()
while True:
try:
@ -452,7 +482,7 @@ class VoiceClient(VoiceProtocol):
log.warning('Could not connect to voice... Retrying...')
continue
async def disconnect(self, *, force: bool = False):
async def disconnect(self, *, force: bool = False) -> None:
"""|coro|
Disconnects this voice client from voice.
@ -473,7 +503,7 @@ class VoiceClient(VoiceProtocol):
if self.socket:
self.socket.close()
async def move_to(self, channel):
async def move_to(self, channel: abc.Snowflake) -> None:
"""|coro|
Moves you to a different voice channel.
@ -485,7 +515,7 @@ class VoiceClient(VoiceProtocol):
"""
await self.channel.guild.change_voice_state(channel=channel)
def is_connected(self):
def is_connected(self) -> bool:
"""Indicates if the voice client is connected to voice."""
return self._connected.is_set()
@ -504,20 +534,20 @@ class VoiceClient(VoiceProtocol):
encrypt_packet = getattr(self, '_encrypt_' + self.mode)
return encrypt_packet(header, data)
def _encrypt_xsalsa20_poly1305(self, header, data):
def _encrypt_xsalsa20_poly1305(self, header: bytes, data) -> bytes:
box = nacl.secret.SecretBox(bytes(self.secret_key))
nonce = bytearray(24)
nonce[:12] = header
return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext
def _encrypt_xsalsa20_poly1305_suffix(self, header, data):
def _encrypt_xsalsa20_poly1305_suffix(self, header: bytes, data) -> bytes:
box = nacl.secret.SecretBox(bytes(self.secret_key))
nonce = nacl.utils.random(nacl.secret.SecretBox.NONCE_SIZE)
return header + box.encrypt(bytes(data), nonce).ciphertext + nonce
def _encrypt_xsalsa20_poly1305_lite(self, header, data):
def _encrypt_xsalsa20_poly1305_lite(self, header: bytes, data) -> bytes:
box = nacl.secret.SecretBox(bytes(self.secret_key))
nonce = bytearray(24)
@ -526,7 +556,7 @@ class VoiceClient(VoiceProtocol):
return header + box.encrypt(bytes(data), bytes(nonce)).ciphertext + nonce[:4]
def play(self, source: AudioSource, *, after: Callable[[Exception], Any]=None):
def play(self, source: AudioSource, *, after: Callable[[Optional[Exception]], Any]=None) -> None:
"""Plays an :class:`AudioSource`.
The finalizer, ``after`` is called after the source has been exhausted
@ -570,32 +600,32 @@ class VoiceClient(VoiceProtocol):
self._player = AudioPlayer(source, self, after=after)
self._player.start()
def is_playing(self):
def is_playing(self) -> bool:
"""Indicates if we're currently playing audio."""
return self._player is not None and self._player.is_playing()
def is_paused(self):
def is_paused(self) -> bool:
"""Indicates if we're playing audio, but if we're paused."""
return self._player is not None and self._player.is_paused()
def stop(self):
def stop(self) -> None:
"""Stops playing audio."""
if self._player:
self._player.stop()
self._player = None
def pause(self):
def pause(self) -> None:
"""Pauses the audio playing."""
if self._player:
self._player.pause()
def resume(self):
def resume(self) -> None:
"""Resumes the audio playing."""
if self._player:
self._player.resume()
@property
def source(self):
def source(self) -> Optional[AudioSource]:
"""Optional[:class:`AudioSource`]: The audio source being played, if playing.
This property can also be used to change the audio source currently being played.
@ -603,7 +633,7 @@ class VoiceClient(VoiceProtocol):
return self._player.source if self._player else None
@source.setter
def source(self, value):
def source(self, value: AudioSource) -> None:
if not isinstance(value, AudioSource):
raise TypeError(f'expected AudioSource not {value.__class__.__name__}.')
@ -612,7 +642,7 @@ class VoiceClient(VoiceProtocol):
self._player._set_source(value)
def send_audio_packet(self, data, *, encode=True):
def send_audio_packet(self, data: bytes, *, encode: bool = True) -> None:
"""Sends an audio packet composed of the data.
You must be connected to play audio.