shadow_crypt_core/v1/
header_ops.rs

1use crate::{errors::HeaderError, v1::key::KeyDerivationParams};
2
3use super::header::FileHeader;
4
5pub fn serialize(header: &FileHeader) -> Vec<u8> {
6    let mut bytes = Vec::new();
7
8    bytes.extend_from_slice(header.magic.as_slice());
9    bytes.push(header.version);
10    bytes.extend_from_slice(header.header_length.to_le_bytes().as_slice());
11    bytes.extend_from_slice(header.salt.as_slice());
12    bytes.extend_from_slice(header.kdf_memory.to_le_bytes().as_slice());
13    bytes.extend_from_slice(header.kdf_iterations.to_le_bytes().as_slice());
14    bytes.extend_from_slice(header.kdf_parallelism.to_le_bytes().as_slice());
15    bytes.push(header.kdf_key_length);
16    bytes.extend_from_slice(header.content_nonce.as_slice());
17    bytes.extend_from_slice(header.filename_nonce.as_slice());
18    bytes.extend_from_slice(header.filename_ciphertext_length.to_le_bytes().as_slice());
19    bytes.extend_from_slice(header.filename_ciphertext.as_slice());
20
21    bytes
22}
23
24pub fn is_shadow_file(bytes: &[u8]) -> Result<bool, HeaderError> {
25    let magic = get_magic_from_bytes(bytes)?;
26    Ok(&magic == b"SHADOW")
27}
28
29fn get_magic_from_bytes(bytes: &[u8]) -> Result<[u8; 6], HeaderError> {
30    if bytes.len() < 6 {
31        return Err(HeaderError::InsufficientBytes);
32    }
33    let magic_bytes = &bytes[0..6];
34    let mut magic = [0u8; 6];
35    magic.copy_from_slice(magic_bytes);
36    Ok(magic)
37}
38
39pub fn get_version_from_bytes(bytes: &[u8]) -> Result<u8, HeaderError> {
40    if bytes.len() < 7 {
41        return Err(HeaderError::InsufficientBytes);
42    }
43    Ok(bytes[6])
44}
45
46pub fn get_length_from_bytes(bytes: &[u8]) -> Result<u32, HeaderError> {
47    if bytes.len() < 11 {
48        return Err(HeaderError::InsufficientBytes);
49    }
50    let length_bytes = &bytes[7..11];
51    let length = u32::from_le_bytes(
52        length_bytes
53            .try_into()
54            .map_err(|_| HeaderError::InvalidData)?,
55    );
56    Ok(length)
57}
58
59pub fn get_kdf_params(header: &FileHeader) -> KeyDerivationParams {
60    KeyDerivationParams {
61        memory_cost: header.kdf_memory,
62        time_cost: header.kdf_iterations,
63        parallelism: header.kdf_parallelism,
64        key_size: header.kdf_key_length,
65    }
66}
67
68pub fn try_deserialize(bytes: &[u8]) -> Result<FileHeader, HeaderError> {
69    if bytes.len() < FileHeader::min_length() {
70        return Err(HeaderError::InsufficientBytes);
71    }
72
73    let length: u32 = get_length_from_bytes(bytes)?;
74
75    if bytes.len() < length as usize {
76        return Err(HeaderError::InsufficientBytes);
77    }
78
79    match deserialize(bytes) {
80        Some(header) => Ok(header),
81        None => Err(HeaderError::InvalidData),
82    }
83}
84
85fn deserialize(bytes: &[u8]) -> Option<FileHeader> {
86    if bytes.len() < FileHeader::min_length() {
87        return None;
88    }
89    let magic = bytes[0..6].try_into().ok()?;
90    let version = bytes[6];
91    let header_length = u32::from_le_bytes(bytes[7..11].try_into().ok()?);
92    let salt = bytes[11..27].try_into().ok()?;
93    let kdf_memory = u32::from_le_bytes(bytes[27..31].try_into().ok()?);
94    let kdf_iterations = u32::from_le_bytes(bytes[31..35].try_into().ok()?);
95    let kdf_parallelism = u32::from_le_bytes(bytes[35..39].try_into().ok()?);
96    let kdf_key_length = bytes[39];
97    let content_nonce = bytes[40..64].try_into().ok()?;
98    let filename_nonce = bytes[64..88].try_into().ok()?;
99    let filename_ciphertext_length = u16::from_le_bytes(bytes[88..90].try_into().ok()?);
100
101    let expected_length: usize = FileHeader::min_length() + filename_ciphertext_length as usize;
102
103    if header_length != expected_length as u32 {
104        return None;
105    }
106
107    if bytes.len() < expected_length {
108        return None;
109    }
110
111    let filename_ciphertext = bytes[FileHeader::min_length()
112        ..(FileHeader::min_length() + filename_ciphertext_length as usize)]
113        .to_vec();
114
115    Some(FileHeader {
116        magic,
117        version,
118        header_length,
119        salt,
120        kdf_memory,
121        kdf_iterations,
122        kdf_parallelism,
123        kdf_key_length,
124        content_nonce,
125        filename_nonce,
126        filename_ciphertext_length,
127        filename_ciphertext,
128    })
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use crate::profile;
135    use crate::v1::key::KeyDerivationParams;
136
137    fn create_test_header() -> FileHeader {
138        let salt = [1u8; 16];
139        let kdf_params = KeyDerivationParams::from(profile::SecurityProfile::Test);
140        let content_nonce = [2u8; 24];
141        let filename_nonce = [3u8; 24];
142        let filename_ciphertext = vec![4, 5, 6, 7, 8];
143
144        FileHeader::new(
145            salt,
146            kdf_params,
147            content_nonce,
148            filename_nonce,
149            filename_ciphertext,
150        )
151    }
152
153    #[test]
154    fn test_serialize() {
155        let header = create_test_header();
156        let serialized = serialize(&header);
157
158        // Check that the serialized data has the correct length
159        assert_eq!(serialized.len(), header.header_length as usize);
160
161        // Check magic bytes
162        assert_eq!(&serialized[0..6], b"SHADOW");
163
164        // Check version
165        assert_eq!(serialized[6], 1);
166
167        // Check header length (little endian)
168        let header_len_bytes = &serialized[7..11];
169        let header_len = u32::from_le_bytes(header_len_bytes.try_into().unwrap());
170        assert_eq!(header_len, header.header_length);
171
172        // Check salt
173        assert_eq!(&serialized[11..27], &header.salt);
174
175        // Check KDF parameters
176        let kdf_memory_bytes = &serialized[27..31];
177        let kdf_memory = u32::from_le_bytes(kdf_memory_bytes.try_into().unwrap());
178        assert_eq!(kdf_memory, header.kdf_memory);
179
180        let kdf_iterations_bytes = &serialized[31..35];
181        let kdf_iterations = u32::from_le_bytes(kdf_iterations_bytes.try_into().unwrap());
182        assert_eq!(kdf_iterations, header.kdf_iterations);
183
184        let kdf_parallelism_bytes = &serialized[35..39];
185        let kdf_parallelism = u32::from_le_bytes(kdf_parallelism_bytes.try_into().unwrap());
186        assert_eq!(kdf_parallelism, header.kdf_parallelism);
187
188        // Check key length
189        assert_eq!(serialized[39], header.kdf_key_length);
190
191        // Check nonces
192        assert_eq!(&serialized[40..64], &header.content_nonce);
193        assert_eq!(&serialized[64..88], &header.filename_nonce);
194
195        // Check filename ciphertext length
196        let filename_len_bytes = &serialized[88..90];
197        let filename_len = u16::from_le_bytes(filename_len_bytes.try_into().unwrap());
198        assert_eq!(filename_len, header.filename_ciphertext_length);
199
200        // Check filename ciphertext
201        let filename_start = FileHeader::min_length();
202        let filename_end = filename_start + header.filename_ciphertext.len();
203        assert_eq!(
204            &serialized[filename_start..filename_end],
205            &header.filename_ciphertext[..]
206        );
207    }
208
209    #[test]
210    fn test_is_shadow_file_valid() {
211        let header = create_test_header();
212        let serialized = serialize(&header);
213
214        let result = is_shadow_file(&serialized);
215        assert!(result.is_ok());
216        assert!(result.unwrap());
217    }
218
219    #[test]
220    fn test_is_shadow_file_invalid_magic() {
221        let mut bytes = vec![0u8; 100];
222        bytes[0..6].copy_from_slice(b"NOTSHD");
223
224        let result = is_shadow_file(&bytes);
225        assert!(result.is_ok());
226        assert!(!result.unwrap());
227    }
228
229    #[test]
230    fn test_is_shadow_file_insufficient_bytes() {
231        let bytes = vec![0u8; 5]; // Less than 6 bytes
232
233        let result = is_shadow_file(&bytes);
234        assert!(result.is_err());
235        assert!(matches!(
236            result.unwrap_err(),
237            HeaderError::InsufficientBytes
238        ));
239    }
240
241    #[test]
242    fn test_try_deserialize_valid() {
243        let original_header = create_test_header();
244        let serialized = serialize(&original_header);
245
246        let result = try_deserialize(&serialized);
247        assert!(result.is_ok());
248
249        let deserialized_header = result.unwrap();
250        assert_eq!(deserialized_header.magic, original_header.magic);
251        assert_eq!(deserialized_header.version, original_header.version);
252        assert_eq!(
253            deserialized_header.header_length,
254            original_header.header_length
255        );
256        assert_eq!(deserialized_header.salt, original_header.salt);
257        assert_eq!(deserialized_header.kdf_memory, original_header.kdf_memory);
258        assert_eq!(
259            deserialized_header.kdf_iterations,
260            original_header.kdf_iterations
261        );
262        assert_eq!(
263            deserialized_header.kdf_parallelism,
264            original_header.kdf_parallelism
265        );
266        assert_eq!(
267            deserialized_header.kdf_key_length,
268            original_header.kdf_key_length
269        );
270        assert_eq!(
271            deserialized_header.content_nonce,
272            original_header.content_nonce
273        );
274        assert_eq!(
275            deserialized_header.filename_nonce,
276            original_header.filename_nonce
277        );
278        assert_eq!(
279            deserialized_header.filename_ciphertext_length,
280            original_header.filename_ciphertext_length
281        );
282        assert_eq!(
283            deserialized_header.filename_ciphertext,
284            original_header.filename_ciphertext
285        );
286    }
287
288    #[test]
289    fn test_try_deserialize_insufficient_bytes() {
290        let bytes = vec![0u8; 50]; // Less than min_length
291
292        let result = try_deserialize(&bytes);
293        assert!(result.is_err());
294        assert!(matches!(
295            result.unwrap_err(),
296            HeaderError::InsufficientBytes
297        ));
298    }
299
300    #[test]
301    fn test_try_deserialize_invalid_data() {
302        let mut bytes = vec![0u8; 100];
303        // Set invalid header length (too small)
304        bytes[7..11].copy_from_slice(&(50u32.to_le_bytes())); // Header length smaller than min
305
306        let result = try_deserialize(&bytes);
307        assert!(result.is_err());
308        assert!(matches!(result.unwrap_err(), HeaderError::InvalidData));
309    }
310
311    #[test]
312    fn test_try_deserialize_insufficient_bytes_for_filename() {
313        let mut bytes = vec![0u8; 95]; // FileHeader::min_length() is 90, but we need more for filename
314        // Set up a valid header but with filename_ciphertext_length > 0
315        bytes[0..6].copy_from_slice(b"SHADOW");
316        bytes[6] = 1; // version
317        bytes[7..11].copy_from_slice(&(100u32.to_le_bytes())); // header_length = 100 (90 + 10)
318        // salt, kdf params, nonces, etc. - keep as zeros for simplicity
319        bytes[88..90].copy_from_slice(&(10u16.to_le_bytes())); // filename_ciphertext_length = 10
320        // But we only have 95 bytes total, and we need 100, so insufficient
321
322        let result = try_deserialize(&bytes);
323        assert!(result.is_err());
324        assert!(matches!(
325            result.unwrap_err(),
326            HeaderError::InsufficientBytes
327        ));
328    }
329
330    #[test]
331    fn test_round_trip_serialization() {
332        let original_header = create_test_header();
333        let serialized = serialize(&original_header);
334        let deserialized_result = try_deserialize(&serialized);
335
336        assert!(deserialized_result.is_ok());
337        let deserialized_header = deserialized_result.unwrap();
338
339        // Ensure all fields match
340        assert_eq!(original_header.magic, deserialized_header.magic);
341        assert_eq!(original_header.version, deserialized_header.version);
342        assert_eq!(
343            original_header.header_length,
344            deserialized_header.header_length
345        );
346        assert_eq!(original_header.salt, deserialized_header.salt);
347        assert_eq!(original_header.kdf_memory, deserialized_header.kdf_memory);
348        assert_eq!(
349            original_header.kdf_iterations,
350            deserialized_header.kdf_iterations
351        );
352        assert_eq!(
353            original_header.kdf_parallelism,
354            deserialized_header.kdf_parallelism
355        );
356        assert_eq!(
357            original_header.kdf_key_length,
358            deserialized_header.kdf_key_length
359        );
360        assert_eq!(
361            original_header.content_nonce,
362            deserialized_header.content_nonce
363        );
364        assert_eq!(
365            original_header.filename_nonce,
366            deserialized_header.filename_nonce
367        );
368        assert_eq!(
369            original_header.filename_ciphertext_length,
370            deserialized_header.filename_ciphertext_length
371        );
372        assert_eq!(
373            original_header.filename_ciphertext,
374            deserialized_header.filename_ciphertext
375        );
376    }
377
378    #[test]
379    fn test_empty_filename_ciphertext() {
380        let salt = [1u8; 16];
381        let kdf_params = KeyDerivationParams::from(profile::SecurityProfile::Test);
382        let content_nonce = [2u8; 24];
383        let filename_nonce = [3u8; 24];
384        let filename_ciphertext = vec![]; // Empty filename
385
386        let header = FileHeader::new(
387            salt,
388            kdf_params,
389            content_nonce,
390            filename_nonce,
391            filename_ciphertext,
392        );
393
394        let serialized = serialize(&header);
395        let deserialized_result = try_deserialize(&serialized);
396
397        assert!(deserialized_result.is_ok());
398        let deserialized_header = deserialized_result.unwrap();
399        assert_eq!(header.magic, deserialized_header.magic);
400        assert_eq!(header.version, deserialized_header.version);
401        assert_eq!(header.header_length, deserialized_header.header_length);
402        assert_eq!(header.salt, deserialized_header.salt);
403        assert_eq!(header.kdf_memory, deserialized_header.kdf_memory);
404        assert_eq!(header.kdf_iterations, deserialized_header.kdf_iterations);
405        assert_eq!(header.kdf_parallelism, deserialized_header.kdf_parallelism);
406        assert_eq!(header.kdf_key_length, deserialized_header.kdf_key_length);
407        assert_eq!(header.content_nonce, deserialized_header.content_nonce);
408        assert_eq!(header.filename_nonce, deserialized_header.filename_nonce);
409        assert_eq!(
410            header.filename_ciphertext_length,
411            deserialized_header.filename_ciphertext_length
412        );
413        assert_eq!(
414            header.filename_ciphertext,
415            deserialized_header.filename_ciphertext
416        );
417    }
418
419    #[test]
420    fn test_large_filename_ciphertext() {
421        let salt = [1u8; 16];
422        let kdf_params = KeyDerivationParams::from(profile::SecurityProfile::Test);
423        let content_nonce = [2u8; 24];
424        let filename_nonce = [3u8; 24];
425        let filename_ciphertext = vec![4u8; 1000]; // Large filename
426
427        let header = FileHeader::new(
428            salt,
429            kdf_params,
430            content_nonce,
431            filename_nonce,
432            filename_ciphertext,
433        );
434
435        let serialized = serialize(&header);
436        let deserialized_result = try_deserialize(&serialized);
437
438        assert!(deserialized_result.is_ok());
439        let deserialized_header = deserialized_result.unwrap();
440        assert_eq!(header.magic, deserialized_header.magic);
441        assert_eq!(header.version, deserialized_header.version);
442        assert_eq!(header.header_length, deserialized_header.header_length);
443        assert_eq!(header.salt, deserialized_header.salt);
444        assert_eq!(header.kdf_memory, deserialized_header.kdf_memory);
445        assert_eq!(header.kdf_iterations, deserialized_header.kdf_iterations);
446        assert_eq!(header.kdf_parallelism, deserialized_header.kdf_parallelism);
447        assert_eq!(header.kdf_key_length, deserialized_header.kdf_key_length);
448        assert_eq!(header.content_nonce, deserialized_header.content_nonce);
449        assert_eq!(header.filename_nonce, deserialized_header.filename_nonce);
450        assert_eq!(
451            header.filename_ciphertext_length,
452            deserialized_header.filename_ciphertext_length
453        );
454        assert_eq!(
455            header.filename_ciphertext,
456            deserialized_header.filename_ciphertext
457        );
458    }
459}