Add DAVE protocol support

This commit is contained in:
Snazzah
2026-01-07 16:05:46 -05:00
committed by GitHub
parent 0052878983
commit bd37844be7
4 changed files with 210 additions and 15 deletions

View File

@@ -44,6 +44,11 @@ from .activity import BaseActivity
from .enums import SpeakingState
from .errors import ConnectionClosed
try:
import davey # type: ignore
except ImportError:
pass
_log = logging.getLogger(__name__)
__all__ = (
@@ -812,18 +817,30 @@ class DiscordVoiceWebSocket:
_max_heartbeat_timeout: float
# fmt: off
IDENTIFY = 0
SELECT_PROTOCOL = 1
READY = 2
HEARTBEAT = 3
SESSION_DESCRIPTION = 4
SPEAKING = 5
HEARTBEAT_ACK = 6
RESUME = 7
HELLO = 8
RESUMED = 9
CLIENT_CONNECT = 12
CLIENT_DISCONNECT = 13
IDENTIFY = 0
SELECT_PROTOCOL = 1
READY = 2
HEARTBEAT = 3
SESSION_DESCRIPTION = 4
SPEAKING = 5
HEARTBEAT_ACK = 6
RESUME = 7
HELLO = 8
RESUMED = 9
CLIENTS_CONNECT = 11
CLIENT_CONNECT = 12
CLIENT_DISCONNECT = 13
DAVE_PREPARE_TRANSITION = 21
DAVE_EXECUTE_TRANSITION = 22
DAVE_TRANSITION_READY = 23
DAVE_PREPARE_EPOCH = 24
MLS_EXTERNAL_SENDER = 25
MLS_KEY_PACKAGE = 26
MLS_PROPOSALS = 27
MLS_COMMIT_WELCOME = 28
MLS_ANNOUNCE_COMMIT_TRANSITION = 29
MLS_WELCOME = 30
MLS_INVALID_COMMIT_WELCOME = 31
# fmt: on
def __init__(
@@ -850,6 +867,10 @@ class DiscordVoiceWebSocket:
_log.debug('Sending voice websocket frame: %s.', data)
await self.ws.send_str(utils._to_json(data))
async def send_binary(self, opcode: int, data: bytes) -> None:
_log.debug('Sending voice websocket binary frame: opcode=%s size=%d', opcode, len(data))
await self.ws.send_bytes(bytes([opcode]) + data)
send_heartbeat = send_as_json
async def resume(self) -> None:
@@ -874,6 +895,7 @@ class DiscordVoiceWebSocket:
'user_id': str(state.user.id),
'session_id': state.session_id,
'token': state.token,
'max_dave_protocol_version': state.max_dave_protocol_version,
},
}
await self.send_as_json(payload)
@@ -943,6 +965,16 @@ class DiscordVoiceWebSocket:
await self.send_as_json(payload)
async def send_transition_ready(self, transition_id: int):
payload = {
'op': DiscordVoiceWebSocket.DAVE_TRANSITION_READY,
'd': {
'transition_id': transition_id,
},
}
await self.send_as_json(payload)
async def received_message(self, msg: Dict[str, Any]) -> None:
_log.debug('Voice websocket frame received: %s', msg)
op = msg['op']
@@ -959,13 +991,85 @@ class DiscordVoiceWebSocket:
elif op == self.SESSION_DESCRIPTION:
self._connection.mode = data['mode']
await self.load_secret_key(data)
self._connection.dave_protocol_version = data['dave_protocol_version']
if data['dave_protocol_version'] > 0:
await self._connection.reinit_dave_session()
elif op == self.HELLO:
interval = data['heartbeat_interval'] / 1000.0
self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0))
self._keep_alive.start()
elif self._connection.dave_session:
state = self._connection
if op == self.DAVE_PREPARE_TRANSITION:
_log.debug(
'Preparing for DAVE transition id %d for protocol version %d',
data['transition_id'],
data['protocol_version'],
)
state.dave_pending_transitions[data['transition_id']] = data['protocol_version']
if data['transition_id'] == 0:
await state._execute_transition(data['transition_id'])
else:
if data['protocol_version'] == 0 and state.dave_session:
state.dave_session.set_passthrough_mode(True, 120)
await self.send_transition_ready(data['transition_id'])
elif op == self.DAVE_EXECUTE_TRANSITION:
_log.debug('Executing DAVE transition id %d', data['transition_id'])
await state._execute_transition(data['transition_id'])
elif op == self.DAVE_PREPARE_EPOCH:
_log.debug('Preparing for DAVE epoch %d', data['epoch'])
# When the epoch ID is equal to 1, this message indicates that a new MLS group is to be created for the given protocol version.
if data['epoch'] == 1:
state.dave_protocol_version = data['protocol_version']
await state.reinit_dave_session()
await self._hook(self, msg)
async def received_binary_message(self, msg: bytes) -> None:
self.seq_ack = struct.unpack_from('>H', msg, 0)[0]
op = msg[2]
_log.debug('Voice websocket binary frame received: %d bytes; seq=%s op=%s', len(msg), self.seq_ack, op)
state = self._connection
if state.dave_session is None:
return
if op == self.MLS_EXTERNAL_SENDER:
state.dave_session.set_external_sender(msg[3:])
_log.debug('Set MLS external sender')
elif op == self.MLS_PROPOSALS:
optype = msg[3]
result = state.dave_session.process_proposals(
davey.ProposalsOperationType.append if optype == 0 else davey.ProposalsOperationType.revoke, msg[4:]
)
if isinstance(result, davey.CommitWelcome):
await self.send_binary(
DiscordVoiceWebSocket.MLS_COMMIT_WELCOME,
result.commit + result.welcome if result.welcome else result.commit,
)
_log.debug('MLS proposals processed')
elif op == self.MLS_ANNOUNCE_COMMIT_TRANSITION:
transition_id = struct.unpack_from('>H', msg, 3)[0]
try:
state.dave_session.process_commit(msg[5:])
if transition_id != 0:
state.dave_pending_transitions[transition_id] = state.dave_protocol_version
await self.send_transition_ready(transition_id)
_log.debug('MLS commit processed for transition id %d', transition_id)
except Exception:
await state._recover_from_invalid_commit(transition_id)
elif op == self.MLS_WELCOME:
transition_id = struct.unpack_from('>H', msg, 3)[0]
try:
state.dave_session.process_welcome(msg[5:])
if transition_id != 0:
state.dave_pending_transitions[transition_id] = state.dave_protocol_version
await self.send_transition_ready(transition_id)
_log.debug('MLS welcome processed for transition id %d', transition_id)
except Exception:
await state._recover_from_invalid_commit(transition_id)
async def initial_connection(self, data: Dict[str, Any]) -> None:
state = self._connection
state.ssrc = data['ssrc']
@@ -1045,6 +1149,8 @@ class DiscordVoiceWebSocket:
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
if msg.type is aiohttp.WSMsgType.TEXT:
await self.received_message(utils._from_json(msg.data))
elif msg.type is aiohttp.WSMsgType.BINARY:
await self.received_binary_message(msg.data)
elif msg.type is aiohttp.WSMsgType.ERROR:
_log.debug('Received voice %s', msg)
raise ConnectionClosed(self.ws, shard_id=None) from msg.data