mirror of
				https://github.com/Rapptz/discord.py.git
				synced 2025-10-25 10:32:59 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			1061 lines
		
	
	
		
			36 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			1061 lines
		
	
	
		
			36 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| """
 | |
| The MIT License (MIT)
 | |
| 
 | |
| Copyright (c) 2015-present Rapptz
 | |
| 
 | |
| Permission is hereby granted, free of charge, to any person obtaining a
 | |
| copy of this software and associated documentation files (the "Software"),
 | |
| to deal in the Software without restriction, including without limitation
 | |
| the rights to use, copy, modify, merge, publish, distribute, sublicense,
 | |
| and/or sell copies of the Software, and to permit persons to whom the
 | |
| Software is furnished to do so, subject to the following conditions:
 | |
| 
 | |
| The above copyright notice and this permission notice shall be included in
 | |
| all copies or substantial portions of the Software.
 | |
| 
 | |
| THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
 | |
| OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 | |
| FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 | |
| AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 | |
| LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
 | |
| FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
 | |
| DEALINGS IN THE SOFTWARE.
 | |
| """
 | |
| 
 | |
| from __future__ import annotations
 | |
| 
 | |
| import asyncio
 | |
| from collections import deque
 | |
| import concurrent.futures
 | |
| import logging
 | |
| import struct
 | |
| import sys
 | |
| import time
 | |
| import threading
 | |
| import traceback
 | |
| 
 | |
| from typing import Any, Callable, Coroutine, Deque, Dict, List, TYPE_CHECKING, NamedTuple, Optional, TypeVar, Tuple
 | |
| 
 | |
| import aiohttp
 | |
| import yarl
 | |
| 
 | |
| from . import utils
 | |
| from .activity import BaseActivity
 | |
| from .enums import SpeakingState
 | |
| from .errors import ConnectionClosed
 | |
| 
 | |
| _log = logging.getLogger(__name__)
 | |
| 
 | |
| __all__ = (
 | |
|     'DiscordWebSocket',
 | |
|     'KeepAliveHandler',
 | |
|     'VoiceKeepAliveHandler',
 | |
|     'DiscordVoiceWebSocket',
 | |
|     'ReconnectWebSocket',
 | |
| )
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     from typing_extensions import Self
 | |
| 
 | |
|     from .client import Client
 | |
|     from .state import ConnectionState
 | |
|     from .voice_state import VoiceConnectionState
 | |
| 
 | |
| 
 | |
| class ReconnectWebSocket(Exception):
 | |
|     """Signals to safely reconnect the websocket."""
 | |
| 
 | |
|     def __init__(self, shard_id: Optional[int], *, resume: bool = True) -> None:
 | |
|         self.shard_id: Optional[int] = shard_id
 | |
|         self.resume: bool = resume
 | |
|         self.op: str = 'RESUME' if resume else 'IDENTIFY'
 | |
| 
 | |
| 
 | |
| class WebSocketClosure(Exception):
 | |
|     """An exception to make up for the fact that aiohttp doesn't signal closure."""
 | |
| 
 | |
|     pass
 | |
| 
 | |
| 
 | |
| class EventListener(NamedTuple):
 | |
|     predicate: Callable[[Dict[str, Any]], bool]
 | |
|     event: str
 | |
|     result: Optional[Callable[[Dict[str, Any]], Any]]
 | |
|     future: asyncio.Future[Any]
 | |
| 
 | |
| 
 | |
| class GatewayRatelimiter:
 | |
|     def __init__(self, count: int = 110, per: float = 60.0) -> None:
 | |
|         # The default is 110 to give room for at least 10 heartbeats per minute
 | |
|         self.max: int = count
 | |
|         self.remaining: int = count
 | |
|         self.window: float = 0.0
 | |
|         self.per: float = per
 | |
|         self.lock: asyncio.Lock = asyncio.Lock()
 | |
|         self.shard_id: Optional[int] = None
 | |
| 
 | |
|     def is_ratelimited(self) -> bool:
 | |
|         current = time.time()
 | |
|         if current > self.window + self.per:
 | |
|             return False
 | |
|         return self.remaining == 0
 | |
| 
 | |
|     def get_delay(self) -> float:
 | |
|         current = time.time()
 | |
| 
 | |
|         if current > self.window + self.per:
 | |
|             self.remaining = self.max
 | |
| 
 | |
|         if self.remaining == self.max:
 | |
|             self.window = current
 | |
| 
 | |
|         if self.remaining == 0:
 | |
|             return self.per - (current - self.window)
 | |
| 
 | |
|         self.remaining -= 1
 | |
|         return 0.0
 | |
| 
 | |
|     async def block(self) -> None:
 | |
|         async with self.lock:
 | |
|             delta = self.get_delay()
 | |
|             if delta:
 | |
|                 _log.warning('WebSocket in shard ID %s is ratelimited, waiting %.2f seconds', self.shard_id, delta)
 | |
|                 await asyncio.sleep(delta)
 | |
| 
 | |
| 
 | |
| class KeepAliveHandler(threading.Thread):
 | |
|     def __init__(
 | |
|         self,
 | |
|         *args: Any,
 | |
|         ws: DiscordWebSocket,
 | |
|         interval: Optional[float] = None,
 | |
|         shard_id: Optional[int] = None,
 | |
|         **kwargs: Any,
 | |
|     ) -> None:
 | |
|         daemon: bool = kwargs.pop('daemon', True)
 | |
|         name: str = kwargs.pop('name', f'keep-alive-handler:shard-{shard_id}')
 | |
|         super().__init__(*args, daemon=daemon, name=name, **kwargs)
 | |
|         self.ws: DiscordWebSocket = ws
 | |
|         self._main_thread_id: int = ws.thread_id
 | |
|         self.interval: Optional[float] = interval
 | |
|         self.shard_id: Optional[int] = shard_id
 | |
|         self.msg: str = 'Keeping shard ID %s websocket alive with sequence %s.'
 | |
