Implement VoiceProtocol lower level hooks.
This allows changing the connect flow and taking control of it without relying on internal events or tricks.
This commit is contained in:
		@@ -54,7 +54,7 @@ from .mentions import AllowedMentions
 | 
				
			|||||||
from .shard import AutoShardedClient, ShardInfo
 | 
					from .shard import AutoShardedClient, ShardInfo
 | 
				
			||||||
from .player import *
 | 
					from .player import *
 | 
				
			||||||
from .webhook import *
 | 
					from .webhook import *
 | 
				
			||||||
from .voice_client import VoiceClient
 | 
					from .voice_client import VoiceClient, VoiceProtocol
 | 
				
			||||||
from .audit_logs import AuditLogChanges, AuditLogEntry, AuditLogDiff
 | 
					from .audit_logs import AuditLogChanges, AuditLogEntry, AuditLogDiff
 | 
				
			||||||
from .raw_models import *
 | 
					from .raw_models import *
 | 
				
			||||||
from .team import *
 | 
					from .team import *
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -36,7 +36,7 @@ from .permissions import PermissionOverwrite, Permissions
 | 
				
			|||||||
from .role import Role
 | 
					from .role import Role
 | 
				
			||||||
from .invite import Invite
 | 
					from .invite import Invite
 | 
				
			||||||
from .file import File
 | 
					from .file import File
 | 
				
			||||||
from .voice_client import VoiceClient
 | 
					from .voice_client import VoiceClient, VoiceProtocol
 | 
				
			||||||
from . import utils
 | 
					from . import utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class _Undefined:
 | 
					class _Undefined:
 | 
				
			||||||
@@ -1053,7 +1053,6 @@ class Messageable(metaclass=abc.ABCMeta):
 | 
				
			|||||||
        """
 | 
					        """
 | 
				
			||||||
        return HistoryIterator(self, limit=limit, before=before, after=after, around=around, oldest_first=oldest_first)
 | 
					        return HistoryIterator(self, limit=limit, before=before, after=after, around=around, oldest_first=oldest_first)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					 | 
				
			||||||
class Connectable(metaclass=abc.ABCMeta):
 | 
					class Connectable(metaclass=abc.ABCMeta):
 | 
				
			||||||
    """An ABC that details the common operations on a channel that can
 | 
					    """An ABC that details the common operations on a channel that can
 | 
				
			||||||
    connect to a voice server.
 | 
					    connect to a voice server.
 | 
				
			||||||
@@ -1072,7 +1071,7 @@ class Connectable(metaclass=abc.ABCMeta):
 | 
				
			|||||||
    def _get_voice_state_pair(self):
 | 
					    def _get_voice_state_pair(self):
 | 
				
			||||||
        raise NotImplementedError
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def connect(self, *, timeout=60.0, reconnect=True):
 | 
					    async def connect(self, *, timeout=60.0, reconnect=True, cls=VoiceClient):
 | 
				
			||||||
        """|coro|
 | 
					        """|coro|
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Connects to voice and creates a :class:`VoiceClient` to establish
 | 
					        Connects to voice and creates a :class:`VoiceClient` to establish
 | 
				
			||||||
@@ -1086,6 +1085,9 @@ class Connectable(metaclass=abc.ABCMeta):
 | 
				
			|||||||
            Whether the bot should automatically attempt
 | 
					            Whether the bot should automatically attempt
 | 
				
			||||||
            a reconnect if a part of the handshake fails
 | 
					            a reconnect if a part of the handshake fails
 | 
				
			||||||
            or the gateway goes down.
 | 
					            or the gateway goes down.
 | 
				
			||||||
 | 
					        cls: Type[:class:`VoiceProtocol`]
 | 
				
			||||||
 | 
					            A type that subclasses :class:`~discord.VoiceProtocol` to connect with.
 | 
				
			||||||
 | 
					            Defaults to :class:`~discord.VoiceClient`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        Raises
 | 
					        Raises
 | 
				
			||||||
        -------
 | 
					        -------
 | 
				
			||||||
