Merge pull request #44

* Typehint gateway.py

* Add relevant typehints to gateway.py to voice_client.py

* Change EventListener to subclass NamedTuple

* Add return type for DiscordWebSocket.wait_for

* Correct deque typehint

* Remove unnecessary typehints for literals

* Use type aliases

* Merge branch '2.0' into pr7422
This commit is contained in:
Arthur 2021-09-02 22:50:19 +02:00 committed by GitHub
parent 1032728311
commit 3ffe134895
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 161 additions and 99 deletions

View File

@ -22,8 +22,12 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
DEALINGS IN THE SOFTWARE. DEALINGS IN THE SOFTWARE.
""" """
from __future__ import annotations
from typing import TYPE_CHECKING, TypedDict, Any, Optional, List, TypeVar, Type, Dict, Callable, Coroutine, NamedTuple, Deque
import asyncio import asyncio
from collections import namedtuple, deque from collections import deque
import concurrent.futures import concurrent.futures
import logging import logging
import struct import struct
@ -38,9 +42,25 @@ import aiohttp
from . import utils from . import utils
from .activity import BaseActivity from .activity import BaseActivity
from .enums import SpeakingState from .enums import SpeakingState
from .errors import ConnectionClosed, InvalidArgument from .errors import ConnectionClosed, InvalidArgument
if TYPE_CHECKING:
from .client import Client
from .state import ConnectionState
from .voice_client import VoiceClient
T = TypeVar('T')
DWS = TypeVar('DWS', bound='DiscordWebSocket')
DVWS = TypeVar('DVWS', bound='DiscordVoiceWebSocket')
Coro = Callable[..., Coroutine[Any, Any, Any]]
Predicate = Callable[[Dict[str, Any]], bool]
DataCallable = Callable[[Dict[str, Any]], T]
Result = Optional[DataCallable[Any]]
_log: logging.Logger = logging.getLogger(__name__)
_log = logging.getLogger(__name__)
__all__ = ( __all__ = (
'DiscordWebSocket', 'DiscordWebSocket',
@ -50,36 +70,49 @@ __all__ = (
'ReconnectWebSocket', 'ReconnectWebSocket',
) )
class Heartbeat(TypedDict):
op: int
d: int
class ReconnectWebSocket(Exception): class ReconnectWebSocket(Exception):
"""Signals to safely reconnect the websocket.""" """Signals to safely reconnect the websocket."""
def __init__(self, shard_id, *, resume=True): def __init__(self, shard_id: Optional[int], *, resume: bool = True) -> None:
self.shard_id = shard_id self.shard_id: Optional[int] = shard_id
self.resume = resume self.resume: bool = resume
self.op = 'RESUME' if resume else 'IDENTIFY' self.op = 'RESUME' if resume else 'IDENTIFY'
class WebSocketClosure(Exception): class WebSocketClosure(Exception):
"""An exception to make up for the fact that aiohttp doesn't signal closure.""" """An exception to make up for the fact that aiohttp doesn't signal closure."""
pass pass
EventListener = namedtuple('EventListener', 'predicate event result future')
class EventListener(NamedTuple):
predicate: Predicate
event: str
result: Result
future: asyncio.Future
class GatewayRatelimiter: class GatewayRatelimiter:
def __init__(self, count=110, per=60.0): def __init__(self, count: int = 110, per: float = 60.0) -> None:
# The default is 110 to give room for at least 10 heartbeats per minute # The default is 110 to give room for at least 10 heartbeats per minute
self.max = count self.max: int = count
self.remaining = count self.remaining: int = count
self.window = 0.0 self.window: float = 0.0
self.per = per self.per: float = per
self.lock = asyncio.Lock() self.lock: asyncio.Lock = asyncio.Lock()
self.shard_id = None self.shard_id: Optional[int] = None
def is_ratelimited(self): def is_ratelimited(self) -> bool:
current = time.time() current = time.time()
if current > self.window + self.per: if current > self.window + self.per:
return False return False
return self.remaining == 0 return self.remaining == 0
def get_delay(self): def get_delay(self) -> float:
current = time.time() current = time.time()
if current > self.window + self.per: if current > self.window + self.per:
@ -97,7 +130,7 @@ class GatewayRatelimiter:
return 0.0 return 0.0
async def block(self): async def block(self) -> None:
async with self.lock: async with self.lock:
delta = self.get_delay() delta = self.get_delay()
if delta: if delta:
@ -106,27 +139,27 @@ class GatewayRatelimiter:
class KeepAliveHandler(threading.Thread): class KeepAliveHandler(threading.Thread):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
ws = kwargs.pop('ws', None) ws = kwargs.pop('ws')
interval = kwargs.pop('interval', None) interval = kwargs.pop('interval', None)
shard_id = kwargs.pop('shard_id', None) shard_id = kwargs.pop('shard_id', None)
threading.Thread.__init__(self, *args, **kwargs) threading.Thread.__init__(self, *args, **kwargs)
self.ws = ws self.ws: DiscordWebSocket = ws
self._main_thread_id = ws.thread_id self._main_thread_id: int = ws.thread_id
self.interval = interval self.interval: Optional[float] = interval
self.daemon = True self.daemon: bool = True
self.shard_id = shard_id self.shard_id: Optional[int] = shard_id
self.msg = 'Keeping shard ID %s websocket alive with sequence %s.' self.msg: str = 'Keeping shard ID %s websocket alive with sequence %s.'
self.block_msg = 'Shard ID %s heartbeat blocked for more than %s seconds.' self.block_msg: str = 'Shard ID %s heartbeat blocked for more than %s seconds.'
self.behind_msg = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.' self.behind_msg: str = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.'
self._stop_ev = threading.Event() self._stop_ev: threading.Event = threading.Event()
self._last_ack = time.perf_counter() self._last_ack: float = time.perf_counter()
self._last_send = time.perf_counter() self._last_send: float = time.perf_counter()
self._last_recv = time.perf_counter() self._last_recv: float = time.perf_counter()
self.latency = float('inf') self.latency: float = float('inf')
self.heartbeat_timeout = ws._max_heartbeat_timeout self.heartbeat_timeout: float = ws._max_heartbeat_timeout
def run(self): def run(self) -> None:
while not self._stop_ev.wait(self.interval): while not self._stop_ev.wait(self.interval):
if self._last_recv + self.heartbeat_timeout < time.perf_counter(): if self._last_recv + self.heartbeat_timeout < time.perf_counter():
_log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id) _log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id)
@ -168,19 +201,20 @@ class KeepAliveHandler(threading.Thread):
else: else:
self._last_send = time.perf_counter() self._last_send = time.perf_counter()
def get_payload(self): def get_payload(self) -> Heartbeat:
return { return {
'op': self.ws.HEARTBEAT, 'op': self.ws.HEARTBEAT,
'd': self.ws.sequence # the websocket's sequence won't be None here
'd': self.ws.sequence # type: ignore
} }
def stop(self): def stop(self) -> None:
self._stop_ev.set() self._stop_ev.set()
def tick(self): def tick(self) -> None:
self._last_recv = time.perf_counter() self._last_recv = time.perf_counter()
def ack(self): def ack(self) -> None:
ack_time = time.perf_counter() ack_time = time.perf_counter()
self._last_ack = ack_time self._last_ack = ack_time
self.latency = ack_time - self._last_send self.latency = ack_time - self._last_send
@ -188,30 +222,32 @@ class KeepAliveHandler(threading.Thread):
_log.warning(self.behind_msg, self.shard_id, self.latency) _log.warning(self.behind_msg, self.shard_id, self.latency)
class VoiceKeepAliveHandler(KeepAliveHandler): class VoiceKeepAliveHandler(KeepAliveHandler):
def __init__(self, *args, **kwargs): def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.recent_ack_latencies = deque(maxlen=20) self.recent_ack_latencies: Deque[float] = deque(maxlen=20)
self.msg = 'Keeping shard ID %s voice websocket alive with timestamp %s.' self.msg = 'Keeping shard ID %s voice websocket alive with timestamp %s.'
self.block_msg = 'Shard ID %s voice heartbeat blocked for more than %s seconds' self.block_msg = 'Shard ID %s voice heartbeat blocked for more than %s seconds'
self.behind_msg = 'High socket latency, shard ID %s heartbeat is %.1fs behind' self.behind_msg = 'High socket latency, shard ID %s heartbeat is %.1fs behind'
def get_payload(self): def get_payload(self) -> Heartbeat:
return { return {
'op': self.ws.HEARTBEAT, 'op': self.ws.HEARTBEAT,
'd': int(time.time() * 1000) 'd': int(time.time() * 1000)
} }
def ack(self): def ack(self) -> None:
ack_time = time.perf_counter() ack_time = time.perf_counter()
self._last_ack = ack_time self._last_ack = ack_time
self._last_recv = ack_time self._last_recv = ack_time
self.latency = ack_time - self._last_send self.latency = ack_time - self._last_send
self.recent_ack_latencies.append(self.latency) self.recent_ack_latencies.append(self.latency)
class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse): class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse):
async def close(self, *, code: int = 4000, message: bytes = b'') -> bool: async def close(self, *, code: int = 4000, message: bytes = b'') -> bool:
return await super().close(code=code, message=message) return await super().close(code=code, message=message)
class DiscordWebSocket: class DiscordWebSocket:
"""Implements a WebSocket for Discord's gateway v6. """Implements a WebSocket for Discord's gateway v6.
@ -266,41 +302,53 @@ class DiscordWebSocket:
HEARTBEAT_ACK = 11 HEARTBEAT_ACK = 11
GUILD_SYNC = 12 GUILD_SYNC = 12
def __init__(self, socket, *, loop): def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.AbstractEventLoop) -> None:
self.socket = socket self.socket: aiohttp.ClientWebSocketResponse = socket
self.loop = loop self.loop: asyncio.AbstractEventLoop = loop
# an empty dispatcher to prevent crashes # an empty dispatcher to prevent crashes
self._dispatch = lambda *args: None self._dispatch = lambda *args: None
# generic event listeners # generic event listeners
self._dispatch_listeners = [] self._dispatch_listeners: List[EventListener] = []
# the keep alive # the keep alive
self._keep_alive = None self._keep_alive: Optional[KeepAliveHandler] = None
self.thread_id = threading.get_ident() self.thread_id: int = threading.get_ident()
# ws related stuff # ws related stuff
self.session_id = None self.session_id: Optional[str] = None
self.sequence = None self.sequence: Optional[int] = None
self._zlib = zlib.decompressobj() self._zlib = zlib.decompressobj()
self._buffer = bytearray() self._buffer: bytearray = bytearray()
self._close_code = None self._close_code: Optional[int] = None
self._rate_limiter = GatewayRatelimiter() self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter()
# attributes that get set in from_client
self.token: str = utils.MISSING
self._connection: ConnectionState = utils.MISSING
self._discord_parsers: Dict[str, DataCallable[None]] = utils.MISSING
self.gateway: str = utils.MISSING
self.call_hooks: Coro = utils.MISSING
self._initial_identify: bool = utils.MISSING
self.shard_id: Optional[int] = utils.MISSING
self.shard_count: Optional[int] = utils.MISSING
self.session_id: Optional[str] = utils.MISSING
self._max_heartbeat_timeout: float = utils.MISSING
@property @property
def open(self): def open(self) -> bool:
return not self.socket.closed return not self.socket.closed
def is_ratelimited(self): def is_ratelimited(self) -> bool:
return self._rate_limiter.is_ratelimited() return self._rate_limiter.is_ratelimited()
def debug_log_receive(self, data, /): def debug_log_receive(self, data, /) -> None:
self._dispatch('socket_raw_receive', data) self._dispatch('socket_raw_receive', data)
def log_receive(self, _, /): def log_receive(self, _, /) -> None:
pass pass
@classmethod @classmethod
async def from_client(cls, client, *, initial=False, gateway=None, shard_id=None, session=None, sequence=None, resume=False): async def from_client(cls: Type[DWS], client: Client, *, initial: bool = False, gateway: Optional[str] = None, shard_id: Optional[int] = None, session: Optional[str] = None, sequence: Optional[int] = None, resume: bool = False) -> DWS:
"""Creates a main websocket for Discord from a :class:`Client`. """Creates a main websocket for Discord from a :class:`Client`.
This is for internal use only. This is for internal use only.
@ -310,7 +358,9 @@ class DiscordWebSocket:
ws = cls(socket, loop=client.loop) ws = cls(socket, loop=client.loop)
# dynamically add attributes needed # dynamically add attributes needed
ws.token = client.http.token
# the token won't be None here
ws.token = client.http.token # type: ignore
ws._connection = client._connection ws._connection = client._connection
ws._discord_parsers = client._connection.parsers ws._discord_parsers = client._connection.parsers
ws._dispatch = client.dispatch ws._dispatch = client.dispatch
@ -342,7 +392,7 @@ class DiscordWebSocket:
await ws.resume() await ws.resume()
return ws return ws
def wait_for(self, event, predicate, result=None): def wait_for(self, event: str, predicate: Predicate, result: Result = None) -> asyncio.Future:
"""Waits for a DISPATCH'd event that meets the predicate. """Waits for a DISPATCH'd event that meets the predicate.
Parameters Parameters
@ -367,7 +417,7 @@ class DiscordWebSocket:
self._dispatch_listeners.append(entry) self._dispatch_listeners.append(entry)
return future return future
async def identify(self): async def identify(self) -> None:
"""Sends the IDENTIFY packet.""" """Sends the IDENTIFY packet."""
payload = { payload = {
'op': self.IDENTIFY, 'op': self.IDENTIFY,
@ -405,7 +455,7 @@ class DiscordWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
_log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id) _log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id)
async def resume(self): async def resume(self) -> None:
"""Sends the RESUME packet.""" """Sends the RESUME packet."""
payload = { payload = {
'op': self.RESUME, 'op': self.RESUME,
@ -419,7 +469,8 @@ class DiscordWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
_log.info('Shard ID %s has sent the RESUME payload.', self.shard_id) _log.info('Shard ID %s has sent the RESUME payload.', self.shard_id)
async def received_message(self, msg, /):
async def received_message(self, msg, /) -> None:
if type(msg) is bytes: if type(msg) is bytes:
self._buffer.extend(msg) self._buffer.extend(msg)
@ -537,16 +588,16 @@ class DiscordWebSocket:
del self._dispatch_listeners[index] del self._dispatch_listeners[index]
@property @property
def latency(self): def latency(self) -> float:
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.""" """:class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds."""
heartbeat = self._keep_alive heartbeat = self._keep_alive
return float('inf') if heartbeat is None else heartbeat.latency return float('inf') if heartbeat is None else heartbeat.latency
def _can_handle_close(self): def _can_handle_close(self) -> bool:
code = self._close_code or self.socket.close_code code = self._close_code or self.socket.close_code
return code not in (1000, 4004, 4010, 4011, 4012, 4013, 4014) return code not in (1000, 4004, 4010, 4011, 4012, 4013, 4014)
async def poll_event(self): async def poll_event(self) -> None:
"""Polls for a DISPATCH event and handles the general gateway loop. """Polls for a DISPATCH event and handles the general gateway loop.
Raises Raises
@ -584,23 +635,23 @@ class DiscordWebSocket:
_log.info('Websocket closed with %s, cannot reconnect.', code) _log.info('Websocket closed with %s, cannot reconnect.', code)
raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None
async def debug_send(self, data, /): async def debug_send(self, data, /) -> None:
await self._rate_limiter.block() await self._rate_limiter.block()
self._dispatch('socket_raw_send', data) self._dispatch('socket_raw_send', data)
await self.socket.send_str(data) await self.socket.send_str(data)
async def send(self, data, /): async def send(self, data, /) -> None:
await self._rate_limiter.block() await self._rate_limiter.block()
await self.socket.send_str(data) await self.socket.send_str(data)
async def send_as_json(self, data): async def send_as_json(self, data) -> None:
try: try:
await self.send(utils._to_json(data)) await self.send(utils._to_json(data))
except RuntimeError as exc: except RuntimeError as exc:
if not self._can_handle_close(): if not self._can_handle_close():
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
async def send_heartbeat(self, data): async def send_heartbeat(self, data: Heartbeat) -> None:
# This bypasses the rate limit handling code since it has a higher priority # This bypasses the rate limit handling code since it has a higher priority
try: try:
await self.socket.send_str(utils._to_json(data)) await self.socket.send_str(utils._to_json(data))
@ -608,13 +659,13 @@ class DiscordWebSocket:
if not self._can_handle_close(): if not self._can_handle_close():
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
async def change_presence(self, *, activity=None, status=None, since=0.0): async def change_presence(self, *, activity: Optional[BaseActivity] = None, status: Optional[str] = None, since: float = 0.0) -> None:
if activity is not None: if activity is not None:
if not isinstance(activity, BaseActivity): if not isinstance(activity, BaseActivity):
raise InvalidArgument('activity must derive from BaseActivity.') raise InvalidArgument('activity must derive from BaseActivity.')
activity = [activity.to_dict()] activities = [activity.to_dict()]
else: else:
activity = [] activities = []
if status == 'idle': if status == 'idle':
since = int(time.time() * 1000) since = int(time.time() * 1000)
@ -622,7 +673,7 @@ class DiscordWebSocket:
payload = { payload = {
'op': self.PRESENCE, 'op': self.PRESENCE,
'd': { 'd': {
'activities': activity, 'activities': activities,
'afk': False, 'afk': False,
'since': since, 'since': since,
'status': status 'status': status
@ -633,7 +684,7 @@ class DiscordWebSocket:
_log.debug('Sending "%s" to change status', sent) _log.debug('Sending "%s" to change status', sent)
await self.send(sent) await self.send(sent)
async def request_chunks(self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None): async def request_chunks(self, guild_id: int, query: Optional[str] = None, *, limit: int, user_ids: Optional[List[int]] = None, presences: bool = False, nonce: Optional[int] = None) -> None:
payload = { payload = {
'op': self.REQUEST_MEMBERS, 'op': self.REQUEST_MEMBERS,
'd': { 'd': {
@ -655,7 +706,7 @@ class DiscordWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False): async def voice_state(self, guild_id: int, channel_id: int, self_mute: bool = False, self_deaf: bool = False) -> None:
payload = { payload = {
'op': self.VOICE_STATE, 'op': self.VOICE_STATE,
'd': { 'd': {
@ -669,7 +720,7 @@ class DiscordWebSocket:
_log.debug('Updating our voice state to %s.', payload) _log.debug('Updating our voice state to %s.', payload)
await self.send_as_json(payload) await self.send_as_json(payload)
async def close(self, code=4000): async def close(self, code: int = 4000) -> None:
if self._keep_alive: if self._keep_alive:
self._keep_alive.stop() self._keep_alive.stop()
self._keep_alive = None self._keep_alive = None
@ -721,25 +772,31 @@ class DiscordVoiceWebSocket:
CLIENT_CONNECT = 12 CLIENT_CONNECT = 12
CLIENT_DISCONNECT = 13 CLIENT_DISCONNECT = 13
def __init__(self, socket, loop, *, hook=None): def __init__(self, socket: aiohttp.ClientWebSocketResponse, loop: asyncio.AbstractEventLoop, *, hook: Optional[Coro] = None) -> None:
self.ws = socket self.ws: aiohttp.ClientWebSocketResponse = socket
self.loop = loop self.loop: asyncio.AbstractEventLoop = loop
self._keep_alive = None self._keep_alive: VoiceKeepAliveHandler = utils.MISSING
self._close_code = None self._close_code: Optional[int] = None
self.secret_key = None self.secret_key: Optional[List[int]] = None
self.gateway: str = utils.MISSING
self._connection: VoiceClient = utils.MISSING
self._max_heartbeat_timeout: float = utils.MISSING
self.thread_id: int = utils.MISSING
if hook: if hook:
self._hook = hook # we want to redeclare self._hook
self._hook = hook # type: ignore
async def _hook(self, *args): async def _hook(self, *args: Any) -> Any:
pass pass
async def send_as_json(self, data):
async def send_as_json(self, data) -> None:
_log.debug('Sending voice websocket frame: %s.', data) _log.debug('Sending voice websocket frame: %s.', data)
await self.ws.send_str(utils._to_json(data)) await self.ws.send_str(utils._to_json(data))
send_heartbeat = send_as_json send_heartbeat = send_as_json
async def resume(self): async def resume(self) -> None:
state = self._connection state = self._connection
payload = { payload = {
'op': self.RESUME, 'op': self.RESUME,
@ -765,7 +822,7 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
@classmethod @classmethod
async def from_client(cls, client, *, resume=False, hook=None): async def from_client(cls: Type[DVWS], client: VoiceClient, *, resume: bool = False, hook: Optional[Coro] = None) -> DVWS:
"""Creates a voice websocket for the :class:`VoiceClient`.""" """Creates a voice websocket for the :class:`VoiceClient`."""
gateway = 'wss://' + client.endpoint + '/?v=4' gateway = 'wss://' + client.endpoint + '/?v=4'
http = client._state.http http = client._state.http
@ -783,7 +840,7 @@ class DiscordVoiceWebSocket:
return ws return ws
async def select_protocol(self, ip, port, mode): async def select_protocol(self, ip, port, mode) -> None:
payload = { payload = {
'op': self.SELECT_PROTOCOL, 'op': self.SELECT_PROTOCOL,
'd': { 'd': {
@ -798,7 +855,7 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
async def client_connect(self): async def client_connect(self) -> None:
payload = { payload = {
'op': self.CLIENT_CONNECT, 'op': self.CLIENT_CONNECT,
'd': { 'd': {
@ -808,7 +865,7 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
async def speak(self, state=SpeakingState.voice): async def speak(self, state=SpeakingState.voice) -> None:
payload = { payload = {
'op': self.SPEAKING, 'op': self.SPEAKING,
'd': { 'd': {
@ -819,7 +876,8 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload) await self.send_as_json(payload)
async def received_message(self, msg):
async def received_message(self, msg) -> None:
_log.debug('Voice websocket frame received: %s', msg) _log.debug('Voice websocket frame received: %s', msg)
op = msg['op'] op = msg['op']
data = msg.get('d') data = msg.get('d')
@ -840,7 +898,7 @@ class DiscordVoiceWebSocket:
await self._hook(self, msg) await self._hook(self, msg)
async def initial_connection(self, data): async def initial_connection(self, data) -> None:
state = self._connection state = self._connection
state.ssrc = data['ssrc'] state.ssrc = data['ssrc']
state.voice_port = data['port'] state.voice_port = data['port']
@ -871,13 +929,13 @@ class DiscordVoiceWebSocket:
_log.info('selected the voice protocol for use (%s)', mode) _log.info('selected the voice protocol for use (%s)', mode)
@property @property
def latency(self): def latency(self) -> float:
""":class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds.""" """:class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds."""
heartbeat = self._keep_alive heartbeat = self._keep_alive
return float('inf') if heartbeat is None else heartbeat.latency return float('inf') if heartbeat is None else heartbeat.latency
@property @property
def average_latency(self): def average_latency(self) -> float:
""":class:`list`: Average of last 20 HEARTBEAT latencies.""" """:class:`list`: Average of last 20 HEARTBEAT latencies."""
heartbeat = self._keep_alive heartbeat = self._keep_alive
if heartbeat is None or not heartbeat.recent_ack_latencies: if heartbeat is None or not heartbeat.recent_ack_latencies:
@ -885,13 +943,14 @@ class DiscordVoiceWebSocket:
return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies) return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies)
async def load_secret_key(self, data):
async def load_secret_key(self, data) -> None:
_log.info('received secret key for voice connection') _log.info('received secret key for voice connection')
self.secret_key = self._connection.secret_key = data.get('secret_key') self.secret_key = self._connection.secret_key = data.get('secret_key')
await self.speak() await self.speak()
await self.speak(False) await self.speak(False)
async def poll_event(self): async def poll_event(self) -> None:
# This exception is handled up the chain # This exception is handled up the chain
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0) msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
if msg.type is aiohttp.WSMsgType.TEXT: if msg.type is aiohttp.WSMsgType.TEXT:
@ -903,7 +962,7 @@ class DiscordVoiceWebSocket:
_log.debug('Received %s', msg) _log.debug('Received %s', msg)
raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code) raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code)
async def close(self, code=1000): async def close(self, code: int = 1000) -> None:
if self._keep_alive is not None: if self._keep_alive is not None:
self._keep_alive.stop() self._keep_alive.stop()

View File

@ -255,6 +255,9 @@ class VoiceClient(VoiceProtocol):
self.encoder: Encoder = MISSING self.encoder: Encoder = MISSING
self._lite_nonce: int = 0 self._lite_nonce: int = 0
self.ws: DiscordVoiceWebSocket = MISSING self.ws: DiscordVoiceWebSocket = MISSING
self.ip: str = MISSING
self.port: Tuple[Any, ...] = MISSING
warn_nacl = not has_nacl warn_nacl = not has_nacl
supported_modes: Tuple[SupportedModes, ...] = ( supported_modes: Tuple[SupportedModes, ...] = (