shadow_crypt_shell/encryption/
file_ops.rs

1use std::{
2    io::{Read, Write},
3    path::PathBuf,
4};
5
6use rand::distr::{Alphabetic, SampleString};
7use shadow_crypt_core::{
8    memory::{SecureBytes, SecureString},
9    v1::file::{EncryptedFile, PlaintextFile},
10};
11
12use crate::{
13    encryption::file::{EncryptionInputFile, EncryptionOutputFile},
14    errors::{WorkflowError, WorkflowResult},
15};
16
17pub fn store_encrypted_file(
18    encrypted_file: &EncryptedFile,
19    output_dir: &std::path::Path,
20) -> WorkflowResult<EncryptionOutputFile> {
21    let output_file = create_encryption_output_file(output_dir)?;
22    let mut f = std::fs::File::create(&output_file.path)?;
23    let serialized_header: Vec<u8> =
24        shadow_crypt_core::v1::header_ops::serialize(encrypted_file.header());
25    f.write_all(&serialized_header)?;
26    f.write_all(encrypted_file.ciphertext())?;
27
28    Ok(output_file)
29}
30
31pub fn load_plaintext_file(file: &EncryptionInputFile) -> WorkflowResult<PlaintextFile> {
32    let filename = SecureString::new(file.filename.clone());
33    let size: usize = file.size as usize;
34
35    let mut f = std::fs::File::open(&file.path)?;
36    let mut buffer: Vec<u8> = Vec::with_capacity(size);
37
38    f.read_to_end(&mut buffer)?;
39
40    let content = SecureBytes::new(buffer);
41
42    Ok(PlaintextFile::new(filename, content))
43}
44
45fn generate_output_filename() -> WorkflowResult<String> {
46    let mut rng = rand::rng();
47    let len = 16;
48
49    Ok(Alphabetic.sample_string(&mut rng, len))
50}
51
52fn create_encryption_output_file(
53    output_dir: &std::path::Path,
54) -> WorkflowResult<EncryptionOutputFile> {
55    let mut counter = 0;
56    loop {
57        let base = generate_output_filename()?;
58        let filename = if counter == 0 {
59            base
60        } else {
61            format!("{}_{}", base, counter)
62        };
63
64        let mut path = PathBuf::from(&filename);
65        path.set_extension("shadow");
66
67        let full_path = output_dir.join(&path);
68
69        if !full_path.exists() {
70            let filename_str = path
71                .to_str()
72                .ok_or_else(|| WorkflowError::File("Invalid output filename".to_string()))?
73                .to_string();
74
75            return Ok(EncryptionOutputFile {
76                path: full_path,
77                filename: filename_str,
78            });
79        }
80
81        counter += 1;
82        if counter > 1000 {
83            return Err(WorkflowError::File(
84                "Unable to generate a unique output filename after 1000 attempts".to_string(),
85            ));
86        }
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93    use shadow_crypt_core::{
94        profile::SecurityProfile,
95        v1::{file::EncryptedFile, header::FileHeader, key::KeyDerivationParams},
96    };
97    use std::fs;
98    use tempfile::TempDir;
99
100    fn create_test_header() -> FileHeader {
101        let salt = [1u8; 16];
102        let kdf_params = KeyDerivationParams::from(SecurityProfile::Test);
103        let content_nonce = [2u8; 24];
104        let filename_nonce = [3u8; 24];
105        let filename_ciphertext = vec![4, 5, 6, 7, 8];
106
107        FileHeader::new(
108            salt,
109            kdf_params,
110            content_nonce,
111            filename_nonce,
112            filename_ciphertext,
113        )
114    }
115
116    fn create_test_encrypted_file() -> EncryptedFile {
117        let header = create_test_header();
118        let ciphertext = vec![10, 11, 12, 13, 14];
119        EncryptedFile::new(header, ciphertext)
120    }
121
122    #[test]
123    fn test_generate_output_filename() {
124        let filename = generate_output_filename().unwrap();
125        assert_eq!(filename.len(), 16);
126        assert!(filename.chars().all(|c| c.is_ascii_alphabetic()));
127    }
128
129    #[test]
130    fn test_create_output_file() {
131        // Test filename generation without creating actual files
132        let filename = generate_output_filename().unwrap();
133        assert_eq!(filename.len(), 16);
134        assert!(filename.chars().all(|c| c.is_ascii_alphabetic()));
135
136        // Test path construction logic
137        let expected_filename = format!("{}.shadow", filename);
138
139        // We can't easily test create_output_file without changing directories,
140        // but we can verify the filename format it would generate
141        assert!(expected_filename.ends_with(".shadow"));
142        assert!(expected_filename.len() > 7);
143    }
144
145    #[test]
146    fn test_load_file() {
147        let temp_dir = TempDir::new().unwrap();
148        let test_content = b"Hello, World!";
149        let test_filename = "test.txt";
150        let file_path = temp_dir.path().join(test_filename);
151
152        // Create test file
153        fs::write(&file_path, test_content).unwrap();
154
155        let input_file = EncryptionInputFile {
156            path: file_path.clone(),
157            filename: test_filename.to_string(),
158            size: test_content.len() as u64,
159        };
160
161        let plaintext_file = load_plaintext_file(&input_file).unwrap();
162
163        assert_eq!(plaintext_file.filename().as_str(), test_filename);
164        assert_eq!(plaintext_file.content().as_slice(), test_content);
165    }
166
167    #[test]
168    fn test_store_encrypted_file() {
169        // Create temp directory for isolated testing
170        let temp_dir = TempDir::new().unwrap();
171
172        let result = (|| -> Result<(), Box<dyn std::error::Error>> {
173            let encrypted_file = create_test_encrypted_file();
174            let output_file = store_encrypted_file(&encrypted_file, temp_dir.path())?;
175
176            // Check file was created in temp directory
177            assert!(output_file.path.exists());
178
179            // Canonicalize paths to handle macOS /private symlink
180            let canonical_output = fs::canonicalize(&output_file.path)?;
181            let canonical_temp = fs::canonicalize(temp_dir.path())?;
182            assert!(canonical_output.starts_with(canonical_temp));
183
184            // Read back and verify content
185            let written_content = fs::read(&output_file.path)?;
186            let expected_header =
187                shadow_crypt_core::v1::header_ops::serialize(encrypted_file.header());
188            let expected_content = [expected_header, encrypted_file.ciphertext().clone()].concat();
189
190            assert_eq!(written_content, expected_content);
191            Ok(())
192        })();
193
194        // Propagate any test failure
195        result.unwrap();
196    }
197}