Skip to main content

shield_core/
group.rs

1//! Multi-recipient encryption.
2//!
3//! Encrypt once for multiple recipients, each can decrypt with their own key.
4
5// Member indices fit in u16 for practical group sizes (<65k members)
6#![allow(clippy::cast_possible_truncation)]
7
8use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
9use ring::hmac;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use subtle::ConstantTimeEq;
13use zeroize::Zeroize;
14
15use crate::error::{Result, ShieldError};
16
17/// Generate keystream using HMAC-SHA256 (keyed PRF).
18fn generate_keystream(key: &[u8], nonce: &[u8], length: usize) -> Result<Vec<u8>> {
19    let num_blocks = length.div_ceil(32);
20    if u32::try_from(num_blocks).is_err() {
21        return Err(ShieldError::StreamError(
22            "keystream too long: counter overflow".into(),
23        ));
24    }
25    let mut keystream = Vec::with_capacity(num_blocks * 32);
26    let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
27
28    for i in 0..num_blocks {
29        let counter = (i as u32).to_le_bytes();
30        let mut data = Vec::with_capacity(nonce.len() + 4);
31        data.extend_from_slice(nonce);
32        data.extend_from_slice(&counter);
33
34        let tag = hmac::sign(&hmac_key, &data);
35        keystream.extend_from_slice(tag.as_ref());
36    }
37
38    keystream.truncate(length);
39    Ok(keystream)
40}
41
42/// Derive separated encryption and MAC subkeys from a block key using HMAC-SHA256.
43fn derive_block_subkeys(key: &[u8; 32]) -> ([u8; 32], [u8; 32]) {
44    let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
45
46    let enc_tag = hmac::sign(&hmac_key, b"shield-group-encrypt");
47    let mut enc_key = [0u8; 32];
48    enc_key.copy_from_slice(&enc_tag.as_ref()[..32]);
49
50    let mac_tag = hmac::sign(&hmac_key, b"shield-group-authenticate");
51    let mut mac_key = [0u8; 32];
52    mac_key.copy_from_slice(&mac_tag.as_ref()[..32]);
53
54    (enc_key, mac_key)
55}
56
57/// Encrypt a block with HMAC authentication (separated enc/mac keys).
58fn encrypt_block(key: &[u8; 32], data: &[u8]) -> Result<Vec<u8>> {
59    let (enc_key, mac_key) = derive_block_subkeys(key);
60    let nonce: [u8; 16] = crate::random::random_bytes()?;
61
62    let keystream = generate_keystream(&enc_key, &nonce, data.len())?;
63    let ciphertext: Vec<u8> = data
64        .iter()
65        .zip(keystream.iter())
66        .map(|(p, k)| p ^ k)
67        .collect();
68
69    let hmac_signing_key = hmac::Key::new(hmac::HMAC_SHA256, &mac_key);
70    let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
71    hmac_data.extend_from_slice(&nonce);
72    hmac_data.extend_from_slice(&ciphertext);
73    let tag = hmac::sign(&hmac_signing_key, &hmac_data);
74
75    let mut result = Vec::with_capacity(16 + ciphertext.len() + 16);
76    result.extend_from_slice(&nonce);
77    result.extend_from_slice(&ciphertext);
78    result.extend_from_slice(&tag.as_ref()[..16]);
79
80    Ok(result)
81}
82
83/// Decrypt a block with HMAC verification (separated enc/mac keys).
84fn decrypt_block(key: &[u8; 32], encrypted: &[u8]) -> Result<Vec<u8>> {
85    if encrypted.len() < 32 {
86        return Err(ShieldError::CiphertextTooShort {
87            expected: 32,
88            actual: encrypted.len(),
89        });
90    }
91
92    let (enc_key, mac_key) = derive_block_subkeys(key);
93
94    let nonce = &encrypted[..16];
95    let ciphertext = &encrypted[16..encrypted.len() - 16];
96    let mac = &encrypted[encrypted.len() - 16..];
97
98    let hmac_signing_key = hmac::Key::new(hmac::HMAC_SHA256, &mac_key);
99    let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
100    hmac_data.extend_from_slice(nonce);
101    hmac_data.extend_from_slice(ciphertext);
102    let expected_tag = hmac::sign(&hmac_signing_key, &hmac_data);
103
104    if mac.ct_eq(&expected_tag.as_ref()[..16]).unwrap_u8() != 1 {
105        return Err(ShieldError::AuthenticationFailed);
106    }
107
108    let keystream = generate_keystream(&enc_key, nonce, ciphertext.len())?;
109    let plaintext: Vec<u8> = ciphertext
110        .iter()
111        .zip(keystream.iter())
112        .map(|(c, k)| c ^ k)
113        .collect();
114
115    Ok(plaintext)
116}
117
118/// Encrypted group message format.
119#[derive(Serialize, Deserialize)]
120pub struct EncryptedGroupMessage {
121    pub version: u8,
122    pub ciphertext: String,
123    pub keys: HashMap<String, String>,
124}
125
126/// Multi-recipient encryption.
127pub struct GroupEncryption {
128    group_key: [u8; 32],
129    members: HashMap<String, [u8; 32]>,
130}
131
132impl GroupEncryption {
133    /// Create new group encryption.
134    pub fn new(group_key: Option<[u8; 32]>) -> Result<Self> {
135        let key = if let Some(k) = group_key {
136            k
137        } else {
138            crate::random::random_bytes()?
139        };
140
141        Ok(Self {
142            group_key: key,
143            members: HashMap::new(),
144        })
145    }
146
147    /// Add member to group.
148    pub fn add_member(&mut self, member_id: &str, shared_key: [u8; 32]) {
149        self.members.insert(member_id.to_string(), shared_key);
150    }
151
152    /// Remove member from group.
153    pub fn remove_member(&mut self, member_id: &str) -> bool {
154        if let Some(mut key) = self.members.remove(member_id) {
155            key.zeroize();
156            true
157        } else {
158            false
159        }
160    }
161
162    /// Get member list.
163    pub fn members(&self) -> Vec<&str> {
164        self.members.keys().map(String::as_str).collect()
165    }
166
167    /// Encrypt for all group members.
168    pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedGroupMessage> {
169        let ciphertext = encrypt_block(&self.group_key, plaintext)?;
170
171        let mut keys = HashMap::new();
172        for (member_id, member_key) in &self.members {
173            let encrypted_key = encrypt_block(member_key, &self.group_key)?;
174            keys.insert(member_id.clone(), URL_SAFE_NO_PAD.encode(&encrypted_key));
175        }
176
177        Ok(EncryptedGroupMessage {
178            version: 1,
179            ciphertext: URL_SAFE_NO_PAD.encode(&ciphertext),
180            keys,
181        })
182    }
183
184    /// Decrypt as group member.
185    pub fn decrypt_as_member(
186        encrypted: &EncryptedGroupMessage,
187        member_id: &str,
188        member_key: &[u8; 32],
189    ) -> Result<Vec<u8>> {
190        let encrypted_group_key = encrypted
191            .keys
192            .get(member_id)
193            .ok_or(ShieldError::MemberNotFound)?;
194
195        let encrypted_key_bytes = URL_SAFE_NO_PAD
196            .decode(encrypted_group_key)
197            .map_err(|_| ShieldError::InvalidFormat)?;
198
199        let group_key_vec = decrypt_block(member_key, &encrypted_key_bytes)?;
200        let mut group_key = [0u8; 32];
201        group_key.copy_from_slice(&group_key_vec);
202
203        let ciphertext = URL_SAFE_NO_PAD
204            .decode(&encrypted.ciphertext)
205            .map_err(|_| ShieldError::InvalidFormat)?;
206
207        decrypt_block(&group_key, &ciphertext)
208    }
209
210    /// Rotate the group key.
211    pub fn rotate_key(&mut self) -> Result<[u8; 32]> {
212        let old_key = self.group_key;
213        self.group_key = crate::random::random_bytes()?;
214        Ok(old_key)
215    }
216}
217
218impl Drop for GroupEncryption {
219    fn drop(&mut self) {
220        self.group_key.zeroize();
221        for key in self.members.values_mut() {
222            key.zeroize();
223        }
224    }
225}
226
227/// Encrypted broadcast message format.
228#[derive(Serialize, Deserialize)]
229pub struct EncryptedBroadcast {
230    pub version: u8,
231    pub ciphertext: String,
232    pub subgroups: HashMap<String, String>,
233    pub members: HashMap<String, MemberKeyData>,
234}
235
236#[derive(Serialize, Deserialize, Clone)]
237pub struct MemberKeyData {
238    pub sg: u32,
239    pub key: String,
240}
241
242/// Efficient broadcast encryption for large groups.
243pub struct BroadcastEncryption {
244    master_key: [u8; 32],
245    subgroup_size: usize,
246    members: HashMap<String, (u32, [u8; 32])>,
247    subgroup_keys: HashMap<u32, [u8; 32]>,
248    next_subgroup: u32,
249}
250
251impl BroadcastEncryption {
252    /// Create new broadcast encryption.
253    pub fn new(master_key: Option<[u8; 32]>, subgroup_size: usize) -> Result<Self> {
254        let key = if let Some(k) = master_key {
255            k
256        } else {
257            crate::random::random_bytes()?
258        };
259
260        Ok(Self {
261            master_key: key,
262            subgroup_size: if subgroup_size == 0 {
263                16
264            } else {
265                subgroup_size
266            },
267            members: HashMap::new(),
268            subgroup_keys: HashMap::new(),
269            next_subgroup: 0,
270        })
271    }
272
273    /// Add member to broadcast group.
274    pub fn add_member(&mut self, member_id: &str, member_key: [u8; 32]) -> Result<u32> {
275        // Find subgroup with space
276        let mut subgroup_id = None;
277        for sg_id in self.subgroup_keys.keys() {
278            let count = self.members.values().filter(|(sg, _)| sg == sg_id).count();
279            if count < self.subgroup_size {
280                subgroup_id = Some(*sg_id);
281                break;
282            }
283        }
284
285        let sg_id = if let Some(id) = subgroup_id {
286            id
287        } else {
288            let id = self.next_subgroup;
289            let sg_key: [u8; 32] = crate::random::random_bytes()?;
290            self.subgroup_keys.insert(id, sg_key);
291            self.next_subgroup += 1;
292            id
293        };
294
295        self.members
296            .insert(member_id.to_string(), (sg_id, member_key));
297        Ok(sg_id)
298    }
299
300    /// Encrypt for broadcast.
301    pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedBroadcast> {
302        let message_key: [u8; 32] = crate::random::random_bytes()?;
303
304        // Encrypt message
305        let ciphertext = encrypt_block(&message_key, plaintext)?;
306
307        // Encrypt message key for each subgroup
308        let mut subgroups = HashMap::new();
309        for (sg_id, sg_key) in &self.subgroup_keys {
310            let encrypted_msg_key = encrypt_block(sg_key, &message_key)?;
311            subgroups.insert(
312                sg_id.to_string(),
313                URL_SAFE_NO_PAD.encode(&encrypted_msg_key),
314            );
315        }
316
317        // Encrypt subgroup keys for each member
318        let mut members = HashMap::new();
319        for (member_id, (sg_id, member_key)) in &self.members {
320            let Some(sg_key) = self.subgroup_keys.get(sg_id) else {
321                continue;
322            };
323            let encrypted_sg_key = encrypt_block(member_key, sg_key)?;
324            members.insert(
325                member_id.clone(),
326                MemberKeyData {
327                    sg: *sg_id,
328                    key: URL_SAFE_NO_PAD.encode(&encrypted_sg_key),
329                },
330            );
331        }
332
333        Ok(EncryptedBroadcast {
334            version: 1,
335            ciphertext: URL_SAFE_NO_PAD.encode(&ciphertext),
336            subgroups,
337            members,
338        })
339    }
340
341    /// Decrypt broadcast as member.
342    pub fn decrypt_as_member(
343        encrypted: &EncryptedBroadcast,
344        member_id: &str,
345        member_key: &[u8; 32],
346    ) -> Result<Vec<u8>> {
347        let member_data = encrypted
348            .members
349            .get(member_id)
350            .ok_or(ShieldError::MemberNotFound)?;
351
352        // Decrypt subgroup key
353        let sg_key_enc = URL_SAFE_NO_PAD
354            .decode(&member_data.key)
355            .map_err(|_| ShieldError::InvalidFormat)?;
356        let sg_key_vec = decrypt_block(member_key, &sg_key_enc)?;
357        let mut sg_key = [0u8; 32];
358        sg_key.copy_from_slice(&sg_key_vec);
359
360        // Decrypt message key
361        let msg_key_enc = URL_SAFE_NO_PAD
362            .decode(
363                encrypted
364                    .subgroups
365                    .get(&member_data.sg.to_string())
366                    .ok_or(ShieldError::InvalidFormat)?,
367            )
368            .map_err(|_| ShieldError::InvalidFormat)?;
369        let msg_key_vec = decrypt_block(&sg_key, &msg_key_enc)?;
370        let mut msg_key = [0u8; 32];
371        msg_key.copy_from_slice(&msg_key_vec);
372
373        // Decrypt message
374        let ciphertext = URL_SAFE_NO_PAD
375            .decode(&encrypted.ciphertext)
376            .map_err(|_| ShieldError::InvalidFormat)?;
377        decrypt_block(&msg_key, &ciphertext)
378    }
379}
380
381impl Drop for BroadcastEncryption {
382    fn drop(&mut self) {
383        self.master_key.zeroize();
384        for key in self.members.values_mut() {
385            key.1.zeroize();
386        }
387        for key in self.subgroup_keys.values_mut() {
388            key.zeroize();
389        }
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[test]
398    fn test_group_encrypt_decrypt() {
399        let mut group = GroupEncryption::new(None).unwrap();
400        let alice_key = [1u8; 32];
401        let bob_key = [2u8; 32];
402
403        group.add_member("alice", alice_key);
404        group.add_member("bob", bob_key);
405
406        let plaintext = b"Group message!";
407        let encrypted = group.encrypt(plaintext).unwrap();
408
409        let alice_decrypted =
410            GroupEncryption::decrypt_as_member(&encrypted, "alice", &alice_key).unwrap();
411        let bob_decrypted =
412            GroupEncryption::decrypt_as_member(&encrypted, "bob", &bob_key).unwrap();
413
414        assert_eq!(plaintext.as_slice(), alice_decrypted.as_slice());
415        assert_eq!(plaintext.as_slice(), bob_decrypted.as_slice());
416    }
417
418    #[test]
419    fn test_group_non_member() {
420        let mut group = GroupEncryption::new(None).unwrap();
421        group.add_member("alice", [1u8; 32]);
422
423        let encrypted = group.encrypt(b"secret").unwrap();
424        let result = GroupEncryption::decrypt_as_member(&encrypted, "eve", &[3u8; 32]);
425        assert!(result.is_err());
426    }
427
428    #[test]
429    fn test_group_remove_member() {
430        let mut group = GroupEncryption::new(None).unwrap();
431        group.add_member("alice", [1u8; 32]);
432        group.add_member("bob", [2u8; 32]);
433
434        assert_eq!(group.members().len(), 2);
435        group.remove_member("bob");
436        assert_eq!(group.members().len(), 1);
437    }
438
439    #[test]
440    fn test_broadcast_encrypt_decrypt() {
441        let mut broadcast = BroadcastEncryption::new(None, 2).unwrap();
442        let alice_key = [1u8; 32];
443        let bob_key = [2u8; 32];
444
445        broadcast.add_member("alice", alice_key).unwrap();
446        broadcast.add_member("bob", bob_key).unwrap();
447
448        let plaintext = b"Broadcast message!";
449        let encrypted = broadcast.encrypt(plaintext).unwrap();
450
451        let alice_decrypted =
452            BroadcastEncryption::decrypt_as_member(&encrypted, "alice", &alice_key).unwrap();
453        let bob_decrypted =
454            BroadcastEncryption::decrypt_as_member(&encrypted, "bob", &bob_key).unwrap();
455
456        assert_eq!(plaintext.as_slice(), alice_decrypted.as_slice());
457        assert_eq!(plaintext.as_slice(), bob_decrypted.as_slice());
458    }
459
460    #[test]
461    fn test_broadcast_subgroups() {
462        let mut broadcast = BroadcastEncryption::new(None, 2).unwrap();
463
464        let sg1 = broadcast.add_member("alice", [1u8; 32]).unwrap();
465        let sg2 = broadcast.add_member("bob", [2u8; 32]).unwrap();
466        let sg3 = broadcast.add_member("carol", [3u8; 32]).unwrap();
467
468        // First two in same subgroup
469        assert_eq!(sg1, 0);
470        assert_eq!(sg2, 0);
471        // Third in new subgroup
472        assert_eq!(sg3, 1);
473    }
474}