mirror of
https://github.com/Rapptz/discord.py.git
synced 2026-01-15 18:51:41 +00:00
Add DAVE protocol support
This commit is contained in:
@@ -44,6 +44,11 @@ from .activity import BaseActivity
|
||||
from .enums import SpeakingState
|
||||
from .errors import ConnectionClosed
|
||||
|
||||
try:
|
||||
import davey # type: ignore
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
__all__ = (
|
||||
@@ -812,18 +817,30 @@ class DiscordVoiceWebSocket:
|
||||
_max_heartbeat_timeout: float
|
||||
|
||||
# fmt: off
|
||||
IDENTIFY = 0
|
||||
SELECT_PROTOCOL = 1
|
||||
READY = 2
|
||||
HEARTBEAT = 3
|
||||
SESSION_DESCRIPTION = 4
|
||||
SPEAKING = 5
|
||||
HEARTBEAT_ACK = 6
|
||||
RESUME = 7
|
||||
HELLO = 8
|
||||
RESUMED = 9
|
||||
CLIENT_CONNECT = 12
|
||||
CLIENT_DISCONNECT = 13
|
||||
IDENTIFY = 0
|
||||
SELECT_PROTOCOL = 1
|
||||
READY = 2
|
||||
HEARTBEAT = 3
|
||||
SESSION_DESCRIPTION = 4
|
||||
SPEAKING = 5
|
||||
HEARTBEAT_ACK = 6
|
||||
RESUME = 7
|
||||
HELLO = 8
|
||||
RESUMED = 9
|
||||
CLIENTS_CONNECT = 11
|
||||
CLIENT_CONNECT = 12
|
||||
CLIENT_DISCONNECT = 13
|
||||
DAVE_PREPARE_TRANSITION = 21
|
||||
DAVE_EXECUTE_TRANSITION = 22
|
||||
DAVE_TRANSITION_READY = 23
|
||||
DAVE_PREPARE_EPOCH = 24
|
||||
MLS_EXTERNAL_SENDER = 25
|
||||
MLS_KEY_PACKAGE = 26
|
||||
MLS_PROPOSALS = 27
|
||||
MLS_COMMIT_WELCOME = 28
|
||||
MLS_ANNOUNCE_COMMIT_TRANSITION = 29
|
||||
MLS_WELCOME = 30
|
||||
MLS_INVALID_COMMIT_WELCOME = 31
|
||||
# fmt: on
|
||||
|
||||
def __init__(
|
||||
@@ -850,6 +867,10 @@ class DiscordVoiceWebSocket:
|
||||
_log.debug('Sending voice websocket frame: %s.', data)
|
||||
await self.ws.send_str(utils._to_json(data))
|
||||
|
||||
async def send_binary(self, opcode: int, data: bytes) -> None:
|
||||
_log.debug('Sending voice websocket binary frame: opcode=%s size=%d', opcode, len(data))
|
||||
await self.ws.send_bytes(bytes([opcode]) + data)
|
||||
|
||||
send_heartbeat = send_as_json
|
||||
|
||||
async def resume(self) -> None:
|
||||
@@ -874,6 +895,7 @@ class DiscordVoiceWebSocket:
|
||||
'user_id': str(state.user.id),
|
||||
'session_id': state.session_id,
|
||||
'token': state.token,
|
||||
'max_dave_protocol_version': state.max_dave_protocol_version,
|
||||
},
|
||||
}
|
||||
await self.send_as_json(payload)
|
||||
@@ -943,6 +965,16 @@ class DiscordVoiceWebSocket:
|
||||
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def send_transition_ready(self, transition_id: int):
|
||||
payload = {
|
||||
'op': DiscordVoiceWebSocket.DAVE_TRANSITION_READY,
|
||||
'd': {
|
||||
'transition_id': transition_id,
|
||||
},
|
||||
}
|
||||
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def received_message(self, msg: Dict[str, Any]) -> None:
|
||||
_log.debug('Voice websocket frame received: %s', msg)
|
||||
op = msg['op']
|
||||
@@ -959,13 +991,85 @@ class DiscordVoiceWebSocket:
|
||||
elif op == self.SESSION_DESCRIPTION:
|
||||
self._connection.mode = data['mode']
|
||||
await self.load_secret_key(data)
|
||||
self._connection.dave_protocol_version = data['dave_protocol_version']
|
||||
if data['dave_protocol_version'] > 0:
|
||||
await self._connection.reinit_dave_session()
|
||||
elif op == self.HELLO:
|
||||
interval = data['heartbeat_interval'] / 1000.0
|
||||
self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0))
|
||||
self._keep_alive.start()
|
||||
elif self._connection.dave_session:
|
||||
state = self._connection
|
||||
if op == self.DAVE_PREPARE_TRANSITION:
|
||||
_log.debug(
|
||||
'Preparing for DAVE transition id %d for protocol version %d',
|
||||
data['transition_id'],
|
||||
data['protocol_version'],
|
||||
)
|
||||
state.dave_pending_transitions[data['transition_id']] = data['protocol_version']
|
||||
if data['transition_id'] == 0:
|
||||
await state._execute_transition(data['transition_id'])
|
||||
else:
|
||||
if data['protocol_version'] == 0 and state.dave_session:
|
||||
state.dave_session.set_passthrough_mode(True, 120)
|
||||
|
||||
await self.send_transition_ready(data['transition_id'])
|
||||
elif op == self.DAVE_EXECUTE_TRANSITION:
|
||||
_log.debug('Executing DAVE transition id %d', data['transition_id'])
|
||||
await state._execute_transition(data['transition_id'])
|
||||
elif op == self.DAVE_PREPARE_EPOCH:
|
||||
_log.debug('Preparing for DAVE epoch %d', data['epoch'])
|
||||
# When the epoch ID is equal to 1, this message indicates that a new MLS group is to be created for the given protocol version.
|
||||
if data['epoch'] == 1:
|
||||
state.dave_protocol_version = data['protocol_version']
|
||||
await state.reinit_dave_session()
|
||||
|
||||
await self._hook(self, msg)
|
||||
|
||||
async def received_binary_message(self, msg: bytes) -> None:
|
||||
self.seq_ack = struct.unpack_from('>H', msg, 0)[0]
|
||||
op = msg[2]
|
||||
_log.debug('Voice websocket binary frame received: %d bytes; seq=%s op=%s', len(msg), self.seq_ack, op)
|
||||
state = self._connection
|
||||
|
||||
if state.dave_session is None:
|
||||
return
|
||||
|
||||
if op == self.MLS_EXTERNAL_SENDER:
|
||||
state.dave_session.set_external_sender(msg[3:])
|
||||
_log.debug('Set MLS external sender')
|
||||
elif op == self.MLS_PROPOSALS:
|
||||
optype = msg[3]
|
||||
result = state.dave_session.process_proposals(
|
||||
davey.ProposalsOperationType.append if optype == 0 else davey.ProposalsOperationType.revoke, msg[4:]
|
||||
)
|
||||
if isinstance(result, davey.CommitWelcome):
|
||||
await self.send_binary(
|
||||
DiscordVoiceWebSocket.MLS_COMMIT_WELCOME,
|
||||
result.commit + result.welcome if result.welcome else result.commit,
|
||||
)
|
||||
_log.debug('MLS proposals processed')
|
||||
elif op == self.MLS_ANNOUNCE_COMMIT_TRANSITION:
|
||||
transition_id = struct.unpack_from('>H', msg, 3)[0]
|
||||
try:
|
||||
state.dave_session.process_commit(msg[5:])
|
||||
if transition_id != 0:
|
||||
state.dave_pending_transitions[transition_id] = state.dave_protocol_version
|
||||
await self.send_transition_ready(transition_id)
|
||||
_log.debug('MLS commit processed for transition id %d', transition_id)
|
||||
except Exception:
|
||||
await state._recover_from_invalid_commit(transition_id)
|
||||
elif op == self.MLS_WELCOME:
|
||||
transition_id = struct.unpack_from('>H', msg, 3)[0]
|
||||
try:
|
||||
state.dave_session.process_welcome(msg[5:])
|
||||
if transition_id != 0:
|
||||
state.dave_pending_transitions[transition_id] = state.dave_protocol_version
|
||||
await self.send_transition_ready(transition_id)
|
||||
_log.debug('MLS welcome processed for transition id %d', transition_id)
|
||||
except Exception:
|
||||
await state._recover_from_invalid_commit(transition_id)
|
||||
|
||||
async def initial_connection(self, data: Dict[str, Any]) -> None:
|
||||
state = self._connection
|
||||
state.ssrc = data['ssrc']
|
||||
@@ -1045,6 +1149,8 @@ class DiscordVoiceWebSocket:
|
||||
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
|
||||
if msg.type is aiohttp.WSMsgType.TEXT:
|
||||
await self.received_message(utils._from_json(msg.data))
|
||||
elif msg.type is aiohttp.WSMsgType.BINARY:
|
||||
await self.received_binary_message(msg.data)
|
||||
elif msg.type is aiohttp.WSMsgType.ERROR:
|
||||
_log.debug('Received voice %s', msg)
|
||||
raise ConnectionClosed(self.ws, shard_id=None) from msg.data
|
||||
|
||||
Reference in New Issue
Block a user