|         self.block_msg: str = 'Shard ID %s heartbeat blocked for more than %s seconds.'
 | |
|         self.behind_msg: str = "Can't keep up, shard ID %s websocket is %.1fs behind."
 | |
|         self._stop_ev: threading.Event = threading.Event()
 | |
|         self._last_ack: float = time.perf_counter()
 | |
|         self._last_send: float = time.perf_counter()
 | |
|         self._last_recv: float = time.perf_counter()
 | |
|         self.latency: float = float('inf')
 | |
|         self.heartbeat_timeout: float = ws._max_heartbeat_timeout
 | |
| 
 | |
|     def run(self) -> None:
 | |
|         while not self._stop_ev.wait(self.interval):
 | |
|             if self._last_recv + self.heartbeat_timeout < time.perf_counter():
 | |
|                 _log.warning('Shard ID %s has stopped responding to the gateway. Closing and restarting.', self.shard_id)
 | |
|                 coro = self.ws.close(4000)
 | |
|                 f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop)
 | |
| 
 | |
|                 try:
 | |
|                     f.result()
 | |
|                 except Exception:
 | |
|                     _log.exception('An error occurred while stopping the gateway. Ignoring.')
 | |
|                 except BaseException as exc:
 | |
|                     _log.debug('A BaseException was raised while stopping the gateway', exc_info=exc)
 | |
|                 finally:
 | |
|                     self.stop()
 | |
|                 return
 | |
| 
 | |
|             data = self.get_payload()
 | |
|             _log.debug(self.msg, self.shard_id, data['d'])
 | |
|             coro = self.ws.send_heartbeat(data)
 | |
|             f = asyncio.run_coroutine_threadsafe(coro, loop=self.ws.loop)
 | |
|             try:
 | |
|                 # block until sending is complete
 | |
|                 total = 0
 | |
|                 while True:
 | |
|                     try:
 | |
|                         f.result(10)
 | |
|                         break
 | |
|                     except concurrent.futures.TimeoutError:
 | |
|                         total += 10
 | |
|                         try:
 | |
|                             frame = sys._current_frames()[self._main_thread_id]
 | |
|                         except KeyError:
 | |
|                             msg = self.block_msg
 | |
|                         else:
 | |
|                             stack = ''.join(traceback.format_stack(frame))
 | |
|                             msg = f'{self.block_msg}\nLoop thread traceback (most recent call last):\n{stack}'
 | |
|                         _log.warning(msg, self.shard_id, total)
 | |
| 
 | |
|             except Exception:
 | |
|                 self.stop()
 | |
|             else:
 | |
|                 self._last_send = time.perf_counter()
 | |
| 
 | |
|     def get_payload(self) -> Dict[str, Any]:
 | |
|         return {
 | |
|             'op': self.ws.HEARTBEAT,
 | |
|             'd': self.ws.sequence,
 | |
|         }
 | |
| 
 | |
|     def stop(self) -> None:
 | |
|         self._stop_ev.set()
 | |
| 
 | |
|     def tick(self) -> None:
 | |
|         self._last_recv = time.perf_counter()
 | |
| 
 | |
|     def ack(self) -> None:
 | |
|         ack_time = time.perf_counter()
 | |
|         self._last_ack = ack_time
 | |
|         self.latency = ack_time - self._last_send
 | |
|         if self.latency > 10:
 | |
|             _log.warning(self.behind_msg, self.shard_id, self.latency)
 | |
| 
 | |
| 
 | |
| class VoiceKeepAliveHandler(KeepAliveHandler):
 | |
|     if TYPE_CHECKING:
 | |
|         ws: DiscordVoiceWebSocket
 | |
| 
 | |
|     def __init__(self, *args: Any, **kwargs: Any) -> None:
 | |
|         name: str = kwargs.pop('name', f'voice-keep-alive-handler:{id(self):#x}')
 | |
|         super().__init__(*args, name=name, **kwargs)
 | |
|         self.recent_ack_latencies: Deque[float] = deque(maxlen=20)
 | |
|         self.msg: str = 'Keeping shard ID %s voice websocket alive with timestamp %s.'
 | |
|         self.block_msg: str = 'Shard ID %s voice heartbeat blocked for more than %s seconds'
 | |
|         self.behind_msg: str = 'High socket latency, shard ID %s heartbeat is %.1fs behind'
 | |
| 
 | |
|     def get_payload(self) -> Dict[str, Any]:
 | |
|         return {
 | |
|             'op': self.ws.HEARTBEAT,
 | |
|             'd': {
 | |
|                 't': int(time.time() * 1000),
 | |
|                 'seq_ack': self.ws.seq_ack,
 | |
|             },
 | |
|         }
 | |
| 
 | |
|     def ack(self) -> None:
 | |
|         ack_time = time.perf_counter()
 | |
|         self._last_ack = ack_time
 | |
|         self._last_recv = ack_time
 | |
|         self.latency: float = ack_time - self._last_send
 | |
|         self.recent_ack_latencies.append(self.latency)
 | |
| 
 | |
| 
 | |
| class DiscordClientWebSocketResponse(aiohttp.ClientWebSocketResponse):
 | |
|     async def close(self, *, code: int = 4000, message: bytes = b'') -> bool:
 | |
|         return await super().close(code=code, message=message)
 | |
| 
 | |
| 
 | |
| DWS = TypeVar('DWS', bound='DiscordWebSocket')
 | |
| 
 | |
| 
 | |
| class DiscordWebSocket:
 | |
