1#![forbid(unsafe_code)]
29#![warn(missing_docs)]
30
31#[cfg(feature = "cache")]
32pub mod cache;
33mod options;
34mod source_map;
35
36#[cfg(feature = "cache")]
37pub use cache::{
38 CacheEntry, CachedOptions, CachedPlugin, invalidate_cache, load_cache_entry,
39 reintern_directives, save_cache_entry,
40};
41pub use options::Options;
42pub use source_map::{SourceFile, SourceMap};
43
44use rustledger_core::{Directive, DisplayContext};
45use rustledger_parser::{ParseError, Span, Spanned};
46use std::collections::HashSet;
47use std::fs;
48use std::path::{Path, PathBuf};
49use std::process::Command;
50use thiserror::Error;
51
52#[derive(Debug, Error)]
54pub enum LoadError {
55 #[error("failed to read file {path}: {source}")]
57 Io {
58 path: PathBuf,
60 #[source]
62 source: std::io::Error,
63 },
64
65 #[error("include cycle detected: {}", .cycle.join(" -> "))]
67 IncludeCycle {
68 cycle: Vec<String>,
70 },
71
72 #[error("parse errors in {path}")]
74 ParseErrors {
75 path: PathBuf,
77 errors: Vec<ParseError>,
79 },
80
81 #[error("path traversal not allowed: {include_path} escapes base directory {base_dir}")]
83 PathTraversal {
84 include_path: String,
86 base_dir: PathBuf,
88 },
89
90 #[error("failed to decrypt {path}: {message}")]
92 Decryption {
93 path: PathBuf,
95 message: String,
97 },
98}
99
100#[derive(Debug)]
102pub struct LoadResult {
103 pub directives: Vec<Spanned<Directive>>,
105 pub options: Options,
107 pub plugins: Vec<Plugin>,
109 pub source_map: SourceMap,
111 pub errors: Vec<LoadError>,
113 pub display_context: DisplayContext,
115}
116
117#[derive(Debug, Clone)]
119pub struct Plugin {
120 pub name: String,
122 pub config: Option<String>,
124 pub span: Span,
126 pub file_id: usize,
128}
129
130fn is_encrypted_file(path: &Path) -> bool {
136 match path.extension().and_then(|e| e.to_str()) {
137 Some("gpg") => true,
138 Some("asc") => {
139 if let Ok(content) = fs::read_to_string(path) {
141 let check_len = 1024.min(content.len());
142 content[..check_len].contains("-----BEGIN PGP MESSAGE-----")
143 } else {
144 false
145 }
146 }
147 _ => false,
148 }
149}
150
151fn decrypt_gpg_file(path: &Path) -> Result<String, LoadError> {
156 let output = Command::new("gpg")
157 .args(["--batch", "--decrypt"])
158 .arg(path)
159 .output()
160 .map_err(|e| LoadError::Decryption {
161 path: path.to_path_buf(),
162 message: format!("failed to run gpg: {e}"),
163 })?;
164
165 if !output.status.success() {
166 return Err(LoadError::Decryption {
167 path: path.to_path_buf(),
168 message: String::from_utf8_lossy(&output.stderr).trim().to_string(),
169 });
170 }
171
172 String::from_utf8(output.stdout).map_err(|e| LoadError::Decryption {
173 path: path.to_path_buf(),
174 message: format!("decrypted content is not valid UTF-8: {e}"),
175 })
176}
177
178#[derive(Debug, Default)]
180pub struct Loader {
181 loaded_files: HashSet<PathBuf>,
183 include_stack: Vec<PathBuf>,
185 root_dir: Option<PathBuf>,
188 enforce_path_security: bool,
190}
191
192impl Loader {
193 #[must_use]
195 pub fn new() -> Self {
196 Self::default()
197 }
198
199 #[must_use]
213 pub const fn with_path_security(mut self, enabled: bool) -> Self {
214 self.enforce_path_security = enabled;
215 self
216 }
217
218 #[must_use]
223 pub fn with_root_dir(mut self, root: PathBuf) -> Self {
224 self.root_dir = Some(root);
225 self.enforce_path_security = true;
226 self
227 }
228
229 pub fn load(&mut self, path: &Path) -> Result<LoadResult, LoadError> {
245 let mut directives = Vec::new();
246 let mut options = Options::default();
247 let mut plugins = Vec::new();
248 let mut source_map = SourceMap::new();
249 let mut errors = Vec::new();
250
251 let canonical = path.canonicalize().map_err(|e| LoadError::Io {
253 path: path.to_path_buf(),
254 source: e,
255 })?;
256
257 if self.enforce_path_security && self.root_dir.is_none() {
259 self.root_dir = canonical.parent().map(Path::to_path_buf);
260 }
261
262 self.load_recursive(
263 &canonical,
264 &mut directives,
265 &mut options,
266 &mut plugins,
267 &mut source_map,
268 &mut errors,
269 )?;
270
271 let display_context = build_display_context(&directives, &options);
273
274 Ok(LoadResult {
275 directives,
276 options,
277 plugins,
278 source_map,
279 errors,
280 display_context,
281 })
282 }
283
284 fn load_recursive(
285 &mut self,
286 path: &Path,
287 directives: &mut Vec<Spanned<Directive>>,
288 options: &mut Options,
289 plugins: &mut Vec<Plugin>,
290 source_map: &mut SourceMap,
291 errors: &mut Vec<LoadError>,
292 ) -> Result<(), LoadError> {
293 let path_buf = path.to_path_buf();
295 if self.include_stack.contains(&path_buf) {
296 let mut cycle: Vec<String> = self
297 .include_stack
298 .iter()
299 .map(|p| p.display().to_string())
300 .collect();
301 cycle.push(path.display().to_string());
302 return Err(LoadError::IncludeCycle { cycle });
303 }
304
305 if self.loaded_files.contains(path) {
307 return Ok(());
308 }
309
310 let source: std::sync::Arc<str> = if is_encrypted_file(path) {
313 decrypt_gpg_file(path)?.into()
314 } else {
315 let bytes = fs::read(path).map_err(|e| LoadError::Io {
316 path: path.to_path_buf(),
317 source: e,
318 })?;
319 String::from_utf8_lossy(&bytes).into_owned().into()
320 };
321
322 let file_id = source_map.add_file(path.to_path_buf(), std::sync::Arc::clone(&source));
324
325 self.include_stack.push(path.to_path_buf());
327 self.loaded_files.insert(path.to_path_buf());
328
329 let result = rustledger_parser::parse(&source);
331
332 if !result.errors.is_empty() {
334 errors.push(LoadError::ParseErrors {
335 path: path.to_path_buf(),
336 errors: result.errors,
337 });
338 }
339
340 for (key, value, _span) in result.options {
342 options.set(&key, &value);
343 }
344
345 for (name, config, span) in result.plugins {
347 plugins.push(Plugin {
348 name,
349 config,
350 span,
351 file_id,
352 });
353 }
354
355 let base_dir = path.parent().unwrap_or(Path::new("."));
357 for (include_path, _span) in &result.includes {
358 let full_path = base_dir.join(include_path);
359 let canonical = match full_path.canonicalize() {
360 Ok(p) => p,
361 Err(e) => {
362 errors.push(LoadError::Io {
363 path: full_path,
364 source: e,
365 });
366 continue;
367 }
368 };
369
370 if self.enforce_path_security {
372 if let Some(ref root) = self.root_dir {
373 if !canonical.starts_with(root) {
374 errors.push(LoadError::PathTraversal {
375 include_path: include_path.clone(),
376 base_dir: root.clone(),
377 });
378 continue;
379 }
380 }
381 }
382
383 if let Err(e) =
384 self.load_recursive(&canonical, directives, options, plugins, source_map, errors)
385 {
386 errors.push(e);
387 }
388 }
389
390 directives.extend(
392 result
393 .directives
394 .into_iter()
395 .map(|d| d.with_file_id(file_id)),
396 );
397
398 self.include_stack.pop();
400
401 Ok(())
402 }
403}
404
405fn build_display_context(directives: &[Spanned<Directive>], options: &Options) -> DisplayContext {
411 let mut ctx = DisplayContext::new();
412
413 ctx.set_render_commas(options.render_commas);
415
416 for spanned in directives {
418 match &spanned.value {
419 Directive::Transaction(txn) => {
420 for posting in &txn.postings {
421 if let Some(ref units) = posting.units {
423 if let (Some(number), Some(currency)) = (units.number(), units.currency()) {
424 ctx.update(number, currency);
425 }
426 }
427 if let Some(ref cost) = posting.cost {
429 if let (Some(number), Some(currency)) =
430 (cost.number_per.or(cost.number_total), &cost.currency)
431 {
432 ctx.update(number, currency.as_str());
433 }
434 }
435 if let Some(ref price) = posting.price {
437 if let Some(amount) = price.amount() {
438 ctx.update(amount.number, amount.currency.as_str());
439 }
440 }
441 }
442 }
443 Directive::Balance(bal) => {
444 ctx.update(bal.amount.number, bal.amount.currency.as_str());
445 if let Some(tol) = bal.tolerance {
446 ctx.update(tol, bal.amount.currency.as_str());
447 }
448 }
449 Directive::Price(price) => {
450 ctx.update(price.amount.number, price.amount.currency.as_str());
451 }
452 Directive::Pad(_)
453 | Directive::Open(_)
454 | Directive::Close(_)
455 | Directive::Commodity(_)
456 | Directive::Event(_)
457 | Directive::Query(_)
458 | Directive::Note(_)
459 | Directive::Document(_)
460 | Directive::Custom(_) => {}
461 }
462 }
463
464 for (currency, precision) in &options.display_precision {
466 ctx.set_fixed_precision(currency, *precision);
467 }
468
469 ctx
470}
471
472pub fn load(path: &Path) -> Result<LoadResult, LoadError> {
476 Loader::new().load(path)
477}
478
479#[cfg(test)]
480mod tests {
481 use super::*;
482 use std::io::Write;
483 use tempfile::NamedTempFile;
484
485 #[test]
486 fn test_is_encrypted_file_gpg_extension() {
487 let path = Path::new("test.beancount.gpg");
488 assert!(is_encrypted_file(path));
489 }
490
491 #[test]
492 fn test_is_encrypted_file_plain_beancount() {
493 let path = Path::new("test.beancount");
494 assert!(!is_encrypted_file(path));
495 }
496
497 #[test]
498 fn test_is_encrypted_file_asc_with_pgp_header() {
499 let mut file = NamedTempFile::with_suffix(".asc").unwrap();
500 writeln!(file, "-----BEGIN PGP MESSAGE-----").unwrap();
501 writeln!(file, "some encrypted content").unwrap();
502 writeln!(file, "-----END PGP MESSAGE-----").unwrap();
503 file.flush().unwrap();
504
505 assert!(is_encrypted_file(file.path()));
506 }
507
508 #[test]
509 fn test_is_encrypted_file_asc_without_pgp_header() {
510 let mut file = NamedTempFile::with_suffix(".asc").unwrap();
511 writeln!(file, "This is just a plain text file").unwrap();
512 writeln!(file, "with .asc extension but no PGP content").unwrap();
513 file.flush().unwrap();
514
515 assert!(!is_encrypted_file(file.path()));
516 }
517
518 #[test]
519 fn test_decrypt_gpg_file_missing_gpg() {
520 let mut file = NamedTempFile::with_suffix(".gpg").unwrap();
522 writeln!(file, "fake encrypted content").unwrap();
523 file.flush().unwrap();
524
525 let result = decrypt_gpg_file(file.path());
528 assert!(result.is_err());
529
530 if let Err(LoadError::Decryption { path, message }) = result {
531 assert_eq!(path, file.path().to_path_buf());
532 assert!(!message.is_empty());
533 } else {
534 panic!("Expected Decryption error");
535 }
536 }
537}