1#![allow(clippy::cast_possible_truncation)]
7
8use ring::{hmac, pbkdf2};
9use std::num::NonZeroU32;
10use subtle::ConstantTimeEq;
11use zeroize::{Zeroize, ZeroizeOnDrop};
12
13use crate::error::{Result, ShieldError};
14
15fn current_timestamp_ms() -> u64 {
17 #[cfg(target_arch = "wasm32")]
18 {
19 #[wasm_bindgen::prelude::wasm_bindgen]
21 extern "C" {
22 #[wasm_bindgen::prelude::wasm_bindgen(js_namespace = Date, js_name = now)]
23 fn date_now() -> f64;
24 }
25 date_now() as u64
26 }
27 #[cfg(not(target_arch = "wasm32"))]
28 {
29 std::time::SystemTime::now()
30 .duration_since(std::time::UNIX_EPOCH)
31 .unwrap_or_default()
32 .as_millis() as u64
33 }
34}
35
36const PBKDF2_ITERATIONS: u32 = 100_000;
38
39const NONCE_SIZE: usize = 16;
41
42const MAC_SIZE: usize = 16;
44
45const MIN_CIPHERTEXT_SIZE: usize = NONCE_SIZE + 8 + MAC_SIZE;
47
48const V2_HEADER_SIZE: usize = 17;
50
51const MIN_PADDING: usize = 32;
53
54const MAX_PADDING: usize = 128;
56
57const MIN_TIMESTAMP_MS: u64 = 1_577_836_800_000;
59const MAX_TIMESTAMP_MS: u64 = 4_102_444_800_000;
60
61#[derive(Zeroize, ZeroizeOnDrop)]
76pub struct Shield {
77 key: [u8; 32],
78 enc_key: [u8; 32],
79 mac_key: [u8; 32],
80 #[zeroize(skip)]
82 max_age_ms: Option<u64>,
83}
84
85fn derive_subkeys(master_key: &[u8; 32]) -> ([u8; 32], [u8; 32]) {
91 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, master_key);
92
93 let enc_tag = hmac::sign(&hmac_key, b"shield-encrypt");
94 let mut enc_key = [0u8; 32];
95 enc_key.copy_from_slice(&enc_tag.as_ref()[..32]);
96
97 let mac_tag = hmac::sign(&hmac_key, b"shield-authenticate");
98 let mut mac_key = [0u8; 32];
99 mac_key.copy_from_slice(&mac_tag.as_ref()[..32]);
100
101 (enc_key, mac_key)
102}
103
104impl Shield {
105 #[must_use]
117 pub fn new(password: &str, service: &str) -> Self {
118 let salt = ring::digest::digest(&ring::digest::SHA256, service.as_bytes());
120
121 let mut key = [0u8; 32];
123 pbkdf2::derive(
124 pbkdf2::PBKDF2_HMAC_SHA256,
125 NonZeroU32::new(PBKDF2_ITERATIONS).unwrap(),
126 salt.as_ref(),
127 password.as_bytes(),
128 &mut key,
129 );
130
131 let (enc_key, mac_key) = derive_subkeys(&key);
132 Self {
133 key,
134 enc_key,
135 mac_key,
136 max_age_ms: Some(60_000), }
138 }
139
140 #[must_use]
142 pub fn with_key(key: [u8; 32]) -> Self {
143 let (enc_key, mac_key) = derive_subkeys(&key);
144 Self {
145 key,
146 enc_key,
147 mac_key,
148 max_age_ms: Some(60_000),
149 }
150 }
151
152 pub fn with_fingerprint(
177 password: &str,
178 service: &str,
179 mode: crate::fingerprint::FingerprintMode,
180 ) -> Result<Self> {
181 let fingerprint = crate::fingerprint::collect_fingerprint(mode)?;
183
184 let combined_password = if fingerprint.is_empty() {
186 password.to_string()
187 } else {
188 format!("{password}:{fingerprint}")
189 };
190
191 let salt = ring::digest::digest(&ring::digest::SHA256, service.as_bytes());
193
194 let mut key = [0u8; 32];
196 pbkdf2::derive(
197 pbkdf2::PBKDF2_HMAC_SHA256,
198 NonZeroU32::new(PBKDF2_ITERATIONS).unwrap(),
199 salt.as_ref(),
200 combined_password.as_bytes(),
201 &mut key,
202 );
203
204 let (enc_key, mac_key) = derive_subkeys(&key);
205 Ok(Self {
206 key,
207 enc_key,
208 mac_key,
209 max_age_ms: Some(60_000),
210 })
211 }
212
213 #[must_use]
218 pub fn with_max_age(mut self, max_age_ms: Option<u64>) -> Self {
219 self.max_age_ms = max_age_ms;
220 self
221 }
222
223 pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
232 Self::encrypt_with_separated_keys(&self.enc_key, &self.mac_key, plaintext)
233 }
234
235 pub fn encrypt_with_key(key: &[u8; 32], plaintext: &[u8]) -> Result<Vec<u8>> {
237 let (enc_key, mac_key) = derive_subkeys(key);
238 Self::encrypt_with_separated_keys(&enc_key, &mac_key, plaintext)
239 }
240
241 fn encrypt_with_separated_keys(
243 enc_key: &[u8; 32],
244 mac_key: &[u8; 32],
245 plaintext: &[u8],
246 ) -> Result<Vec<u8>> {
247 let nonce: [u8; NONCE_SIZE] = crate::random::random_bytes()?;
249
250 let counter_bytes = 0u64.to_le_bytes();
252
253 let timestamp_ms = current_timestamp_ms();
255 let timestamp_bytes = timestamp_ms.to_le_bytes();
256
257 let pad_range = MAX_PADDING - MIN_PADDING + 1; let pad_len = loop {
260 let rand_byte: [u8; 1] = crate::random::random_bytes()?;
261 let val = rand_byte[0] as usize;
262 if val < pad_range * (256 / pad_range) {
265 break (val % pad_range) + MIN_PADDING;
266 }
267 };
268 let padding = crate::random::random_vec(pad_len)?;
269
270 let mut data_to_encrypt = Vec::with_capacity(V2_HEADER_SIZE + pad_len + plaintext.len());
272 data_to_encrypt.extend_from_slice(&counter_bytes);
273 data_to_encrypt.extend_from_slice(×tamp_bytes);
274 data_to_encrypt.push(pad_len as u8);
275 data_to_encrypt.extend_from_slice(&padding);
276 data_to_encrypt.extend_from_slice(plaintext);
277
278 let keystream = generate_keystream(enc_key, &nonce, data_to_encrypt.len())?;
280 let ciphertext: Vec<u8> = data_to_encrypt
281 .iter()
282 .zip(keystream.iter())
283 .map(|(p, k)| p ^ k)
284 .collect();
285
286 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, mac_key);
288 let mut hmac_data = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
289 hmac_data.extend_from_slice(&nonce);
290 hmac_data.extend_from_slice(&ciphertext);
291 let tag = hmac::sign(&hmac_key, &hmac_data);
292
293 let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len() + MAC_SIZE);
295 result.extend_from_slice(&nonce);
296 result.extend_from_slice(&ciphertext);
297 result.extend_from_slice(&tag.as_ref()[..MAC_SIZE]);
298
299 Ok(result)
300 }
301
302 pub fn decrypt(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
310 Self::decrypt_with_max_age(&self.key, encrypted, self.max_age_ms)
311 }
312
313 pub fn decrypt_with_key(key: &[u8; 32], encrypted: &[u8]) -> Result<Vec<u8>> {
315 Self::decrypt_with_max_age(key, encrypted, Some(60_000))
316 }
317
318 pub fn decrypt_with_max_age(
320 key: &[u8; 32],
321 encrypted: &[u8],
322 max_age_ms: Option<u64>,
323 ) -> Result<Vec<u8>> {
324 if encrypted.len() < MIN_CIPHERTEXT_SIZE {
325 return Err(ShieldError::CiphertextTooShort {
326 expected: MIN_CIPHERTEXT_SIZE,
327 actual: encrypted.len(),
328 });
329 }
330
331 let nonce = &encrypted[..NONCE_SIZE];
333 let ciphertext = &encrypted[NONCE_SIZE..encrypted.len() - MAC_SIZE];
334 let mac = &encrypted[encrypted.len() - MAC_SIZE..];
335
336 let mut hmac_data = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
337 hmac_data.extend_from_slice(nonce);
338 hmac_data.extend_from_slice(ciphertext);
339
340 let (enc_key, mac_key) = derive_subkeys(key);
342 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, &mac_key);
343 let expected_tag = hmac::sign(&hmac_key, &hmac_data);
344
345 if mac.ct_eq(&expected_tag.as_ref()[..MAC_SIZE]).unwrap_u8() != 1 {
346 return Err(ShieldError::AuthenticationFailed);
347 }
348
349 let keystream = generate_keystream(&enc_key, nonce, ciphertext.len())?;
351 let decrypted: Vec<u8> = ciphertext
352 .iter()
353 .zip(keystream.iter())
354 .map(|(c, k)| c ^ k)
355 .collect();
356
357 if decrypted.len() >= V2_HEADER_SIZE {
359 let timestamp_bytes = &decrypted[8..16];
360 let mut ts_bytes = [0u8; 8];
361 ts_bytes.copy_from_slice(timestamp_bytes);
362 let timestamp_ms = u64::from_le_bytes(ts_bytes);
363
364 if (MIN_TIMESTAMP_MS..=MAX_TIMESTAMP_MS).contains(×tamp_ms) {
366 let pad_len = decrypted[16] as usize;
368
369 if !(MIN_PADDING..=MAX_PADDING).contains(&pad_len) {
371 return Err(ShieldError::AuthenticationFailed);
372 }
373
374 let data_start = V2_HEADER_SIZE + pad_len;
375
376 if data_start > decrypted.len() {
377 return Err(ShieldError::InvalidFormat);
378 }
379
380 if let Some(max_age) = max_age_ms {
382 let now_ms = std::time::SystemTime::now()
383 .duration_since(std::time::UNIX_EPOCH)
384 .unwrap_or_default()
385 .as_millis() as u64;
386
387 let age = i64::try_from(now_ms).unwrap_or(i64::MAX)
389 - i64::try_from(timestamp_ms).unwrap_or(0);
390
391 if age < -5000 {
393 return Err(ShieldError::InvalidFormat);
394 }
395
396 if age > i64::try_from(max_age).unwrap_or(i64::MAX) {
398 return Err(ShieldError::InvalidFormat);
399 }
400 }
401
402 return Ok(decrypted[data_start..].to_vec());
403 }
404 }
405
406 Ok(decrypted[8..].to_vec())
408 }
409
410 pub fn decrypt_v1(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
414 Self::decrypt_v1_with_key(&self.key, encrypted)
415 }
416
417 pub fn decrypt_v1_with_key(key: &[u8; 32], encrypted: &[u8]) -> Result<Vec<u8>> {
419 let (enc_key, mac_key) = derive_subkeys(key);
420
421 if encrypted.len() < MIN_CIPHERTEXT_SIZE {
422 return Err(ShieldError::CiphertextTooShort {
423 expected: MIN_CIPHERTEXT_SIZE,
424 actual: encrypted.len(),
425 });
426 }
427
428 let nonce = &encrypted[..NONCE_SIZE];
430 let ciphertext = &encrypted[NONCE_SIZE..encrypted.len() - MAC_SIZE];
431 let mac = &encrypted[encrypted.len() - MAC_SIZE..];
432
433 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, &mac_key);
435 let mut hmac_data = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
436 hmac_data.extend_from_slice(nonce);
437 hmac_data.extend_from_slice(ciphertext);
438 let expected_tag = hmac::sign(&hmac_key, &hmac_data);
439
440 if mac.ct_eq(&expected_tag.as_ref()[..MAC_SIZE]).unwrap_u8() != 1 {
441 return Err(ShieldError::AuthenticationFailed);
442 }
443
444 let keystream = generate_keystream(&enc_key, nonce, ciphertext.len())?;
446 let decrypted: Vec<u8> = ciphertext
447 .iter()
448 .zip(keystream.iter())
449 .map(|(c, k)| c ^ k)
450 .collect();
451
452 Ok(decrypted[8..].to_vec())
454 }
455
456 #[cfg(any(feature = "pgvector", feature = "confidential"))]
461 #[must_use]
462 pub(crate) fn master_key(&self) -> &[u8; 32] {
463 &self.key
464 }
465}
466
467fn generate_keystream(key: &[u8], nonce: &[u8], length: usize) -> Result<Vec<u8>> {
469 let num_blocks = length.div_ceil(32);
470 if u32::try_from(num_blocks).is_err() {
471 return Err(ShieldError::StreamError(
472 "keystream too long: counter overflow".into(),
473 ));
474 }
475 let mut keystream = Vec::with_capacity(num_blocks * 32);
476
477 for i in 0..num_blocks {
478 let counter = (i as u32).to_le_bytes();
479
480 let mut data = Vec::with_capacity(key.len() + nonce.len() + 4);
482 data.extend_from_slice(key);
483 data.extend_from_slice(nonce);
484 data.extend_from_slice(&counter);
485
486 let hash = ring::digest::digest(&ring::digest::SHA256, &data);
487 keystream.extend_from_slice(hash.as_ref());
488 }
489
490 keystream.truncate(length);
491 Ok(keystream)
492}
493
494#[cfg(test)]
495mod tests {
496 use super::*;
497
498 #[test]
499 fn test_keystream_deterministic() {
500 let key = [1u8; 32];
501 let nonce = [2u8; 16];
502
503 let ks1 = generate_keystream(&key, &nonce, 64).unwrap();
504 let ks2 = generate_keystream(&key, &nonce, 64).unwrap();
505
506 assert_eq!(ks1, ks2);
507 }
508
509 #[test]
510 fn test_keystream_different_nonce() {
511 let key = [1u8; 32];
512 let nonce1 = [2u8; 16];
513 let nonce2 = [3u8; 16];
514
515 let ks1 = generate_keystream(&key, &nonce1, 32).unwrap();
516 let ks2 = generate_keystream(&key, &nonce2, 32).unwrap();
517
518 assert_ne!(ks1, ks2);
519 }
520
521 #[test]
522 fn test_encrypt_format_v2() {
523 let shield = Shield::new("password", "service");
524 let encrypted = shield.encrypt(b"test").unwrap();
525
526 assert!(encrypted.len() >= 85 && encrypted.len() <= 181);
530 }
531
532 #[test]
533 fn test_v2_roundtrip() {
534 let shield = Shield::new("password", "service");
535 let plaintext = b"Hello, Shield v2!";
536
537 let encrypted = shield.encrypt(plaintext).unwrap();
538 let decrypted = shield.decrypt(&encrypted).unwrap();
539
540 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
541 }
542
543 #[test]
544 fn test_v2_replay_protection_fresh() {
545 let shield = Shield::new("password", "service");
546 let encrypted = shield.encrypt(b"fresh message").unwrap();
547
548 let decrypted = shield.decrypt(&encrypted).unwrap();
550 assert_eq!(b"fresh message", decrypted.as_slice());
551 }
552
553 #[test]
554 fn test_v2_replay_protection_disabled() {
555 let shield = Shield::new("password", "service").with_max_age(None);
556 let encrypted = shield.encrypt(b"no expiry").unwrap();
557
558 let decrypted = shield.decrypt(&encrypted).unwrap();
560 assert_eq!(b"no expiry", decrypted.as_slice());
561 }
562
563 #[test]
564 fn test_v2_length_variation() {
565 let shield = Shield::new("password", "service");
566 let plaintext = b"same message";
567
568 let mut lengths = std::collections::HashSet::new();
570 for _ in 0..20 {
571 let encrypted = shield.encrypt(plaintext).unwrap();
572 lengths.insert(encrypted.len());
573 }
574
575 assert!(
577 lengths.len() > 1,
578 "Expected length variation due to random padding"
579 );
580 }
581
582 #[test]
583 fn test_v1_backward_compat() {
584 let key = [1u8; 32];
586 let (enc_key, mac_key) = derive_subkeys(&key);
587 let plaintext = b"v1 message";
588
589 let nonce: [u8; 16] = [2u8; 16];
591 let counter_bytes = 0u64.to_le_bytes();
592
593 let mut data_to_encrypt = Vec::new();
594 data_to_encrypt.extend_from_slice(&counter_bytes);
595 data_to_encrypt.extend_from_slice(plaintext);
596
597 let keystream = generate_keystream(&enc_key, &nonce, data_to_encrypt.len()).unwrap();
598 let ciphertext: Vec<u8> = data_to_encrypt
599 .iter()
600 .zip(keystream.iter())
601 .map(|(p, k)| p ^ k)
602 .collect();
603
604 let hmac_signing_key = hmac::Key::new(hmac::HMAC_SHA256, &mac_key);
605 let mut hmac_data = Vec::new();
606 hmac_data.extend_from_slice(&nonce);
607 hmac_data.extend_from_slice(&ciphertext);
608 let tag = hmac::sign(&hmac_signing_key, &hmac_data);
609
610 let mut v1_encrypted = Vec::new();
611 v1_encrypted.extend_from_slice(&nonce);
612 v1_encrypted.extend_from_slice(&ciphertext);
613 v1_encrypted.extend_from_slice(&tag.as_ref()[..16]);
614
615 let shield = Shield::with_key(key);
617 let decrypted = shield.decrypt(&v1_encrypted).unwrap();
618 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
619 }
620
621 #[test]
622 fn test_v1_explicit_decrypt() {
623 let key = [3u8; 32];
625 let (enc_key, mac_key) = derive_subkeys(&key);
626 let plaintext = b"explicit v1";
627
628 let nonce: [u8; 16] = [4u8; 16];
629 let counter_bytes = 0u64.to_le_bytes();
630
631 let mut data_to_encrypt = Vec::new();
632 data_to_encrypt.extend_from_slice(&counter_bytes);
633 data_to_encrypt.extend_from_slice(plaintext);
634
635 let keystream = generate_keystream(&enc_key, &nonce, data_to_encrypt.len()).unwrap();
636 let ciphertext: Vec<u8> = data_to_encrypt
637 .iter()
638 .zip(keystream.iter())
639 .map(|(p, k)| p ^ k)
640 .collect();
641
642 let hmac_signing_key = hmac::Key::new(hmac::HMAC_SHA256, &mac_key);
643 let mut hmac_data = Vec::new();
644 hmac_data.extend_from_slice(&nonce);
645 hmac_data.extend_from_slice(&ciphertext);
646 let tag = hmac::sign(&hmac_signing_key, &hmac_data);
647
648 let mut v1_encrypted = Vec::new();
649 v1_encrypted.extend_from_slice(&nonce);
650 v1_encrypted.extend_from_slice(&ciphertext);
651 v1_encrypted.extend_from_slice(&tag.as_ref()[..16]);
652
653 let shield = Shield::with_key(key);
655 let decrypted = shield.decrypt_v1(&v1_encrypted).unwrap();
656 assert_eq!(plaintext.as_slice(), decrypted.as_slice());
657 }
658
659 #[test]
660 fn test_tamper_detection_v2() {
661 let shield = Shield::new("password", "service");
662 let mut encrypted = shield.encrypt(b"data").unwrap();
663
664 encrypted[20] ^= 0xFF;
666
667 assert!(shield.decrypt(&encrypted).is_err());
669 }
670}