|     """Implements a WebSocket for Discord's gateway v10.
 | |
| 
 | |
|     Attributes
 | |
|     -----------
 | |
|     DISPATCH
 | |
|         Receive only. Denotes an event to be sent to Discord, such as READY.
 | |
|     HEARTBEAT
 | |
|         When received tells Discord to keep the connection alive.
 | |
|         When sent asks if your connection is currently alive.
 | |
|     IDENTIFY
 | |
|         Send only. Starts a new session.
 | |
|     PRESENCE
 | |
|         Send only. Updates your presence.
 | |
|     VOICE_STATE
 | |
|         Send only. Starts a new connection to a voice guild.
 | |
|     VOICE_PING
 | |
|         Send only. Checks ping time to a voice guild, do not use.
 | |
|     RESUME
 | |
|         Send only. Resumes an existing connection.
 | |
|     RECONNECT
 | |
|         Receive only. Tells the client to reconnect to a new gateway.
 | |
|     REQUEST_MEMBERS
 | |
|         Send only. Asks for the full member list of a guild.
 | |
|     INVALIDATE_SESSION
 | |
|         Receive only. Tells the client to optionally invalidate the session
 | |
|         and IDENTIFY again.
 | |
|     HELLO
 | |
|         Receive only. Tells the client the heartbeat interval.
 | |
|     HEARTBEAT_ACK
 | |
|         Receive only. Confirms receiving of a heartbeat. Not having it implies
 | |
|         a connection issue.
 | |
|     GUILD_SYNC
 | |
|         Send only. Requests a guild sync.
 | |
|     gateway
 | |
|         The gateway we are currently connected to.
 | |
|     token
 | |
|         The authentication token for discord.
 | |
|     """
 | |
| 
 | |
|     if TYPE_CHECKING:
 | |
|         token: Optional[str]
 | |
|         _connection: ConnectionState
 | |
|         _discord_parsers: Dict[str, Callable[..., Any]]
 | |
|         call_hooks: Callable[..., Any]
 | |
|         _initial_identify: bool
 | |
|         shard_id: Optional[int]
 | |
|         shard_count: Optional[int]
 | |
|         gateway: yarl.URL
 | |
|         _max_heartbeat_timeout: float
 | |
| 
 | |
|     # fmt: off
 | |
|     DEFAULT_GATEWAY    = yarl.URL('wss://gateway.discord.gg/')
 | |
|     DISPATCH                    = 0
 | |
|     HEARTBEAT                   = 1
 | |
|     IDENTIFY                    = 2
 | |
|     PRESENCE                    = 3
 | |
|     VOICE_STATE                 = 4
 | |
|     VOICE_PING                  = 5
 | |
|     RESUME                      = 6
 | |
|     RECONNECT                   = 7
 | |
|     REQUEST_MEMBERS             = 8
 | |
|     INVALIDATE_SESSION          = 9
 | |
|     HELLO                       = 10
 | |
|     HEARTBEAT_ACK               = 11
 | |
|     GUILD_SYNC                  = 12
 | |
|     # fmt: on
 | |
| 
 | |
|     def __init__(self, socket: aiohttp.ClientWebSocketResponse, *, loop: asyncio.AbstractEventLoop) -> None:
 | |
|         self.socket: aiohttp.ClientWebSocketResponse = socket
 | |
|         self.loop: asyncio.AbstractEventLoop = loop
 | |
| 
 | |
|         # an empty dispatcher to prevent crashes
 | |
|         self._dispatch: Callable[..., Any] = lambda *args: None
 | |
|         # generic event listeners
 | |
|         self._dispatch_listeners: List[EventListener] = []
 | |
|         # the keep alive
 | |
|         self._keep_alive: Optional[KeepAliveHandler] = None
 | |
|         self.thread_id: int = threading.get_ident()
 | |
| 
 | |
|         # ws related stuff
 | |
|         self.session_id: Optional[str] = None
 | |
|         self.sequence: Optional[int] = None
 | |
|         self._decompressor: utils._DecompressionContext = utils._ActiveDecompressionContext()
 | |
|         self._close_code: Optional[int] = None
 | |
|         self._rate_limiter: GatewayRatelimiter = GatewayRatelimiter()
 | |
| 
 | |
|     @property
 | |
|     def open(self) -> bool:
 | |
|         return not self.socket.closed
 | |
| 
 | |
|     def is_ratelimited(self) -> bool:
 | |
|         return self._rate_limiter.is_ratelimited()
 | |
| 
 | |
|     def debug_log_receive(self, data: Dict[str, Any], /) -> None:
 | |
|         self._dispatch('socket_raw_receive', data)
 | |
| 
 | |
|     def log_receive(self, _: Dict[str, Any], /) -> None:
 | |
|         pass
 | |
| 
 | |
|     @classmethod
 | |
|     async def from_client(
 | |
|         cls,
 | |
|         client: Client,
 | |
|         *,
 | |
|         initial: bool = False,
 | |
|         gateway: Optional[yarl.URL] = None,
 | |
|         shard_id: Optional[int] = None,
 | |
|         session: Optional[str] = None,
 | |
|         sequence: Optional[int] = None,
 | |
|         resume: bool = False,
 | |
|         encoding: str = 'json',
 | |
|         compress: bool = True,
 | |
|     ) -> Self:
 | |
|         """Creates a main websocket for Discord from a :class:`Client`.
 | |
| 
 | |
|         This is for internal use only.
 | |
|         """
 | |
|         # Circular import
 | |
|         from .http import INTERNAL_API_VERSION
 | |
| 
 | |
|         gateway = gateway or cls.DEFAULT_GATEWAY
 | |
| 
 | |
|         if not compress:
 | |
|             url = gateway.with_query(v=INTERNAL_API_VERSION, encoding=encoding)
 | |
|         else:
 | |
|             url = gateway.with_query(
 | |
|                 v=INTERNAL_API_VERSION, encoding=encoding, compress=utils._ActiveDecompressionContext.COMPRESSION_TYPE
 | |
|             )
 | |
| 
 | |
|         socket = await client.http.ws_connect(str(url))
 | |
|         ws = cls(socket, loop=client.loop)
 | |
| 
 | |
|         # dynamically add attributes needed
 | |
|         ws.token = client.http.token
 | |
|         ws._connection = client._connection
 | |
|         ws._discord_parsers = client._connection.parsers
 | |
|         ws._dispatch = client.dispatch
 | |
