rustledger_loader/
lib.rs

1//! Beancount file loader with include resolution.
2//!
3//! This crate handles loading beancount files, resolving includes,
4//! and collecting options. It builds on the parser to provide a
5//! complete loading pipeline.
6//!
7//! # Features
8//!
9//! - Recursive include resolution with cycle detection
10//! - Options collection and parsing
11//! - Plugin directive collection
12//! - Source map for error reporting
13//! - Push/pop tag and metadata handling
14//! - Automatic GPG decryption for encrypted files (`.gpg`, `.asc`)
15//!
16//! # Example
17//!
18//! ```ignore
19//! use rustledger_loader::Loader;
20//! use std::path::Path;
21//!
22//! let result = Loader::new().load(Path::new("ledger.beancount"))?;
23//! for directive in result.directives {
24//!     println!("{:?}", directive);
25//! }
26//! ```
27
28#![forbid(unsafe_code)]
29#![warn(missing_docs)]
30
31#[cfg(feature = "cache")]
32pub mod cache;
33mod options;
34mod source_map;
35
36#[cfg(feature = "cache")]
37pub use cache::{
38    CacheEntry, CachedOptions, CachedPlugin, invalidate_cache, load_cache_entry,
39    reintern_directives, save_cache_entry,
40};
41pub use options::Options;
42pub use source_map::{SourceFile, SourceMap};
43
44use rustledger_core::Directive;
45use rustledger_parser::{ParseError, Span, Spanned};
46use std::collections::HashSet;
47use std::fs;
48use std::path::{Path, PathBuf};
49use std::process::Command;
50use thiserror::Error;
51
52/// Errors that can occur during loading.
53#[derive(Debug, Error)]
54pub enum LoadError {
55    /// IO error reading a file.
56    #[error("failed to read file {path}: {source}")]
57    Io {
58        /// The path that failed to read.
59        path: PathBuf,
60        /// The underlying IO error.
61        #[source]
62        source: std::io::Error,
63    },
64
65    /// Include cycle detected.
66    #[error("include cycle detected: {}", .cycle.join(" -> "))]
67    IncludeCycle {
68        /// The cycle of file paths.
69        cycle: Vec<String>,
70    },
71
72    /// Parse errors occurred.
73    #[error("parse errors in {path}")]
74    ParseErrors {
75        /// The file with parse errors.
76        path: PathBuf,
77        /// The parse errors.
78        errors: Vec<ParseError>,
79    },
80
81    /// Path traversal attempt detected.
82    #[error("path traversal not allowed: {include_path} escapes base directory {base_dir}")]
83    PathTraversal {
84        /// The include path that attempted traversal.
85        include_path: String,
86        /// The base directory.
87        base_dir: PathBuf,
88    },
89
90    /// GPG decryption failed.
91    #[error("failed to decrypt {path}: {message}")]
92    Decryption {
93        /// The encrypted file path.
94        path: PathBuf,
95        /// Error message from GPG.
96        message: String,
97    },
98}
99
100/// Result of loading a beancount file.
101#[derive(Debug)]
102pub struct LoadResult {
103    /// All directives from all files, in order.
104    pub directives: Vec<Spanned<Directive>>,
105    /// Parsed options.
106    pub options: Options,
107    /// Plugins to load.
108    pub plugins: Vec<Plugin>,
109    /// Source map for error reporting.
110    pub source_map: SourceMap,
111    /// All errors encountered during loading.
112    pub errors: Vec<LoadError>,
113}
114
115/// A plugin directive.
116#[derive(Debug, Clone)]
117pub struct Plugin {
118    /// Plugin module name.
119    pub name: String,
120    /// Optional configuration string.
121    pub config: Option<String>,
122    /// Source location.
123    pub span: Span,
124    /// File this plugin was declared in.
125    pub file_id: usize,
126}
127
128/// Check if a file is GPG-encrypted based on extension or content.
129///
130/// Returns `true` for:
131/// - Files with `.gpg` extension
132/// - Files with `.asc` extension containing a PGP message header
133fn is_encrypted_file(path: &Path) -> bool {
134    match path.extension().and_then(|e| e.to_str()) {
135        Some("gpg") => true,
136        Some("asc") => {
137            // Check for PGP header in first 1024 bytes
138            if let Ok(content) = fs::read_to_string(path) {
139                let check_len = 1024.min(content.len());
140                content[..check_len].contains("-----BEGIN PGP MESSAGE-----")
141            } else {
142                false
143            }
144        }
145        _ => false,
146    }
147}
148
149/// Decrypt a GPG-encrypted file using the system `gpg` command.
150///
151/// This uses `gpg --batch --decrypt` which will use the user's
152/// GPG keyring and gpg-agent for passphrase handling.
153fn decrypt_gpg_file(path: &Path) -> Result<String, LoadError> {
154    let output = Command::new("gpg")
155        .args(["--batch", "--decrypt"])
156        .arg(path)
157        .output()
158        .map_err(|e| LoadError::Decryption {
159            path: path.to_path_buf(),
160            message: format!("failed to run gpg: {e}"),
161        })?;
162
163    if !output.status.success() {
164        return Err(LoadError::Decryption {
165            path: path.to_path_buf(),
166            message: String::from_utf8_lossy(&output.stderr).trim().to_string(),
167        });
168    }
169
170    String::from_utf8(output.stdout).map_err(|e| LoadError::Decryption {
171        path: path.to_path_buf(),
172        message: format!("decrypted content is not valid UTF-8: {e}"),
173    })
174}
175
176/// Beancount file loader.
177#[derive(Debug, Default)]
178pub struct Loader {
179    /// Files that have been loaded (for cycle detection).
180    loaded_files: HashSet<PathBuf>,
181    /// Stack for cycle detection during loading.
182    include_stack: Vec<PathBuf>,
183    /// Root directory for path traversal protection.
184    /// If set, includes must resolve to paths within this directory.
185    root_dir: Option<PathBuf>,
186    /// Whether to enforce path traversal protection.
187    enforce_path_security: bool,
188}
189
190impl Loader {
191    /// Create a new loader.
192    #[must_use]
193    pub fn new() -> Self {
194        Self::default()
195    }
196
197    /// Enable path traversal protection.
198    ///
199    /// When enabled, include directives cannot escape the root directory
200    /// of the main beancount file. This prevents malicious ledger files
201    /// from accessing sensitive files outside the ledger directory.
202    ///
203    /// # Example
204    ///
205    /// ```ignore
206    /// let result = Loader::new()
207    ///     .with_path_security(true)
208    ///     .load(Path::new("ledger.beancount"))?;
209    /// ```
210    #[must_use]
211    pub const fn with_path_security(mut self, enabled: bool) -> Self {
212        self.enforce_path_security = enabled;
213        self
214    }
215
216    /// Set a custom root directory for path security.
217    ///
218    /// By default, the root directory is the parent directory of the main file.
219    /// This method allows overriding that to a custom directory.
220    #[must_use]
221    pub fn with_root_dir(mut self, root: PathBuf) -> Self {
222        self.root_dir = Some(root);
223        self.enforce_path_security = true;
224        self
225    }
226
227    /// Load a beancount file and all its includes.
228    ///
229    /// Parses the file, processes options and plugin directives, and recursively
230    /// loads any included files.
231    ///
232    /// # Errors
233    ///
234    /// Returns [`LoadError`] in the following cases:
235    ///
236    /// - [`LoadError::Io`] - Failed to read the file or an included file
237    /// - [`LoadError::IncludeCycle`] - Circular include detected
238    ///
239    /// Note: Parse errors and path traversal errors are collected in
240    /// [`LoadResult::errors`] rather than returned directly, allowing
241    /// partial results to be returned.
242    pub fn load(&mut self, path: &Path) -> Result<LoadResult, LoadError> {
243        let mut directives = Vec::new();
244        let mut options = Options::default();
245        let mut plugins = Vec::new();
246        let mut source_map = SourceMap::new();
247        let mut errors = Vec::new();
248
249        // Get canonical path
250        let canonical = path.canonicalize().map_err(|e| LoadError::Io {
251            path: path.to_path_buf(),
252            source: e,
253        })?;
254
255        // Set root directory for path security if enabled but not explicitly set
256        if self.enforce_path_security && self.root_dir.is_none() {
257            self.root_dir = canonical.parent().map(Path::to_path_buf);
258        }
259
260        self.load_recursive(
261            &canonical,
262            &mut directives,
263            &mut options,
264            &mut plugins,
265            &mut source_map,
266            &mut errors,
267        )?;
268
269        Ok(LoadResult {
270            directives,
271            options,
272            plugins,
273            source_map,
274            errors,
275        })
276    }
277
278    fn load_recursive(
279        &mut self,
280        path: &Path,
281        directives: &mut Vec<Spanned<Directive>>,
282        options: &mut Options,
283        plugins: &mut Vec<Plugin>,
284        source_map: &mut SourceMap,
285        errors: &mut Vec<LoadError>,
286    ) -> Result<(), LoadError> {
287        // Check for cycles
288        let path_buf = path.to_path_buf();
289        if self.include_stack.contains(&path_buf) {
290            let mut cycle: Vec<String> = self
291                .include_stack
292                .iter()
293                .map(|p| p.display().to_string())
294                .collect();
295            cycle.push(path.display().to_string());
296            return Err(LoadError::IncludeCycle { cycle });
297        }
298
299        // Check if already loaded
300        if self.loaded_files.contains(path) {
301            return Ok(());
302        }
303
304        // Read file (decrypting if necessary)
305        let source: std::sync::Arc<str> = if is_encrypted_file(path) {
306            decrypt_gpg_file(path)?.into()
307        } else {
308            fs::read_to_string(path)
309                .map_err(|e| LoadError::Io {
310                    path: path.to_path_buf(),
311                    source: e,
312                })?
313                .into()
314        };
315
316        // Add to source map (Arc::clone is cheap - just increments refcount)
317        let file_id = source_map.add_file(path.to_path_buf(), std::sync::Arc::clone(&source));
318
319        // Mark as loading
320        self.include_stack.push(path.to_path_buf());
321        self.loaded_files.insert(path.to_path_buf());
322
323        // Parse (borrows from Arc, no allocation)
324        let result = rustledger_parser::parse(&source);
325
326        // Collect parse errors
327        if !result.errors.is_empty() {
328            errors.push(LoadError::ParseErrors {
329                path: path.to_path_buf(),
330                errors: result.errors,
331            });
332        }
333
334        // Process options
335        for (key, value, _span) in result.options {
336            options.set(&key, &value);
337        }
338
339        // Process plugins
340        for (name, config, span) in result.plugins {
341            plugins.push(Plugin {
342                name,
343                config,
344                span,
345                file_id,
346            });
347        }
348
349        // Process includes
350        let base_dir = path.parent().unwrap_or(Path::new("."));
351        for (include_path, _span) in &result.includes {
352            let full_path = base_dir.join(include_path);
353            let canonical = match full_path.canonicalize() {
354                Ok(p) => p,
355                Err(e) => {
356                    errors.push(LoadError::Io {
357                        path: full_path,
358                        source: e,
359                    });
360                    continue;
361                }
362            };
363
364            // Path traversal protection: ensure include stays within root directory
365            if self.enforce_path_security {
366                if let Some(ref root) = self.root_dir {
367                    if !canonical.starts_with(root) {
368                        errors.push(LoadError::PathTraversal {
369                            include_path: include_path.clone(),
370                            base_dir: root.clone(),
371                        });
372                        continue;
373                    }
374                }
375            }
376
377            if let Err(e) =
378                self.load_recursive(&canonical, directives, options, plugins, source_map, errors)
379            {
380                errors.push(e);
381            }
382        }
383
384        // Add directives from this file
385        directives.extend(result.directives);
386
387        // Pop from stack
388        self.include_stack.pop();
389
390        Ok(())
391    }
392}
393
394/// Load a beancount file.
395///
396/// This is a convenience function that creates a loader and loads a single file.
397pub fn load(path: &Path) -> Result<LoadResult, LoadError> {
398    Loader::new().load(path)
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404    use std::io::Write;
405    use tempfile::NamedTempFile;
406
407    #[test]
408    fn test_is_encrypted_file_gpg_extension() {
409        let path = Path::new("test.beancount.gpg");
410        assert!(is_encrypted_file(path));
411    }
412
413    #[test]
414    fn test_is_encrypted_file_plain_beancount() {
415        let path = Path::new("test.beancount");
416        assert!(!is_encrypted_file(path));
417    }
418
419    #[test]
420    fn test_is_encrypted_file_asc_with_pgp_header() {
421        let mut file = NamedTempFile::with_suffix(".asc").unwrap();
422        writeln!(file, "-----BEGIN PGP MESSAGE-----").unwrap();
423        writeln!(file, "some encrypted content").unwrap();
424        writeln!(file, "-----END PGP MESSAGE-----").unwrap();
425        file.flush().unwrap();
426
427        assert!(is_encrypted_file(file.path()));
428    }
429
430    #[test]
431    fn test_is_encrypted_file_asc_without_pgp_header() {
432        let mut file = NamedTempFile::with_suffix(".asc").unwrap();
433        writeln!(file, "This is just a plain text file").unwrap();
434        writeln!(file, "with .asc extension but no PGP content").unwrap();
435        file.flush().unwrap();
436
437        assert!(!is_encrypted_file(file.path()));
438    }
439
440    #[test]
441    fn test_decrypt_gpg_file_missing_gpg() {
442        // Create a fake .gpg file
443        let mut file = NamedTempFile::with_suffix(".gpg").unwrap();
444        writeln!(file, "fake encrypted content").unwrap();
445        file.flush().unwrap();
446
447        // This will fail because the content isn't actually GPG-encrypted
448        // (or gpg isn't installed, or there's no matching key)
449        let result = decrypt_gpg_file(file.path());
450        assert!(result.is_err());
451
452        if let Err(LoadError::Decryption { path, message }) = result {
453            assert_eq!(path, file.path().to_path_buf());
454            assert!(!message.is_empty());
455        } else {
456            panic!("Expected Decryption error");
457        }
458    }
459}