shadow_crypt_shell/decryption/
file_ops.rs1use std::io::{Read, Write};
2
3use shadow_crypt_core::v1::{
4 file::{EncryptedFile, PlaintextFile},
5 file_ops::get_encrypted_file_from_bytes,
6};
7
8use crate::{
9 decryption::file::{DecryptionInputFile, DecryptionOutputFile},
10 errors::{WorkflowError, WorkflowResult},
11};
12
13pub fn store_plaintext_file(
14 file: &PlaintextFile,
15 output_dir: &std::path::Path,
16) -> WorkflowResult<DecryptionOutputFile> {
17 let output_file = DecryptionOutputFile {
18 path: output_dir.join(file.filename().as_str()),
19 filename: file.filename().as_str().to_string(),
20 };
21
22 if output_file.path.exists() {
23 return Err(WorkflowError::File(format!(
24 "Output file '{}' already exists",
25 output_file.filename
26 )));
27 }
28
29 let mut f = std::fs::File::create(output_file.path.as_path())?;
30 f.write_all(file.content().as_slice())?;
31
32 Ok(output_file)
33}
34
35pub fn load_encrypted_file(file: &DecryptionInputFile) -> WorkflowResult<EncryptedFile> {
36 let size: usize = file.size as usize;
37
38 let mut f = std::fs::File::open(&file.path)?;
39 let mut buffer: Vec<u8> = Vec::with_capacity(size);
40
41 f.read_to_end(&mut buffer)?;
42
43 Ok(get_encrypted_file_from_bytes(buffer.as_slice())?)
44}
45
46#[cfg(test)]
47mod tests {
48 use crate::utils::read_n_bytes_from_file;
49
50 use super::*;
51 use shadow_crypt_core::memory::{SecureBytes, SecureString};
52 use std::fs;
53 use std::io::Write;
54 use tempfile::NamedTempFile;
55
56 #[test]
57 fn test_read_n_bytes_from_file() {
58 let mut temp_file = NamedTempFile::new().unwrap();
59 let data = b"hello world";
60 temp_file.write_all(data).unwrap();
61 temp_file.flush().unwrap();
62 let path = temp_file.path();
63
64 let result = read_n_bytes_from_file(path, 5).unwrap();
65 assert_eq!(result.as_slice(), b"hello");
66 }
67
68 #[test]
69 fn test_read_n_bytes_from_file_more_than_size() {
70 let mut temp_file = NamedTempFile::new().unwrap();
71 let data = b"hi";
72 temp_file.write_all(data).unwrap();
73 temp_file.flush().unwrap();
74 let path = temp_file.path();
75
76 let result = read_n_bytes_from_file(path, 10).unwrap();
77 assert_eq!(result.as_slice(), b"hi");
78 }
79
80 #[test]
81 fn test_store_plaintext_file() {
82 let filename = SecureString::new("test.txt".to_string());
83 let content = SecureBytes::new(b"test content".to_vec());
84 let plaintext = PlaintextFile::new(filename, content);
85
86 let output = store_plaintext_file(&plaintext, &std::env::current_dir().unwrap()).unwrap();
87 assert_eq!(output.filename, "test.txt");
88
89 let read_content = fs::read(&output.path).unwrap();
90 assert_eq!(read_content, b"test content");
91
92 fs::remove_file(&output.path).unwrap();
94 }
95
96 #[test]
97 fn test_store_plaintext_file_no_overwrite() {
98 let temp_dir = tempfile::TempDir::new().unwrap();
99 let original_dir = std::env::current_dir().unwrap();
100
101 std::env::set_current_dir(&temp_dir).unwrap();
102
103 let filename = SecureString::new("test.txt".to_string());
104 let content = SecureBytes::new(b"new content".to_vec());
105 let plaintext = PlaintextFile::new(filename, content);
106
107 let output_path = temp_dir.path().join("test.txt");
109 let existing_content = b"existing content";
110 std::fs::write(&output_path, existing_content).unwrap();
111
112 let result = store_plaintext_file(&plaintext, temp_dir.path());
113 assert!(result.is_err());
114 if let Err(WorkflowError::File(msg)) = result {
115 assert!(msg.contains("already exists"));
116 } else {
117 panic!("Expected File error");
118 }
119
120 let read_content = std::fs::read(&output_path).unwrap();
122 assert_eq!(read_content, existing_content);
123
124 std::env::set_current_dir(original_dir).unwrap();
125 }
126
127 #[test]
128 fn test_load_encrypted_file_invalid_data() {
129 let mut temp_file = NamedTempFile::new().unwrap();
130 let data = b"invalid encrypted data";
131 temp_file.write_all(data).unwrap();
132 temp_file.flush().unwrap();
133 let path = temp_file.path();
134
135 let input_file = DecryptionInputFile {
136 path: path.to_path_buf(),
137 filename: "test.shadow".to_string(),
138 size: data.len() as u64,
139 };
140
141 let result = load_encrypted_file(&input_file);
142 assert!(result.is_err());
144 }
145}