shield_core/
group.rs

1//! Multi-recipient encryption.
2//!
3//! Encrypt once for multiple recipients, each can decrypt with their own key.
4
5// Member indices fit in u16 for practical group sizes (<65k members)
6#![allow(clippy::cast_possible_truncation)]
7
8use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
9use ring::hmac;
10use ring::rand::{SecureRandom, SystemRandom};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use subtle::ConstantTimeEq;
14
15use crate::error::{Result, ShieldError};
16
17/// Generate keystream using SHA256.
18fn generate_keystream(key: &[u8], nonce: &[u8], length: usize) -> Vec<u8> {
19    let mut keystream = Vec::with_capacity(length.div_ceil(32) * 32);
20    let num_blocks = length.div_ceil(32);
21
22    for i in 0..num_blocks {
23        let counter = (i as u32).to_le_bytes();
24        let mut data = Vec::with_capacity(key.len() + nonce.len() + 4);
25        data.extend_from_slice(key);
26        data.extend_from_slice(nonce);
27        data.extend_from_slice(&counter);
28
29        let hash = ring::digest::digest(&ring::digest::SHA256, &data);
30        keystream.extend_from_slice(hash.as_ref());
31    }
32
33    keystream.truncate(length);
34    keystream
35}
36
37/// Encrypt a block with HMAC authentication.
38fn encrypt_block(key: &[u8; 32], data: &[u8]) -> Result<Vec<u8>> {
39    let rng = SystemRandom::new();
40    let mut nonce = [0u8; 16];
41    rng.fill(&mut nonce)
42        .map_err(|_| ShieldError::RandomFailed)?;
43
44    let keystream = generate_keystream(key, &nonce, data.len());
45    let ciphertext: Vec<u8> = data
46        .iter()
47        .zip(keystream.iter())
48        .map(|(p, k)| p ^ k)
49        .collect();
50
51    let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
52    let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
53    hmac_data.extend_from_slice(&nonce);
54    hmac_data.extend_from_slice(&ciphertext);
55    let tag = hmac::sign(&hmac_key, &hmac_data);
56
57    let mut result = Vec::with_capacity(16 + ciphertext.len() + 16);
58    result.extend_from_slice(&nonce);
59    result.extend_from_slice(&ciphertext);
60    result.extend_from_slice(&tag.as_ref()[..16]);
61
62    Ok(result)
63}
64
65/// Decrypt a block with HMAC verification.
66fn decrypt_block(key: &[u8; 32], encrypted: &[u8]) -> Result<Vec<u8>> {
67    if encrypted.len() < 32 {
68        return Err(ShieldError::CiphertextTooShort {
69            expected: 32,
70            actual: encrypted.len(),
71        });
72    }
73
74    let nonce = &encrypted[..16];
75    let ciphertext = &encrypted[16..encrypted.len() - 16];
76    let mac = &encrypted[encrypted.len() - 16..];
77
78    let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
79    let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
80    hmac_data.extend_from_slice(nonce);
81    hmac_data.extend_from_slice(ciphertext);
82    let expected_tag = hmac::sign(&hmac_key, &hmac_data);
83
84    if mac.ct_eq(&expected_tag.as_ref()[..16]).unwrap_u8() != 1 {
85        return Err(ShieldError::AuthenticationFailed);
86    }
87
88    let keystream = generate_keystream(key, nonce, ciphertext.len());
89    let plaintext: Vec<u8> = ciphertext
90        .iter()
91        .zip(keystream.iter())
92        .map(|(c, k)| c ^ k)
93        .collect();
94
95    Ok(plaintext)
96}
97
98/// Encrypted group message format.
99#[derive(Serialize, Deserialize)]
100pub struct EncryptedGroupMessage {
101    pub version: u8,
102    pub ciphertext: String,
103    pub keys: HashMap<String, String>,
104}
105
106/// Multi-recipient encryption.
107pub struct GroupEncryption {
108    group_key: [u8; 32],
109    members: HashMap<String, [u8; 32]>,
110}
111
112impl GroupEncryption {
113    /// Create new group encryption.
114    pub fn new(group_key: Option<[u8; 32]>) -> Result<Self> {
115        let key = if let Some(k) = group_key {
116            k
117        } else {
118            let rng = SystemRandom::new();
119            let mut k = [0u8; 32];
120            rng.fill(&mut k).map_err(|_| ShieldError::RandomFailed)?;
121            k
122        };
123
124        Ok(Self {
125            group_key: key,
126            members: HashMap::new(),
127        })
128    }
129
130    /// Add member to group.
131    pub fn add_member(&mut self, member_id: &str, shared_key: [u8; 32]) {
132        self.members.insert(member_id.to_string(), shared_key);
133    }
134
135    /// Remove member from group.
136    pub fn remove_member(&mut self, member_id: &str) -> bool {
137        self.members.remove(member_id).is_some()
138    }
139
140    /// Get member list.
141    pub fn members(&self) -> Vec<&str> {
142        self.members.keys().map(String::as_str).collect()
143    }
144
145    /// Encrypt for all group members.
146    pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedGroupMessage> {
147        let ciphertext = encrypt_block(&self.group_key, plaintext)?;
148
149        let mut keys = HashMap::new();
150        for (member_id, member_key) in &self.members {
151            let encrypted_key = encrypt_block(member_key, &self.group_key)?;
152            keys.insert(member_id.clone(), URL_SAFE_NO_PAD.encode(&encrypted_key));
153        }
154
155        Ok(EncryptedGroupMessage {
156            version: 1,
157            ciphertext: URL_SAFE_NO_PAD.encode(&ciphertext),
158            keys,
159        })
160    }
161
162    /// Decrypt as group member.
163    pub fn decrypt_as_member(
164        encrypted: &EncryptedGroupMessage,
165        member_id: &str,
166        member_key: &[u8; 32],
167    ) -> Result<Vec<u8>> {
168        let encrypted_group_key = encrypted
169            .keys
170            .get(member_id)
171            .ok_or(ShieldError::MemberNotFound)?;
172
173        let encrypted_key_bytes = URL_SAFE_NO_PAD
174            .decode(encrypted_group_key)
175            .map_err(|_| ShieldError::InvalidFormat)?;
176
177        let group_key_vec = decrypt_block(member_key, &encrypted_key_bytes)?;
178        let mut group_key = [0u8; 32];
179        group_key.copy_from_slice(&group_key_vec);
180
181        let ciphertext = URL_SAFE_NO_PAD
182            .decode(&encrypted.ciphertext)
183            .map_err(|_| ShieldError::InvalidFormat)?;
184
185        decrypt_block(&group_key, &ciphertext)
186    }
187
188    /// Rotate the group key.
189    pub fn rotate_key(&mut self) -> Result<[u8; 32]> {
190        let old_key = self.group_key;
191        let rng = SystemRandom::new();
192        rng.fill(&mut self.group_key)
193            .map_err(|_| ShieldError::RandomFailed)?;
194        Ok(old_key)
195    }
196
197    /// Get the group key.
198    #[must_use]
199    pub fn group_key(&self) -> &[u8; 32] {
200        &self.group_key
201    }
202}
203
204/// Encrypted broadcast message format.
205#[derive(Serialize, Deserialize)]
206pub struct EncryptedBroadcast {
207    pub version: u8,
208    pub ciphertext: String,
209    pub subgroups: HashMap<String, String>,
210    pub members: HashMap<String, MemberKeyData>,
211}
212
213#[derive(Serialize, Deserialize, Clone)]
214pub struct MemberKeyData {
215    pub sg: u32,
216    pub key: String,
217}
218
219/// Efficient broadcast encryption for large groups.
220pub struct BroadcastEncryption {
221    #[allow(dead_code)]
222    master_key: [u8; 32],
223    subgroup_size: usize,
224    members: HashMap<String, (u32, [u8; 32])>,
225    subgroup_keys: HashMap<u32, [u8; 32]>,
226    next_subgroup: u32,
227}
228
229impl BroadcastEncryption {
230    /// Create new broadcast encryption.
231    pub fn new(master_key: Option<[u8; 32]>, subgroup_size: usize) -> Result<Self> {
232        let key = if let Some(k) = master_key {
233            k
234        } else {
235            let rng = SystemRandom::new();
236            let mut k = [0u8; 32];
237            rng.fill(&mut k).map_err(|_| ShieldError::RandomFailed)?;
238            k
239        };
240
241        Ok(Self {
242            master_key: key,
243            subgroup_size: if subgroup_size == 0 {
244                16
245            } else {
246                subgroup_size
247            },
248            members: HashMap::new(),
249            subgroup_keys: HashMap::new(),
250            next_subgroup: 0,
251        })
252    }
253
254    /// Add member to broadcast group.
255    pub fn add_member(&mut self, member_id: &str, member_key: [u8; 32]) -> Result<u32> {
256        let rng = SystemRandom::new();
257
258        // Find subgroup with space
259        let mut subgroup_id = None;
260        for sg_id in self.subgroup_keys.keys() {
261            let count = self.members.values().filter(|(sg, _)| sg == sg_id).count();
262            if count < self.subgroup_size {
263                subgroup_id = Some(*sg_id);
264                break;
265            }
266        }
267
268        let sg_id = if let Some(id) = subgroup_id {
269            id
270        } else {
271            let id = self.next_subgroup;
272            let mut sg_key = [0u8; 32];
273            rng.fill(&mut sg_key)
274                .map_err(|_| ShieldError::RandomFailed)?;
275            self.subgroup_keys.insert(id, sg_key);
276            self.next_subgroup += 1;
277            id
278        };
279
280        self.members
281            .insert(member_id.to_string(), (sg_id, member_key));
282        Ok(sg_id)
283    }
284
285    /// Encrypt for broadcast.
286    pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedBroadcast> {
287        let rng = SystemRandom::new();
288        let mut message_key = [0u8; 32];
289        rng.fill(&mut message_key)
290            .map_err(|_| ShieldError::RandomFailed)?;
291
292        // Encrypt message
293        let ciphertext = encrypt_block(&message_key, plaintext)?;
294
295        // Encrypt message key for each subgroup
296        let mut subgroups = HashMap::new();
297        for (sg_id, sg_key) in &self.subgroup_keys {
298            let encrypted_msg_key = encrypt_block(sg_key, &message_key)?;
299            subgroups.insert(
300                sg_id.to_string(),
301                URL_SAFE_NO_PAD.encode(&encrypted_msg_key),
302            );
303        }
304
305        // Encrypt subgroup keys for each member
306        let mut members = HashMap::new();
307        for (member_id, (sg_id, member_key)) in &self.members {
308            let sg_key = self.subgroup_keys.get(sg_id).unwrap();
309            let encrypted_sg_key = encrypt_block(member_key, sg_key)?;
310            members.insert(
311                member_id.clone(),
312                MemberKeyData {
313                    sg: *sg_id,
314                    key: URL_SAFE_NO_PAD.encode(&encrypted_sg_key),
315                },
316            );
317        }
318
319        Ok(EncryptedBroadcast {
320            version: 1,
321            ciphertext: URL_SAFE_NO_PAD.encode(&ciphertext),
322            subgroups,
323            members,
324        })
325    }
326
327    /// Decrypt broadcast as member.
328    pub fn decrypt_as_member(
329        encrypted: &EncryptedBroadcast,
330        member_id: &str,
331        member_key: &[u8; 32],
332    ) -> Result<Vec<u8>> {
333        let member_data = encrypted
334            .members
335            .get(member_id)
336            .ok_or(ShieldError::MemberNotFound)?;
337
338        // Decrypt subgroup key
339        let sg_key_enc = URL_SAFE_NO_PAD
340            .decode(&member_data.key)
341            .map_err(|_| ShieldError::InvalidFormat)?;
342        let sg_key_vec = decrypt_block(member_key, &sg_key_enc)?;
343        let mut sg_key = [0u8; 32];
344        sg_key.copy_from_slice(&sg_key_vec);
345
346        // Decrypt message key
347        let msg_key_enc = URL_SAFE_NO_PAD
348            .decode(
349                encrypted
350                    .subgroups
351                    .get(&member_data.sg.to_string())
352                    .ok_or(ShieldError::InvalidFormat)?,
353            )
354            .map_err(|_| ShieldError::InvalidFormat)?;
355        let msg_key_vec = decrypt_block(&sg_key, &msg_key_enc)?;
356        let mut msg_key = [0u8; 32];
357        msg_key.copy_from_slice(&msg_key_vec);
358
359        // Decrypt message
360        let ciphertext = URL_SAFE_NO_PAD
361            .decode(&encrypted.ciphertext)
362            .map_err(|_| ShieldError::InvalidFormat)?;
363        decrypt_block(&msg_key, &ciphertext)
364    }
365}
366
367#[cfg(test)]
368mod tests {
369    use super::*;
370
371    #[test]
372    fn test_group_encrypt_decrypt() {
373        let mut group = GroupEncryption::new(None).unwrap();
374        let alice_key = [1u8; 32];
375        let bob_key = [2u8; 32];
376
377        group.add_member("alice", alice_key);
378        group.add_member("bob", bob_key);
379
380        let plaintext = b"Group message!";
381        let encrypted = group.encrypt(plaintext).unwrap();
382
383        let alice_decrypted =
384            GroupEncryption::decrypt_as_member(&encrypted, "alice", &alice_key).unwrap();
385        let bob_decrypted =
386            GroupEncryption::decrypt_as_member(&encrypted, "bob", &bob_key).unwrap();
387
388        assert_eq!(plaintext.as_slice(), alice_decrypted.as_slice());
389        assert_eq!(plaintext.as_slice(), bob_decrypted.as_slice());
390    }
391
392    #[test]
393    fn test_group_non_member() {
394        let mut group = GroupEncryption::new(None).unwrap();
395        group.add_member("alice", [1u8; 32]);
396
397        let encrypted = group.encrypt(b"secret").unwrap();
398        let result = GroupEncryption::decrypt_as_member(&encrypted, "eve", &[3u8; 32]);
399        assert!(result.is_err());
400    }
401
402    #[test]
403    fn test_group_remove_member() {
404        let mut group = GroupEncryption::new(None).unwrap();
405        group.add_member("alice", [1u8; 32]);
406        group.add_member("bob", [2u8; 32]);
407
408        assert_eq!(group.members().len(), 2);
409        group.remove_member("bob");
410        assert_eq!(group.members().len(), 1);
411    }
412
413    #[test]
414    fn test_broadcast_encrypt_decrypt() {
415        let mut broadcast = BroadcastEncryption::new(None, 2).unwrap();
416        let alice_key = [1u8; 32];
417        let bob_key = [2u8; 32];
418
419        broadcast.add_member("alice", alice_key).unwrap();
420        broadcast.add_member("bob", bob_key).unwrap();
421
422        let plaintext = b"Broadcast message!";
423        let encrypted = broadcast.encrypt(plaintext).unwrap();
424
425        let alice_decrypted =
426            BroadcastEncryption::decrypt_as_member(&encrypted, "alice", &alice_key).unwrap();
427        let bob_decrypted =
428            BroadcastEncryption::decrypt_as_member(&encrypted, "bob", &bob_key).unwrap();
429
430        assert_eq!(plaintext.as_slice(), alice_decrypted.as_slice());
431        assert_eq!(plaintext.as_slice(), bob_decrypted.as_slice());
432    }
433
434    #[test]
435    fn test_broadcast_subgroups() {
436        let mut broadcast = BroadcastEncryption::new(None, 2).unwrap();
437
438        let sg1 = broadcast.add_member("alice", [1u8; 32]).unwrap();
439        let sg2 = broadcast.add_member("bob", [2u8; 32]).unwrap();
440        let sg3 = broadcast.add_member("carol", [3u8; 32]).unwrap();
441
442        // First two in same subgroup
443        assert_eq!(sg1, 0);
444        assert_eq!(sg2, 0);
445        // Third in new subgroup
446        assert_eq!(sg3, 1);
447    }
448}