1#![allow(clippy::cast_possible_truncation)]
8
9use ring::{hmac, digest, rand::{SecureRandom, SystemRandom}};
10use subtle::ConstantTimeEq;
11
12use crate::error::{Result, ShieldError};
13
14const DEFAULT_CHUNK_SIZE: usize = 64 * 1024;
16
17pub struct StreamCipher {
24 key: [u8; 32],
25 chunk_size: usize,
26}
27
28impl StreamCipher {
29 #[must_use]
31 pub fn new(key: [u8; 32]) -> Self {
32 Self {
33 key,
34 chunk_size: DEFAULT_CHUNK_SIZE,
35 }
36 }
37
38 #[must_use]
40 pub fn with_chunk_size(key: [u8; 32], chunk_size: usize) -> Self {
41 Self { key, chunk_size }
42 }
43
44 #[must_use]
46 pub fn chunk_size(&self) -> usize {
47 self.chunk_size
48 }
49
50 pub fn encrypt_stream<'a>(
55 &'a self,
56 data: &'a [u8],
57 ) -> Result<StreamEncryptor<'a>> {
58 StreamEncryptor::new(&self.key, data, self.chunk_size)
59 }
60
61 pub fn decrypt_stream(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
63 if encrypted.len() < 20 {
64 return Err(ShieldError::StreamError("encrypted data too short".into()));
65 }
66
67 let _chunk_size = u32::from_le_bytes([
69 encrypted[0], encrypted[1], encrypted[2], encrypted[3],
70 ]) as usize;
71 let stream_salt = &encrypted[4..20];
72
73 let mut output = Vec::new();
74 let mut pos = 20;
75 let mut chunk_num: u64 = 0;
76
77 while pos < encrypted.len() {
78 if pos + 4 > encrypted.len() {
80 return Err(ShieldError::StreamError("truncated chunk length".into()));
81 }
82
83 let chunk_len = u32::from_le_bytes([
84 encrypted[pos],
85 encrypted[pos + 1],
86 encrypted[pos + 2],
87 encrypted[pos + 3],
88 ]) as usize;
89 pos += 4;
90
91 if chunk_len == 0 {
93 break;
94 }
95
96 if pos + chunk_len > encrypted.len() {
97 return Err(ShieldError::StreamError("truncated chunk data".into()));
98 }
99
100 let chunk_data = &encrypted[pos..pos + chunk_len];
101 pos += chunk_len;
102
103 let chunk_key = derive_chunk_key(&self.key, stream_salt, chunk_num);
105
106 let decrypted = decrypt_chunk(&chunk_key, chunk_data)?;
108 output.extend_from_slice(&decrypted);
109
110 chunk_num += 1;
111 }
112
113 Ok(output)
114 }
115
116 pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
118 let encryptor = self.encrypt_stream(data)?;
119 let mut result = Vec::new();
120
121 for chunk in encryptor {
122 result.extend_from_slice(&chunk?);
123 }
124
125 Ok(result)
126 }
127
128 pub fn decrypt(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
130 self.decrypt_stream(encrypted)
131 }
132}
133
134pub struct StreamEncryptor<'a> {
136 key: &'a [u8; 32],
137 data: &'a [u8],
138 stream_salt: [u8; 16],
139 chunk_size: usize,
140 position: usize,
141 chunk_num: u64,
142 header_sent: bool,
143 finished: bool,
144}
145
146impl<'a> StreamEncryptor<'a> {
147 fn new(key: &'a [u8; 32], data: &'a [u8], chunk_size: usize) -> Result<Self> {
148 let rng = SystemRandom::new();
149 let mut stream_salt = [0u8; 16];
150 rng.fill(&mut stream_salt).map_err(|_| ShieldError::RandomFailed)?;
151
152 Ok(Self {
153 key,
154 data,
155 stream_salt,
156 chunk_size,
157 position: 0,
158 chunk_num: 0,
159 header_sent: false,
160 finished: false,
161 })
162 }
163}
164
165impl Iterator for StreamEncryptor<'_> {
166 type Item = Result<Vec<u8>>;
167
168 fn next(&mut self) -> Option<Self::Item> {
169 if self.finished {
170 return None;
171 }
172
173 if !self.header_sent {
175 self.header_sent = true;
176 let mut header = Vec::with_capacity(20);
177 header.extend_from_slice(&(self.chunk_size as u32).to_le_bytes());
178 header.extend_from_slice(&self.stream_salt);
179 return Some(Ok(header));
180 }
181
182 if self.position >= self.data.len() {
184 self.finished = true;
185 return Some(Ok(vec![0, 0, 0, 0]));
187 }
188
189 let end = std::cmp::min(self.position + self.chunk_size, self.data.len());
191 let chunk_data = &self.data[self.position..end];
192 self.position = end;
193
194 let chunk_key = derive_chunk_key(self.key, &self.stream_salt, self.chunk_num);
196 self.chunk_num += 1;
197
198 match encrypt_chunk(&chunk_key, chunk_data) {
200 Ok(encrypted) => {
201 let mut result = Vec::with_capacity(4 + encrypted.len());
202 result.extend_from_slice(&(encrypted.len() as u32).to_le_bytes());
203 result.extend_from_slice(&encrypted);
204 Some(Ok(result))
205 }
206 Err(e) => Some(Err(e)),
207 }
208 }
209}
210
211fn derive_chunk_key(key: &[u8], stream_salt: &[u8], chunk_num: u64) -> [u8; 32] {
213 let mut data = Vec::with_capacity(key.len() + stream_salt.len() + 8);
214 data.extend_from_slice(key);
215 data.extend_from_slice(stream_salt);
216 data.extend_from_slice(&chunk_num.to_le_bytes());
217
218 let hash = digest::digest(&digest::SHA256, &data);
219 let mut result = [0u8; 32];
220 result.copy_from_slice(hash.as_ref());
221 result
222}
223
224fn encrypt_chunk(key: &[u8; 32], data: &[u8]) -> Result<Vec<u8>> {
226 let rng = SystemRandom::new();
227
228 let mut nonce = [0u8; 16];
230 rng.fill(&mut nonce).map_err(|_| ShieldError::RandomFailed)?;
231
232 let mut keystream = Vec::with_capacity(data.len().div_ceil(32) * 32);
234 for i in 0..data.len().div_ceil(32) {
235 let counter = (i as u32).to_le_bytes();
236 let mut hash_input = Vec::with_capacity(key.len() + nonce.len() + 4);
237 hash_input.extend_from_slice(key);
238 hash_input.extend_from_slice(&nonce);
239 hash_input.extend_from_slice(&counter);
240 let hash = digest::digest(&digest::SHA256, &hash_input);
241 keystream.extend_from_slice(hash.as_ref());
242 }
243
244 let ciphertext: Vec<u8> = data
246 .iter()
247 .zip(keystream.iter())
248 .map(|(p, k)| p ^ k)
249 .collect();
250
251 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
253 let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
254 hmac_data.extend_from_slice(&nonce);
255 hmac_data.extend_from_slice(&ciphertext);
256 let tag = hmac::sign(&hmac_key, &hmac_data);
257
258 let mut result = Vec::with_capacity(16 + ciphertext.len() + 16);
260 result.extend_from_slice(&nonce);
261 result.extend_from_slice(&ciphertext);
262 result.extend_from_slice(&tag.as_ref()[..16]);
263
264 Ok(result)
265}
266
267fn decrypt_chunk(key: &[u8; 32], encrypted: &[u8]) -> Result<Vec<u8>> {
269 if encrypted.len() < 32 {
270 return Err(ShieldError::StreamError("chunk too short".into()));
271 }
272
273 let nonce = &encrypted[..16];
274 let ciphertext = &encrypted[16..encrypted.len() - 16];
275 let mac = &encrypted[encrypted.len() - 16..];
276
277 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
279 let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
280 hmac_data.extend_from_slice(nonce);
281 hmac_data.extend_from_slice(ciphertext);
282 let expected = hmac::sign(&hmac_key, &hmac_data);
283
284 if mac.ct_eq(&expected.as_ref()[..16]).unwrap_u8() != 1 {
285 return Err(ShieldError::AuthenticationFailed);
286 }
287
288 let mut keystream = Vec::with_capacity(ciphertext.len().div_ceil(32) * 32);
290 for i in 0..ciphertext.len().div_ceil(32) {
291 let counter = (i as u32).to_le_bytes();
292 let mut hash_input = Vec::with_capacity(key.len() + nonce.len() + 4);
293 hash_input.extend_from_slice(key);
294 hash_input.extend_from_slice(nonce);
295 hash_input.extend_from_slice(&counter);
296 let hash = digest::digest(&digest::SHA256, &hash_input);
297 keystream.extend_from_slice(hash.as_ref());
298 }
299
300 Ok(ciphertext
302 .iter()
303 .zip(keystream.iter())
304 .map(|(c, k)| c ^ k)
305 .collect())
306}
307
308#[cfg(test)]
309mod tests {
310 use super::*;
311
312 #[test]
313 fn test_stream_roundtrip() {
314 let key = [0x42u8; 32];
315 let cipher = StreamCipher::new(key);
316
317 let data = b"Hello, streaming world!";
318 let encrypted = cipher.encrypt(data).unwrap();
319 let decrypted = cipher.decrypt(&encrypted).unwrap();
320
321 assert_eq!(data.as_slice(), decrypted.as_slice());
322 }
323
324 #[test]
325 fn test_stream_large_data() {
326 let key = [0x42u8; 32];
327 let cipher = StreamCipher::with_chunk_size(key, 1024);
328
329 let data: Vec<u8> = (0..10000).map(|i| (i % 256) as u8).collect();
330 let encrypted = cipher.encrypt(&data).unwrap();
331 let decrypted = cipher.decrypt(&encrypted).unwrap();
332
333 assert_eq!(data, decrypted);
334 }
335
336 #[test]
337 fn test_stream_tamper_detection() {
338 let key = [0x42u8; 32];
339 let cipher = StreamCipher::new(key);
340
341 let mut encrypted = cipher.encrypt(b"test data").unwrap();
342 if encrypted.len() > 30 {
344 encrypted[30] ^= 0xFF;
345 }
346
347 assert!(cipher.decrypt(&encrypted).is_err());
348 }
349}