Working multi-server voice support.
This commit is contained in:
parent
5fa715c350
commit
d9c780b8a8
@ -90,10 +90,10 @@ class Client:
|
|||||||
-----------
|
-----------
|
||||||
user : Optional[:class:`User`]
|
user : Optional[:class:`User`]
|
||||||
Represents the connected client. None if not logged in.
|
Represents the connected client. None if not logged in.
|
||||||
voice : Optional[:class:`VoiceClient`]
|
voice_clients : iterable of :class:`VoiceClient`
|
||||||
Represents the current voice connection. None if you are not connected
|
Represents a list of voice connections. To connect to voice use
|
||||||
to a voice channel. To connect to voice use :meth:`join_voice_channel`.
|
:meth:`join_voice_channel`. To query the voice connection state use
|
||||||
To query the voice connection state use :meth:`is_voice_connected`.
|
:meth:`is_voice_connected`.
|
||||||
servers : iterable of :class:`Server`
|
servers : iterable of :class:`Server`
|
||||||
The servers that the connected client is a member of.
|
The servers that the connected client is a member of.
|
||||||
private_channels : iterable of :class:`PrivateChannel`
|
private_channels : iterable of :class:`PrivateChannel`
|
||||||
@ -114,7 +114,6 @@ class Client:
|
|||||||
def __init__(self, *, loop=None, **options):
|
def __init__(self, *, loop=None, **options):
|
||||||
self.ws = None
|
self.ws = None
|
||||||
self.token = None
|
self.token = None
|
||||||
self.voice = None
|
|
||||||
self.loop = asyncio.get_event_loop() if loop is None else loop
|
self.loop = asyncio.get_event_loop() if loop is None else loop
|
||||||
self._listeners = []
|
self._listeners = []
|
||||||
self.cache_auth = options.get('cache_auth', True)
|
self.cache_auth = options.get('cache_auth', True)
|
||||||
@ -227,14 +226,14 @@ class Client:
|
|||||||
raise InvalidArgument('Destination must be Channel, PrivateChannel, User, or Object')
|
raise InvalidArgument('Destination must be Channel, PrivateChannel, User, or Object')
|
||||||
|
|
||||||
def __getattr__(self, name):
|
def __getattr__(self, name):
|
||||||
if name in ('user', 'servers', 'private_channels', 'messages'):
|
if name in ('user', 'servers', 'private_channels', 'messages', 'voice_clients'):
|
||||||
return getattr(self.connection, name)
|
return getattr(self.connection, name)
|
||||||
else:
|
else:
|
||||||
msg = "'{}' object has no attribute '{}'"
|
msg = "'{}' object has no attribute '{}'"
|
||||||
raise AttributeError(msg.format(self.__class__, name))
|
raise AttributeError(msg.format(self.__class__, name))
|
||||||
|
|
||||||
def __setattr__(self, name, value):
|
def __setattr__(self, name, value):
|
||||||
if name in ('user', 'servers', 'private_channels', 'messages'):
|
if name in ('user', 'servers', 'private_channels', 'messages', 'voice_clients'):
|
||||||
return setattr(self.connection, name, value)
|
return setattr(self.connection, name, value)
|
||||||
else:
|
else:
|
||||||
object.__setattr__(self, name, value)
|
object.__setattr__(self, name, value)
|
||||||
@ -418,13 +417,13 @@ class Client:
|
|||||||
if self.is_closed:
|
if self.is_closed:
|
||||||
return
|
return
|
||||||
|
|
||||||
if self.is_voice_connected():
|
|
||||||
yield from self.voice.disconnect()
|
|
||||||
self.voice = None
|
|
||||||
|
|
||||||
if self.ws is not None and self.ws.open:
|
if self.ws is not None and self.ws.open:
|
||||||
yield from self.ws.close()
|
yield from self.ws.close()
|
||||||
|
|
||||||
|
for voice in list(self.voice_clients):
|
||||||
|
yield from voice.disconnect()
|
||||||
|
self.connection._remove_voice_client(voice.server.id)
|
||||||
|
|
||||||
yield from self.session.close()
|
yield from self.session.close()
|
||||||
self._closed.set()
|
self._closed.set()
|
||||||
self._is_ready.clear()
|
self._is_ready.clear()
|
||||||
@ -2415,15 +2414,17 @@ class Client:
|
|||||||
:class:`VoiceClient`
|
:class:`VoiceClient`
|
||||||
A voice client that is fully connected to the voice server.
|
A voice client that is fully connected to the voice server.
|
||||||
"""
|
"""
|
||||||
if self.is_voice_connected():
|
|
||||||
raise ClientException('Already connected to a voice channel')
|
|
||||||
|
|
||||||
if isinstance(channel, Object):
|
if isinstance(channel, Object):
|
||||||
channel = self.get_channel(channel.id)
|
channel = self.get_channel(channel.id)
|
||||||
|
|
||||||
if getattr(channel, 'type', ChannelType.text) != ChannelType.voice:
|
if getattr(channel, 'type', ChannelType.text) != ChannelType.voice:
|
||||||
raise InvalidArgument('Channel passed must be a voice channel')
|
raise InvalidArgument('Channel passed must be a voice channel')
|
||||||
|
|
||||||
|
server = channel.server
|
||||||
|
|
||||||
|
if self.is_voice_connected(server):
|
||||||
|
raise ClientException('Already connected to a voice channel in this server')
|
||||||
|
|
||||||
log.info('attempting to join voice channel {0.name}'.format(channel))
|
log.info('attempting to join voice channel {0.name}'.format(channel))
|
||||||
|
|
||||||
def session_id_found(data):
|
def session_id_found(data):
|
||||||
@ -2435,14 +2436,10 @@ class Client:
|
|||||||
voice_data_future = self.ws.wait_for('VOICE_SERVER_UPDATE', lambda d: True)
|
voice_data_future = self.ws.wait_for('VOICE_SERVER_UPDATE', lambda d: True)
|
||||||
|
|
||||||
# request joining
|
# request joining
|
||||||
yield from self.ws.voice_state(channel.server.id, channel.id)
|
yield from self.ws.voice_state(server.id, channel.id)
|
||||||
session_id_data = yield from asyncio.wait_for(session_id_future, timeout=10.0, loop=self.loop)
|
session_id_data = yield from asyncio.wait_for(session_id_future, timeout=10.0, loop=self.loop)
|
||||||
data = yield from asyncio.wait_for(voice_data_future, timeout=10.0, loop=self.loop)
|
data = yield from asyncio.wait_for(voice_data_future, timeout=10.0, loop=self.loop)
|
||||||
|
|
||||||
# todo: multivoice
|
|
||||||
if self.is_voice_connected():
|
|
||||||
self.voice.channel = self.get_channel(session_id_data.get('channel_id'))
|
|
||||||
|
|
||||||
kwargs = {
|
kwargs = {
|
||||||
'user': self.user,
|
'user': self.user,
|
||||||
'channel': channel,
|
'channel': channel,
|
||||||
@ -2452,10 +2449,36 @@ class Client:
|
|||||||
'main_ws': self.ws
|
'main_ws': self.ws
|
||||||
}
|
}
|
||||||
|
|
||||||
self.voice = VoiceClient(**kwargs)
|
voice = VoiceClient(**kwargs)
|
||||||
yield from self.voice.connect()
|
yield from voice.connect()
|
||||||
return self.voice
|
self.connection._add_voice_client(server.id, voice)
|
||||||
|
return voice
|
||||||
|
|
||||||
def is_voice_connected(self):
|
def is_voice_connected(self, server):
|
||||||
"""bool : Indicates if we are currently connected to a voice channel."""
|
"""Indicates if we are currently connected to a voice channel in the
|
||||||
return self.voice is not None and self.voice.is_connected()
|
specified server.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
server : :class:`Server`
|
||||||
|
The server to query if we're connected to it.
|
||||||
|
"""
|
||||||
|
voice = self.voice_client_in(server)
|
||||||
|
return voice is not None
|
||||||
|
|
||||||
|
def voice_client_in(self, server):
|
||||||
|
"""Returns the voice client associated with a server.
|
||||||
|
|
||||||
|
If no voice client is found then ``None`` is returned.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
-----------
|
||||||
|
server : :class:`Server`
|
||||||
|
The server to query if we have a voice client for.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
--------
|
||||||
|
:class:`VoiceClient`
|
||||||
|
The voice client associated with the server.
|
||||||
|
"""
|
||||||
|
return self.connection._get_voice_client(server.id)
|
||||||
|
@ -179,35 +179,21 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
# the keep alive
|
# the keep alive
|
||||||
self._keep_alive = None
|
self._keep_alive = None
|
||||||
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def connect(cls, dispatch, *, token=None, connection=None, loop=None):
|
def from_client(cls, client):
|
||||||
"""Creates a main websocket for Discord used for the client.
|
"""Creates a main websocket for Discord from a :class:`Client`.
|
||||||
|
|
||||||
Parameters
|
This is for internal use only.
|
||||||
----------
|
|
||||||
token : str
|
|
||||||
The token for Discord authentication.
|
|
||||||
connection
|
|
||||||
The ConnectionState for the client.
|
|
||||||
dispatch
|
|
||||||
The function that dispatches events.
|
|
||||||
loop
|
|
||||||
The event loop to use.
|
|
||||||
|
|
||||||
Returns
|
|
||||||
-------
|
|
||||||
DiscordWebSocket
|
|
||||||
A websocket connected to Discord.
|
|
||||||
"""
|
"""
|
||||||
|
gateway = yield from get_gateway(client.token, loop=client.loop)
|
||||||
gateway = yield from get_gateway(token, loop=loop)
|
ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls)
|
||||||
ws = yield from websockets.connect(gateway, loop=loop, klass=cls)
|
|
||||||
|
|
||||||
# dynamically add attributes needed
|
# dynamically add attributes needed
|
||||||
ws.token = token
|
ws.token = client.token
|
||||||
ws._connection = connection
|
ws._connection = client.connection
|
||||||
ws._dispatch = dispatch
|
ws._dispatch = client.dispatch
|
||||||
ws.gateway = gateway
|
ws.gateway = gateway
|
||||||
|
|
||||||
log.info('Created websocket connected to {}'.format(gateway))
|
log.info('Created websocket connected to {}'.format(gateway))
|
||||||
@ -215,16 +201,6 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
log.info('sent the identify payload to create the websocket')
|
log.info('sent the identify payload to create the websocket')
|
||||||
return ws
|
return ws
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_client(cls, client):
|
|
||||||
"""Creates a main websocket for Discord from a :class:`Client`.
|
|
||||||
|
|
||||||
This is for internal use only.
|
|
||||||
"""
|
|
||||||
return cls.connect(client.dispatch, token=client.token,
|
|
||||||
connection=client.connection,
|
|
||||||
loop=client.loop)
|
|
||||||
|
|
||||||
def wait_for(self, event, predicate, result=None):
|
def wait_for(self, event, predicate, result=None):
|
||||||
"""Waits for a DISPATCH'd event that meets the predicate.
|
"""Waits for a DISPATCH'd event that meets the predicate.
|
||||||
|
|
||||||
@ -280,6 +256,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
msg = msg.decode('utf-8')
|
msg = msg.decode('utf-8')
|
||||||
|
|
||||||
msg = json.loads(msg)
|
msg = json.loads(msg)
|
||||||
|
state = self._connection
|
||||||
|
|
||||||
log.debug('WebSocket Event: {}'.format(msg))
|
log.debug('WebSocket Event: {}'.format(msg))
|
||||||
self._dispatch('socket_response', msg)
|
self._dispatch('socket_response', msg)
|
||||||
@ -288,7 +265,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
data = msg.get('d')
|
data = msg.get('d')
|
||||||
|
|
||||||
if 's' in msg:
|
if 's' in msg:
|
||||||
self._connection.sequence = msg['s']
|
state.sequence = msg['s']
|
||||||
|
|
||||||
if op == self.RECONNECT:
|
if op == self.RECONNECT:
|
||||||
# "reconnect" can only be handled by the Client
|
# "reconnect" can only be handled by the Client
|
||||||
@ -299,8 +276,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
raise ReconnectWebSocket()
|
raise ReconnectWebSocket()
|
||||||
|
|
||||||
if op == self.INVALIDATE_SESSION:
|
if op == self.INVALIDATE_SESSION:
|
||||||
self._connection.sequence = None
|
state.sequence = None
|
||||||
self._connection.session_id = None
|
state.session_id = None
|
||||||
return
|
return
|
||||||
|
|
||||||
if op != self.DISPATCH:
|
if op != self.DISPATCH:
|
||||||
@ -311,9 +288,9 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
is_ready = event == 'READY'
|
is_ready = event == 'READY'
|
||||||
|
|
||||||
if is_ready:
|
if is_ready:
|
||||||
self._connection.clear()
|
state.clear()
|
||||||
self._connection.sequence = msg['s']
|
state.sequence = msg['s']
|
||||||
self._connection.session_id = data['session_id']
|
state.session_id = data['session_id']
|
||||||
|
|
||||||
if is_ready or event == 'RESUMED':
|
if is_ready or event == 'RESUMED':
|
||||||
interval = data['heartbeat_interval'] / 1000.0
|
interval = data['heartbeat_interval'] / 1000.0
|
||||||
@ -366,7 +343,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
msg = yield from self.recv()
|
msg = yield from self.recv()
|
||||||
yield from self.received_message(msg)
|
yield from self.received_message(msg)
|
||||||
except websockets.exceptions.ConnectionClosed as e:
|
except websockets.exceptions.ConnectionClosed as e:
|
||||||
if e.code in (4008, 4009) or e.code in range(1001, 1015):
|
if e.code in (4006, 4008, 4009) or e.code in range(1001, 1015):
|
||||||
log.info('Websocket closed with {0.code}, attempting a reconnect.'.format(e))
|
log.info('Websocket closed with {0.code}, attempting a reconnect.'.format(e))
|
||||||
raise ReconnectWebSocket() from e
|
raise ReconnectWebSocket() from e
|
||||||
else:
|
else:
|
||||||
@ -424,6 +401,10 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol):
|
|||||||
|
|
||||||
yield from self.send_as_json(payload)
|
yield from self.send_as_json(payload)
|
||||||
|
|
||||||
|
# we're leaving a voice channel so remove it from the client list
|
||||||
|
if channel_id is None:
|
||||||
|
self._connection._remove_voice_client(guild_id)
|
||||||
|
|
||||||
@asyncio.coroutine
|
@asyncio.coroutine
|
||||||
def close(self, code=1000, reason=''):
|
def close(self, code=1000, reason=''):
|
||||||
if self._keep_alive:
|
if self._keep_alive:
|
||||||
|
@ -62,6 +62,7 @@ class ConnectionState:
|
|||||||
self.sequence = None
|
self.sequence = None
|
||||||
self.session_id = None
|
self.session_id = None
|
||||||
self._servers = {}
|
self._servers = {}
|
||||||
|
self._voice_clients = {}
|
||||||
self._private_channels = {}
|
self._private_channels = {}
|
||||||
# extra dict to look up private channels by user id
|
# extra dict to look up private channels by user id
|
||||||
self._private_channels_by_user = {}
|
self._private_channels_by_user = {}
|
||||||
@ -93,6 +94,19 @@ class ConnectionState:
|
|||||||
for index in reversed(removed):
|
for index in reversed(removed):
|
||||||
del self._listeners[index]
|
del self._listeners[index]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def voice_clients(self):
|
||||||
|
return self._voice_clients.values()
|
||||||
|
|
||||||
|
def _get_voice_client(self, guild_id):
|
||||||
|
return self._voice_clients.get(guild_id)
|
||||||
|
|
||||||
|
def _add_voice_client(self, guild_id, voice):
|
||||||
|
self._voice_clients[guild_id] = voice
|
||||||
|
|
||||||
|
def _remove_voice_client(self, guild_id):
|
||||||
|
self._voice_clients.pop(guild_id, None)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def servers(self):
|
def servers(self):
|
||||||
return self._servers.values()
|
return self._servers.values()
|
||||||
@ -130,6 +144,7 @@ class ConnectionState:
|
|||||||
def _add_server_from_data(self, guild):
|
def _add_server_from_data(self, guild):
|
||||||
server = Server(**guild)
|
server = Server(**guild)
|
||||||
Server.me = property(lambda s: s.get_member(self.user.id))
|
Server.me = property(lambda s: s.get_member(self.user.id))
|
||||||
|
Server.voice_client = property(lambda s: self._get_voice_client(s.id))
|
||||||
self._add_server(server)
|
self._add_server(server)
|
||||||
return server
|
return server
|
||||||
|
|
||||||
@ -489,7 +504,13 @@ class ConnectionState:
|
|||||||
|
|
||||||
def parse_voice_state_update(self, data):
|
def parse_voice_state_update(self, data):
|
||||||
server = self._get_server(data.get('guild_id'))
|
server = self._get_server(data.get('guild_id'))
|
||||||
|
user_id = data.get('user_id')
|
||||||
if server is not None:
|
if server is not None:
|
||||||
|
if user_id == self.user.id:
|
||||||
|
voice = self._get_voice_client(server.id)
|
||||||
|
if voice is not None:
|
||||||
|
voice.channel = server.get_channel(data.get('channel_id'))
|
||||||
|
|
||||||
updated_members = server._update_voice_state(data)
|
updated_members = server._update_voice_state(data)
|
||||||
self.dispatch('voice_state_update', *updated_members)
|
self.dispatch('voice_state_update', *updated_members)
|
||||||
|
|
||||||
|
@ -158,6 +158,9 @@ class VoiceClient:
|
|||||||
The endpoint we are connecting to.
|
The endpoint we are connecting to.
|
||||||
channel : :class:`Channel`
|
channel : :class:`Channel`
|
||||||
The voice channel connected to.
|
The voice channel connected to.
|
||||||
|
server : :class:`Server`
|
||||||
|
The server the voice channel is connected to.
|
||||||
|
Shorthand for ``channel.server``.
|
||||||
loop
|
loop
|
||||||
The event loop that the voice client is running on.
|
The event loop that the voice client is running on.
|
||||||
"""
|
"""
|
||||||
@ -176,6 +179,10 @@ class VoiceClient:
|
|||||||
self.encoder = OpusEncoder(48000, 2)
|
self.encoder = OpusEncoder(48000, 2)
|
||||||
log.info('created opus encoder with {0.__dict__}'.format(self.encoder))
|
log.info('created opus encoder with {0.__dict__}'.format(self.encoder))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def server(self):
|
||||||
|
return self.channel.server
|
||||||
|
|
||||||
def checked_add(self, attr, value, limit):
|
def checked_add(self, attr, value, limit):
|
||||||
val = getattr(self, attr)
|
val = getattr(self, attr)
|
||||||
if val + value > limit:
|
if val + value > limit:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user