1#![allow(clippy::cast_possible_truncation)]
8
9use ring::{digest, hmac};
10use subtle::ConstantTimeEq;
11use zeroize::{Zeroize, ZeroizeOnDrop};
12
13use crate::error::{Result, ShieldError};
14
15const DEFAULT_CHUNK_SIZE: usize = 64 * 1024;
17
18#[derive(Zeroize, ZeroizeOnDrop)]
25pub struct StreamCipher {
26 key: [u8; 32],
27 #[zeroize(skip)]
28 chunk_size: usize,
29}
30
31impl StreamCipher {
32 #[must_use]
34 pub fn new(key: [u8; 32]) -> Self {
35 Self {
36 key,
37 chunk_size: DEFAULT_CHUNK_SIZE,
38 }
39 }
40
41 #[must_use]
43 pub fn with_chunk_size(key: [u8; 32], chunk_size: usize) -> Self {
44 Self { key, chunk_size }
45 }
46
47 #[must_use]
49 pub fn chunk_size(&self) -> usize {
50 self.chunk_size
51 }
52
53 pub fn encrypt_stream<'a>(&'a self, data: &'a [u8]) -> Result<StreamEncryptor<'a>> {
58 StreamEncryptor::new(&self.key, data, self.chunk_size)
59 }
60
61 pub fn decrypt_stream(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
63 if encrypted.len() < 20 {
64 return Err(ShieldError::StreamError("encrypted data too short".into()));
65 }
66
67 let _chunk_size =
69 u32::from_le_bytes([encrypted[0], encrypted[1], encrypted[2], encrypted[3]]) as usize;
70 let stream_salt = &encrypted[4..20];
71
72 let mut output = Vec::new();
73 let mut pos = 20;
74 let mut chunk_num: u64 = 0;
75
76 while pos < encrypted.len() {
77 if pos + 4 > encrypted.len() {
79 return Err(ShieldError::StreamError("truncated chunk length".into()));
80 }
81
82 let chunk_len = u32::from_le_bytes([
83 encrypted[pos],
84 encrypted[pos + 1],
85 encrypted[pos + 2],
86 encrypted[pos + 3],
87 ]) as usize;
88 pos += 4;
89
90 if chunk_len == 0 {
92 break;
93 }
94
95 if pos + chunk_len > encrypted.len() {
96 return Err(ShieldError::StreamError("truncated chunk data".into()));
97 }
98
99 let chunk_data = &encrypted[pos..pos + chunk_len];
100 pos += chunk_len;
101
102 let chunk_key = derive_chunk_key(&self.key, stream_salt, chunk_num);
104
105 let decrypted = decrypt_chunk(&chunk_key, chunk_data)?;
107 output.extend_from_slice(&decrypted);
108
109 chunk_num += 1;
110 }
111
112 Ok(output)
113 }
114
115 pub fn encrypt(&self, data: &[u8]) -> Result<Vec<u8>> {
117 let encryptor = self.encrypt_stream(data)?;
118 let mut result = Vec::new();
119
120 for chunk in encryptor {
121 result.extend_from_slice(&chunk?);
122 }
123
124 Ok(result)
125 }
126
127 pub fn decrypt(&self, encrypted: &[u8]) -> Result<Vec<u8>> {
129 self.decrypt_stream(encrypted)
130 }
131}
132
133pub struct StreamEncryptor<'a> {
135 key: &'a [u8; 32],
136 data: &'a [u8],
137 stream_salt: [u8; 16],
138 chunk_size: usize,
139 position: usize,
140 chunk_num: u64,
141 header_sent: bool,
142 finished: bool,
143}
144
145impl<'a> StreamEncryptor<'a> {
146 fn new(key: &'a [u8; 32], data: &'a [u8], chunk_size: usize) -> Result<Self> {
147 let stream_salt: [u8; 16] = crate::random::random_bytes()?;
148
149 Ok(Self {
150 key,
151 data,
152 stream_salt,
153 chunk_size,
154 position: 0,
155 chunk_num: 0,
156 header_sent: false,
157 finished: false,
158 })
159 }
160}
161
162impl Iterator for StreamEncryptor<'_> {
163 type Item = Result<Vec<u8>>;
164
165 fn next(&mut self) -> Option<Self::Item> {
166 if self.finished {
167 return None;
168 }
169
170 if !self.header_sent {
172 self.header_sent = true;
173 let mut header = Vec::with_capacity(20);
174 header.extend_from_slice(&(self.chunk_size as u32).to_le_bytes());
175 header.extend_from_slice(&self.stream_salt);
176 return Some(Ok(header));
177 }
178
179 if self.position >= self.data.len() {
181 self.finished = true;
182 return Some(Ok(vec![0, 0, 0, 0]));
184 }
185
186 let end = std::cmp::min(self.position + self.chunk_size, self.data.len());
188 let chunk_data = &self.data[self.position..end];
189 self.position = end;
190
191 let chunk_key = derive_chunk_key(self.key, &self.stream_salt, self.chunk_num);
193 self.chunk_num += 1;
194
195 match encrypt_chunk(&chunk_key, chunk_data) {
197 Ok(encrypted) => {
198 let mut result = Vec::with_capacity(4 + encrypted.len());
199 result.extend_from_slice(&(encrypted.len() as u32).to_le_bytes());
200 result.extend_from_slice(&encrypted);
201 Some(Ok(result))
202 }
203 Err(e) => Some(Err(e)),
204 }
205 }
206}
207
208fn derive_chunk_key(key: &[u8], stream_salt: &[u8], chunk_num: u64) -> [u8; 32] {
210 let mut data = Vec::with_capacity(key.len() + stream_salt.len() + 8);
211 data.extend_from_slice(key);
212 data.extend_from_slice(stream_salt);
213 data.extend_from_slice(&chunk_num.to_le_bytes());
214
215 let hash = digest::digest(&digest::SHA256, &data);
216 let mut result = [0u8; 32];
217 result.copy_from_slice(hash.as_ref());
218 result
219}
220
221fn derive_chunk_subkeys(key: &[u8; 32]) -> ([u8; 32], [u8; 32]) {
223 let hmac_key = hmac::Key::new(hmac::HMAC_SHA256, key);
224
225 let enc_tag = hmac::sign(&hmac_key, b"shield-stream-encrypt");
226 let mut enc_key = [0u8; 32];
227 enc_key.copy_from_slice(&enc_tag.as_ref()[..32]);
228
229 let mac_tag = hmac::sign(&hmac_key, b"shield-stream-authenticate");
230 let mut mac_key = [0u8; 32];
231 mac_key.copy_from_slice(&mac_tag.as_ref()[..32]);
232
233 (enc_key, mac_key)
234}
235
236fn encrypt_chunk(key: &[u8; 32], data: &[u8]) -> Result<Vec<u8>> {
238 let (enc_key, mac_key) = derive_chunk_subkeys(key);
239
240 let nonce: [u8; 16] = crate::random::random_bytes()?;
242
243 let num_blocks = data.len().div_ceil(32);
245 if u32::try_from(num_blocks).is_err() {
246 return Err(ShieldError::StreamError(
247 "keystream too long: counter overflow".into(),
248 ));
249 }
250 let mut keystream = Vec::with_capacity(num_blocks * 32);
251 for i in 0..num_blocks {
252 let counter = (i as u32).to_le_bytes();
253 let mut hash_input = Vec::with_capacity(enc_key.len() + nonce.len() + 4);
254 hash_input.extend_from_slice(&enc_key);
255 hash_input.extend_from_slice(&nonce);
256 hash_input.extend_from_slice(&counter);
257 let hash = digest::digest(&digest::SHA256, &hash_input);
258 keystream.extend_from_slice(hash.as_ref());
259 }
260
261 let ciphertext: Vec<u8> = data
263 .iter()
264 .zip(keystream.iter())
265 .map(|(p, k)| p ^ k)
266 .collect();
267
268 let hmac_signing_key = hmac::Key::new(hmac::HMAC_SHA256, &mac_key);
270 let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
271 hmac_data.extend_from_slice(&nonce);
272 hmac_data.extend_from_slice(&ciphertext);
273 let tag = hmac::sign(&hmac_signing_key, &hmac_data);
274
275 let mut result = Vec::with_capacity(16 + ciphertext.len() + 16);
277 result.extend_from_slice(&nonce);
278 result.extend_from_slice(&ciphertext);
279 result.extend_from_slice(&tag.as_ref()[..16]);
280
281 Ok(result)
282}
283
284fn decrypt_chunk(key: &[u8; 32], encrypted: &[u8]) -> Result<Vec<u8>> {
286 if encrypted.len() < 32 {
287 return Err(ShieldError::StreamError("chunk too short".into()));
288 }
289
290 let (enc_key, mac_key) = derive_chunk_subkeys(key);
291
292 let nonce = &encrypted[..16];
293 let ciphertext = &encrypted[16..encrypted.len() - 16];
294 let mac = &encrypted[encrypted.len() - 16..];
295
296 let hmac_signing_key = hmac::Key::new(hmac::HMAC_SHA256, &mac_key);
298 let mut hmac_data = Vec::with_capacity(16 + ciphertext.len());
299 hmac_data.extend_from_slice(nonce);
300 hmac_data.extend_from_slice(ciphertext);
301 let expected = hmac::sign(&hmac_signing_key, &hmac_data);
302
303 if mac.ct_eq(&expected.as_ref()[..16]).unwrap_u8() != 1 {
304 return Err(ShieldError::AuthenticationFailed);
305 }
306
307 let num_blocks = ciphertext.len().div_ceil(32);
309 if u32::try_from(num_blocks).is_err() {
310 return Err(ShieldError::StreamError(
311 "keystream too long: counter overflow".into(),
312 ));
313 }
314 let mut keystream = Vec::with_capacity(num_blocks * 32);
315 for i in 0..num_blocks {
316 let counter = (i as u32).to_le_bytes();
317 let mut hash_input = Vec::with_capacity(enc_key.len() + nonce.len() + 4);
318 hash_input.extend_from_slice(&enc_key);
319 hash_input.extend_from_slice(nonce);
320 hash_input.extend_from_slice(&counter);
321 let hash = digest::digest(&digest::SHA256, &hash_input);
322 keystream.extend_from_slice(hash.as_ref());
323 }
324
325 Ok(ciphertext
327 .iter()
328 .zip(keystream.iter())
329 .map(|(c, k)| c ^ k)
330 .collect())
331}
332
333#[cfg(test)]
334mod tests {
335 use super::*;
336
337 #[test]
338 fn test_stream_roundtrip() {
339 let key = [0x42u8; 32];
340 let cipher = StreamCipher::new(key);
341
342 let data = b"Hello, streaming world!";
343 let encrypted = cipher.encrypt(data).unwrap();
344 let decrypted = cipher.decrypt(&encrypted).unwrap();
345
346 assert_eq!(data.as_slice(), decrypted.as_slice());
347 }
348
349 #[test]
350 fn test_stream_large_data() {
351 let key = [0x42u8; 32];
352 let cipher = StreamCipher::with_chunk_size(key, 1024);
353
354 let data: Vec<u8> = (0..10000_u32).map(|i| (i % 256) as u8).collect();
355 let encrypted = cipher.encrypt(&data).unwrap();
356 let decrypted = cipher.decrypt(&encrypted).unwrap();
357
358 assert_eq!(data, decrypted);
359 }
360
361 #[test]
362 fn test_stream_tamper_detection() {
363 let key = [0x42u8; 32];
364 let cipher = StreamCipher::new(key);
365
366 let mut encrypted = cipher.encrypt(b"test data").unwrap();
367 if encrypted.len() > 30 {
369 encrypted[30] ^= 0xFF;
370 }
371
372 assert!(cipher.decrypt(&encrypted).is_err());
373 }
374}