1#![forbid(unsafe_code)]
29#![warn(missing_docs)]
30
31#[cfg(feature = "cache")]
32pub mod cache;
33mod options;
34mod source_map;
35
36#[cfg(feature = "cache")]
37pub use cache::{
38 CacheEntry, CachedOptions, CachedPlugin, invalidate_cache, load_cache_entry,
39 reintern_directives, save_cache_entry,
40};
41pub use options::Options;
42pub use source_map::{SourceFile, SourceMap};
43
44use rustledger_core::{Directive, DisplayContext};
45use rustledger_parser::{ParseError, Span, Spanned};
46use std::collections::HashSet;
47use std::fs;
48use std::path::{Path, PathBuf};
49use std::process::Command;
50use thiserror::Error;
51
52fn normalize_path(path: &Path) -> PathBuf {
60 if let Ok(canonical) = path.canonicalize() {
62 return canonical;
63 }
64
65 if path.is_absolute() {
67 path.to_path_buf()
68 } else if let Ok(cwd) = std::env::current_dir() {
69 let mut result = cwd;
71 for component in path.components() {
72 match component {
73 std::path::Component::ParentDir => {
74 result.pop();
75 }
76 std::path::Component::Normal(s) => {
77 result.push(s);
78 }
79 std::path::Component::CurDir => {}
80 std::path::Component::RootDir => {
81 result = PathBuf::from("/");
82 }
83 std::path::Component::Prefix(p) => {
84 result = PathBuf::from(p.as_os_str());
85 }
86 }
87 }
88 result
89 } else {
90 path.to_path_buf()
92 }
93}
94
95#[derive(Debug, Error)]
97pub enum LoadError {
98 #[error("failed to read file {path}: {source}")]
100 Io {
101 path: PathBuf,
103 #[source]
105 source: std::io::Error,
106 },
107
108 #[error("include cycle detected: {}", .cycle.join(" -> "))]
110 IncludeCycle {
111 cycle: Vec<String>,
113 },
114
115 #[error("parse errors in {path}")]
117 ParseErrors {
118 path: PathBuf,
120 errors: Vec<ParseError>,
122 },
123
124 #[error("path traversal not allowed: {include_path} escapes base directory {base_dir}")]
126 PathTraversal {
127 include_path: String,
129 base_dir: PathBuf,
131 },
132
133 #[error("failed to decrypt {path}: {message}")]
135 Decryption {
136 path: PathBuf,
138 message: String,
140 },
141}
142
143#[derive(Debug)]
145pub struct LoadResult {
146 pub directives: Vec<Spanned<Directive>>,
148 pub options: Options,
150 pub plugins: Vec<Plugin>,
152 pub source_map: SourceMap,
154 pub errors: Vec<LoadError>,
156 pub display_context: DisplayContext,
158}
159
160#[derive(Debug, Clone)]
162pub struct Plugin {
163 pub name: String,
165 pub config: Option<String>,
167 pub span: Span,
169 pub file_id: usize,
171}
172
173fn is_encrypted_file(path: &Path) -> bool {
179 match path.extension().and_then(|e| e.to_str()) {
180 Some("gpg") => true,
181 Some("asc") => {
182 if let Ok(content) = fs::read_to_string(path) {
184 let check_len = 1024.min(content.len());
185 content[..check_len].contains("-----BEGIN PGP MESSAGE-----")
186 } else {
187 false
188 }
189 }
190 _ => false,
191 }
192}
193
194fn decrypt_gpg_file(path: &Path) -> Result<String, LoadError> {
199 let output = Command::new("gpg")
200 .args(["--batch", "--decrypt"])
201 .arg(path)
202 .output()
203 .map_err(|e| LoadError::Decryption {
204 path: path.to_path_buf(),
205 message: format!("failed to run gpg: {e}"),
206 })?;
207
208 if !output.status.success() {
209 return Err(LoadError::Decryption {
210 path: path.to_path_buf(),
211 message: String::from_utf8_lossy(&output.stderr).trim().to_string(),
212 });
213 }
214
215 String::from_utf8(output.stdout).map_err(|e| LoadError::Decryption {
216 path: path.to_path_buf(),
217 message: format!("decrypted content is not valid UTF-8: {e}"),
218 })
219}
220
221#[derive(Debug, Default)]
223pub struct Loader {
224 loaded_files: HashSet<PathBuf>,
226 include_stack: Vec<PathBuf>,
228 root_dir: Option<PathBuf>,
231 enforce_path_security: bool,
233}
234
235impl Loader {
236 #[must_use]
238 pub fn new() -> Self {
239 Self::default()
240 }
241
242 #[must_use]
256 pub const fn with_path_security(mut self, enabled: bool) -> Self {
257 self.enforce_path_security = enabled;
258 self
259 }
260
261 #[must_use]
266 pub fn with_root_dir(mut self, root: PathBuf) -> Self {
267 self.root_dir = Some(root);
268 self.enforce_path_security = true;
269 self
270 }
271
272 pub fn load(&mut self, path: &Path) -> Result<LoadResult, LoadError> {
288 let mut directives = Vec::new();
289 let mut options = Options::default();
290 let mut plugins = Vec::new();
291 let mut source_map = SourceMap::new();
292 let mut errors = Vec::new();
293
294 let canonical = normalize_path(path);
296
297 if self.enforce_path_security && self.root_dir.is_none() {
299 self.root_dir = canonical.parent().map(Path::to_path_buf);
300 }
301
302 self.load_recursive(
303 &canonical,
304 &mut directives,
305 &mut options,
306 &mut plugins,
307 &mut source_map,
308 &mut errors,
309 )?;
310
311 let display_context = build_display_context(&directives, &options);
313
314 Ok(LoadResult {
315 directives,
316 options,
317 plugins,
318 source_map,
319 errors,
320 display_context,
321 })
322 }
323
324 fn load_recursive(
325 &mut self,
326 path: &Path,
327 directives: &mut Vec<Spanned<Directive>>,
328 options: &mut Options,
329 plugins: &mut Vec<Plugin>,
330 source_map: &mut SourceMap,
331 errors: &mut Vec<LoadError>,
332 ) -> Result<(), LoadError> {
333 let path_buf = path.to_path_buf();
335 if self.include_stack.contains(&path_buf) {
336 let mut cycle: Vec<String> = self
337 .include_stack
338 .iter()
339 .map(|p| p.display().to_string())
340 .collect();
341 cycle.push(path.display().to_string());
342 return Err(LoadError::IncludeCycle { cycle });
343 }
344
345 if self.loaded_files.contains(path) {
347 return Ok(());
348 }
349
350 let source: std::sync::Arc<str> = if is_encrypted_file(path) {
353 decrypt_gpg_file(path)?.into()
354 } else {
355 let bytes = fs::read(path).map_err(|e| LoadError::Io {
356 path: path.to_path_buf(),
357 source: e,
358 })?;
359 String::from_utf8_lossy(&bytes).into_owned().into()
360 };
361
362 let file_id = source_map.add_file(path.to_path_buf(), std::sync::Arc::clone(&source));
364
365 self.include_stack.push(path.to_path_buf());
367 self.loaded_files.insert(path.to_path_buf());
368
369 let result = rustledger_parser::parse(&source);
371
372 if !result.errors.is_empty() {
374 errors.push(LoadError::ParseErrors {
375 path: path.to_path_buf(),
376 errors: result.errors,
377 });
378 }
379
380 for (key, value, _span) in result.options {
382 options.set(&key, &value);
383 }
384
385 for (name, config, span) in result.plugins {
387 plugins.push(Plugin {
388 name,
389 config,
390 span,
391 file_id,
392 });
393 }
394
395 let base_dir = path.parent().unwrap_or(Path::new("."));
397 for (include_path, _span) in &result.includes {
398 let full_path = base_dir.join(include_path);
399 let canonical = normalize_path(&full_path);
401
402 if self.enforce_path_security {
404 if let Some(ref root) = self.root_dir {
405 if !canonical.starts_with(root) {
406 errors.push(LoadError::PathTraversal {
407 include_path: include_path.clone(),
408 base_dir: root.clone(),
409 });
410 continue;
411 }
412 }
413 }
414
415 if let Err(e) =
416 self.load_recursive(&canonical, directives, options, plugins, source_map, errors)
417 {
418 errors.push(e);
419 }
420 }
421
422 directives.extend(
424 result
425 .directives
426 .into_iter()
427 .map(|d| d.with_file_id(file_id)),
428 );
429
430 self.include_stack.pop();
432
433 Ok(())
434 }
435}
436
437fn build_display_context(directives: &[Spanned<Directive>], options: &Options) -> DisplayContext {
443 let mut ctx = DisplayContext::new();
444
445 ctx.set_render_commas(options.render_commas);
447
448 for spanned in directives {
450 match &spanned.value {
451 Directive::Transaction(txn) => {
452 for posting in &txn.postings {
453 if let Some(ref units) = posting.units {
455 if let (Some(number), Some(currency)) = (units.number(), units.currency()) {
456 ctx.update(number, currency);
457 }
458 }
459 if let Some(ref cost) = posting.cost {
461 if let (Some(number), Some(currency)) =
462 (cost.number_per.or(cost.number_total), &cost.currency)
463 {
464 ctx.update(number, currency.as_str());
465 }
466 }
467 if let Some(ref price) = posting.price {
469 if let Some(amount) = price.amount() {
470 ctx.update(amount.number, amount.currency.as_str());
471 }
472 }
473 }
474 }
475 Directive::Balance(bal) => {
476 ctx.update(bal.amount.number, bal.amount.currency.as_str());
477 if let Some(tol) = bal.tolerance {
478 ctx.update(tol, bal.amount.currency.as_str());
479 }
480 }
481 Directive::Price(price) => {
482 ctx.update(price.amount.number, price.amount.currency.as_str());
483 }
484 Directive::Pad(_)
485 | Directive::Open(_)
486 | Directive::Close(_)
487 | Directive::Commodity(_)
488 | Directive::Event(_)
489 | Directive::Query(_)
490 | Directive::Note(_)
491 | Directive::Document(_)
492 | Directive::Custom(_) => {}
493 }
494 }
495
496 for (currency, precision) in &options.display_precision {
498 ctx.set_fixed_precision(currency, *precision);
499 }
500
501 ctx
502}
503
504pub fn load(path: &Path) -> Result<LoadResult, LoadError> {
508 Loader::new().load(path)
509}
510
511#[cfg(test)]
512mod tests {
513 use super::*;
514 use std::io::Write;
515 use tempfile::NamedTempFile;
516
517 #[test]
518 fn test_is_encrypted_file_gpg_extension() {
519 let path = Path::new("test.beancount.gpg");
520 assert!(is_encrypted_file(path));
521 }
522
523 #[test]
524 fn test_is_encrypted_file_plain_beancount() {
525 let path = Path::new("test.beancount");
526 assert!(!is_encrypted_file(path));
527 }
528
529 #[test]
530 fn test_is_encrypted_file_asc_with_pgp_header() {
531 let mut file = NamedTempFile::with_suffix(".asc").unwrap();
532 writeln!(file, "-----BEGIN PGP MESSAGE-----").unwrap();
533 writeln!(file, "some encrypted content").unwrap();
534 writeln!(file, "-----END PGP MESSAGE-----").unwrap();
535 file.flush().unwrap();
536
537 assert!(is_encrypted_file(file.path()));
538 }
539
540 #[test]
541 fn test_is_encrypted_file_asc_without_pgp_header() {
542 let mut file = NamedTempFile::with_suffix(".asc").unwrap();
543 writeln!(file, "This is just a plain text file").unwrap();
544 writeln!(file, "with .asc extension but no PGP content").unwrap();
545 file.flush().unwrap();
546
547 assert!(!is_encrypted_file(file.path()));
548 }
549
550 #[test]
551 fn test_decrypt_gpg_file_missing_gpg() {
552 let mut file = NamedTempFile::with_suffix(".gpg").unwrap();
554 writeln!(file, "fake encrypted content").unwrap();
555 file.flush().unwrap();
556
557 let result = decrypt_gpg_file(file.path());
560 assert!(result.is_err());
561
562 if let Err(LoadError::Decryption { path, message }) = result {
563 assert_eq!(path, file.path().to_path_buf());
564 assert!(!message.is_empty());
565 } else {
566 panic!("Expected Decryption error");
567 }
568 }
569}