mirror of
https://github.com/Rapptz/discord.py.git
synced 2025-04-21 00:07:51 +00:00
Type-hint gateway
This commit is contained in:
parent
f5e087c5c3
commit
c8064ba6f2
@ -21,9 +21,10 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
||||
DEALINGS IN THE SOFTWARE.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections import namedtuple, deque
|
||||
from collections import deque
|
||||
import concurrent.futures
|
||||
import logging
|
||||
import struct
|
||||
@ -33,6 +34,8 @@ import threading
|
||||
import traceback
|
||||
import zlib
|
||||
|
||||
from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Type
|
||||
|
||||
import aiohttp
|
||||
|
||||
from . import utils
|
||||
@ -50,6 +53,11 @@ __all__ = (
|
||||
'ReconnectWebSocket',
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .client import Client
|
||||
from .state import ConnectionState
|
||||
from .voice_client import VoiceClient
|
||||
|
||||
|
||||
class ReconnectWebSocket(Exception):
|
||||
"""Signals to safely reconnect the websocket."""
|
||||
@ -66,26 +74,30 @@ class WebSocketClosure(Exception):
|
||||
pass
|
||||
|
||||
|
||||
EventListener = namedtuple('EventListener', 'predicate event result future')
|
||||
class EventListener(NamedTuple):
|
||||
predicate: Callable[[Dict[str, Any]], bool]
|
||||
event: str
|
||||
result: Optional[Callable[[Dict[str, Any]], Any]]
|
||||
future: asyncio.Future[Any]
|
||||
|
||||
|
||||
class GatewayRatelimiter:
|
||||
def __init__(self, count=110, per=60.0):
|
||||
def __init__(self, count: int = 110, per: float = 60.0) -> None:
|
||||
# The default is 110 to give room for at least 10 heartbeats per minute
|
||||
self.max = count
|
||||
self.remaining = count
|
||||
self.window = 0.0
|
||||
self.per = per
|
||||
self.lock = asyncio.Lock()
|
||||
self.shard_id = None
|
||||
self.max: int = count
|
||||
self.remaining: int = count
|
||||
self.window: float = 0.0
|
||||
self.per: float = per
|
||||
self.lock: asyncio.Lock = asyncio.Lock()
|
||||
self.shard_id: Optional[int] = None
|
||||
|
||||
def is_ratelimited(self):
|
||||
def is_ratelimited(self) -> bool:
|
||||
current = time.time()
|
||||
if current > self.window + self.per:
|
||||
return False
|
||||
return self.remaining == 0
|
||||
|
||||
def get_delay(self):
|
||||
def get_delay(self) -> float:
|
||||
current = time.time()
|
||||
|
||||
if current > self.window + self.per:
|
||||
@ -103,7 +115,7 @@ class GatewayRatelimiter:
|
||||
|
||||
return 0.0
|
||||
|
||||
async def block(self):
|
||||
async def block(self) -> None:
|
||||
async with self.lock:
|
||||
delta = self.get_delay()
|
||||
if delta:
|
||||
@ -112,27 +124,31 @@ class GatewayRatelimiter:
|
||||
|
||||
|
||||
class KeepAliveHandler(threading.Thread):
|
||||
def __init__(self, *args, **kwargs):
|
||||
ws = kwargs.pop('ws', None)
|
||||
interval = kwargs.pop('interval', None)
|
||||
shard_id = kwargs.pop('shard_id', None)
|
||||
threading.Thread.__init__(self, *args, **kwargs)
|
||||
self.ws = ws
|
||||
self._main_thread_id = ws.thread_id
|
||||
self.interval = interval
|
||||
self.daemon = True
|
||||
self.shard_id = shard_id
|
||||
self.msg = 'Keeping shard ID %s websocket alive with sequence %s.'
|
||||
self.block_msg = 'Shard ID %s heartbeat blocked for more than %s seconds.'
|
||||
self.behind_msg = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.'
|
||||
self._stop_ev = threading.Event()
|
||||
self._last_ack = time.perf_counter()
|
||||
self._last_send = time.perf_counter()
|
||||
self._last_recv = time.perf_counter()
|
||||
self.latency = float('inf')
|
||||
self.heartbeat_timeout = ws._max_heartbeat_timeout
|
||||
def __init__(
|
||||
self,
|
||||
*args: Any,
|
||||
ws: DiscordWebSocket,
|
||||
interval: Optional[float] = None,
|
||||
shard_id: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.ws: DiscordWebSocket = ws
|
||||
self._main_thread_id: int = ws.thread_id
|
||||
self.interval: Optional[float] = interval
|
||||
self.daemon: bool = True
|
||||
self.shard_id: Optional[int] = shard_id
|
||||
self.msg: str = 'Keeping shard ID %s websocket alive with sequence %s.'
|
||||
self.block_msg: str = 'Shard ID %s heartbeat blocked for more than %s seconds.'
|
||||
self.behind_msg: str = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.'
|
||||
self._stop_ev: threading.Event = threading.Event()
|
||||
self._last_ack: float = time.perf_counter()
|
||||
self._last_send: float = time.perf_counter()
|
||||
self._last_recv: float = time.perf_counter()
|
||||
self.latency: float = float('inf')
|
||||
self.heartbeat_timeout: float = ws._max_heartbeat_timeout
|
||||
|
||||
def run(self):
|
||||
def run(self) -> None:
|
||||
while not self._stop_ev.wait(self.interval):
|
||||
if self._last_recv + self.heartbeat_timeout < time.perf_counter():
|
||||
_log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id)
|
||||
@ -174,19 +190,19 @@ class KeepAliveHandler(threading.Thread):
|
||||
else:
|
||||
self._last_send = time.perf_counter()
|
||||
|
||||
def get_payload(self):
|
||||
def get_payload(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'op': self.ws.HEARTBEAT,
|
||||
'd': self.ws.sequence,
|
||||
}
|
||||
|
||||
def stop(self):
|
||||
def stop(self) -> None:
|
||||
self._stop_ev.set()
|
||||
|
||||
def tick(self):
|
||||
def tick(self) -> None:
|
||||
self._last_recv = time.perf_counter()
|
||||
|
||||
def ack(self):
|
||||
def ack(self) -> None:
|
||||
ack_time = time.perf_counter()
|
||||
self._last_ack = ack_time
|
||||
self.latency = ack_time - self._last_send
|
||||
@ -195,20 +211,20 @@ class KeepAliveHandler(threading.Thread):
|
||||
|
||||
|
||||
class VoiceKeepAliveHandler(KeepAliveHandler):
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self.recent_ack_latencies = deque(maxlen=20)
|
||||
self.msg = 'Keeping shard ID %s voice websocket alive with timestamp %s.'
|
||||
self.block_msg = 'Shard ID %s voice heartbeat blocked for more than %s seconds'
|
||||
self.behind_msg = 'High socket latency, shard ID %s heartbeat is %.1fs behind'
|
||||
self.recent_ack_latencies: Deque[float] = deque(maxlen=20)
|
||||
self.msg: str = 'Keeping shard ID %s voice websocket alive with timestamp %s.'
|
||||
self.block_msg: str = 'Shard ID %s voice heartbeat blocked for more than %s seconds'
|
||||
self.behind_msg: str = 'High socket latency, shard ID %s heartbeat is %.1fs behind'
|
||||
|
||||
def get_payload(self):
|
||||
def get_payload(self) -> Dict[str, Any]:
|
||||
return {
|
||||
'op': self.ws.HEARTBEAT,
|
||||
'd': int(time.time() * 1000),
|
||||
}
|
||||
|
||||
def ack(self):
|
||||
def ack(self) -> None:
|
||||
ack_time = time.perf_counter()
|
||||
self._last_ack = ack_time
|
||||
self._last_recv = ack_time
|
||||
@ -221,6 +237,9 @@ class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse):
|
||||
return await super().close(code=code, message=message)
|
||||
|
||||
|
||||
DWS = TypeVar('DWS', bound='DiscordWebSocket')
|
||||
|
||||
|
||||
class DiscordWebSocket:
|
||||
"""Implements a WebSocket for Discord's gateway v6.
|
||||
|
||||
@ -261,6 +280,17 @@ class DiscordWebSocket:
|
||||
The authentication token for discord.
|
||||
"""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
token: Optional[str]
|
||||
_connection: ConnectionState
|
||||
_discord_parsers: Dict[str, Callable[..., Any]]
|
||||
call_hooks: Callable[..., Any]
|
||||
_initial_identify: bool
|
||||
shard_id: Optional[int]
|
||||
shard_count: Optional[int]
|
||||
gateway: str
|
||||
_max_heartbeat_timeout: float
|
||||
|
||||
# fmt: off
|
||||
DISPATCH = 0
|
||||
HEARTBEAT = 1
|
||||
@ -277,51 +307,51 @@ class DiscordWebSocket:
|
||||
GUILD_SYNC = 12
|
||||
# fmt: on
|
||||
|
||||
def __init__(self, socket, *, loop):
|
||||
self.socket = socket
|
||||
self.loop = loop
|
||||
def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.AbstractEventLoop) -> None:
|
||||
self.socket: aiohttp.ClientWebSocketResponse = socket
|
||||
self.loop: asyncio.AbstractEventLoop = loop
|
||||
|
||||
# an empty dispatcher to prevent crashes
|
||||
self._dispatch = lambda *args: None
|
||||
self._dispatch: Callable[..., Any] = lambda *args: None
|
||||
# generic event listeners
|
||||
self._dispatch_listeners = []
|
||||
self._dispatch_listeners: List[EventListener] = []
|
||||
# the keep alive
|
||||
self._keep_alive = None
|
||||
self.thread_id = threading.get_ident()
|
||||
self._keep_alive: Optional[KeepAliveHandler] = None
|
||||
self.thread_id: int = threading.get_ident()
|
||||
|
||||
# ws related stuff
|
||||
self.session_id = None
|
||||
self.sequence = None
|
||||
self._zlib = zlib.decompressobj()
|
||||
self._buffer = bytearray()
|
||||
self._close_code = None
|
||||
self._rate_limiter = GatewayRatelimiter()
|
||||
self.session_id: Optional[str] = None
|
||||
self.sequence: Optional[int] = None
|
||||
self._zlib: zlib._Decompress = zlib.decompressobj()
|
||||
self._buffer: bytearray = bytearray()
|
||||
self._close_code: Optional[int] = None
|
||||
self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter()
|
||||
|
||||
@property
|
||||
def open(self):
|
||||
def open(self) -> bool:
|
||||
return not self.socket.closed
|
||||
|
||||
def is_ratelimited(self):
|
||||
def is_ratelimited(self) -> bool:
|
||||
return self._rate_limiter.is_ratelimited()
|
||||
|
||||
def debug_log_receive(self, data, /):
|
||||
def debug_log_receive(self, data: Dict[str, Any], /) -> None:
|
||||
self._dispatch('socket_raw_receive', data)
|
||||
|
||||
def log_receive(self, _, /):
|
||||
def log_receive(self, _: Dict[str, Any], /) -> None:
|
||||
pass
|
||||
|
||||
@classmethod
|
||||
async def from_client(
|
||||
cls,
|
||||
client,
|
||||
cls: Type[DWS],
|
||||
client: Client,
|
||||
*,
|
||||
initial=False,
|
||||
gateway=None,
|
||||
shard_id=None,
|
||||
session=None,
|
||||
sequence=None,
|
||||
resume=False,
|
||||
):
|
||||
initial: bool = False,
|
||||
gateway: Optional[str] = None,
|
||||
shard_id: Optional[int] = None,
|
||||
session: Optional[str] = None,
|
||||
sequence: Optional[int] = None,
|
||||
resume: bool = False,
|
||||
) -> DWS:
|
||||
"""Creates a main websocket for Discord from a :class:`Client`.
|
||||
|
||||
This is for internal use only.
|
||||
@ -363,7 +393,12 @@ class DiscordWebSocket:
|
||||
await ws.resume()
|
||||
return ws
|
||||
|
||||
def wait_for(self, event, predicate, result=None):
|
||||
def wait_for(
|
||||
self,
|
||||
event: str,
|
||||
predicate: Callable[[Dict[str, Any]], bool],
|
||||
result: Optional[Callable[[Dict[str, Any]], Any]] = None,
|
||||
) -> asyncio.Future[Any]:
|
||||
"""Waits for a DISPATCH'd event that meets the predicate.
|
||||
|
||||
Parameters
|
||||
@ -388,7 +423,7 @@ class DiscordWebSocket:
|
||||
self._dispatch_listeners.append(entry)
|
||||
return future
|
||||
|
||||
async def identify(self):
|
||||
async def identify(self) -> None:
|
||||
"""Sends the IDENTIFY packet."""
|
||||
payload = {
|
||||
'op': self.IDENTIFY,
|
||||
@ -426,7 +461,7 @@ class DiscordWebSocket:
|
||||
await self.send_as_json(payload)
|
||||
_log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id)
|
||||
|
||||
async def resume(self):
|
||||
async def resume(self) -> None:
|
||||
"""Sends the RESUME packet."""
|
||||
payload = {
|
||||
'op': self.RESUME,
|
||||
@ -440,7 +475,7 @@ class DiscordWebSocket:
|
||||
await self.send_as_json(payload)
|
||||
_log.info('Shard ID %s has sent the RESUME payload.', self.shard_id)
|
||||
|
||||
async def received_message(self, msg, /):
|
||||
async def received_message(self, msg: Any, /) -> None:
|
||||
if type(msg) is bytes:
|
||||
self._buffer.extend(msg)
|
||||
|
||||
@ -566,16 +601,16 @@ class DiscordWebSocket:
|
||||
del self._dispatch_listeners[index]
|
||||
|
||||
@property
|
||||
def latency(self):
|
||||
def latency(self) -> float:
|
||||
""":class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds."""
|
||||
heartbeat = self._keep_alive
|
||||
return float('inf') if heartbeat is None else heartbeat.latency
|
||||
|
||||
def _can_handle_close(self):
|
||||
def _can_handle_close(self) -> bool:
|
||||
code = self._close_code or self.socket.close_code
|
||||
return code not in (1000, 4004, 4010, 4011, 4012, 4013, 4014)
|
||||
|
||||
async def poll_event(self):
|
||||
async def poll_event(self) -> None:
|
||||
"""Polls for a DISPATCH event and handles the general gateway loop.
|
||||
|
||||
Raises
|
||||
@ -613,23 +648,23 @@ class DiscordWebSocket:
|
||||
_log.info('Websocket closed with %s, cannot reconnect.', code)
|
||||
raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None
|
||||
|
||||
async def debug_send(self, data, /):
|
||||
async def debug_send(self, data: str, /) -> None:
|
||||
await self._rate_limiter.block()
|
||||
self._dispatch('socket_raw_send', data)
|
||||
await self.socket.send_str(data)
|
||||
|
||||
async def send(self, data, /):
|
||||
async def send(self, data: str, /) -> None:
|
||||
await self._rate_limiter.block()
|
||||
await self.socket.send_str(data)
|
||||
|
||||
async def send_as_json(self, data):
|
||||
async def send_as_json(self, data: Any) -> None:
|
||||
try:
|
||||
await self.send(utils._to_json(data))
|
||||
except RuntimeError as exc:
|
||||
if not self._can_handle_close():
|
||||
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
|
||||
|
||||
async def send_heartbeat(self, data):
|
||||
async def send_heartbeat(self, data: Any) -> None:
|
||||
# This bypasses the rate limit handling code since it has a higher priority
|
||||
try:
|
||||
await self.socket.send_str(utils._to_json(data))
|
||||
@ -637,13 +672,19 @@ class DiscordWebSocket:
|
||||
if not self._can_handle_close():
|
||||
raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
|
||||
|
||||
async def change_presence(self, *, activity=None, status=None, since=0.0):
|
||||
async def change_presence(
|
||||
self,
|
||||
*,
|
||||
activity: Optional[BaseActivity] = None,
|
||||
status: Optional[str] = None,
|
||||
since: float = 0.0,
|
||||
) -> None:
|
||||
if activity is not None:
|
||||
if not isinstance(activity, BaseActivity):
|
||||
raise InvalidArgument('activity must derive from BaseActivity.')
|
||||
activity = [activity.to_dict()]
|
||||
activities = [activity.to_dict()]
|
||||
else:
|
||||
activity = []
|
||||
activities = []
|
||||
|
||||
if status == 'idle':
|
||||
since = int(time.time() * 1000)
|
||||
@ -651,7 +692,7 @@ class DiscordWebSocket:
|
||||
payload = {
|
||||
'op': self.PRESENCE,
|
||||
'd': {
|
||||
'activities': activity,
|
||||
'activities': activities,
|
||||
'afk': False,
|
||||
'since': since,
|
||||
'status': status,
|
||||
@ -662,7 +703,16 @@ class DiscordWebSocket:
|
||||
_log.debug('Sending "%s" to change status', sent)
|
||||
await self.send(sent)
|
||||
|
||||
async def request_chunks(self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None):
|
||||
async def request_chunks(
|
||||
self,
|
||||
guild_id: int,
|
||||
query: Optional[str] = None,
|
||||
*,
|
||||
limit: int,
|
||||
user_ids: Optional[List[int]] = None,
|
||||
presences: bool = False,
|
||||
nonce: Optional[str] = None,
|
||||
) -> None:
|
||||
payload = {
|
||||
'op': self.REQUEST_MEMBERS,
|
||||
'd': {
|
||||
@ -683,7 +733,13 @@ class DiscordWebSocket:
|
||||
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False):
|
||||
async def voice_state(
|
||||
self,
|
||||
guild_id: int,
|
||||
channel_id: int,
|
||||
self_mute: bool = False,
|
||||
self_deaf: bool = False,
|
||||
) -> None:
|
||||
payload = {
|
||||
'op': self.VOICE_STATE,
|
||||
'd': {
|
||||
@ -697,7 +753,7 @@ class DiscordWebSocket:
|
||||
_log.debug('Updating our voice state to %s.', payload)
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def close(self, code=4000):
|
||||
async def close(self, code: int = 4000) -> None:
|
||||
if self._keep_alive:
|
||||
self._keep_alive.stop()
|
||||
self._keep_alive = None
|
||||
@ -706,6 +762,9 @@ class DiscordWebSocket:
|
||||
await self.socket.close(code=code)
|
||||
|
||||
|
||||
DVWS = TypeVar('DVWS', bound='DiscordVoiceWebSocket')
|
||||
|
||||
|
||||
class DiscordVoiceWebSocket:
|
||||
"""Implements the websocket protocol for handling voice connections.
|
||||
|
||||
@ -737,6 +796,12 @@ class DiscordVoiceWebSocket:
|
||||
Receive only. Indicates a user has disconnected from voice.
|
||||
"""
|
||||
|
||||
if TYPE_CHECKING:
|
||||
thread_id: int
|
||||
_connection: VoiceClient
|
||||
gateway: str
|
||||
_max_heartbeat_timeout: float
|
||||
|
||||
# fmt: off
|
||||
IDENTIFY = 0
|
||||
SELECT_PROTOCOL = 1
|
||||
@ -752,25 +817,31 @@ class DiscordVoiceWebSocket:
|
||||
CLIENT_DISCONNECT = 13
|
||||
# fmt: on
|
||||
|
||||
def __init__(self, socket, loop, *, hook=None):
|
||||
def __init__(
|
||||
self,
|
||||
socket: aiohttp.ClientWebSocketResponse,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
*,
|
||||
hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None,
|
||||
) -> None:
|
||||
self.ws = socket
|
||||
self.loop = loop
|
||||
self._keep_alive = None
|
||||
self._close_code = None
|
||||
self.secret_key = None
|
||||
if hook:
|
||||
self._hook = hook
|
||||
self._hook = hook # type: ignore - type-checker doesn't like overriding methods
|
||||
|
||||
async def _hook(self, *args):
|
||||
async def _hook(self, *args: Any) -> None:
|
||||
pass
|
||||
|
||||
async def send_as_json(self, data):
|
||||
async def send_as_json(self, data: Any) -> None:
|
||||
_log.debug('Sending voice websocket frame: %s.', data)
|
||||
await self.ws.send_str(utils._to_json(data))
|
||||
|
||||
send_heartbeat = send_as_json
|
||||
|
||||
async def resume(self):
|
||||
async def resume(self) -> None:
|
||||
state = self._connection
|
||||
payload = {
|
||||
'op': self.RESUME,
|
||||
@ -782,7 +853,7 @@ class DiscordVoiceWebSocket:
|
||||
}
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def identify(self):
|
||||
async def identify(self) -> None:
|
||||
state = self._connection
|
||||
payload = {
|
||||
'op': self.IDENTIFY,
|
||||
@ -796,7 +867,7 @@ class DiscordVoiceWebSocket:
|
||||
await self.send_as_json(payload)
|
||||
|
||||
@classmethod
|
||||
async def from_client(cls, client, *, resume=False, hook=None):
|
||||
async def from_client(cls: Type[DVWS], client: VoiceClient, *, resume=False, hook=None) -> DVWS:
|
||||
"""Creates a voice websocket for the :class:`VoiceClient`."""
|
||||
gateway = 'wss://' + client.endpoint + '/?v=4'
|
||||
http = client._state.http
|
||||
@ -814,7 +885,7 @@ class DiscordVoiceWebSocket:
|
||||
|
||||
return ws
|
||||
|
||||
async def select_protocol(self, ip, port, mode):
|
||||
async def select_protocol(self, ip: str, port: int, mode: int) -> None:
|
||||
payload = {
|
||||
'op': self.SELECT_PROTOCOL,
|
||||
'd': {
|
||||
@ -829,7 +900,7 @@ class DiscordVoiceWebSocket:
|
||||
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def client_connect(self):
|
||||
async def client_connect(self) -> None:
|
||||
payload = {
|
||||
'op': self.CLIENT_CONNECT,
|
||||
'd': {
|
||||
@ -839,7 +910,7 @@ class DiscordVoiceWebSocket:
|
||||
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def speak(self, state=SpeakingState.voice):
|
||||
async def speak(self, state: SpeakingState = SpeakingState.voice) -> None:
|
||||
payload = {
|
||||
'op': self.SPEAKING,
|
||||
'd': {
|
||||
@ -850,28 +921,29 @@ class DiscordVoiceWebSocket:
|
||||
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def received_message(self, msg):
|
||||
async def received_message(self, msg: Dict[str, Any]) -> None:
|
||||
_log.debug('Voice websocket frame received: %s', msg)
|
||||
op = msg['op']
|
||||
data = msg.get('d')
|
||||
|
||||
if op == self.READY:
|
||||
await self.initial_connection(data)
|
||||
await self.initial_connection(data) # type: ignore - type-checker thinks data could be None
|
||||
elif op == self.HEARTBEAT_ACK:
|
||||
self._keep_alive.ack()
|
||||
self._keep_alive.ack() # type: ignore - _keep_alive can't be None at this point
|
||||
elif op == self.RESUMED:
|
||||
_log.info('Voice RESUME succeeded.')
|
||||
elif op == self.SESSION_DESCRIPTION:
|
||||
self._connection.mode = data['mode']
|
||||
await self.load_secret_key(data)
|
||||
# type-checker thinks data could be None
|
||||
self._connection.mode = data['mode'] # type: ignore
|
||||
await self.load_secret_key(data) # type: ignore
|
||||
elif op == self.HELLO:
|
||||
interval = data['heartbeat_interval'] / 1000.0
|
||||
interval = data['heartbeat_interval'] / 1000.0 # type: ignore - type-checker thinks data could be None
|
||||
self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0))
|
||||
self._keep_alive.start()
|
||||
|
||||
await self._hook(self, msg)
|
||||
|
||||
async def initial_connection(self, data):
|
||||
async def initial_connection(self, data: Dict[str, Any]) -> None:
|
||||
state = self._connection
|
||||
state.ssrc = data['ssrc']
|
||||
state.voice_port = data['port']
|
||||
@ -888,41 +960,41 @@ class DiscordVoiceWebSocket:
|
||||
# the ip is ascii starting at the 4th byte and ending at the first null
|
||||
ip_start = 4
|
||||
ip_end = recv.index(0, ip_start)
|
||||
state.ip = recv[ip_start:ip_end].decode('ascii')
|
||||
state.endpoint_ip = recv[ip_start:ip_end].decode('ascii')
|
||||
|
||||
state.port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
|
||||
_log.debug('detected ip: %s port: %s', state.ip, state.port)
|
||||
state.voice_port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
|
||||
_log.debug('detected ip: %s port: %s', state.endpoint_ip, state.voice_port)
|
||||
|
||||
# there *should* always be at least one supported mode (xsalsa20_poly1305)
|
||||
modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes]
|
||||
_log.debug('received supported encryption modes: %s', ", ".join(modes))
|
||||
|
||||
mode = modes[0]
|
||||
await self.select_protocol(state.ip, state.port, mode)
|
||||
await self.select_protocol(state.endpoint_ip, state.voice_port, mode)
|
||||
_log.info('selected the voice protocol for use (%s)', mode)
|
||||
|
||||
@property
|
||||
def latency(self):
|
||||
def latency(self) -> float:
|
||||
""":class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds."""
|
||||
heartbeat = self._keep_alive
|
||||
return float('inf') if heartbeat is None else heartbeat.latency
|
||||
|
||||
@property
|
||||
def average_latency(self):
|
||||
""":class:`list`: Average of last 20 HEARTBEAT latencies."""
|
||||
def average_latency(self) -> float:
|
||||
""":class:`float`: Average of last 20 HEARTBEAT latencies."""
|
||||
heartbeat = self._keep_alive
|
||||
if heartbeat is None or not heartbeat.recent_ack_latencies:
|
||||
return float('inf')
|
||||
|
||||
return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies)
|
||||
|
||||
async def load_secret_key(self, data):
|
||||
async def load_secret_key(self, data: Dict[str, Any]) -> None:
|
||||
_log.info('received secret key for voice connection')
|
||||
self.secret_key = self._connection.secret_key = data.get('secret_key')
|
||||
self.secret_key = self._connection.secret_key = data.get('secret_key') # type: ignore - type-checker thinks secret_key could be None
|
||||
await self.speak()
|
||||
await self.speak(False)
|
||||
await self.speak(SpeakingState.none)
|
||||
|
||||
async def poll_event(self):
|
||||
async def poll_event(self) -> None:
|
||||
# This exception is handled up the chain
|
||||
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
|
||||
if msg.type is aiohttp.WSMsgType.TEXT:
|
||||
@ -934,7 +1006,7 @@ class DiscordVoiceWebSocket:
|
||||
_log.debug('Received %s', msg)
|
||||
raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code)
|
||||
|
||||
async def close(self, code=1000):
|
||||
async def close(self, code: int = 1000) -> None:
|
||||
if self._keep_alive is not None:
|
||||
self._keep_alive.stop()
|
||||
|
||||
|
@ -341,7 +341,7 @@ class HTTPClient:
|
||||
connector=self.connector, ws_response_class=DiscordClientWebSocketResponse
|
||||
)
|
||||
|
||||
async def ws_connect(self, url: str, *, compress: int = 0) -> Any:
|
||||
async def ws_connect(self, url: str, *, compress: int = 0) -> aiohttp.ClientWebSocketResponse:
|
||||
kwargs = {
|
||||
'proxy_auth': self.proxy_auth,
|
||||
'proxy': self.proxy,
|
||||
|
@ -239,6 +239,7 @@ class VoiceClient(VoiceProtocol):
|
||||
super().__init__(client, channel)
|
||||
state = client._connection
|
||||
self.token: str = MISSING
|
||||
self.server_id: int = MISSING
|
||||
self.socket = MISSING
|
||||
self.loop: asyncio.AbstractEventLoop = state.loop
|
||||
self._state: ConnectionState = state
|
||||
|
Loading…
x
Reference in New Issue
Block a user