1#![allow(clippy::cast_possible_truncation)]
8
9use ring::{
10 digest, hmac,
11 rand::{SecureRandom, SystemRandom},
12};
13use subtle::ConstantTimeEq;
14
15use crate::error::{Result, ShieldError};
16
17const DEFAULT_CHUNK_SIZE: usize = 64 * 1024;
19
20pub struct StreamCipher {
27 key: [u8; 32],
28 chunk_size: usize,
29}
30
31impl StreamCipher {
32 #[must_use]
34 pub fn new(key: [u8; 32]) -> Self {
35 Self {
36 key,
37 chunk_size: DEFAULT_CHUNK_SIZE,
38 }
39 }
40
41 #[must_use]
43 pub fn with_chunk_size(key: [u8; 32], chunk_size: usize) -> Self {
44 Self { key, chunk_size }
45 }
46
47 #[must_use]
49 pub fn chunk_size(&self) -> usize {
50 self.chunk_size
51 }
52
53 pub fn encrypt_stream<'a>(&'a self, data: &'a [u8]) -> Result<StreamEncryptor<'a>> {
58 StreamEncryptor::new(&self.key, data, self.chunk_size)
59 }
60
61 pub fn decrypt_stream(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
63 if encrypted.len() < 20 {
64 return Err(ShieldError::StreamError("encrypted data too short".into()));
65 }
66
67 let _chunk_size =
69 u32::from_le_bytes([encrypted[0], encrypted[1], encrypted[2], encrypted[3]]) as usize;
70 let stream_salt = &encrypted[4..20];
71
72 let mut output = Vec::new();
73 let mut pos = 20;
74 let mut chunk_num: u64 = 0;
75
76 while pos < encrypted.len() {
77 if pos + 4 > encrypted.len() {
79 return Err(ShieldError::StreamError("truncated chunk length".into()));
80 }
81
82 let chunk_len = u32::from_le_bytes([
83 encrypted[pos],
84 encrypted[pos + 1],
85 encrypted[pos + 2],
86 encrypted[pos + 3],
87 ]) as usize;
88 pos += 4;
89
90 if chunk_len == 0 {
92 break;
93 }
94
95 if pos + chunk_len > encrypted.len() {
96 return Err(ShieldError::StreamError("truncated chunk data".into()));
97 }
98
99 let chunk_data = &encrypted[pos..pos + chunk_len];
100 pos += chunk_len;
101
102 let chunk_key = derive_chunk_key(&self.key, stream_salt, chunk_num);
104
105 let decrypted = decrypt_chunk(&chunk_key, chunk_data)?;
107 output.extend_from_slice(&decrypted);
108
109 chunk_num += 1;
110 }
111
112 Ok(output)
113 }
114
115 pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
117 let encryptor = self.encrypt_stream(data)?;
118 let mut result = Vec::new();
119
120 for chunk in encryptor {
121 result.extend_from_slice(&chunk?);
122 }
123
124 Ok(result)
125 }
126
127 pub fn decrypt(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
129 self.decrypt_stream(encrypted)
130 }
131}
132
133pub struct StreamEncryptor<'a> {
135 key: &'a [u8; 32],
136 data: &'a [u8],
137 stream_salt: [u8; 16],
138 chunk_size: usize,
139 position: usize,
140 chunk_num: u64,
141 header_sent: bool,
142 finished: bool,
143}
144
145impl<'a> StreamEncryptor<'a> {
146 fn new(key: &'a [u8; 32], data: &'a [u8], chunk_size: usize) -> Result<Self> {
147 let rng = SystemRandom::new();
148 let mut stream_salt = [0u8; 16];
149 rng.fill(&mut stream_salt)
150 .map_err(|_| ShieldError::RandomFailed)?;
151
152 Ok(Self {
153 key,
154 data,
155 stream_salt,
156 chunk_size,
157 position: 0,
158 chunk_num: 0,
159 header_sent: false,
160 finished: false,
161 })
162 }
163}
164
165impl Iterator for StreamEncryptor<'_> {
166 type Item = Result<Vec<u8>>;
167
168 fn next(&mut self) -> Option<Self::Item> {
169 if self.finished {
170 return None;
171 }
172
173 if !self.header_sent {
175 self.header_sent = true;
176 let mut header = Vec::with_capacity(20);
177 header.extend_from_slice(&(self.chunk_size as u32).to_le_bytes());
178 header.extend_from_slice(&self.stream_salt);
179 return Some(Ok(header));
180 }
181
182 if self.position >= self.data.len() {
184 self.finished = true;
185 return Some(Ok(vec![0, 0, 0, 0]));
187 }
188
189 let end = std::cmp::min(self.position + self.chunk_size, self.data.len());
191 let chunk_data = &self.data[self.position..end];
192 self.position = end;
193
194 let chunk_key = derive_chunk_key(self.key, &self.stream_salt, self.chunk_num);
196 self.chunk_num += 1;
197
198 match encrypt_chunk(&chunk_key, chunk_data) {
200 Ok(encrypted) => {
201 let mut result = Vec::with_capacity(4 + encrypted.len());
202 result.extend_from_slice(&(encrypted.len() as u32).to_le_bytes());
203 result.extend_from_slice(&encrypted);
204 Some(Ok(result))
205 }
206 Err(e) => Some(Err(e)),
207 }
208 }
209}
210
211fn derive_chunk_key(key: &[u8], stream_salt: &[u8], chunk_num: u64) -> [u8; 32] {
213 let mut data = Vec::with_capacity(key.len() + stream_salt.len() + 8);
214 data.extend_from_slice(key);
215 data.extend_from_slice(stream_salt);
216 data.extend_from_slice(&chunk_num.to_le_bytes());
217
218 let hash = digest::digest(&digest::SHA256, &data);
219 let mut result = [0u8; 32];
220 result.copy_from_slice(hash.as_ref());
221 result
222}
223
224fn encrypt_chunk(key: &[u8; 32], data: &[u8]) -> Result<Vec<u8>> {
226 let rng = SystemRandom::new();
227
228 let mut nonce = [0u8; 16];
230 rng.fill(&mut nonce)
231 .map_err(|_| ShieldError::RandomFailed)?;
232
233 let mut keystream = Vec::with_capacity(data.len().div_ceil(32) * 32);
235 for i in 0..data.len().div_ceil(32) {
236 let counter = (i as u32).to_le_bytes();
237 let mut hash_input = Vec::with_capacity(key.len() + nonce.len() + 4);
238 hash_input.extend_from_slice(key);
239 hash_input.extend_from_slice(&nonce);
240 hash_input.extend_from_slice(&counter);
241 let hash = digest::digest(&digest::SHA256, &hash_input);
242 keystream.extend_from_slice(hash.as_ref());
243 }
244
245 let ciphertext: Vec<u8> = data
247 .iter()
248 .zip(keystream.iter())
249 .map(|(p, k)| p ^ k)
250 .collect();
251
252 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
254 let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
255 hmac_data.extend_from_slice(&nonce);
256 hmac_data.extend_from_slice(&ciphertext);
257 let tag = hmac::sign(&hmac_key, &hmac_data);
258
259 let mut result = Vec::with_capacity(16 + ciphertext.len() + 16);
261 result.extend_from_slice(&nonce);
262 result.extend_from_slice(&ciphertext);
263 result.extend_from_slice(&tag.as_ref()[..16]);
264
265 Ok(result)
266}
267
268fn decrypt_chunk(key: &[u8; 32], encrypted: &[u8]) -> Result<Vec<u8>> {
270 if encrypted.len() < 32 {
271 return Err(ShieldError::StreamError("chunk too short".into()));
272 }
273
274 let nonce = &encrypted[..16];
275 let ciphertext = &encrypted[16..encrypted.len() - 16];
276 let mac = &encrypted[encrypted.len() - 16..];
277
278 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
280 let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
281 hmac_data.extend_from_slice(nonce);
282 hmac_data.extend_from_slice(ciphertext);
283 let expected = hmac::sign(&hmac_key, &hmac_data);
284
285 if mac.ct_eq(&expected.as_ref()[..16]).unwrap_u8() != 1 {
286 return Err(ShieldError::AuthenticationFailed);
287 }
288
289 let mut keystream = Vec::with_capacity(ciphertext.len().div_ceil(32) * 32);
291 for i in 0..ciphertext.len().div_ceil(32) {
292 let counter = (i as u32).to_le_bytes();
293 let mut hash_input = Vec::with_capacity(key.len() + nonce.len() + 4);
294 hash_input.extend_from_slice(key);
295 hash_input.extend_from_slice(nonce);
296 hash_input.extend_from_slice(&counter);
297 let hash = digest::digest(&digest::SHA256, &hash_input);
298 keystream.extend_from_slice(hash.as_ref());
299 }
300
301 Ok(ciphertext
303 .iter()
304 .zip(keystream.iter())
305 .map(|(c, k)| c ^ k)
306 .collect())
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312
313 #[test]
314 fn test_stream_roundtrip() {
315 let key = [0x42u8; 32];
316 let cipher = StreamCipher::new(key);
317
318 let data = b"Hello, streaming world!";
319 let encrypted = cipher.encrypt(data).unwrap();
320 let decrypted = cipher.decrypt(&encrypted).unwrap();
321
322 assert_eq!(data.as_slice(), decrypted.as_slice());
323 }
324
325 #[test]
326 fn test_stream_large_data() {
327 let key = [0x42u8; 32];
328 let cipher = StreamCipher::with_chunk_size(key, 1024);
329
330 let data: Vec<u8> = (0..10000).map(|i| (i % 256) as u8).collect();
331 let encrypted = cipher.encrypt(&data).unwrap();
332 let decrypted = cipher.decrypt(&encrypted).unwrap();
333
334 assert_eq!(data, decrypted);
335 }
336
337 #[test]
338 fn test_stream_tamper_detection() {
339 let key = [0x42u8; 32];
340 let cipher = StreamCipher::new(key);
341
342 let mut encrypted = cipher.encrypt(b"test data").unwrap();
343 if encrypted.len() > 30 {
345 encrypted[30] ^= 0xFF;
346 }
347
348 assert!(cipher.decrypt(&encrypted).is_err());
349 }
350}