|         ws.gateway = gateway
 | |
|         ws.call_hooks = client._connection.call_hooks
 | |
|         ws._initial_identify = initial
 | |
|         ws.shard_id = shard_id
 | |
|         ws._rate_limiter.shard_id = shard_id
 | |
|         ws.shard_count = client._connection.shard_count
 | |
|         ws.session_id = session
 | |
|         ws.sequence = sequence
 | |
|         ws._max_heartbeat_timeout = client._connection.heartbeat_timeout
 | |
| 
 | |
|         if client._enable_debug_events:
 | |
|             ws.send = ws.debug_send
 | |
|             ws.log_receive = ws.debug_log_receive
 | |
| 
 | |
|         client._connection._update_references(ws)
 | |
| 
 | |
|         _log.debug('Created websocket connected to %s', gateway)
 | |
| 
 | |
|         # poll event for OP Hello
 | |
|         await ws.poll_event()
 | |
| 
 | |
|         if not resume:
 | |
|             await ws.identify()
 | |
|             return ws
 | |
| 
 | |
|         await ws.resume()
 | |
|         return ws
 | |
| 
 | |
|     def wait_for(
 | |
|         self,
 | |
|         event: str,
 | |
|         predicate: Callable[[Dict[str, Any]], bool],
 | |
|         result: Optional[Callable[[Dict[str, Any]], Any]] = None,
 | |
|     ) -> asyncio.Future[Any]:
 | |
|         """Waits for a DISPATCH'd event that meets the predicate.
 | |
| 
 | |
|         Parameters
 | |
|         -----------
 | |
|         event: :class:`str`
 | |
|             The event name in all upper case to wait for.
 | |
|         predicate
 | |
|             A function that takes a data parameter to check for event
 | |
|             properties. The data parameter is the 'd' key in the JSON message.
 | |
|         result
 | |
|             A function that takes the same data parameter and executes to send
 | |
|             the result to the future. If ``None``, returns the data.
 | |
| 
 | |
|         Returns
 | |
|         --------
 | |
|         asyncio.Future
 | |
|             A future to wait for.
 | |
|         """
 | |
| 
 | |
|         future = self.loop.create_future()
 | |
|         entry = EventListener(event=event, predicate=predicate, result=result, future=future)
 | |
|         self._dispatch_listeners.append(entry)
 | |
|         return future
 | |
| 
 | |
|     async def identify(self) -> None:
 | |
|         """Sends the IDENTIFY packet."""
 | |
|         payload = {
 | |
|             'op': self.IDENTIFY,
 | |
|             'd': {
 | |
|                 'token': self.token,
 | |
|                 'properties': {
 | |
|                     'os': sys.platform,
 | |
|                     'browser': 'discord.py',
 | |
|                     'device': 'discord.py',
 | |
|                 },
 | |
|                 'compress': True,
 | |
|                 'large_threshold': 250,
 | |
|             },
 | |
|         }
 | |
| 
 | |
|         if self.shard_id is not None and self.shard_count is not None:
 | |
|             payload['d']['shard'] = [self.shard_id, self.shard_count]
 | |
| 
 | |
|         state = self._connection
 | |
|         if state._activity is not None or state._status is not None:
 | |
|             payload['d']['presence'] = {
 | |
|                 'status': state._status,
 | |
|                 'game': state._activity,
 | |
|                 'since': 0,
 | |
|                 'afk': False,
 | |
|             }
 | |
| 
 | |
|         if state._intents is not None:
 | |
|             payload['d']['intents'] = state._intents.value
 | |
| 
 | |
|         await self.call_hooks('before_identify', self.shard_id, initial=self._initial_identify)
 | |
|         await self.send_as_json(payload)
 | |
|         _log.debug('Shard ID %s has sent the IDENTIFY payload.', self.shard_id)
 | |
| 
 | |
|     async def resume(self) -> None:
 | |
|         """Sends the RESUME packet."""
 | |
|         payload = {
 | |
|             'op': self.RESUME,
 | |
|             'd': {
 | |
|                 'seq': self.sequence,
 | |
|                 'session_id': self.session_id,
 | |
|                 'token': self.token,
 | |
|             },
 | |
|         }
 | |
| 
 | |
|         await self.send_as_json(payload)
 | |
|         _log.debug('Shard ID %s has sent the RESUME payload.', self.shard_id)
 | |
| 
 | |
|     async def received_message(self, msg: Any, /) -> None:
 | |
|         if type(msg) is bytes:
 | |
|             msg = self._decompressor.decompress(msg)
 | |
| 
 | |
|             # Received a partial gateway message
 | |
|             if msg is None:
 | |
|                 return
 | |
| 
 | |
|         self.log_receive(msg)
 | |
|         msg = utils._from_json(msg)
 | |
| 
 | |
|         _log.debug('For Shard ID %s: WebSocket Event: %s', self.shard_id, msg)
 | |
|         event = msg.get('t')
 | |
|         if event:
 | |
|             self._dispatch('socket_event_type', event)
 | |
| 
 | |
|         op = msg.get('op')
 | |
|         data = msg.get('d')
 | |
|         seq = msg.get('s')
 | |
|         if seq is not None:
 | |
|             self.sequence = seq
 | |
| 
 | |
|         if self._keep_alive:
 | |
|             self._keep_alive.tick()
 | |
| 
 | |
|         if op != self.DISPATCH:
 | |
|             if op == self.RECONNECT:
 | |
|                 # "reconnect" can only be handled by the Client
 | |
|                 # so we terminate our connection and raise an
 | |
|                 # internal exception signalling to reconnect.
 | |
|                 _log.debug('Received RECONNECT opcode.')
 | |
|                 await self.close()
 | |
|                 raise ReconnectWebSocket(self.shard_id)
 | |
| 
 | |
|             if op == self.HEARTBEAT_ACK:
 | |
|                 if self._keep_alive:
 | |
|                     self._keep_alive.ack()
 | |
