Skip to main content

rustledger_loader/
lib.rs

1//! Beancount file loader with include resolution.
2//!
3//! This crate handles loading beancount files, resolving includes,
4//! and collecting options. It builds on the parser to provide a
5//! complete loading pipeline.
6//!
7//! # Features
8//!
9//! - Recursive include resolution with cycle detection
10//! - Options collection and parsing
11//! - Plugin directive collection
12//! - Source map for error reporting
13//! - Push/pop tag and metadata handling
14//! - Automatic GPG decryption for encrypted files (`.gpg`, `.asc`)
15//!
16//! # Example
17//!
18//! ```ignore
19//! use rustledger_loader::Loader;
20//! use std::path::Path;
21//!
22//! let result = Loader::new().load(Path::new("ledger.beancount"))?;
23//! for directive in result.directives {
24//!     println!("{:?}", directive);
25//! }
26//! ```
27
28#![forbid(unsafe_code)]
29#![warn(missing_docs)]
30
31#[cfg(feature = "cache")]
32pub mod cache;
33mod options;
34#[cfg(any(feature = "booking", feature = "plugins", feature = "validation"))]
35mod process;
36mod source_map;
37
38#[cfg(feature = "cache")]
39pub use cache::{
40    CacheEntry, CachedOptions, CachedPlugin, invalidate_cache, load_cache_entry,
41    reintern_directives, save_cache_entry,
42};
43pub use options::Options;
44pub use source_map::{SourceFile, SourceMap};
45
46// Re-export processing API when features are enabled
47#[cfg(any(feature = "booking", feature = "plugins", feature = "validation"))]
48pub use process::{
49    ErrorLocation, ErrorSeverity, Ledger, LedgerError, LoadOptions, ProcessError, load, load_raw,
50    process,
51};
52
53use rustledger_core::{Directive, DisplayContext};
54use rustledger_parser::{ParseError, Span, Spanned};
55use std::collections::HashSet;
56use std::fs;
57use std::path::{Path, PathBuf};
58use std::process::Command;
59use thiserror::Error;
60
61/// Try to canonicalize a path, falling back to making it absolute if canonicalize
62/// is not supported (e.g., on WASI).
63///
64/// This function:
65/// 1. First tries `fs::canonicalize()` which resolves symlinks and returns absolute path
66/// 2. If that fails (e.g., WASI doesn't support it), tries to make an absolute path manually
67/// 3. As a last resort, returns the original path
68fn normalize_path(path: &Path) -> PathBuf {
69    // Try canonicalize first (works on most platforms, resolves symlinks)
70    if let Ok(canonical) = path.canonicalize() {
71        return canonical;
72    }
73
74    // Fallback: make absolute without resolving symlinks (WASI-compatible)
75    if path.is_absolute() {
76        path.to_path_buf()
77    } else if let Ok(cwd) = std::env::current_dir() {
78        // Join with current directory and clean up the path
79        let mut result = cwd;
80        for component in path.components() {
81            match component {
82                std::path::Component::ParentDir => {
83                    result.pop();
84                }
85                std::path::Component::Normal(s) => {
86                    result.push(s);
87                }
88                std::path::Component::CurDir => {}
89                std::path::Component::RootDir => {
90                    result = PathBuf::from("/");
91                }
92                std::path::Component::Prefix(p) => {
93                    result = PathBuf::from(p.as_os_str());
94                }
95            }
96        }
97        result
98    } else {
99        // Last resort: just return the path as-is
100        path.to_path_buf()
101    }
102}
103
104/// Errors that can occur during loading.
105#[derive(Debug, Error)]
106pub enum LoadError {
107    /// IO error reading a file.
108    #[error("failed to read file {path}: {source}")]
109    Io {
110        /// The path that failed to read.
111        path: PathBuf,
112        /// The underlying IO error.
113        #[source]
114        source: std::io::Error,
115    },
116
117    /// Include cycle detected.
118    #[error("include cycle detected: {}", .cycle.join(" -> "))]
119    IncludeCycle {
120        /// The cycle of file paths.
121        cycle: Vec<String>,
122    },
123
124    /// Parse errors occurred.
125    #[error("parse errors in {path}")]
126    ParseErrors {
127        /// The file with parse errors.
128        path: PathBuf,
129        /// The parse errors.
130        errors: Vec<ParseError>,
131    },
132
133    /// Path traversal attempt detected.
134    #[error("path traversal not allowed: {include_path} escapes base directory {base_dir}")]
135    PathTraversal {
136        /// The include path that attempted traversal.
137        include_path: String,
138        /// The base directory.
139        base_dir: PathBuf,
140    },
141
142    /// GPG decryption failed.
143    #[error("failed to decrypt {path}: {message}")]
144    Decryption {
145        /// The encrypted file path.
146        path: PathBuf,
147        /// Error message from GPG.
148        message: String,
149    },
150
151    /// Glob pattern did not match any files.
152    #[error("include pattern \"{pattern}\" does not match any files")]
153    GlobNoMatch {
154        /// The glob pattern that matched nothing.
155        pattern: String,
156    },
157
158    /// Glob pattern expansion failed.
159    #[error("failed to expand include pattern \"{pattern}\": {message}")]
160    GlobError {
161        /// The glob pattern that failed.
162        pattern: String,
163        /// The error message.
164        message: String,
165    },
166}
167
168/// Result of loading a beancount file.
169#[derive(Debug)]
170pub struct LoadResult {
171    /// All directives from all files, in order.
172    pub directives: Vec<Spanned<Directive>>,
173    /// Parsed options.
174    pub options: Options,
175    /// Plugins to load.
176    pub plugins: Vec<Plugin>,
177    /// Source map for error reporting.
178    pub source_map: SourceMap,
179    /// All errors encountered during loading.
180    pub errors: Vec<LoadError>,
181    /// Display context for formatting numbers (tracks precision per currency).
182    pub display_context: DisplayContext,
183}
184
185/// A plugin directive.
186#[derive(Debug, Clone)]
187pub struct Plugin {
188    /// Plugin module name (with any `python:` prefix stripped).
189    pub name: String,
190    /// Optional configuration string.
191    pub config: Option<String>,
192    /// Source location.
193    pub span: Span,
194    /// File this plugin was declared in.
195    pub file_id: usize,
196    /// Whether the `python:` prefix was used to force Python execution.
197    pub force_python: bool,
198}
199
200/// Check if a file is GPG-encrypted based on extension or content.
201///
202/// Returns `true` for:
203/// - Files with `.gpg` extension
204/// - Files with `.asc` extension containing a PGP message header
205fn is_encrypted_file(path: &Path) -> bool {
206    match path.extension().and_then(|e| e.to_str()) {
207        Some("gpg") => true,
208        Some("asc") => {
209            // Check for PGP header in first 1024 bytes
210            if let Ok(content) = fs::read_to_string(path) {
211                let check_len = 1024.min(content.len());
212                content[..check_len].contains("-----BEGIN PGP MESSAGE-----")
213            } else {
214                false
215            }
216        }
217        _ => false,
218    }
219}
220
221/// Decrypt a GPG-encrypted file using the system `gpg` command.
222///
223/// This uses `gpg --batch --decrypt` which will use the user's
224/// GPG keyring and gpg-agent for passphrase handling.
225fn decrypt_gpg_file(path: &Path) -> Result<String, LoadError> {
226    let output = Command::new("gpg")
227        .args(["--batch", "--decrypt"])
228        .arg(path)
229        .output()
230        .map_err(|e| LoadError::Decryption {
231            path: path.to_path_buf(),
232            message: format!("failed to run gpg: {e}"),
233        })?;
234
235    if !output.status.success() {
236        return Err(LoadError::Decryption {
237            path: path.to_path_buf(),
238            message: String::from_utf8_lossy(&output.stderr).trim().to_string(),
239        });
240    }
241
242    String::from_utf8(output.stdout).map_err(|e| LoadError::Decryption {
243        path: path.to_path_buf(),
244        message: format!("decrypted content is not valid UTF-8: {e}"),
245    })
246}
247
248/// Beancount file loader.
249#[derive(Debug, Default)]
250pub struct Loader {
251    /// Files that have been loaded (for cycle detection).
252    loaded_files: HashSet<PathBuf>,
253    /// Stack for cycle detection during loading (maintains order for error messages).
254    include_stack: Vec<PathBuf>,
255    /// Set for O(1) cycle detection (mirrors `include_stack`).
256    include_stack_set: HashSet<PathBuf>,
257    /// Root directory for path traversal protection.
258    /// If set, includes must resolve to paths within this directory.
259    root_dir: Option<PathBuf>,
260    /// Whether to enforce path traversal protection.
261    enforce_path_security: bool,
262}
263
264impl Loader {
265    /// Create a new loader.
266    #[must_use]
267    pub fn new() -> Self {
268        Self::default()
269    }
270
271    /// Enable path traversal protection.
272    ///
273    /// When enabled, include directives cannot escape the root directory
274    /// of the main beancount file. This prevents malicious ledger files
275    /// from accessing sensitive files outside the ledger directory.
276    ///
277    /// # Example
278    ///
279    /// ```ignore
280    /// let result = Loader::new()
281    ///     .with_path_security(true)
282    ///     .load(Path::new("ledger.beancount"))?;
283    /// ```
284    #[must_use]
285    pub const fn with_path_security(mut self, enabled: bool) -> Self {
286        self.enforce_path_security = enabled;
287        self
288    }
289
290    /// Set a custom root directory for path security.
291    ///
292    /// By default, the root directory is the parent directory of the main file.
293    /// This method allows overriding that to a custom directory.
294    #[must_use]
295    pub fn with_root_dir(mut self, root: PathBuf) -> Self {
296        self.root_dir = Some(root);
297        self.enforce_path_security = true;
298        self
299    }
300
301    /// Load a beancount file and all its includes.
302    ///
303    /// Parses the file, processes options and plugin directives, and recursively
304    /// loads any included files.
305    ///
306    /// # Errors
307    ///
308    /// Returns [`LoadError`] in the following cases:
309    ///
310    /// - [`LoadError::Io`] - Failed to read the file or an included file
311    /// - [`LoadError::IncludeCycle`] - Circular include detected
312    ///
313    /// Note: Parse errors and path traversal errors are collected in
314    /// [`LoadResult::errors`] rather than returned directly, allowing
315    /// partial results to be returned.
316    pub fn load(&mut self, path: &Path) -> Result<LoadResult, LoadError> {
317        let mut directives = Vec::new();
318        let mut options = Options::default();
319        let mut plugins = Vec::new();
320        let mut source_map = SourceMap::new();
321        let mut errors = Vec::new();
322
323        // Get normalized absolute path (WASI-compatible, doesn't require canonicalize)
324        let canonical = normalize_path(path);
325
326        // Set root directory for path security if enabled but not explicitly set
327        if self.enforce_path_security && self.root_dir.is_none() {
328            self.root_dir = canonical.parent().map(Path::to_path_buf);
329        }
330
331        self.load_recursive(
332            &canonical,
333            &mut directives,
334            &mut options,
335            &mut plugins,
336            &mut source_map,
337            &mut errors,
338        )?;
339
340        // Build display context from directives and options
341        let display_context = build_display_context(&directives, &options);
342
343        Ok(LoadResult {
344            directives,
345            options,
346            plugins,
347            source_map,
348            errors,
349            display_context,
350        })
351    }
352
353    fn load_recursive(
354        &mut self,
355        path: &Path,
356        directives: &mut Vec<Spanned<Directive>>,
357        options: &mut Options,
358        plugins: &mut Vec<Plugin>,
359        source_map: &mut SourceMap,
360        errors: &mut Vec<LoadError>,
361    ) -> Result<(), LoadError> {
362        // Allocate path once for reuse
363        let path_buf = path.to_path_buf();
364
365        // Check for cycles using O(1) HashSet lookup
366        if self.include_stack_set.contains(&path_buf) {
367            let mut cycle: Vec<String> = self
368                .include_stack
369                .iter()
370                .map(|p| p.display().to_string())
371                .collect();
372            cycle.push(path.display().to_string());
373            return Err(LoadError::IncludeCycle { cycle });
374        }
375
376        // Check if already loaded
377        if self.loaded_files.contains(&path_buf) {
378            return Ok(());
379        }
380
381        // Read file (decrypting if necessary)
382        // Try fast UTF-8 conversion first, fall back to lossy for non-UTF-8 files
383        let source: std::sync::Arc<str> = if is_encrypted_file(path) {
384            decrypt_gpg_file(path)?.into()
385        } else {
386            let bytes = fs::read(path).map_err(|e| LoadError::Io {
387                path: path_buf.clone(),
388                source: e,
389            })?;
390            // Try zero-copy conversion first (common case), fall back to lossy
391            match String::from_utf8(bytes) {
392                Ok(s) => s.into(),
393                Err(e) => String::from_utf8_lossy(e.as_bytes()).into_owned().into(),
394            }
395        };
396
397        // Add to source map (Arc::clone is cheap - just increments refcount)
398        let file_id = source_map.add_file(path_buf.clone(), std::sync::Arc::clone(&source));
399
400        // Mark as loading (update both stack and set)
401        self.include_stack_set.insert(path_buf.clone());
402        self.include_stack.push(path_buf.clone());
403        self.loaded_files.insert(path_buf);
404
405        // Parse (borrows from Arc, no allocation)
406        let result = rustledger_parser::parse(&source);
407
408        // Collect parse errors
409        if !result.errors.is_empty() {
410            errors.push(LoadError::ParseErrors {
411                path: path.to_path_buf(),
412                errors: result.errors,
413            });
414        }
415
416        // Process options
417        for (key, value, _span) in result.options {
418            options.set(&key, &value);
419        }
420
421        // Process plugins
422        for (name, config, span) in result.plugins {
423            // Check for "python:" prefix to force Python execution
424            let (actual_name, force_python) = if let Some(stripped) = name.strip_prefix("python:") {
425                (stripped.to_string(), true)
426            } else {
427                (name, false)
428            };
429            plugins.push(Plugin {
430                name: actual_name,
431                config,
432                span,
433                file_id,
434                force_python,
435            });
436        }
437
438        // Process includes (with glob pattern support)
439        let base_dir = path.parent().unwrap_or(Path::new("."));
440        for (include_path, _span) in &result.includes {
441            // Check if the include path contains glob metacharacters
442            // (check on include_path, not full_path, to avoid false positives from directory names)
443            let has_glob = include_path.contains('*')
444                || include_path.contains('?')
445                || include_path.contains('[');
446
447            let full_path = base_dir.join(include_path);
448
449            // Path traversal protection: check BEFORE glob expansion to avoid
450            // enumerating files outside the allowed root directory
451            if self.enforce_path_security
452                && let Some(ref root) = self.root_dir
453            {
454                // For glob patterns, extract and check the non-glob prefix
455                let path_to_check = if has_glob {
456                    // Find where the first glob metacharacter is
457                    let glob_start = include_path
458                        .find(['*', '?', '['])
459                        .unwrap_or(include_path.len());
460                    // Get the directory prefix before the glob
461                    let prefix = &include_path[..glob_start];
462                    let prefix_path = if let Some(last_sep) = prefix.rfind('/') {
463                        base_dir.join(&include_path[..=last_sep])
464                    } else {
465                        base_dir.to_path_buf()
466                    };
467                    normalize_path(&prefix_path)
468                } else {
469                    normalize_path(&full_path)
470                };
471
472                if !path_to_check.starts_with(root) {
473                    errors.push(LoadError::PathTraversal {
474                        include_path: include_path.clone(),
475                        base_dir: root.clone(),
476                    });
477                    continue;
478                }
479            }
480
481            let full_path_str = full_path.to_string_lossy();
482
483            // Expand glob patterns or use literal path
484            let paths_to_load: Vec<PathBuf> = if has_glob {
485                match glob::glob(&full_path_str) {
486                    Ok(entries) => {
487                        let mut matched: Vec<PathBuf> = Vec::new();
488                        for entry in entries {
489                            match entry {
490                                Ok(p) => matched.push(p),
491                                Err(e) => {
492                                    errors.push(LoadError::GlobError {
493                                        pattern: include_path.clone(),
494                                        message: e.to_string(),
495                                    });
496                                }
497                            }
498                        }
499                        // Sort for deterministic ordering
500                        matched.sort();
501                        matched
502                    }
503                    Err(e) => {
504                        errors.push(LoadError::GlobError {
505                            pattern: include_path.clone(),
506                            message: e.to_string(),
507                        });
508                        continue;
509                    }
510                }
511            } else {
512                vec![full_path.clone()]
513            };
514
515            // Check if glob matched nothing
516            if has_glob && paths_to_load.is_empty() {
517                errors.push(LoadError::GlobNoMatch {
518                    pattern: include_path.clone(),
519                });
520                continue;
521            }
522
523            // Load each matched file
524            for matched_path in paths_to_load {
525                // Use normalize_path for WASI compatibility (canonicalize not supported)
526                let canonical = normalize_path(&matched_path);
527
528                // Additional security check for each matched file
529                // (glob could still match files outside root via symlinks)
530                if self.enforce_path_security
531                    && let Some(ref root) = self.root_dir
532                    && !canonical.starts_with(root)
533                {
534                    errors.push(LoadError::PathTraversal {
535                        include_path: matched_path.to_string_lossy().into_owned(),
536                        base_dir: root.clone(),
537                    });
538                    continue;
539                }
540
541                if let Err(e) = self
542                    .load_recursive(&canonical, directives, options, plugins, source_map, errors)
543                {
544                    errors.push(e);
545                }
546            }
547        }
548
549        // Add directives from this file, setting the file_id
550        directives.extend(
551            result
552                .directives
553                .into_iter()
554                .map(|d| d.with_file_id(file_id)),
555        );
556
557        // Pop from stack and set
558        if let Some(popped) = self.include_stack.pop() {
559            self.include_stack_set.remove(&popped);
560        }
561
562        Ok(())
563    }
564}
565
566/// Build a display context from loaded directives and options.
567///
568/// This scans all directives for amounts and tracks the maximum precision seen
569/// for each currency. Fixed precisions from `option "display_precision"` override
570/// the inferred values.
571fn build_display_context(directives: &[Spanned<Directive>], options: &Options) -> DisplayContext {
572    let mut ctx = DisplayContext::new();
573
574    // Set render_commas from options
575    ctx.set_render_commas(options.render_commas);
576
577    // Scan directives for amounts to infer precision
578    for spanned in directives {
579        match &spanned.value {
580            Directive::Transaction(txn) => {
581                for posting in &txn.postings {
582                    // Units (IncompleteAmount)
583                    if let Some(ref units) = posting.units
584                        && let (Some(number), Some(currency)) = (units.number(), units.currency())
585                    {
586                        ctx.update(number, currency);
587                    }
588                    // Cost (CostSpec)
589                    if let Some(ref cost) = posting.cost
590                        && let (Some(number), Some(currency)) =
591                            (cost.number_per.or(cost.number_total), &cost.currency)
592                    {
593                        ctx.update(number, currency.as_str());
594                    }
595                    // Price (PriceAnnotation)
596                    if let Some(ref price) = posting.price
597                        && let Some(amount) = price.amount()
598                    {
599                        ctx.update(amount.number, amount.currency.as_str());
600                    }
601                }
602            }
603            Directive::Balance(bal) => {
604                ctx.update(bal.amount.number, bal.amount.currency.as_str());
605                if let Some(tol) = bal.tolerance {
606                    ctx.update(tol, bal.amount.currency.as_str());
607                }
608            }
609            Directive::Price(price) => {
610                ctx.update(price.amount.number, price.amount.currency.as_str());
611            }
612            Directive::Pad(_)
613            | Directive::Open(_)
614            | Directive::Close(_)
615            | Directive::Commodity(_)
616            | Directive::Event(_)
617            | Directive::Query(_)
618            | Directive::Note(_)
619            | Directive::Document(_)
620            | Directive::Custom(_) => {}
621        }
622    }
623
624    // Apply fixed precisions from options (these override inferred values)
625    for (currency, precision) in &options.display_precision {
626        ctx.set_fixed_precision(currency, *precision);
627    }
628
629    ctx
630}
631
632/// Load a beancount file without processing.
633///
634/// This is a convenience function that creates a loader and loads a single file.
635/// For fully processed results (booking, plugins, validation), use the
636/// [`load`] function with [`LoadOptions`] instead.
637#[cfg(not(any(feature = "booking", feature = "plugins", feature = "validation")))]
638pub fn load(path: &Path) -> Result<LoadResult, LoadError> {
639    Loader::new().load(path)
640}
641
642#[cfg(test)]
643mod tests {
644    use super::*;
645    use std::io::Write;
646    use tempfile::NamedTempFile;
647
648    #[test]
649    fn test_is_encrypted_file_gpg_extension() {
650        let path = Path::new("test.beancount.gpg");
651        assert!(is_encrypted_file(path));
652    }
653
654    #[test]
655    fn test_is_encrypted_file_plain_beancount() {
656        let path = Path::new("test.beancount");
657        assert!(!is_encrypted_file(path));
658    }
659
660    #[test]
661    fn test_is_encrypted_file_asc_with_pgp_header() {
662        let mut file = NamedTempFile::with_suffix(".asc").unwrap();
663        writeln!(file, "-----BEGIN PGP MESSAGE-----").unwrap();
664        writeln!(file, "some encrypted content").unwrap();
665        writeln!(file, "-----END PGP MESSAGE-----").unwrap();
666        file.flush().unwrap();
667
668        assert!(is_encrypted_file(file.path()));
669    }
670
671    #[test]
672    fn test_is_encrypted_file_asc_without_pgp_header() {
673        let mut file = NamedTempFile::with_suffix(".asc").unwrap();
674        writeln!(file, "This is just a plain text file").unwrap();
675        writeln!(file, "with .asc extension but no PGP content").unwrap();
676        file.flush().unwrap();
677
678        assert!(!is_encrypted_file(file.path()));
679    }
680
681    #[test]
682    fn test_decrypt_gpg_file_missing_gpg() {
683        // Create a fake .gpg file
684        let mut file = NamedTempFile::with_suffix(".gpg").unwrap();
685        writeln!(file, "fake encrypted content").unwrap();
686        file.flush().unwrap();
687
688        // This will fail because the content isn't actually GPG-encrypted
689        // (or gpg isn't installed, or there's no matching key)
690        let result = decrypt_gpg_file(file.path());
691        assert!(result.is_err());
692
693        if let Err(LoadError::Decryption { path, message }) = result {
694            assert_eq!(path, file.path().to_path_buf());
695            assert!(!message.is_empty());
696        } else {
697            panic!("Expected Decryption error");
698        }
699    }
700
701    #[test]
702    fn test_plugin_force_python_prefix() {
703        let mut file = NamedTempFile::with_suffix(".beancount").unwrap();
704        writeln!(file, r#"plugin "python:my_plugin""#).unwrap();
705        writeln!(file, r#"plugin "regular_plugin""#).unwrap();
706        file.flush().unwrap();
707
708        let result = Loader::new().load(file.path()).unwrap();
709
710        assert_eq!(result.plugins.len(), 2);
711
712        // First plugin should have force_python = true and name without prefix
713        assert_eq!(result.plugins[0].name, "my_plugin");
714        assert!(result.plugins[0].force_python);
715
716        // Second plugin should have force_python = false
717        assert_eq!(result.plugins[1].name, "regular_plugin");
718        assert!(!result.plugins[1].force_python);
719    }
720
721    #[test]
722    fn test_plugin_force_python_with_config() {
723        let mut file = NamedTempFile::with_suffix(".beancount").unwrap();
724        writeln!(file, r#"plugin "python:my_plugin" "config_value""#).unwrap();
725        file.flush().unwrap();
726
727        let result = Loader::new().load(file.path()).unwrap();
728
729        assert_eq!(result.plugins.len(), 1);
730        assert_eq!(result.plugins[0].name, "my_plugin");
731        assert!(result.plugins[0].force_python);
732        assert_eq!(result.plugins[0].config, Some("config_value".to_string()));
733    }
734}