1use crate::{
7 crypto::{random_bytes, SecretBytes},
8 group::MlsGroup,
9 member::{MemberId, MemberIdentity},
10 protocol::{ApplicationMessage, CommitMessage},
11 MlsError, Result,
12};
13use bytes::Bytes;
14use parking_lot::RwLock;
15use serde::{Deserialize, Serialize};
16use std::collections::HashMap;
17use std::sync::Arc;
18
19#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
21pub struct GroupId(pub Bytes);
22
23impl GroupId {
24 pub fn generate() -> Self {
26 Self(Bytes::from(random_bytes(32)))
27 }
28
29 pub fn from_bytes(bytes: Vec<u8>) -> Self {
31 Self(Bytes::from(bytes))
32 }
33}
34
35#[derive(Debug, Clone, Default)]
37pub struct CommitOptions {
38 pub padding: usize,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct Ciphertext {
45 pub data: Bytes,
47 pub sender_id: MemberId,
49 pub sequence: u64,
51 pub epoch: u64,
53 #[serde(skip)]
55 pub signature: Option<crate::crypto::DebugMlDsaSignature>,
56}
57
58pub type Identity = MemberIdentity;
60
61pub type Commit = CommitMessage;
63
64#[derive(Debug)]
66struct GroupStorage {
67 epoch: u64,
69 #[allow(dead_code)]
71 transcript_hash: SecretBytes,
72 ratchets: HashMap<MemberId, SecretBytes>,
74}
75
76#[derive(Debug)]
78pub struct GroupManager {
79 groups: Arc<RwLock<HashMap<GroupId, Arc<MlsGroup>>>>,
80 storage: Arc<RwLock<HashMap<GroupId, GroupStorage>>>,
81}
82
83impl Default for GroupManager {
84 fn default() -> Self {
85 Self::new()
86 }
87}
88
89impl GroupManager {
90 pub fn new() -> Self {
92 Self {
93 groups: Arc::new(RwLock::new(HashMap::new())),
94 storage: Arc::new(RwLock::new(HashMap::new())),
95 }
96 }
97
98 pub async fn group_new(&self, members: &[Identity]) -> Result<GroupId> {
100 self.group_new_with_config(members, crate::GroupConfig::default())
101 .await
102 }
103
104 pub async fn group_new_with_config(
106 &self,
107 members: &[Identity],
108 config: crate::GroupConfig,
109 ) -> Result<GroupId> {
110 if members.is_empty() {
111 return Err(MlsError::InvalidGroupState(
112 "Group must have at least one member".to_string(),
113 ));
114 }
115
116 let suite = crate::CipherSuite::from_id(config.cipher_suite).ok_or_else(|| {
117 MlsError::InvalidGroupState(format!(
118 "unsupported cipher suite 0x{:04X}",
119 config.cipher_suite.as_u16()
120 ))
121 })?;
122
123 if let Some(mismatch) = members
124 .iter()
125 .find(|identity| identity.cipher_suite() != suite)
126 {
127 return Err(MlsError::InvalidGroupState(format!(
128 "member {} does not match group cipher suite",
129 mismatch.id
130 )));
131 }
132
133 let group_id = GroupId::generate();
134
135 let creator = members[0].clone();
137 let mut group = MlsGroup::new(config, creator).await?;
138
139 for member in &members[1..] {
141 group.add_member(member).await?;
142 }
143
144 let mut groups = self.groups.write();
146 groups.insert(group_id.clone(), Arc::new(group));
147
148 let storage = GroupStorage {
150 epoch: 0,
151 transcript_hash: SecretBytes::from(random_bytes(32)),
152 ratchets: HashMap::new(),
153 };
154 let mut storages = self.storage.write();
155 storages.insert(group_id.clone(), storage);
156
157 Ok(group_id)
158 }
159
160 pub async fn add_member(&self, group_id: &GroupId, id: Identity) -> Result<Commit> {
162 let mut group = {
165 let mut groups = self.groups.write();
166 let existing = groups
167 .get(group_id)
168 .ok_or_else(|| MlsError::InvalidGroupState("Group not found".to_string()))?;
169
170 if id.cipher_suite() != existing.cipher_suite() {
171 return Err(MlsError::InvalidGroupState(
172 "member identity does not match group cipher suite".to_string(),
173 ));
174 }
175
176 groups
177 .remove(group_id)
178 .expect("group must still exist after prior lookup")
179 };
181
182 let result = {
184 let group_mut = Arc::get_mut(&mut group).ok_or_else(|| {
185 MlsError::InvalidGroupState("Cannot modify shared group".to_string())
186 })?;
187
188 group_mut.add_member(&id).await
189 };
190
191 {
193 let mut groups = self.groups.write();
194 groups.insert(group_id.clone(), group);
195 }
196
197 let _welcome = result?;
198
199 let mut storages = self.storage.write();
201 if let Some(storage) = storages.get_mut(group_id) {
202 storage.epoch += 1;
203 storage
204 .ratchets
205 .insert(id.id, SecretBytes::from(random_bytes(32)));
206 }
207 let _epoch = storages.get(group_id).map(|s| s.epoch).unwrap_or(0);
208
209 Ok(Commit {
212 proposals: vec![crate::protocol::ProposalRef::Reference(vec![1, 2, 3])],
213 path: None,
214 })
215 }
216
217 pub async fn remove_member(&self, group_id: &GroupId, id: Identity) -> Result<Commit> {
219 let mut group = {
222 let mut groups = self.groups.write();
223 let existing = groups
224 .get(group_id)
225 .ok_or_else(|| MlsError::InvalidGroupState("Group not found".to_string()))?;
226
227 if id.cipher_suite() != existing.cipher_suite() {
228 return Err(MlsError::InvalidGroupState(
229 "member identity does not match group cipher suite".to_string(),
230 ));
231 }
232
233 groups
234 .remove(group_id)
235 .expect("group must still exist after prior lookup")
236 };
238
239 let result = {
241 let group_mut = Arc::get_mut(&mut group).ok_or_else(|| {
242 MlsError::InvalidGroupState("Cannot modify shared group".to_string())
243 })?;
244
245 group_mut.remove_member(&id.id).await
246 };
247
248 {
250 let mut groups = self.groups.write();
251 groups.insert(group_id.clone(), group);
252 }
253
254 result?;
255
256 let mut storages = self.storage.write();
258 if let Some(storage) = storages.get_mut(group_id) {
259 storage.epoch += 1;
260 storage.ratchets.remove(&id.id);
261 }
262 let _epoch = storages.get(group_id).map(|s| s.epoch).unwrap_or(0);
263
264 Ok(Commit {
265 proposals: vec![crate::protocol::ProposalRef::Reference(vec![4, 5, 6])],
266 path: None,
267 })
268 }
269
270 pub fn send(&self, group_id: &GroupId, app_data: &[u8]) -> Result<Ciphertext> {
272 let groups = self.groups.read();
273 let group = groups
274 .get(group_id)
275 .ok_or_else(|| MlsError::InvalidGroupState("Group not found".to_string()))?;
276
277 let app_msg = group.encrypt_message(app_data)?;
279
280 let storages = self.storage.read();
282 let _storage = storages
283 .get(group_id)
284 .ok_or_else(|| MlsError::InvalidGroupState("Storage not found".to_string()))?;
285
286 Ok(Ciphertext {
287 data: Bytes::from(app_msg.ciphertext),
288 sender_id: app_msg.sender,
289 sequence: app_msg.sequence,
290 epoch: app_msg.epoch, signature: Some(app_msg.signature),
292 })
293 }
294
295 pub fn recv(&self, group_id: &GroupId, ciphertext: &Ciphertext) -> Result<Vec<u8>> {
297 let groups = self.groups.read();
298 let group = groups
299 .get(group_id)
300 .ok_or_else(|| MlsError::InvalidGroupState("Group not found".to_string()))?;
301
302 let app_msg = ApplicationMessage {
304 epoch: ciphertext.epoch,
305 sender: ciphertext.sender_id,
306 generation: 0, sequence: ciphertext.sequence,
308 ciphertext: ciphertext.data.to_vec(),
309 signature: ciphertext
310 .signature
311 .clone()
312 .ok_or_else(|| MlsError::InvalidMessage("Missing signature".to_string()))?,
313 };
314
315 group.decrypt_message(&app_msg)
317 }
318}
319
320lazy_static::lazy_static! {
322 static ref MANAGER: GroupManager = GroupManager::new();
323}
324
325pub async fn group_new(members: &[Identity]) -> anyhow::Result<GroupId> {
333 MANAGER.group_new(members).await.map_err(Into::into)
334}
335
336pub async fn group_new_with_config(
338 members: &[Identity],
339 config: crate::GroupConfig,
340) -> anyhow::Result<GroupId> {
341 MANAGER
342 .group_new_with_config(members, config)
343 .await
344 .map_err(Into::into)
345}
346
347pub async fn add_member(g: &GroupId, id: Identity) -> anyhow::Result<Commit> {
356 MANAGER.add_member(g, id).await.map_err(Into::into)
357}
358
359pub async fn remove_member(g: &GroupId, id: Identity) -> anyhow::Result<Commit> {
368 MANAGER.remove_member(g, id).await.map_err(Into::into)
369}
370
371pub fn send(g: &GroupId, app: &[u8]) -> anyhow::Result<Ciphertext> {
380 MANAGER.send(g, app).map_err(Into::into)
381}
382
383pub fn recv(g: &GroupId, ct: &Ciphertext) -> anyhow::Result<Vec<u8>> {
392 MANAGER.recv(g, ct).map_err(Into::into)
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398
399 #[tokio::test]
400 async fn test_group_creation() {
401 let member1 = MemberIdentity::generate(MemberId::generate()).unwrap();
402 let member2 = MemberIdentity::generate(MemberId::generate()).unwrap();
403
404 let members = vec![member1, member2];
405 let group_id = group_new(&members).await.unwrap();
406
407 assert!(!group_id.0.is_empty());
408 }
409
410 #[tokio::test]
411 async fn test_add_remove_member() {
412 let member1 = MemberIdentity::generate(MemberId::generate()).unwrap();
413 let member2 = MemberIdentity::generate(MemberId::generate()).unwrap();
414 let member3 = MemberIdentity::generate(MemberId::generate()).unwrap();
415
416 let members = vec![member1.clone(), member2];
417 let group_id = group_new(&members).await.unwrap();
418
419 let commit = add_member(&group_id, member3.clone()).await.unwrap();
421 assert!(!commit.proposals.is_empty() || commit.path.is_some());
422
423 let commit = remove_member(&group_id, member3).await.unwrap();
425 assert!(!commit.proposals.is_empty() || commit.path.is_some());
426 }
427
428 #[tokio::test]
429 async fn test_send_recv() {
430 let member1 = MemberIdentity::generate(MemberId::generate()).unwrap();
431 let member2 = MemberIdentity::generate(MemberId::generate()).unwrap();
432
433 let members = vec![member1, member2];
434 let group_id = group_new(&members).await.unwrap();
435
436 let message = b"Hello, MLS group!";
437 let ciphertext = send(&group_id, message).unwrap();
438
439 let decrypted = recv(&group_id, &ciphertext).unwrap();
440 assert_eq!(decrypted, message);
441 }
442}