diff --git a/discord/gateway.py b/discord/gateway.py index aa0c6ba0..fbbc3c5e 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -22,8 +22,12 @@ FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ +from __future__ import annotations + +from typing import TYPE_CHECKING, TypedDict, Any, Optional, List, TypeVar, Type, Dict, Callable, Coroutine, NamedTuple, Deque + import asyncio -from collections import namedtuple, deque +from collections import deque import concurrent.futures import logging import struct @@ -38,9 +42,25 @@ import aiohttp from . import utils from .activity import BaseActivity from .enums import SpeakingState -from .errors import ConnectionClosed, InvalidArgument +from .errors import ConnectionClosed, InvalidArgument + +if TYPE_CHECKING: + from .client import Client + from .state import ConnectionState + from .voice_client import VoiceClient + + T = TypeVar('T') + DWS = TypeVar('DWS', bound='DiscordWebSocket') + DVWS = TypeVar('DVWS', bound='DiscordVoiceWebSocket') + + Coro = Callable[..., Coroutine[Any, Any, Any]] + Predicate = Callable[[Dict[str, Any]], bool] + DataCallable = Callable[[Dict[str, Any]], T] + Result = Optional[DataCallable[Any]] + + +_log: logging.Logger = logging.getLogger(__name__) -_log = logging.getLogger(__name__) __all__ = ( 'DiscordWebSocket', @@ -50,36 +70,49 @@ __all__ = ( 'ReconnectWebSocket', ) + +class Heartbeat(TypedDict): + op: int + d: int + + class ReconnectWebSocket(Exception): """Signals to safely reconnect the websocket.""" - def __init__(self, shard_id, *, resume=True): - self.shard_id = shard_id - self.resume = resume + def __init__(self, shard_id: Optional[int], *, resume: bool = True) -> None: + self.shard_id: Optional[int] = shard_id + self.resume: bool = resume self.op = 'RESUME' if resume else 'IDENTIFY' + class WebSocketClosure(Exception): """An exception to make up for the fact that aiohttp doesn't signal closure.""" pass -EventListener = namedtuple('EventListener', 'predicate event result future') + +class EventListener(NamedTuple): + predicate: Predicate + event: str + result: Result + future: asyncio.Future + class GatewayRatelimiter: - def __init__(self, count=110, per=60.0): + def __init__(self, count: int = 110, per: float = 60.0) -> None: # The default is 110 to give room for at least 10 heartbeats per minute - self.max = count - self.remaining = count - self.window = 0.0 - self.per = per - self.lock = asyncio.Lock() - self.shard_id = None + self.max: int = count + self.remaining: int = count + self.window: float = 0.0 + self.per: float = per + self.lock: asyncio.Lock = asyncio.Lock() + self.shard_id: Optional[int] = None - def is_ratelimited(self): + def is_ratelimited(self) -> bool: current = time.time() if current > self.window + self.per: return False return self.remaining == 0 - def get_delay(self): + def get_delay(self) -> float: current = time.time() if current > self.window + self.per: @@ -97,7 +130,7 @@ class GatewayRatelimiter: return 0.0 - async def block(self): + async def block(self) -> None: async with self.lock: delta = self.get_delay() if delta: @@ -106,27 +139,27 @@ class GatewayRatelimiter: class KeepAliveHandler(threading.Thread): - def __init__(self, *args, **kwargs): - ws = kwargs.pop('ws', None) + def __init__(self, *args: Any, **kwargs: Any) -> None: + ws = kwargs.pop('ws') interval = kwargs.pop('interval', None) shard_id = kwargs.pop('shard_id', None) threading.Thread.__init__(self, *args, **kwargs) - self.ws = ws - self._main_thread_id = ws.thread_id - self.interval = interval - self.daemon = True - self.shard_id = shard_id - self.msg = 'Keeping shard ID %s websocket alive with sequence %s.' - self.block_msg = 'Shard ID %s heartbeat blocked for more than %s seconds.' - self.behind_msg = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.' - self._stop_ev = threading.Event() - self._last_ack = time.perf_counter() - self._last_send = time.perf_counter() - self._last_recv = time.perf_counter() - self.latency = float('inf') - self.heartbeat_timeout = ws._max_heartbeat_timeout + self.ws: DiscordWebSocket = ws + self._main_thread_id: int = ws.thread_id + self.interval: Optional[float] = interval + self.daemon: bool = True + self.shard_id: Optional[int] = shard_id + self.msg: str = 'Keeping shard ID %s websocket alive with sequence %s.' + self.block_msg: str = 'Shard ID %s heartbeat blocked for more than %s seconds.' + self.behind_msg: str = 'Can\'t keep up, shard ID %s websocket is %.1fs behind.' + self._stop_ev: threading.Event = threading.Event() + self._last_ack: float = time.perf_counter() + self._last_send: float = time.perf_counter() + self._last_recv: float = time.perf_counter() + self.latency: float = float('inf') + self.heartbeat_timeout: float = ws._max_heartbeat_timeout - def run(self): + def run(self) -> None: while not self._stop_ev.wait(self.interval): if self._last_recv + self.heartbeat_timeout < time.perf_counter(): _log.warning("Shard ID %s has stopped responding to the gateway. Closing and restarting.", self.shard_id) @@ -168,19 +201,20 @@ class KeepAliveHandler(threading.Thread): else: self._last_send = time.perf_counter() - def get_payload(self): + def get_payload(self) -> Heartbeat: return { 'op': self.ws.HEARTBEAT, - 'd': self.ws.sequence + # the websocket's sequence won't be None here + 'd': self.ws.sequence # type: ignore } - def stop(self): + def stop(self) -> None: self._stop_ev.set() - def tick(self): + def tick(self) -> None: self._last_recv = time.perf_counter() - def ack(self): + def ack(self) -> None: ack_time = time.perf_counter() self._last_ack = ack_time self.latency = ack_time - self._last_send @@ -188,30 +222,32 @@ class KeepAliveHandler(threading.Thread): _log.warning(self.behind_msg, self.shard_id, self.latency) class VoiceKeepAliveHandler(KeepAliveHandler): - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) - self.recent_ack_latencies = deque(maxlen=20) + self.recent_ack_latencies: Deque[float] = deque(maxlen=20) self.msg = 'Keeping shard ID %s voice websocket alive with timestamp %s.' self.block_msg = 'Shard ID %s voice heartbeat blocked for more than %s seconds' self.behind_msg = 'High socket latency, shard ID %s heartbeat is %.1fs behind' - def get_payload(self): + def get_payload(self) -> Heartbeat: return { 'op': self.ws.HEARTBEAT, 'd': int(time.time() * 1000) } - def ack(self): + def ack(self) -> None: ack_time = time.perf_counter() self._last_ack = ack_time self._last_recv = ack_time self.latency = ack_time - self._last_send self.recent_ack_latencies.append(self.latency) + class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse): async def close(self, *, code: int = 4000, message: bytes = b'') -> bool: return await super().close(code=code, message=message) + class DiscordWebSocket: """Implements a WebSocket for Discord's gateway v6. @@ -266,41 +302,53 @@ class DiscordWebSocket: HEARTBEAT_ACK = 11 GUILD_SYNC = 12 - def __init__(self, socket, *, loop): - self.socket = socket - self.loop = loop + def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.AbstractEventLoop) -> None: + self.socket: aiohttp.ClientWebSocketResponse = socket + self.loop: asyncio.AbstractEventLoop = loop # an empty dispatcher to prevent crashes self._dispatch = lambda *args: None # generic event listeners - self._dispatch_listeners = [] + self._dispatch_listeners: List[EventListener] = [] # the keep alive - self._keep_alive = None - self.thread_id = threading.get_ident() + self._keep_alive: Optional[KeepAliveHandler] = None + self.thread_id: int = threading.get_ident() # ws related stuff - self.session_id = None - self.sequence = None + self.session_id: Optional[str] = None + self.sequence: Optional[int] = None self._zlib = zlib.decompressobj() - self._buffer = bytearray() - self._close_code = None - self._rate_limiter = GatewayRatelimiter() + self._buffer: bytearray = bytearray() + self._close_code: Optional[int] = None + self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter() + + # attributes that get set in from_client + self.token: str = utils.MISSING + self._connection: ConnectionState = utils.MISSING + self._discord_parsers: Dict[str, DataCallable[None]] = utils.MISSING + self.gateway: str = utils.MISSING + self.call_hooks: Coro = utils.MISSING + self._initial_identify: bool = utils.MISSING + self.shard_id: Optional[int] = utils.MISSING + self.shard_count: Optional[int] = utils.MISSING + self.session_id: Optional[str] = utils.MISSING + self._max_heartbeat_timeout: float = utils.MISSING @property - def open(self): + def open(self) -> bool: return not self.socket.closed - def is_ratelimited(self): + def is_ratelimited(self) -> bool: return self._rate_limiter.is_ratelimited() - def debug_log_receive(self, data, /): + def debug_log_receive(self, data, /) -> None: self._dispatch('socket_raw_receive', data) - def log_receive(self, _, /): + def log_receive(self, _, /) -> None: pass @classmethod - async def from_client(cls, client, *, initial=False, gateway=None, shard_id=None, session=None, sequence=None, resume=False): + async def from_client(cls: Type[DWS], client: Client, *, initial: bool = False, gateway: Optional[str] = None, shard_id: Optional[int] = None, session: Optional[str] = None, sequence: Optional[int] = None, resume: bool = False) -> DWS: """Creates a main websocket for Discord from a :class:`Client`. This is for internal use only. @@ -310,7 +358,9 @@ class DiscordWebSocket: ws = cls(socket, loop=client.loop) # dynamically add attributes needed - ws.token = client.http.token + + # the token won't be None here + ws.token = client.http.token # type: ignore ws._connection = client._connection ws._discord_parsers = client._connection.parsers ws._dispatch = client.dispatch @@ -342,7 +392,7 @@ class DiscordWebSocket: await ws.resume() return ws - def wait_for(self, event, predicate, result=None): + def wait_for(self, event: str, predicate: Predicate, result: Result = None) -> asyncio.Future: """Waits for a DISPATCH'd event that meets the predicate. Parameters @@ -367,7 +417,7 @@ class DiscordWebSocket: self._dispatch_listeners.append(entry) return future - async def identify(self): + async def identify(self) -> None: """Sends the IDENTIFY packet.""" payload = { 'op': self.IDENTIFY, @@ -405,7 +455,7 @@ class DiscordWebSocket: await self.send_as_json(payload) _log.info('Shard ID %s has sent the IDENTIFY payload.', self.shard_id) - async def resume(self): + async def resume(self) -> None: """Sends the RESUME packet.""" payload = { 'op': self.RESUME, @@ -419,7 +469,8 @@ class DiscordWebSocket: await self.send_as_json(payload) _log.info('Shard ID %s has sent the RESUME payload.', self.shard_id) - async def received_message(self, msg, /): + + async def received_message(self, msg, /) -> None: if type(msg) is bytes: self._buffer.extend(msg) @@ -537,16 +588,16 @@ class DiscordWebSocket: del self._dispatch_listeners[index] @property - def latency(self): + def latency(self) -> float: """:class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds.""" heartbeat = self._keep_alive return float('inf') if heartbeat is None else heartbeat.latency - def _can_handle_close(self): + def _can_handle_close(self) -> bool: code = self._close_code or self.socket.close_code return code not in (1000, 4004, 4010, 4011, 4012, 4013, 4014) - async def poll_event(self): + async def poll_event(self) -> None: """Polls for a DISPATCH event and handles the general gateway loop. Raises @@ -584,23 +635,23 @@ class DiscordWebSocket: _log.info('Websocket closed with %s, cannot reconnect.', code) raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None - async def debug_send(self, data, /): + async def debug_send(self, data, /) -> None: await self._rate_limiter.block() self._dispatch('socket_raw_send', data) await self.socket.send_str(data) - async def send(self, data, /): + async def send(self, data, /) -> None: await self._rate_limiter.block() await self.socket.send_str(data) - async def send_as_json(self, data): + async def send_as_json(self, data) -> None: try: await self.send(utils._to_json(data)) except RuntimeError as exc: if not self._can_handle_close(): raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc - async def send_heartbeat(self, data): + async def send_heartbeat(self, data: Heartbeat) -> None: # This bypasses the rate limit handling code since it has a higher priority try: await self.socket.send_str(utils._to_json(data)) @@ -608,13 +659,13 @@ class DiscordWebSocket: if not self._can_handle_close(): raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc - async def change_presence(self, *, activity=None, status=None, since=0.0): + async def change_presence(self, *, activity: Optional[BaseActivity] = None, status: Optional[str] = None, since: float = 0.0) -> None: if activity is not None: if not isinstance(activity, BaseActivity): raise InvalidArgument('activity must derive from BaseActivity.') - activity = [activity.to_dict()] + activities = [activity.to_dict()] else: - activity = [] + activities = [] if status == 'idle': since = int(time.time() * 1000) @@ -622,7 +673,7 @@ class DiscordWebSocket: payload = { 'op': self.PRESENCE, 'd': { - 'activities': activity, + 'activities': activities, 'afk': False, 'since': since, 'status': status @@ -633,7 +684,7 @@ class DiscordWebSocket: _log.debug('Sending "%s" to change status', sent) await self.send(sent) - async def request_chunks(self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None): + async def request_chunks(self, guild_id: int, query: Optional[str] = None, *, limit: int, user_ids: Optional[List[int]] = None, presences: bool = False, nonce: Optional[int] = None) -> None: payload = { 'op': self.REQUEST_MEMBERS, 'd': { @@ -655,7 +706,7 @@ class DiscordWebSocket: await self.send_as_json(payload) - async def voice_state(self, guild_id, channel_id, self_mute=False, self_deaf=False): + async def voice_state(self, guild_id: int, channel_id: int, self_mute: bool = False, self_deaf: bool = False) -> None: payload = { 'op': self.VOICE_STATE, 'd': { @@ -669,7 +720,7 @@ class DiscordWebSocket: _log.debug('Updating our voice state to %s.', payload) await self.send_as_json(payload) - async def close(self, code=4000): + async def close(self, code: int = 4000) -> None: if self._keep_alive: self._keep_alive.stop() self._keep_alive = None @@ -721,25 +772,31 @@ class DiscordVoiceWebSocket: CLIENT_CONNECT = 12 CLIENT_DISCONNECT = 13 - def __init__(self, socket, loop, *, hook=None): - self.ws = socket - self.loop = loop - self._keep_alive = None - self._close_code = None - self.secret_key = None + def __init__(self, socket: aiohttp.ClientWebSocketResponse, loop: asyncio.AbstractEventLoop, *, hook: Optional[Coro] = None) -> None: + self.ws: aiohttp.ClientWebSocketResponse = socket + self.loop: asyncio.AbstractEventLoop = loop + self._keep_alive: VoiceKeepAliveHandler = utils.MISSING + self._close_code: Optional[int] = None + self.secret_key: Optional[List[int]] = None + self.gateway: str = utils.MISSING + self._connection: VoiceClient = utils.MISSING + self._max_heartbeat_timeout: float = utils.MISSING + self.thread_id: int = utils.MISSING if hook: - self._hook = hook + # we want to redeclare self._hook + self._hook = hook # type: ignore - async def _hook(self, *args): + async def _hook(self, *args: Any) -> Any: pass - async def send_as_json(self, data): + + async def send_as_json(self, data) -> None: _log.debug('Sending voice websocket frame: %s.', data) await self.ws.send_str(utils._to_json(data)) send_heartbeat = send_as_json - async def resume(self): + async def resume(self) -> None: state = self._connection payload = { 'op': self.RESUME, @@ -765,7 +822,7 @@ class DiscordVoiceWebSocket: await self.send_as_json(payload) @classmethod - async def from_client(cls, client, *, resume=False, hook=None): + async def from_client(cls: Type[DVWS], client: VoiceClient, *, resume: bool = False, hook: Optional[Coro] = None) -> DVWS: """Creates a voice websocket for the :class:`VoiceClient`.""" gateway = 'wss://' + client.endpoint + '/?v=4' http = client._state.http @@ -783,7 +840,7 @@ class DiscordVoiceWebSocket: return ws - async def select_protocol(self, ip, port, mode): + async def select_protocol(self, ip, port, mode) -> None: payload = { 'op': self.SELECT_PROTOCOL, 'd': { @@ -798,7 +855,7 @@ class DiscordVoiceWebSocket: await self.send_as_json(payload) - async def client_connect(self): + async def client_connect(self) -> None: payload = { 'op': self.CLIENT_CONNECT, 'd': { @@ -808,7 +865,7 @@ class DiscordVoiceWebSocket: await self.send_as_json(payload) - async def speak(self, state=SpeakingState.voice): + async def speak(self, state=SpeakingState.voice) -> None: payload = { 'op': self.SPEAKING, 'd': { @@ -819,7 +876,8 @@ class DiscordVoiceWebSocket: await self.send_as_json(payload) - async def received_message(self, msg): + + async def received_message(self, msg) -> None: _log.debug('Voice websocket frame received: %s', msg) op = msg['op'] data = msg.get('d') @@ -840,7 +898,7 @@ class DiscordVoiceWebSocket: await self._hook(self, msg) - async def initial_connection(self, data): + async def initial_connection(self, data) -> None: state = self._connection state.ssrc = data['ssrc'] state.voice_port = data['port'] @@ -871,13 +929,13 @@ class DiscordVoiceWebSocket: _log.info('selected the voice protocol for use (%s)', mode) @property - def latency(self): + def latency(self) -> float: """:class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds.""" heartbeat = self._keep_alive return float('inf') if heartbeat is None else heartbeat.latency @property - def average_latency(self): + def average_latency(self) -> float: """:class:`list`: Average of last 20 HEARTBEAT latencies.""" heartbeat = self._keep_alive if heartbeat is None or not heartbeat.recent_ack_latencies: @@ -885,13 +943,14 @@ class DiscordVoiceWebSocket: return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies) - async def load_secret_key(self, data): + + async def load_secret_key(self, data) -> None: _log.info('received secret key for voice connection') self.secret_key = self._connection.secret_key = data.get('secret_key') await self.speak() await self.speak(False) - async def poll_event(self): + async def poll_event(self) -> None: # This exception is handled up the chain msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0) if msg.type is aiohttp.WSMsgType.TEXT: @@ -903,7 +962,7 @@ class DiscordVoiceWebSocket: _log.debug('Received %s', msg) raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code) - async def close(self, code=1000): + async def close(self, code: int = 1000) -> None: if self._keep_alive is not None: self._keep_alive.stop() diff --git a/discord/voice_client.py b/discord/voice_client.py index d382a74d..123dd29b 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -255,6 +255,9 @@ class VoiceClient(VoiceProtocol): self.encoder: Encoder = MISSING self._lite_nonce: int = 0 self.ws: DiscordVoiceWebSocket = MISSING + self.ip: str = MISSING + self.port: Tuple[Any, ...] = MISSING + warn_nacl = not has_nacl supported_modes: Tuple[SupportedModes, ...] = (