rustledger_loader/
lib.rs

1//! Beancount file loader with include resolution.
2//!
3//! This crate handles loading beancount files, resolving includes,
4//! and collecting options. It builds on the parser to provide a
5//! complete loading pipeline.
6//!
7//! # Features
8//!
9//! - Recursive include resolution with cycle detection
10//! - Options collection and parsing
11//! - Plugin directive collection
12//! - Source map for error reporting
13//! - Push/pop tag and metadata handling
14//! - Automatic GPG decryption for encrypted files (`.gpg`, `.asc`)
15//!
16//! # Example
17//!
18//! ```ignore
19//! use rustledger_loader::Loader;
20//! use std::path::Path;
21//!
22//! let result = Loader::new().load(Path::new("ledger.beancount"))?;
23//! for directive in result.directives {
24//!     println!("{:?}", directive);
25//! }
26//! ```
27
28#![forbid(unsafe_code)]
29#![warn(missing_docs)]
30
31#[cfg(feature = "cache")]
32pub mod cache;
33mod options;
34mod source_map;
35
36#[cfg(feature = "cache")]
37pub use cache::{
38    CacheEntry, CachedOptions, CachedPlugin, invalidate_cache, load_cache_entry,
39    reintern_directives, save_cache_entry,
40};
41pub use options::Options;
42pub use source_map::{SourceFile, SourceMap};
43
44use rustledger_core::{Directive, DisplayContext};
45use rustledger_parser::{ParseError, Span, Spanned};
46use std::collections::HashSet;
47use std::fs;
48use std::path::{Path, PathBuf};
49use std::process::Command;
50use thiserror::Error;
51
52/// Errors that can occur during loading.
53#[derive(Debug, Error)]
54pub enum LoadError {
55    /// IO error reading a file.
56    #[error("failed to read file {path}: {source}")]
57    Io {
58        /// The path that failed to read.
59        path: PathBuf,
60        /// The underlying IO error.
61        #[source]
62        source: std::io::Error,
63    },
64
65    /// Include cycle detected.
66    #[error("include cycle detected: {}", .cycle.join(" -> "))]
67    IncludeCycle {
68        /// The cycle of file paths.
69        cycle: Vec<String>,
70    },
71
72    /// Parse errors occurred.
73    #[error("parse errors in {path}")]
74    ParseErrors {
75        /// The file with parse errors.
76        path: PathBuf,
77        /// The parse errors.
78        errors: Vec<ParseError>,
79    },
80
81    /// Path traversal attempt detected.
82    #[error("path traversal not allowed: {include_path} escapes base directory {base_dir}")]
83    PathTraversal {
84        /// The include path that attempted traversal.
85        include_path: String,
86        /// The base directory.
87        base_dir: PathBuf,
88    },
89
90    /// GPG decryption failed.
91    #[error("failed to decrypt {path}: {message}")]
92    Decryption {
93        /// The encrypted file path.
94        path: PathBuf,
95        /// Error message from GPG.
96        message: String,
97    },
98}
99
100/// Result of loading a beancount file.
101#[derive(Debug)]
102pub struct LoadResult {
103    /// All directives from all files, in order.
104    pub directives: Vec<Spanned<Directive>>,
105    /// Parsed options.
106    pub options: Options,
107    /// Plugins to load.
108    pub plugins: Vec<Plugin>,
109    /// Source map for error reporting.
110    pub source_map: SourceMap,
111    /// All errors encountered during loading.
112    pub errors: Vec<LoadError>,
113    /// Display context for formatting numbers (tracks precision per currency).
114    pub display_context: DisplayContext,
115}
116
117/// A plugin directive.
118#[derive(Debug, Clone)]
119pub struct Plugin {
120    /// Plugin module name.
121    pub name: String,
122    /// Optional configuration string.
123    pub config: Option<String>,
124    /// Source location.
125    pub span: Span,
126    /// File this plugin was declared in.
127    pub file_id: usize,
128}
129
130/// Check if a file is GPG-encrypted based on extension or content.
131///
132/// Returns `true` for:
133/// - Files with `.gpg` extension
134/// - Files with `.asc` extension containing a PGP message header
135fn is_encrypted_file(path: &Path) -> bool {
136    match path.extension().and_then(|e| e.to_str()) {
137        Some("gpg") => true,
138        Some("asc") => {
139            // Check for PGP header in first 1024 bytes
140            if let Ok(content) = fs::read_to_string(path) {
141                let check_len = 1024.min(content.len());
142                content[..check_len].contains("-----BEGIN PGP MESSAGE-----")
143            } else {
144                false
145            }
146        }
147        _ => false,
148    }
149}
150
151/// Decrypt a GPG-encrypted file using the system `gpg` command.
152///
153/// This uses `gpg --batch --decrypt` which will use the user's
154/// GPG keyring and gpg-agent for passphrase handling.
155fn decrypt_gpg_file(path: &Path) -> Result<String, LoadError> {
156    let output = Command::new("gpg")
157        .args(["--batch", "--decrypt"])
158        .arg(path)
159        .output()
160        .map_err(|e| LoadError::Decryption {
161            path: path.to_path_buf(),
162            message: format!("failed to run gpg: {e}"),
163        })?;
164
165    if !output.status.success() {
166        return Err(LoadError::Decryption {
167            path: path.to_path_buf(),
168            message: String::from_utf8_lossy(&output.stderr).trim().to_string(),
169        });
170    }
171
172    String::from_utf8(output.stdout).map_err(|e| LoadError::Decryption {
173        path: path.to_path_buf(),
174        message: format!("decrypted content is not valid UTF-8: {e}"),
175    })
176}
177
178/// Beancount file loader.
179#[derive(Debug, Default)]
180pub struct Loader {
181    /// Files that have been loaded (for cycle detection).
182    loaded_files: HashSet<PathBuf>,
183    /// Stack for cycle detection during loading.
184    include_stack: Vec<PathBuf>,
185    /// Root directory for path traversal protection.
186    /// If set, includes must resolve to paths within this directory.
187    root_dir: Option<PathBuf>,
188    /// Whether to enforce path traversal protection.
189    enforce_path_security: bool,
190}
191
192impl Loader {
193    /// Create a new loader.
194    #[must_use]
195    pub fn new() -> Self {
196        Self::default()
197    }
198
199    /// Enable path traversal protection.
200    ///
201    /// When enabled, include directives cannot escape the root directory
202    /// of the main beancount file. This prevents malicious ledger files
203    /// from accessing sensitive files outside the ledger directory.
204    ///
205    /// # Example
206    ///
207    /// ```ignore
208    /// let result = Loader::new()
209    ///     .with_path_security(true)
210    ///     .load(Path::new("ledger.beancount"))?;
211    /// ```
212    #[must_use]
213    pub const fn with_path_security(mut self, enabled: bool) -> Self {
214        self.enforce_path_security = enabled;
215        self
216    }
217
218    /// Set a custom root directory for path security.
219    ///
220    /// By default, the root directory is the parent directory of the main file.
221    /// This method allows overriding that to a custom directory.
222    #[must_use]
223    pub fn with_root_dir(mut self, root: PathBuf) -> Self {
224        self.root_dir = Some(root);
225        self.enforce_path_security = true;
226        self
227    }
228
229    /// Load a beancount file and all its includes.
230    ///
231    /// Parses the file, processes options and plugin directives, and recursively
232    /// loads any included files.
233    ///
234    /// # Errors
235    ///
236    /// Returns [`LoadError`] in the following cases:
237    ///
238    /// - [`LoadError::Io`] - Failed to read the file or an included file
239    /// - [`LoadError::IncludeCycle`] - Circular include detected
240    ///
241    /// Note: Parse errors and path traversal errors are collected in
242    /// [`LoadResult::errors`] rather than returned directly, allowing
243    /// partial results to be returned.
244    pub fn load(&mut self, path: &Path) -> Result<LoadResult, LoadError> {
245        let mut directives = Vec::new();
246        let mut options = Options::default();
247        let mut plugins = Vec::new();
248        let mut source_map = SourceMap::new();
249        let mut errors = Vec::new();
250
251        // Get canonical path
252        let canonical = path.canonicalize().map_err(|e| LoadError::Io {
253            path: path.to_path_buf(),
254            source: e,
255        })?;
256
257        // Set root directory for path security if enabled but not explicitly set
258        if self.enforce_path_security && self.root_dir.is_none() {
259            self.root_dir = canonical.parent().map(Path::to_path_buf);
260        }
261
262        self.load_recursive(
263            &canonical,
264            &mut directives,
265            &mut options,
266            &mut plugins,
267            &mut source_map,
268            &mut errors,
269        )?;
270
271        // Build display context from directives and options
272        let display_context = build_display_context(&directives, &options);
273
274        Ok(LoadResult {
275            directives,
276            options,
277            plugins,
278            source_map,
279            errors,
280            display_context,
281        })
282    }
283
284    fn load_recursive(
285        &mut self,
286        path: &Path,
287        directives: &mut Vec<Spanned<Directive>>,
288        options: &mut Options,
289        plugins: &mut Vec<Plugin>,
290        source_map: &mut SourceMap,
291        errors: &mut Vec<LoadError>,
292    ) -> Result<(), LoadError> {
293        // Check for cycles
294        let path_buf = path.to_path_buf();
295        if self.include_stack.contains(&path_buf) {
296            let mut cycle: Vec<String> = self
297                .include_stack
298                .iter()
299                .map(|p| p.display().to_string())
300                .collect();
301            cycle.push(path.display().to_string());
302            return Err(LoadError::IncludeCycle { cycle });
303        }
304
305        // Check if already loaded
306        if self.loaded_files.contains(path) {
307            return Ok(());
308        }
309
310        // Read file (decrypting if necessary)
311        // Use lossy UTF-8 decoding to handle non-UTF-8 files gracefully (like Python beancount)
312        let source: std::sync::Arc<str> = if is_encrypted_file(path) {
313            decrypt_gpg_file(path)?.into()
314        } else {
315            let bytes = fs::read(path).map_err(|e| LoadError::Io {
316                path: path.to_path_buf(),
317                source: e,
318            })?;
319            String::from_utf8_lossy(&bytes).into_owned().into()
320        };
321
322        // Add to source map (Arc::clone is cheap - just increments refcount)
323        let file_id = source_map.add_file(path.to_path_buf(), std::sync::Arc::clone(&source));
324
325        // Mark as loading
326        self.include_stack.push(path.to_path_buf());
327        self.loaded_files.insert(path.to_path_buf());
328
329        // Parse (borrows from Arc, no allocation)
330        let result = rustledger_parser::parse(&source);
331
332        // Collect parse errors
333        if !result.errors.is_empty() {
334            errors.push(LoadError::ParseErrors {
335                path: path.to_path_buf(),
336                errors: result.errors,
337            });
338        }
339
340        // Process options
341        for (key, value, _span) in result.options {
342            options.set(&key, &value);
343        }
344
345        // Process plugins
346        for (name, config, span) in result.plugins {
347            plugins.push(Plugin {
348                name,
349                config,
350                span,
351                file_id,
352            });
353        }
354
355        // Process includes
356        let base_dir = path.parent().unwrap_or(Path::new("."));
357        for (include_path, _span) in &result.includes {
358            let full_path = base_dir.join(include_path);
359            let canonical = match full_path.canonicalize() {
360                Ok(p) => p,
361                Err(e) => {
362                    errors.push(LoadError::Io {
363                        path: full_path,
364                        source: e,
365                    });
366                    continue;
367                }
368            };
369
370            // Path traversal protection: ensure include stays within root directory
371            if self.enforce_path_security {
372                if let Some(ref root) = self.root_dir {
373                    if !canonical.starts_with(root) {
374                        errors.push(LoadError::PathTraversal {
375                            include_path: include_path.clone(),
376                            base_dir: root.clone(),
377                        });
378                        continue;
379                    }
380                }
381            }
382
383            if let Err(e) =
384                self.load_recursive(&canonical, directives, options, plugins, source_map, errors)
385            {
386                errors.push(e);
387            }
388        }
389
390        // Add directives from this file, setting the file_id
391        directives.extend(
392            result
393                .directives
394                .into_iter()
395                .map(|d| d.with_file_id(file_id)),
396        );
397
398        // Pop from stack
399        self.include_stack.pop();
400
401        Ok(())
402    }
403}
404
405/// Build a display context from loaded directives and options.
406///
407/// This scans all directives for amounts and tracks the maximum precision seen
408/// for each currency. Fixed precisions from `option "display_precision"` override
409/// the inferred values.
410fn build_display_context(directives: &[Spanned<Directive>], options: &Options) -> DisplayContext {
411    let mut ctx = DisplayContext::new();
412
413    // Set render_commas from options
414    ctx.set_render_commas(options.render_commas);
415
416    // Scan directives for amounts to infer precision
417    for spanned in directives {
418        match &spanned.value {
419            Directive::Transaction(txn) => {
420                for posting in &txn.postings {
421                    // Units (IncompleteAmount)
422                    if let Some(ref units) = posting.units {
423                        if let (Some(number), Some(currency)) = (units.number(), units.currency()) {
424                            ctx.update(number, currency);
425                        }
426                    }
427                    // Cost (CostSpec)
428                    if let Some(ref cost) = posting.cost {
429                        if let (Some(number), Some(currency)) =
430                            (cost.number_per.or(cost.number_total), &cost.currency)
431                        {
432                            ctx.update(number, currency.as_str());
433                        }
434                    }
435                    // Price (PriceAnnotation)
436                    if let Some(ref price) = posting.price {
437                        if let Some(amount) = price.amount() {
438                            ctx.update(amount.number, amount.currency.as_str());
439                        }
440                    }
441                }
442            }
443            Directive::Balance(bal) => {
444                ctx.update(bal.amount.number, bal.amount.currency.as_str());
445                if let Some(tol) = bal.tolerance {
446                    ctx.update(tol, bal.amount.currency.as_str());
447                }
448            }
449            Directive::Price(price) => {
450                ctx.update(price.amount.number, price.amount.currency.as_str());
451            }
452            Directive::Pad(_)
453            | Directive::Open(_)
454            | Directive::Close(_)
455            | Directive::Commodity(_)
456            | Directive::Event(_)
457            | Directive::Query(_)
458            | Directive::Note(_)
459            | Directive::Document(_)
460            | Directive::Custom(_) => {}
461        }
462    }
463
464    // Apply fixed precisions from options (these override inferred values)
465    for (currency, precision) in &options.display_precision {
466        ctx.set_fixed_precision(currency, *precision);
467    }
468
469    ctx
470}
471
472/// Load a beancount file.
473///
474/// This is a convenience function that creates a loader and loads a single file.
475pub fn load(path: &Path) -> Result<LoadResult, LoadError> {
476    Loader::new().load(path)
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482    use std::io::Write;
483    use tempfile::NamedTempFile;
484
485    #[test]
486    fn test_is_encrypted_file_gpg_extension() {
487        let path = Path::new("test.beancount.gpg");
488        assert!(is_encrypted_file(path));
489    }
490
491    #[test]
492    fn test_is_encrypted_file_plain_beancount() {
493        let path = Path::new("test.beancount");
494        assert!(!is_encrypted_file(path));
495    }
496
497    #[test]
498    fn test_is_encrypted_file_asc_with_pgp_header() {
499        let mut file = NamedTempFile::with_suffix(".asc").unwrap();
500        writeln!(file, "-----BEGIN PGP MESSAGE-----").unwrap();
501        writeln!(file, "some encrypted content").unwrap();
502        writeln!(file, "-----END PGP MESSAGE-----").unwrap();
503        file.flush().unwrap();
504
505        assert!(is_encrypted_file(file.path()));
506    }
507
508    #[test]
509    fn test_is_encrypted_file_asc_without_pgp_header() {
510        let mut file = NamedTempFile::with_suffix(".asc").unwrap();
511        writeln!(file, "This is just a plain text file").unwrap();
512        writeln!(file, "with .asc extension but no PGP content").unwrap();
513        file.flush().unwrap();
514
515        assert!(!is_encrypted_file(file.path()));
516    }
517
518    #[test]
519    fn test_decrypt_gpg_file_missing_gpg() {
520        // Create a fake .gpg file
521        let mut file = NamedTempFile::with_suffix(".gpg").unwrap();
522        writeln!(file, "fake encrypted content").unwrap();
523        file.flush().unwrap();
524
525        // This will fail because the content isn't actually GPG-encrypted
526        // (or gpg isn't installed, or there's no matching key)
527        let result = decrypt_gpg_file(file.path());
528        assert!(result.is_err());
529
530        if let Err(LoadError::Decryption { path, message }) = result {
531            assert_eq!(path, file.path().to_path_buf());
532            assert!(!message.is_empty());
533        } else {
534            panic!("Expected Decryption error");
535        }
536    }
537}