1#![allow(clippy::cast_possible_truncation)]
7
8use base64::{engine::general_purpose::URL_SAFE_NO_PAD, Engine};
9use ring::hmac;
10use ring::rand::{SecureRandom, SystemRandom};
11use serde::{Deserialize, Serialize};
12use std::collections::HashMap;
13use subtle::ConstantTimeEq;
14
15use crate::error::{Result, ShieldError};
16
17fn generate_keystream(key: &[u8], nonce: &[u8], length: usize) -> Vec<u8> {
19 let mut keystream = Vec::with_capacity(length.div_ceil(32) * 32);
20 let num_blocks = length.div_ceil(32);
21
22 for i in 0..num_blocks {
23 let counter = (i as u32).to_le_bytes();
24 let mut data = Vec::with_capacity(key.len() + nonce.len() + 4);
25 data.extend_from_slice(key);
26 data.extend_from_slice(nonce);
27 data.extend_from_slice(&counter);
28
29 let hash = ring::digest::digest(&ring::digest::SHA256, &data);
30 keystream.extend_from_slice(hash.as_ref());
31 }
32
33 keystream.truncate(length);
34 keystream
35}
36
37fn encrypt_block(key: &[u8; 32], data: &[u8]) -> Result<Vec<u8>> {
39 let rng = SystemRandom::new();
40 let mut nonce = [0u8; 16];
41 rng.fill(&mut nonce).map_err(|_| ShieldError::RandomFailed)?;
42
43 let keystream = generate_keystream(key, &nonce, data.len());
44 let ciphertext: Vec<u8> = data.iter().zip(keystream.iter()).map(|(p, k)| p ^ k).collect();
45
46 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
47 let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
48 hmac_data.extend_from_slice(&nonce);
49 hmac_data.extend_from_slice(&ciphertext);
50 let tag = hmac::sign(&hmac_key, &hmac_data);
51
52 let mut result = Vec::with_capacity(16 + ciphertext.len() + 16);
53 result.extend_from_slice(&nonce);
54 result.extend_from_slice(&ciphertext);
55 result.extend_from_slice(&tag.as_ref()[..16]);
56
57 Ok(result)
58}
59
60fn decrypt_block(key: &[u8; 32], encrypted: &[u8]) -> Result<Vec<u8>> {
62 if encrypted.len() < 32 {
63 return Err(ShieldError::CiphertextTooShort {
64 expected: 32,
65 actual: encrypted.len(),
66 });
67 }
68
69 let nonce = &encrypted[..16];
70 let ciphertext = &encrypted[16..encrypted.len() - 16];
71 let mac = &encrypted[encrypted.len() - 16..];
72
73 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
74 let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
75 hmac_data.extend_from_slice(nonce);
76 hmac_data.extend_from_slice(ciphertext);
77 let expected_tag = hmac::sign(&hmac_key, &hmac_data);
78
79 if mac.ct_eq(&expected_tag.as_ref()[..16]).unwrap_u8() != 1 {
80 return Err(ShieldError::AuthenticationFailed);
81 }
82
83 let keystream = generate_keystream(key, nonce, ciphertext.len());
84 let plaintext: Vec<u8> = ciphertext
85 .iter()
86 .zip(keystream.iter())
87 .map(|(c, k)| c ^ k)
88 .collect();
89
90 Ok(plaintext)
91}
92
93#[derive(Serialize, Deserialize)]
95pub struct EncryptedGroupMessage {
96 pub version: u8,
97 pub ciphertext: String,
98 pub keys: HashMap<String, String>,
99}
100
101pub struct GroupEncryption {
103 group_key: [u8; 32],
104 members: HashMap<String, [u8; 32]>,
105}
106
107impl GroupEncryption {
108 pub fn new(group_key: Option<[u8; 32]>) -> Result<Self> {
110 let key = if let Some(k) = group_key { k } else {
111 let rng = SystemRandom::new();
112 let mut k = [0u8; 32];
113 rng.fill(&mut k).map_err(|_| ShieldError::RandomFailed)?;
114 k
115 };
116
117 Ok(Self {
118 group_key: key,
119 members: HashMap::new(),
120 })
121 }
122
123 pub fn add_member(&mut self, member_id: &str, shared_key: [u8; 32]) {
125 self.members.insert(member_id.to_string(), shared_key);
126 }
127
128 pub fn remove_member(&mut self, member_id: &str) -> bool {
130 self.members.remove(member_id).is_some()
131 }
132
133 pub fn members(&self) -> Vec<&str> {
135 self.members.keys().map(String::as_str).collect()
136 }
137
138 pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedGroupMessage> {
140 let ciphertext = encrypt_block(&self.group_key, plaintext)?;
141
142 let mut keys = HashMap::new();
143 for (member_id, member_key) in &self.members {
144 let encrypted_key = encrypt_block(member_key, &self.group_key)?;
145 keys.insert(member_id.clone(), URL_SAFE_NO_PAD.encode(&encrypted_key));
146 }
147
148 Ok(EncryptedGroupMessage {
149 version: 1,
150 ciphertext: URL_SAFE_NO_PAD.encode(&ciphertext),
151 keys,
152 })
153 }
154
155 pub fn decrypt_as_member(
157 encrypted: &EncryptedGroupMessage,
158 member_id: &str,
159 member_key: &[u8; 32],
160 ) -> Result<Vec<u8>> {
161 let encrypted_group_key = encrypted
162 .keys
163 .get(member_id)
164 .ok_or(ShieldError::MemberNotFound)?;
165
166 let encrypted_key_bytes = URL_SAFE_NO_PAD
167 .decode(encrypted_group_key)
168 .map_err(|_| ShieldError::InvalidFormat)?;
169
170 let group_key_vec = decrypt_block(member_key, &encrypted_key_bytes)?;
171 let mut group_key = [0u8; 32];
172 group_key.copy_from_slice(&group_key_vec);
173
174 let ciphertext = URL_SAFE_NO_PAD
175 .decode(&encrypted.ciphertext)
176 .map_err(|_| ShieldError::InvalidFormat)?;
177
178 decrypt_block(&group_key, &ciphertext)
179 }
180
181 pub fn rotate_key(&mut self) -> Result<[u8; 32]> {
183 let old_key = self.group_key;
184 let rng = SystemRandom::new();
185 rng.fill(&mut self.group_key)
186 .map_err(|_| ShieldError::RandomFailed)?;
187 Ok(old_key)
188 }
189
190 #[must_use]
192 pub fn group_key(&self) -> &[u8; 32] {
193 &self.group_key
194 }
195}
196
197#[derive(Serialize, Deserialize)]
199pub struct EncryptedBroadcast {
200 pub version: u8,
201 pub ciphertext: String,
202 pub subgroups: HashMap<String, String>,
203 pub members: HashMap<String, MemberKeyData>,
204}
205
206#[derive(Serialize, Deserialize, Clone)]
207pub struct MemberKeyData {
208 pub sg: u32,
209 pub key: String,
210}
211
212pub struct BroadcastEncryption {
214 #[allow(dead_code)]
215 master_key: [u8; 32],
216 subgroup_size: usize,
217 members: HashMap<String, (u32, [u8; 32])>,
218 subgroup_keys: HashMap<u32, [u8; 32]>,
219 next_subgroup: u32,
220}
221
222impl BroadcastEncryption {
223 pub fn new(master_key: Option<[u8; 32]>, subgroup_size: usize) -> Result<Self> {
225 let key = if let Some(k) = master_key { k } else {
226 let rng = SystemRandom::new();
227 let mut k = [0u8; 32];
228 rng.fill(&mut k).map_err(|_| ShieldError::RandomFailed)?;
229 k
230 };
231
232 Ok(Self {
233 master_key: key,
234 subgroup_size: if subgroup_size == 0 { 16 } else { subgroup_size },
235 members: HashMap::new(),
236 subgroup_keys: HashMap::new(),
237 next_subgroup: 0,
238 })
239 }
240
241 pub fn add_member(&mut self, member_id: &str, member_key: [u8; 32]) -> Result<u32> {
243 let rng = SystemRandom::new();
244
245 let mut subgroup_id = None;
247 for sg_id in self.subgroup_keys.keys() {
248 let count = self.members.values().filter(|(sg, _)| sg == sg_id).count();
249 if count < self.subgroup_size {
250 subgroup_id = Some(*sg_id);
251 break;
252 }
253 }
254
255 let sg_id = if let Some(id) = subgroup_id { id } else {
256 let id = self.next_subgroup;
257 let mut sg_key = [0u8; 32];
258 rng.fill(&mut sg_key)
259 .map_err(|_| ShieldError::RandomFailed)?;
260 self.subgroup_keys.insert(id, sg_key);
261 self.next_subgroup += 1;
262 id
263 };
264
265 self.members.insert(member_id.to_string(), (sg_id, member_key));
266 Ok(sg_id)
267 }
268
269 pub fn encrypt(&self, plaintext: &[u8]) -> Result<EncryptedBroadcast> {
271 let rng = SystemRandom::new();
272 let mut message_key = [0u8; 32];
273 rng.fill(&mut message_key)
274 .map_err(|_| ShieldError::RandomFailed)?;
275
276 let ciphertext = encrypt_block(&message_key, plaintext)?;
278
279 let mut subgroups = HashMap::new();
281 for (sg_id, sg_key) in &self.subgroup_keys {
282 let encrypted_msg_key = encrypt_block(sg_key, &message_key)?;
283 subgroups.insert(sg_id.to_string(), URL_SAFE_NO_PAD.encode(&encrypted_msg_key));
284 }
285
286 let mut members = HashMap::new();
288 for (member_id, (sg_id, member_key)) in &self.members {
289 let sg_key = self.subgroup_keys.get(sg_id).unwrap();
290 let encrypted_sg_key = encrypt_block(member_key, sg_key)?;
291 members.insert(
292 member_id.clone(),
293 MemberKeyData {
294 sg: *sg_id,
295 key: URL_SAFE_NO_PAD.encode(&encrypted_sg_key),
296 },
297 );
298 }
299
300 Ok(EncryptedBroadcast {
301 version: 1,
302 ciphertext: URL_SAFE_NO_PAD.encode(&ciphertext),
303 subgroups,
304 members,
305 })
306 }
307
308 pub fn decrypt_as_member(
310 encrypted: &EncryptedBroadcast,
311 member_id: &str,
312 member_key: &[u8; 32],
313 ) -> Result<Vec<u8>> {
314 let member_data = encrypted
315 .members
316 .get(member_id)
317 .ok_or(ShieldError::MemberNotFound)?;
318
319 let sg_key_enc = URL_SAFE_NO_PAD
321 .decode(&member_data.key)
322 .map_err(|_| ShieldError::InvalidFormat)?;
323 let sg_key_vec = decrypt_block(member_key, &sg_key_enc)?;
324 let mut sg_key = [0u8; 32];
325 sg_key.copy_from_slice(&sg_key_vec);
326
327 let msg_key_enc = URL_SAFE_NO_PAD
329 .decode(
330 encrypted
331 .subgroups
332 .get(&member_data.sg.to_string())
333 .ok_or(ShieldError::InvalidFormat)?,
334 )
335 .map_err(|_| ShieldError::InvalidFormat)?;
336 let msg_key_vec = decrypt_block(&sg_key, &msg_key_enc)?;
337 let mut msg_key = [0u8; 32];
338 msg_key.copy_from_slice(&msg_key_vec);
339
340 let ciphertext = URL_SAFE_NO_PAD
342 .decode(&encrypted.ciphertext)
343 .map_err(|_| ShieldError::InvalidFormat)?;
344 decrypt_block(&msg_key, &ciphertext)
345 }
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 fn test_group_encrypt_decrypt() {
354 let mut group = GroupEncryption::new(None).unwrap();
355 let alice_key = [1u8; 32];
356 let bob_key = [2u8; 32];
357
358 group.add_member("alice", alice_key);
359 group.add_member("bob", bob_key);
360
361 let plaintext = b"Group message!";
362 let encrypted = group.encrypt(plaintext).unwrap();
363
364 let alice_decrypted =
365 GroupEncryption::decrypt_as_member(&encrypted, "alice", &alice_key).unwrap();
366 let bob_decrypted =
367 GroupEncryption::decrypt_as_member(&encrypted, "bob", &bob_key).unwrap();
368
369 assert_eq!(plaintext.as_slice(), alice_decrypted.as_slice());
370 assert_eq!(plaintext.as_slice(), bob_decrypted.as_slice());
371 }
372
373 #[test]
374 fn test_group_non_member() {
375 let mut group = GroupEncryption::new(None).unwrap();
376 group.add_member("alice", [1u8; 32]);
377
378 let encrypted = group.encrypt(b"secret").unwrap();
379 let result = GroupEncryption::decrypt_as_member(&encrypted, "eve", &[3u8; 32]);
380 assert!(result.is_err());
381 }
382
383 #[test]
384 fn test_group_remove_member() {
385 let mut group = GroupEncryption::new(None).unwrap();
386 group.add_member("alice", [1u8; 32]);
387 group.add_member("bob", [2u8; 32]);
388
389 assert_eq!(group.members().len(), 2);
390 group.remove_member("bob");
391 assert_eq!(group.members().len(), 1);
392 }
393
394 #[test]
395 fn test_broadcast_encrypt_decrypt() {
396 let mut broadcast = BroadcastEncryption::new(None, 2).unwrap();
397 let alice_key = [1u8; 32];
398 let bob_key = [2u8; 32];
399
400 broadcast.add_member("alice", alice_key).unwrap();
401 broadcast.add_member("bob", bob_key).unwrap();
402
403 let plaintext = b"Broadcast message!";
404 let encrypted = broadcast.encrypt(plaintext).unwrap();
405
406 let alice_decrypted =
407 BroadcastEncryption::decrypt_as_member(&encrypted, "alice", &alice_key).unwrap();
408 let bob_decrypted =
409 BroadcastEncryption::decrypt_as_member(&encrypted, "bob", &bob_key).unwrap();
410
411 assert_eq!(plaintext.as_slice(), alice_decrypted.as_slice());
412 assert_eq!(plaintext.as_slice(), bob_decrypted.as_slice());
413 }
414
415 #[test]
416 fn test_broadcast_subgroups() {
417 let mut broadcast = BroadcastEncryption::new(None, 2).unwrap();
418
419 let sg1 = broadcast.add_member("alice", [1u8; 32]).unwrap();
420 let sg2 = broadcast.add_member("bob", [2u8; 32]).unwrap();
421 let sg3 = broadcast.add_member("carol", [3u8; 32]).unwrap();
422
423 assert_eq!(sg1, 0);
425 assert_eq!(sg2, 0);
426 assert_eq!(sg3, 1);
428 }
429}