Skip to main content

rustledger_loader/
lib.rs

1//! Beancount file loader with include resolution.
2//!
3//! This crate handles loading beancount files, resolving includes,
4//! and collecting options. It builds on the parser to provide a
5//! complete loading pipeline.
6//!
7//! # Features
8//!
9//! - Recursive include resolution with cycle detection
10//! - Options collection and parsing
11//! - Plugin directive collection
12//! - Source map for error reporting
13//! - Push/pop tag and metadata handling
14//! - Automatic GPG decryption for encrypted files (`.gpg`, `.asc`)
15//!
16//! # Example
17//!
18//! ```ignore
19//! use rustledger_loader::Loader;
20//! use std::path::Path;
21//!
22//! let result = Loader::new().load(Path::new("ledger.beancount"))?;
23//! for directive in result.directives {
24//!     println!("{:?}", directive);
25//! }
26//! ```
27
28#![forbid(unsafe_code)]
29#![warn(missing_docs)]
30
31#[cfg(feature = "cache")]
32pub mod cache;
33mod options;
34#[cfg(any(feature = "booking", feature = "plugins", feature = "validation"))]
35mod process;
36mod source_map;
37
38#[cfg(feature = "cache")]
39pub use cache::{
40    CacheEntry, CachedOptions, CachedPlugin, invalidate_cache, load_cache_entry,
41    reintern_directives, save_cache_entry,
42};
43pub use options::Options;
44pub use source_map::{SourceFile, SourceMap};
45
46// Re-export processing API when features are enabled
47#[cfg(any(feature = "booking", feature = "plugins", feature = "validation"))]
48pub use process::{
49    ErrorLocation, ErrorSeverity, Ledger, LedgerError, LoadOptions, ProcessError, load, load_raw,
50    process,
51};
52
53use rustledger_core::{Directive, DisplayContext};
54use rustledger_parser::{ParseError, Span, Spanned};
55use std::collections::HashSet;
56use std::fs;
57use std::path::{Path, PathBuf};
58use std::process::Command;
59use thiserror::Error;
60
61/// Try to canonicalize a path, falling back to making it absolute if canonicalize
62/// is not supported (e.g., on WASI).
63///
64/// This function:
65/// 1. First tries `fs::canonicalize()` which resolves symlinks and returns absolute path
66/// 2. If that fails (e.g., WASI doesn't support it), tries to make an absolute path manually
67/// 3. As a last resort, returns the original path
68fn normalize_path(path: &Path) -> PathBuf {
69    // Try canonicalize first (works on most platforms, resolves symlinks)
70    if let Ok(canonical) = path.canonicalize() {
71        return canonical;
72    }
73
74    // Fallback: make absolute without resolving symlinks (WASI-compatible)
75    if path.is_absolute() {
76        path.to_path_buf()
77    } else if let Ok(cwd) = std::env::current_dir() {
78        // Join with current directory and clean up the path
79        let mut result = cwd;
80        for component in path.components() {
81            match component {
82                std::path::Component::ParentDir => {
83                    result.pop();
84                }
85                std::path::Component::Normal(s) => {
86                    result.push(s);
87                }
88                std::path::Component::CurDir => {}
89                std::path::Component::RootDir => {
90                    result = PathBuf::from("/");
91                }
92                std::path::Component::Prefix(p) => {
93                    result = PathBuf::from(p.as_os_str());
94                }
95            }
96        }
97        result
98    } else {
99        // Last resort: just return the path as-is
100        path.to_path_buf()
101    }
102}
103
104/// Errors that can occur during loading.
105#[derive(Debug, Error)]
106pub enum LoadError {
107    /// IO error reading a file.
108    #[error("failed to read file {path}: {source}")]
109    Io {
110        /// The path that failed to read.
111        path: PathBuf,
112        /// The underlying IO error.
113        #[source]
114        source: std::io::Error,
115    },
116
117    /// Include cycle detected.
118    #[error("include cycle detected: {}", .cycle.join(" -> "))]
119    IncludeCycle {
120        /// The cycle of file paths.
121        cycle: Vec<String>,
122    },
123
124    /// Parse errors occurred.
125    #[error("parse errors in {path}")]
126    ParseErrors {
127        /// The file with parse errors.
128        path: PathBuf,
129        /// The parse errors.
130        errors: Vec<ParseError>,
131    },
132
133    /// Path traversal attempt detected.
134    #[error("path traversal not allowed: {include_path} escapes base directory {base_dir}")]
135    PathTraversal {
136        /// The include path that attempted traversal.
137        include_path: String,
138        /// The base directory.
139        base_dir: PathBuf,
140    },
141
142    /// GPG decryption failed.
143    #[error("failed to decrypt {path}: {message}")]
144    Decryption {
145        /// The encrypted file path.
146        path: PathBuf,
147        /// Error message from GPG.
148        message: String,
149    },
150}
151
152/// Result of loading a beancount file.
153#[derive(Debug)]
154pub struct LoadResult {
155    /// All directives from all files, in order.
156    pub directives: Vec<Spanned<Directive>>,
157    /// Parsed options.
158    pub options: Options,
159    /// Plugins to load.
160    pub plugins: Vec<Plugin>,
161    /// Source map for error reporting.
162    pub source_map: SourceMap,
163    /// All errors encountered during loading.
164    pub errors: Vec<LoadError>,
165    /// Display context for formatting numbers (tracks precision per currency).
166    pub display_context: DisplayContext,
167}
168
169/// A plugin directive.
170#[derive(Debug, Clone)]
171pub struct Plugin {
172    /// Plugin module name (with any `python:` prefix stripped).
173    pub name: String,
174    /// Optional configuration string.
175    pub config: Option<String>,
176    /// Source location.
177    pub span: Span,
178    /// File this plugin was declared in.
179    pub file_id: usize,
180    /// Whether the `python:` prefix was used to force Python execution.
181    pub force_python: bool,
182}
183
184/// Check if a file is GPG-encrypted based on extension or content.
185///
186/// Returns `true` for:
187/// - Files with `.gpg` extension
188/// - Files with `.asc` extension containing a PGP message header
189fn is_encrypted_file(path: &Path) -> bool {
190    match path.extension().and_then(|e| e.to_str()) {
191        Some("gpg") => true,
192        Some("asc") => {
193            // Check for PGP header in first 1024 bytes
194            if let Ok(content) = fs::read_to_string(path) {
195                let check_len = 1024.min(content.len());
196                content[..check_len].contains("-----BEGIN PGP MESSAGE-----")
197            } else {
198                false
199            }
200        }
201        _ => false,
202    }
203}
204
205/// Decrypt a GPG-encrypted file using the system `gpg` command.
206///
207/// This uses `gpg --batch --decrypt` which will use the user's
208/// GPG keyring and gpg-agent for passphrase handling.
209fn decrypt_gpg_file(path: &Path) -> Result<String, LoadError> {
210    let output = Command::new("gpg")
211        .args(["--batch", "--decrypt"])
212        .arg(path)
213        .output()
214        .map_err(|e| LoadError::Decryption {
215            path: path.to_path_buf(),
216            message: format!("failed to run gpg: {e}"),
217        })?;
218
219    if !output.status.success() {
220        return Err(LoadError::Decryption {
221            path: path.to_path_buf(),
222            message: String::from_utf8_lossy(&output.stderr).trim().to_string(),
223        });
224    }
225
226    String::from_utf8(output.stdout).map_err(|e| LoadError::Decryption {
227        path: path.to_path_buf(),
228        message: format!("decrypted content is not valid UTF-8: {e}"),
229    })
230}
231
232/// Beancount file loader.
233#[derive(Debug, Default)]
234pub struct Loader {
235    /// Files that have been loaded (for cycle detection).
236    loaded_files: HashSet<PathBuf>,
237    /// Stack for cycle detection during loading (maintains order for error messages).
238    include_stack: Vec<PathBuf>,
239    /// Set for O(1) cycle detection (mirrors `include_stack`).
240    include_stack_set: HashSet<PathBuf>,
241    /// Root directory for path traversal protection.
242    /// If set, includes must resolve to paths within this directory.
243    root_dir: Option<PathBuf>,
244    /// Whether to enforce path traversal protection.
245    enforce_path_security: bool,
246}
247
248impl Loader {
249    /// Create a new loader.
250    #[must_use]
251    pub fn new() -> Self {
252        Self::default()
253    }
254
255    /// Enable path traversal protection.
256    ///
257    /// When enabled, include directives cannot escape the root directory
258    /// of the main beancount file. This prevents malicious ledger files
259    /// from accessing sensitive files outside the ledger directory.
260    ///
261    /// # Example
262    ///
263    /// ```ignore
264    /// let result = Loader::new()
265    ///     .with_path_security(true)
266    ///     .load(Path::new("ledger.beancount"))?;
267    /// ```
268    #[must_use]
269    pub const fn with_path_security(mut self, enabled: bool) -> Self {
270        self.enforce_path_security = enabled;
271        self
272    }
273
274    /// Set a custom root directory for path security.
275    ///
276    /// By default, the root directory is the parent directory of the main file.
277    /// This method allows overriding that to a custom directory.
278    #[must_use]
279    pub fn with_root_dir(mut self, root: PathBuf) -> Self {
280        self.root_dir = Some(root);
281        self.enforce_path_security = true;
282        self
283    }
284
285    /// Load a beancount file and all its includes.
286    ///
287    /// Parses the file, processes options and plugin directives, and recursively
288    /// loads any included files.
289    ///
290    /// # Errors
291    ///
292    /// Returns [`LoadError`] in the following cases:
293    ///
294    /// - [`LoadError::Io`] - Failed to read the file or an included file
295    /// - [`LoadError::IncludeCycle`] - Circular include detected
296    ///
297    /// Note: Parse errors and path traversal errors are collected in
298    /// [`LoadResult::errors`] rather than returned directly, allowing
299    /// partial results to be returned.
300    pub fn load(&mut self, path: &Path) -> Result<LoadResult, LoadError> {
301        let mut directives = Vec::new();
302        let mut options = Options::default();
303        let mut plugins = Vec::new();
304        let mut source_map = SourceMap::new();
305        let mut errors = Vec::new();
306
307        // Get normalized absolute path (WASI-compatible, doesn't require canonicalize)
308        let canonical = normalize_path(path);
309
310        // Set root directory for path security if enabled but not explicitly set
311        if self.enforce_path_security && self.root_dir.is_none() {
312            self.root_dir = canonical.parent().map(Path::to_path_buf);
313        }
314
315        self.load_recursive(
316            &canonical,
317            &mut directives,
318            &mut options,
319            &mut plugins,
320            &mut source_map,
321            &mut errors,
322        )?;
323
324        // Build display context from directives and options
325        let display_context = build_display_context(&directives, &options);
326
327        Ok(LoadResult {
328            directives,
329            options,
330            plugins,
331            source_map,
332            errors,
333            display_context,
334        })
335    }
336
337    fn load_recursive(
338        &mut self,
339        path: &Path,
340        directives: &mut Vec<Spanned<Directive>>,
341        options: &mut Options,
342        plugins: &mut Vec<Plugin>,
343        source_map: &mut SourceMap,
344        errors: &mut Vec<LoadError>,
345    ) -> Result<(), LoadError> {
346        // Allocate path once for reuse
347        let path_buf = path.to_path_buf();
348
349        // Check for cycles using O(1) HashSet lookup
350        if self.include_stack_set.contains(&path_buf) {
351            let mut cycle: Vec<String> = self
352                .include_stack
353                .iter()
354                .map(|p| p.display().to_string())
355                .collect();
356            cycle.push(path.display().to_string());
357            return Err(LoadError::IncludeCycle { cycle });
358        }
359
360        // Check if already loaded
361        if self.loaded_files.contains(&path_buf) {
362            return Ok(());
363        }
364
365        // Read file (decrypting if necessary)
366        // Try fast UTF-8 conversion first, fall back to lossy for non-UTF-8 files
367        let source: std::sync::Arc<str> = if is_encrypted_file(path) {
368            decrypt_gpg_file(path)?.into()
369        } else {
370            let bytes = fs::read(path).map_err(|e| LoadError::Io {
371                path: path_buf.clone(),
372                source: e,
373            })?;
374            // Try zero-copy conversion first (common case), fall back to lossy
375            match String::from_utf8(bytes) {
376                Ok(s) => s.into(),
377                Err(e) => String::from_utf8_lossy(e.as_bytes()).into_owned().into(),
378            }
379        };
380
381        // Add to source map (Arc::clone is cheap - just increments refcount)
382        let file_id = source_map.add_file(path_buf.clone(), std::sync::Arc::clone(&source));
383
384        // Mark as loading (update both stack and set)
385        self.include_stack_set.insert(path_buf.clone());
386        self.include_stack.push(path_buf.clone());
387        self.loaded_files.insert(path_buf);
388
389        // Parse (borrows from Arc, no allocation)
390        let result = rustledger_parser::parse(&source);
391
392        // Collect parse errors
393        if !result.errors.is_empty() {
394            errors.push(LoadError::ParseErrors {
395                path: path.to_path_buf(),
396                errors: result.errors,
397            });
398        }
399
400        // Process options
401        for (key, value, _span) in result.options {
402            options.set(&key, &value);
403        }
404
405        // Process plugins
406        for (name, config, span) in result.plugins {
407            // Check for "python:" prefix to force Python execution
408            let (actual_name, force_python) = if let Some(stripped) = name.strip_prefix("python:") {
409                (stripped.to_string(), true)
410            } else {
411                (name, false)
412            };
413            plugins.push(Plugin {
414                name: actual_name,
415                config,
416                span,
417                file_id,
418                force_python,
419            });
420        }
421
422        // Process includes
423        let base_dir = path.parent().unwrap_or(Path::new("."));
424        for (include_path, _span) in &result.includes {
425            let full_path = base_dir.join(include_path);
426            // Use normalize_path for WASI compatibility (canonicalize not supported)
427            let canonical = normalize_path(&full_path);
428
429            // Path traversal protection: ensure include stays within root directory
430            if self.enforce_path_security
431                && let Some(ref root) = self.root_dir
432                && !canonical.starts_with(root)
433            {
434                errors.push(LoadError::PathTraversal {
435                    include_path: include_path.clone(),
436                    base_dir: root.clone(),
437                });
438                continue;
439            }
440
441            if let Err(e) =
442                self.load_recursive(&canonical, directives, options, plugins, source_map, errors)
443            {
444                errors.push(e);
445            }
446        }
447
448        // Add directives from this file, setting the file_id
449        directives.extend(
450            result
451                .directives
452                .into_iter()
453                .map(|d| d.with_file_id(file_id)),
454        );
455
456        // Pop from stack and set
457        if let Some(popped) = self.include_stack.pop() {
458            self.include_stack_set.remove(&popped);
459        }
460
461        Ok(())
462    }
463}
464
465/// Build a display context from loaded directives and options.
466///
467/// This scans all directives for amounts and tracks the maximum precision seen
468/// for each currency. Fixed precisions from `option "display_precision"` override
469/// the inferred values.
470fn build_display_context(directives: &[Spanned<Directive>], options: &Options) -> DisplayContext {
471    let mut ctx = DisplayContext::new();
472
473    // Set render_commas from options
474    ctx.set_render_commas(options.render_commas);
475
476    // Scan directives for amounts to infer precision
477    for spanned in directives {
478        match &spanned.value {
479            Directive::Transaction(txn) => {
480                for posting in &txn.postings {
481                    // Units (IncompleteAmount)
482                    if let Some(ref units) = posting.units
483                        && let (Some(number), Some(currency)) = (units.number(), units.currency())
484                    {
485                        ctx.update(number, currency);
486                    }
487                    // Cost (CostSpec)
488                    if let Some(ref cost) = posting.cost
489                        && let (Some(number), Some(currency)) =
490                            (cost.number_per.or(cost.number_total), &cost.currency)
491                    {
492                        ctx.update(number, currency.as_str());
493                    }
494                    // Price (PriceAnnotation)
495                    if let Some(ref price) = posting.price
496                        && let Some(amount) = price.amount()
497                    {
498                        ctx.update(amount.number, amount.currency.as_str());
499                    }
500                }
501            }
502            Directive::Balance(bal) => {
503                ctx.update(bal.amount.number, bal.amount.currency.as_str());
504                if let Some(tol) = bal.tolerance {
505                    ctx.update(tol, bal.amount.currency.as_str());
506                }
507            }
508            Directive::Price(price) => {
509                ctx.update(price.amount.number, price.amount.currency.as_str());
510            }
511            Directive::Pad(_)
512            | Directive::Open(_)
513            | Directive::Close(_)
514            | Directive::Commodity(_)
515            | Directive::Event(_)
516            | Directive::Query(_)
517            | Directive::Note(_)
518            | Directive::Document(_)
519            | Directive::Custom(_) => {}
520        }
521    }
522
523    // Apply fixed precisions from options (these override inferred values)
524    for (currency, precision) in &options.display_precision {
525        ctx.set_fixed_precision(currency, *precision);
526    }
527
528    ctx
529}
530
531/// Load a beancount file without processing.
532///
533/// This is a convenience function that creates a loader and loads a single file.
534/// For fully processed results (booking, plugins, validation), use the
535/// [`load`] function with [`LoadOptions`] instead.
536#[cfg(not(any(feature = "booking", feature = "plugins", feature = "validation")))]
537pub fn load(path: &Path) -> Result<LoadResult, LoadError> {
538    Loader::new().load(path)
539}
540
541#[cfg(test)]
542mod tests {
543    use super::*;
544    use std::io::Write;
545    use tempfile::NamedTempFile;
546
547    #[test]
548    fn test_is_encrypted_file_gpg_extension() {
549        let path = Path::new("test.beancount.gpg");
550        assert!(is_encrypted_file(path));
551    }
552
553    #[test]
554    fn test_is_encrypted_file_plain_beancount() {
555        let path = Path::new("test.beancount");
556        assert!(!is_encrypted_file(path));
557    }
558
559    #[test]
560    fn test_is_encrypted_file_asc_with_pgp_header() {
561        let mut file = NamedTempFile::with_suffix(".asc").unwrap();
562        writeln!(file, "-----BEGIN PGP MESSAGE-----").unwrap();
563        writeln!(file, "some encrypted content").unwrap();
564        writeln!(file, "-----END PGP MESSAGE-----").unwrap();
565        file.flush().unwrap();
566
567        assert!(is_encrypted_file(file.path()));
568    }
569
570    #[test]
571    fn test_is_encrypted_file_asc_without_pgp_header() {
572        let mut file = NamedTempFile::with_suffix(".asc").unwrap();
573        writeln!(file, "This is just a plain text file").unwrap();
574        writeln!(file, "with .asc extension but no PGP content").unwrap();
575        file.flush().unwrap();
576
577        assert!(!is_encrypted_file(file.path()));
578    }
579
580    #[test]
581    fn test_decrypt_gpg_file_missing_gpg() {
582        // Create a fake .gpg file
583        let mut file = NamedTempFile::with_suffix(".gpg").unwrap();
584        writeln!(file, "fake encrypted content").unwrap();
585        file.flush().unwrap();
586
587        // This will fail because the content isn't actually GPG-encrypted
588        // (or gpg isn't installed, or there's no matching key)
589        let result = decrypt_gpg_file(file.path());
590        assert!(result.is_err());
591
592        if let Err(LoadError::Decryption { path, message }) = result {
593            assert_eq!(path, file.path().to_path_buf());
594            assert!(!message.is_empty());
595        } else {
596            panic!("Expected Decryption error");
597        }
598    }
599
600    #[test]
601    fn test_plugin_force_python_prefix() {
602        let mut file = NamedTempFile::with_suffix(".beancount").unwrap();
603        writeln!(file, r#"plugin "python:my_plugin""#).unwrap();
604        writeln!(file, r#"plugin "regular_plugin""#).unwrap();
605        file.flush().unwrap();
606
607        let result = Loader::new().load(file.path()).unwrap();
608
609        assert_eq!(result.plugins.len(), 2);
610
611        // First plugin should have force_python = true and name without prefix
612        assert_eq!(result.plugins[0].name, "my_plugin");
613        assert!(result.plugins[0].force_python);
614
615        // Second plugin should have force_python = false
616        assert_eq!(result.plugins[1].name, "regular_plugin");
617        assert!(!result.plugins[1].force_python);
618    }
619
620    #[test]
621    fn test_plugin_force_python_with_config() {
622        let mut file = NamedTempFile::with_suffix(".beancount").unwrap();
623        writeln!(file, r#"plugin "python:my_plugin" "config_value""#).unwrap();
624        file.flush().unwrap();
625
626        let result = Loader::new().load(file.path()).unwrap();
627
628        assert_eq!(result.plugins.len(), 1);
629        assert_eq!(result.plugins[0].name, "my_plugin");
630        assert!(result.plugins[0].force_python);
631        assert_eq!(result.plugins[0].config, Some("config_value".to_string()));
632    }
633}