Skip to main content

saorsa_mls/
protocol.rs

1//! MLS protocol messages and state machine
2
3use crate::{
4    crypto::{AeadCipher, CipherSuite, CipherSuiteId, DebugMlDsaSignature, Hash},
5    member::*,
6    EpochNumber, MessageSequence, MlsError, Result,
7};
8// postcard serialization (bincode removed)
9use saorsa_pqc::api::{
10    MlDsa, MlDsaPublicKey, MlDsaSecretKey, MlKem, MlKemCiphertext, MlKemSecretKey,
11};
12use serde::{Deserialize, Serialize};
13
14/// MLS message types
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub enum MlsMessage {
17    /// Handshake message for group operations
18    Handshake(HandshakeMessage),
19    /// Application message with encrypted content
20    Application(ApplicationMessage),
21    /// Welcome message for new members
22    Welcome(WelcomeMessage),
23}
24
25impl MlsMessage {
26    /// Get the epoch number for this message
27    pub fn epoch(&self) -> EpochNumber {
28        match self {
29            Self::Handshake(msg) => msg.epoch,
30            Self::Application(msg) => msg.epoch,
31            Self::Welcome(msg) => msg.epoch,
32        }
33    }
34
35    /// Get the sender of this message
36    pub fn sender(&self) -> MemberId {
37        match self {
38            Self::Handshake(msg) => msg.sender,
39            Self::Application(msg) => msg.sender,
40            Self::Welcome(msg) => msg.sender,
41        }
42    }
43
44    /// Verify the message signature
45    pub fn verify_signature(
46        &self,
47        verifying_key: &MlDsaPublicKey,
48        suite: CipherSuite,
49    ) -> Result<bool> {
50        let (data, signature, suite_for_message) = match self {
51            Self::Handshake(msg) => (&msg.content, &msg.signature.0, suite),
52            Self::Application(msg) => (&msg.ciphertext, &msg.signature.0, suite),
53            Self::Welcome(msg) => (&msg.group_info, &msg.signature.0, msg.cipher_suite),
54        };
55
56        let ml_dsa = MlDsa::new(suite_for_message.ml_dsa_variant());
57        ml_dsa
58            .verify(verifying_key, data, signature)
59            .map_err(|e| MlsError::InvalidMessage(format!("invalid signature: {e:?}")))
60    }
61}
62
63/// Handshake message content types
64#[derive(Debug, Clone, Serialize, Deserialize)]
65pub enum HandshakeContent {
66    /// Add a new member to the group
67    Add(AddProposal),
68    /// Remove a member from the group
69    Remove(RemoveProposal),
70    /// Update member's key material
71    Update(UpdateProposal),
72    /// Commit pending proposals
73    Commit(CommitMessage),
74}
75
76/// Handshake message for group operations
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct HandshakeMessage {
79    pub epoch: EpochNumber,
80    pub sender: MemberId,
81    pub content: Vec<u8>,
82    pub signature: DebugMlDsaSignature,
83}
84
85impl HandshakeMessage {
86    /// Create a signed handshake message for the given content
87    pub fn new_signed(
88        epoch: EpochNumber,
89        sender: MemberId,
90        content: Vec<u8>,
91        signing_key: &MlDsaSecretKey,
92        suite: CipherSuite,
93    ) -> Result<Self> {
94        let ml_dsa = MlDsa::new(suite.ml_dsa_variant());
95        let signature = ml_dsa
96            .sign(signing_key, &content)
97            .map_err(|e| MlsError::CryptoError(format!("Signing failed: {e:?}")))?;
98
99        Ok(Self {
100            epoch,
101            sender,
102            content,
103            signature: DebugMlDsaSignature(signature),
104        })
105    }
106
107    /// Verify the handshake message signature using the provided suite
108    pub fn verify_signature(
109        &self,
110        verifying_key: &MlDsaPublicKey,
111        suite: CipherSuite,
112    ) -> Result<bool> {
113        let ml_dsa = MlDsa::new(suite.ml_dsa_variant());
114        ml_dsa
115            .verify(verifying_key, &self.content, &self.signature.0)
116            .map_err(|e| MlsError::InvalidMessage(format!("invalid signature: {e:?}")))
117    }
118}
119
120/// Application message with encrypted payload
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct ApplicationMessage {
123    pub epoch: EpochNumber,
124    pub sender: MemberId,
125    pub generation: u32,
126    pub sequence: MessageSequence,
127    pub ciphertext: Vec<u8>,
128    pub signature: DebugMlDsaSignature,
129}
130
131/// Welcome message for new members
132#[derive(Debug, Clone, Serialize, Deserialize)]
133pub struct WelcomeMessage {
134    pub epoch: EpochNumber,
135    pub sender: MemberId,
136    pub cipher_suite: CipherSuite,
137    pub group_info: Vec<u8>,
138    pub secrets: Vec<EncryptedGroupSecrets>,
139    pub signature: DebugMlDsaSignature,
140}
141
142/// Proposal to add a new member
143#[derive(Debug, Clone, Serialize, Deserialize)]
144pub struct AddProposal {
145    pub key_package: KeyPackage,
146}
147
148/// Proposal to remove a member
149#[derive(Debug, Clone, Serialize, Deserialize)]
150pub struct RemoveProposal {
151    pub removed: MemberId,
152}
153
154/// Proposal to update member's keys
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct UpdateProposal {
157    pub key_package: KeyPackage,
158    pub signature: DebugMlDsaSignature,
159}
160
161/// Commit message containing proposals and path updates
162#[derive(Debug, Clone, Serialize, Deserialize)]
163pub struct CommitMessage {
164    pub proposals: Vec<ProposalRef>,
165    pub path: Option<UpdatePath>,
166}
167
168/// Reference to a proposal
169#[derive(Debug, Clone, Serialize, Deserialize)]
170pub enum ProposalRef {
171    /// Reference to a proposal by hash
172    Reference(Vec<u8>),
173    /// Inline proposal
174    Inline(ProposalContent),
175}
176
177/// Proposal content wrapper
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub enum ProposalContent {
180    Add(AddProposal),
181    Remove(RemoveProposal),
182    Update(UpdateProposal),
183}
184
185/// Update path for tree operations
186#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct UpdatePath {
188    pub leaf_key_package: KeyPackage,
189    pub nodes: Vec<UpdatePathNode>,
190}
191
192/// Node in an update path
193#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct UpdatePathNode {
195    pub public_key: Vec<u8>,
196    pub encrypted_path_secret: Vec<EncryptedPathSecret>,
197}
198
199/// Encrypted group secrets for welcome messages
200#[derive(Debug, Clone, Serialize, Deserialize)]
201pub struct EncryptedGroupSecrets {
202    pub recipient_key_package_hash: Vec<u8>,
203    pub kem_ciphertext: Vec<u8>,
204    pub encrypted_path_secret: Vec<u8>,
205}
206
207/// Message framing with metadata
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct MessageFrame {
210    pub schema_version: u8,
211    pub message_type: MessageType,
212    pub epoch: EpochNumber,
213    pub sender: MemberId,
214    pub authenticated_data: Vec<u8>,
215    pub payload: Vec<u8>,
216    pub signature: DebugMlDsaSignature,
217}
218
219/// Message types in the protocol
220#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
221pub enum MessageType {
222    Handshake = 1,
223    Application = 2,
224    Welcome = 3,
225    GroupInfo = 4,
226    KeyPackage = 5,
227}
228
229/// Group information for synchronization
230#[derive(Debug, Clone, Serialize, Deserialize)]
231pub struct GroupInfo {
232    pub group_id: Vec<u8>,
233    pub epoch: EpochNumber,
234    pub tree_hash: Vec<u8>,
235    pub confirmed_transcript_hash: Vec<u8>,
236    pub extensions: Vec<Extension>,
237    pub confirmation_tag: Vec<u8>,
238    pub signer: MemberId,
239}
240
241/// Tree structure for key management
242#[derive(Debug, Clone, Serialize, Deserialize)]
243pub struct TreeKemState {
244    pub nodes: Vec<TreeNode>,
245    pub epoch: EpochNumber,
246}
247
248/// Node in the TreeKEM structure
249#[derive(Debug, Clone, Serialize, Deserialize)]
250pub enum TreeNode {
251    Leaf(LeafNode),
252    Parent(ParentNode),
253}
254
255/// Leaf node containing member information
256#[derive(Debug, Clone, Serialize, Deserialize)]
257pub struct LeafNode {
258    pub key_package: Option<KeyPackage>,
259    pub unmerged_leaves: Vec<MemberId>,
260}
261
262/// Parent node in the tree
263#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct ParentNode {
265    pub public_key: Option<Vec<u8>>,
266    pub unmerged_leaves: Vec<MemberId>,
267}
268
269/// Encrypted path secret for tree operations
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct EncryptedPathSecret {
272    /// Recipient of this encrypted secret
273    pub recipient: MemberId,
274    /// Encrypted path secret using ML-KEM (serialized as bytes)
275    pub ciphertext: Vec<u8>,
276}
277
278/// Protocol constants
279pub mod constants {
280    /// Maximum group size
281    pub const MAX_GROUP_SIZE: usize = 1000;
282    /// Maximum message size in bytes
283    pub const MAX_MESSAGE_SIZE: usize = 1024 * 1024; // 1MB
284    /// Default epoch lifetime in seconds
285    pub const EPOCH_LIFETIME: u64 = 86400; // 24 hours
286}
287
288/// Validation functions for protocol messages
289impl HandshakeMessage {
290    /// Validate the handshake message
291    pub fn validate(&self) -> Result<()> {
292        if self.content.is_empty() {
293            return Err(MlsError::InvalidMessage(
294                "Empty handshake content".to_string(),
295            ));
296        }
297        if self.content.len() > constants::MAX_MESSAGE_SIZE {
298            return Err(MlsError::InvalidMessage("Message too large".to_string()));
299        }
300        Ok(())
301    }
302}
303
304impl ApplicationMessage {
305    /// Validate the application message
306    pub fn validate(&self) -> Result<()> {
307        if self.ciphertext.is_empty() {
308            return Err(MlsError::InvalidMessage("Empty ciphertext".to_string()));
309        }
310        if self.ciphertext.len() > constants::MAX_MESSAGE_SIZE {
311            return Err(MlsError::InvalidMessage("Message too large".to_string()));
312        }
313        Ok(())
314    }
315}
316
317impl WelcomeMessage {
318    /// Validate the welcome message
319    pub fn validate(&self) -> Result<()> {
320        if self.group_info.is_empty() {
321            return Err(MlsError::InvalidMessage("Empty group info".to_string()));
322        }
323        if self.secrets.is_empty() {
324            return Err(MlsError::InvalidMessage("No encrypted secrets".to_string()));
325        }
326        Ok(())
327    }
328
329    /// Verify the welcome message signature against the creator's public key
330    pub fn verify_signature(&self, verifying_key: &MlDsaPublicKey) -> Result<bool> {
331        let ml_dsa = MlDsa::new(self.cipher_suite.ml_dsa_variant());
332        ml_dsa
333            .verify(verifying_key, &self.group_info, &self.signature.0)
334            .map_err(|e| MlsError::InvalidMessage(format!("invalid signature: {e:?}")))
335    }
336}
337
338impl EncryptedGroupSecrets {
339    /// Return the ML-KEM ciphertext for this recipient
340    pub fn ciphertext(&self, suite: &CipherSuite) -> Result<MlKemCiphertext> {
341        MlKemCiphertext::from_bytes(suite.ml_kem_variant(), &self.kem_ciphertext)
342            .map_err(|e| MlsError::CryptoError(format!("Invalid ML-KEM ciphertext: {e:?}")))
343    }
344
345    fn hkdf_expand(shared_secret_bytes: &[u8], label: &[u8], length: usize) -> Result<Vec<u8>> {
346        use saorsa_pqc::api::{kdf::HkdfSha3_256, traits::Kdf};
347
348        let mut output = vec![0u8; length];
349        HkdfSha3_256::derive(shared_secret_bytes, None, label, &mut output)
350            .map_err(|e| MlsError::CryptoError(format!("HKDF error: {e:?}")))?;
351        Ok(output)
352    }
353
354    fn encrypt_application_secret(
355        suite: CipherSuite,
356        shared_secret_bytes: &[u8],
357        application_secret: &[u8],
358    ) -> Result<Vec<u8>> {
359        let key = Self::hkdf_expand(shared_secret_bytes, b"saorsa aead key", suite.key_size())?;
360        let nonce = Self::hkdf_expand(
361            shared_secret_bytes,
362            b"saorsa aead nonce",
363            suite.nonce_size(),
364        )?;
365        let cipher = AeadCipher::new(key, suite)?;
366        cipher
367            .encrypt(&nonce, application_secret, &[])
368            .map_err(|e| MlsError::CryptoError(format!("Path secret encrypt failed: {e:?}")))
369    }
370
371    fn decapsulate_shared_bytes(
372        &self,
373        suite: &CipherSuite,
374        kem_secret: &MlKemSecretKey,
375    ) -> Result<Vec<u8>> {
376        let ciphertext = self.ciphertext(suite)?;
377        let ml_kem = MlKem::new(suite.ml_kem_variant());
378        let shared = ml_kem
379            .decapsulate(kem_secret, &ciphertext)
380            .map_err(|e| MlsError::CryptoError(format!("Decapsulation failed: {e:?}")))?;
381        Ok(shared.to_bytes().to_vec())
382    }
383
384    /// Decapsulate the path secret using the recipient's private key
385    pub fn decapsulate_path_secret(
386        &self,
387        suite: &CipherSuite,
388        kem_secret: &MlKemSecretKey,
389    ) -> Result<Vec<u8>> {
390        let shared_bytes = self.decapsulate_shared_bytes(suite, kem_secret)?;
391        let key = Self::hkdf_expand(&shared_bytes, b"saorsa aead key", suite.key_size())?;
392        let expected_nonce =
393            Self::hkdf_expand(&shared_bytes, b"saorsa aead nonce", suite.nonce_size())?;
394
395        if self.encrypted_path_secret.len() < suite.nonce_size() {
396            return Err(MlsError::InvalidMessage(
397                "Invalid encrypted path secret".to_string(),
398            ));
399        }
400
401        let stored_nonce = &self.encrypted_path_secret[..suite.nonce_size()];
402
403        if stored_nonce != expected_nonce.as_slice() {
404            return Err(MlsError::InvalidMessage(
405                "Encrypted path secret nonce mismatch".to_string(),
406            ));
407        }
408
409        let cipher = AeadCipher::new(key, *suite)?;
410        cipher
411            .decrypt(&expected_nonce, &self.encrypted_path_secret, &[])
412            .map_err(|e| MlsError::CryptoError(format!("Path secret decrypt failed: {e:?}")))
413    }
414
415    pub(crate) fn encrypt_for_recipient(
416        suite: CipherSuite,
417        shared_secret_bytes: &[u8],
418        application_secret: &[u8],
419    ) -> Result<Vec<u8>> {
420        Self::encrypt_application_secret(suite, shared_secret_bytes, application_secret)
421    }
422}
423
424/// State machine for protocol message processing
425#[derive(Debug, Clone)]
426pub struct ProtocolSessionState {
427    pub epoch: EpochNumber,
428    pub pending_proposals: Vec<ProposalContent>,
429    pub confirmed_transcript_hash: Vec<u8>,
430}
431
432impl ProtocolSessionState {
433    /// Create a new protocol state
434    pub fn new(epoch: EpochNumber) -> Self {
435        Self {
436            epoch,
437            pending_proposals: Vec::new(),
438            confirmed_transcript_hash: Vec::new(),
439        }
440    }
441
442    /// Add a proposal to pending list
443    pub fn add_proposal(&mut self, proposal: ProposalContent) {
444        self.pending_proposals.push(proposal);
445    }
446
447    /// Clear pending proposals after commit
448    pub fn clear_proposals(&mut self) {
449        self.pending_proposals.clear();
450    }
451
452    /// Update transcript hash
453    pub fn update_transcript(&mut self, data: &[u8]) {
454        let hasher = Hash::new(CipherSuite::default());
455        let mut input = self.confirmed_transcript_hash.clone();
456        input.extend_from_slice(data);
457        self.confirmed_transcript_hash = hasher.hash(&input);
458    }
459}
460
461/// Serialization helpers
462impl MlsMessage {
463    /// Serialize message to bytes
464    pub fn to_bytes(&self) -> Result<Vec<u8>> {
465        postcard::to_stdvec(self).map_err(|e| MlsError::SerializationError(e.to_string()))
466    }
467
468    /// Deserialize message from bytes
469    pub fn from_bytes(data: &[u8]) -> Result<Self> {
470        postcard::from_bytes(data).map_err(|e| MlsError::DeserializationError(e.to_string()))
471    }
472}
473
474/// Configuration for an MLS group
475#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
476pub struct GroupConfig {
477    /// Protocol version
478    pub protocol_version: u16,
479    /// Cipher suite identifier
480    pub cipher_suite: crate::crypto::CipherSuiteId,
481    /// Maximum number of members
482    pub max_members: Option<u32>,
483    /// Group lifetime in seconds
484    pub lifetime: Option<u64>,
485    /// Maximum epoch age in milliseconds (SPEC-2 §3: default 24 hours)
486    pub max_epoch_age_millis: u64,
487    /// Maximum messages per epoch (SPEC-2 §3: default 10,000)
488    pub max_messages_per_epoch: u64,
489    /// Schema version for forward compatibility
490    pub schema_version: u8,
491}
492
493impl GroupConfig {
494    /// Create a new group configuration
495    pub fn new(protocol_version: u16, cipher_suite: CipherSuiteId) -> Self {
496        Self {
497            protocol_version,
498            cipher_suite,
499            max_members: None,
500            lifetime: None,
501            max_epoch_age_millis: 24 * 3600 * 1000, // 24 hours in milliseconds per SPEC-2 §3
502            max_messages_per_epoch: 10_000,         // 10,000 messages per SPEC-2 §3
503            schema_version: 1,
504        }
505    }
506
507    /// Set cipher suite identifier
508    pub fn with_cipher_suite(mut self, cipher_suite: CipherSuiteId) -> Self {
509        self.cipher_suite = cipher_suite;
510        self
511    }
512
513    /// Set maximum number of members
514    pub fn with_max_members(mut self, max_members: u32) -> Self {
515        self.max_members = Some(max_members);
516        self
517    }
518
519    /// Set group lifetime
520    pub fn with_lifetime(mut self, lifetime: u64) -> Self {
521        self.lifetime = Some(lifetime);
522        self
523    }
524
525    /// Set maximum epoch age (SPEC-2 §3 requirement)
526    pub fn with_max_epoch_age(mut self, duration: std::time::Duration) -> Self {
527        self.max_epoch_age_millis = duration.as_millis() as u64;
528        self
529    }
530
531    /// Set maximum messages per epoch (SPEC-2 §3 requirement)
532    pub fn with_max_messages_per_epoch(mut self, count: u64) -> Self {
533        self.max_messages_per_epoch = count;
534        self
535    }
536
537    /// Get maximum epoch age as Duration
538    pub fn max_epoch_age(&self) -> std::time::Duration {
539        std::time::Duration::from_millis(self.max_epoch_age_millis)
540    }
541
542    /// Get maximum messages per epoch
543    pub fn max_messages_per_epoch(&self) -> u64 {
544        self.max_messages_per_epoch
545    }
546}
547
548impl Default for GroupConfig {
549    fn default() -> Self {
550        // SPEC-2 default: ChaCha20Poly1305 + SHA256 + ML-DSA-65 (0x0B01)
551        Self::new(
552            1,
553            CipherSuiteId::SPEC2_MLS_128_MLKEM768_CHACHA20POLY1305_SHA256_MLDSA65,
554        )
555    }
556}
557
558/// Unique identifier for an MLS group
559#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
560pub struct GroupId(Vec<u8>);
561
562impl GroupId {
563    /// Create a new group ID from bytes
564    pub fn new(id: Vec<u8>) -> Self {
565        Self(id)
566    }
567
568    /// Generate a random group ID
569    pub fn generate() -> Self {
570        use rand_core::{OsRng, RngCore};
571        let mut id = vec![0u8; 32];
572        OsRng.fill_bytes(&mut id);
573        Self(id)
574    }
575
576    /// Get the group ID as bytes
577    pub fn as_bytes(&self) -> &[u8] {
578        &self.0
579    }
580
581    /// Convert to bytes vector
582    pub fn into_bytes(self) -> Vec<u8> {
583        self.0
584    }
585}
586
587impl From<Vec<u8>> for GroupId {
588    fn from(bytes: Vec<u8>) -> Self {
589        Self(bytes)
590    }
591}
592
593impl From<&[u8]> for GroupId {
594    fn from(bytes: &[u8]) -> Self {
595        Self(bytes.to_vec())
596    }
597}
598
599impl AsRef<[u8]> for GroupId {
600    fn as_ref(&self) -> &[u8] {
601        &self.0
602    }
603}
604
605impl std::fmt::Display for GroupId {
606    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
607        write!(f, "{}", hex::encode(&self.0))
608    }
609}
610
611/// State machine for managing MLS protocol state
612#[derive(Debug, Clone, Serialize, Deserialize)]
613pub struct ProtocolStateMachine {
614    /// Current epoch number
615    pub epoch: u64,
616    /// Current state
617    pub state: ProtocolState,
618    /// Schema version for forward compatibility
619    pub schema_version: u8,
620}
621
622/// Protocol states
623#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
624pub enum ProtocolState {
625    /// Initial state before group creation
626    Initial,
627    /// Group is being created
628    Creating,
629    /// Group is active and operational
630    Active,
631    /// Group is being updated
632    Updating,
633    /// Group has been terminated
634    Terminated,
635}
636
637impl ProtocolStateMachine {
638    /// Create a new protocol state machine
639    pub fn new(epoch: u64) -> Self {
640        Self {
641            epoch,
642            state: ProtocolState::Initial,
643            schema_version: 1,
644        }
645    }
646
647    /// Transition to creating state
648    pub fn start_creation(&mut self) -> Result<()> {
649        match self.state {
650            ProtocolState::Initial => {
651                self.state = ProtocolState::Creating;
652                Ok(())
653            }
654            _ => Err(MlsError::InvalidGroupState(format!(
655                "Cannot start creation from state {:?}",
656                self.state
657            ))),
658        }
659    }
660
661    /// Transition to active state
662    pub fn activate(&mut self) -> Result<()> {
663        match self.state {
664            ProtocolState::Creating => {
665                self.state = ProtocolState::Active;
666                Ok(())
667            }
668            _ => Err(MlsError::InvalidGroupState(format!(
669                "Cannot activate from state {:?}",
670                self.state
671            ))),
672        }
673    }
674
675    /// Start an update operation
676    pub fn start_update(&mut self) -> Result<()> {
677        match self.state {
678            ProtocolState::Active => {
679                self.state = ProtocolState::Updating;
680                Ok(())
681            }
682            _ => Err(MlsError::InvalidGroupState(format!(
683                "Cannot start update from state {:?}",
684                self.state
685            ))),
686        }
687    }
688
689    /// Complete an update operation
690    pub fn complete_update(&mut self) -> Result<()> {
691        match self.state {
692            ProtocolState::Updating => {
693                self.state = ProtocolState::Active;
694                self.epoch += 1;
695                Ok(())
696            }
697            _ => Err(MlsError::InvalidGroupState(format!(
698                "Cannot complete update from state {:?}",
699                self.state
700            ))),
701        }
702    }
703
704    /// Terminate the group
705    pub fn terminate(&mut self) -> Result<()> {
706        if matches!(self.state, ProtocolState::Terminated) {
707            return Err(MlsError::InvalidGroupState(
708                "Group is already terminated".to_string(),
709            ));
710        }
711
712        self.state = ProtocolState::Terminated;
713        Ok(())
714    }
715
716    /// Get current state
717    pub fn state(&self) -> &ProtocolState {
718        &self.state
719    }
720
721    /// Get current epoch
722    pub fn epoch(&self) -> u64 {
723        self.epoch
724    }
725
726    /// Check if group is active
727    pub fn is_active(&self) -> bool {
728        matches!(self.state, ProtocolState::Active)
729    }
730
731    /// Check if group is terminated
732    pub fn is_terminated(&self) -> bool {
733        matches!(self.state, ProtocolState::Terminated)
734    }
735
736    /// Set the epoch number (internal use)
737    pub fn set_epoch(&mut self, epoch: u64) {
738        self.epoch = epoch;
739    }
740}
741
742#[cfg(test)]
743mod tests {
744    use super::*;
745    use crate::crypto::KeyPair;
746
747    #[test]
748    fn test_message_serialization() {
749        let msg = HandshakeMessage {
750            epoch: 0,
751            sender: MemberId::generate(),
752            content: vec![1, 2, 3],
753            signature: create_test_signature(),
754        };
755
756        let mls_msg = MlsMessage::Handshake(msg);
757        let bytes = mls_msg.to_bytes().unwrap();
758        let decoded = MlsMessage::from_bytes(&bytes).unwrap();
759
760        assert_eq!(mls_msg.epoch(), decoded.epoch());
761        assert_eq!(mls_msg.sender(), decoded.sender());
762    }
763
764    #[test]
765    fn test_handshake_validation() {
766        let valid = HandshakeMessage {
767            epoch: 0,
768            sender: MemberId::generate(),
769            content: vec![1, 2, 3],
770            signature: create_test_signature(),
771        };
772        assert!(valid.validate().is_ok());
773
774        let empty = HandshakeMessage {
775            epoch: 0,
776            sender: MemberId::generate(),
777            content: vec![],
778            signature: create_test_signature(),
779        };
780        assert!(empty.validate().is_err());
781    }
782
783    #[test]
784    fn test_protocol_state() {
785        let mut state = ProtocolSessionState::new(0);
786        assert!(state.pending_proposals.is_empty());
787
788        let proposal = ProposalContent::Remove(RemoveProposal {
789            removed: MemberId::generate(),
790        });
791        state.add_proposal(proposal);
792        assert_eq!(state.pending_proposals.len(), 1);
793
794        state.clear_proposals();
795        assert!(state.pending_proposals.is_empty());
796    }
797
798    #[test]
799    fn test_tree_node_types() {
800        let leaf = TreeNode::Leaf(LeafNode {
801            key_package: None,
802            unmerged_leaves: vec![],
803        });
804
805        let parent = TreeNode::Parent(ParentNode {
806            public_key: None,
807            unmerged_leaves: vec![],
808        });
809
810        match leaf {
811            TreeNode::Leaf(_) => (),
812            TreeNode::Parent(_) => panic!("Expected leaf node"),
813        }
814
815        match parent {
816            TreeNode::Parent(_) => (),
817            TreeNode::Leaf(_) => panic!("Expected parent node"),
818        }
819    }
820
821    #[test]
822    fn test_message_type_equality() {
823        assert_eq!(MessageType::Handshake, MessageType::Handshake);
824        assert_ne!(MessageType::Handshake, MessageType::Application);
825    }
826
827    #[test]
828    fn test_group_info_serialization() {
829        let info = GroupInfo {
830            group_id: vec![1, 2, 3],
831            epoch: 42,
832            tree_hash: vec![4, 5, 6],
833            confirmed_transcript_hash: vec![7, 8, 9],
834            extensions: vec![],
835            confirmation_tag: vec![10, 11, 12],
836            signer: MemberId::generate(),
837        };
838
839        let bytes = postcard::to_stdvec(&info).unwrap();
840        let decoded: GroupInfo = postcard::from_bytes(&bytes).unwrap();
841
842        assert_eq!(info.group_id, decoded.group_id);
843        assert_eq!(info.epoch, decoded.epoch);
844    }
845
846    #[test]
847    fn test_update_path_construction() {
848        let keypair = KeyPair::generate(CipherSuite::default());
849        let member_id = MemberId::generate();
850        let cred = Credential::new_basic(member_id, None, &keypair, keypair.suite).unwrap();
851        let key_package = KeyPackage::new(keypair, cred).unwrap();
852
853        let path = UpdatePath {
854            leaf_key_package: key_package,
855            nodes: vec![],
856        };
857
858        assert!(path.nodes.is_empty());
859    }
860
861    // Helper function to create test signature
862    fn create_test_signature() -> DebugMlDsaSignature {
863        let keypair = KeyPair::generate(CipherSuite::default());
864        let sig = keypair.sign(b"test").unwrap();
865        match sig {
866            crate::crypto::Signature::MlDsa(ml_dsa_sig) => DebugMlDsaSignature(ml_dsa_sig),
867            _ => panic!("Expected ML-DSA signature for default suite"),
868        }
869    }
870
871    #[test]
872    fn test_encrypted_path_secret() {
873        let keypair1 = KeyPair::generate(CipherSuite::default());
874        let keypair2 = KeyPair::generate(CipherSuite::default());
875        let member_id = MemberId::generate();
876
877        // Create encrypted path secret using ML-KEM
878        let (ciphertext, _shared_secret) = keypair1.encapsulate(keypair2.public_key()).unwrap();
879
880        let eps = EncryptedPathSecret {
881            recipient: member_id,
882            ciphertext: ciphertext.to_bytes(),
883        };
884
885        assert_eq!(eps.recipient, member_id);
886    }
887
888    #[test]
889    fn test_welcome_message_validation() {
890        let valid = WelcomeMessage {
891            epoch: 0,
892            sender: MemberId::generate(),
893            cipher_suite: CipherSuite::default(),
894            group_info: vec![1, 2, 3],
895            secrets: vec![EncryptedGroupSecrets {
896                recipient_key_package_hash: vec![1],
897                kem_ciphertext: vec![2],
898                encrypted_path_secret: vec![3],
899            }],
900            signature: create_test_signature(),
901        };
902        assert!(valid.validate().is_ok());
903
904        let no_secrets = WelcomeMessage {
905            epoch: 0,
906            sender: MemberId::generate(),
907            cipher_suite: CipherSuite::default(),
908            group_info: vec![1, 2, 3],
909            secrets: vec![],
910            signature: create_test_signature(),
911        };
912        assert!(no_secrets.validate().is_err());
913    }
914}
915
916/// Audit log entry for group operations (SPEC-2 §8)
917#[derive(Debug, Clone, Serialize, Deserialize)]
918pub struct AuditLogEntry {
919    /// Timestamp of the event
920    pub timestamp: std::time::SystemTime,
921    /// Type of event (group_created, epoch_advanced, member_added, etc.)
922    pub event_type: String,
923    /// Cipher suite ID used
924    pub cipher_suite_id: CipherSuiteId,
925    /// Whether cipher suite is PQC-only
926    pub is_pqc_only: bool,
927    /// Whether cipher suite is deprecated
928    pub is_deprecated: bool,
929    /// Member ID involved (if applicable)
930    pub member_id: Option<MemberId>,
931    /// Old epoch (for epoch changes)
932    pub old_epoch: Option<u64>,
933    /// New epoch (for epoch changes)
934    pub new_epoch: Option<u64>,
935    /// Additional context
936    pub context: Option<String>,
937}