1#![allow(clippy::cast_possible_truncation)]
7
8use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
9use ring::hmac;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use subtle::ConstantTimeEq;
13use zeroize::Zeroize;
14
15use crate::error::{Result, ShieldError};
16
17fn generate_keystream(key: &[u8], nonce: &[u8], length: usize) -> Result<Vec<u8>> {
19 let num_blocks = length.div_ceil(32);
20 if u32::try_from(num_blocks).is_err() {
21 return Err(ShieldError::StreamError(
22 "keystream too long: counter overflow".into(),
23 ));
24 }
25 let mut keystream = Vec::with_capacity(num_blocks * 32);
26 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
27
28 for i in 0..num_blocks {
29 let counter = (i as u32).to_le_bytes();
30 let mut data = Vec::with_capacity(nonce.len() + 4);
31 data.extend_from_slice(nonce);
32 data.extend_from_slice(&counter);
33
34 let tag = hmac::sign(&hmac_key, &data);
35 keystream.extend_from_slice(tag.as_ref());
36 }
37
38 keystream.truncate(length);
39 Ok(keystream)
40}
41
42fn derive_block_subkeys(key: &[u8; 32]) -> ([u8; 32], [u8; 32]) {
44 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
45
46 let enc_tag = hmac::sign(&hmac_key, b"shield-group-encrypt");
47 let mut enc_key = [0u8; 32];
48 enc_key.copy_from_slice(&enc_tag.as_ref()[..32]);
49
50 let mac_tag = hmac::sign(&hmac_key, b"shield-group-authenticate");
51 let mut mac_key = [0u8; 32];
52 mac_key.copy_from_slice(&mac_tag.as_ref()[..32]);
53
54 (enc_key, mac_key)
55}
56
57fn encrypt_block(key: &[u8; 32], data: &[u8]) -> Result<Vec<u8>> {
59 let (enc_key, mac_key) = derive_block_subkeys(key);
60 let nonce: [u8; 16] = crate::random::random_bytes()?;
61
62 let keystream = generate_keystream(&enc_key, &nonce, data.len())?;
63 let ciphertext: Vec<u8> = data
64 .iter()
65 .zip(keystream.iter())
66 .map(|(p, k)| p ^ k)
67 .collect();
68
69 let hmac_signing_key = hmac::Key::new(hmac::HMAC_SHA256, &mac_key);
70 let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
71 hmac_data.extend_from_slice(&nonce);
72 hmac_data.extend_from_slice(&ciphertext);
73 let tag = hmac::sign(&hmac_signing_key, &hmac_data);
74
75 let mut result = Vec::with_capacity(16 + ciphertext.len() + 16);
76 result.extend_from_slice(&nonce);
77 result.extend_from_slice(&ciphertext);
78 result.extend_from_slice(&tag.as_ref()[..16]);
79
80 Ok(result)
81}
82
83fn decrypt_block(key: &[u8; 32], encrypted: &[u8]) -> Result<Vec<u8>> {
85 if encrypted.len() < 32 {
86 return Err(ShieldError::CiphertextTooShort {
87 expected: 32,
88 actual: encrypted.len(),
89 });
90 }
91
92 let (enc_key, mac_key) = derive_block_subkeys(key);
93
94 let nonce = &encrypted[..16];
95 let ciphertext = &encrypted[16..encrypted.len() - 16];
96 let mac = &encrypted[encrypted.len() - 16..];
97
98 let hmac_signing_key = hmac::Key::new(hmac::HMAC_SHA256, &mac_key);
99 let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
100 hmac_data.extend_from_slice(nonce);
101 hmac_data.extend_from_slice(ciphertext);
102 let expected_tag = hmac::sign(&hmac_signing_key, &hmac_data);
103
104 if mac.ct_eq(&expected_tag.as_ref()[..16]).unwrap_u8() != 1 {
105 return Err(ShieldError::AuthenticationFailed);
106 }
107
108 let keystream = generate_keystream(&enc_key, nonce, ciphertext.len())?;
109 let plaintext: Vec<u8> = ciphertext
110 .iter()
111 .zip(keystream.iter())
112 .map(|(c, k)| c ^ k)
113 .collect();
114
115 Ok(plaintext)
116}
117
118#[derive(Serialize, Deserialize)]
120pub struct EncryptedGroupMessage {
121 pub version: u8,
122 pub ciphertext: String,
123 pub keys: HashMap<String, String>,
124}
125
126pub struct GroupEncryption {
128 group_key: [u8; 32],
129 members: HashMap<String, [u8; 32]>,
130}
131
132impl GroupEncryption {
133 pub fn new(group_key: Option<[u8; 32]>) -> Result<Self> {
135 let key = if let Some(k) = group_key {
136 k
137 } else {
138 crate::random::random_bytes()?
139 };
140
141 Ok(Self {
142 group_key: key,
143 members: HashMap::new(),
144 })
145 }
146
147 pub fn add_member(&mut self, member_id: &str, shared_key: [u8; 32]) {
149 self.members.insert(member_id.to_string(), shared_key);
150 }
151
152 pub fn remove_member(&mut self, member_id: &str) -> bool {
154 if let Some(mut key) = self.members.remove(member_id) {
155 key.zeroize();
156 true
157 } else {
158 false
159 }
160 }
161
162 pub fn members(&self) -> Vec<&str> {
164 self.members.keys().map(String::as_str).collect()
165 }
166
167 pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedGroupMessage> {
169 let ciphertext = encrypt_block(&self.group_key, plaintext)?;
170
171 let mut keys = HashMap::new();
172 for (member_id, member_key) in &self.members {
173 let encrypted_key = encrypt_block(member_key, &self.group_key)?;
174 keys.insert(member_id.clone(), URL_SAFE_NO_PAD.encode(&encrypted_key));
175 }
176
177 Ok(EncryptedGroupMessage {
178 version: 1,
179 ciphertext: URL_SAFE_NO_PAD.encode(&ciphertext),
180 keys,
181 })
182 }
183
184 pub fn decrypt_as_member(
186 encrypted: &EncryptedGroupMessage,
187 member_id: &str,
188 member_key: &[u8; 32],
189 ) -> Result<Vec<u8>> {
190 let encrypted_group_key = encrypted
191 .keys
192 .get(member_id)
193 .ok_or(ShieldError::MemberNotFound)?;
194
195 let encrypted_key_bytes = URL_SAFE_NO_PAD
196 .decode(encrypted_group_key)
197 .map_err(|_| ShieldError::InvalidFormat)?;
198
199 let group_key_vec = decrypt_block(member_key, &encrypted_key_bytes)?;
200 let mut group_key = [0u8; 32];
201 group_key.copy_from_slice(&group_key_vec);
202
203 let ciphertext = URL_SAFE_NO_PAD
204 .decode(&encrypted.ciphertext)
205 .map_err(|_| ShieldError::InvalidFormat)?;
206
207 decrypt_block(&group_key, &ciphertext)
208 }
209
210 pub fn rotate_key(&mut self) -> Result<[u8; 32]> {
212 let old_key = self.group_key;
213 self.group_key = crate::random::random_bytes()?;
214 Ok(old_key)
215 }
216}
217
218impl Drop for GroupEncryption {
219 fn drop(&mut self) {
220 self.group_key.zeroize();
221 for key in self.members.values_mut() {
222 key.zeroize();
223 }
224 }
225}
226
227#[derive(Serialize, Deserialize)]
229pub struct EncryptedBroadcast {
230 pub version: u8,
231 pub ciphertext: String,
232 pub subgroups: HashMap<String, String>,
233 pub members: HashMap<String, MemberKeyData>,
234}
235
236#[derive(Serialize, Deserialize, Clone)]
237pub struct MemberKeyData {
238 pub sg: u32,
239 pub key: String,
240}
241
242pub struct BroadcastEncryption {
244 master_key: [u8; 32],
245 subgroup_size: usize,
246 members: HashMap<String, (u32, [u8; 32])>,
247 subgroup_keys: HashMap<u32, [u8; 32]>,
248 next_subgroup: u32,
249}
250
251impl BroadcastEncryption {
252 pub fn new(master_key: Option<[u8; 32]>, subgroup_size: usize) -> Result<Self> {
254 let key = if let Some(k) = master_key {
255 k
256 } else {
257 crate::random::random_bytes()?
258 };
259
260 Ok(Self {
261 master_key: key,
262 subgroup_size: if subgroup_size == 0 {
263 16
264 } else {
265 subgroup_size
266 },
267 members: HashMap::new(),
268 subgroup_keys: HashMap::new(),
269 next_subgroup: 0,
270 })
271 }
272
273 pub fn add_member(&mut self, member_id: &str, member_key: [u8; 32]) -> Result<u32> {
275 let mut subgroup_id = None;
277 for sg_id in self.subgroup_keys.keys() {
278 let count = self.members.values().filter(|(sg, _)| sg == sg_id).count();
279 if count < self.subgroup_size {
280 subgroup_id = Some(*sg_id);
281 break;
282 }
283 }
284
285 let sg_id = if let Some(id) = subgroup_id {
286 id
287 } else {
288 let id = self.next_subgroup;
289 let sg_key: [u8; 32] = crate::random::random_bytes()?;
290 self.subgroup_keys.insert(id, sg_key);
291 self.next_subgroup += 1;
292 id
293 };
294
295 self.members
296 .insert(member_id.to_string(), (sg_id, member_key));
297 Ok(sg_id)
298 }
299
300 pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedBroadcast> {
302 let message_key: [u8; 32] = crate::random::random_bytes()?;
303
304 let ciphertext = encrypt_block(&message_key, plaintext)?;
306
307 let mut subgroups = HashMap::new();
309 for (sg_id, sg_key) in &self.subgroup_keys {
310 let encrypted_msg_key = encrypt_block(sg_key, &message_key)?;
311 subgroups.insert(
312 sg_id.to_string(),
313 URL_SAFE_NO_PAD.encode(&encrypted_msg_key),
314 );
315 }
316
317 let mut members = HashMap::new();
319 for (member_id, (sg_id, member_key)) in &self.members {
320 let Some(sg_key) = self.subgroup_keys.get(sg_id) else {
321 continue;
322 };
323 let encrypted_sg_key = encrypt_block(member_key, sg_key)?;
324 members.insert(
325 member_id.clone(),
326 MemberKeyData {
327 sg: *sg_id,
328 key: URL_SAFE_NO_PAD.encode(&encrypted_sg_key),
329 },
330 );
331 }
332
333 Ok(EncryptedBroadcast {
334 version: 1,
335 ciphertext: URL_SAFE_NO_PAD.encode(&ciphertext),
336 subgroups,
337 members,
338 })
339 }
340
341 pub fn decrypt_as_member(
343 encrypted: &EncryptedBroadcast,
344 member_id: &str,
345 member_key: &[u8; 32],
346 ) -> Result<Vec<u8>> {
347 let member_data = encrypted
348 .members
349 .get(member_id)
350 .ok_or(ShieldError::MemberNotFound)?;
351
352 let sg_key_enc = URL_SAFE_NO_PAD
354 .decode(&member_data.key)
355 .map_err(|_| ShieldError::InvalidFormat)?;
356 let sg_key_vec = decrypt_block(member_key, &sg_key_enc)?;
357 let mut sg_key = [0u8; 32];
358 sg_key.copy_from_slice(&sg_key_vec);
359
360 let msg_key_enc = URL_SAFE_NO_PAD
362 .decode(
363 encrypted
364 .subgroups
365 .get(&member_data.sg.to_string())
366 .ok_or(ShieldError::InvalidFormat)?,
367 )
368 .map_err(|_| ShieldError::InvalidFormat)?;
369 let msg_key_vec = decrypt_block(&sg_key, &msg_key_enc)?;
370 let mut msg_key = [0u8; 32];
371 msg_key.copy_from_slice(&msg_key_vec);
372
373 let ciphertext = URL_SAFE_NO_PAD
375 .decode(&encrypted.ciphertext)
376 .map_err(|_| ShieldError::InvalidFormat)?;
377 decrypt_block(&msg_key, &ciphertext)
378 }
379}
380
381impl Drop for BroadcastEncryption {
382 fn drop(&mut self) {
383 self.master_key.zeroize();
384 for key in self.members.values_mut() {
385 key.1.zeroize();
386 }
387 for key in self.subgroup_keys.values_mut() {
388 key.zeroize();
389 }
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[test]
398 fn test_group_encrypt_decrypt() {
399 let mut group = GroupEncryption::new(None).unwrap();
400 let alice_key = [1u8; 32];
401 let bob_key = [2u8; 32];
402
403 group.add_member("alice", alice_key);
404 group.add_member("bob", bob_key);
405
406 let plaintext = b"Group message!";
407 let encrypted = group.encrypt(plaintext).unwrap();
408
409 let alice_decrypted =
410 GroupEncryption::decrypt_as_member(&encrypted, "alice", &alice_key).unwrap();
411 let bob_decrypted =
412 GroupEncryption::decrypt_as_member(&encrypted, "bob", &bob_key).unwrap();
413
414 assert_eq!(plaintext.as_slice(), alice_decrypted.as_slice());
415 assert_eq!(plaintext.as_slice(), bob_decrypted.as_slice());
416 }
417
418 #[test]
419 fn test_group_non_member() {
420 let mut group = GroupEncryption::new(None).unwrap();
421 group.add_member("alice", [1u8; 32]);
422
423 let encrypted = group.encrypt(b"secret").unwrap();
424 let result = GroupEncryption::decrypt_as_member(&encrypted, "eve", &[3u8; 32]);
425 assert!(result.is_err());
426 }
427
428 #[test]
429 fn test_group_remove_member() {
430 let mut group = GroupEncryption::new(None).unwrap();
431 group.add_member("alice", [1u8; 32]);
432 group.add_member("bob", [2u8; 32]);
433
434 assert_eq!(group.members().len(), 2);
435 group.remove_member("bob");
436 assert_eq!(group.members().len(), 1);
437 }
438
439 #[test]
440 fn test_broadcast_encrypt_decrypt() {
441 let mut broadcast = BroadcastEncryption::new(None, 2).unwrap();
442 let alice_key = [1u8; 32];
443 let bob_key = [2u8; 32];
444
445 broadcast.add_member("alice", alice_key).unwrap();
446 broadcast.add_member("bob", bob_key).unwrap();
447
448 let plaintext = b"Broadcast message!";
449 let encrypted = broadcast.encrypt(plaintext).unwrap();
450
451 let alice_decrypted =
452 BroadcastEncryption::decrypt_as_member(&encrypted, "alice", &alice_key).unwrap();
453 let bob_decrypted =
454 BroadcastEncryption::decrypt_as_member(&encrypted, "bob", &bob_key).unwrap();
455
456 assert_eq!(plaintext.as_slice(), alice_decrypted.as_slice());
457 assert_eq!(plaintext.as_slice(), bob_decrypted.as_slice());
458 }
459
460 #[test]
461 fn test_broadcast_subgroups() {
462 let mut broadcast = BroadcastEncryption::new(None, 2).unwrap();
463
464 let sg1 = broadcast.add_member("alice", [1u8; 32]).unwrap();
465 let sg2 = broadcast.add_member("bob", [2u8; 32]).unwrap();
466 let sg3 = broadcast.add_member("carol", [3u8; 32]).unwrap();
467
468 assert_eq!(sg1, 0);
470 assert_eq!(sg2, 0);
471 assert_eq!(sg3, 1);
473 }
474}