1#![allow(clippy::cast_possible_truncation)]
7
8use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
9use ring::hmac;
10use ring::rand::{SecureRandom, SystemRandom};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use subtle::ConstantTimeEq;
14
15use crate::error::{Result, ShieldError};
16
17fn generate_keystream(key: &[u8], nonce: &[u8], length: usize) -> Vec<u8> {
19 let mut keystream = Vec::with_capacity(length.div_ceil(32) * 32);
20 let num_blocks = length.div_ceil(32);
21
22 for i in 0..num_blocks {
23 let counter = (i as u32).to_le_bytes();
24 let mut data = Vec::with_capacity(key.len() + nonce.len() + 4);
25 data.extend_from_slice(key);
26 data.extend_from_slice(nonce);
27 data.extend_from_slice(&counter);
28
29 let hash = ring::digest::digest(&ring::digest::SHA256, &data);
30 keystream.extend_from_slice(hash.as_ref());
31 }
32
33 keystream.truncate(length);
34 keystream
35}
36
37fn encrypt_block(key: &[u8; 32], data: &[u8]) -> Result<Vec<u8>> {
39 let rng = SystemRandom::new();
40 let mut nonce = [0u8; 16];
41 rng.fill(&mut nonce)
42 .map_err(|_| ShieldError::RandomFailed)?;
43
44 let keystream = generate_keystream(key, &nonce, data.len());
45 let ciphertext: Vec<u8> = data
46 .iter()
47 .zip(keystream.iter())
48 .map(|(p, k)| p ^ k)
49 .collect();
50
51 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
52 let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
53 hmac_data.extend_from_slice(&nonce);
54 hmac_data.extend_from_slice(&ciphertext);
55 let tag = hmac::sign(&hmac_key, &hmac_data);
56
57 let mut result = Vec::with_capacity(16 + ciphertext.len() + 16);
58 result.extend_from_slice(&nonce);
59 result.extend_from_slice(&ciphertext);
60 result.extend_from_slice(&tag.as_ref()[..16]);
61
62 Ok(result)
63}
64
65fn decrypt_block(key: &[u8; 32], encrypted: &[u8]) -> Result<Vec<u8>> {
67 if encrypted.len() < 32 {
68 return Err(ShieldError::CiphertextTooShort {
69 expected: 32,
70 actual: encrypted.len(),
71 });
72 }
73
74 let nonce = &encrypted[..16];
75 let ciphertext = &encrypted[16..encrypted.len() - 16];
76 let mac = &encrypted[encrypted.len() - 16..];
77
78 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
79 let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
80 hmac_data.extend_from_slice(nonce);
81 hmac_data.extend_from_slice(ciphertext);
82 let expected_tag = hmac::sign(&hmac_key, &hmac_data);
83
84 if mac.ct_eq(&expected_tag.as_ref()[..16]).unwrap_u8() != 1 {
85 return Err(ShieldError::AuthenticationFailed);
86 }
87
88 let keystream = generate_keystream(key, nonce, ciphertext.len());
89 let plaintext: Vec<u8> = ciphertext
90 .iter()
91 .zip(keystream.iter())
92 .map(|(c, k)| c ^ k)
93 .collect();
94
95 Ok(plaintext)
96}
97
98#[derive(Serialize, Deserialize)]
100pub struct EncryptedGroupMessage {
101 pub version: u8,
102 pub ciphertext: String,
103 pub keys: HashMap<String, String>,
104}
105
106pub struct GroupEncryption {
108 group_key: [u8; 32],
109 members: HashMap<String, [u8; 32]>,
110}
111
112impl GroupEncryption {
113 pub fn new(group_key: Option<[u8; 32]>) -> Result<Self> {
115 let key = if let Some(k) = group_key {
116 k
117 } else {
118 let rng = SystemRandom::new();
119 let mut k = [0u8; 32];
120 rng.fill(&mut k).map_err(|_| ShieldError::RandomFailed)?;
121 k
122 };
123
124 Ok(Self {
125 group_key: key,
126 members: HashMap::new(),
127 })
128 }
129
130 pub fn add_member(&mut self, member_id: &str, shared_key: [u8; 32]) {
132 self.members.insert(member_id.to_string(), shared_key);
133 }
134
135 pub fn remove_member(&mut self, member_id: &str) -> bool {
137 self.members.remove(member_id).is_some()
138 }
139
140 pub fn members(&self) -> Vec<&str> {
142 self.members.keys().map(String::as_str).collect()
143 }
144
145 pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedGroupMessage> {
147 let ciphertext = encrypt_block(&self.group_key, plaintext)?;
148
149 let mut keys = HashMap::new();
150 for (member_id, member_key) in &self.members {
151 let encrypted_key = encrypt_block(member_key, &self.group_key)?;
152 keys.insert(member_id.clone(), URL_SAFE_NO_PAD.encode(&encrypted_key));
153 }
154
155 Ok(EncryptedGroupMessage {
156 version: 1,
157 ciphertext: URL_SAFE_NO_PAD.encode(&ciphertext),
158 keys,
159 })
160 }
161
162 pub fn decrypt_as_member(
164 encrypted: &EncryptedGroupMessage,
165 member_id: &str,
166 member_key: &[u8; 32],
167 ) -> Result<Vec<u8>> {
168 let encrypted_group_key = encrypted
169 .keys
170 .get(member_id)
171 .ok_or(ShieldError::MemberNotFound)?;
172
173 let encrypted_key_bytes = URL_SAFE_NO_PAD
174 .decode(encrypted_group_key)
175 .map_err(|_| ShieldError::InvalidFormat)?;
176
177 let group_key_vec = decrypt_block(member_key, &encrypted_key_bytes)?;
178 let mut group_key = [0u8; 32];
179 group_key.copy_from_slice(&group_key_vec);
180
181 let ciphertext = URL_SAFE_NO_PAD
182 .decode(&encrypted.ciphertext)
183 .map_err(|_| ShieldError::InvalidFormat)?;
184
185 decrypt_block(&group_key, &ciphertext)
186 }
187
188 pub fn rotate_key(&mut self) -> Result<[u8; 32]> {
190 let old_key = self.group_key;
191 let rng = SystemRandom::new();
192 rng.fill(&mut self.group_key)
193 .map_err(|_| ShieldError::RandomFailed)?;
194 Ok(old_key)
195 }
196
197 #[must_use]
199 pub fn group_key(&self) -> &[u8; 32] {
200 &self.group_key
201 }
202}
203
204#[derive(Serialize, Deserialize)]
206pub struct EncryptedBroadcast {
207 pub version: u8,
208 pub ciphertext: String,
209 pub subgroups: HashMap<String, String>,
210 pub members: HashMap<String, MemberKeyData>,
211}
212
213#[derive(Serialize, Deserialize, Clone)]
214pub struct MemberKeyData {
215 pub sg: u32,
216 pub key: String,
217}
218
219pub struct BroadcastEncryption {
221 #[allow(dead_code)]
222 master_key: [u8; 32],
223 subgroup_size: usize,
224 members: HashMap<String, (u32, [u8; 32])>,
225 subgroup_keys: HashMap<u32, [u8; 32]>,
226 next_subgroup: u32,
227}
228
229impl BroadcastEncryption {
230 pub fn new(master_key: Option<[u8; 32]>, subgroup_size: usize) -> Result<Self> {
232 let key = if let Some(k) = master_key {
233 k
234 } else {
235 let rng = SystemRandom::new();
236 let mut k = [0u8; 32];
237 rng.fill(&mut k).map_err(|_| ShieldError::RandomFailed)?;
238 k
239 };
240
241 Ok(Self {
242 master_key: key,
243 subgroup_size: if subgroup_size == 0 {
244 16
245 } else {
246 subgroup_size
247 },
248 members: HashMap::new(),
249 subgroup_keys: HashMap::new(),
250 next_subgroup: 0,
251 })
252 }
253
254 pub fn add_member(&mut self, member_id: &str, member_key: [u8; 32]) -> Result<u32> {
256 let rng = SystemRandom::new();
257
258 let mut subgroup_id = None;
260 for sg_id in self.subgroup_keys.keys() {
261 let count = self.members.values().filter(|(sg, _)| sg == sg_id).count();
262 if count < self.subgroup_size {
263 subgroup_id = Some(*sg_id);
264 break;
265 }
266 }
267
268 let sg_id = if let Some(id) = subgroup_id {
269 id
270 } else {
271 let id = self.next_subgroup;
272 let mut sg_key = [0u8; 32];
273 rng.fill(&mut sg_key)
274 .map_err(|_| ShieldError::RandomFailed)?;
275 self.subgroup_keys.insert(id, sg_key);
276 self.next_subgroup += 1;
277 id
278 };
279
280 self.members
281 .insert(member_id.to_string(), (sg_id, member_key));
282 Ok(sg_id)
283 }
284
285 pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedBroadcast> {
287 let rng = SystemRandom::new();
288 let mut message_key = [0u8; 32];
289 rng.fill(&mut message_key)
290 .map_err(|_| ShieldError::RandomFailed)?;
291
292 let ciphertext = encrypt_block(&message_key, plaintext)?;
294
295 let mut subgroups = HashMap::new();
297 for (sg_id, sg_key) in &self.subgroup_keys {
298 let encrypted_msg_key = encrypt_block(sg_key, &message_key)?;
299 subgroups.insert(
300 sg_id.to_string(),
301 URL_SAFE_NO_PAD.encode(&encrypted_msg_key),
302 );
303 }
304
305 let mut members = HashMap::new();
307 for (member_id, (sg_id, member_key)) in &self.members {
308 let sg_key = self.subgroup_keys.get(sg_id).unwrap();
309 let encrypted_sg_key = encrypt_block(member_key, sg_key)?;
310 members.insert(
311 member_id.clone(),
312 MemberKeyData {
313 sg: *sg_id,
314 key: URL_SAFE_NO_PAD.encode(&encrypted_sg_key),
315 },
316 );
317 }
318
319 Ok(EncryptedBroadcast {
320 version: 1,
321 ciphertext: URL_SAFE_NO_PAD.encode(&ciphertext),
322 subgroups,
323 members,
324 })
325 }
326
327 pub fn decrypt_as_member(
329 encrypted: &EncryptedBroadcast,
330 member_id: &str,
331 member_key: &[u8; 32],
332 ) -> Result<Vec<u8>> {
333 let member_data = encrypted
334 .members
335 .get(member_id)
336 .ok_or(ShieldError::MemberNotFound)?;
337
338 let sg_key_enc = URL_SAFE_NO_PAD
340 .decode(&member_data.key)
341 .map_err(|_| ShieldError::InvalidFormat)?;
342 let sg_key_vec = decrypt_block(member_key, &sg_key_enc)?;
343 let mut sg_key = [0u8; 32];
344 sg_key.copy_from_slice(&sg_key_vec);
345
346 let msg_key_enc = URL_SAFE_NO_PAD
348 .decode(
349 encrypted
350 .subgroups
351 .get(&member_data.sg.to_string())
352 .ok_or(ShieldError::InvalidFormat)?,
353 )
354 .map_err(|_| ShieldError::InvalidFormat)?;
355 let msg_key_vec = decrypt_block(&sg_key, &msg_key_enc)?;
356 let mut msg_key = [0u8; 32];
357 msg_key.copy_from_slice(&msg_key_vec);
358
359 let ciphertext = URL_SAFE_NO_PAD
361 .decode(&encrypted.ciphertext)
362 .map_err(|_| ShieldError::InvalidFormat)?;
363 decrypt_block(&msg_key, &ciphertext)
364 }
365}
366
367#[cfg(test)]
368mod tests {
369 use super::*;
370
371 #[test]
372 fn test_group_encrypt_decrypt() {
373 let mut group = GroupEncryption::new(None).unwrap();
374 let alice_key = [1u8; 32];
375 let bob_key = [2u8; 32];
376
377 group.add_member("alice", alice_key);
378 group.add_member("bob", bob_key);
379
380 let plaintext = b"Group message!";
381 let encrypted = group.encrypt(plaintext).unwrap();
382
383 let alice_decrypted =
384 GroupEncryption::decrypt_as_member(&encrypted, "alice", &alice_key).unwrap();
385 let bob_decrypted =
386 GroupEncryption::decrypt_as_member(&encrypted, "bob", &bob_key).unwrap();
387
388 assert_eq!(plaintext.as_slice(), alice_decrypted.as_slice());
389 assert_eq!(plaintext.as_slice(), bob_decrypted.as_slice());
390 }
391
392 #[test]
393 fn test_group_non_member() {
394 let mut group = GroupEncryption::new(None).unwrap();
395 group.add_member("alice", [1u8; 32]);
396
397 let encrypted = group.encrypt(b"secret").unwrap();
398 let result = GroupEncryption::decrypt_as_member(&encrypted, "eve", &[3u8; 32]);
399 assert!(result.is_err());
400 }
401
402 #[test]
403 fn test_group_remove_member() {
404 let mut group = GroupEncryption::new(None).unwrap();
405 group.add_member("alice", [1u8; 32]);
406 group.add_member("bob", [2u8; 32]);
407
408 assert_eq!(group.members().len(), 2);
409 group.remove_member("bob");
410 assert_eq!(group.members().len(), 1);
411 }
412
413 #[test]
414 fn test_broadcast_encrypt_decrypt() {
415 let mut broadcast = BroadcastEncryption::new(None, 2).unwrap();
416 let alice_key = [1u8; 32];
417 let bob_key = [2u8; 32];
418
419 broadcast.add_member("alice", alice_key).unwrap();
420 broadcast.add_member("bob", bob_key).unwrap();
421
422 let plaintext = b"Broadcast message!";
423 let encrypted = broadcast.encrypt(plaintext).unwrap();
424
425 let alice_decrypted =
426 BroadcastEncryption::decrypt_as_member(&encrypted, "alice", &alice_key).unwrap();
427 let bob_decrypted =
428 BroadcastEncryption::decrypt_as_member(&encrypted, "bob", &bob_key).unwrap();
429
430 assert_eq!(plaintext.as_slice(), alice_decrypted.as_slice());
431 assert_eq!(plaintext.as_slice(), bob_decrypted.as_slice());
432 }
433
434 #[test]
435 fn test_broadcast_subgroups() {
436 let mut broadcast = BroadcastEncryption::new(None, 2).unwrap();
437
438 let sg1 = broadcast.add_member("alice", [1u8; 32]).unwrap();
439 let sg2 = broadcast.add_member("bob", [2u8; 32]).unwrap();
440 let sg3 = broadcast.add_member("carol", [3u8; 32]).unwrap();
441
442 assert_eq!(sg1, 0);
444 assert_eq!(sg2, 0);
445 assert_eq!(sg3, 1);
447 }
448}