mirror of
https://github.com/Rapptz/discord.py.git
synced 2026-01-15 18:51:41 +00:00
Add DAVE protocol support
This commit is contained in:
@@ -44,6 +44,11 @@ from .activity import BaseActivity
|
|||||||
from .enums import SpeakingState
|
from .enums import SpeakingState
|
||||||
from .errors import ConnectionClosed
|
from .errors import ConnectionClosed
|
||||||
|
|
||||||
|
try:
|
||||||
|
import davey # type: ignore
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
_log = logging.getLogger(__name__)
|
_log = logging.getLogger(__name__)
|
||||||
|
|
||||||
__all__ = (
|
__all__ = (
|
||||||
@@ -822,8 +827,20 @@ class DiscordVoiceWebSocket:
|
|||||||
RESUME = 7
|
RESUME = 7
|
||||||
HELLO = 8
|
HELLO = 8
|
||||||
RESUMED = 9
|
RESUMED = 9
|
||||||
|
CLIENTS_CONNECT = 11
|
||||||
CLIENT_CONNECT = 12
|
CLIENT_CONNECT = 12
|
||||||
CLIENT_DISCONNECT = 13
|
CLIENT_DISCONNECT = 13
|
||||||
|
DAVE_PREPARE_TRANSITION = 21
|
||||||
|
DAVE_EXECUTE_TRANSITION = 22
|
||||||
|
DAVE_TRANSITION_READY = 23
|
||||||
|
DAVE_PREPARE_EPOCH = 24
|
||||||
|
MLS_EXTERNAL_SENDER = 25
|
||||||
|
MLS_KEY_PACKAGE = 26
|
||||||
|
MLS_PROPOSALS = 27
|
||||||
|
MLS_COMMIT_WELCOME = 28
|
||||||
|
MLS_ANNOUNCE_COMMIT_TRANSITION = 29
|
||||||
|
MLS_WELCOME = 30
|
||||||
|
MLS_INVALID_COMMIT_WELCOME = 31
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -850,6 +867,10 @@ class DiscordVoiceWebSocket:
|
|||||||
_log.debug('Sending voice websocket frame: %s.', data)
|
_log.debug('Sending voice websocket frame: %s.', data)
|
||||||
await self.ws.send_str(utils._to_json(data))
|
await self.ws.send_str(utils._to_json(data))
|
||||||
|
|
||||||
|
async def send_binary(self, opcode: int, data: bytes) -> None:
|
||||||
|
_log.debug('Sending voice websocket binary frame: opcode=%s size=%d', opcode, len(data))
|
||||||
|
await self.ws.send_bytes(bytes([opcode]) + data)
|
||||||
|
|
||||||
send_heartbeat = send_as_json
|
send_heartbeat = send_as_json
|
||||||
|
|
||||||
async def resume(self) -> None:
|
async def resume(self) -> None:
|
||||||
@@ -874,6 +895,7 @@ class DiscordVoiceWebSocket:
|
|||||||
'user_id': str(state.user.id),
|
'user_id': str(state.user.id),
|
||||||
'session_id': state.session_id,
|
'session_id': state.session_id,
|
||||||
'token': state.token,
|
'token': state.token,
|
||||||
|
'max_dave_protocol_version': state.max_dave_protocol_version,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
await self.send_as_json(payload)
|
await self.send_as_json(payload)
|
||||||
@@ -943,6 +965,16 @@ class DiscordVoiceWebSocket:
|
|||||||
|
|
||||||
await self.send_as_json(payload)
|
await self.send_as_json(payload)
|
||||||
|
|
||||||
|
async def send_transition_ready(self, transition_id: int):
|
||||||
|
payload = {
|
||||||
|
'op': DiscordVoiceWebSocket.DAVE_TRANSITION_READY,
|
||||||
|
'd': {
|
||||||
|
'transition_id': transition_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.send_as_json(payload)
|
||||||
|
|
||||||
async def received_message(self, msg: Dict[str, Any]) -> None:
|
async def received_message(self, msg: Dict[str, Any]) -> None:
|
||||||
_log.debug('Voice websocket frame received: %s', msg)
|
_log.debug('Voice websocket frame received: %s', msg)
|
||||||
op = msg['op']
|
op = msg['op']
|
||||||
@@ -959,13 +991,85 @@ class DiscordVoiceWebSocket:
|
|||||||
elif op == self.SESSION_DESCRIPTION:
|
elif op == self.SESSION_DESCRIPTION:
|
||||||
self._connection.mode = data['mode']
|
self._connection.mode = data['mode']
|
||||||
await self.load_secret_key(data)
|
await self.load_secret_key(data)
|
||||||
|
self._connection.dave_protocol_version = data['dave_protocol_version']
|
||||||
|
if data['dave_protocol_version'] > 0:
|
||||||
|
await self._connection.reinit_dave_session()
|
||||||
elif op == self.HELLO:
|
elif op == self.HELLO:
|
||||||
interval = data['heartbeat_interval'] / 1000.0
|
interval = data['heartbeat_interval'] / 1000.0
|
||||||
self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0))
|
self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0))
|
||||||
self._keep_alive.start()
|
self._keep_alive.start()
|
||||||
|
elif self._connection.dave_session:
|
||||||
|
state = self._connection
|
||||||
|
if op == self.DAVE_PREPARE_TRANSITION:
|
||||||
|
_log.debug(
|
||||||
|
'Preparing for DAVE transition id %d for protocol version %d',
|
||||||
|
data['transition_id'],
|
||||||
|
data['protocol_version'],
|
||||||
|
)
|
||||||
|
state.dave_pending_transitions[data['transition_id']] = data['protocol_version']
|
||||||
|
if data['transition_id'] == 0:
|
||||||
|
await state._execute_transition(data['transition_id'])
|
||||||
|
else:
|
||||||
|
if data['protocol_version'] == 0 and state.dave_session:
|
||||||
|
state.dave_session.set_passthrough_mode(True, 120)
|
||||||
|
|
||||||
|
await self.send_transition_ready(data['transition_id'])
|
||||||
|
elif op == self.DAVE_EXECUTE_TRANSITION:
|
||||||
|
_log.debug('Executing DAVE transition id %d', data['transition_id'])
|
||||||
|
await state._execute_transition(data['transition_id'])
|
||||||
|
elif op == self.DAVE_PREPARE_EPOCH:
|
||||||
|
_log.debug('Preparing for DAVE epoch %d', data['epoch'])
|
||||||
|
# When the epoch ID is equal to 1, this message indicates that a new MLS group is to be created for the given protocol version.
|
||||||
|
if data['epoch'] == 1:
|
||||||
|
state.dave_protocol_version = data['protocol_version']
|
||||||
|
await state.reinit_dave_session()
|
||||||
|
|
||||||
await self._hook(self, msg)
|
await self._hook(self, msg)
|
||||||
|
|
||||||
|
async def received_binary_message(self, msg: bytes) -> None:
|
||||||
|
self.seq_ack = struct.unpack_from('>H', msg, 0)[0]
|
||||||
|
op = msg[2]
|
||||||
|
_log.debug('Voice websocket binary frame received: %d bytes; seq=%s op=%s', len(msg), self.seq_ack, op)
|
||||||
|
state = self._connection
|
||||||
|
|
||||||
|
if state.dave_session is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
if op == self.MLS_EXTERNAL_SENDER:
|
||||||
|
state.dave_session.set_external_sender(msg[3:])
|
||||||
|
_log.debug('Set MLS external sender')
|
||||||
|
elif op == self.MLS_PROPOSALS:
|
||||||
|
optype = msg[3]
|
||||||
|
result = state.dave_session.process_proposals(
|
||||||
|
davey.ProposalsOperationType.append if optype == 0 else davey.ProposalsOperationType.revoke, msg[4:]
|
||||||
|
)
|
||||||
|
if isinstance(result, davey.CommitWelcome):
|
||||||
|
await self.send_binary(
|
||||||
|
DiscordVoiceWebSocket.MLS_COMMIT_WELCOME,
|
||||||
|
result.commit + result.welcome if result.welcome else result.commit,
|
||||||
|
)
|
||||||
|
_log.debug('MLS proposals processed')
|
||||||
|
elif op == self.MLS_ANNOUNCE_COMMIT_TRANSITION:
|
||||||
|
transition_id = struct.unpack_from('>H', msg, 3)[0]
|
||||||
|
try:
|
||||||
|
state.dave_session.process_commit(msg[5:])
|
||||||
|
if transition_id != 0:
|
||||||
|
state.dave_pending_transitions[transition_id] = state.dave_protocol_version
|
||||||
|
await self.send_transition_ready(transition_id)
|
||||||
|
_log.debug('MLS commit processed for transition id %d', transition_id)
|
||||||
|
except Exception:
|
||||||
|
await state._recover_from_invalid_commit(transition_id)
|
||||||
|
elif op == self.MLS_WELCOME:
|
||||||
|
transition_id = struct.unpack_from('>H', msg, 3)[0]
|
||||||
|
try:
|
||||||
|
state.dave_session.process_welcome(msg[5:])
|
||||||
|
if transition_id != 0:
|
||||||
|
state.dave_pending_transitions[transition_id] = state.dave_protocol_version
|
||||||
|
await self.send_transition_ready(transition_id)
|
||||||
|
_log.debug('MLS welcome processed for transition id %d', transition_id)
|
||||||
|
except Exception:
|
||||||
|
await state._recover_from_invalid_commit(transition_id)
|
||||||
|
|
||||||
async def initial_connection(self, data: Dict[str, Any]) -> None:
|
async def initial_connection(self, data: Dict[str, Any]) -> None:
|
||||||
state = self._connection
|
state = self._connection
|
||||||
state.ssrc = data['ssrc']
|
state.ssrc = data['ssrc']
|
||||||
@@ -1045,6 +1149,8 @@ class DiscordVoiceWebSocket:
|
|||||||
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
|
msg = await asyncio.wait_for(self.ws.receive(), timeout=30.0)
|
||||||
if msg.type is aiohttp.WSMsgType.TEXT:
|
if msg.type is aiohttp.WSMsgType.TEXT:
|
||||||
await self.received_message(utils._from_json(msg.data))
|
await self.received_message(utils._from_json(msg.data))
|
||||||
|
elif msg.type is aiohttp.WSMsgType.BINARY:
|
||||||
|
await self.received_binary_message(msg.data)
|
||||||
elif msg.type is aiohttp.WSMsgType.ERROR:
|
elif msg.type is aiohttp.WSMsgType.ERROR:
|
||||||
_log.debug('Received voice %s', msg)
|
_log.debug('Received voice %s', msg)
|
||||||
raise ConnectionClosed(self.ws, shard_id=None) from msg.data
|
raise ConnectionClosed(self.ws, shard_id=None) from msg.data
|
||||||
|
|||||||
@@ -284,6 +284,17 @@ class VoiceClient(VoiceProtocol):
|
|||||||
def timeout(self) -> float:
|
def timeout(self) -> float:
|
||||||
return self._connection.timeout
|
return self._connection.timeout
|
||||||
|
|
||||||
|
@property
|
||||||
|
def voice_privacy_code(self) -> Optional[str]:
|
||||||
|
""":class:`str`: Get the voice privacy code of this E2EE session's group.
|
||||||
|
|
||||||
|
A new privacy code is created and cached each time a new transition is executed.
|
||||||
|
This can be None if there is no active DAVE session happening.
|
||||||
|
|
||||||
|
.. versionadded:: 2.7
|
||||||
|
"""
|
||||||
|
return self._connection.dave_session.voice_privacy_code if self._connection.dave_session else None
|
||||||
|
|
||||||
def checked_add(self, attr: str, value: int, limit: int) -> None:
|
def checked_add(self, attr: str, value: int, limit: int) -> None:
|
||||||
val = getattr(self, attr)
|
val = getattr(self, attr)
|
||||||
if val + value > limit:
|
if val + value > limit:
|
||||||
@@ -368,7 +379,12 @@ class VoiceClient(VoiceProtocol):
|
|||||||
|
|
||||||
# audio related
|
# audio related
|
||||||
|
|
||||||
def _get_voice_packet(self, data):
|
def _get_voice_packet(self, data: bytes):
|
||||||
|
packet = (
|
||||||
|
self._connection.dave_session.encrypt_opus(data)
|
||||||
|
if self._connection.dave_session and self._connection.can_encrypt
|
||||||
|
else data
|
||||||
|
)
|
||||||
header = bytearray(12)
|
header = bytearray(12)
|
||||||
|
|
||||||
# Formulate rtp header
|
# Formulate rtp header
|
||||||
@@ -379,7 +395,7 @@ class VoiceClient(VoiceProtocol):
|
|||||||
struct.pack_into('>I', header, 8, self.ssrc)
|
struct.pack_into('>I', header, 8, self.ssrc)
|
||||||
|
|
||||||
encrypt_packet = getattr(self, '_encrypt_' + self.mode)
|
encrypt_packet = getattr(self, '_encrypt_' + self.mode)
|
||||||
return encrypt_packet(header, data)
|
return encrypt_packet(header, packet)
|
||||||
|
|
||||||
def _encrypt_aead_xchacha20_poly1305_rtpsize(self, header: bytes, data) -> bytes:
|
def _encrypt_aead_xchacha20_poly1305_rtpsize(self, header: bytes, data) -> bytes:
|
||||||
# Esentially the same as _lite
|
# Esentially the same as _lite
|
||||||
|
|||||||
@@ -69,6 +69,14 @@ if TYPE_CHECKING:
|
|||||||
WebsocketHook = Optional[Callable[[DiscordVoiceWebSocket, Dict[str, Any]], Coroutine[Any, Any, Any]]]
|
WebsocketHook = Optional[Callable[[DiscordVoiceWebSocket, Dict[str, Any]], Coroutine[Any, Any, Any]]]
|
||||||
SocketReaderCallback = Callable[[bytes], Any]
|
SocketReaderCallback = Callable[[bytes], Any]
|
||||||
|
|
||||||
|
has_dave: bool
|
||||||
|
|
||||||
|
try:
|
||||||
|
import davey # type: ignore
|
||||||
|
|
||||||
|
has_dave = True
|
||||||
|
except ImportError:
|
||||||
|
has_dave = False
|
||||||
|
|
||||||
__all__ = ('VoiceConnectionState',)
|
__all__ = ('VoiceConnectionState',)
|
||||||
|
|
||||||
@@ -208,6 +216,10 @@ class VoiceConnectionState:
|
|||||||
self.mode: SupportedModes = MISSING
|
self.mode: SupportedModes = MISSING
|
||||||
self.socket: socket.socket = MISSING
|
self.socket: socket.socket = MISSING
|
||||||
self.ws: DiscordVoiceWebSocket = MISSING
|
self.ws: DiscordVoiceWebSocket = MISSING
|
||||||
|
self.dave_session: Optional[davey.DaveSession] = None
|
||||||
|
self.dave_protocol_version: int = 0
|
||||||
|
self.dave_pending_transitions: Dict[int, int] = {}
|
||||||
|
self.dave_downgraded: bool = False
|
||||||
|
|
||||||
self._state: ConnectionFlowState = ConnectionFlowState.disconnected
|
self._state: ConnectionFlowState = ConnectionFlowState.disconnected
|
||||||
self._expecting_disconnect: bool = False
|
self._expecting_disconnect: bool = False
|
||||||
@@ -252,6 +264,64 @@ class VoiceConnectionState:
|
|||||||
def self_voice_state(self) -> Optional[VoiceState]:
|
def self_voice_state(self) -> Optional[VoiceState]:
|
||||||
return self.guild.me.voice
|
return self.guild.me.voice
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_dave_protocol_version(self) -> int:
|
||||||
|
return davey.DAVE_PROTOCOL_VERSION if has_dave else 0
|
||||||
|
|
||||||
|
@property
|
||||||
|
def can_encrypt(self) -> bool:
|
||||||
|
return self.dave_protocol_version != 0 and self.dave_session != None and self.dave_session.ready
|
||||||
|
|
||||||
|
async def reinit_dave_session(self) -> None:
|
||||||
|
if self.dave_protocol_version > 0:
|
||||||
|
if not has_dave:
|
||||||
|
raise RuntimeError('davey library needed in order to use E2EE voice')
|
||||||
|
if self.dave_session is not None:
|
||||||
|
self.dave_session.reinit(self.dave_protocol_version, self.user.id, self.voice_client.channel.id)
|
||||||
|
else:
|
||||||
|
self.dave_session = davey.DaveSession(self.dave_protocol_version, self.user.id, self.voice_client.channel.id)
|
||||||
|
|
||||||
|
if self.dave_session is not None:
|
||||||
|
await self.voice_client.ws.send_binary(
|
||||||
|
DiscordVoiceWebSocket.MLS_KEY_PACKAGE, self.dave_session.get_serialized_key_package()
|
||||||
|
)
|
||||||
|
elif self.dave_session:
|
||||||
|
self.dave_session.reset()
|
||||||
|
self.dave_session.set_passthrough_mode(True, 10)
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _recover_from_invalid_commit(self, transition_id: int) -> None:
|
||||||
|
payload = {
|
||||||
|
'op': DiscordVoiceWebSocket.MLS_INVALID_COMMIT_WELCOME,
|
||||||
|
'd': {
|
||||||
|
'transition_id': transition_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
await self.voice_client.ws.send_as_json(payload)
|
||||||
|
await self.reinit_dave_session()
|
||||||
|
|
||||||
|
async def _execute_transition(self, transition_id: int) -> None:
|
||||||
|
_log.debug('Executing transition id %d', transition_id)
|
||||||
|
if transition_id not in self.dave_pending_transitions:
|
||||||
|
_log.warning("Received execute transition, but we don't have a pending transition for id %d", transition_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
old_version = self.dave_protocol_version
|
||||||
|
self.dave_protocol_version = self.dave_pending_transitions.pop(transition_id)
|
||||||
|
|
||||||
|
if old_version != self.dave_protocol_version and self.dave_protocol_version == 0:
|
||||||
|
self.dave_downgraded = True
|
||||||
|
_log.debug('DAVE Session downgraded')
|
||||||
|
elif transition_id > 0 and self.dave_downgraded:
|
||||||
|
self.dave_downgraded = False
|
||||||
|
if self.dave_session:
|
||||||
|
self.dave_session.set_passthrough_mode(True, 10)
|
||||||
|
_log.debug('DAVE Session upgraded')
|
||||||
|
|
||||||
|
# In the future, the session should be signaled too, but for now theres just v1
|
||||||
|
_log.debug('Transition id %d executed', transition_id)
|
||||||
|
|
||||||
async def voice_state_update(self, data: GuildVoiceStatePayload) -> None:
|
async def voice_state_update(self, data: GuildVoiceStatePayload) -> None:
|
||||||
channel_id = data['channel_id']
|
channel_id = data['channel_id']
|
||||||
|
|
||||||
|
|||||||
@@ -36,7 +36,10 @@ Documentation = "https://discordpy.readthedocs.io/en/latest/"
|
|||||||
dependencies = { file = "requirements.txt" }
|
dependencies = { file = "requirements.txt" }
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
voice = ["PyNaCl>=1.5.0,<1.6"]
|
voice = [
|
||||||
|
"PyNaCl>=1.5.0,<1.6",
|
||||||
|
"davey==0.1.0"
|
||||||
|
]
|
||||||
docs = [
|
docs = [
|
||||||
"sphinx==4.4.0",
|
"sphinx==4.4.0",
|
||||||
"sphinxcontrib_trio==1.1.2",
|
"sphinxcontrib_trio==1.1.2",
|
||||||
|
|||||||
Reference in New Issue
Block a user