shield_core/
stream.rs

1//! Streaming encryption for large files.
2//!
3//! Processes data in chunks with per-chunk authentication.
4//! Matches Python `StreamCipher` from `shield_enterprise.py`.
5
6// Crypto block/chunk counters are intentionally u32 - data >4GB would have other issues
7#![allow(clippy::cast_possible_truncation)]
8
9use ring::{hmac, digest, rand::{SecureRandom, SystemRandom}};
10use subtle::ConstantTimeEq;
11
12use crate::error::{Result, ShieldError};
13
14/// Default chunk size: 64KB
15const DEFAULT_CHUNK_SIZE: usize = 64 * 1024;
16
17/// Stream cipher for large file encryption.
18///
19/// Each chunk is independently authenticated, allowing:
20/// - Constant memory usage regardless of file size
21/// - Early detection of tampering
22/// - Potential for parallel processing
23pub struct StreamCipher {
24    key: [u8; 32],
25    chunk_size: usize,
26}
27
28impl StreamCipher {
29    /// Create a new stream cipher with the given key.
30    #[must_use]
31    pub fn new(key: [u8; 32]) -> Self {
32        Self {
33            key,
34            chunk_size: DEFAULT_CHUNK_SIZE,
35        }
36    }
37
38    /// Create with custom chunk size.
39    #[must_use]
40    pub fn with_chunk_size(key: [u8; 32], chunk_size: usize) -> Self {
41        Self { key, chunk_size }
42    }
43
44    /// Get the chunk size.
45    #[must_use]
46    pub fn chunk_size(&self) -> usize {
47        self.chunk_size
48    }
49
50    /// Encrypt a stream of data.
51    ///
52    /// Returns an iterator over encrypted chunks.
53    /// First chunk is the header containing stream salt.
54    pub fn encrypt_stream<'a>(
55        &'a self,
56        data: &'a [u8],
57    ) -> Result<StreamEncryptor<'a>> {
58        StreamEncryptor::new(&self.key, data, self.chunk_size)
59    }
60
61    /// Decrypt a stream of encrypted chunks.
62    pub fn decrypt_stream(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
63        if encrypted.len() < 20 {
64            return Err(ShieldError::StreamError("encrypted data too short".into()));
65        }
66
67        // Parse header: chunk_size(4) || stream_salt(16)
68        let _chunk_size = u32::from_le_bytes([
69            encrypted[0], encrypted[1], encrypted[2], encrypted[3],
70        ]) as usize;
71        let stream_salt = &encrypted[4..20];
72
73        let mut output = Vec::new();
74        let mut pos = 20;
75        let mut chunk_num: u64 = 0;
76
77        while pos < encrypted.len() {
78            // Read chunk length
79            if pos + 4 > encrypted.len() {
80                return Err(ShieldError::StreamError("truncated chunk length".into()));
81            }
82
83            let chunk_len = u32::from_le_bytes([
84                encrypted[pos],
85                encrypted[pos + 1],
86                encrypted[pos + 2],
87                encrypted[pos + 3],
88            ]) as usize;
89            pos += 4;
90
91            // End marker
92            if chunk_len == 0 {
93                break;
94            }
95
96            if pos + chunk_len > encrypted.len() {
97                return Err(ShieldError::StreamError("truncated chunk data".into()));
98            }
99
100            let chunk_data = &encrypted[pos..pos + chunk_len];
101            pos += chunk_len;
102
103            // Derive chunk key
104            let chunk_key = derive_chunk_key(&self.key, stream_salt, chunk_num);
105
106            // Decrypt chunk
107            let decrypted = decrypt_chunk(&chunk_key, chunk_data)?;
108            output.extend_from_slice(&decrypted);
109
110            chunk_num += 1;
111        }
112
113        Ok(output)
114    }
115
116    /// Encrypt entire data at once (convenience method).
117    pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
118        let encryptor = self.encrypt_stream(data)?;
119        let mut result = Vec::new();
120
121        for chunk in encryptor {
122            result.extend_from_slice(&chunk?);
123        }
124
125        Ok(result)
126    }
127
128    /// Decrypt entire data at once (convenience method).
129    pub fn decrypt(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
130        self.decrypt_stream(encrypted)
131    }
132}
133
134/// Iterator over encrypted chunks.
135pub struct StreamEncryptor<'a> {
136    key: &'a [u8; 32],
137    data: &'a [u8],
138    stream_salt: [u8; 16],
139    chunk_size: usize,
140    position: usize,
141    chunk_num: u64,
142    header_sent: bool,
143    finished: bool,
144}
145
146impl<'a> StreamEncryptor<'a> {
147    fn new(key: &'a [u8; 32], data: &'a [u8], chunk_size: usize) -> Result<Self> {
148        let rng = SystemRandom::new();
149        let mut stream_salt = [0u8; 16];
150        rng.fill(&mut stream_salt).map_err(|_| ShieldError::RandomFailed)?;
151
152        Ok(Self {
153            key,
154            data,
155            stream_salt,
156            chunk_size,
157            position: 0,
158            chunk_num: 0,
159            header_sent: false,
160            finished: false,
161        })
162    }
163}
164
165impl Iterator for StreamEncryptor<'_> {
166    type Item = Result<Vec<u8>>;
167
168    fn next(&mut self) -> Option<Self::Item> {
169        if self.finished {
170            return None;
171        }
172
173        // First, send header
174        if !self.header_sent {
175            self.header_sent = true;
176            let mut header = Vec::with_capacity(20);
177            header.extend_from_slice(&(self.chunk_size as u32).to_le_bytes());
178            header.extend_from_slice(&self.stream_salt);
179            return Some(Ok(header));
180        }
181
182        // Check if we have more data
183        if self.position >= self.data.len() {
184            self.finished = true;
185            // Send end marker
186            return Some(Ok(vec![0, 0, 0, 0]));
187        }
188
189        // Get next chunk
190        let end = std::cmp::min(self.position + self.chunk_size, self.data.len());
191        let chunk_data = &self.data[self.position..end];
192        self.position = end;
193
194        // Derive chunk key
195        let chunk_key = derive_chunk_key(self.key, &self.stream_salt, self.chunk_num);
196        self.chunk_num += 1;
197
198        // Encrypt chunk
199        match encrypt_chunk(&chunk_key, chunk_data) {
200            Ok(encrypted) => {
201                let mut result = Vec::with_capacity(4 + encrypted.len());
202                result.extend_from_slice(&(encrypted.len() as u32).to_le_bytes());
203                result.extend_from_slice(&encrypted);
204                Some(Ok(result))
205            }
206            Err(e) => Some(Err(e)),
207        }
208    }
209}
210
211/// Derive per-chunk key from master key and stream salt.
212fn derive_chunk_key(key: &[u8], stream_salt: &[u8], chunk_num: u64) -> [u8; 32] {
213    let mut data = Vec::with_capacity(key.len() + stream_salt.len() + 8);
214    data.extend_from_slice(key);
215    data.extend_from_slice(stream_salt);
216    data.extend_from_slice(&chunk_num.to_le_bytes());
217
218    let hash = digest::digest(&digest::SHA256, &data);
219    let mut result = [0u8; 32];
220    result.copy_from_slice(hash.as_ref());
221    result
222}
223
224/// Encrypt a single chunk.
225fn encrypt_chunk(key: &[u8; 32], data: &[u8]) -> Result<Vec<u8>> {
226    let rng = SystemRandom::new();
227
228    // Generate nonce
229    let mut nonce = [0u8; 16];
230    rng.fill(&mut nonce).map_err(|_| ShieldError::RandomFailed)?;
231
232    // Generate keystream
233    let mut keystream = Vec::with_capacity(data.len().div_ceil(32) * 32);
234    for i in 0..data.len().div_ceil(32) {
235        let counter = (i as u32).to_le_bytes();
236        let mut hash_input = Vec::with_capacity(key.len() + nonce.len() + 4);
237        hash_input.extend_from_slice(key);
238        hash_input.extend_from_slice(&nonce);
239        hash_input.extend_from_slice(&counter);
240        let hash = digest::digest(&digest::SHA256, &hash_input);
241        keystream.extend_from_slice(hash.as_ref());
242    }
243
244    // XOR encrypt
245    let ciphertext: Vec<u8> = data
246        .iter()
247        .zip(keystream.iter())
248        .map(|(p, k)| p ^ k)
249        .collect();
250
251    // HMAC
252    let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
253    let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
254    hmac_data.extend_from_slice(&nonce);
255    hmac_data.extend_from_slice(&ciphertext);
256    let tag = hmac::sign(&hmac_key, &hmac_data);
257
258    // Format: nonce || ciphertext || mac(16)
259    let mut result = Vec::with_capacity(16 + ciphertext.len() + 16);
260    result.extend_from_slice(&nonce);
261    result.extend_from_slice(&ciphertext);
262    result.extend_from_slice(&tag.as_ref()[..16]);
263
264    Ok(result)
265}
266
267/// Decrypt a single chunk.
268fn decrypt_chunk(key: &[u8; 32], encrypted: &[u8]) -> Result<Vec<u8>> {
269    if encrypted.len() < 32 {
270        return Err(ShieldError::StreamError("chunk too short".into()));
271    }
272
273    let nonce = &encrypted[..16];
274    let ciphertext = &encrypted[16..encrypted.len() - 16];
275    let mac = &encrypted[encrypted.len() - 16..];
276
277    // Verify MAC
278    let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
279    let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
280    hmac_data.extend_from_slice(nonce);
281    hmac_data.extend_from_slice(ciphertext);
282    let expected = hmac::sign(&hmac_key, &hmac_data);
283
284    if mac.ct_eq(&expected.as_ref()[..16]).unwrap_u8() != 1 {
285        return Err(ShieldError::AuthenticationFailed);
286    }
287
288    // Generate keystream
289    let mut keystream = Vec::with_capacity(ciphertext.len().div_ceil(32) * 32);
290    for i in 0..ciphertext.len().div_ceil(32) {
291        let counter = (i as u32).to_le_bytes();
292        let mut hash_input = Vec::with_capacity(key.len() + nonce.len() + 4);
293        hash_input.extend_from_slice(key);
294        hash_input.extend_from_slice(nonce);
295        hash_input.extend_from_slice(&counter);
296        let hash = digest::digest(&digest::SHA256, &hash_input);
297        keystream.extend_from_slice(hash.as_ref());
298    }
299
300    // XOR decrypt
301    Ok(ciphertext
302        .iter()
303        .zip(keystream.iter())
304        .map(|(c, k)| c ^ k)
305        .collect())
306}
307
308#[cfg(test)]
309mod tests {
310    use super::*;
311
312    #[test]
313    fn test_stream_roundtrip() {
314        let key = [0x42u8; 32];
315        let cipher = StreamCipher::new(key);
316
317        let data = b"Hello, streaming world!";
318        let encrypted = cipher.encrypt(data).unwrap();
319        let decrypted = cipher.decrypt(&encrypted).unwrap();
320
321        assert_eq!(data.as_slice(), decrypted.as_slice());
322    }
323
324    #[test]
325    fn test_stream_large_data() {
326        let key = [0x42u8; 32];
327        let cipher = StreamCipher::with_chunk_size(key, 1024);
328
329        let data: Vec<u8> = (0..10000).map(|i| (i % 256) as u8).collect();
330        let encrypted = cipher.encrypt(&data).unwrap();
331        let decrypted = cipher.decrypt(&encrypted).unwrap();
332
333        assert_eq!(data, decrypted);
334    }
335
336    #[test]
337    fn test_stream_tamper_detection() {
338        let key = [0x42u8; 32];
339        let cipher = StreamCipher::new(key);
340
341        let mut encrypted = cipher.encrypt(b"test data").unwrap();
342        // Tamper with a chunk
343        if encrypted.len() > 30 {
344            encrypted[30] ^= 0xFF;
345        }
346
347        assert!(cipher.decrypt(&encrypted).is_err());
348    }
349}