|                 return
 | |
| 
 | |
|             if op == self.HEARTBEAT:
 | |
|                 if self._keep_alive:
 | |
|                     beat = self._keep_alive.get_payload()
 | |
|                     await self.send_as_json(beat)
 | |
|                 return
 | |
| 
 | |
|             if op == self.HELLO:
 | |
|                 interval = data['heartbeat_interval'] / 1000.0
 | |
|                 self._keep_alive = KeepAliveHandler(ws=self, interval=interval, shard_id=self.shard_id)
 | |
|                 # send a heartbeat immediately
 | |
|                 await self.send_as_json(self._keep_alive.get_payload())
 | |
|                 self._keep_alive.start()
 | |
|                 return
 | |
| 
 | |
|             if op == self.INVALIDATE_SESSION:
 | |
|                 if data is True:
 | |
|                     await self.close()
 | |
|                     raise ReconnectWebSocket(self.shard_id)
 | |
| 
 | |
|                 self.sequence = None
 | |
|                 self.session_id = None
 | |
|                 self.gateway = self.DEFAULT_GATEWAY
 | |
|                 _log.info('Shard ID %s session has been invalidated.', self.shard_id)
 | |
|                 await self.close(code=1000)
 | |
|                 raise ReconnectWebSocket(self.shard_id, resume=False)
 | |
| 
 | |
|             _log.warning('Unknown OP code %s.', op)
 | |
|             return
 | |
| 
 | |
|         if event == 'READY':
 | |
|             self.sequence = msg['s']
 | |
|             self.session_id = data['session_id']
 | |
|             self.gateway = yarl.URL(data['resume_gateway_url'])
 | |
|             _log.info('Shard ID %s has connected to Gateway (Session ID: %s).', self.shard_id, self.session_id)
 | |
| 
 | |
|         elif event == 'RESUMED':
 | |
|             # pass back the shard ID to the resumed handler
 | |
|             data['__shard_id__'] = self.shard_id
 | |
|             _log.info('Shard ID %s has successfully RESUMED session %s.', self.shard_id, self.session_id)
 | |
| 
 | |
|         try:
 | |
|             func = self._discord_parsers[event]
 | |
|         except KeyError:
 | |
|             _log.debug('Unknown event %s.', event)
 | |
|         else:
 | |
|             func(data)
 | |
| 
 | |
|         # remove the dispatched listeners
 | |
|         removed = []
 | |
|         for index, entry in enumerate(self._dispatch_listeners):
 | |
|             if entry.event != event:
 | |
|                 continue
 | |
| 
 | |
|             future = entry.future
 | |
|             if future.cancelled():
 | |
|                 removed.append(index)
 | |
|                 continue
 | |
| 
 | |
|             try:
 | |
|                 valid = entry.predicate(data)
 | |
|             except Exception as exc:
 | |
|                 future.set_exception(exc)
 | |
|                 removed.append(index)
 | |
|             else:
 | |
|                 if valid:
 | |
|                     ret = data if entry.result is None else entry.result(data)
 | |
|                     future.set_result(ret)
 | |
|                     removed.append(index)
 | |
| 
 | |
|         for index in reversed(removed):
 | |
|             del self._dispatch_listeners[index]
 | |
| 
 | |
|     @property
 | |
|     def latency(self) -> float:
 | |
|         """:class:`float`: Measures latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds."""
 | |
|         heartbeat = self._keep_alive
 | |
|         return float('inf') if heartbeat is None else heartbeat.latency
 | |
| 
 | |
|     def _can_handle_close(self) -> bool:
 | |
|         code = self._close_code or self.socket.close_code
 | |
|         # If the socket is closed remotely with 1000 and it's not our own explicit close
 | |
|         # then it's an improper close that should be handled and reconnected
 | |
|         is_improper_close = self._close_code is None and self.socket.close_code == 1000
 | |
|         return is_improper_close or code not in (1000, 4004, 4010, 4011, 4012, 4013, 4014)
 | |
| 
 | |
|     async def poll_event(self) -> None:
 | |
|         """Polls for a DISPATCH event and handles the general gateway loop.
 | |
| 
 | |
|         Raises
 | |
|         ------
 | |
|         ConnectionClosed
 | |
|             The websocket connection was terminated for unhandled reasons.
 | |
|         """
 | |
|         try:
 | |
|             msg = await self.socket.receive(timeout=self._max_heartbeat_timeout)
 | |
|             if msg.type is aiohttp.WSMsgType.TEXT:
 | |
|                 await self.received_message(msg.data)
 | |
|             elif msg.type is aiohttp.WSMsgType.BINARY:
 | |
|                 await self.received_message(msg.data)
 | |
|             elif msg.type is aiohttp.WSMsgType.ERROR:
 | |
|                 _log.debug('Received error %s', msg)
 | |
|                 raise WebSocketClosure
 | |
|             elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSING, aiohttp.WSMsgType.CLOSE):
 | |
|                 _log.debug('Received %s', msg)
 | |
|                 raise WebSocketClosure
 | |
|         except (asyncio.TimeoutError, WebSocketClosure) as e:
 | |
|             # Ensure the keep alive handler is closed
 | |
|             if self._keep_alive:
 | |
|                 self._keep_alive.stop()
 | |
|                 self._keep_alive = None
 | |
| 
 | |
|             if isinstance(e, asyncio.TimeoutError):
 | |
|                 _log.debug('Timed out receiving packet. Attempting a reconnect.')
 | |
|                 raise ReconnectWebSocket(self.shard_id) from None
 | |
| 
 | |
|             code = self._close_code or self.socket.close_code
 | |
|             if self._can_handle_close():
 | |
|                 _log.debug('Websocket closed with %s, attempting a reconnect.', code)
 | |
|                 raise ReconnectWebSocket(self.shard_id) from None
 | |
|             else:
 | |
|                 _log.debug('Websocket closed with %s, cannot reconnect.', code)
 | |
