Skip to main content

shield_core/
stream.rs

1//! Streaming encryption for large files.
2//!
3//! Processes data in chunks with per-chunk authentication.
4//! Matches Python `StreamCipher` from `shield_enterprise.py`.
5
6// Crypto block/chunk counters are intentionally u32 - data >4GB would have other issues
7#![allow(clippy::cast_possible_truncation)]
8
9use ring::{digest, hmac};
10use subtle::ConstantTimeEq;
11use zeroize::{Zeroize, ZeroizeOnDrop};
12
13use crate::error::{Result, ShieldError};
14
15/// Default chunk size: 64KB
16const DEFAULT_CHUNK_SIZE: usize = 64 * 1024;
17
18/// Stream cipher for large file encryption.
19///
20/// Each chunk is independently authenticated, allowing:
21/// - Constant memory usage regardless of file size
22/// - Early detection of tampering
23/// - Potential for parallel processing
24#[derive(Zeroize, ZeroizeOnDrop)]
25pub struct StreamCipher {
26    key: [u8; 32],
27    #[zeroize(skip)]
28    chunk_size: usize,
29}
30
31impl StreamCipher {
32    /// Create a new stream cipher with the given key.
33    #[must_use]
34    pub fn new(key: [u8; 32]) -> Self {
35        Self {
36            key,
37            chunk_size: DEFAULT_CHUNK_SIZE,
38        }
39    }
40
41    /// Create with custom chunk size.
42    #[must_use]
43    pub fn with_chunk_size(key: [u8; 32], chunk_size: usize) -> Self {
44        Self { key, chunk_size }
45    }
46
47    /// Get the chunk size.
48    #[must_use]
49    pub fn chunk_size(&self) -> usize {
50        self.chunk_size
51    }
52
53    /// Encrypt a stream of data.
54    ///
55    /// Returns an iterator over encrypted chunks.
56    /// First chunk is the header containing stream salt.
57    pub fn encrypt_stream<'a>(&'a self, data: &'a [u8]) -> Result<StreamEncryptor<'a>> {
58        StreamEncryptor::new(&self.key, data, self.chunk_size)
59    }
60
61    /// Decrypt a stream of encrypted chunks.
62    pub fn decrypt_stream(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
63        if encrypted.len() < 20 {
64            return Err(ShieldError::StreamError("encrypted data too short".into()));
65        }
66
67        // Parse header: chunk_size(4) || stream_salt(16)
68        let _chunk_size =
69            u32::from_le_bytes([encrypted[0], encrypted[1], encrypted[2], encrypted[3]]) as usize;
70        let stream_salt = &encrypted[4..20];
71
72        let mut output = Vec::new();
73        let mut pos = 20;
74        let mut chunk_num: u64 = 0;
75
76        while pos < encrypted.len() {
77            // Read chunk length
78            if pos + 4 > encrypted.len() {
79                return Err(ShieldError::StreamError("truncated chunk length".into()));
80            }
81
82            let chunk_len = u32::from_le_bytes([
83                encrypted[pos],
84                encrypted[pos + 1],
85                encrypted[pos + 2],
86                encrypted[pos + 3],
87            ]) as usize;
88            pos += 4;
89
90            // End marker
91            if chunk_len == 0 {
92                break;
93            }
94
95            if pos + chunk_len > encrypted.len() {
96                return Err(ShieldError::StreamError("truncated chunk data".into()));
97            }
98
99            let chunk_data = &encrypted[pos..pos + chunk_len];
100            pos += chunk_len;
101
102            // Derive chunk key
103            let chunk_key = derive_chunk_key(&self.key, stream_salt, chunk_num);
104
105            // Decrypt chunk
106            let decrypted = decrypt_chunk(&chunk_key, chunk_data)?;
107            output.extend_from_slice(&decrypted);
108
109            chunk_num += 1;
110        }
111
112        Ok(output)
113    }
114
115    /// Encrypt entire data at once (convenience method).
116    pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
117        let encryptor = self.encrypt_stream(data)?;
118        let mut result = Vec::new();
119
120        for chunk in encryptor {
121            result.extend_from_slice(&chunk?);
122        }
123
124        Ok(result)
125    }
126
127    /// Decrypt entire data at once (convenience method).
128    pub fn decrypt(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
129        self.decrypt_stream(encrypted)
130    }
131}
132
133/// Iterator over encrypted chunks.
134pub struct StreamEncryptor<'a> {
135    key: &'a [u8; 32],
136    data: &'a [u8],
137    stream_salt: [u8; 16],
138    chunk_size: usize,
139    position: usize,
140    chunk_num: u64,
141    header_sent: bool,
142    finished: bool,
143}
144
145impl<'a> StreamEncryptor<'a> {
146    fn new(key: &'a [u8; 32], data: &'a [u8], chunk_size: usize) -> Result<Self> {
147        let stream_salt: [u8; 16] = crate::random::random_bytes()?;
148
149        Ok(Self {
150            key,
151            data,
152            stream_salt,
153            chunk_size,
154            position: 0,
155            chunk_num: 0,
156            header_sent: false,
157            finished: false,
158        })
159    }
160}
161
162impl Iterator for StreamEncryptor<'_> {
163    type Item = Result<Vec<u8>>;
164
165    fn next(&mut self) -> Option<Self::Item> {
166        if self.finished {
167            return None;
168        }
169
170        // First, send header
171        if !self.header_sent {
172            self.header_sent = true;
173            let mut header = Vec::with_capacity(20);
174            header.extend_from_slice(&(self.chunk_size as u32).to_le_bytes());
175            header.extend_from_slice(&self.stream_salt);
176            return Some(Ok(header));
177        }
178
179        // Check if we have more data
180        if self.position >= self.data.len() {
181            self.finished = true;
182            // Send end marker
183            return Some(Ok(vec![0, 0, 0, 0]));
184        }
185
186        // Get next chunk
187        let end = std::cmp::min(self.position + self.chunk_size, self.data.len());
188        let chunk_data = &self.data[self.position..end];
189        self.position = end;
190
191        // Derive chunk key
192        let chunk_key = derive_chunk_key(self.key, &self.stream_salt, self.chunk_num);
193        self.chunk_num += 1;
194
195        // Encrypt chunk
196        match encrypt_chunk(&chunk_key, chunk_data) {
197            Ok(encrypted) => {
198                let mut result = Vec::with_capacity(4 + encrypted.len());
199                result.extend_from_slice(&(encrypted.len() as u32).to_le_bytes());
200                result.extend_from_slice(&encrypted);
201                Some(Ok(result))
202            }
203            Err(e) => Some(Err(e)),
204        }
205    }
206}
207
208/// Derive per-chunk key from master key and stream salt.
209fn derive_chunk_key(key: &[u8], stream_salt: &[u8], chunk_num: u64) -> [u8; 32] {
210    let mut data = Vec::with_capacity(key.len() + stream_salt.len() + 8);
211    data.extend_from_slice(key);
212    data.extend_from_slice(stream_salt);
213    data.extend_from_slice(&chunk_num.to_le_bytes());
214
215    let hash = digest::digest(&digest::SHA256, &data);
216    let mut result = [0u8; 32];
217    result.copy_from_slice(hash.as_ref());
218    result
219}
220
221/// Derive separated encryption and MAC subkeys from a chunk key using HMAC-SHA256.
222fn derive_chunk_subkeys(key: &[u8; 32]) -> ([u8; 32], [u8; 32]) {
223    let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
224
225    let enc_tag = hmac::sign(&hmac_key, b"shield-stream-encrypt");
226    let mut enc_key = [0u8; 32];
227    enc_key.copy_from_slice(&enc_tag.as_ref()[..32]);
228
229    let mac_tag = hmac::sign(&hmac_key, b"shield-stream-authenticate");
230    let mut mac_key = [0u8; 32];
231    mac_key.copy_from_slice(&mac_tag.as_ref()[..32]);
232
233    (enc_key, mac_key)
234}
235
236/// Encrypt a single chunk.
237fn encrypt_chunk(key: &[u8; 32], data: &[u8]) -> Result<Vec<u8>> {
238    let (enc_key, mac_key) = derive_chunk_subkeys(key);
239
240    // Generate nonce
241    let nonce: [u8; 16] = crate::random::random_bytes()?;
242
243    // Generate keystream with enc_key
244    let num_blocks = data.len().div_ceil(32);
245    if u32::try_from(num_blocks).is_err() {
246        return Err(ShieldError::StreamError(
247            "keystream too long: counter overflow".into(),
248        ));
249    }
250    let mut keystream = Vec::with_capacity(num_blocks * 32);
251    for i in 0..num_blocks {
252        let counter = (i as u32).to_le_bytes();
253        let mut hash_input = Vec::with_capacity(enc_key.len() + nonce.len() + 4);
254        hash_input.extend_from_slice(&enc_key);
255        hash_input.extend_from_slice(&nonce);
256        hash_input.extend_from_slice(&counter);
257        let hash = digest::digest(&digest::SHA256, &hash_input);
258        keystream.extend_from_slice(hash.as_ref());
259    }
260
261    // XOR encrypt
262    let ciphertext: Vec<u8> = data
263        .iter()
264        .zip(keystream.iter())
265        .map(|(p, k)| p ^ k)
266        .collect();
267
268    // HMAC with mac_key
269    let hmac_signing_key = hmac::Key::new(hmac::HMAC_SHA256, &mac_key);
270    let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
271    hmac_data.extend_from_slice(&nonce);
272    hmac_data.extend_from_slice(&ciphertext);
273    let tag = hmac::sign(&hmac_signing_key, &hmac_data);
274
275    // Format: nonce || ciphertext || mac(16)
276    let mut result = Vec::with_capacity(16 + ciphertext.len() + 16);
277    result.extend_from_slice(&nonce);
278    result.extend_from_slice(&ciphertext);
279    result.extend_from_slice(&tag.as_ref()[..16]);
280
281    Ok(result)
282}
283
284/// Decrypt a single chunk.
285fn decrypt_chunk(key: &[u8; 32], encrypted: &[u8]) -> Result<Vec<u8>> {
286    if encrypted.len() < 32 {
287        return Err(ShieldError::StreamError("chunk too short".into()));
288    }
289
290    let (enc_key, mac_key) = derive_chunk_subkeys(key);
291
292    let nonce = &encrypted[..16];
293    let ciphertext = &encrypted[16..encrypted.len() - 16];
294    let mac = &encrypted[encrypted.len() - 16..];
295
296    // Verify MAC with mac_key
297    let hmac_signing_key = hmac::Key::new(hmac::HMAC_SHA256, &mac_key);
298    let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
299    hmac_data.extend_from_slice(nonce);
300    hmac_data.extend_from_slice(ciphertext);
301    let expected = hmac::sign(&hmac_signing_key, &hmac_data);
302
303    if mac.ct_eq(&expected.as_ref()[..16]).unwrap_u8() != 1 {
304        return Err(ShieldError::AuthenticationFailed);
305    }
306
307    // Generate keystream with enc_key
308    let num_blocks = ciphertext.len().div_ceil(32);
309    if u32::try_from(num_blocks).is_err() {
310        return Err(ShieldError::StreamError(
311            "keystream too long: counter overflow".into(),
312        ));
313    }
314    let mut keystream = Vec::with_capacity(num_blocks * 32);
315    for i in 0..num_blocks {
316        let counter = (i as u32).to_le_bytes();
317        let mut hash_input = Vec::with_capacity(enc_key.len() + nonce.len() + 4);
318        hash_input.extend_from_slice(&enc_key);
319        hash_input.extend_from_slice(nonce);
320        hash_input.extend_from_slice(&counter);
321        let hash = digest::digest(&digest::SHA256, &hash_input);
322        keystream.extend_from_slice(hash.as_ref());
323    }
324
325    // XOR decrypt
326    Ok(ciphertext
327        .iter()
328        .zip(keystream.iter())
329        .map(|(c, k)| c ^ k)
330        .collect())
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    #[test]
338    fn test_stream_roundtrip() {
339        let key = [0x42u8; 32];
340        let cipher = StreamCipher::new(key);
341
342        let data = b"Hello, streaming world!";
343        let encrypted = cipher.encrypt(data).unwrap();
344        let decrypted = cipher.decrypt(&encrypted).unwrap();
345
346        assert_eq!(data.as_slice(), decrypted.as_slice());
347    }
348
349    #[test]
350    fn test_stream_large_data() {
351        let key = [0x42u8; 32];
352        let cipher = StreamCipher::with_chunk_size(key, 1024);
353
354        let data: Vec<u8> = (0..10000_u32).map(|i| (i % 256) as u8).collect();
355        let encrypted = cipher.encrypt(&data).unwrap();
356        let decrypted = cipher.decrypt(&encrypted).unwrap();
357
358        assert_eq!(data, decrypted);
359    }
360
361    #[test]
362    fn test_stream_tamper_detection() {
363        let key = [0x42u8; 32];
364        let cipher = StreamCipher::new(key);
365
366        let mut encrypted = cipher.encrypt(b"test data").unwrap();
367        // Tamper with a chunk
368        if encrypted.len() > 30 {
369            encrypted[30] ^= 0xFF;
370        }
371
372        assert!(cipher.decrypt(&encrypted).is_err());
373    }
374}