shield_core/
shield.rs

1//! Core Shield encryption implementation.
2//!
3//! Matches Python `shield.py` byte-for-byte for interoperability.
4
5// Crypto block counters are intentionally u32 - data >4GB would have other issues
6#![allow(clippy::cast_possible_truncation)]
7
8use ring::{
9    hmac, pbkdf2,
10    rand::{SecureRandom, SystemRandom},
11};
12use std::num::NonZeroU32;
13use subtle::ConstantTimeEq;
14use zeroize::{Zeroize, ZeroizeOnDrop};
15
16use crate::error::{Result, ShieldError};
17
18/// PBKDF2 iteration count (matches Python implementation).
19const PBKDF2_ITERATIONS: u32 = 100_000;
20
21/// Nonce size in bytes.
22const NONCE_SIZE: usize = 16;
23
24/// MAC size in bytes.
25const MAC_SIZE: usize = 16;
26
27/// Minimum ciphertext size: nonce + counter(8) + mac.
28const MIN_CIPHERTEXT_SIZE: usize = NONCE_SIZE + 8 + MAC_SIZE;
29
30/// EXPTIME-secure symmetric encryption.
31///
32/// Uses password-derived keys with PBKDF2 and encrypts using
33/// a SHA256-based stream cipher with HMAC-SHA256 authentication.
34/// Breaking requires 2^256 operations - no shortcut exists.
35///
36/// Key material is securely zeroized from memory when dropped.
37#[derive(Zeroize, ZeroizeOnDrop)]
38pub struct Shield {
39    key: [u8; 32],
40    #[zeroize(skip)]
41    #[allow(dead_code)]
42    counter: u64,
43}
44
45impl Shield {
46    /// Create a new Shield instance from password and service name.
47    ///
48    /// # Arguments
49    /// * `password` - User's password
50    /// * `service` - Service identifier (e.g., "github.com")
51    ///
52    /// # Example
53    /// ```
54    /// use shield_core::Shield;
55    /// let shield = Shield::new("my_password", "example.com");
56    /// ```
57    #[must_use]
58    pub fn new(password: &str, service: &str) -> Self {
59        // Derive salt from service name (matches Python)
60        let salt = ring::digest::digest(&ring::digest::SHA256, service.as_bytes());
61
62        // Derive key using PBKDF2
63        let mut key = [0u8; 32];
64        pbkdf2::derive(
65            pbkdf2::PBKDF2_HMAC_SHA256,
66            NonZeroU32::new(PBKDF2_ITERATIONS).unwrap(),
67            salt.as_ref(),
68            password.as_bytes(),
69            &mut key,
70        );
71
72        Self { key, counter: 0 }
73    }
74
75    /// Create Shield with a pre-shared key (no password derivation).
76    #[must_use]
77    pub fn with_key(key: [u8; 32]) -> Self {
78        Self { key, counter: 0 }
79    }
80
81    /// Encrypt data.
82    ///
83    /// Returns: `nonce(16) || ciphertext || mac(16)`
84    ///
85    /// # Errors
86    /// Returns error if random generation fails.
87    pub fn encrypt(&self, plaintext: &[u8]) -> Result<Vec<u8>> {
88        Self::encrypt_with_key(&self.key, plaintext)
89    }
90
91    /// Encrypt with explicit key.
92    pub fn encrypt_with_key(key: &[u8; 32], plaintext: &[u8]) -> Result<Vec<u8>> {
93        let rng = SystemRandom::new();
94
95        // Generate random nonce
96        let mut nonce = [0u8; NONCE_SIZE];
97        rng.fill(&mut nonce)
98            .map_err(|_| ShieldError::RandomFailed)?;
99
100        // Counter prefix (matches Python format)
101        let counter_bytes = 0u64.to_le_bytes();
102
103        // Data to encrypt: counter || plaintext
104        let mut data_to_encrypt = Vec::with_capacity(8 + plaintext.len());
105        data_to_encrypt.extend_from_slice(&counter_bytes);
106        data_to_encrypt.extend_from_slice(plaintext);
107
108        // Generate keystream and XOR
109        let keystream = generate_keystream(key, &nonce, data_to_encrypt.len());
110        let ciphertext: Vec<u8> = data_to_encrypt
111            .iter()
112            .zip(keystream.iter())
113            .map(|(p, k)| p ^ k)
114            .collect();
115
116        // Compute HMAC over nonce || ciphertext
117        let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
118        let mut hmac_data = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
119        hmac_data.extend_from_slice(&nonce);
120        hmac_data.extend_from_slice(&ciphertext);
121        let tag = hmac::sign(&hmac_key, &hmac_data);
122
123        // Format: nonce || ciphertext || mac(16)
124        let mut result = Vec::with_capacity(NONCE_SIZE + ciphertext.len() + MAC_SIZE);
125        result.extend_from_slice(&nonce);
126        result.extend_from_slice(&ciphertext);
127        result.extend_from_slice(&tag.as_ref()[..MAC_SIZE]);
128
129        Ok(result)
130    }
131
132    /// Decrypt and verify data.
133    ///
134    /// # Errors
135    /// Returns error if MAC verification fails or ciphertext is malformed.
136    pub fn decrypt(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
137        Self::decrypt_with_key(&self.key, encrypted)
138    }
139
140    /// Decrypt with explicit key.
141    pub fn decrypt_with_key(key: &[u8; 32], encrypted: &[u8]) -> Result<Vec<u8>> {
142        if encrypted.len() < MIN_CIPHERTEXT_SIZE {
143            return Err(ShieldError::CiphertextTooShort {
144                expected: MIN_CIPHERTEXT_SIZE,
145                actual: encrypted.len(),
146            });
147        }
148
149        // Parse components
150        let nonce = &encrypted[..NONCE_SIZE];
151        let ciphertext = &encrypted[NONCE_SIZE..encrypted.len() - MAC_SIZE];
152        let mac = &encrypted[encrypted.len() - MAC_SIZE..];
153
154        // Verify MAC first (constant-time)
155        let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
156        let mut hmac_data = Vec::with_capacity(NONCE_SIZE + ciphertext.len());
157        hmac_data.extend_from_slice(nonce);
158        hmac_data.extend_from_slice(ciphertext);
159        let expected_tag = hmac::sign(&hmac_key, &hmac_data);
160
161        // Constant-time comparison
162        if mac.ct_eq(&expected_tag.as_ref()[..MAC_SIZE]).unwrap_u8() != 1 {
163            return Err(ShieldError::AuthenticationFailed);
164        }
165
166        // Decrypt
167        let keystream = generate_keystream(key, nonce, ciphertext.len());
168        let decrypted: Vec<u8> = ciphertext
169            .iter()
170            .zip(keystream.iter())
171            .map(|(c, k)| c ^ k)
172            .collect();
173
174        // Skip counter prefix (8 bytes)
175        Ok(decrypted[8..].to_vec())
176    }
177
178    /// Get the derived key (for testing/debugging).
179    #[must_use]
180    pub fn key(&self) -> &[u8; 32] {
181        &self.key
182    }
183}
184
185/// Generate keystream using SHA256 (matches Python implementation).
186fn generate_keystream(key: &[u8], nonce: &[u8], length: usize) -> Vec<u8> {
187    let mut keystream = Vec::with_capacity(length.div_ceil(32) * 32);
188    let num_blocks = length.div_ceil(32);
189
190    for i in 0..num_blocks {
191        let counter = (i as u32).to_le_bytes();
192
193        // SHA256(key || nonce || counter)
194        let mut data = Vec::with_capacity(key.len() + nonce.len() + 4);
195        data.extend_from_slice(key);
196        data.extend_from_slice(nonce);
197        data.extend_from_slice(&counter);
198
199        let hash = ring::digest::digest(&ring::digest::SHA256, &data);
200        keystream.extend_from_slice(hash.as_ref());
201    }
202
203    keystream.truncate(length);
204    keystream
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210
211    #[test]
212    fn test_keystream_deterministic() {
213        let key = [1u8; 32];
214        let nonce = [2u8; 16];
215
216        let ks1 = generate_keystream(&key, &nonce, 64);
217        let ks2 = generate_keystream(&key, &nonce, 64);
218
219        assert_eq!(ks1, ks2);
220    }
221
222    #[test]
223    fn test_keystream_different_nonce() {
224        let key = [1u8; 32];
225        let nonce1 = [2u8; 16];
226        let nonce2 = [3u8; 16];
227
228        let ks1 = generate_keystream(&key, &nonce1, 32);
229        let ks2 = generate_keystream(&key, &nonce2, 32);
230
231        assert_ne!(ks1, ks2);
232    }
233
234    #[test]
235    fn test_encrypt_format() {
236        let shield = Shield::new("password", "service");
237        let encrypted = shield.encrypt(b"test").unwrap();
238
239        // nonce(16) + counter(8) + plaintext(4) + mac(16) = 44
240        assert_eq!(encrypted.len(), 16 + 8 + 4 + 16);
241    }
242}