shield_core/
stream.rs

1//! Streaming encryption for large files.
2//!
3//! Processes data in chunks with per-chunk authentication.
4//! Matches Python `StreamCipher` from `shield_enterprise.py`.
5
6// Crypto block/chunk counters are intentionally u32 - data >4GB would have other issues
7#![allow(clippy::cast_possible_truncation)]
8
9use ring::{
10    digest, hmac,
11    rand::{SecureRandom, SystemRandom},
12};
13use subtle::ConstantTimeEq;
14
15use crate::error::{Result, ShieldError};
16
17/// Default chunk size: 64KB
18const DEFAULT_CHUNK_SIZE: usize = 64 * 1024;
19
20/// Stream cipher for large file encryption.
21///
22/// Each chunk is independently authenticated, allowing:
23/// - Constant memory usage regardless of file size
24/// - Early detection of tampering
25/// - Potential for parallel processing
26pub struct StreamCipher {
27    key: [u8; 32],
28    chunk_size: usize,
29}
30
31impl StreamCipher {
32    /// Create a new stream cipher with the given key.
33    #[must_use]
34    pub fn new(key: [u8; 32]) -> Self {
35        Self {
36            key,
37            chunk_size: DEFAULT_CHUNK_SIZE,
38        }
39    }
40
41    /// Create with custom chunk size.
42    #[must_use]
43    pub fn with_chunk_size(key: [u8; 32], chunk_size: usize) -> Self {
44        Self { key, chunk_size }
45    }
46
47    /// Get the chunk size.
48    #[must_use]
49    pub fn chunk_size(&self) -> usize {
50        self.chunk_size
51    }
52
53    /// Encrypt a stream of data.
54    ///
55    /// Returns an iterator over encrypted chunks.
56    /// First chunk is the header containing stream salt.
57    pub fn encrypt_stream<'a>(&'a self, data: &'a [u8]) -> Result<StreamEncryptor<'a>> {
58        StreamEncryptor::new(&self.key, data, self.chunk_size)
59    }
60
61    /// Decrypt a stream of encrypted chunks.
62    pub fn decrypt_stream(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
63        if encrypted.len() < 20 {
64            return Err(ShieldError::StreamError("encrypted data too short".into()));
65        }
66
67        // Parse header: chunk_size(4) || stream_salt(16)
68        let _chunk_size =
69            u32::from_le_bytes([encrypted[0], encrypted[1], encrypted[2], encrypted[3]]) as usize;
70        let stream_salt = &encrypted[4..20];
71
72        let mut output = Vec::new();
73        let mut pos = 20;
74        let mut chunk_num: u64 = 0;
75
76        while pos < encrypted.len() {
77            // Read chunk length
78            if pos + 4 > encrypted.len() {
79                return Err(ShieldError::StreamError("truncated chunk length".into()));
80            }
81
82            let chunk_len = u32::from_le_bytes([
83                encrypted[pos],
84                encrypted[pos + 1],
85                encrypted[pos + 2],
86                encrypted[pos + 3],
87            ]) as usize;
88            pos += 4;
89
90            // End marker
91            if chunk_len == 0 {
92                break;
93            }
94
95            if pos + chunk_len > encrypted.len() {
96                return Err(ShieldError::StreamError("truncated chunk data".into()));
97            }
98
99            let chunk_data = &encrypted[pos..pos + chunk_len];
100            pos += chunk_len;
101
102            // Derive chunk key
103            let chunk_key = derive_chunk_key(&self.key, stream_salt, chunk_num);
104
105            // Decrypt chunk
106            let decrypted = decrypt_chunk(&chunk_key, chunk_data)?;
107            output.extend_from_slice(&decrypted);
108
109            chunk_num += 1;
110        }
111
112        Ok(output)
113    }
114
115    /// Encrypt entire data at once (convenience method).
116    pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
117        let encryptor = self.encrypt_stream(data)?;
118        let mut result = Vec::new();
119
120        for chunk in encryptor {
121            result.extend_from_slice(&chunk?);
122        }
123
124        Ok(result)
125    }
126
127    /// Decrypt entire data at once (convenience method).
128    pub fn decrypt(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
129        self.decrypt_stream(encrypted)
130    }
131}
132
133/// Iterator over encrypted chunks.
134pub struct StreamEncryptor<'a> {
135    key: &'a [u8; 32],
136    data: &'a [u8],
137    stream_salt: [u8; 16],
138    chunk_size: usize,
139    position: usize,
140    chunk_num: u64,
141    header_sent: bool,
142    finished: bool,
143}
144
145impl<'a> StreamEncryptor<'a> {
146    fn new(key: &'a [u8; 32], data: &'a [u8], chunk_size: usize) -> Result<Self> {
147        let rng = SystemRandom::new();
148        let mut stream_salt = [0u8; 16];
149        rng.fill(&mut stream_salt)
150            .map_err(|_| ShieldError::RandomFailed)?;
151
152        Ok(Self {
153            key,
154            data,
155            stream_salt,
156            chunk_size,
157            position: 0,
158            chunk_num: 0,
159            header_sent: false,
160            finished: false,
161        })
162    }
163}
164
165impl Iterator for StreamEncryptor<'_> {
166    type Item = Result<Vec<u8>>;
167
168    fn next(&mut self) -> Option<Self::Item> {
169        if self.finished {
170            return None;
171        }
172
173        // First, send header
174        if !self.header_sent {
175            self.header_sent = true;
176            let mut header = Vec::with_capacity(20);
177            header.extend_from_slice(&(self.chunk_size as u32).to_le_bytes());
178            header.extend_from_slice(&self.stream_salt);
179            return Some(Ok(header));
180        }
181
182        // Check if we have more data
183        if self.position >= self.data.len() {
184            self.finished = true;
185            // Send end marker
186            return Some(Ok(vec![0, 0, 0, 0]));
187        }
188
189        // Get next chunk
190        let end = std::cmp::min(self.position + self.chunk_size, self.data.len());
191        let chunk_data = &self.data[self.position..end];
192        self.position = end;
193
194        // Derive chunk key
195        let chunk_key = derive_chunk_key(self.key, &self.stream_salt, self.chunk_num);
196        self.chunk_num += 1;
197
198        // Encrypt chunk
199        match encrypt_chunk(&chunk_key, chunk_data) {
200            Ok(encrypted) => {
201                let mut result = Vec::with_capacity(4 + encrypted.len());
202                result.extend_from_slice(&(encrypted.len() as u32).to_le_bytes());
203                result.extend_from_slice(&encrypted);
204                Some(Ok(result))
205            }
206            Err(e) => Some(Err(e)),
207        }
208    }
209}
210
211/// Derive per-chunk key from master key and stream salt.
212fn derive_chunk_key(key: &[u8], stream_salt: &[u8], chunk_num: u64) -> [u8; 32] {
213    let mut data = Vec::with_capacity(key.len() + stream_salt.len() + 8);
214    data.extend_from_slice(key);
215    data.extend_from_slice(stream_salt);
216    data.extend_from_slice(&chunk_num.to_le_bytes());
217
218    let hash = digest::digest(&digest::SHA256, &data);
219    let mut result = [0u8; 32];
220    result.copy_from_slice(hash.as_ref());
221    result
222}
223
224/// Encrypt a single chunk.
225fn encrypt_chunk(key: &[u8; 32], data: &[u8]) -> Result<Vec<u8>> {
226    let rng = SystemRandom::new();
227
228    // Generate nonce
229    let mut nonce = [0u8; 16];
230    rng.fill(&mut nonce)
231        .map_err(|_| ShieldError::RandomFailed)?;
232
233    // Generate keystream
234    let mut keystream = Vec::with_capacity(data.len().div_ceil(32) * 32);
235    for i in 0..data.len().div_ceil(32) {
236        let counter = (i as u32).to_le_bytes();
237        let mut hash_input = Vec::with_capacity(key.len() + nonce.len() + 4);
238        hash_input.extend_from_slice(key);
239        hash_input.extend_from_slice(&nonce);
240        hash_input.extend_from_slice(&counter);
241        let hash = digest::digest(&digest::SHA256, &hash_input);
242        keystream.extend_from_slice(hash.as_ref());
243    }
244
245    // XOR encrypt
246    let ciphertext: Vec<u8> = data
247        .iter()
248        .zip(keystream.iter())
249        .map(|(p, k)| p ^ k)
250        .collect();
251
252    // HMAC
253    let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
254    let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
255    hmac_data.extend_from_slice(&nonce);
256    hmac_data.extend_from_slice(&ciphertext);
257    let tag = hmac::sign(&hmac_key, &hmac_data);
258
259    // Format: nonce || ciphertext || mac(16)
260    let mut result = Vec::with_capacity(16 + ciphertext.len() + 16);
261    result.extend_from_slice(&nonce);
262    result.extend_from_slice(&ciphertext);
263    result.extend_from_slice(&tag.as_ref()[..16]);
264
265    Ok(result)
266}
267
268/// Decrypt a single chunk.
269fn decrypt_chunk(key: &[u8; 32], encrypted: &[u8]) -> Result<Vec<u8>> {
270    if encrypted.len() < 32 {
271        return Err(ShieldError::StreamError("chunk too short".into()));
272    }
273
274    let nonce = &encrypted[..16];
275    let ciphertext = &encrypted[16..encrypted.len() - 16];
276    let mac = &encrypted[encrypted.len() - 16..];
277
278    // Verify MAC
279    let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
280    let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
281    hmac_data.extend_from_slice(nonce);
282    hmac_data.extend_from_slice(ciphertext);
283    let expected = hmac::sign(&hmac_key, &hmac_data);
284
285    if mac.ct_eq(&expected.as_ref()[..16]).unwrap_u8() != 1 {
286        return Err(ShieldError::AuthenticationFailed);
287    }
288
289    // Generate keystream
290    let mut keystream = Vec::with_capacity(ciphertext.len().div_ceil(32) * 32);
291    for i in 0..ciphertext.len().div_ceil(32) {
292        let counter = (i as u32).to_le_bytes();
293        let mut hash_input = Vec::with_capacity(key.len() + nonce.len() + 4);
294        hash_input.extend_from_slice(key);
295        hash_input.extend_from_slice(nonce);
296        hash_input.extend_from_slice(&counter);
297        let hash = digest::digest(&digest::SHA256, &hash_input);
298        keystream.extend_from_slice(hash.as_ref());
299    }
300
301    // XOR decrypt
302    Ok(ciphertext
303        .iter()
304        .zip(keystream.iter())
305        .map(|(c, k)| c ^ k)
306        .collect())
307}
308
309#[cfg(test)]
310mod tests {
311    use super::*;
312
313    #[test]
314    fn test_stream_roundtrip() {
315        let key = [0x42u8; 32];
316        let cipher = StreamCipher::new(key);
317
318        let data = b"Hello, streaming world!";
319        let encrypted = cipher.encrypt(data).unwrap();
320        let decrypted = cipher.decrypt(&encrypted).unwrap();
321
322        assert_eq!(data.as_slice(), decrypted.as_slice());
323    }
324
325    #[test]
326    fn test_stream_large_data() {
327        let key = [0x42u8; 32];
328        let cipher = StreamCipher::with_chunk_size(key, 1024);
329
330        let data: Vec<u8> = (0..10000).map(|i| (i % 256) as u8).collect();
331        let encrypted = cipher.encrypt(&data).unwrap();
332        let decrypted = cipher.decrypt(&encrypted).unwrap();
333
334        assert_eq!(data, decrypted);
335    }
336
337    #[test]
338    fn test_stream_tamper_detection() {
339        let key = [0x42u8; 32];
340        let cipher = StreamCipher::new(key);
341
342        let mut encrypted = cipher.encrypt(b"test data").unwrap();
343        // Tamper with a chunk
344        if encrypted.len() > 30 {
345            encrypted[30] ^= 0xFF;
346        }
347
348        assert!(cipher.decrypt(&encrypted).is_err());
349    }
350}