1#![forbid(unsafe_code)]
29#![warn(missing_docs)]
30
31#[cfg(feature = "cache")]
32pub mod cache;
33mod options;
34mod source_map;
35
36#[cfg(feature = "cache")]
37pub use cache::{
38 CacheEntry, CachedOptions, CachedPlugin, invalidate_cache, load_cache_entry,
39 reintern_directives, save_cache_entry,
40};
41pub use options::Options;
42pub use source_map::{SourceFile, SourceMap};
43
44use rustledger_core::{Directive, DisplayContext};
45use rustledger_parser::{ParseError, Span, Spanned};
46use std::collections::HashSet;
47use std::fs;
48use std::path::{Path, PathBuf};
49use std::process::Command;
50use thiserror::Error;
51
52fn normalize_path(path: &Path) -> PathBuf {
60 if let Ok(canonical) = path.canonicalize() {
62 return canonical;
63 }
64
65 if path.is_absolute() {
67 path.to_path_buf()
68 } else if let Ok(cwd) = std::env::current_dir() {
69 let mut result = cwd;
71 for component in path.components() {
72 match component {
73 std::path::Component::ParentDir => {
74 result.pop();
75 }
76 std::path::Component::Normal(s) => {
77 result.push(s);
78 }
79 std::path::Component::CurDir => {}
80 std::path::Component::RootDir => {
81 result = PathBuf::from("/");
82 }
83 std::path::Component::Prefix(p) => {
84 result = PathBuf::from(p.as_os_str());
85 }
86 }
87 }
88 result
89 } else {
90 path.to_path_buf()
92 }
93}
94
95#[derive(Debug, Error)]
97pub enum LoadError {
98 #[error("failed to read file {path}: {source}")]
100 Io {
101 path: PathBuf,
103 #[source]
105 source: std::io::Error,
106 },
107
108 #[error("include cycle detected: {}", .cycle.join(" -> "))]
110 IncludeCycle {
111 cycle: Vec<String>,
113 },
114
115 #[error("parse errors in {path}")]
117 ParseErrors {
118 path: PathBuf,
120 errors: Vec<ParseError>,
122 },
123
124 #[error("path traversal not allowed: {include_path} escapes base directory {base_dir}")]
126 PathTraversal {
127 include_path: String,
129 base_dir: PathBuf,
131 },
132
133 #[error("failed to decrypt {path}: {message}")]
135 Decryption {
136 path: PathBuf,
138 message: String,
140 },
141}
142
143#[derive(Debug)]
145pub struct LoadResult {
146 pub directives: Vec<Spanned<Directive>>,
148 pub options: Options,
150 pub plugins: Vec<Plugin>,
152 pub source_map: SourceMap,
154 pub errors: Vec<LoadError>,
156 pub display_context: DisplayContext,
158}
159
160#[derive(Debug, Clone)]
162pub struct Plugin {
163 pub name: String,
165 pub config: Option<String>,
167 pub span: Span,
169 pub file_id: usize,
171}
172
173fn is_encrypted_file(path: &Path) -> bool {
179 match path.extension().and_then(|e| e.to_str()) {
180 Some("gpg") => true,
181 Some("asc") => {
182 if let Ok(content) = fs::read_to_string(path) {
184 let check_len = 1024.min(content.len());
185 content[..check_len].contains("-----BEGIN PGP MESSAGE-----")
186 } else {
187 false
188 }
189 }
190 _ => false,
191 }
192}
193
194fn decrypt_gpg_file(path: &Path) -> Result<String, LoadError> {
199 let output = Command::new("gpg")
200 .args(["--batch", "--decrypt"])
201 .arg(path)
202 .output()
203 .map_err(|e| LoadError::Decryption {
204 path: path.to_path_buf(),
205 message: format!("failed to run gpg: {e}"),
206 })?;
207
208 if !output.status.success() {
209 return Err(LoadError::Decryption {
210 path: path.to_path_buf(),
211 message: String::from_utf8_lossy(&output.stderr).trim().to_string(),
212 });
213 }
214
215 String::from_utf8(output.stdout).map_err(|e| LoadError::Decryption {
216 path: path.to_path_buf(),
217 message: format!("decrypted content is not valid UTF-8: {e}"),
218 })
219}
220
221#[derive(Debug, Default)]
223pub struct Loader {
224 loaded_files: HashSet<PathBuf>,
226 include_stack: Vec<PathBuf>,
228 include_stack_set: HashSet<PathBuf>,
230 root_dir: Option<PathBuf>,
233 enforce_path_security: bool,
235}
236
237impl Loader {
238 #[must_use]
240 pub fn new() -> Self {
241 Self::default()
242 }
243
244 #[must_use]
258 pub const fn with_path_security(mut self, enabled: bool) -> Self {
259 self.enforce_path_security = enabled;
260 self
261 }
262
263 #[must_use]
268 pub fn with_root_dir(mut self, root: PathBuf) -> Self {
269 self.root_dir = Some(root);
270 self.enforce_path_security = true;
271 self
272 }
273
274 pub fn load(&mut self, path: &Path) -> Result<LoadResult, LoadError> {
290 let mut directives = Vec::new();
291 let mut options = Options::default();
292 let mut plugins = Vec::new();
293 let mut source_map = SourceMap::new();
294 let mut errors = Vec::new();
295
296 let canonical = normalize_path(path);
298
299 if self.enforce_path_security && self.root_dir.is_none() {
301 self.root_dir = canonical.parent().map(Path::to_path_buf);
302 }
303
304 self.load_recursive(
305 &canonical,
306 &mut directives,
307 &mut options,
308 &mut plugins,
309 &mut source_map,
310 &mut errors,
311 )?;
312
313 let display_context = build_display_context(&directives, &options);
315
316 Ok(LoadResult {
317 directives,
318 options,
319 plugins,
320 source_map,
321 errors,
322 display_context,
323 })
324 }
325
326 fn load_recursive(
327 &mut self,
328 path: &Path,
329 directives: &mut Vec<Spanned<Directive>>,
330 options: &mut Options,
331 plugins: &mut Vec<Plugin>,
332 source_map: &mut SourceMap,
333 errors: &mut Vec<LoadError>,
334 ) -> Result<(), LoadError> {
335 let path_buf = path.to_path_buf();
337
338 if self.include_stack_set.contains(&path_buf) {
340 let mut cycle: Vec<String> = self
341 .include_stack
342 .iter()
343 .map(|p| p.display().to_string())
344 .collect();
345 cycle.push(path.display().to_string());
346 return Err(LoadError::IncludeCycle { cycle });
347 }
348
349 if self.loaded_files.contains(&path_buf) {
351 return Ok(());
352 }
353
354 let source: std::sync::Arc<str> = if is_encrypted_file(path) {
357 decrypt_gpg_file(path)?.into()
358 } else {
359 let bytes = fs::read(path).map_err(|e| LoadError::Io {
360 path: path_buf.clone(),
361 source: e,
362 })?;
363 match String::from_utf8(bytes) {
365 Ok(s) => s.into(),
366 Err(e) => String::from_utf8_lossy(e.as_bytes()).into_owned().into(),
367 }
368 };
369
370 let file_id = source_map.add_file(path_buf.clone(), std::sync::Arc::clone(&source));
372
373 self.include_stack_set.insert(path_buf.clone());
375 self.include_stack.push(path_buf.clone());
376 self.loaded_files.insert(path_buf);
377
378 let result = rustledger_parser::parse(&source);
380
381 if !result.errors.is_empty() {
383 errors.push(LoadError::ParseErrors {
384 path: path.to_path_buf(),
385 errors: result.errors,
386 });
387 }
388
389 for (key, value, _span) in result.options {
391 options.set(&key, &value);
392 }
393
394 for (name, config, span) in result.plugins {
396 plugins.push(Plugin {
397 name,
398 config,
399 span,
400 file_id,
401 });
402 }
403
404 let base_dir = path.parent().unwrap_or(Path::new("."));
406 for (include_path, _span) in &result.includes {
407 let full_path = base_dir.join(include_path);
408 let canonical = normalize_path(&full_path);
410
411 if self.enforce_path_security {
413 if let Some(ref root) = self.root_dir {
414 if !canonical.starts_with(root) {
415 errors.push(LoadError::PathTraversal {
416 include_path: include_path.clone(),
417 base_dir: root.clone(),
418 });
419 continue;
420 }
421 }
422 }
423
424 if let Err(e) =
425 self.load_recursive(&canonical, directives, options, plugins, source_map, errors)
426 {
427 errors.push(e);
428 }
429 }
430
431 directives.extend(
433 result
434 .directives
435 .into_iter()
436 .map(|d| d.with_file_id(file_id)),
437 );
438
439 if let Some(popped) = self.include_stack.pop() {
441 self.include_stack_set.remove(&popped);
442 }
443
444 Ok(())
445 }
446}
447
448fn build_display_context(directives: &[Spanned<Directive>], options: &Options) -> DisplayContext {
454 let mut ctx = DisplayContext::new();
455
456 ctx.set_render_commas(options.render_commas);
458
459 for spanned in directives {
461 match &spanned.value {
462 Directive::Transaction(txn) => {
463 for posting in &txn.postings {
464 if let Some(ref units) = posting.units {
466 if let (Some(number), Some(currency)) = (units.number(), units.currency()) {
467 ctx.update(number, currency);
468 }
469 }
470 if let Some(ref cost) = posting.cost {
472 if let (Some(number), Some(currency)) =
473 (cost.number_per.or(cost.number_total), &cost.currency)
474 {
475 ctx.update(number, currency.as_str());
476 }
477 }
478 if let Some(ref price) = posting.price {
480 if let Some(amount) = price.amount() {
481 ctx.update(amount.number, amount.currency.as_str());
482 }
483 }
484 }
485 }
486 Directive::Balance(bal) => {
487 ctx.update(bal.amount.number, bal.amount.currency.as_str());
488 if let Some(tol) = bal.tolerance {
489 ctx.update(tol, bal.amount.currency.as_str());
490 }
491 }
492 Directive::Price(price) => {
493 ctx.update(price.amount.number, price.amount.currency.as_str());
494 }
495 Directive::Pad(_)
496 | Directive::Open(_)
497 | Directive::Close(_)
498 | Directive::Commodity(_)
499 | Directive::Event(_)
500 | Directive::Query(_)
501 | Directive::Note(_)
502 | Directive::Document(_)
503 | Directive::Custom(_) => {}
504 }
505 }
506
507 for (currency, precision) in &options.display_precision {
509 ctx.set_fixed_precision(currency, *precision);
510 }
511
512 ctx
513}
514
515pub fn load(path: &Path) -> Result<LoadResult, LoadError> {
519 Loader::new().load(path)
520}
521
522#[cfg(test)]
523mod tests {
524 use super::*;
525 use std::io::Write;
526 use tempfile::NamedTempFile;
527
528 #[test]
529 fn test_is_encrypted_file_gpg_extension() {
530 let path = Path::new("test.beancount.gpg");
531 assert!(is_encrypted_file(path));
532 }
533
534 #[test]
535 fn test_is_encrypted_file_plain_beancount() {
536 let path = Path::new("test.beancount");
537 assert!(!is_encrypted_file(path));
538 }
539
540 #[test]
541 fn test_is_encrypted_file_asc_with_pgp_header() {
542 let mut file = NamedTempFile::with_suffix(".asc").unwrap();
543 writeln!(file, "-----BEGIN PGP MESSAGE-----").unwrap();
544 writeln!(file, "some encrypted content").unwrap();
545 writeln!(file, "-----END PGP MESSAGE-----").unwrap();
546 file.flush().unwrap();
547
548 assert!(is_encrypted_file(file.path()));
549 }
550
551 #[test]
552 fn test_is_encrypted_file_asc_without_pgp_header() {
553 let mut file = NamedTempFile::with_suffix(".asc").unwrap();
554 writeln!(file, "This is just a plain text file").unwrap();
555 writeln!(file, "with .asc extension but no PGP content").unwrap();
556 file.flush().unwrap();
557
558 assert!(!is_encrypted_file(file.path()));
559 }
560
561 #[test]
562 fn test_decrypt_gpg_file_missing_gpg() {
563 let mut file = NamedTempFile::with_suffix(".gpg").unwrap();
565 writeln!(file, "fake encrypted content").unwrap();
566 file.flush().unwrap();
567
568 let result = decrypt_gpg_file(file.path());
571 assert!(result.is_err());
572
573 if let Err(LoadError::Decryption { path, message }) = result {
574 assert_eq!(path, file.path().to_path_buf());
575 assert!(!message.is_empty());
576 } else {
577 panic!("Expected Decryption error");
578 }
579 }
580}