1#![forbid(unsafe_code)]
29#![warn(missing_docs)]
30
31#[cfg(feature = "cache")]
32pub mod cache;
33mod options;
34mod source_map;
35
36#[cfg(feature = "cache")]
37pub use cache::{
38 CacheEntry, CachedOptions, CachedPlugin, invalidate_cache, load_cache_entry,
39 reintern_directives, save_cache_entry,
40};
41pub use options::Options;
42pub use source_map::{SourceFile, SourceMap};
43
44use rustledger_core::{Directive, DisplayContext};
45use rustledger_parser::{ParseError, Span, Spanned};
46use std::collections::HashSet;
47use std::fs;
48use std::path::{Path, PathBuf};
49use std::process::Command;
50use thiserror::Error;
51
52fn normalize_path(path: &Path) -> PathBuf {
60 if let Ok(canonical) = path.canonicalize() {
62 return canonical;
63 }
64
65 if path.is_absolute() {
67 path.to_path_buf()
68 } else if let Ok(cwd) = std::env::current_dir() {
69 let mut result = cwd;
71 for component in path.components() {
72 match component {
73 std::path::Component::ParentDir => {
74 result.pop();
75 }
76 std::path::Component::Normal(s) => {
77 result.push(s);
78 }
79 std::path::Component::CurDir => {}
80 std::path::Component::RootDir => {
81 result = PathBuf::from("/");
82 }
83 std::path::Component::Prefix(p) => {
84 result = PathBuf::from(p.as_os_str());
85 }
86 }
87 }
88 result
89 } else {
90 path.to_path_buf()
92 }
93}
94
95#[derive(Debug, Error)]
97pub enum LoadError {
98 #[error("failed to read file {path}: {source}")]
100 Io {
101 path: PathBuf,
103 #[source]
105 source: std::io::Error,
106 },
107
108 #[error("include cycle detected: {}", .cycle.join(" -> "))]
110 IncludeCycle {
111 cycle: Vec<String>,
113 },
114
115 #[error("parse errors in {path}")]
117 ParseErrors {
118 path: PathBuf,
120 errors: Vec<ParseError>,
122 },
123
124 #[error("path traversal not allowed: {include_path} escapes base directory {base_dir}")]
126 PathTraversal {
127 include_path: String,
129 base_dir: PathBuf,
131 },
132
133 #[error("failed to decrypt {path}: {message}")]
135 Decryption {
136 path: PathBuf,
138 message: String,
140 },
141}
142
143#[derive(Debug)]
145pub struct LoadResult {
146 pub directives: Vec<Spanned<Directive>>,
148 pub options: Options,
150 pub plugins: Vec<Plugin>,
152 pub source_map: SourceMap,
154 pub errors: Vec<LoadError>,
156 pub display_context: DisplayContext,
158}
159
160#[derive(Debug, Clone)]
162pub struct Plugin {
163 pub name: String,
165 pub config: Option<String>,
167 pub span: Span,
169 pub file_id: usize,
171}
172
173fn is_encrypted_file(path: &Path) -> bool {
179 match path.extension().and_then(|e| e.to_str()) {
180 Some("gpg") => true,
181 Some("asc") => {
182 if let Ok(content) = fs::read_to_string(path) {
184 let check_len = 1024.min(content.len());
185 content[..check_len].contains("-----BEGIN PGP MESSAGE-----")
186 } else {
187 false
188 }
189 }
190 _ => false,
191 }
192}
193
194fn decrypt_gpg_file(path: &Path) -> Result<String, LoadError> {
199 let output = Command::new("gpg")
200 .args(["--batch", "--decrypt"])
201 .arg(path)
202 .output()
203 .map_err(|e| LoadError::Decryption {
204 path: path.to_path_buf(),
205 message: format!("failed to run gpg: {e}"),
206 })?;
207
208 if !output.status.success() {
209 return Err(LoadError::Decryption {
210 path: path.to_path_buf(),
211 message: String::from_utf8_lossy(&output.stderr).trim().to_string(),
212 });
213 }
214
215 String::from_utf8(output.stdout).map_err(|e| LoadError::Decryption {
216 path: path.to_path_buf(),
217 message: format!("decrypted content is not valid UTF-8: {e}"),
218 })
219}
220
221#[derive(Debug, Default)]
223pub struct Loader {
224 loaded_files: HashSet<PathBuf>,
226 include_stack: Vec<PathBuf>,
228 include_stack_set: HashSet<PathBuf>,
230 root_dir: Option<PathBuf>,
233 enforce_path_security: bool,
235}
236
237impl Loader {
238 #[must_use]
240 pub fn new() -> Self {
241 Self::default()
242 }
243
244 #[must_use]
258 pub const fn with_path_security(mut self, enabled: bool) -> Self {
259 self.enforce_path_security = enabled;
260 self
261 }
262
263 #[must_use]
268 pub fn with_root_dir(mut self, root: PathBuf) -> Self {
269 self.root_dir = Some(root);
270 self.enforce_path_security = true;
271 self
272 }
273
274 pub fn load(&mut self, path: &Path) -> Result<LoadResult, LoadError> {
290 let mut directives = Vec::new();
291 let mut options = Options::default();
292 let mut plugins = Vec::new();
293 let mut source_map = SourceMap::new();
294 let mut errors = Vec::new();
295
296 let canonical = normalize_path(path);
298
299 if self.enforce_path_security && self.root_dir.is_none() {
301 self.root_dir = canonical.parent().map(Path::to_path_buf);
302 }
303
304 self.load_recursive(
305 &canonical,
306 &mut directives,
307 &mut options,
308 &mut plugins,
309 &mut source_map,
310 &mut errors,
311 )?;
312
313 let display_context = build_display_context(&directives, &options);
315
316 Ok(LoadResult {
317 directives,
318 options,
319 plugins,
320 source_map,
321 errors,
322 display_context,
323 })
324 }
325
326 fn load_recursive(
327 &mut self,
328 path: &Path,
329 directives: &mut Vec<Spanned<Directive>>,
330 options: &mut Options,
331 plugins: &mut Vec<Plugin>,
332 source_map: &mut SourceMap,
333 errors: &mut Vec<LoadError>,
334 ) -> Result<(), LoadError> {
335 let path_buf = path.to_path_buf();
337
338 if self.include_stack_set.contains(&path_buf) {
340 let mut cycle: Vec<String> = self
341 .include_stack
342 .iter()
343 .map(|p| p.display().to_string())
344 .collect();
345 cycle.push(path.display().to_string());
346 return Err(LoadError::IncludeCycle { cycle });
347 }
348
349 if self.loaded_files.contains(&path_buf) {
351 return Ok(());
352 }
353
354 let source: std::sync::Arc<str> = if is_encrypted_file(path) {
357 decrypt_gpg_file(path)?.into()
358 } else {
359 let bytes = fs::read(path).map_err(|e| LoadError::Io {
360 path: path_buf.clone(),
361 source: e,
362 })?;
363 match String::from_utf8(bytes) {
365 Ok(s) => s.into(),
366 Err(e) => String::from_utf8_lossy(e.as_bytes()).into_owned().into(),
367 }
368 };
369
370 let file_id = source_map.add_file(path_buf.clone(), std::sync::Arc::clone(&source));
372
373 self.include_stack_set.insert(path_buf.clone());
375 self.include_stack.push(path_buf.clone());
376 self.loaded_files.insert(path_buf);
377
378 let result = rustledger_parser::parse(&source);
380
381 if !result.errors.is_empty() {
383 errors.push(LoadError::ParseErrors {
384 path: path.to_path_buf(),
385 errors: result.errors,
386 });
387 }
388
389 for (key, value, _span) in result.options {
391 options.set(&key, &value);
392 }
393
394 for (name, config, span) in result.plugins {
396 plugins.push(Plugin {
397 name,
398 config,
399 span,
400 file_id,
401 });
402 }
403
404 let base_dir = path.parent().unwrap_or(Path::new("."));
406 for (include_path, _span) in &result.includes {
407 let full_path = base_dir.join(include_path);
408 let canonical = normalize_path(&full_path);
410
411 if self.enforce_path_security
413 && let Some(ref root) = self.root_dir
414 && !canonical.starts_with(root)
415 {
416 errors.push(LoadError::PathTraversal {
417 include_path: include_path.clone(),
418 base_dir: root.clone(),
419 });
420 continue;
421 }
422
423 if let Err(e) =
424 self.load_recursive(&canonical, directives, options, plugins, source_map, errors)
425 {
426 errors.push(e);
427 }
428 }
429
430 directives.extend(
432 result
433 .directives
434 .into_iter()
435 .map(|d| d.with_file_id(file_id)),
436 );
437
438 if let Some(popped) = self.include_stack.pop() {
440 self.include_stack_set.remove(&popped);
441 }
442
443 Ok(())
444 }
445}
446
447fn build_display_context(directives: &[Spanned<Directive>], options: &Options) -> DisplayContext {
453 let mut ctx = DisplayContext::new();
454
455 ctx.set_render_commas(options.render_commas);
457
458 for spanned in directives {
460 match &spanned.value {
461 Directive::Transaction(txn) => {
462 for posting in &txn.postings {
463 if let Some(ref units) = posting.units
465 && let (Some(number), Some(currency)) = (units.number(), units.currency())
466 {
467 ctx.update(number, currency);
468 }
469 if let Some(ref cost) = posting.cost
471 && let (Some(number), Some(currency)) =
472 (cost.number_per.or(cost.number_total), &cost.currency)
473 {
474 ctx.update(number, currency.as_str());
475 }
476 if let Some(ref price) = posting.price
478 && let Some(amount) = price.amount()
479 {
480 ctx.update(amount.number, amount.currency.as_str());
481 }
482 }
483 }
484 Directive::Balance(bal) => {
485 ctx.update(bal.amount.number, bal.amount.currency.as_str());
486 if let Some(tol) = bal.tolerance {
487 ctx.update(tol, bal.amount.currency.as_str());
488 }
489 }
490 Directive::Price(price) => {
491 ctx.update(price.amount.number, price.amount.currency.as_str());
492 }
493 Directive::Pad(_)
494 | Directive::Open(_)
495 | Directive::Close(_)
496 | Directive::Commodity(_)
497 | Directive::Event(_)
498 | Directive::Query(_)
499 | Directive::Note(_)
500 | Directive::Document(_)
501 | Directive::Custom(_) => {}
502 }
503 }
504
505 for (currency, precision) in &options.display_precision {
507 ctx.set_fixed_precision(currency, *precision);
508 }
509
510 ctx
511}
512
513pub fn load(path: &Path) -> Result<LoadResult, LoadError> {
517 Loader::new().load(path)
518}
519
520#[cfg(test)]
521mod tests {
522 use super::*;
523 use std::io::Write;
524 use tempfile::NamedTempFile;
525
526 #[test]
527 fn test_is_encrypted_file_gpg_extension() {
528 let path = Path::new("test.beancount.gpg");
529 assert!(is_encrypted_file(path));
530 }
531
532 #[test]
533 fn test_is_encrypted_file_plain_beancount() {
534 let path = Path::new("test.beancount");
535 assert!(!is_encrypted_file(path));
536 }
537
538 #[test]
539 fn test_is_encrypted_file_asc_with_pgp_header() {
540 let mut file = NamedTempFile::with_suffix(".asc").unwrap();
541 writeln!(file, "-----BEGIN PGP MESSAGE-----").unwrap();
542 writeln!(file, "some encrypted content").unwrap();
543 writeln!(file, "-----END PGP MESSAGE-----").unwrap();
544 file.flush().unwrap();
545
546 assert!(is_encrypted_file(file.path()));
547 }
548
549 #[test]
550 fn test_is_encrypted_file_asc_without_pgp_header() {
551 let mut file = NamedTempFile::with_suffix(".asc").unwrap();
552 writeln!(file, "This is just a plain text file").unwrap();
553 writeln!(file, "with .asc extension but no PGP content").unwrap();
554 file.flush().unwrap();
555
556 assert!(!is_encrypted_file(file.path()));
557 }
558
559 #[test]
560 fn test_decrypt_gpg_file_missing_gpg() {
561 let mut file = NamedTempFile::with_suffix(".gpg").unwrap();
563 writeln!(file, "fake encrypted content").unwrap();
564 file.flush().unwrap();
565
566 let result = decrypt_gpg_file(file.path());
569 assert!(result.is_err());
570
571 if let Err(LoadError::Decryption { path, message }) = result {
572 assert_eq!(path, file.path().to_path_buf());
573 assert!(!message.is_empty());
574 } else {
575 panic!("Expected Decryption error");
576 }
577 }
578}