|                 raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None
 | |
| 
 | |
|     async def debug_send(self, data: str, /) -> None:
 | |
|         await self._rate_limiter.block()
 | |
|         self._dispatch('socket_raw_send', data)
 | |
|         await self.socket.send_str(data)
 | |
| 
 | |
|     async def send(self, data: str, /) -> None:
 | |
|         await self._rate_limiter.block()
 | |
|         await self.socket.send_str(data)
 | |
| 
 | |
|     async def send_as_json(self, data: Any) -> None:
 | |
|         try:
 | |
|             await self.send(utils._to_json(data))
 | |
|         except RuntimeError as exc:
 | |
|             if not self._can_handle_close():
 | |
|                 raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
 | |
| 
 | |
|     async def send_heartbeat(self, data: Any) -> None:
 | |
|         # This bypasses the rate limit handling code since it has a higher priority
 | |
|         try:
 | |
|             await self.socket.send_str(utils._to_json(data))
 | |
|         except RuntimeError as exc:
 | |
|             if not self._can_handle_close():
 | |
|                 raise ConnectionClosed(self.socket, shard_id=self.shard_id) from exc
 | |
| 
 | |
|     async def change_presence(
 | |
|         self,
 | |
|         *,
 | |
|         activity: Optional[BaseActivity] = None,
 | |
|         status: Optional[str] = None,
 | |
|         since: float = 0.0,
 | |
|     ) -> None:
 | |
|         if activity is not None:
 | |
|             if not isinstance(activity, BaseActivity):
 | |
|                 raise TypeError('activity must derive from BaseActivity.')
 | |
|             activities = [activity.to_dict()]
 | |
|         else:
 | |
|             activities = []
 | |
| 
 | |
|         if status == 'idle':
 | |
|             since = int(time.time() * 1000)
 | |
| 
 | |
|         payload = {
 | |
|             'op': self.PRESENCE,
 | |
|             'd': {
 | |
|                 'activities': activities,
 | |
|                 'afk': False,
 | |
|                 'since': since,
 | |
|                 'status': status,
 | |
|             },
 | |
|         }
 | |
| 
 | |
|         sent = utils._to_json(payload)
 | |
|         _log.debug('Sending "%s" to change status', sent)
 | |
|         await self.send(sent)
 | |
| 
 | |
|     async def request_chunks(
 | |
|         self,
 | |
|         guild_id: int,
 | |
|         query: Optional[str] = None,
 | |
|         *,
 | |
|         limit: int,
 | |
|         user_ids: Optional[List[int]] = None,
 | |
|         presences: bool = False,
 | |
|         nonce: Optional[str] = None,
 | |
|     ) -> None:
 | |
|         payload = {
 | |
|             'op': self.REQUEST_MEMBERS,
 | |
|             'd': {
 | |
|                 'guild_id': guild_id,
 | |
|                 'presences': presences,
 | |
|                 'limit': limit,
 | |
|             },
 | |
|         }
 | |
| 
 | |
|         if nonce:
 | |
|             payload['d']['nonce'] = nonce
 | |
| 
 | |
|         if user_ids:
 | |
|             payload['d']['user_ids'] = user_ids
 | |
| 
 | |
|         if query is not None:
 | |
|             payload['d']['query'] = query
 | |
| 
 | |
|         await self.send_as_json(payload)
 | |
| 
 | |
|     async def voice_state(
 | |
|         self,
 | |
|         guild_id: int,
 | |
|         channel_id: Optional[int],
 | |
|         self_mute: bool = False,
 | |
|         self_deaf: bool = False,
 | |
|     ) -> None:
 | |
|         payload = {
 | |
|             'op': self.VOICE_STATE,
 | |
|             'd': {
 | |
|                 'guild_id': guild_id,
 | |
|                 'channel_id': channel_id,
 | |
|                 'self_mute': self_mute,
 | |
|                 'self_deaf': self_deaf,
 | |
|             },
 | |
|         }
 | |
| 
 | |
|         _log.debug('Updating our voice state to %s.', payload)
 | |
|         await self.send_as_json(payload)
 | |
| 
 | |
|     async def close(self, code: int = 4000) -> None:
 | |
|         if self._keep_alive:
 | |
|             self._keep_alive.stop()
 | |
|             self._keep_alive = None
 | |
| 
 | |
|         self._close_code = code
 | |
|         await self.socket.close(code=code)
 | |
| 
 | |
| 
 | |
| DVWS = TypeVar('DVWS', bound='DiscordVoiceWebSocket')
 | |
| 
 | |
| 
 | |
| class DiscordVoiceWebSocket:
 | |
|     """Implements the websocket protocol for handling voice connections.
 | |
| 
 | |
|     Attributes
 | |
|     -----------
 | |
|     IDENTIFY
 | |
|         Send only. Starts a new voice session.
 | |
|     SELECT_PROTOCOL
 | |
|         Send only. Tells discord what encryption mode and how to connect for voice.
 | |
|     READY
 | |
|         Receive only. Tells the websocket that the initial connection has completed.
 | |
|     HEARTBEAT
 | |
|         Send only. Keeps your websocket connection alive.
 | |
|     SESSION_DESCRIPTION
 | |
|         Receive only. Gives you the secret key required for voice.
 | |
|     SPEAKING
 | |
|         Send only. Notifies the client if you are currently speaking.
 | |
|     HEARTBEAT_ACK
 | |
|         Receive only. Tells you your heartbeat has been acknowledged.
 | |
|     RESUME
 | |
|         Sent only. Tells the client to resume its session.
 | |
|     HELLO
 | |
|         Receive only. Tells you that your websocket connection was acknowledged.
 | |
|     RESUMED
 | |
|         Sent only. Tells you that your RESUME request has succeeded.
 | |
|     CLIENT_CONNECT
 | |
|         Indicates a user has connected to voice.
 | |
|     CLIENT_DISCONNECT
 | |
|         Receive only.  Indicates a user has disconnected from voice.
 | |
|     """
 | |
