mirror of
https://github.com/Rapptz/discord.py.git
synced 2026-01-15 18:51:41 +00:00
Add DAVE protocol support
This commit is contained in:
@@ -44,6 +44,11 @@ from .activity import BaseActivity
|
||||
from .enums import SpeakingState
|
||||
from .errors import ConnectionClosed
|
||||
|
||||
try:
|
||||
import davey # type: ignore
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
__all__ = (
|
||||
@@ -812,18 +817,30 @@ class DiscordVoiceWebSocket:
|
||||
_max_heartbeat_timeout: float
|
||||
|
||||
# fmt: off
|
||||
IDENTIFY = 0
|
||||
SELECT_PROTOCOL = 1
|
||||
READY = 2
|
||||
HEARTBEAT = 3
|
||||
SESSION_DESCRIPTION = 4
|
||||
SPEAKING = 5
|
||||
HEARTBEAT_ACK = 6
|
||||
RESUME = 7
|
||||
HELLO = 8
|
||||
RESUMED = 9
|
||||
CLIENT_CONNECT = 12
|
||||
CLIENT_DISCONNECT = 13
|
||||
IDENTIFY = 0
|
||||
SELECT_PROTOCOL = 1
|
||||
READY = 2
|
||||
HEARTBEAT = 3
|
||||
SESSION_DESCRIPTION = 4
|
||||
SPEAKING = 5
|
||||
HEARTBEAT_ACK = 6
|
||||
RESUME = 7
|
||||
HELLO = 8
|
||||
RESUMED = 9
|
||||
CLIENTS_CONNECT = 11
|
||||
CLIENT_CONNECT = 12
|
||||
CLIENT_DISCONNECT = 13
|
||||
DAVE_PREPARE_TRANSITION = 21
|
||||
DAVE_EXECUTE_TRANSITION = 22
|
||||
DAVE_TRANSITION_READY = 23
|
||||
DAVE_PREPARE_EPOCH = 24
|
||||
MLS_EXTERNAL_SENDER = 25
|
||||
MLS_KEY_PACKAGE = 26
|
||||
MLS_PROPOSALS = 27
|
||||
MLS_COMMIT_WELCOME = 28
|
||||
MLS_ANNOUNCE_COMMIT_TRANSITION = 29
|
||||
MLS_WELCOME = 30
|
||||
MLS_INVALID_COMMIT_WELCOME = 31
|
||||
# fmt: on
|
||||
|
||||
def __init__(
|
||||
@@ -850,6 +867,10 @@ class DiscordVoiceWebSocket:
|
||||
_log.debug('Sending voice websocket frame: %s.', data)
|
||||
await self.ws.send_str(utils._to_json(data))
|
||||
|
||||
async def send_binary(self, opcode: int, data: bytes) -> None:
|
||||
_log.debug('Sending voice websocket binary frame: opcode=%s size=%d', opcode, len(data))
|
||||
await self.ws.send_bytes(bytes([opcode]) + data)
|
||||
|
||||
send_heartbeat = send_as_json
|
||||
|
||||
async def resume(self) -> None:
|
||||
@@ -874,6 +895,7 @@ class DiscordVoiceWebSocket:
|
||||
'user_id': str(state.user.id),
|
||||
'session_id': state.session_id,
|
||||
'token': state.token,
|
||||
'max_dave_protocol_version': state.max_dave_protocol_version,
|
||||
},
|
||||
}
|
||||
await self.send_as_json(payload)
|
||||
@@ -943,6 +965,16 @@ class DiscordVoiceWebSocket:
|
||||
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def send_transition_ready(self, transition_id: int):
|
||||
payload = {
|
||||
'op': DiscordVoiceWebSocket.DAVE_TRANSITION_READY,
|
||||
'd': {
|
||||
'transition_id': transition_id,
|
||||
},
|
||||
}
|
||||
|
||||
await self.send_as_json(payload)
|
||||
|
||||
async def received_message(self, msg: Dict[str, Any]) -> None:
|
||||
_log.debug('Voice websocket frame received: %s', msg)
|
||||
op = msg['op']
|
||||
@@ -959,13 +991,85 @@ class DiscordVoiceWebSocket:
|
||||
elif op == self.SESSION_DESCRIPTION:
|
||||
self._connection.mode = data['mode']
|
||||
await self.load_secret_key(data)
|
||||
self._connection.dave_protocol_version = data['dave_protocol_version']
|
||||
if data['dave_protocol_version'] > 0:
|
||||
await self._connection.reinit_dave_session()
|
||||
elif op == self.HELLO:
|
||||
interval = data['heartbeat_interval'] / 1000.0
|
||||
self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0))
|
||||
self._keep_alive.start()
|
||||
elif self._connection.dave_session:
|
||||
state = self._connection
|
||||
if op == self.DAVE_PREPARE_TRANSITION:
|
||||
_log.debug(
|
||||
'Preparing for DAVE transition id %d for protocol version %d',
|
||||
data['transition_id'],
|
||||
data['protocol_version'],
|
||||
)
|
||||
state.dave_pending_transitions[data['transition_id']] = data['protocol_version']
|
||||
if data['transition_id'] == 0:
|
||||
await state._execute_transition(data['transition_id'])
|
||||
else:
|
||||
if data['protocol_version'] == 0 and state.dave_session:
|
||||
state.dave_session.set_passthrough_mode(True, 120)
|
||||
|
||||
await self.send_transition_ready(data['transition_id'])
|
||||
elif op == self.DAVE_EXECUTE_TRANSITION:
|
||||
_log.debug('Executing DAVE transition id %d', data['transition_id'])
|
||||
await state._execute_transition(data['transition_id'])
|
||||
elif op == self.DAVE_PREPARE_EPOCH:
|
||||
_log.debug('Preparing for DAVE epoch %d', data['epoch'])
|
||||
# When the epoch ID is equal to 1, this message indicates that a new MLS group is to be created for the given protocol version.
|
||||
if data['epoch'] == 1:
|
||||
state.dave_protocol_version = data['protocol_version']
|
||||
await state.reinit_dave_session()
|
||||
|
||||
await self._hook(self, msg)
|
||||
|
||||
async def received_binary_message(self, msg: bytes) -> None:
|
||||
self.seq_ack = struct.unpack_from('>H', msg, 0)[0]
|
||||
op = msg[2]
|
||||
_log.debug('Voice websocket binary frame received: %d bytes; seq=%s op=%s', len(msg), self.seq_ack, op)
|
||||
state = self._connection
|
||||
|
||||
if state.dave_session is None:
|
||||
return
|
||||
|
||||
if op == self.MLS_EXTERNAL_SENDER:
|
||||
state.dave_session.set_external_sender(msg[3:])
|
||||
_log.debug('Set MLS external sender')
|
||||
elif op == self.MLS_PROPOSALS:
|
||||
optype = msg[3]
|
||||
result = state.dave_session.process_proposals(
|
||||
davey.ProposalsOperationType.append if optype == 0 else davey.ProposalsOperationType.revoke, msg[4:]
|
||||
)
|
||||
if isinstance(result, davey.CommitWelcome):
|
||||
await self.send_binary(
|
||||
DiscordVoiceWebSocket.MLS_COMMIT_WELCOME,
|
||||
result.commit + result.welcome if result.welcome else result.commit,
|
||||
)
|
||||
_log.debug('MLS proposals processed')
|
||||
elif op == self.MLS_ANNOUNCE_COMMIT_TRANSITION:
|
||||
transition_id = struct.unpack_from('>H', msg, 3)[0]
|
||||
try:
|
||||
state.dave_session.process_commit(msg[5:])
|
||||
if transition_id != 0:
|
||||
state.dave_pending_transitions[transition_id] = state.dave_protocol_version
|
||||
await self.send_transition_ready(transition_id)
|
||||
_log.debug('MLS commit processed for transition id %d', transition_id)
|
||||
except Exception:
|
||||
await state._recover_from_invalid_commit(transition_id)
|
||||
elif op == self.MLS_WELCOME:
|
||||
transition_id = struct.unpack_from('>H', msg, 3)[0]
|
||||
try:
|
||||
state.dave_session.process_welcome(msg[5:])
|
||||
if transition_id != 0:
|
||||
state.dave_pending_transitions[transition_id] = state.dave_protocol_version
|
||||
await self.send_transition_ready(transition_id)
|
||||
_log.debug('MLS welcome processed for transition id %d', transition_id)
|
||||
except Exception:
|
||||
await state._recover_from_invalid_commit(transition_id)
|
||||
|
||||
async def initial_connection(self, data: Dict[str, Any]) -> None:
|
||||
state = self._connection
|
||||
state.ssrc = data['ssrc']
|
||||
@@ -1045,6 +1149,8 @@ class DiscordVoiceWebSocket:
|
||||
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
|
||||
if msg.type is aiohttp.WSMsgType.TEXT:
|
||||
await self.received_message(utils._from_json(msg.data))
|
||||
elif msg.type is aiohttp.WSMsgType.BINARY:
|
||||
await self.received_binary_message(msg.data)
|
||||
elif msg.type is aiohttp.WSMsgType.ERROR:
|
||||
_log.debug('Received voice %s', msg)
|
||||
raise ConnectionClosed(self.ws, shard_id=None) from msg.data
|
||||
|
||||
@@ -284,6 +284,17 @@ class VoiceClient(VoiceProtocol):
|
||||
def timeout(self) -> float:
|
||||
return self._connection.timeout
|
||||
|
||||
@property
|
||||
def voice_privacy_code(self) -> Optional[str]:
|
||||
""":class:`str`: Get the voice privacy code of this E2EE session's group.
|
||||
|
||||
A new privacy code is created and cached each time a new transition is executed.
|
||||
This can be None if there is no active DAVE session happening.
|
||||
|
||||
.. versionadded:: 2.7
|
||||
"""
|
||||
return self._connection.dave_session.voice_privacy_code if self._connection.dave_session else None
|
||||
|
||||
def checked_add(self, attr: str, value: int, limit: int) -> None:
|
||||
val = getattr(self, attr)
|
||||
if val + value > limit:
|
||||
@@ -368,7 +379,12 @@ class VoiceClient(VoiceProtocol):
|
||||
|
||||
# audio related
|
||||
|
||||
def _get_voice_packet(self, data):
|
||||
def _get_voice_packet(self, data: bytes):
|
||||
packet = (
|
||||
self._connection.dave_session.encrypt_opus(data)
|
||||
if self._connection.dave_session and self._connection.can_encrypt
|
||||
else data
|
||||
)
|
||||
header = bytearray(12)
|
||||
|
||||
# Formulate rtp header
|
||||
@@ -379,7 +395,7 @@ class VoiceClient(VoiceProtocol):
|
||||
struct.pack_into('>I', header, 8, self.ssrc)
|
||||
|
||||
encrypt_packet = getattr(self, '_encrypt_' + self.mode)
|
||||
return encrypt_packet(header, data)
|
||||
return encrypt_packet(header, packet)
|
||||
|
||||
def _encrypt_aead_xchacha20_poly1305_rtpsize(self, header: bytes, data) -> bytes:
|
||||
# Esentially the same as _lite
|
||||
|
||||
@@ -69,6 +69,14 @@ if TYPE_CHECKING:
|
||||
WebsocketHook = Optional[Callable[[DiscordVoiceWebSocket, Dict[str, Any]], Coroutine[Any, Any, Any]]]
|
||||
SocketReaderCallback = Callable[[bytes], Any]
|
||||
|
||||
has_dave: bool
|
||||
|
||||
try:
|
||||
import davey # type: ignore
|
||||
|
||||
has_dave = True
|
||||
except ImportError:
|
||||
has_dave = False
|
||||
|
||||
__all__ = ('VoiceConnectionState',)
|
||||
|
||||
@@ -208,6 +216,10 @@ class VoiceConnectionState:
|
||||
self.mode: SupportedModes = MISSING
|
||||
self.socket: socket.socket = MISSING
|
||||
self.ws: DiscordVoiceWebSocket = MISSING
|
||||
self.dave_session: Optional[davey.DaveSession] = None
|
||||
self.dave_protocol_version: int = 0
|
||||
self.dave_pending_transitions: Dict[int, int] = {}
|
||||
self.dave_downgraded: bool = False
|
||||
|
||||
self._state: ConnectionFlowState = ConnectionFlowState.disconnected
|
||||
self._expecting_disconnect: bool = False
|
||||
@@ -252,6 +264,64 @@ class VoiceConnectionState:
|
||||
def self_voice_state(self) -> Optional[VoiceState]:
|
||||
return self.guild.me.voice
|
||||
|
||||
@property
|
||||
def max_dave_protocol_version(self) -> int:
|
||||
return davey.DAVE_PROTOCOL_VERSION if has_dave else 0
|
||||
|
||||
@property
|
||||
def can_encrypt(self) -> bool:
|
||||
return self.dave_protocol_version != 0 and self.dave_session != None and self.dave_session.ready
|
||||
|
||||
async def reinit_dave_session(self) -> None:
|
||||
if self.dave_protocol_version > 0:
|
||||
if not has_dave:
|
||||
raise RuntimeError('davey library needed in order to use E2EE voice')
|
||||
if self.dave_session is not None:
|
||||
self.dave_session.reinit(self.dave_protocol_version, self.user.id, self.voice_client.channel.id)
|
||||
else:
|
||||
self.dave_session = davey.DaveSession(self.dave_protocol_version, self.user.id, self.voice_client.channel.id)
|
||||
|
||||
if self.dave_session is not None:
|
||||
await self.voice_client.ws.send_binary(
|
||||
DiscordVoiceWebSocket.MLS_KEY_PACKAGE, self.dave_session.get_serialized_key_package()
|
||||
)
|
||||
elif self.dave_session:
|
||||
self.dave_session.reset()
|
||||
self.dave_session.set_passthrough_mode(True, 10)
|
||||
pass
|
||||
|
||||
async def _recover_from_invalid_commit(self, transition_id: int) -> None:
|
||||
payload = {
|
||||
'op': DiscordVoiceWebSocket.MLS_INVALID_COMMIT_WELCOME,
|
||||
'd': {
|
||||
'transition_id': transition_id,
|
||||
},
|
||||
}
|
||||
|
||||
await self.voice_client.ws.send_as_json(payload)
|
||||
await self.reinit_dave_session()
|
||||
|
||||
async def _execute_transition(self, transition_id: int) -> None:
|
||||
_log.debug('Executing transition id %d', transition_id)
|
||||
if transition_id not in self.dave_pending_transitions:
|
||||
_log.warning("Received execute transition, but we don't have a pending transition for id %d", transition_id)
|
||||
return
|
||||
|
||||
old_version = self.dave_protocol_version
|
||||
self.dave_protocol_version = self.dave_pending_transitions.pop(transition_id)
|
||||
|
||||
if old_version != self.dave_protocol_version and self.dave_protocol_version == 0:
|
||||
self.dave_downgraded = True
|
||||
_log.debug('DAVE Session downgraded')
|
||||
elif transition_id > 0 and self.dave_downgraded:
|
||||
self.dave_downgraded = False
|
||||
if self.dave_session:
|
||||
self.dave_session.set_passthrough_mode(True, 10)
|
||||
_log.debug('DAVE Session upgraded')
|
||||
|
||||
# In the future, the session should be signaled too, but for now theres just v1
|
||||
_log.debug('Transition id %d executed', transition_id)
|
||||
|
||||
async def voice_state_update(self, data: GuildVoiceStatePayload) -> None:
|
||||
channel_id = data['channel_id']
|
||||
|
||||
|
||||
@@ -36,7 +36,10 @@ Documentation = "https://discordpy.readthedocs.io/en/latest/"
|
||||
dependencies = { file = "requirements.txt" }
|
||||
|
||||
[project.optional-dependencies]
|
||||
voice = ["PyNaCl>=1.5.0,<1.6"]
|
||||
voice = [
|
||||
"PyNaCl>=1.5.0,<1.6",
|
||||
"davey==0.1.0"
|
||||
]
|
||||
docs = [
|
||||
"sphinx==4.4.0",
|
||||
"sphinxcontrib_trio==1.1.2",
|
||||
|
||||
Reference in New Issue
Block a user