Working multi-server voice support.
This commit is contained in:
		| @@ -90,10 +90,10 @@ class Client: | |||||||
|     ----------- |     ----------- | ||||||
|     user : Optional[:class:`User`] |     user : Optional[:class:`User`] | ||||||
|         Represents the connected client. None if not logged in. |         Represents the connected client. None if not logged in. | ||||||
|     voice : Optional[:class:`VoiceClient`] |     voice_clients : iterable of :class:`VoiceClient` | ||||||
|         Represents the current voice connection. None if you are not connected |         Represents a list of voice connections. To connect to voice use | ||||||
|         to a voice channel. To connect to voice use :meth:`join_voice_channel`. |         :meth:`join_voice_channel`. To query the voice connection state use | ||||||
|         To query the voice connection state use :meth:`is_voice_connected`. |         :meth:`is_voice_connected`. | ||||||
|     servers : iterable of :class:`Server` |     servers : iterable of :class:`Server` | ||||||
|         The servers that the connected client is a member of. |         The servers that the connected client is a member of. | ||||||
|     private_channels : iterable of :class:`PrivateChannel` |     private_channels : iterable of :class:`PrivateChannel` | ||||||
| @@ -114,7 +114,6 @@ class Client: | |||||||
|     def __init__(self, *, loop=None, **options): |     def __init__(self, *, loop=None, **options): | ||||||
|         self.ws = None |         self.ws = None | ||||||
|         self.token = None |         self.token = None | ||||||
|         self.voice = None |  | ||||||
|         self.loop = asyncio.get_event_loop() if loop is None else loop |         self.loop = asyncio.get_event_loop() if loop is None else loop | ||||||
|         self._listeners = [] |         self._listeners = [] | ||||||
|         self.cache_auth = options.get('cache_auth', True) |         self.cache_auth = options.get('cache_auth', True) | ||||||
| @@ -227,14 +226,14 @@ class Client: | |||||||
|             raise InvalidArgument('Destination must be Channel, PrivateChannel, User, or Object') |             raise InvalidArgument('Destination must be Channel, PrivateChannel, User, or Object') | ||||||
|  |  | ||||||
|     def __getattr__(self, name): |     def __getattr__(self, name): | ||||||
|         if name in ('user', 'servers', 'private_channels', 'messages'): |         if name in ('user', 'servers', 'private_channels', 'messages', 'voice_clients'): | ||||||
|             return getattr(self.connection, name) |             return getattr(self.connection, name) | ||||||
|         else: |         else: | ||||||
|             msg = "'{}' object has no attribute '{}'" |             msg = "'{}' object has no attribute '{}'" | ||||||
|             raise AttributeError(msg.format(self.__class__, name)) |             raise AttributeError(msg.format(self.__class__, name)) | ||||||
|  |  | ||||||
|     def __setattr__(self, name, value): |     def __setattr__(self, name, value): | ||||||
|         if name in ('user', 'servers', 'private_channels', 'messages'): |         if name in ('user', 'servers', 'private_channels', 'messages', 'voice_clients'): | ||||||
|             return setattr(self.connection, name, value) |             return setattr(self.connection, name, value) | ||||||
|         else: |         else: | ||||||
|             object.__setattr__(self, name, value) |             object.__setattr__(self, name, value) | ||||||
| @@ -418,13 +417,13 @@ class Client: | |||||||
|         if self.is_closed: |         if self.is_closed: | ||||||
|             return |             return | ||||||
|  |  | ||||||
|         if self.is_voice_connected(): |  | ||||||
|             yield from self.voice.disconnect() |  | ||||||
|             self.voice = None |  | ||||||
|  |  | ||||||
|         if self.ws is not None and self.ws.open: |         if self.ws is not None and self.ws.open: | ||||||
|             yield from self.ws.close() |             yield from self.ws.close() | ||||||
|  |  | ||||||
|  |         for voice in list(self.voice_clients): | ||||||
|  |             yield from voice.disconnect() | ||||||
|  |             self.connection._remove_voice_client(voice.server.id) | ||||||
|  |  | ||||||
|         yield from self.session.close() |         yield from self.session.close() | ||||||
|         self._closed.set() |         self._closed.set() | ||||||
|         self._is_ready.clear() |         self._is_ready.clear() | ||||||
| @@ -2415,15 +2414,17 @@ class Client: | |||||||
|         :class:`VoiceClient` |         :class:`VoiceClient` | ||||||
|             A voice client that is fully connected to the voice server. |             A voice client that is fully connected to the voice server. | ||||||
|         """ |         """ | ||||||
|         if self.is_voice_connected(): |  | ||||||
|             raise ClientException('Already connected to a voice channel') |  | ||||||
|  |  | ||||||
|         if isinstance(channel, Object): |         if isinstance(channel, Object): | ||||||
|             channel = self.get_channel(channel.id) |             channel = self.get_channel(channel.id) | ||||||
|  |  | ||||||
|         if getattr(channel, 'type', ChannelType.text) != ChannelType.voice: |         if getattr(channel, 'type', ChannelType.text) != ChannelType.voice: | ||||||
|             raise InvalidArgument('Channel passed must be a voice channel') |             raise InvalidArgument('Channel passed must be a voice channel') | ||||||
|  |  | ||||||
|  |         server = channel.server | ||||||
|  |  | ||||||
|  |         if self.is_voice_connected(server): | ||||||
|  |             raise ClientException('Already connected to a voice channel in this server') | ||||||
|  |  | ||||||
|         log.info('attempting to join voice channel {0.name}'.format(channel)) |         log.info('attempting to join voice channel {0.name}'.format(channel)) | ||||||
|  |  | ||||||
|         def session_id_found(data): |         def session_id_found(data): | ||||||
| @@ -2435,14 +2436,10 @@ class Client: | |||||||
|         voice_data_future = self.ws.wait_for('VOICE_SERVER_UPDATE', lambda d: True) |         voice_data_future = self.ws.wait_for('VOICE_SERVER_UPDATE', lambda d: True) | ||||||
|  |  | ||||||
|         # request joining |         # request joining | ||||||
|         yield from self.ws.voice_state(channel.server.id, channel.id) |         yield from self.ws.voice_state(server.id, channel.id) | ||||||
|         session_id_data = yield from asyncio.wait_for(session_id_future, timeout=10.0, loop=self.loop) |         session_id_data = yield from asyncio.wait_for(session_id_future, timeout=10.0, loop=self.loop) | ||||||
|         data = yield from asyncio.wait_for(voice_data_future, timeout=10.0, loop=self.loop) |         data = yield from asyncio.wait_for(voice_data_future, timeout=10.0, loop=self.loop) | ||||||
|  |  | ||||||
|         # todo: multivoice |  | ||||||
|         if self.is_voice_connected(): |  | ||||||
|             self.voice.channel = self.get_channel(session_id_data.get('channel_id')) |  | ||||||
|  |  | ||||||
|         kwargs = { |         kwargs = { | ||||||
|             'user': self.user, |             'user': self.user, | ||||||
|             'channel': channel, |             'channel': channel, | ||||||
| @@ -2452,10 +2449,36 @@ class Client: | |||||||
|             'main_ws': self.ws |             'main_ws': self.ws | ||||||
|         } |         } | ||||||
|  |  | ||||||
|         self.voice = VoiceClient(**kwargs) |         voice = VoiceClient(**kwargs) | ||||||
|         yield from self.voice.connect() |         yield from voice.connect() | ||||||
|         return self.voice |         self.connection._add_voice_client(server.id, voice) | ||||||
|  |         return voice | ||||||
|  |  | ||||||
|     def is_voice_connected(self): |     def is_voice_connected(self, server): | ||||||
|         """bool : Indicates if we are currently connected to a voice channel.""" |         """Indicates if we are currently connected to a voice channel in the | ||||||
|         return self.voice is not None and self.voice.is_connected() |         specified server. | ||||||
|  |  | ||||||
|  |         Parameters | ||||||
|  |         ----------- | ||||||
|  |         server : :class:`Server` | ||||||
|  |             The server to query if we're connected to it. | ||||||
|  |         """ | ||||||
|  |         voice = self.voice_client_in(server) | ||||||
|  |         return voice is not None | ||||||
|  |  | ||||||
|  |     def voice_client_in(self, server): | ||||||
|  |         """Returns the voice client associated with a server. | ||||||
|  |  | ||||||
|  |         If no voice client is found then ``None`` is returned. | ||||||
|  |  | ||||||
|  |         Parameters | ||||||
|  |         ----------- | ||||||
|  |         server : :class:`Server` | ||||||
|  |             The server to query if we have a voice client for. | ||||||
|  |  | ||||||
|  |         Returns | ||||||
|  |         -------- | ||||||
|  |         :class:`VoiceClient` | ||||||
|  |             The voice client associated with the server. | ||||||
|  |         """ | ||||||
|  |         return self.connection._get_voice_client(server.id) | ||||||
|   | |||||||
| @@ -179,35 +179,21 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): | |||||||
|         # the keep alive |         # the keep alive | ||||||
|         self._keep_alive = None |         self._keep_alive = None | ||||||
|  |  | ||||||
|  |  | ||||||
|     @classmethod |     @classmethod | ||||||
|     @asyncio.coroutine |     @asyncio.coroutine | ||||||
|     def connect(cls, dispatch, *, token=None, connection=None, loop=None): |     def from_client(cls, client): | ||||||
|         """Creates a main websocket for Discord used for the client. |         """Creates a main websocket for Discord from a :class:`Client`. | ||||||
|  |  | ||||||
|         Parameters |         This is for internal use only. | ||||||
|         ---------- |  | ||||||
|         token : str |  | ||||||
|             The token for Discord authentication. |  | ||||||
|         connection |  | ||||||
|             The ConnectionState for the client. |  | ||||||
|         dispatch |  | ||||||
|             The function that dispatches events. |  | ||||||
|         loop |  | ||||||
|             The event loop to use. |  | ||||||
|  |  | ||||||
|         Returns |  | ||||||
|         ------- |  | ||||||
|         DiscordWebSocket |  | ||||||
|             A websocket connected to Discord. |  | ||||||
|         """ |         """ | ||||||
|  |         gateway = yield from get_gateway(client.token, loop=client.loop) | ||||||
|         gateway = yield from get_gateway(token, loop=loop) |         ws = yield from websockets.connect(gateway, loop=client.loop, klass=cls) | ||||||
|         ws = yield from websockets.connect(gateway, loop=loop, klass=cls) |  | ||||||
|  |  | ||||||
|         # dynamically add attributes needed |         # dynamically add attributes needed | ||||||
|         ws.token = token |         ws.token = client.token | ||||||
|         ws._connection = connection |         ws._connection = client.connection | ||||||
|         ws._dispatch = dispatch |         ws._dispatch = client.dispatch | ||||||
|         ws.gateway = gateway |         ws.gateway = gateway | ||||||
|  |  | ||||||
|         log.info('Created websocket connected to {}'.format(gateway)) |         log.info('Created websocket connected to {}'.format(gateway)) | ||||||
| @@ -215,16 +201,6 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): | |||||||
|         log.info('sent the identify payload to create the websocket') |         log.info('sent the identify payload to create the websocket') | ||||||
|         return ws |         return ws | ||||||
|  |  | ||||||
|     @classmethod |  | ||||||
|     def from_client(cls, client): |  | ||||||
|         """Creates a main websocket for Discord from a :class:`Client`. |  | ||||||
|  |  | ||||||
|         This is for internal use only. |  | ||||||
|         """ |  | ||||||
|         return cls.connect(client.dispatch, token=client.token, |  | ||||||
|                                             connection=client.connection, |  | ||||||
|                                             loop=client.loop) |  | ||||||
|  |  | ||||||
|     def wait_for(self, event, predicate, result=None): |     def wait_for(self, event, predicate, result=None): | ||||||
|         """Waits for a DISPATCH'd event that meets the predicate. |         """Waits for a DISPATCH'd event that meets the predicate. | ||||||
|  |  | ||||||
| @@ -280,6 +256,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): | |||||||
|             msg = msg.decode('utf-8') |             msg = msg.decode('utf-8') | ||||||
|  |  | ||||||
|         msg = json.loads(msg) |         msg = json.loads(msg) | ||||||
|  |         state = self._connection | ||||||
|  |  | ||||||
|         log.debug('WebSocket Event: {}'.format(msg)) |         log.debug('WebSocket Event: {}'.format(msg)) | ||||||
|         self._dispatch('socket_response', msg) |         self._dispatch('socket_response', msg) | ||||||
| @@ -288,7 +265,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): | |||||||
|         data = msg.get('d') |         data = msg.get('d') | ||||||
|  |  | ||||||
|         if 's' in msg: |         if 's' in msg: | ||||||
|             self._connection.sequence = msg['s'] |             state.sequence = msg['s'] | ||||||
|  |  | ||||||
|         if op == self.RECONNECT: |         if op == self.RECONNECT: | ||||||
|             # "reconnect" can only be handled by the Client |             # "reconnect" can only be handled by the Client | ||||||
| @@ -299,8 +276,8 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): | |||||||
|             raise ReconnectWebSocket() |             raise ReconnectWebSocket() | ||||||
|  |  | ||||||
|         if op == self.INVALIDATE_SESSION: |         if op == self.INVALIDATE_SESSION: | ||||||
|             self._connection.sequence = None |             state.sequence = None | ||||||
|             self._connection.session_id = None |             state.session_id = None | ||||||
|             return |             return | ||||||
|  |  | ||||||
|         if op != self.DISPATCH: |         if op != self.DISPATCH: | ||||||
| @@ -311,9 +288,9 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): | |||||||
|         is_ready = event == 'READY' |         is_ready = event == 'READY' | ||||||
|  |  | ||||||
|         if is_ready: |         if is_ready: | ||||||
|             self._connection.clear() |             state.clear() | ||||||
|             self._connection.sequence = msg['s'] |             state.sequence = msg['s'] | ||||||
|             self._connection.session_id = data['session_id'] |             state.session_id = data['session_id'] | ||||||
|  |  | ||||||
|         if is_ready or event == 'RESUMED': |         if is_ready or event == 'RESUMED': | ||||||
|             interval = data['heartbeat_interval'] / 1000.0 |             interval = data['heartbeat_interval'] / 1000.0 | ||||||
| @@ -366,7 +343,7 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): | |||||||
|             msg = yield from self.recv() |             msg = yield from self.recv() | ||||||
|             yield from self.received_message(msg) |             yield from self.received_message(msg) | ||||||
|         except websockets.exceptions.ConnectionClosed as e: |         except websockets.exceptions.ConnectionClosed as e: | ||||||
|             if e.code in (4008, 4009) or e.code in range(1001, 1015): |             if e.code in (4006, 4008, 4009) or e.code in range(1001, 1015): | ||||||
|                 log.info('Websocket closed with {0.code}, attempting a reconnect.'.format(e)) |                 log.info('Websocket closed with {0.code}, attempting a reconnect.'.format(e)) | ||||||
|                 raise ReconnectWebSocket() from e |                 raise ReconnectWebSocket() from e | ||||||
|             else: |             else: | ||||||
| @@ -424,6 +401,10 @@ class DiscordWebSocket(websockets.client.WebSocketClientProtocol): | |||||||
|  |  | ||||||
|         yield from self.send_as_json(payload) |         yield from self.send_as_json(payload) | ||||||
|  |  | ||||||
|  |         # we're leaving a voice channel so remove it from the client list | ||||||
|  |         if channel_id is None: | ||||||
|  |             self._connection._remove_voice_client(guild_id) | ||||||
|  |  | ||||||
|     @asyncio.coroutine |     @asyncio.coroutine | ||||||
|     def close(self, code=1000, reason=''): |     def close(self, code=1000, reason=''): | ||||||
|         if self._keep_alive: |         if self._keep_alive: | ||||||
|   | |||||||
| @@ -62,6 +62,7 @@ class ConnectionState: | |||||||
|         self.sequence = None |         self.sequence = None | ||||||
|         self.session_id = None |         self.session_id = None | ||||||
|         self._servers = {} |         self._servers = {} | ||||||
|  |         self._voice_clients = {} | ||||||
|         self._private_channels = {} |         self._private_channels = {} | ||||||
|         # extra dict to look up private channels by user id |         # extra dict to look up private channels by user id | ||||||
|         self._private_channels_by_user = {} |         self._private_channels_by_user = {} | ||||||
| @@ -93,6 +94,19 @@ class ConnectionState: | |||||||
|         for index in reversed(removed): |         for index in reversed(removed): | ||||||
|             del self._listeners[index] |             del self._listeners[index] | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def voice_clients(self): | ||||||
|  |         return self._voice_clients.values() | ||||||
|  |  | ||||||
|  |     def _get_voice_client(self, guild_id): | ||||||
|  |         return self._voice_clients.get(guild_id) | ||||||
|  |  | ||||||
|  |     def _add_voice_client(self, guild_id, voice): | ||||||
|  |         self._voice_clients[guild_id] = voice | ||||||
|  |  | ||||||
|  |     def _remove_voice_client(self, guild_id): | ||||||
|  |         self._voice_clients.pop(guild_id, None) | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def servers(self): |     def servers(self): | ||||||
|         return self._servers.values() |         return self._servers.values() | ||||||
| @@ -130,6 +144,7 @@ class ConnectionState: | |||||||
|     def _add_server_from_data(self, guild): |     def _add_server_from_data(self, guild): | ||||||
|         server = Server(**guild) |         server = Server(**guild) | ||||||
|         Server.me = property(lambda s: s.get_member(self.user.id)) |         Server.me = property(lambda s: s.get_member(self.user.id)) | ||||||
|  |         Server.voice_client = property(lambda s: self._get_voice_client(s.id)) | ||||||
|         self._add_server(server) |         self._add_server(server) | ||||||
|         return server |         return server | ||||||
|  |  | ||||||
| @@ -489,7 +504,13 @@ class ConnectionState: | |||||||
|  |  | ||||||
|     def parse_voice_state_update(self, data): |     def parse_voice_state_update(self, data): | ||||||
|         server = self._get_server(data.get('guild_id')) |         server = self._get_server(data.get('guild_id')) | ||||||
|  |         user_id = data.get('user_id') | ||||||
|         if server is not None: |         if server is not None: | ||||||
|  |             if user_id == self.user.id: | ||||||
|  |                 voice = self._get_voice_client(server.id) | ||||||
|  |                 if voice is not None: | ||||||
|  |                     voice.channel = server.get_channel(data.get('channel_id')) | ||||||
|  |  | ||||||
|             updated_members = server._update_voice_state(data) |             updated_members = server._update_voice_state(data) | ||||||
|             self.dispatch('voice_state_update', *updated_members) |             self.dispatch('voice_state_update', *updated_members) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -158,6 +158,9 @@ class VoiceClient: | |||||||
|         The endpoint we are connecting to. |         The endpoint we are connecting to. | ||||||
|     channel : :class:`Channel` |     channel : :class:`Channel` | ||||||
|         The voice channel connected to. |         The voice channel connected to. | ||||||
|  |     server : :class:`Server` | ||||||
|  |         The server the voice channel is connected to. | ||||||
|  |         Shorthand for ``channel.server``. | ||||||
|     loop |     loop | ||||||
|         The event loop that the voice client is running on. |         The event loop that the voice client is running on. | ||||||
|     """ |     """ | ||||||
| @@ -176,6 +179,10 @@ class VoiceClient: | |||||||
|         self.encoder = OpusEncoder(48000, 2) |         self.encoder = OpusEncoder(48000, 2) | ||||||
|         log.info('created opus encoder with {0.__dict__}'.format(self.encoder)) |         log.info('created opus encoder with {0.__dict__}'.format(self.encoder)) | ||||||
|  |  | ||||||
|  |     @property | ||||||
|  |     def server(self): | ||||||
|  |         return self.channel.server | ||||||
|  |  | ||||||
|     def checked_add(self, attr, value, limit): |     def checked_add(self, attr, value, limit): | ||||||
|         val = getattr(self, attr) |         val = getattr(self, attr) | ||||||
|         if val + value > limit: |         if val + value > limit: | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user