shadow_crypt_shell/decryption/
file_ops.rs

1use std::io::{Read, Write};
2
3use shadow_crypt_core::v1::{
4    file::{EncryptedFile, PlaintextFile},
5    file_ops::get_encrypted_file_from_bytes,
6};
7
8use crate::{
9    decryption::file::{DecryptionInputFile, DecryptionOutputFile},
10    errors::{WorkflowError, WorkflowResult},
11};
12
13pub fn store_plaintext_file(
14    file: &PlaintextFile,
15    output_dir: &std::path::Path,
16) -> WorkflowResult<DecryptionOutputFile> {
17    let output_file = DecryptionOutputFile {
18        path: output_dir.join(file.filename().as_str()),
19        filename: file.filename().as_str().to_string(),
20    };
21
22    if output_file.path.exists() {
23        return Err(WorkflowError::File(format!(
24            "Output file '{}' already exists",
25            output_file.filename
26        )));
27    }
28
29    let mut f = std::fs::File::create(output_file.path.as_path())?;
30    f.write_all(file.content().as_slice())?;
31
32    Ok(output_file)
33}
34
35pub fn load_encrypted_file(file: &DecryptionInputFile) -> WorkflowResult<EncryptedFile> {
36    let size: usize = file.size as usize;
37
38    let mut f = std::fs::File::open(&file.path)?;
39    let mut buffer: Vec<u8> = Vec::with_capacity(size);
40
41    f.read_to_end(&mut buffer)?;
42
43    Ok(get_encrypted_file_from_bytes(buffer.as_slice())?)
44}
45
46#[cfg(test)]
47mod tests {
48    use crate::utils::read_n_bytes_from_file;
49
50    use super::*;
51    use shadow_crypt_core::memory::{SecureBytes, SecureString};
52    use std::fs;
53    use std::io::Write;
54    use tempfile::NamedTempFile;
55
56    #[test]
57    fn test_read_n_bytes_from_file() {
58        let mut temp_file = NamedTempFile::new().unwrap();
59        let data = b"hello world";
60        temp_file.write_all(data).unwrap();
61        temp_file.flush().unwrap();
62        let path = temp_file.path();
63
64        let result = read_n_bytes_from_file(path, 5).unwrap();
65        assert_eq!(result.as_slice(), b"hello");
66    }
67
68    #[test]
69    fn test_read_n_bytes_from_file_more_than_size() {
70        let mut temp_file = NamedTempFile::new().unwrap();
71        let data = b"hi";
72        temp_file.write_all(data).unwrap();
73        temp_file.flush().unwrap();
74        let path = temp_file.path();
75
76        let result = read_n_bytes_from_file(path, 10).unwrap();
77        assert_eq!(result.as_slice(), b"hi");
78    }
79
80    #[test]
81    fn test_store_plaintext_file() {
82        let filename = SecureString::new("test.txt".to_string());
83        let content = SecureBytes::new(b"test content".to_vec());
84        let plaintext = PlaintextFile::new(filename, content);
85
86        let output = store_plaintext_file(&plaintext, &std::env::current_dir().unwrap()).unwrap();
87        assert_eq!(output.filename, "test.txt");
88
89        let read_content = fs::read(&output.path).unwrap();
90        assert_eq!(read_content, b"test content");
91
92        // Clean up
93        fs::remove_file(&output.path).unwrap();
94    }
95
96    #[test]
97    fn test_store_plaintext_file_no_overwrite() {
98        let temp_dir = tempfile::TempDir::new().unwrap();
99        let original_dir = std::env::current_dir().unwrap();
100
101        std::env::set_current_dir(&temp_dir).unwrap();
102
103        let filename = SecureString::new("test.txt".to_string());
104        let content = SecureBytes::new(b"new content".to_vec());
105        let plaintext = PlaintextFile::new(filename, content);
106
107        // Create existing file
108        let output_path = temp_dir.path().join("test.txt");
109        let existing_content = b"existing content";
110        std::fs::write(&output_path, existing_content).unwrap();
111
112        let result = store_plaintext_file(&plaintext, temp_dir.path());
113        assert!(result.is_err());
114        if let Err(WorkflowError::File(msg)) = result {
115            assert!(msg.contains("already exists"));
116        } else {
117            panic!("Expected File error");
118        }
119
120        // Check existing content unchanged
121        let read_content = std::fs::read(&output_path).unwrap();
122        assert_eq!(read_content, existing_content);
123
124        std::env::set_current_dir(original_dir).unwrap();
125    }
126
127    #[test]
128    fn test_load_encrypted_file_invalid_data() {
129        let mut temp_file = NamedTempFile::new().unwrap();
130        let data = b"invalid encrypted data";
131        temp_file.write_all(data).unwrap();
132        temp_file.flush().unwrap();
133        let path = temp_file.path();
134
135        let input_file = DecryptionInputFile {
136            path: path.to_path_buf(),
137            filename: "test.shadow".to_string(),
138            size: data.len() as u64,
139        };
140
141        let result = load_encrypted_file(&input_file);
142        // Should error due to invalid data
143        assert!(result.is_err());
144    }
145}