Fix disconnect when trying to move to another voice channel.
Not overly proud of this implementation but this allows the library to differentiate between a 4014 that means "move to another channel" or "move nowhere". Sometimes the VOICE_STATE_UPDATE comes before the actual websocket disconnect so special care had to be taken in that case. Fix #5904
This commit is contained in:
		| @@ -719,6 +719,7 @@ class DiscordVoiceWebSocket: | |||||||
|         self.loop = loop |         self.loop = loop | ||||||
|         self._keep_alive = None |         self._keep_alive = None | ||||||
|         self._close_code = None |         self._close_code = None | ||||||
|  |         self.secret_key = None | ||||||
|  |  | ||||||
|     async def send_as_json(self, data): |     async def send_as_json(self, data): | ||||||
|         log.debug('Sending voice websocket frame: %s.', data) |         log.debug('Sending voice websocket frame: %s.', data) | ||||||
| @@ -872,7 +873,7 @@ class DiscordVoiceWebSocket: | |||||||
|  |  | ||||||
|     async def load_secret_key(self, data): |     async def load_secret_key(self, data): | ||||||
|         log.info('received secret key for voice connection') |         log.info('received secret key for voice connection') | ||||||
|         self._connection.secret_key = data.get('secret_key') |         self.secret_key = self._connection.secret_key = data.get('secret_key') | ||||||
|         await self.speak() |         await self.speak() | ||||||
|         await self.speak(False) |         await self.speak(False) | ||||||
|  |  | ||||||
|   | |||||||
| @@ -208,6 +208,7 @@ class VoiceClient(VoiceProtocol): | |||||||
|         self._connected = threading.Event() |         self._connected = threading.Event() | ||||||
|  |  | ||||||
|         self._handshaking = False |         self._handshaking = False | ||||||
|  |         self._potentially_reconnecting = False | ||||||
|         self._voice_state_complete = asyncio.Event() |         self._voice_state_complete = asyncio.Event() | ||||||
|         self._voice_server_complete = asyncio.Event() |         self._voice_server_complete = asyncio.Event() | ||||||
|  |  | ||||||
| @@ -250,8 +251,10 @@ class VoiceClient(VoiceProtocol): | |||||||
|         self.session_id = data['session_id'] |         self.session_id = data['session_id'] | ||||||
|         channel_id = data['channel_id'] |         channel_id = data['channel_id'] | ||||||
|  |  | ||||||
|         if not self._handshaking: |         if not self._handshaking or self._potentially_reconnecting: | ||||||
|             # If we're done handshaking then we just need to update ourselves |             # If we're done handshaking then we just need to update ourselves | ||||||
|  |             # If we're potentially reconnecting due to a 4014, then we need to differentiate | ||||||
|  |             # a channel move and an actual force disconnect | ||||||
|             if channel_id is None: |             if channel_id is None: | ||||||
|                 # We're being disconnected so cleanup |                 # We're being disconnected so cleanup | ||||||
|                 await self.disconnect() |                 await self.disconnect() | ||||||
| @@ -294,26 +297,39 @@ class VoiceClient(VoiceProtocol): | |||||||
|         self._voice_server_complete.set() |         self._voice_server_complete.set() | ||||||
|  |  | ||||||
|     async def voice_connect(self): |     async def voice_connect(self): | ||||||
|         self._connections += 1 |  | ||||||
|         await self.channel.guild.change_voice_state(channel=self.channel) |         await self.channel.guild.change_voice_state(channel=self.channel) | ||||||
|  |  | ||||||
|     async def voice_disconnect(self): |     async def voice_disconnect(self): | ||||||
|         log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id) |         log.info('The voice handshake is being terminated for Channel ID %s (Guild ID %s)', self.channel.id, self.guild.id) | ||||||
|         await self.channel.guild.change_voice_state(channel=None) |         await self.channel.guild.change_voice_state(channel=None) | ||||||
|  |  | ||||||
|     async def connect(self, *, reconnect, timeout): |     def prepare_handshake(self): | ||||||
|         log.info('Connecting to voice...') |  | ||||||
|         self.timeout = timeout |  | ||||||
|         try: |  | ||||||
|             del self.secret_key |  | ||||||
|         except AttributeError: |  | ||||||
|             pass |  | ||||||
|  |  | ||||||
|  |  | ||||||
|         for i in range(5): |  | ||||||
|         self._voice_state_complete.clear() |         self._voice_state_complete.clear() | ||||||
|         self._voice_server_complete.clear() |         self._voice_server_complete.clear() | ||||||
|         self._handshaking = True |         self._handshaking = True | ||||||
|  |         log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1) | ||||||
|  |         self._connections += 1 | ||||||
|  |  | ||||||
|  |     def finish_handshake(self): | ||||||
|  |         log.info('Voice handshake complete. Endpoint found %s', self.endpoint) | ||||||
|  |         self._handshaking = False | ||||||
|  |         self._voice_server_complete.clear() | ||||||
|  |         self._voice_state_complete.clear() | ||||||
|  |  | ||||||
|  |     async def connect_websocket(self): | ||||||
|  |         ws = await DiscordVoiceWebSocket.from_client(self) | ||||||
|  |         self._connected.clear() | ||||||
|  |         while ws.secret_key is None: | ||||||
|  |             await ws.poll_event() | ||||||
|  |         self._connected.set() | ||||||
|  |         return ws | ||||||
|  |  | ||||||
|  |     async def connect(self, *, reconnect, timeout): | ||||||
|  |         log.info('Connecting to voice...') | ||||||
|  |         self.timeout = timeout | ||||||
|  |  | ||||||
|  |         for i in range(5): | ||||||
|  |             self.prepare_handshake() | ||||||
|  |  | ||||||
|             # This has to be created before we start the flow. |             # This has to be created before we start the flow. | ||||||
|             futures = [ |             futures = [ | ||||||
| @@ -322,7 +338,6 @@ class VoiceClient(VoiceProtocol): | |||||||
|             ] |             ] | ||||||
|  |  | ||||||
|             # Start the connection flow |             # Start the connection flow | ||||||
|             log.info('Starting voice handshake... (connection attempt %d)', self._connections + 1) |  | ||||||
|             await self.voice_connect() |             await self.voice_connect() | ||||||
|  |  | ||||||
|             try: |             try: | ||||||
| @@ -331,17 +346,10 @@ class VoiceClient(VoiceProtocol): | |||||||
|                 await self.disconnect(force=True) |                 await self.disconnect(force=True) | ||||||
|                 raise |                 raise | ||||||
|  |  | ||||||
|             log.info('Voice handshake complete. Endpoint found %s', self.endpoint) |             self.finish_handshake() | ||||||
|             self._handshaking = False |  | ||||||
|             self._voice_server_complete.clear() |  | ||||||
|             self._voice_state_complete.clear() |  | ||||||
|  |  | ||||||
|             try: |             try: | ||||||
|                 self.ws = await DiscordVoiceWebSocket.from_client(self) |                 self.ws = await self.connect_websocket() | ||||||
|                 self._connected.clear() |  | ||||||
|                 while not hasattr(self, 'secret_key'): |  | ||||||
|                     await self.ws.poll_event() |  | ||||||
|                 self._connected.set() |  | ||||||
|                 break |                 break | ||||||
|             except (ConnectionClosed, asyncio.TimeoutError): |             except (ConnectionClosed, asyncio.TimeoutError): | ||||||
|                 if reconnect: |                 if reconnect: | ||||||
| @@ -355,6 +363,26 @@ class VoiceClient(VoiceProtocol): | |||||||
|         if self._runner is None: |         if self._runner is None: | ||||||
|             self._runner = self.loop.create_task(self.poll_voice_ws(reconnect)) |             self._runner = self.loop.create_task(self.poll_voice_ws(reconnect)) | ||||||
|  |  | ||||||
|  |     async def potential_reconnect(self): | ||||||
|  |         self.prepare_handshake() | ||||||
|  |         self._potentially_reconnecting = True | ||||||
|  |         try: | ||||||
|  |             # We only care about VOICE_SERVER_UPDATE since VOICE_STATE_UPDATE can come before we get disconnected | ||||||
|  |             await asyncio.wait_for(self._voice_server_complete.wait(), timeout=self.timeout) | ||||||
|  |         except asyncio.TimeoutError: | ||||||
|  |             self._potentially_reconnecting = False | ||||||
|  |             await self.disconnect(force=True) | ||||||
|  |             return False | ||||||
|  |  | ||||||
|  |         self.finish_handshake() | ||||||
|  |         self._potentially_reconnecting = False | ||||||
|  |         try: | ||||||
|  |             self.ws = await self.connect_websocket() | ||||||
|  |         except (ConnectionClosed, asyncio.TimeoutError): | ||||||
|  |             return False | ||||||
|  |         else: | ||||||
|  |             return True | ||||||
|  |  | ||||||
|     @property |     @property | ||||||
|     def latency(self): |     def latency(self): | ||||||
|         """:class:`float`: Latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds. |         """:class:`float`: Latency between a HEARTBEAT and a HEARTBEAT_ACK in seconds. | ||||||
| @@ -387,10 +415,19 @@ class VoiceClient(VoiceProtocol): | |||||||
|                     # 1000 - normal closure (obviously) |                     # 1000 - normal closure (obviously) | ||||||
|                     # 4014 - voice channel has been deleted. |                     # 4014 - voice channel has been deleted. | ||||||
|                     # 4015 - voice server has crashed |                     # 4015 - voice server has crashed | ||||||
|                     if exc.code in (1000, 4014, 4015): |                     if exc.code in (1000, 4015): | ||||||
|                         log.info('Disconnecting from voice normally, close code %d.', exc.code) |                         log.info('Disconnecting from voice normally, close code %d.', exc.code) | ||||||
|                         await self.disconnect() |                         await self.disconnect() | ||||||
|                         break |                         break | ||||||
|  |                     if exc.code == 4014: | ||||||
|  |                         log.info('Disconnected from voice by force... potentially reconnecting.') | ||||||
|  |                         successful = await self.potential_reconnect() | ||||||
|  |                         if not successful: | ||||||
|  |                             log.info('Reconnect was unsuccessful, disconnecting from voice normally...') | ||||||
|  |                             await self.disconnect() | ||||||
|  |                             break | ||||||
|  |                         else: | ||||||
|  |                             continue | ||||||
|  |  | ||||||
|                 if not reconnect: |                 if not reconnect: | ||||||
|                     await self.disconnect() |                     await self.disconnect() | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user