| 
 | |
|     if TYPE_CHECKING:
 | |
|         thread_id: int
 | |
|         _connection: VoiceConnectionState
 | |
|         gateway: str
 | |
|         _max_heartbeat_timeout: float
 | |
| 
 | |
|     # fmt: off
 | |
|     IDENTIFY            = 0
 | |
|     SELECT_PROTOCOL     = 1
 | |
|     READY               = 2
 | |
|     HEARTBEAT           = 3
 | |
|     SESSION_DESCRIPTION = 4
 | |
|     SPEAKING            = 5
 | |
|     HEARTBEAT_ACK       = 6
 | |
|     RESUME              = 7
 | |
|     HELLO               = 8
 | |
|     RESUMED             = 9
 | |
|     CLIENT_CONNECT      = 12
 | |
|     CLIENT_DISCONNECT   = 13
 | |
|     # fmt: on
 | |
| 
 | |
|     def __init__(
 | |
|         self,
 | |
|         socket: aiohttp.ClientWebSocketResponse,
 | |
|         loop: asyncio.AbstractEventLoop,
 | |
|         *,
 | |
|         hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None,
 | |
|     ) -> None:
 | |
|         self.ws: aiohttp.ClientWebSocketResponse = socket
 | |
|         self.loop: asyncio.AbstractEventLoop = loop
 | |
|         self._keep_alive: Optional[VoiceKeepAliveHandler] = None
 | |
|         self._close_code: Optional[int] = None
 | |
|         self.secret_key: Optional[List[int]] = None
 | |
|         # defaulting to -1
 | |
|         self.seq_ack: int = -1
 | |
|         if hook:
 | |
|             self._hook = hook  # type: ignore
 | |
| 
 | |
|     async def _hook(self, *args: Any) -> None:
 | |
|         pass
 | |
| 
 | |
|     async def send_as_json(self, data: Any) -> None:
 | |
|         _log.debug('Sending voice websocket frame: %s.', data)
 | |
|         await self.ws.send_str(utils._to_json(data))
 | |
| 
 | |
|     send_heartbeat = send_as_json
 | |
| 
 | |
|     async def resume(self) -> None:
 | |
|         state = self._connection
 | |
|         payload = {
 | |
|             'op': self.RESUME,
 | |
|             'd': {
 | |
|                 'token': state.token,
 | |
|                 'server_id': str(state.server_id),
 | |
|                 'session_id': state.session_id,
 | |
|                 'seq_ack': self.seq_ack,
 | |
|             },
 | |
|         }
 | |
|         await self.send_as_json(payload)
 | |
| 
 | |
|     async def identify(self) -> None:
 | |
|         state = self._connection
 | |
|         payload = {
 | |
|             'op': self.IDENTIFY,
 | |
|             'd': {
 | |
|                 'server_id': str(state.server_id),
 | |
|                 'user_id': str(state.user.id),
 | |
|                 'session_id': state.session_id,
 | |
|                 'token': state.token,
 | |
|             },
 | |
|         }
 | |
|         await self.send_as_json(payload)
 | |
| 
 | |
|     @classmethod
 | |
|     async def from_connection_state(
 | |
|         cls,
 | |
|         state: VoiceConnectionState,
 | |
|         *,
 | |
|         resume: bool = False,
 | |
|         hook: Optional[Callable[..., Coroutine[Any, Any, Any]]] = None,
 | |
|         seq_ack: int = -1,
 | |
|     ) -> Self:
 | |
|         """Creates a voice websocket for the :class:`VoiceClient`."""
 | |
|         gateway = f'wss://{state.endpoint}/?v=8'
 | |
|         client = state.voice_client
 | |
|         http = client._state.http
 | |
|         socket = await http.ws_connect(gateway, compress=15)
 | |
|         ws = cls(socket, loop=client.loop, hook=hook)
 | |
|         ws.gateway = gateway
 | |
|         ws.seq_ack = seq_ack
 | |
|         ws._connection = state
 | |
|         ws._max_heartbeat_timeout = 60.0
 | |
|         ws.thread_id = threading.get_ident()
 | |
| 
 | |
|         if resume:
 | |
|             await ws.resume()
 | |
|         else:
 | |
|             await ws.identify()
 | |
| 
 | |
|         return ws
 | |
| 
 | |
|     async def select_protocol(self, ip: str, port: int, mode: str) -> None:
 | |
|         payload = {
 | |
|             'op': self.SELECT_PROTOCOL,
 | |
|             'd': {
 | |
|                 'protocol': 'udp',
 | |
|                 'data': {
 | |
|                     'address': ip,
 | |
|                     'port': port,
 | |
|                     'mode': mode,
 | |
|                 },
 | |
|             },
 | |
|         }
 | |
| 
 | |
|         await self.send_as_json(payload)
 | |
| 
 | |
|     async def client_connect(self) -> None:
 | |
|         payload = {
 | |
|             'op': self.CLIENT_CONNECT,
 | |
|             'd': {
 | |
|                 'audio_ssrc': self._connection.ssrc,
 | |
|             },
 | |
|         }
 | |
| 
 | |
|         await self.send_as_json(payload)
 | |
| 
 | |
|     async def speak(self, state: SpeakingState = SpeakingState.voice) -> None:
 | |
|         payload = {
 | |
|             'op': self.SPEAKING,
 | |
|             'd': {
 | |
|                 'speaking': int(state),
 | |
|                 'delay': 0,
 | |
|                 'ssrc': self._connection.ssrc,
 | |
|             },
 | |
|         }
 | |
| 
 | |
|         await self.send_as_json(payload)
 | |
| 
 | |
|     async def received_message(self, msg: Dict[str, Any]) -> None:
 | |
|         _log.debug('Voice websocket frame received: %s', msg)
 | |
|         op = msg['op']
 | |
|         data = msg['d']  # According to Discord this key is always given
 | |
|         self.seq_ack = msg.get('seq', self.seq_ack)  # this key could not be given
 | |
