1#![forbid(unsafe_code)]
29#![warn(missing_docs)]
30
31#[cfg(feature = "cache")]
32pub mod cache;
33mod options;
34mod source_map;
35
36#[cfg(feature = "cache")]
37pub use cache::{
38 invalidate_cache, load_cache_entry, reintern_directives, save_cache_entry, CacheEntry,
39 CachedOptions, CachedPlugin,
40};
41pub use options::Options;
42pub use source_map::{SourceFile, SourceMap};
43
44use rustledger_core::Directive;
45use rustledger_parser::{ParseError, Span, Spanned};
46use std::collections::HashSet;
47use std::fs;
48use std::path::{Path, PathBuf};
49use std::process::Command;
50use thiserror::Error;
51
52#[derive(Debug, Error)]
54pub enum LoadError {
55 #[error("failed to read file {path}: {source}")]
57 Io {
58 path: PathBuf,
60 #[source]
62 source: std::io::Error,
63 },
64
65 #[error("include cycle detected: {}", .cycle.join(" -> "))]
67 IncludeCycle {
68 cycle: Vec<String>,
70 },
71
72 #[error("parse errors in {path}")]
74 ParseErrors {
75 path: PathBuf,
77 errors: Vec<ParseError>,
79 },
80
81 #[error("path traversal not allowed: {include_path} escapes base directory {base_dir}")]
83 PathTraversal {
84 include_path: String,
86 base_dir: PathBuf,
88 },
89
90 #[error("failed to decrypt {path}: {message}")]
92 Decryption {
93 path: PathBuf,
95 message: String,
97 },
98}
99
100#[derive(Debug)]
102pub struct LoadResult {
103 pub directives: Vec<Spanned<Directive>>,
105 pub options: Options,
107 pub plugins: Vec<Plugin>,
109 pub source_map: SourceMap,
111 pub errors: Vec<LoadError>,
113}
114
115#[derive(Debug, Clone)]
117pub struct Plugin {
118 pub name: String,
120 pub config: Option<String>,
122 pub span: Span,
124 pub file_id: usize,
126}
127
128fn is_encrypted_file(path: &Path) -> bool {
134 match path.extension().and_then(|e| e.to_str()) {
135 Some("gpg") => true,
136 Some("asc") => {
137 if let Ok(content) = fs::read_to_string(path) {
139 let check_len = 1024.min(content.len());
140 content[..check_len].contains("-----BEGIN PGP MESSAGE-----")
141 } else {
142 false
143 }
144 }
145 _ => false,
146 }
147}
148
149fn decrypt_gpg_file(path: &Path) -> Result<String, LoadError> {
154 let output = Command::new("gpg")
155 .args(["--batch", "--decrypt"])
156 .arg(path)
157 .output()
158 .map_err(|e| LoadError::Decryption {
159 path: path.to_path_buf(),
160 message: format!("failed to run gpg: {e}"),
161 })?;
162
163 if !output.status.success() {
164 return Err(LoadError::Decryption {
165 path: path.to_path_buf(),
166 message: String::from_utf8_lossy(&output.stderr).trim().to_string(),
167 });
168 }
169
170 String::from_utf8(output.stdout).map_err(|e| LoadError::Decryption {
171 path: path.to_path_buf(),
172 message: format!("decrypted content is not valid UTF-8: {e}"),
173 })
174}
175
176#[derive(Debug, Default)]
178pub struct Loader {
179 loaded_files: HashSet<PathBuf>,
181 include_stack: Vec<PathBuf>,
183 root_dir: Option<PathBuf>,
186 enforce_path_security: bool,
188}
189
190impl Loader {
191 #[must_use]
193 pub fn new() -> Self {
194 Self::default()
195 }
196
197 #[must_use]
211 pub const fn with_path_security(mut self, enabled: bool) -> Self {
212 self.enforce_path_security = enabled;
213 self
214 }
215
216 #[must_use]
221 pub fn with_root_dir(mut self, root: PathBuf) -> Self {
222 self.root_dir = Some(root);
223 self.enforce_path_security = true;
224 self
225 }
226
227 pub fn load(&mut self, path: &Path) -> Result<LoadResult, LoadError> {
243 let mut directives = Vec::new();
244 let mut options = Options::default();
245 let mut plugins = Vec::new();
246 let mut source_map = SourceMap::new();
247 let mut errors = Vec::new();
248
249 let canonical = path.canonicalize().map_err(|e| LoadError::Io {
251 path: path.to_path_buf(),
252 source: e,
253 })?;
254
255 if self.enforce_path_security && self.root_dir.is_none() {
257 self.root_dir = canonical.parent().map(Path::to_path_buf);
258 }
259
260 self.load_recursive(
261 &canonical,
262 &mut directives,
263 &mut options,
264 &mut plugins,
265 &mut source_map,
266 &mut errors,
267 )?;
268
269 Ok(LoadResult {
270 directives,
271 options,
272 plugins,
273 source_map,
274 errors,
275 })
276 }
277
278 fn load_recursive(
279 &mut self,
280 path: &Path,
281 directives: &mut Vec<Spanned<Directive>>,
282 options: &mut Options,
283 plugins: &mut Vec<Plugin>,
284 source_map: &mut SourceMap,
285 errors: &mut Vec<LoadError>,
286 ) -> Result<(), LoadError> {
287 let path_buf = path.to_path_buf();
289 if self.include_stack.contains(&path_buf) {
290 let mut cycle: Vec<String> = self
291 .include_stack
292 .iter()
293 .map(|p| p.display().to_string())
294 .collect();
295 cycle.push(path.display().to_string());
296 return Err(LoadError::IncludeCycle { cycle });
297 }
298
299 if self.loaded_files.contains(path) {
301 return Ok(());
302 }
303
304 let source: std::sync::Arc<str> = if is_encrypted_file(path) {
306 decrypt_gpg_file(path)?.into()
307 } else {
308 fs::read_to_string(path)
309 .map_err(|e| LoadError::Io {
310 path: path.to_path_buf(),
311 source: e,
312 })?
313 .into()
314 };
315
316 let file_id = source_map.add_file(path.to_path_buf(), std::sync::Arc::clone(&source));
318
319 self.include_stack.push(path.to_path_buf());
321 self.loaded_files.insert(path.to_path_buf());
322
323 let result = rustledger_parser::parse(&source);
325
326 if !result.errors.is_empty() {
328 errors.push(LoadError::ParseErrors {
329 path: path.to_path_buf(),
330 errors: result.errors,
331 });
332 }
333
334 for (key, value, _span) in result.options {
336 options.set(&key, &value);
337 }
338
339 for (name, config, span) in result.plugins {
341 plugins.push(Plugin {
342 name,
343 config,
344 span,
345 file_id,
346 });
347 }
348
349 let base_dir = path.parent().unwrap_or(Path::new("."));
351 for (include_path, _span) in &result.includes {
352 let full_path = base_dir.join(include_path);
353 let canonical = match full_path.canonicalize() {
354 Ok(p) => p,
355 Err(e) => {
356 errors.push(LoadError::Io {
357 path: full_path,
358 source: e,
359 });
360 continue;
361 }
362 };
363
364 if self.enforce_path_security {
366 if let Some(ref root) = self.root_dir {
367 if !canonical.starts_with(root) {
368 errors.push(LoadError::PathTraversal {
369 include_path: include_path.clone(),
370 base_dir: root.clone(),
371 });
372 continue;
373 }
374 }
375 }
376
377 if let Err(e) =
378 self.load_recursive(&canonical, directives, options, plugins, source_map, errors)
379 {
380 errors.push(e);
381 }
382 }
383
384 directives.extend(result.directives);
386
387 self.include_stack.pop();
389
390 Ok(())
391 }
392}
393
394pub fn load(path: &Path) -> Result<LoadResult, LoadError> {
398 Loader::new().load(path)
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404 use std::io::Write;
405 use tempfile::NamedTempFile;
406
407 #[test]
408 fn test_is_encrypted_file_gpg_extension() {
409 let path = Path::new("test.beancount.gpg");
410 assert!(is_encrypted_file(path));
411 }
412
413 #[test]
414 fn test_is_encrypted_file_plain_beancount() {
415 let path = Path::new("test.beancount");
416 assert!(!is_encrypted_file(path));
417 }
418
419 #[test]
420 fn test_is_encrypted_file_asc_with_pgp_header() {
421 let mut file = NamedTempFile::with_suffix(".asc").unwrap();
422 writeln!(file, "-----BEGIN PGP MESSAGE-----").unwrap();
423 writeln!(file, "some encrypted content").unwrap();
424 writeln!(file, "-----END PGP MESSAGE-----").unwrap();
425 file.flush().unwrap();
426
427 assert!(is_encrypted_file(file.path()));
428 }
429
430 #[test]
431 fn test_is_encrypted_file_asc_without_pgp_header() {
432 let mut file = NamedTempFile::with_suffix(".asc").unwrap();
433 writeln!(file, "This is just a plain text file").unwrap();
434 writeln!(file, "with .asc extension but no PGP content").unwrap();
435 file.flush().unwrap();
436
437 assert!(!is_encrypted_file(file.path()));
438 }
439
440 #[test]
441 fn test_decrypt_gpg_file_missing_gpg() {
442 let mut file = NamedTempFile::with_suffix(".gpg").unwrap();
444 writeln!(file, "fake encrypted content").unwrap();
445 file.flush().unwrap();
446
447 let result = decrypt_gpg_file(file.path());
450 assert!(result.is_err());
451
452 if let Err(LoadError::Decryption { path, message }) = result {
453 assert_eq!(path, file.path().to_path_buf());
454 assert!(!message.is_empty());
455 } else {
456 panic!("Expected Decryption error");
457 }
458 }
459}