@@ -1098,20 +1100,25 @@ class Connectable(metaclass=abc.ABCMeta):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        Returns
 | 
					        Returns
 | 
				
			||||||
        --------
 | 
					        --------
 | 
				
			||||||
        :class:`~discord.VoiceClient`
 | 
					        :class:`~discord.VoiceProtocol`
 | 
				
			||||||
            A voice client that is fully connected to the voice server.
 | 
					            A voice client that is fully connected to the voice server.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if not issubclass(cls, VoiceProtocol):
 | 
				
			||||||
 | 
					            raise TypeError('Type must meet VoiceProtocol abstract base class.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        key_id, _ = self._get_voice_client_key()
 | 
					        key_id, _ = self._get_voice_client_key()
 | 
				
			||||||
        state = self._state
 | 
					        state = self._state
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if state._get_voice_client(key_id):
 | 
					        if state._get_voice_client(key_id):
 | 
				
			||||||
            raise ClientException('Already connected to a voice channel.')
 | 
					            raise ClientException('Already connected to a voice channel.')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        voice = VoiceClient(state=state, timeout=timeout, channel=self)
 | 
					        client = state._get_client()
 | 
				
			||||||
 | 
					        voice = cls(client, self)
 | 
				
			||||||
        state._add_voice_client(key_id, voice)
 | 
					        state._add_voice_client(key_id, voice)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            await voice.connect(reconnect=reconnect)
 | 
					            await voice.connect(timeout=timeout, reconnect=reconnect)
 | 
				
			||||||
        except asyncio.TimeoutError:
 | 
					        except asyncio.TimeoutError:
 | 
				
			||||||
            try:
 | 
					            try:
 | 
				
			||||||
                await voice.disconnect(force=True)
 | 
					                await voice.disconnect(force=True)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -238,6 +238,7 @@ class Client:
 | 
				
			|||||||
        self._closed = False
 | 
					        self._closed = False
 | 
				
			||||||
        self._ready = asyncio.Event()
 | 
					        self._ready = asyncio.Event()
 | 
				
			||||||
        self._connection._get_websocket = self._get_websocket
 | 
					        self._connection._get_websocket = self._get_websocket
 | 
				
			||||||
 | 
					        self._connection._get_client = lambda: self
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if VoiceClient.warn_nacl:
 | 
					        if VoiceClient.warn_nacl:
 | 
				
			||||||
            VoiceClient.warn_nacl = False
 | 
					            VoiceClient.warn_nacl = False
 | 
				
			||||||
@@ -299,7 +300,10 @@ class Client:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def voice_clients(self):
 | 
					    def voice_clients(self):
 | 
				
			||||||
        """List[:class:`.VoiceClient`]: Represents a list of voice connections."""
 | 
					        """List[:class:`.VoiceProtocol`]: Represents a list of voice connections.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        These are usually :class:`.VoiceClient` instances.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
        return self._connection.voice_clients
 | 
					        return self._connection.voice_clients
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def is_ready(self):
 | 
					    def is_ready(self):
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -238,7 +238,7 @@ class Context(discord.abc.Messageable):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def voice_client(self):
 | 
					    def voice_client(self):
 | 
				
			||||||
        r"""Optional[:class:`.VoiceClient`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable."""
 | 
					        r"""Optional[:class:`.VoiceProtocol`]: A shortcut to :attr:`.Guild.voice_client`\, if applicable."""
 | 
				
			||||||
        g = self.guild
 | 
					        g = self.guild
 | 
				
			||||||
        return g.voice_client if g else None
 | 
					        return g.voice_client if g else None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -377,7 +377,7 @@ class Guild(Hashable):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def voice_client(self):
 | 
					    def voice_client(self):
 | 
				
			||||||
        """Optional[:class:`VoiceClient`]: Returns the :class:`VoiceClient` associated with this guild, if any."""
 | 
					        """Optional[:class:`VoiceProtocol`]: Returns the :class:`VoiceProtocol` associated with this guild, if any."""
 | 
				
			||||||
        return self._state._get_voice_client(self.id)
 | 
					        return self._state._get_voice_client(self.id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -292,6 +292,7 @@ class AutoShardedClient(Client):
 | 
				
			|||||||
        # the key is the shard_id
 | 
					        # the key is the shard_id
 | 
				
			||||||
        self.__shards = {}
 | 
					        self.__shards = {}
 | 
				
			||||||
        self._connection._get_websocket = self._get_websocket
 | 
					        self._connection._get_websocket = self._get_websocket
 | 
				
			||||||
 | 
					        self._connection._get_client = lambda: self
 | 
				
			||||||
        self.__queue = asyncio.PriorityQueue()
 | 
					        self.__queue = asyncio.PriorityQueue()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _get_websocket(self, guild_id=None, *, shard_id=None):
 | 
					    def _get_websocket(self, guild_id=None, *, shard_id=None):
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -63,6 +63,12 @@ Listener = namedtuple('Listener', ('type', 'future', 'predicate'))
 | 
				
			|||||||
log = logging.getLogger(__name__)
 | 
					log = logging.getLogger(__name__)
 | 
				
			||||||
ReadyState = namedtuple('ReadyState', ('launch', 'guilds'))
 | 
					ReadyState = namedtuple('ReadyState', ('launch', 'guilds'))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					async def logging_coroutine(coroutine, *, info):
 | 
				
			||||||
 | 
					    try:
 | 
				
			||||||
 | 
					        await coroutine
 | 
				
			||||||
 | 
					    except Exception:
 | 
				
			||||||
 | 
					        log.exception('Exception occurred during %s', info)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ConnectionState:
 | 
					class ConnectionState:
 | 
				
			||||||
    def __init__(self, *, dispatch, handlers, hooks, syncer, http, loop, **options):
 | 
					    def __init__(self, *, dispatch, handlers, hooks, syncer, http, loop, **options):
 | 
				
			||||||
        self.loop = loop
 | 
					        self.loop = loop
 | 
				
			||||||
@@ -939,9 +945,8 @@ class ConnectionState:
 | 
				
			|||||||
            if int(data['user_id']) == self.user.id:
 | 
					            if int(data['user_id']) == self.user.id:
 | 
				
			||||||
                voice = self._get_voice_client(guild.id)
 | 
					                voice = self._get_voice_client(guild.id)
 | 
				
			||||||
                if voice is not None:
 | 
					                if voice is not None:
 | 
				
			||||||
                    ch = guild.get_channel(channel_id)
 | 
					                    coro = voice.on_voice_state_update(data)
 | 
				
			||||||
                    if ch is not None:
 | 
					                    asyncio.ensure_future(logging_coroutine(coro, info='Voice Protocol voice state update handler'))
 | 
				
			||||||
                        voice.channel = ch
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            member, before, after = guild._update_voice_state(data, channel_id)
 | 
					            member, before, after = guild._update_voice_state(data, channel_id)
 | 
				
			||||||
            if member is not None:
 | 
					            if member is not None:
 | 
				
			||||||
@@ -962,7 +967,8 @@ class ConnectionState:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
        vc = self._get_voice_client(key_id)
 | 
					        vc = self._get_voice_client(key_id)
 | 
				
			||||||
        if vc is not None:
 | 
					        if vc is not None:
 | 
				
			||||||
            asyncio.ensure_future(vc._create_socket(key_id, data))
 | 
					            coro = vc.on_voice_server_update(data)
 | 
				
			||||||
 | 
					            asyncio.ensure_future(logging_coroutine(coro, info='Voice Protocol voice server update handler'))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def parse_typing_start(self, data):
 | 
					    def parse_typing_start(self, data):
 | 
				
			||||||
        channel, guild = self._get_guild_channel(data)
 | 
					        channel, guild = self._get_guild_channel(data)
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -45,7 +45,7 @@ import logging
 | 
				
			|||||||
import struct
 | 
					import struct
 | 
				
			||||||
import threading
 | 
					import threading
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from . import opus
 | 
					from . import opus, utils
 | 
				
			||||||
from .backoff import ExponentialBackoff
 | 
					from .backoff import ExponentialBackoff
 | 
				
			||||||
from .gateway import *
 | 
					from .gateway import *
 | 
				
			||||||
from .errors import ClientException, ConnectionClosed
 | 
					from .errors import ClientException, ConnectionClosed
 | 
				
			||||||
@@ -59,7 +59,110 @@ except ImportError:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
log = logging.getLogger(__name__)
 | 
					log = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class VoiceClient:
 | 
					class VoiceProtocol:
 | 
				
			||||||
 | 
					    """A class that represents the Discord voice protocol.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    This is an abstract class. The library provides a concrete implementation
 | 
				
			||||||
 | 
					    under :class:`VoiceClient`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    This class allows you to implement a protocol to allow for an external
 | 
				
			||||||
 | 
					    method of sending voice, such as Lavalink_ or a native library implementation.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    These classes are passed to :meth:`abc.Connectable.connect`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    .. _Lavalink: https://github.com/Frederikam/Lavalink
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    Parameters
 | 
				
			||||||
 | 
					    ------------
 | 
				
			||||||
 | 
					    client: :class:`Client`
 | 
				
			||||||
 | 
					        The client (or its subclasses) that started the connection request.
 | 
				
			||||||
 | 
					    channel: :class:`abc.Connectable`
 | 
				
			||||||
 | 
					        The voice channel that is being connected to.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __init__(self, client, channel):
 | 
				
			||||||
 | 
					        self.client = client
 | 
				
			||||||
 | 
					        self.channel = channel
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def on_voice_state_update(self, data):
 | 
				
			||||||
 | 
					        """|coro|
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        An abstract method that is called when the client's voice state
 | 
				
			||||||
 | 
					        has changed. This corresponds to ``VOICE_STATE_UPDATE``.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ------------
 | 
				
			||||||
 | 
					        data: :class:`dict`
 | 
				
			||||||
 | 
					            The raw `voice state payload`_.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        .. _voice state payload: https://discord.com/developers/docs/resources/voice#voice-state-object
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def on_voice_server_update(self, data):
 | 
				
			||||||
 | 
					        """|coro|
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        An abstract method that is called when initially connecting to voice.
 | 
				
			||||||
 | 
					        This corresponds to ``VOICE_SERVER_UPDATE``.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ------------
 | 
				
			||||||
 | 
					        data: :class:`dict`
 | 
				
			||||||
 | 
					            The raw `voice server update payload`__.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            .. _VSU: https://discord.com/developers/docs/topics/gateway#voice-server-update-voice-server-update-event-fields
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            __ VSU_
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def connect(self, *, timeout, reconnect):
 | 
				
			||||||
 | 
					        """|coro|
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        An abstract method called when the client initiates the connection request.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        When a connection is requested initially, the library calls the following functions
 | 
				
			||||||
 | 
					        in order:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        - ``__init__``
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ------------
 | 
				
			||||||
 | 
					        timeout: :class:`float`
 | 
				
			||||||
 | 
					            The timeout for the connection.
 | 
				
			||||||
 | 
					        reconnect: :class:`bool`
 | 
				
			||||||
 | 
					            Whether reconnection is expected.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def disconnect(self, *, force):
 | 
				
			||||||
 | 
					        """|coro|
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        An abstract method called when the client terminates the connection.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        See :meth:`cleanup`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        Parameters
 | 
				
			||||||
 | 
					        ------------
 | 
				
			||||||
 | 
					        force: :class:`bool`
 | 
				
			||||||
 | 
					            Whether the disconnection was forced.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        raise NotImplementedError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def cleanup(self):
 | 
				
			||||||
 | 
					        """This method *must* be called to ensure proper clean-up during a disconnect.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        It is advisable to call this from within :meth:`disconnect` when you are
 | 
				
			||||||
 | 
					        completely done with the voice protocol instance.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        This method removes it from the internal state cache that keeps track of
 | 
				
			||||||
 | 
					        currently alive voice clients. Failure to clean-up will cause subsequent
 | 
				
			||||||
 | 
					        connections to report that it's still connected.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        key_id, _ = self.channel._get_voice_client_key()
 | 
				
			||||||
 | 
					        self.client._connection._remove_voice_client(key_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class VoiceClient(VoiceProtocol):
 | 
				
			||||||
    """Represents a Discord voice connection.
 | 
					    """Represents a Discord voice connection.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    You do not create these, you typically get them from
 | 
					    You do not create these, you typically get them from
 | 
				
			||||||
@@ -85,14 +188,13 @@ class VoiceClient:
 | 
				
			|||||||
    loop: :class:`asyncio.AbstractEventLoop`
 | 
					    loop: :class:`asyncio.AbstractEventLoop`
 | 
				
			||||||
        The event loop that the voice client is running on.
 | 
					        The event loop that the voice client is running on.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    def __init__(self, state, timeout, channel):
 | 
					    def __init__(self, client, channel):
 | 
				
			||||||
        if not has_nacl:
 | 
					        if not has_nacl:
 | 
				
			||||||
            raise RuntimeError("PyNaCl library needed in order to use voice")
 | 
					            raise RuntimeError("PyNaCl library needed in order to use voice")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.channel = channel
 | 
					        super().__init__(client, channel)
 | 
				
			||||||
        self.main_ws = None
 | 
					        state = client._connection
 | 
				
			||||||
        self.timeout = timeout
 | 
					        self.token = None
 | 
				
			||||||
        self.ws = None
 | 
					 | 
				
			||||||
        self.socket = None
 | 
					        self.socket = None
 | 
				
			||||||
        self.loop = state.loop
 | 
					        self.loop = state.loop
 | 
				
			||||||
        self._state = state
 | 
					        self._state = state
 | 
				
			||||||
@@ -100,8 +202,8 @@ class VoiceClient:
 | 
				
			|||||||
        self._connected = threading.Event()
 | 
					        self._connected = threading.Event()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self._handshaking = False
 | 
					        self._handshaking = False
 | 
				
			||||||
        self._handshake_check = asyncio.Lock()
 | 
					        self._voice_state_complete = asyncio.Event()
 | 
				
			||||||
        self._handshake_complete = asyncio.Event()
 | 
					        self._voice_server_complete = asyncio.Event()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.mode = None
 | 
					        self.mode = None
 | 
				
			||||||
        self._connections = 0
 | 
					        self._connections = 0
 | 
				
			||||||
@@ -138,48 +240,24 @@ class VoiceClient:
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
    # connection related
 | 
					    # connection related
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def start_handshake(self):
 | 
					    async def on_voice_state_update(self, data):
 | 
				
			||||||
        log.info('Starting voice handshake...')
 | 
					        self.session_id = data['session_id']
 | 
				
			||||||
 | 
					        channel_id = data['channel_id']
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        guild_id, channel_id = self.channel._get_voice_state_pair()
 | 
					        if not self._handshaking:
 | 
				
			||||||
        state = self._state
 | 
					            # If we're done handshaking then we just need to update ourselves
 | 
				
			||||||
        self.main_ws = ws = state._get_websocket(guild_id)
 | 
					            guild = self.guild
 | 
				
			||||||
        self._connections += 1
 | 
					            self.channel = channel_id and guild and guild.get_channel(int(channel_id))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            self._voice_state_complete.set()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # request joining
 | 
					    async def on_voice_server_update(self, data):
 | 
				
			||||||
        await ws.voice_state(guild_id, channel_id)
 | 
					        if self._voice_server_complete.is_set():
 | 
				
			||||||
 | 
					            log.info('Ignoring extraneous voice server update.')
 | 
				
			||||||
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            await asyncio.wait_for(self._handshake_complete.wait(), timeout=self.timeout)
 | 
					 | 
				
			||||||
        except asyncio.TimeoutError:
 | 
					 | 
				
			||||||
            await self.terminate_handshake(remove=True)
 | 
					 | 
				
			||||||
            raise
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        log.info('Voice handshake complete. Endpoint found %s (IP: %s)', self.endpoint, self.endpoint_ip)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def terminate_handshake(self, *, remove=False):
 | 
					 | 
				
			||||||
        guild_id, channel_id = self.channel._get_voice_state_pair()
 | 
					 | 
				
			||||||
        self._handshake_complete.clear()
 | 
					 | 
				
			||||||
        await self.main_ws.voice_state(guild_id, None, self_mute=True)
 | 
					 | 
				
			||||||
        self._handshaking = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', channel_id, guild_id)
 | 
					 | 
				
			||||||
        if remove:
 | 
					 | 
				
			||||||
            log.info('The voice client has been removed for Channel ID %s (Guild ID %s)', channel_id, guild_id)
 | 
					 | 
				
			||||||
            key_id, _ = self.channel._get_voice_client_key()
 | 
					 | 
				
			||||||
            self._state._remove_voice_client(key_id)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def _create_socket(self, server_id, data):
 | 
					 | 
				
			||||||
        async with self._handshake_check:
 | 
					 | 
				
			||||||
            if self._handshaking:
 | 
					 | 
				
			||||||
                log.info("Ignoring voice server update while handshake is in progress")
 | 
					 | 
				
			||||||
                return
 | 
					 | 
				
			||||||
            self._handshaking = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self._connected.clear()
 | 
					 | 
				
			||||||
        self.session_id = self.main_ws.session_id
 | 
					 | 
				
			||||||
        self.server_id = server_id
 | 
					 | 
				
			||||||
        self.token = data.get('token')
 | 
					        self.token = data.get('token')
 | 
				
			||||||
 | 
					        self.server_id = int(data['guild_id'])
 | 
				
			||||||
        endpoint = data.get('endpoint')
 | 
					        endpoint = data.get('endpoint')
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if endpoint is None or self.token is None:
 | 
					        if endpoint is None or self.token is None:
 | 
				
			||||||
@@ -195,23 +273,77 @@ class VoiceClient:
 | 
				
			|||||||
        # This gets set later
 | 
					        # This gets set later
 | 
				
			||||||
        self.endpoint_ip = None
 | 
					        self.endpoint_ip = None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self.socket:
 | 
					 | 
				
			||||||
            try:
 | 
					 | 
				
			||||||
                self.socket.close()
 | 
					 | 
				
			||||||
            except Exception:
 | 
					 | 
				
			||||||
                pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
 | 
					        self.socket = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
 | 
				
			||||||
        self.socket.setblocking(False)
 | 
					        self.socket.setblocking(False)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if self._handshake_complete.is_set():
 | 
					        if not self._handshaking:
 | 
				
			||||||
            # terminate the websocket and handle the reconnect loop if necessary.
 | 
					            # If we're not handshaking then we need to terminate our previous connection in the websocket
 | 
				
			||||||
            self._handshake_complete.clear()
 | 
					 | 
				
			||||||
            self._handshaking = False
 | 
					 | 
				
			||||||
            await self.ws.close(4000)
 | 
					            await self.ws.close(4000)
 | 
				
			||||||
            return
 | 
					            return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self._handshake_complete.set()
 | 
					        self._voice_server_complete.set()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def voice_connect(self):
 | 
				
			||||||
 | 
					        self._connections += 1
 | 
				
			||||||
 | 
					        await self.channel.guild.change_voice_state(channel=self.channel)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def voice_disconnect(self):
 | 
				
			||||||
 | 
					        log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id)
 | 
				
			||||||
 | 
					        await self.channel.guild.change_voice_state(channel=None)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    async def connect(self, *, reconnect, timeout):
 | 
				
			||||||
 | 
					        log.info('Connecting to voice...')
 | 
				
			||||||
 | 
					        self.timeout = timeout
 | 
				
			||||||
 | 
					        try:
 | 
				
			||||||
 | 
					            del self.secret_key
 | 
				
			||||||
 | 
					        except AttributeError:
 | 
				
			||||||
 | 
					            pass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        for i in range(5):
 | 
				
			||||||
 | 
					            self._voice_state_complete.clear()
 | 
				
			||||||
 | 
					            self._voice_server_complete.clear()
 | 
				
			||||||
 | 
					            self._handshaking = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # This has to be created before we start the flow.
 | 
				
			||||||
 | 
					            futures = [
 | 
				
			||||||
 | 
					                self._voice_state_complete.wait(),
 | 
				
			||||||
 | 
					                self._voice_server_complete.wait(),
 | 
				
			||||||
 | 
					            ]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # Start the connection flow
 | 
				
			||||||
 | 
					            log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1)
 | 
				
			||||||
 | 
					            await self.voice_connect()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                await utils.sane_wait_for(futures, timeout=timeout)
 | 
				
			||||||
 | 
					            except asyncio.TimeoutError:
 | 
				
			||||||
 | 
					                await self.disconnect(force=True)
 | 
				
			||||||
 | 
					                raise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            log.info('Voice handshake complete. Endpoint found %s', self.endpoint)
 | 
				
			||||||
 | 
					            self._handshaking = False
 | 
				
			||||||
 | 
					            self._voice_server_complete.clear()
 | 
				
			||||||
 | 
					            self._voice_state_complete.clear()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            try:
 | 
				
			||||||
 | 
					                self.ws = await DiscordVoiceWebSocket.from_client(self)
 | 
				
			||||||
 | 
					                self._connected.clear()
 | 
				
			||||||
 | 
					                while not hasattr(self, 'secret_key'):
 | 
				
			||||||
 | 
					                    await self.ws.poll_event()
 | 
				
			||||||
 | 
					                self._connected.set()
 | 
				
			||||||
 | 
					                break
 | 
				
			||||||
 | 
					            except (ConnectionClosed, asyncio.TimeoutError):
 | 
				
			||||||
 | 
					                if reconnect:
 | 
				
			||||||
 | 
					                    log.exception('Failed to connect to voice... Retrying...')
 | 
				
			||||||
 | 
					                    await asyncio.sleep(1 + i * 2.0)
 | 
				
			||||||
 | 
					                    await self.voice_disconnect()
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					                else:
 | 
				
			||||||
 | 
					                    raise
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if self._runner is None:
 | 
				
			||||||
 | 
					            self._runner = self.loop.create_task(self.poll_voice_ws(reconnect))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def latency(self):
 | 
					    def latency(self):
 | 
				
			||||||
@@ -234,35 +366,6 @@ class VoiceClient:
 | 
				
			|||||||
        ws = self.ws
 | 
					        ws = self.ws
 | 
				
			||||||
        return float("inf") if not ws else ws.average_latency
 | 
					        return float("inf") if not ws else ws.average_latency
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    async def connect(self, *, reconnect=True, _tries=0, do_handshake=True):
 | 
					 | 
				
			||||||
        log.info('Connecting to voice...')
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            del self.secret_key
 | 
					 | 
				
			||||||
        except AttributeError:
 | 
					 | 
				
			||||||
            pass
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if do_handshake:
 | 
					 | 
				
			||||||
            await self.start_handshake()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        try:
 | 
					 | 
				
			||||||
            self.ws = await DiscordVoiceWebSocket.from_client(self)
 | 
					 | 
				
			||||||
            self._handshaking = False
 | 
					 | 
				
			||||||
            self._connected.clear()
 | 
					 | 
				
			||||||
            while not hasattr(self, 'secret_key'):
 | 
					 | 
				
			||||||
                await self.ws.poll_event()
 | 
					 | 
				
			||||||
            self._connected.set()
 | 
					 | 
				
			||||||
        except (ConnectionClosed, asyncio.TimeoutError):
 | 
					 | 
				
			||||||
            if reconnect and _tries < 5:
 | 
					 | 
				
			||||||
                log.exception('Failed to connect to voice... Retrying...')
 | 
					 | 
				
			||||||
                await asyncio.sleep(1 + _tries * 2.0)
 | 
					 | 
				
			||||||
                await self.terminate_handshake()
 | 
					 | 
				
			||||||
                await self.connect(reconnect=reconnect, _tries=_tries + 1)
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                raise
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if self._runner is None:
 | 
					 | 
				
			||||||
            self._runner = self.loop.create_task(self.poll_voice_ws(reconnect))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    async def poll_voice_ws(self, reconnect):
 | 
					    async def poll_voice_ws(self, reconnect):
 | 
				
			||||||
        backoff = ExponentialBackoff()
 | 
					        backoff = ExponentialBackoff()
 | 
				
			||||||
        while True:
 | 
					        while True:
 | 
				
			||||||
@@ -287,9 +390,9 @@ class VoiceClient:
 | 
				
			|||||||
                log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry)
 | 
					                log.exception('Disconnected from voice... Reconnecting in %.2fs.', retry)
 | 
				
			||||||
                self._connected.clear()
 | 
					                self._connected.clear()
 | 
				
			||||||
                await asyncio.sleep(retry)
 | 
					                await asyncio.sleep(retry)
 | 
				
			||||||
                await self.terminate_handshake()
 | 
					                await self.voice_disconnect()
 | 
				
			||||||
                try:
 | 
					                try:
 | 
				
			||||||
                    await self.connect(reconnect=True)
 | 
					                    await self.connect(reconnect=True, timeout=self.timeout)
 | 
				
			||||||
                except asyncio.TimeoutError:
 | 
					                except asyncio.TimeoutError:
 | 
				
			||||||
                    # at this point we've retried 5 times... let's continue the loop.
 | 
					                    # at this point we've retried 5 times... let's continue the loop.
 | 
				
			||||||
                    log.warning('Could not connect to voice... Retrying...')
 | 
					                    log.warning('Could not connect to voice... Retrying...')
 | 
				
			||||||
@@ -310,8 +413,9 @@ class VoiceClient:
 | 
				
			|||||||
            if self.ws:
 | 
					            if self.ws:
 | 
				
			||||||
                await self.ws.close()
 | 
					                await self.ws.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            await self.terminate_handshake(remove=True)
 | 
					            await self.voice_disconnect()
 | 
				
			||||||
        finally:
 | 
					        finally:
 | 
				
			||||||
 | 
					            self.cleanup()
 | 
				
			||||||
            if self.socket:
 | 
					            if self.socket:
 | 
				
			||||||
                self.socket.close()
 | 
					                self.socket.close()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@@ -325,8 +429,7 @@ class VoiceClient:
 | 
				
			|||||||
        channel: :class:`abc.Snowflake`
 | 
					        channel: :class:`abc.Snowflake`
 | 
				
			||||||
            The channel to move to. Must be a voice channel.
 | 
					            The channel to move to. Must be a voice channel.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        guild_id, _ = self.channel._get_voice_state_pair()
 | 
					        await self.channel.guild.change_voice_state(channel=channel)
 | 
				
			||||||
        await self.main_ws.voice_state(guild_id, channel.id)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def is_connected(self):
 | 
					    def is_connected(self):
 | 
				
			||||||
        """Indicates if the voice client is connected to voice."""
 | 
					        """Indicates if the voice client is connected to voice."""
 | 
				
			||||||
 
 | 
				
			|||||||
@@ -54,6 +54,9 @@ Voice
 | 
				
			|||||||
.. autoclass:: VoiceClient()
 | 
					.. autoclass:: VoiceClient()
 | 
				
			||||||
    :members:
 | 
					    :members:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					.. autoclass:: VoiceProtocol
 | 
				
			||||||
 | 
					    :members:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
.. autoclass:: AudioSource
 | 
					.. autoclass:: AudioSource
 | 
				
			||||||
    :members:
 | 
					    :members:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 
 | 
				
			|||||||
		Reference in New Issue
	
	Block a user