mirror of
				https://github.com/Rapptz/discord.py.git
				synced 2025-10-30 21:12:58 +00:00 
			
		
		
		
	Implement VoiceProtocol lower level hooks.
This allows changing the connect flow and taking control of it without relying on internal events or tricks.
This commit is contained in:
		| @@ -54,7 +54,7 @@ from .mentions import AllowedMentions | ||||
| from .shard import AutoShardedClient, ShardInfo | ||||
| from .player import * | ||||
| from .webhook import * | ||||
| from .voice_client import VoiceClient | ||||
| from .voice_client import VoiceClient, VoiceProtocol | ||||
| from .audit_logs import AuditLogChanges, AuditLogEntry, AuditLogDiff | ||||
| from .raw_models import * | ||||
| from .team import * | ||||
|   | ||||
| @@ -36,7 +36,7 @@ from .permissions import PermissionOverwrite, Permissions | ||||
| from .role import Role | ||||
| from .invite import Invite | ||||
| from .file import File | ||||
| from .voice_client import VoiceClient | ||||
| from .voice_client import VoiceClient, VoiceProtocol | ||||
| from . import utils | ||||
|  | ||||
| class _Undefined: | ||||
| @@ -1053,7 +1053,6 @@ class Messageable(metaclass=abc.ABCMeta): | ||||
|         """ | ||||
|         return HistoryIterator(self, limit=limit, before=before, after=after, around=around, oldest_first=oldest_first) | ||||
|  | ||||
|  | ||||
| class Connectable(metaclass=abc.ABCMeta): | ||||
|     """An ABC that details the common operations on a channel that can | ||||
|     connect to a voice server. | ||||
| @@ -1072,7 +1071,7 @@ class Connectable(metaclass=abc.ABCMeta): | ||||
|     def _get_voice_state_pair(self): | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     async def connect(self, *, timeout=60.0, reconnect=True): | ||||
|     async def connect(self, *, timeout=60.0, reconnect=True, cls=VoiceClient): | ||||
|         """|coro| | ||||
|  | ||||
|         Connects to voice and creates a :class:`VoiceClient` to establish | ||||
| @@ -1086,6 +1085,9 @@ class Connectable(metaclass=abc.ABCMeta): | ||||
|             Whether the bot should automatically attempt | ||||
|             a reconnect if a part of the handshake fails | ||||
|             or the gateway goes down. | ||||
|         cls: Type[:class:`VoiceProtocol`] | ||||
|             A type that subclasses :class:`~discord.VoiceProtocol` to connect with. | ||||
|             Defaults to :class:`~discord.VoiceClient`. | ||||
|  | ||||
|         Raises | ||||
|         ------- | ||||
| @@ -1098,20 +1100,25 @@ class Connectable(metaclass=abc.ABCMeta): | ||||
|  | ||||
|         Returns | ||||
|         -------- | ||||
|         :class:`~discord.VoiceClient` | ||||
|         :class:`~discord.VoiceProtocol` | ||||
|             A voice client that is fully connected to the voice server. | ||||
|         """ | ||||
|  | ||||
|         if not issubclass(cls, VoiceProtocol): | ||||
|             raise TypeError('Type must meet VoiceProtocol abstract base class.') | ||||
|  | ||||
|         key_id, _ = self._get_voice_client_key() | ||||
|         state = self._state | ||||
|  | ||||
|         if state._get_voice_client(key_id): | ||||
|             raise ClientException('Already connected to a voice channel.') | ||||
|  | ||||
|         voice = VoiceClient(state=state, timeout=timeout, channel=self) | ||||
|         client = state._get_client() | ||||
|         voice = cls(client, self) | ||||
|         state._add_voice_client(key_id, voice) | ||||
|  | ||||
|         try: | ||||
|             await voice.connect(reconnect=reconnect) | ||||
|             await voice.connect(timeout=timeout, reconnect=reconnect) | ||||
|         except asyncio.TimeoutError: | ||||
|             try: | ||||
|                 await voice.disconnect(force=True) | ||||
|   | ||||
| @@ -238,6 +238,7 @@ class Client: | ||||
|         self._closed = False | ||||
|         self._ready = asyncio.Event() | ||||
|         self._connection._get_websocket = self._get_websocket | ||||
|         self._connection._get_client = lambda: self | ||||
|  | ||||
|         if VoiceClient.warn_nacl: | ||||
|             VoiceClient.warn_nacl = False | ||||
| @@ -299,7 +300,10 @@ class Client: | ||||
|  | ||||
|     @property | ||||
|     def voice_clients(self): | ||||
|         """List[:class:`.VoiceClient`]: Represents a list of voice connections.""" | ||||
|         """List[:class:`.VoiceProtocol`]: Represents a list of voice connections. | ||||
|  | ||||
|         These are usually :class:`.VoiceClient` instances. | ||||
|         """ | ||||
|         return self._connection.voice_clients | ||||
|  | ||||
|     def is_ready(self): | ||||
|   | ||||
| @@ -238,7 +238,7 @@ class Context(discord.abc.Messageable): | ||||
|  | ||||
|     @property | ||||
|     def voice_client(self): | ||||
|         r"""Optional[:class:`.VoiceClient`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable.""" | ||||
|         r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable.""" | ||||
|         g = self.guild | ||||
|         return g.voice_client if g else None | ||||
|  | ||||
|   | ||||
| @@ -377,7 +377,7 @@ class Guild(Hashable): | ||||
|  | ||||
|     @property | ||||
|     def voice_client(self): | ||||
|         """Optional[:class:`VoiceClient`]: Returns the :class:`VoiceClient` associated with this guild, if any.""" | ||||
|         """Optional[:class:`VoiceProtocol`]: Returns the :class:`VoiceProtocol` associated with this guild, if any.""" | ||||
|         return self._state._get_voice_client(self.id) | ||||
|  | ||||
|     @property | ||||
|   | ||||
| @@ -292,6 +292,7 @@ class AutoShardedClient(Client): | ||||
|         # the key is the shard_id | ||||
|         self.__shards = {} | ||||
|         self._connection._get_websocket = self._get_websocket | ||||
|         self._connection._get_client = lambda: self | ||||
|         self.__queue = asyncio.PriorityQueue() | ||||
|  | ||||
|     def _get_websocket(self, guild_id=None, *, shard_id=None): | ||||
|   | ||||
| @@ -63,6 +63,12 @@ Listener = namedtuple('Listener', ('type', 'future', 'predicate')) | ||||
| log = logging.getLogger(__name__) | ||||
| ReadyState = namedtuple('ReadyState', ('launch', 'guilds')) | ||||
|  | ||||
| async def logging_coroutine(coroutine, *, info): | ||||
|     try: | ||||
|         await coroutine | ||||
|     except Exception: | ||||
|         log.exception('Exception occurred during %s', info) | ||||
|  | ||||
| class ConnectionState: | ||||
|     def __init__(self, *, dispatch, handlers, hooks, syncer, http, loop, **options): | ||||
|         self.loop = loop | ||||
| @@ -939,9 +945,8 @@ class ConnectionState: | ||||
|             if int(data['user_id']) == self.user.id: | ||||
|                 voice = self._get_voice_client(guild.id) | ||||
|                 if voice is not None: | ||||
|                     ch = guild.get_channel(channel_id) | ||||
|                     if ch is not None: | ||||
|                         voice.channel = ch | ||||
|                     coro = voice.on_voice_state_update(data) | ||||
|                     asyncio.ensure_future(logging_coroutine(coro, info='Voice Protocol voice state update handler')) | ||||
|  | ||||
|             member, before, after = guild._update_voice_state(data, channel_id) | ||||
|             if member is not None: | ||||
| @@ -962,7 +967,8 @@ class ConnectionState: | ||||
|  | ||||
|         vc = self._get_voice_client(key_id) | ||||
|         if vc is not None: | ||||
|             asyncio.ensure_future(vc._create_socket(key_id, data)) | ||||
|             coro = vc.on_voice_server_update(data) | ||||
|             asyncio.ensure_future(logging_coroutine(coro, info='Voice Protocol voice server update handler')) | ||||
|  | ||||
|     def parse_typing_start(self, data): | ||||
|         channel, guild = self._get_guild_channel(data) | ||||
|   | ||||
| @@ -45,7 +45,7 @@ import logging | ||||
| import struct | ||||
| import threading | ||||
|  | ||||
| from . import opus | ||||
| from . import opus, utils | ||||
| from .backoff import ExponentialBackoff | ||||
| from .gateway import * | ||||
| from .errors import ClientException, ConnectionClosed | ||||
| @@ -59,7 +59,110 @@ except ImportError: | ||||
|  | ||||
| log = logging.getLogger(__name__) | ||||
|  | ||||
| class VoiceClient: | ||||
| class VoiceProtocol: | ||||
|     """A class that represents the Discord voice protocol. | ||||
|  | ||||
|     This is an abstract class. The library provides a concrete implementation | ||||
|     under :class:`VoiceClient`. | ||||
|  | ||||
|     This class allows you to implement a protocol to allow for an external | ||||
|     method of sending voice, such as Lavalink_ or a native library implementation. | ||||
|  | ||||
|     These classes are passed to :meth:`abc.Connectable.connect`. | ||||
|  | ||||
|     .. _Lavalink: https://github.com/Frederikam/Lavalink | ||||
|  | ||||
|     Parameters | ||||
|     ------------ | ||||
|     client: :class:`Client` | ||||
|         The client (or its subclasses) that started the connection request. | ||||
|     channel: :class:`abc.Connectable` | ||||
|         The voice channel that is being connected to. | ||||
|     """ | ||||
|  | ||||
|     def __init__(self, client, channel): | ||||
|         self.client = client | ||||
|         self.channel = channel | ||||
|  | ||||
|     async def on_voice_state_update(self, data): | ||||
|         """|coro| | ||||
|  | ||||
|         An abstract method that is called when the client's voice state | ||||
|         has changed. This corresponds to ``VOICE_STATE_UPDATE``. | ||||
|  | ||||
|         Parameters | ||||
|         ------------ | ||||
|         data: :class:`dict` | ||||
|             The raw `voice state payload`_. | ||||
|  | ||||
|         .. _voice state payload: https://discord.com/developers/docs/resources/voice#voice-state-object | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     async def on_voice_server_update(self, data): | ||||
|         """|coro| | ||||
|  | ||||
|         An abstract method that is called when initially connecting to voice. | ||||
|         This corresponds to ``VOICE_SERVER_UPDATE``. | ||||
|  | ||||
|         Parameters | ||||
|         ------------ | ||||
|         data: :class:`dict` | ||||
|             The raw `voice server update payload`__. | ||||
|  | ||||
|             .. _VSU: https://discord.com/developers/docs/topics/gateway#voice-server-update-voice-server-update-event-fields | ||||
|  | ||||
|             __ VSU_ | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     async def connect(self, *, timeout, reconnect): | ||||
|         """|coro| | ||||
|  | ||||
|         An abstract method called when the client initiates the connection request. | ||||
|  | ||||
|         When a connection is requested initially, the library calls the following functions | ||||
|         in order: | ||||
|  | ||||
|         - ``__init__`` | ||||
|  | ||||
|         Parameters | ||||
|         ------------ | ||||
|         timeout: :class:`float` | ||||
|             The timeout for the connection. | ||||
|         reconnect: :class:`bool` | ||||
|             Whether reconnection is expected. | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     async def disconnect(self, *, force): | ||||
|         """|coro| | ||||
|  | ||||
|         An abstract method called when the client terminates the connection. | ||||
|  | ||||
|         See :meth:`cleanup`. | ||||
|  | ||||
|         Parameters | ||||
|         ------------ | ||||
|         force: :class:`bool` | ||||
|             Whether the disconnection was forced. | ||||
|         """ | ||||
|         raise NotImplementedError | ||||
|  | ||||
|     def cleanup(self): | ||||
|         """This method *must* be called to ensure proper clean-up during a disconnect. | ||||
|  | ||||
|         It is advisable to call this from within :meth:`disconnect` when you are | ||||
|         completely done with the voice protocol instance. | ||||
|  | ||||
|         This method removes it from the internal state cache that keeps track of | ||||
|         currently alive voice clients. Failure to clean-up will cause subsequent | ||||
|         connections to report that it's still connected. | ||||
|         """ | ||||
|         key_id, _ = self.channel._get_voice_client_key() | ||||
|         self.client._connection._remove_voice_client(key_id) | ||||
|  | ||||
| class VoiceClient(VoiceProtocol): | ||||
|     """Represents a Discord voice connection. | ||||
|  | ||||
|     You do not create these, you typically get them from | ||||
| @@ -85,14 +188,13 @@ class VoiceClient: | ||||
|     loop: :class:`asyncio.AbstractEventLoop` | ||||
|         The event loop that the voice client is running on. | ||||
|     """ | ||||
|     def __init__(self, state, timeout, channel): | ||||
|     def __init__(self, client, channel): | ||||
|         if not has_nacl: | ||||
|             raise RuntimeError("PyNaCl library needed in order to use voice") | ||||
|  | ||||
|         self.channel = channel | ||||
|         self.main_ws = None | ||||
|         self.timeout = timeout | ||||
|         self.ws = None | ||||
|         super().__init__(client, channel) | ||||
|         state = client._connection | ||||
|         self.token = None | ||||
|         self.socket = None | ||||
|         self.loop = state.loop | ||||
|         self._state = state | ||||
| @@ -100,8 +202,8 @@ class VoiceClient: | ||||
|         self._connected = threading.Event() | ||||
|  | ||||
|         self._handshaking = False | ||||
|         self._handshake_check = asyncio.Lock() | ||||
|         self._handshake_complete = asyncio.Event() | ||||
|         self._voice_state_complete = asyncio.Event() | ||||
|         self._voice_server_complete = asyncio.Event() | ||||
|  | ||||
|         self.mode = None | ||||
|         self._connections = 0 | ||||
| @@ -138,48 +240,24 @@ class VoiceClient: | ||||
|  | ||||
|     # connection related | ||||
|  | ||||
|     async def start_handshake(self): | ||||
|         log.info('Starting voice handshake...') | ||||
|     async def on_voice_state_update(self, data): | ||||
|         self.session_id = data['session_id'] | ||||
|         channel_id = data['channel_id'] | ||||
|  | ||||
|         guild_id, channel_id = self.channel._get_voice_state_pair() | ||||
|         state = self._state | ||||
|         self.main_ws = ws = state._get_websocket(guild_id) | ||||
|         self._connections += 1 | ||||
|         if not self._handshaking: | ||||
|             # If we're done handshaking then we just need to update ourselves | ||||
|             guild = self.guild | ||||
|             self.channel = channel_id and guild and guild.get_channel(int(channel_id)) | ||||
|         else: | ||||
|             self._voice_state_complete.set() | ||||
|  | ||||
|         # request joining | ||||
|         await ws.voice_state(guild_id, channel_id) | ||||
|     async def on_voice_server_update(self, data): | ||||
|         if self._voice_server_complete.is_set(): | ||||
|             log.info('Ignoring extraneous voice server update.') | ||||
|             return | ||||
|  | ||||
|         try: | ||||
|             await asyncio.wait_for(self._handshake_complete.wait(), timeout=self.timeout) | ||||
|         except asyncio.TimeoutError: | ||||
|             await self.terminate_handshake(remove=True) | ||||
|             raise | ||||
|  | ||||
|         log.info('Voice handshake complete. Endpoint found %s (IP: %s)', self.endpoint, self.endpoint_ip) | ||||
|  | ||||
|     async def terminate_handshake(self, *, remove=False): | ||||
|         guild_id, channel_id = self.channel._get_voice_state_pair() | ||||
|         self._handshake_complete.clear() | ||||
|         await self.main_ws.voice_state(guild_id, None, self_mute=True) | ||||
|         self._handshaking = False | ||||
|  | ||||
|         log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', channel_id, guild_id) | ||||
|         if remove: | ||||
|             log.info('The voice client has been removed for Channel ID %s (Guild ID %s)', channel_id, guild_id) | ||||
|             key_id, _ = self.channel._get_voice_client_key() | ||||
|             self._state._remove_voice_client(key_id) | ||||
|  | ||||
|     async def _create_socket(self, server_id, data): | ||||
|         async with self._handshake_check: | ||||
|             if self._handshaking: | ||||
|                 log.info("Ignoring voice server update while handshake is in progress") | ||||
|                 return | ||||
|             self._handshaking = True | ||||
|  | ||||
|         self._connected.clear() | ||||
|         self.session_id = self.main_ws.session_id | ||||
|         self.server_id = server_id | ||||
|         self.token = data.get('token') | ||||
|         self.server_id = int(data['guild_id']) | ||||
|         endpoint = data.get('endpoint') | ||||
|  | ||||
|         if endpoint is None or self.token is None: | ||||
| @@ -195,23 +273,77 @@ class VoiceClient: | ||||
|         # This gets set later | ||||
|         self.endpoint_ip = None | ||||
|  | ||||
|         if self.socket: | ||||
|             try: | ||||
|                 self.socket.close() | ||||
|             except Exception: | ||||
|                 pass | ||||
|  | ||||
|         self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | ||||
|         self.socket.setblocking(False) | ||||
|  | ||||
|         if self._handshake_complete.is_set(): | ||||
|             # terminate the websocket and handle the reconnect loop if necessary. | ||||
|             self._handshake_complete.clear() | ||||
|             self._handshaking = False | ||||
|         if not self._handshaking: | ||||
|             # If we're not handshaking then we need to terminate our previous connection in the websocket | ||||
|             await self.ws.close(4000) | ||||
|             return | ||||
|  | ||||
|         self._handshake_complete.set() | ||||
|         self._voice_server_complete.set() | ||||
|  | ||||
|     async def voice_connect(self): | ||||
|         self._connections += 1 | ||||
|         await self.channel.guild.change_voice_state(channel=self.channel) | ||||
|  | ||||
|     async def voice_disconnect(self): | ||||
|         log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id) | ||||
|         await self.channel.guild.change_voice_state(channel=None) | ||||
|  | ||||
|     async def connect(self, *, reconnect, timeout): | ||||
|         log.info('Connecting to voice...') | ||||
|         self.timeout = timeout | ||||
|         try: | ||||
|             del self.secret_key | ||||
|         except AttributeError: | ||||
|             pass | ||||
|  | ||||
|  | ||||
|         for i in range(5): | ||||
|             self._voice_state_complete.clear() | ||||
|             self._voice_server_complete.clear() | ||||
|             self._handshaking = True | ||||
|  | ||||
|             # This has to be created before we start the flow. | ||||
|             futures = [ | ||||
|                 self._voice_state_complete.wait(), | ||||
|                 self._voice_server_complete.wait(), | ||||
|             ] | ||||
|  | ||||
|             # Start the connection flow | ||||
|             log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1) | ||||
|             await self.voice_connect() | ||||
|  | ||||
|             try: | ||||
|                 await utils.sane_wait_for(futures, timeout=timeout) | ||||
|             except asyncio.TimeoutError: | ||||
|                 await self.disconnect(force=True) | ||||
|                 raise | ||||
|  | ||||
|             log.info('Voice handshake complete. Endpoint found %s', self.endpoint) | ||||
|             self._handshaking = False | ||||
|             self._voice_server_complete.clear() | ||||
|             self._voice_state_complete.clear() | ||||
|  | ||||
|             try: | ||||
|                 self.ws = await DiscordVoiceWebSocket.from_client(self) | ||||
|                 self._connected.clear() | ||||
|                 while not hasattr(self, 'secret_key'): | ||||
|                     await self.ws.poll_event() | ||||
|                 self._connected.set() | ||||
|                 break | ||||
|             except (ConnectionClosed, asyncio.TimeoutError): | ||||
|                 if reconnect: | ||||
|                     log.exception('Failed to connect to voice... Retrying...') | ||||
|                     await asyncio.sleep(1 + i * 2.0) | ||||
|                     await self.voice_disconnect() | ||||
|                     continue | ||||
|                 else: | ||||
|                     raise | ||||
|  | ||||
|         if self._runner is None: | ||||
|             self._runner = self.loop.create_task(self.poll_voice_ws(reconnect)) | ||||
|  | ||||
|     @property | ||||
|     def latency(self): | ||||
| @@ -234,35 +366,6 @@ class VoiceClient: | ||||
|         ws = self.ws | ||||
|         return float("inf") if not ws else ws.average_latency | ||||
|  | ||||
|     async def connect(self, *, reconnect=True, _tries=0, do_handshake=True): | ||||
|         log.info('Connecting to voice...') | ||||
|         try: | ||||
|             del self.secret_key | ||||
|         except AttributeError: | ||||
|             pass | ||||
|  | ||||
|         if do_handshake: | ||||
|             await self.start_handshake() | ||||
|  | ||||
|         try: | ||||
|             self.ws = await DiscordVoiceWebSocket.from_client(self) | ||||
|             self._handshaking = False | ||||
|             self._connected.clear() | ||||
|             while not hasattr(self, 'secret_key'): | ||||
|                 await self.ws.poll_event() | ||||
|             self._connected.set() | ||||
|         except (ConnectionClosed, asyncio.TimeoutError): | ||||
|             if reconnect and _tries < 5: | ||||
|                 log.exception('Failed to connect to voice... Retrying...') | ||||
|                 await asyncio.sleep(1 + _tries * 2.0) | ||||
|                 await self.terminate_handshake() | ||||
|                 await self.connect(reconnect=reconnect, _tries=_tries + 1) | ||||
|             else: | ||||
|                 raise | ||||
|  | ||||
|         if self._runner is None: | ||||
|             self._runner = self.loop.create_task(self.poll_voice_ws(reconnect)) | ||||
|  | ||||
|     async def poll_voice_ws(self, reconnect): | ||||
|         backoff = ExponentialBackoff() | ||||
|         while True: | ||||
| @@ -287,9 +390,9 @@ class VoiceClient: | ||||
|                 log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry) | ||||
|                 self._connected.clear() | ||||
|                 await asyncio.sleep(retry) | ||||
|                 await self.terminate_handshake() | ||||
|                 await self.voice_disconnect() | ||||
|                 try: | ||||
|                     await self.connect(reconnect=True) | ||||
|                     await self.connect(reconnect=True, timeout=self.timeout) | ||||
|                 except asyncio.TimeoutError: | ||||
|                     # at this point we've retried 5 times... let's continue the loop. | ||||
|                     log.warning('Could not connect to voice... Retrying...') | ||||
| @@ -310,8 +413,9 @@ class VoiceClient: | ||||
|             if self.ws: | ||||
|                 await self.ws.close() | ||||
|  | ||||
|             await self.terminate_handshake(remove=True) | ||||
|             await self.voice_disconnect() | ||||
|         finally: | ||||
|             self.cleanup() | ||||
|             if self.socket: | ||||
|                 self.socket.close() | ||||
|  | ||||
| @@ -325,8 +429,7 @@ class VoiceClient: | ||||
|         channel: :class:`abc.Snowflake` | ||||
|             The channel to move to. Must be a voice channel. | ||||
|         """ | ||||
|         guild_id, _ = self.channel._get_voice_state_pair() | ||||
|         await self.main_ws.voice_state(guild_id, channel.id) | ||||
|         await self.channel.guild.change_voice_state(channel=channel) | ||||
|  | ||||
|     def is_connected(self): | ||||
|         """Indicates if the voice client is connected to voice.""" | ||||
|   | ||||
| @@ -54,6 +54,9 @@ Voice | ||||
| .. autoclass:: VoiceClient() | ||||
|     :members: | ||||
|  | ||||
| .. autoclass:: VoiceProtocol | ||||
|     :members: | ||||
|  | ||||
| .. autoclass:: AudioSource | ||||
|     :members: | ||||
|  | ||||
|   | ||||
		Reference in New Issue
	
	Block a user