Skip to main content

saorsa_mls/
api.rs

1// Copyright 2024 Saorsa Labs
2// SPDX-License-Identifier: AGPL-3.0-or-later
3
4//! Simplified API for MLS group messaging with QUIC stream integration
5
6use crate::{
7    crypto::{random_bytes, SecretBytes},
8    group::MlsGroup,
9    member::{MemberId, MemberIdentity},
10    protocol::{ApplicationMessage, CommitMessage},
11    MlsError, Result,
12};
13use bytes::Bytes;
14use parking_lot::RwLock;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::sync::Arc;
18
19/// Group identifier
20#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
21pub struct GroupId(pub Bytes);
22
23impl GroupId {
24    /// Generate a new random group ID
25    pub fn generate() -> Self {
26        Self(Bytes::from(random_bytes(32)))
27    }
28
29    /// Create from raw bytes
30    pub fn from_bytes(bytes: Vec<u8>) -> Self {
31        Self(Bytes::from(bytes))
32    }
33}
34
35/// Commit options for padding and metadata
36#[derive(Debug, Clone, Default)]
37pub struct CommitOptions {
38    /// Padding size for traffic analysis resistance
39    pub padding: usize,
40}
41
42/// Encrypted ciphertext with metadata
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct Ciphertext {
45    /// The encrypted payload
46    pub data: Bytes,
47    /// Sender's member ID
48    pub sender_id: MemberId,
49    /// Message sequence number
50    pub sequence: u64,
51    /// Epoch number
52    pub epoch: u64,
53    /// Signature over the ciphertext
54    #[serde(skip)]
55    pub signature: Option<crate::crypto::DebugMlDsaSignature>,
56}
57
58/// Identity type for simplified API
59pub type Identity = MemberIdentity;
60
61/// Commit type for simplified API
62pub type Commit = CommitMessage;
63
64/// Storage for group state persistence
65#[derive(Debug)]
66struct GroupStorage {
67    /// Current epoch number
68    epoch: u64,
69    /// Transcript hash for epoch
70    #[allow(dead_code)]
71    transcript_hash: SecretBytes,
72    /// Ratchet states per member
73    ratchets: HashMap<MemberId, SecretBytes>,
74}
75
76/// Group manager for simplified API
77#[derive(Debug)]
78pub struct GroupManager {
79    groups: Arc<RwLock<HashMap<GroupId, Arc<MlsGroup>>>>,
80    storage: Arc<RwLock<HashMap<GroupId, GroupStorage>>>,
81}
82
83impl Default for GroupManager {
84    fn default() -> Self {
85        Self::new()
86    }
87}
88
89impl GroupManager {
90    /// Create a new group manager
91    pub fn new() -> Self {
92        Self {
93            groups: Arc::new(RwLock::new(HashMap::new())),
94            storage: Arc::new(RwLock::new(HashMap::new())),
95        }
96    }
97
98    /// Create a new group with initial members
99    pub async fn group_new(&self, members: &[Identity]) -> Result<GroupId> {
100        self.group_new_with_config(members, crate::GroupConfig::default())
101            .await
102    }
103
104    /// Create a new group with explicit configuration (including cipher suite)
105    pub async fn group_new_with_config(
106        &self,
107        members: &[Identity],
108        config: crate::GroupConfig,
109    ) -> Result<GroupId> {
110        if members.is_empty() {
111            return Err(MlsError::InvalidGroupState(
112                "Group must have at least one member".to_string(),
113            ));
114        }
115
116        let suite = crate::CipherSuite::from_id(config.cipher_suite).ok_or_else(|| {
117            MlsError::InvalidGroupState(format!(
118                "unsupported cipher suite 0x{:04X}",
119                config.cipher_suite.as_u16()
120            ))
121        })?;
122
123        if let Some(mismatch) = members
124            .iter()
125            .find(|identity| identity.cipher_suite() != suite)
126        {
127            return Err(MlsError::InvalidGroupState(format!(
128                "member {} does not match group cipher suite",
129                mismatch.id
130            )));
131        }
132
133        let group_id = GroupId::generate();
134
135        // Use first member as creator
136        let creator = members[0].clone();
137        let mut group = MlsGroup::new(config, creator).await?;
138
139        // Add remaining members
140        for member in &members[1..] {
141            group.add_member(member).await?;
142        }
143
144        // Store group
145        let mut groups = self.groups.write();
146        groups.insert(group_id.clone(), Arc::new(group));
147
148        // Initialize storage
149        let storage = GroupStorage {
150            epoch: 0,
151            transcript_hash: SecretBytes::from(random_bytes(32)),
152            ratchets: HashMap::new(),
153        };
154        let mut storages = self.storage.write();
155        storages.insert(group_id.clone(), storage);
156
157        Ok(group_id)
158    }
159
160    /// Add a member to the group
161    pub async fn add_member(&self, group_id: &GroupId, id: Identity) -> Result<Commit> {
162        // We need to replace the group with a mutable one temporarily
163        // First, validate and remove under lock
164        let mut group = {
165            let mut groups = self.groups.write();
166            let existing = groups
167                .get(group_id)
168                .ok_or_else(|| MlsError::InvalidGroupState("Group not found".to_string()))?;
169
170            if id.cipher_suite() != existing.cipher_suite() {
171                return Err(MlsError::InvalidGroupState(
172                    "member identity does not match group cipher suite".to_string(),
173                ));
174            }
175
176            groups
177                .remove(group_id)
178                .expect("group must still exist after prior lookup")
179            // Lock is released here when `groups` goes out of scope
180        };
181
182        // Get mutable reference and add member (lock is not held during await)
183        let result = {
184            let group_mut = Arc::get_mut(&mut group).ok_or_else(|| {
185                MlsError::InvalidGroupState("Cannot modify shared group".to_string())
186            })?;
187
188            group_mut.add_member(&id).await
189        };
190
191        // Put the group back under a new lock
192        {
193            let mut groups = self.groups.write();
194            groups.insert(group_id.clone(), group);
195        }
196
197        let _welcome = result?;
198
199        // Update storage with new epoch
200        let mut storages = self.storage.write();
201        if let Some(storage) = storages.get_mut(group_id) {
202            storage.epoch += 1;
203            storage
204                .ratchets
205                .insert(id.id, SecretBytes::from(random_bytes(32)));
206        }
207        let _epoch = storages.get(group_id).map(|s| s.epoch).unwrap_or(0);
208
209        // Convert welcome to commit
210        // In a real implementation, this would contain actual proposals
211        Ok(Commit {
212            proposals: vec![crate::protocol::ProposalRef::Reference(vec![1, 2, 3])],
213            path: None,
214        })
215    }
216
217    /// Remove a member from the group
218    pub async fn remove_member(&self, group_id: &GroupId, id: Identity) -> Result<Commit> {
219        // We need to replace the group with a mutable one temporarily
220        // First, validate and remove under lock
221        let mut group = {
222            let mut groups = self.groups.write();
223            let existing = groups
224                .get(group_id)
225                .ok_or_else(|| MlsError::InvalidGroupState("Group not found".to_string()))?;
226
227            if id.cipher_suite() != existing.cipher_suite() {
228                return Err(MlsError::InvalidGroupState(
229                    "member identity does not match group cipher suite".to_string(),
230                ));
231            }
232
233            groups
234                .remove(group_id)
235                .expect("group must still exist after prior lookup")
236            // Lock is released here when `groups` goes out of scope
237        };
238
239        // Get mutable reference and remove member (lock is not held during await)
240        let result = {
241            let group_mut = Arc::get_mut(&mut group).ok_or_else(|| {
242                MlsError::InvalidGroupState("Cannot modify shared group".to_string())
243            })?;
244
245            group_mut.remove_member(&id.id).await
246        };
247
248        // Put the group back under a new lock
249        {
250            let mut groups = self.groups.write();
251            groups.insert(group_id.clone(), group);
252        }
253
254        result?;
255
256        // Update storage
257        let mut storages = self.storage.write();
258        if let Some(storage) = storages.get_mut(group_id) {
259            storage.epoch += 1;
260            storage.ratchets.remove(&id.id);
261        }
262        let _epoch = storages.get(group_id).map(|s| s.epoch).unwrap_or(0);
263
264        Ok(Commit {
265            proposals: vec![crate::protocol::ProposalRef::Reference(vec![4, 5, 6])],
266            path: None,
267        })
268    }
269
270    /// Send an encrypted message to the group
271    pub fn send(&self, group_id: &GroupId, app_data: &[u8]) -> Result<Ciphertext> {
272        let groups = self.groups.read();
273        let group = groups
274            .get(group_id)
275            .ok_or_else(|| MlsError::InvalidGroupState("Group not found".to_string()))?;
276
277        // Encrypt the message
278        let app_msg = group.encrypt_message(app_data)?;
279
280        // Get current state for metadata (not needed since we use message epoch)
281        let storages = self.storage.read();
282        let _storage = storages
283            .get(group_id)
284            .ok_or_else(|| MlsError::InvalidGroupState("Storage not found".to_string()))?;
285
286        Ok(Ciphertext {
287            data: Bytes::from(app_msg.ciphertext),
288            sender_id: app_msg.sender,
289            sequence: app_msg.sequence,
290            epoch: app_msg.epoch, // Use the actual epoch from the message
291            signature: Some(app_msg.signature),
292        })
293    }
294
295    /// Receive and decrypt a message from the group
296    pub fn recv(&self, group_id: &GroupId, ciphertext: &Ciphertext) -> Result<Vec<u8>> {
297        let groups = self.groups.read();
298        let group = groups
299            .get(group_id)
300            .ok_or_else(|| MlsError::InvalidGroupState("Group not found".to_string()))?;
301
302        // Reconstruct ApplicationMessage from Ciphertext
303        let app_msg = ApplicationMessage {
304            epoch: ciphertext.epoch,
305            sender: ciphertext.sender_id,
306            generation: 0, // Simplified
307            sequence: ciphertext.sequence,
308            ciphertext: ciphertext.data.to_vec(),
309            signature: ciphertext
310                .signature
311                .clone()
312                .ok_or_else(|| MlsError::InvalidMessage("Missing signature".to_string()))?,
313        };
314
315        // Decrypt the message
316        group.decrypt_message(&app_msg)
317    }
318}
319
320// Global instance for simplified API
321lazy_static::lazy_static! {
322    static ref MANAGER: GroupManager = GroupManager::new();
323}
324
325/// Create a new group with initial members
326///
327/// # Arguments
328/// * `members` - Initial group members
329///
330/// # Returns
331/// * `GroupId` - The identifier for the created group
332pub async fn group_new(members: &[Identity]) -> anyhow::Result<GroupId> {
333    MANAGER.group_new(members).await.map_err(Into::into)
334}
335
336/// Create a new group with explicit configuration
337pub async fn group_new_with_config(
338    members: &[Identity],
339    config: crate::GroupConfig,
340) -> anyhow::Result<GroupId> {
341    MANAGER
342        .group_new_with_config(members, config)
343        .await
344        .map_err(Into::into)
345}
346
347/// Add a member to an existing group
348///
349/// # Arguments
350/// * `g` - The group identifier
351/// * `id` - The identity of the member to add
352///
353/// # Returns
354/// * `Commit` - The commit message for the add operation
355pub async fn add_member(g: &GroupId, id: Identity) -> anyhow::Result<Commit> {
356    MANAGER.add_member(g, id).await.map_err(Into::into)
357}
358
359/// Remove a member from a group
360///
361/// # Arguments
362/// * `g` - The group identifier
363/// * `id` - The identity of the member to remove
364///
365/// # Returns
366/// * `Commit` - The commit message for the remove operation
367pub async fn remove_member(g: &GroupId, id: Identity) -> anyhow::Result<Commit> {
368    MANAGER.remove_member(g, id).await.map_err(Into::into)
369}
370
371/// Send an encrypted message to the group
372///
373/// # Arguments
374/// * `g` - The group identifier
375/// * `app` - The application data to encrypt
376///
377/// # Returns
378/// * `Ciphertext` - The encrypted message
379pub fn send(g: &GroupId, app: &[u8]) -> anyhow::Result<Ciphertext> {
380    MANAGER.send(g, app).map_err(Into::into)
381}
382
383/// Receive and decrypt a message from the group
384///
385/// # Arguments
386/// * `g` - The group identifier
387/// * `ct` - The encrypted ciphertext
388///
389/// # Returns
390/// * `Vec<u8>` - The decrypted application data
391pub fn recv(g: &GroupId, ct: &Ciphertext) -> anyhow::Result<Vec<u8>> {
392    MANAGER.recv(g, ct).map_err(Into::into)
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398
399    #[tokio::test]
400    async fn test_group_creation() {
401        let member1 = MemberIdentity::generate(MemberId::generate()).unwrap();
402        let member2 = MemberIdentity::generate(MemberId::generate()).unwrap();
403
404        let members = vec![member1, member2];
405        let group_id = group_new(&members).await.unwrap();
406
407        assert!(!group_id.0.is_empty());
408    }
409
410    #[tokio::test]
411    async fn test_add_remove_member() {
412        let member1 = MemberIdentity::generate(MemberId::generate()).unwrap();
413        let member2 = MemberIdentity::generate(MemberId::generate()).unwrap();
414        let member3 = MemberIdentity::generate(MemberId::generate()).unwrap();
415
416        let members = vec![member1.clone(), member2];
417        let group_id = group_new(&members).await.unwrap();
418
419        // Add member
420        let commit = add_member(&group_id, member3.clone()).await.unwrap();
421        assert!(!commit.proposals.is_empty() || commit.path.is_some());
422
423        // Remove member
424        let commit = remove_member(&group_id, member3).await.unwrap();
425        assert!(!commit.proposals.is_empty() || commit.path.is_some());
426    }
427
428    #[tokio::test]
429    async fn test_send_recv() {
430        let member1 = MemberIdentity::generate(MemberId::generate()).unwrap();
431        let member2 = MemberIdentity::generate(MemberId::generate()).unwrap();
432
433        let members = vec![member1, member2];
434        let group_id = group_new(&members).await.unwrap();
435
436        let message = b"Hello, MLS group!";
437        let ciphertext = send(&group_id, message).unwrap();
438
439        let decrypted = recv(&group_id, &ciphertext).unwrap();
440        assert_eq!(decrypted, message);
441    }
442}