shield_core/
group.rs

1//! Multi-recipient encryption.
2//!
3//! Encrypt once for multiple recipients, each can decrypt with their own key.
4
5// Member indices fit in u16 for practical group sizes (<65k members)
6#![allow(clippy::cast_possible_truncation)]
7
8use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
9use ring::hmac;
10use ring::rand::{SecureRandom, SystemRandom};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use subtle::ConstantTimeEq;
14
15use crate::error::{Result, ShieldError};
16
17/// Generate keystream using SHA256.
18fn generate_keystream(key: &[u8], nonce: &[u8], length: usize) -> Vec<u8> {
19    let mut keystream = Vec::with_capacity(length.div_ceil(32) * 32);
20    let num_blocks = length.div_ceil(32);
21
22    for i in 0..num_blocks {
23        let counter = (i as u32).to_le_bytes();
24        let mut data = Vec::with_capacity(key.len() + nonce.len() + 4);
25        data.extend_from_slice(key);
26        data.extend_from_slice(nonce);
27        data.extend_from_slice(&counter);
28
29        let hash = ring::digest::digest(&ring::digest::SHA256, &data);
30        keystream.extend_from_slice(hash.as_ref());
31    }
32
33    keystream.truncate(length);
34    keystream
35}
36
37/// Encrypt a block with HMAC authentication.
38fn encrypt_block(key: &[u8; 32], data: &[u8]) -> Result<Vec<u8>> {
39    let rng = SystemRandom::new();
40    let mut nonce = [0u8; 16];
41    rng.fill(&mut nonce).map_err(|_| ShieldError::RandomFailed)?;
42
43    let keystream = generate_keystream(key, &nonce, data.len());
44    let ciphertext: Vec<u8> = data.iter().zip(keystream.iter()).map(|(p, k)| p ^ k).collect();
45
46    let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
47    let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
48    hmac_data.extend_from_slice(&nonce);
49    hmac_data.extend_from_slice(&ciphertext);
50    let tag = hmac::sign(&hmac_key, &hmac_data);
51
52    let mut result = Vec::with_capacity(16 + ciphertext.len() + 16);
53    result.extend_from_slice(&nonce);
54    result.extend_from_slice(&ciphertext);
55    result.extend_from_slice(&tag.as_ref()[..16]);
56
57    Ok(result)
58}
59
60/// Decrypt a block with HMAC verification.
61fn decrypt_block(key: &[u8; 32], encrypted: &[u8]) -> Result<Vec<u8>> {
62    if encrypted.len() < 32 {
63        return Err(ShieldError::CiphertextTooShort {
64            expected: 32,
65            actual: encrypted.len(),
66        });
67    }
68
69    let nonce = &encrypted[..16];
70    let ciphertext = &encrypted[16..encrypted.len() - 16];
71    let mac = &encrypted[encrypted.len() - 16..];
72
73    let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
74    let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
75    hmac_data.extend_from_slice(nonce);
76    hmac_data.extend_from_slice(ciphertext);
77    let expected_tag = hmac::sign(&hmac_key, &hmac_data);
78
79    if mac.ct_eq(&expected_tag.as_ref()[..16]).unwrap_u8() != 1 {
80        return Err(ShieldError::AuthenticationFailed);
81    }
82
83    let keystream = generate_keystream(key, nonce, ciphertext.len());
84    let plaintext: Vec<u8> = ciphertext
85        .iter()
86        .zip(keystream.iter())
87        .map(|(c, k)| c ^ k)
88        .collect();
89
90    Ok(plaintext)
91}
92
93/// Encrypted group message format.
94#[derive(Serialize, Deserialize)]
95pub struct EncryptedGroupMessage {
96    pub version: u8,
97    pub ciphertext: String,
98    pub keys: HashMap<String, String>,
99}
100
101/// Multi-recipient encryption.
102pub struct GroupEncryption {
103    group_key: [u8; 32],
104    members: HashMap<String, [u8; 32]>,
105}
106
107impl GroupEncryption {
108    /// Create new group encryption.
109    pub fn new(group_key: Option<[u8; 32]>) -> Result<Self> {
110        let key = if let Some(k) = group_key { k } else {
111            let rng = SystemRandom::new();
112            let mut k = [0u8; 32];
113            rng.fill(&mut k).map_err(|_| ShieldError::RandomFailed)?;
114            k
115        };
116
117        Ok(Self {
118            group_key: key,
119            members: HashMap::new(),
120        })
121    }
122
123    /// Add member to group.
124    pub fn add_member(&mut self, member_id: &str, shared_key: [u8; 32]) {
125        self.members.insert(member_id.to_string(), shared_key);
126    }
127
128    /// Remove member from group.
129    pub fn remove_member(&mut self, member_id: &str) -> bool {
130        self.members.remove(member_id).is_some()
131    }
132
133    /// Get member list.
134    pub fn members(&self) -> Vec<&str> {
135        self.members.keys().map(String::as_str).collect()
136    }
137
138    /// Encrypt for all group members.
139    pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedGroupMessage> {
140        let ciphertext = encrypt_block(&self.group_key, plaintext)?;
141
142        let mut keys = HashMap::new();
143        for (member_id, member_key) in &self.members {
144            let encrypted_key = encrypt_block(member_key, &self.group_key)?;
145            keys.insert(member_id.clone(), URL_SAFE_NO_PAD.encode(&encrypted_key));
146        }
147
148        Ok(EncryptedGroupMessage {
149            version: 1,
150            ciphertext: URL_SAFE_NO_PAD.encode(&ciphertext),
151            keys,
152        })
153    }
154
155    /// Decrypt as group member.
156    pub fn decrypt_as_member(
157        encrypted: &EncryptedGroupMessage,
158        member_id: &str,
159        member_key: &[u8; 32],
160    ) -> Result<Vec<u8>> {
161        let encrypted_group_key = encrypted
162            .keys
163            .get(member_id)
164            .ok_or(ShieldError::MemberNotFound)?;
165
166        let encrypted_key_bytes = URL_SAFE_NO_PAD
167            .decode(encrypted_group_key)
168            .map_err(|_| ShieldError::InvalidFormat)?;
169
170        let group_key_vec = decrypt_block(member_key, &encrypted_key_bytes)?;
171        let mut group_key = [0u8; 32];
172        group_key.copy_from_slice(&group_key_vec);
173
174        let ciphertext = URL_SAFE_NO_PAD
175            .decode(&encrypted.ciphertext)
176            .map_err(|_| ShieldError::InvalidFormat)?;
177
178        decrypt_block(&group_key, &ciphertext)
179    }
180
181    /// Rotate the group key.
182    pub fn rotate_key(&mut self) -> Result<[u8; 32]> {
183        let old_key = self.group_key;
184        let rng = SystemRandom::new();
185        rng.fill(&mut self.group_key)
186            .map_err(|_| ShieldError::RandomFailed)?;
187        Ok(old_key)
188    }
189
190    /// Get the group key.
191    #[must_use] 
192    pub fn group_key(&self) -> &[u8; 32] {
193        &self.group_key
194    }
195}
196
197/// Encrypted broadcast message format.
198#[derive(Serialize, Deserialize)]
199pub struct EncryptedBroadcast {
200    pub version: u8,
201    pub ciphertext: String,
202    pub subgroups: HashMap<String, String>,
203    pub members: HashMap<String, MemberKeyData>,
204}
205
206#[derive(Serialize, Deserialize, Clone)]
207pub struct MemberKeyData {
208    pub sg: u32,
209    pub key: String,
210}
211
212/// Efficient broadcast encryption for large groups.
213pub struct BroadcastEncryption {
214    #[allow(dead_code)]
215    master_key: [u8; 32],
216    subgroup_size: usize,
217    members: HashMap<String, (u32, [u8; 32])>,
218    subgroup_keys: HashMap<u32, [u8; 32]>,
219    next_subgroup: u32,
220}
221
222impl BroadcastEncryption {
223    /// Create new broadcast encryption.
224    pub fn new(master_key: Option<[u8; 32]>, subgroup_size: usize) -> Result<Self> {
225        let key = if let Some(k) = master_key { k } else {
226            let rng = SystemRandom::new();
227            let mut k = [0u8; 32];
228            rng.fill(&mut k).map_err(|_| ShieldError::RandomFailed)?;
229            k
230        };
231
232        Ok(Self {
233            master_key: key,
234            subgroup_size: if subgroup_size == 0 { 16 } else { subgroup_size },
235            members: HashMap::new(),
236            subgroup_keys: HashMap::new(),
237            next_subgroup: 0,
238        })
239    }
240
241    /// Add member to broadcast group.
242    pub fn add_member(&mut self, member_id: &str, member_key: [u8; 32]) -> Result<u32> {
243        let rng = SystemRandom::new();
244
245        // Find subgroup with space
246        let mut subgroup_id = None;
247        for sg_id in self.subgroup_keys.keys() {
248            let count = self.members.values().filter(|(sg, _)| sg == sg_id).count();
249            if count < self.subgroup_size {
250                subgroup_id = Some(*sg_id);
251                break;
252            }
253        }
254
255        let sg_id = if let Some(id) = subgroup_id { id } else {
256            let id = self.next_subgroup;
257            let mut sg_key = [0u8; 32];
258            rng.fill(&mut sg_key)
259                .map_err(|_| ShieldError::RandomFailed)?;
260            self.subgroup_keys.insert(id, sg_key);
261            self.next_subgroup += 1;
262            id
263        };
264
265        self.members.insert(member_id.to_string(), (sg_id, member_key));
266        Ok(sg_id)
267    }
268
269    /// Encrypt for broadcast.
270    pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedBroadcast> {
271        let rng = SystemRandom::new();
272        let mut message_key = [0u8; 32];
273        rng.fill(&mut message_key)
274            .map_err(|_| ShieldError::RandomFailed)?;
275
276        // Encrypt message
277        let ciphertext = encrypt_block(&message_key, plaintext)?;
278
279        // Encrypt message key for each subgroup
280        let mut subgroups = HashMap::new();
281        for (sg_id, sg_key) in &self.subgroup_keys {
282            let encrypted_msg_key = encrypt_block(sg_key, &message_key)?;
283            subgroups.insert(sg_id.to_string(), URL_SAFE_NO_PAD.encode(&encrypted_msg_key));
284        }
285
286        // Encrypt subgroup keys for each member
287        let mut members = HashMap::new();
288        for (member_id, (sg_id, member_key)) in &self.members {
289            let sg_key = self.subgroup_keys.get(sg_id).unwrap();
290            let encrypted_sg_key = encrypt_block(member_key, sg_key)?;
291            members.insert(
292                member_id.clone(),
293                MemberKeyData {
294                    sg: *sg_id,
295                    key: URL_SAFE_NO_PAD.encode(&encrypted_sg_key),
296                },
297            );
298        }
299
300        Ok(EncryptedBroadcast {
301            version: 1,
302            ciphertext: URL_SAFE_NO_PAD.encode(&ciphertext),
303            subgroups,
304            members,
305        })
306    }
307
308    /// Decrypt broadcast as member.
309    pub fn decrypt_as_member(
310        encrypted: &EncryptedBroadcast,
311        member_id: &str,
312        member_key: &[u8; 32],
313    ) -> Result<Vec<u8>> {
314        let member_data = encrypted
315            .members
316            .get(member_id)
317            .ok_or(ShieldError::MemberNotFound)?;
318
319        // Decrypt subgroup key
320        let sg_key_enc = URL_SAFE_NO_PAD
321            .decode(&member_data.key)
322            .map_err(|_| ShieldError::InvalidFormat)?;
323        let sg_key_vec = decrypt_block(member_key, &sg_key_enc)?;
324        let mut sg_key = [0u8; 32];
325        sg_key.copy_from_slice(&sg_key_vec);
326
327        // Decrypt message key
328        let msg_key_enc = URL_SAFE_NO_PAD
329            .decode(
330                encrypted
331                    .subgroups
332                    .get(&member_data.sg.to_string())
333                    .ok_or(ShieldError::InvalidFormat)?,
334            )
335            .map_err(|_| ShieldError::InvalidFormat)?;
336        let msg_key_vec = decrypt_block(&sg_key, &msg_key_enc)?;
337        let mut msg_key = [0u8; 32];
338        msg_key.copy_from_slice(&msg_key_vec);
339
340        // Decrypt message
341        let ciphertext = URL_SAFE_NO_PAD
342            .decode(&encrypted.ciphertext)
343            .map_err(|_| ShieldError::InvalidFormat)?;
344        decrypt_block(&msg_key, &ciphertext)
345    }
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn test_group_encrypt_decrypt() {
354        let mut group = GroupEncryption::new(None).unwrap();
355        let alice_key = [1u8; 32];
356        let bob_key = [2u8; 32];
357
358        group.add_member("alice", alice_key);
359        group.add_member("bob", bob_key);
360
361        let plaintext = b"Group message!";
362        let encrypted = group.encrypt(plaintext).unwrap();
363
364        let alice_decrypted =
365            GroupEncryption::decrypt_as_member(&encrypted, "alice", &alice_key).unwrap();
366        let bob_decrypted =
367            GroupEncryption::decrypt_as_member(&encrypted, "bob", &bob_key).unwrap();
368
369        assert_eq!(plaintext.as_slice(), alice_decrypted.as_slice());
370        assert_eq!(plaintext.as_slice(), bob_decrypted.as_slice());
371    }
372
373    #[test]
374    fn test_group_non_member() {
375        let mut group = GroupEncryption::new(None).unwrap();
376        group.add_member("alice", [1u8; 32]);
377
378        let encrypted = group.encrypt(b"secret").unwrap();
379        let result = GroupEncryption::decrypt_as_member(&encrypted, "eve", &[3u8; 32]);
380        assert!(result.is_err());
381    }
382
383    #[test]
384    fn test_group_remove_member() {
385        let mut group = GroupEncryption::new(None).unwrap();
386        group.add_member("alice", [1u8; 32]);
387        group.add_member("bob", [2u8; 32]);
388
389        assert_eq!(group.members().len(), 2);
390        group.remove_member("bob");
391        assert_eq!(group.members().len(), 1);
392    }
393
394    #[test]
395    fn test_broadcast_encrypt_decrypt() {
396        let mut broadcast = BroadcastEncryption::new(None, 2).unwrap();
397        let alice_key = [1u8; 32];
398        let bob_key = [2u8; 32];
399
400        broadcast.add_member("alice", alice_key).unwrap();
401        broadcast.add_member("bob", bob_key).unwrap();
402
403        let plaintext = b"Broadcast message!";
404        let encrypted = broadcast.encrypt(plaintext).unwrap();
405
406        let alice_decrypted =
407            BroadcastEncryption::decrypt_as_member(&encrypted, "alice", &alice_key).unwrap();
408        let bob_decrypted =
409            BroadcastEncryption::decrypt_as_member(&encrypted, "bob", &bob_key).unwrap();
410
411        assert_eq!(plaintext.as_slice(), alice_decrypted.as_slice());
412        assert_eq!(plaintext.as_slice(), bob_decrypted.as_slice());
413    }
414
415    #[test]
416    fn test_broadcast_subgroups() {
417        let mut broadcast = BroadcastEncryption::new(None, 2).unwrap();
418
419        let sg1 = broadcast.add_member("alice", [1u8; 32]).unwrap();
420        let sg2 = broadcast.add_member("bob", [2u8; 32]).unwrap();
421        let sg3 = broadcast.add_member("carol", [3u8; 32]).unwrap();
422
423        // First two in same subgroup
424        assert_eq!(sg1, 0);
425        assert_eq!(sg2, 0);
426        // Third in new subgroup
427        assert_eq!(sg3, 1);
428    }
429}