| 
 | |
|         if op == self.READY:
 | |
|             await self.initial_connection(data)
 | |
|         elif op == self.HEARTBEAT_ACK:
 | |
|             if self._keep_alive:
 | |
|                 self._keep_alive.ack()
 | |
|         elif op == self.RESUMED:
 | |
|             _log.debug('Voice RESUME succeeded.')
 | |
|         elif op == self.SESSION_DESCRIPTION:
 | |
|             self._connection.mode = data['mode']
 | |
|             await self.load_secret_key(data)
 | |
|         elif op == self.HELLO:
 | |
|             interval = data['heartbeat_interval'] / 1000.0
 | |
|             self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0))
 | |
|             self._keep_alive.start()
 | |
| 
 | |
|         await self._hook(self, msg)
 | |
| 
 | |
|     async def initial_connection(self, data: Dict[str, Any]) -> None:
 | |
|         state = self._connection
 | |
|         state.ssrc = data['ssrc']
 | |
|         state.voice_port = data['port']
 | |
|         state.endpoint_ip = data['ip']
 | |
| 
 | |
|         _log.debug('Connecting to voice socket')
 | |
|         await self.loop.sock_connect(state.socket, (state.endpoint_ip, state.voice_port))
 | |
| 
 | |
|         state.ip, state.port = await self.discover_ip()
 | |
|         # there *should* always be at least one supported mode (xsalsa20_poly1305)
 | |
|         modes = [mode for mode in data['modes'] if mode in self._connection.supported_modes]
 | |
|         _log.debug('received supported encryption modes: %s', ', '.join(modes))
 | |
| 
 | |
|         mode = modes[0]
 | |
|         await self.select_protocol(state.ip, state.port, mode)
 | |
|         _log.debug('selected the voice protocol for use (%s)', mode)
 | |
| 
 | |
|     async def discover_ip(self) -> Tuple[str, int]:
 | |
|         state = self._connection
 | |
|         packet = bytearray(74)
 | |
|         struct.pack_into('>H', packet, 0, 1)  # 1 = Send
 | |
|         struct.pack_into('>H', packet, 2, 70)  # 70 = Length
 | |
|         struct.pack_into('>I', packet, 4, state.ssrc)
 | |
| 
 | |
|         _log.debug('Sending ip discovery packet')
 | |
|         await self.loop.sock_sendall(state.socket, packet)
 | |
| 
 | |
|         fut: asyncio.Future[bytes] = self.loop.create_future()
 | |
| 
 | |
|         def get_ip_packet(data: bytes):
 | |
|             if data[1] == 0x02 and len(data) == 74:
 | |
|                 self.loop.call_soon_threadsafe(fut.set_result, data)
 | |
| 
 | |
|         fut.add_done_callback(lambda f: state.remove_socket_listener(get_ip_packet))
 | |
|         state.add_socket_listener(get_ip_packet)
 | |
|         recv = await fut
 | |
| 
 | |
|         _log.debug('Received ip discovery packet: %s', recv)
 | |
| 
 | |
|         # the ip is ascii starting at the 8th byte and ending at the first null
 | |
|         ip_start = 8
 | |
|         ip_end = recv.index(0, ip_start)
 | |
|         ip = recv[ip_start:ip_end].decode('ascii')
 | |
| 
 | |
|         port = struct.unpack_from('>H', recv, len(recv) - 2)[0]
 | |
|         _log.debug('detected ip: %s port: %s', ip, port)
 | |
| 
 | |
|         return ip, port
 | |
| 
 | |
|     @property
 | |
|     def latency(self) -> float:
 | |
|         """:class:`float`: Latency between a HEARTBEAT and its HEARTBEAT_ACK in seconds."""
 | |
|         heartbeat = self._keep_alive
 | |
|         return float('inf') if heartbeat is None else heartbeat.latency
 | |
| 
 | |
|     @property
 | |
|     def average_latency(self) -> float:
 | |
|         """:class:`float`: Average of last 20 HEARTBEAT latencies."""
 | |
|         heartbeat = self._keep_alive
 | |
|         if heartbeat is None or not heartbeat.recent_ack_latencies:
 | |
|             return float('inf')
 | |
| 
 | |
|         return sum(heartbeat.recent_ack_latencies) / len(heartbeat.recent_ack_latencies)
 | |
| 
 | |
|     async def load_secret_key(self, data: Dict[str, Any]) -> None:
 | |
|         _log.debug('received secret key for voice connection')
 | |
|         self.secret_key = self._connection.secret_key = data['secret_key']
 | |
| 
 | |
|         # Send a speak command with the "not speaking" state.
 | |
|         # This also tells Discord our SSRC value, which Discord requires before
 | |
|         # sending any voice data (and is the real reason why we call this here).
 | |
|         await self.speak(SpeakingState.none)
 | |
| 
 | |
|     async def poll_event(self) -> None:
 | |
|         # This exception is handled up the chain
 | |
|         msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
 | |
|         if msg.type is aiohttp.WSMsgType.TEXT:
 | |
|             await self.received_message(utils._from_json(msg.data))
 | |
|         elif msg.type is aiohttp.WSMsgType.ERROR:
 | |
|             _log.debug('Received voice %s', msg)
 | |
|             raise ConnectionClosed(self.ws, shard_id=None) from msg.data
 | |
|         elif msg.type in (aiohttp.WSMsgType.CLOSED, aiohttp.WSMsgType.CLOSE, aiohttp.WSMsgType.CLOSING):
 | |
|             _log.debug('Received voice %s', msg)
 | |
|             raise ConnectionClosed(self.ws, shard_id=None, code=self._close_code)
 | |
| 
 | |
|     async def close(self, code: int = 1000) -> None:
 | |
|         if self._keep_alive is not None:
 | |
|             self._keep_alive.stop()
 | |
| 
 | |
|         self._close_code = code
 | |
|         await self.ws.close(code=code)
 |