Skip to main content

rustledger_booking/
lib.rs

1//! Beancount booking engine with interpolation.
2//!
3//! This crate provides:
4//! - Transaction interpolation (filling in missing amounts)
5//! - Transaction balancing verification
6//! - Tolerance calculation
7//!
8//! # Interpolation
9//!
10//! When a transaction has exactly one posting per currency without an amount,
11//! that amount can be calculated to make the transaction balance.
12//!
13//! ```ignore
14//! use rustledger_booking::interpolate;
15//!
16//! // Transaction with one missing amount
17//! // 2024-01-15 * "Groceries"
18//! //   Expenses:Food  50.00 USD
19//! //   Assets:Cash               <- amount inferred as -50.00 USD
20//! ```
21
22#![forbid(unsafe_code)]
23#![warn(missing_docs)]
24
25mod book;
26mod interpolate;
27mod pad;
28
29pub use book::{BookedTransaction, BookingEngine, BookingError, CapitalGain, book_transactions};
30pub use interpolate::{InterpolationError, InterpolationResult, interpolate};
31pub use pad::{PadError, PadResult, expand_pads, merge_with_padding, process_pads};
32
33use bigdecimal::BigDecimal;
34use rust_decimal::Decimal;
35use rust_decimal::prelude::Signed;
36use rustledger_core::{Amount, IncompleteAmount, InternedStr, Transaction};
37use std::collections::HashMap;
38
39/// Calculate the tolerance for a set of amounts.
40///
41/// Tolerance is the maximum of all individual amount tolerances.
42#[must_use]
43pub fn calculate_tolerance(amounts: &[&Amount]) -> HashMap<InternedStr, Decimal> {
44    // Pre-allocate for typical case (1-3 currencies per transaction)
45    let mut tolerances: HashMap<InternedStr, Decimal> =
46        HashMap::with_capacity(amounts.len().min(4));
47
48    for amount in amounts {
49        let tol = amount.inferred_tolerance();
50        tolerances
51            .entry(amount.currency.clone())
52            .and_modify(|t| *t = (*t).max(tol))
53            .or_insert(tol);
54    }
55
56    tolerances
57}
58
59/// Infer the cost currency from other postings in the transaction.
60///
61/// Python beancount infers cost currency from simple postings (those without
62/// cost specs) when a cost is specified without a currency like `{100}`.
63///
64/// Currency inference follows this priority:
65/// 1. An explicit currency in the cost specification itself (handled by the caller).
66/// 2. A price annotation on a simple posting (the price currency takes precedence).
67/// 3. The currency of other simple postings (units or currency-only amounts).
68/// 4. The currency from a cost spec (e.g., `{0 USD}` for zero-cost items).
69#[must_use]
70pub(crate) fn infer_cost_currency_from_postings(transaction: &Transaction) -> Option<InternedStr> {
71    // First pass: look for simple postings (no cost spec) - these take priority
72    for posting in &transaction.postings {
73        // Skip postings with cost specs in first pass
74        if posting.cost.is_some() {
75            continue;
76        }
77
78        // Get the currency from this posting's units
79        if let Some(units) = &posting.units {
80            match units {
81                IncompleteAmount::Complete(amount) => {
82                    // If this posting has a price annotation, the "real" currency
83                    // is the price currency, not the units currency
84                    if let Some(price) = &posting.price {
85                        match price {
86                            rustledger_core::PriceAnnotation::Unit(a)
87                            | rustledger_core::PriceAnnotation::Total(a) => {
88                                return Some(a.currency.clone());
89                            }
90                            rustledger_core::PriceAnnotation::UnitIncomplete(inc)
91                            | rustledger_core::PriceAnnotation::TotalIncomplete(inc) => {
92                                if let Some(a) = inc.as_amount() {
93                                    return Some(a.currency.clone());
94                                }
95                            }
96                            _ => {}
97                        }
98                    }
99                    // Simple posting - use its currency
100                    return Some(amount.currency.clone());
101                }
102                IncompleteAmount::CurrencyOnly(currency) => {
103                    return Some(currency.clone());
104                }
105                IncompleteAmount::NumberOnly(_) => {}
106            }
107        }
108    }
109
110    // Second pass: look for cost spec currencies (e.g., `{0 USD}`)
111    // This handles zero-cost postings where the cost currency should be used
112    for posting in &transaction.postings {
113        if let Some(cost) = &posting.cost
114            && let Some(currency) = &cost.currency
115        {
116            return Some(currency.clone());
117        }
118    }
119
120    None
121}
122
123/// Calculate the residual (imbalance) of a transaction.
124///
125/// Returns a map of currency -> residual amount.
126/// A balanced transaction has all residuals within tolerance.
127///
128/// # TLA+ Specification
129///
130/// Implements balance checking from `DoubleEntry.tla`:
131/// - Invariant: `TransactionsBalance` - For every transaction, `sum(postings) = 0`
132/// - Each currency is checked independently
133/// - A non-zero residual indicates a violation of double-entry bookkeeping
134///
135/// See: `spec/tla/DoubleEntry.tla`
136#[must_use]
137pub fn calculate_residual(transaction: &Transaction) -> HashMap<InternedStr, Decimal> {
138    // Pre-allocate for typical case (1-2 currencies per transaction)
139    let mut residuals: HashMap<InternedStr, Decimal> =
140        HashMap::with_capacity(transaction.postings.len().min(4));
141
142    // Lazily compute inferred currency only when needed (most transactions don't need it)
143    let mut inferred_cost_currency: Option<Option<InternedStr>> = None;
144    let get_inferred_currency = |cache: &mut Option<Option<InternedStr>>| -> Option<InternedStr> {
145        cache
146            .get_or_insert_with(|| infer_cost_currency_from_postings(transaction))
147            .clone()
148    };
149
150    for posting in &transaction.postings {
151        // Only process complete amounts
152        if let Some(IncompleteAmount::Complete(units)) = &posting.units {
153            // Determine the "weight" of this posting for balance purposes.
154            // - If there's a cost, the weight is in the cost currency (not units currency)
155            // - If there's a price annotation, the weight is in the price currency (not units currency)
156            // - Otherwise, the weight is just the units
157
158            // Check if cost spec has determinable values.
159            // If cost has number but no currency, try to infer currency from:
160            // 1. Price annotation
161            // 2. Other postings in the transaction
162            let cost_contribution = posting.cost.as_ref().and_then(|cost_spec| {
163                // Helper to get currency from price annotation
164                let price_currency = posting.price.as_ref().and_then(|p| match p {
165                    rustledger_core::PriceAnnotation::Unit(a)
166                    | rustledger_core::PriceAnnotation::Total(a) => Some(a.currency.clone()),
167                    rustledger_core::PriceAnnotation::UnitIncomplete(inc)
168                    | rustledger_core::PriceAnnotation::TotalIncomplete(inc) => {
169                        inc.as_amount().map(|a| a.currency.clone())
170                    }
171                    _ => None,
172                });
173
174                // Try to get cost currency, falling back to price currency, then other postings
175                let inferred_currency = cost_spec
176                    .currency
177                    .clone()
178                    .or(price_currency)
179                    .or_else(|| get_inferred_currency(&mut inferred_cost_currency));
180
181                // Check number_total first: when both per-unit and total are present
182                // (booking preserves total), use the total directly for exact residual
183                // calculation. Division-then-multiplication loses precision.
184                if let (Some(total), Some(cost_curr)) =
185                    (&cost_spec.number_total, &inferred_currency)
186                {
187                    Some((cost_curr.clone(), *total * units.number.signum()))
188                } else if let (Some(per_unit), Some(cost_curr)) =
189                    (&cost_spec.number_per, &inferred_currency)
190                {
191                    let cost_amount = units.number * per_unit;
192                    Some((cost_curr.clone(), cost_amount))
193                } else {
194                    None // Cost spec without determinable amount (e.g., empty `{}`)
195                }
196            });
197
198            if let Some((currency, amount)) = cost_contribution {
199                // Cost-based posting: weight is in the cost currency
200                *residuals.entry(currency).or_default() += amount;
201            } else if let Some(price) = &posting.price {
202                // Price annotation: converts units to price currency for balance purposes.
203                // The weight is in the price currency, not the units currency.
204                match price {
205                    rustledger_core::PriceAnnotation::Unit(price_amt) => {
206                        let converted = units.number.abs() * price_amt.number;
207                        *residuals.entry(price_amt.currency.clone()).or_default() +=
208                            converted * units.number.signum();
209                    }
210                    rustledger_core::PriceAnnotation::Total(price_amt) => {
211                        *residuals.entry(price_amt.currency.clone()).or_default() +=
212                            price_amt.number * units.number.signum();
213                    }
214                    // Incomplete price annotations - extract what we can
215                    rustledger_core::PriceAnnotation::UnitIncomplete(inc) => {
216                        if let Some(price_amt) = inc.as_amount() {
217                            let converted = units.number.abs() * price_amt.number;
218                            *residuals.entry(price_amt.currency.clone()).or_default() +=
219                                converted * units.number.signum();
220                        } else {
221                            // Can't calculate price conversion, fall back to units
222                            *residuals.entry(units.currency.clone()).or_default() += units.number;
223                        }
224                    }
225                    rustledger_core::PriceAnnotation::TotalIncomplete(inc) => {
226                        if let Some(price_amt) = inc.as_amount() {
227                            *residuals.entry(price_amt.currency.clone()).or_default() +=
228                                price_amt.number * units.number.signum();
229                        } else {
230                            // Can't calculate price conversion, fall back to units
231                            *residuals.entry(units.currency.clone()).or_default() += units.number;
232                        }
233                    }
234                    // Empty price annotations - fall back to units
235                    rustledger_core::PriceAnnotation::UnitEmpty
236                    | rustledger_core::PriceAnnotation::TotalEmpty => {
237                        *residuals.entry(units.currency.clone()).or_default() += units.number;
238                    }
239                }
240            } else if posting.cost.is_some() {
241                // Cost spec exists but is empty (e.g., `{}`), and no price annotation
242                // Don't contribute to residual - cost will be filled by lot matching
243            } else {
244                // Simple posting: weight is just the units
245                *residuals.entry(units.currency.clone()).or_default() += units.number;
246            }
247        }
248    }
249
250    residuals
251}
252
253/// Convert a `rust_decimal::Decimal` to `BigDecimal` for arbitrary-precision arithmetic.
254///
255/// Individual `Decimal` values are representable exactly (≤28 significant digits).
256/// The precision loss only occurs during arithmetic, so converting before operations
257/// preserves full precision.
258fn to_big(d: Decimal) -> BigDecimal {
259    use std::str::FromStr;
260    // rust_decimal Display is exact; BigDecimal FromStr handles any decimal string
261    BigDecimal::from_str(&d.to_string()).expect("Decimal always produces valid decimal string")
262}
263
264/// Calculate the residual of a transaction using arbitrary-precision arithmetic.
265///
266/// This mirrors [`calculate_residual`] but uses `BigDecimal` to avoid precision loss
267/// when amounts have near-28-digit precision. `rust_decimal` is limited to 28-29
268/// significant digits; this function handles arbitrary precision correctly.
269#[must_use]
270pub fn calculate_residual_precise(transaction: &Transaction) -> HashMap<InternedStr, BigDecimal> {
271    let mut residuals: HashMap<InternedStr, BigDecimal> =
272        HashMap::with_capacity(transaction.postings.len().min(4));
273
274    let mut inferred_cost_currency: Option<Option<InternedStr>> = None;
275    let get_inferred_currency = |cache: &mut Option<Option<InternedStr>>| -> Option<InternedStr> {
276        cache
277            .get_or_insert_with(|| infer_cost_currency_from_postings(transaction))
278            .clone()
279    };
280
281    for posting in &transaction.postings {
282        if let Some(IncompleteAmount::Complete(units)) = &posting.units {
283            let units_number = to_big(units.number);
284
285            let cost_contribution = posting.cost.as_ref().and_then(|cost_spec| {
286                let price_currency = posting.price.as_ref().and_then(|p| match p {
287                    rustledger_core::PriceAnnotation::Unit(a)
288                    | rustledger_core::PriceAnnotation::Total(a) => Some(a.currency.clone()),
289                    rustledger_core::PriceAnnotation::UnitIncomplete(inc)
290                    | rustledger_core::PriceAnnotation::TotalIncomplete(inc) => {
291                        inc.as_amount().map(|a| a.currency.clone())
292                    }
293                    _ => None,
294                });
295
296                let inferred_currency = cost_spec
297                    .currency
298                    .clone()
299                    .or(price_currency)
300                    .or_else(|| get_inferred_currency(&mut inferred_cost_currency));
301
302                // Check number_total first: when both per-unit and total are present
303                // (booking preserves total), use the total directly for exact residual
304                // calculation. Division-then-multiplication loses precision.
305                if let (Some(total), Some(cost_curr)) =
306                    (&cost_spec.number_total, &inferred_currency)
307                {
308                    Some((
309                        cost_curr.clone(),
310                        to_big(*total) * to_big(units.number.signum()),
311                    ))
312                } else if let (Some(per_unit), Some(cost_curr)) =
313                    (&cost_spec.number_per, &inferred_currency)
314                {
315                    let cost_amount = &units_number * to_big(*per_unit);
316                    Some((cost_curr.clone(), cost_amount))
317                } else {
318                    None
319                }
320            });
321
322            if let Some((currency, amount)) = cost_contribution {
323                *residuals.entry(currency).or_default() += amount;
324            } else if let Some(price) = &posting.price {
325                match price {
326                    rustledger_core::PriceAnnotation::Unit(price_amt) => {
327                        let converted = units_number.abs() * to_big(price_amt.number);
328                        *residuals.entry(price_amt.currency.clone()).or_default() +=
329                            converted * to_big(units.number.signum());
330                    }
331                    rustledger_core::PriceAnnotation::Total(price_amt) => {
332                        *residuals.entry(price_amt.currency.clone()).or_default() +=
333                            to_big(price_amt.number) * to_big(units.number.signum());
334                    }
335                    rustledger_core::PriceAnnotation::UnitIncomplete(inc) => {
336                        if let Some(price_amt) = inc.as_amount() {
337                            let converted = units_number.abs() * to_big(price_amt.number);
338                            *residuals.entry(price_amt.currency.clone()).or_default() +=
339                                converted * to_big(units.number.signum());
340                        } else {
341                            *residuals.entry(units.currency.clone()).or_default() +=
342                                units_number.clone();
343                        }
344                    }
345                    rustledger_core::PriceAnnotation::TotalIncomplete(inc) => {
346                        if let Some(price_amt) = inc.as_amount() {
347                            *residuals.entry(price_amt.currency.clone()).or_default() +=
348                                to_big(price_amt.number) * to_big(units.number.signum());
349                        } else {
350                            *residuals.entry(units.currency.clone()).or_default() +=
351                                units_number.clone();
352                        }
353                    }
354                    rustledger_core::PriceAnnotation::UnitEmpty
355                    | rustledger_core::PriceAnnotation::TotalEmpty => {
356                        *residuals.entry(units.currency.clone()).or_default() +=
357                            units_number.clone();
358                    }
359                }
360            } else if posting.cost.is_some() {
361                // Empty cost spec — don't contribute to residual
362            } else {
363                *residuals.entry(units.currency.clone()).or_default() += units_number;
364            }
365        }
366    }
367
368    residuals
369}
370
371/// Check if a transaction is balanced within tolerance.
372#[must_use]
373#[allow(clippy::implicit_hasher)]
374pub fn is_balanced(transaction: &Transaction, tolerances: &HashMap<InternedStr, Decimal>) -> bool {
375    let residuals = calculate_residual(transaction);
376
377    for (currency, residual) in residuals {
378        let tolerance = tolerances.get(&currency).copied().unwrap_or(Decimal::ZERO); // Default 0 (exact balance for integer-only currencies)
379
380        if residual.abs() > tolerance {
381            return false;
382        }
383    }
384
385    true
386}
387
388/// Normalize total prices (`@@`) to per-unit prices (`@`) on a transaction.
389///
390/// This converts `PriceAnnotation::Total` to `PriceAnnotation::Unit` by dividing
391/// the total price by the number of units. This should be called AFTER validation
392/// (balance checking) to preserve exact total prices for precise residual calculation.
393///
394/// Matches Python beancount behavior where `@@` is converted to `@`.
395pub fn normalize_prices(txn: &mut Transaction) {
396    use rustledger_core::PriceAnnotation;
397
398    for posting in &mut txn.postings {
399        if let (Some(IncompleteAmount::Complete(units)), Some(price)) =
400            (&posting.units, &posting.price)
401        {
402            let normalized = match price {
403                PriceAnnotation::Total(total_amount) if !units.number.is_zero() => {
404                    let per_unit = total_amount.number / units.number.abs();
405                    Some(PriceAnnotation::Unit(Amount::new(
406                        per_unit,
407                        &total_amount.currency,
408                    )))
409                }
410                PriceAnnotation::TotalIncomplete(inc) if !units.number.is_zero() => {
411                    if let Some(total_amount) = inc.as_amount() {
412                        let per_unit = total_amount.number / units.number.abs();
413                        Some(PriceAnnotation::Unit(Amount::new(
414                            per_unit,
415                            &total_amount.currency,
416                        )))
417                    } else {
418                        None
419                    }
420                }
421                PriceAnnotation::TotalEmpty => Some(PriceAnnotation::UnitEmpty),
422                _ => None,
423            };
424            if let Some(normalized_price) = normalized {
425                posting.price = Some(normalized_price);
426            }
427        }
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434    use rust_decimal_macros::dec;
435    use rustledger_core::{CostSpec, IncompleteAmount, NaiveDate, Posting, PriceAnnotation};
436
437    fn date(year: i32, month: u32, day: u32) -> NaiveDate {
438        rustledger_core::naive_date(year, month, day).unwrap()
439    }
440
441    // =========================================================================
442    // Basic residual tests (existing)
443    // =========================================================================
444
445    #[test]
446    fn test_calculate_residual_balanced() {
447        let txn = Transaction::new(date(2024, 1, 15), "Test")
448            .with_posting(Posting::new(
449                "Expenses:Food",
450                Amount::new(dec!(50.00), "USD"),
451            ))
452            .with_posting(Posting::new(
453                "Assets:Cash",
454                Amount::new(dec!(-50.00), "USD"),
455            ));
456
457        let residual = calculate_residual(&txn);
458        assert_eq!(residual.get("USD"), Some(&dec!(0)));
459    }
460
461    #[test]
462    fn test_calculate_residual_unbalanced() {
463        let txn = Transaction::new(date(2024, 1, 15), "Test")
464            .with_posting(Posting::new(
465                "Expenses:Food",
466                Amount::new(dec!(50.00), "USD"),
467            ))
468            .with_posting(Posting::new(
469                "Assets:Cash",
470                Amount::new(dec!(-45.00), "USD"),
471            ));
472
473        let residual = calculate_residual(&txn);
474        assert_eq!(residual.get("USD"), Some(&dec!(5.00)));
475    }
476
477    #[test]
478    fn test_is_balanced() {
479        let txn = Transaction::new(date(2024, 1, 15), "Test")
480            .with_posting(Posting::new(
481                "Expenses:Food",
482                Amount::new(dec!(50.00), "USD"),
483            ))
484            .with_posting(Posting::new(
485                "Assets:Cash",
486                Amount::new(dec!(-50.00), "USD"),
487            ));
488
489        let tolerances = calculate_tolerance(&[
490            &Amount::new(dec!(50.00), "USD"),
491            &Amount::new(dec!(-50.00), "USD"),
492        ]);
493
494        assert!(is_balanced(&txn, &tolerances));
495    }
496
497    #[test]
498    fn test_is_balanced_within_tolerance() {
499        let txn = Transaction::new(date(2024, 1, 15), "Test")
500            .with_posting(Posting::new(
501                "Expenses:Food",
502                Amount::new(dec!(50.004), "USD"),
503            ))
504            .with_posting(Posting::new(
505                "Assets:Cash",
506                Amount::new(dec!(-50.00), "USD"),
507            ));
508
509        let tolerances = calculate_tolerance(&[
510            &Amount::new(dec!(50.004), "USD"),
511            &Amount::new(dec!(-50.00), "USD"),
512        ]);
513
514        // 0.004 is within tolerance of 0.005 (scale 2 -> 0.005)
515        assert!(is_balanced(&txn, &tolerances));
516    }
517
518    #[test]
519    fn test_calculate_tolerance() {
520        let amounts = [
521            Amount::new(dec!(100), "USD"),    // scale 0 -> tol 0.5
522            Amount::new(dec!(50.00), "USD"),  // scale 2 -> tol 0.005
523            Amount::new(dec!(25.000), "EUR"), // scale 3 -> tol 0.0005
524        ];
525
526        let refs: Vec<&Amount> = amounts.iter().collect();
527        let tolerances = calculate_tolerance(&refs);
528
529        // USD should use the max tolerance (0.5 from scale 0)
530        assert_eq!(tolerances.get("USD"), Some(&dec!(0.5)));
531        assert_eq!(tolerances.get("EUR"), Some(&dec!(0.0005)));
532    }
533
534    // =========================================================================
535    // Cost-based residual tests
536    // =========================================================================
537
538    /// Test residual calculation with per-unit cost.
539    /// Buy 10 AAPL at $150 each = $1500 total cost in USD.
540    #[test]
541    fn test_calculate_residual_with_per_unit_cost() {
542        let txn = Transaction::new(date(2024, 1, 15), "Buy stock")
543            .with_posting(
544                Posting::new("Assets:Stock", Amount::new(dec!(10), "AAPL")).with_cost(
545                    CostSpec::empty()
546                        .with_number_per(dec!(150.00))
547                        .with_currency("USD"),
548                ),
549            )
550            .with_posting(Posting::new(
551                "Assets:Cash",
552                Amount::new(dec!(-1500.00), "USD"),
553            ));
554
555        let residual = calculate_residual(&txn);
556        // Cost posting contributes 10 * 150 = 1500 USD
557        // Cash posting contributes -1500 USD
558        // Residual should be 0
559        assert_eq!(residual.get("USD"), Some(&dec!(0)));
560        // AAPL should not appear in residuals (cost converts to USD)
561        assert_eq!(residual.get("AAPL"), None);
562    }
563
564    /// Test residual calculation with total cost.
565    /// Buy 10 AAPL with total cost of $1500.
566    #[test]
567    fn test_calculate_residual_with_total_cost() {
568        let txn = Transaction::new(date(2024, 1, 15), "Buy stock")
569            .with_posting(
570                Posting::new("Assets:Stock", Amount::new(dec!(10), "AAPL")).with_cost(
571                    CostSpec::empty()
572                        .with_number_total(dec!(1500.00))
573                        .with_currency("USD"),
574                ),
575            )
576            .with_posting(Posting::new(
577                "Assets:Cash",
578                Amount::new(dec!(-1500.00), "USD"),
579            ));
580
581        let residual = calculate_residual(&txn);
582        // Total cost posting contributes 1500 * signum(10) = 1500 USD
583        // Cash posting contributes -1500 USD
584        assert_eq!(residual.get("USD"), Some(&dec!(0)));
585    }
586
587    /// Test residual calculation with total cost and negative units (sell).
588    #[test]
589    fn test_calculate_residual_with_total_cost_negative_units() {
590        let txn = Transaction::new(date(2024, 1, 15), "Sell stock")
591            .with_posting(
592                Posting::new("Assets:Stock", Amount::new(dec!(-10), "AAPL")).with_cost(
593                    CostSpec::empty()
594                        .with_number_total(dec!(1500.00))
595                        .with_currency("USD"),
596                ),
597            )
598            .with_posting(Posting::new(
599                "Assets:Cash",
600                Amount::new(dec!(1500.00), "USD"),
601            ));
602
603        let residual = calculate_residual(&txn);
604        // Total cost with negative units: 1500 * signum(-10) = -1500 USD
605        // Cash posting contributes +1500 USD
606        assert_eq!(residual.get("USD"), Some(&dec!(0)));
607    }
608
609    /// Test cost spec without amount/currency falls back to units.
610    #[test]
611    fn test_calculate_residual_cost_without_amount_skips() {
612        // When a posting has an empty cost spec (e.g., `{}`) and no price annotation,
613        // it doesn't contribute to the residual because the cost will be determined
614        // by lot matching during booking. This matches Python beancount behavior.
615        let txn = Transaction::new(date(2024, 1, 15), "Test")
616            .with_posting(
617                Posting::new("Assets:Stock", Amount::new(dec!(10), "AAPL"))
618                    .with_cost(CostSpec::empty()), // Empty cost spec - doesn't contribute
619            )
620            .with_posting(Posting::new("Assets:Cash", Amount::new(dec!(-10), "AAPL")));
621
622        let residual = calculate_residual(&txn);
623        // Empty cost spec posting doesn't contribute, only the second posting does
624        assert_eq!(residual.get("AAPL"), Some(&dec!(-10)));
625    }
626
627    // =========================================================================
628    // Price annotation residual tests
629    // =========================================================================
630
631    /// Test residual with per-unit price annotation (@).
632    /// -100 USD @ 0.85 EUR means we're converting 100 USD to EUR at 0.85 rate.
633    #[test]
634    fn test_calculate_residual_with_unit_price() {
635        let txn = Transaction::new(date(2024, 1, 15), "Currency exchange")
636            .with_posting(
637                Posting::new("Assets:USD", Amount::new(dec!(-100.00), "USD"))
638                    .with_price(PriceAnnotation::Unit(Amount::new(dec!(0.85), "EUR"))),
639            )
640            .with_posting(Posting::new("Assets:EUR", Amount::new(dec!(85.00), "EUR")));
641
642        let residual = calculate_residual(&txn);
643        // Price posting: |-100| * 0.85 * signum(-100) = -85 EUR
644        // EUR posting: +85 EUR
645        // Total: 0 EUR
646        assert_eq!(residual.get("EUR"), Some(&dec!(0)));
647        // USD should not appear (converted to EUR)
648        assert_eq!(residual.get("USD"), None);
649    }
650
651    /// Test residual with total price annotation (@@).
652    #[test]
653    fn test_calculate_residual_with_total_price() {
654        let txn = Transaction::new(date(2024, 1, 15), "Currency exchange")
655            .with_posting(
656                Posting::new("Assets:USD", Amount::new(dec!(-100.00), "USD"))
657                    .with_price(PriceAnnotation::Total(Amount::new(dec!(85.00), "EUR"))),
658            )
659            .with_posting(Posting::new("Assets:EUR", Amount::new(dec!(85.00), "EUR")));
660
661        let residual = calculate_residual(&txn);
662        // Total price: 85 * signum(-100) = -85 EUR
663        // EUR posting: +85 EUR
664        assert_eq!(residual.get("EUR"), Some(&dec!(0)));
665    }
666
667    /// Test residual with positive units and unit price.
668    #[test]
669    fn test_calculate_residual_with_unit_price_positive() {
670        let txn = Transaction::new(date(2024, 1, 15), "Buy EUR")
671            .with_posting(
672                Posting::new("Assets:EUR", Amount::new(dec!(85.00), "EUR"))
673                    .with_price(PriceAnnotation::Unit(Amount::new(dec!(1.18), "USD"))),
674            )
675            .with_posting(Posting::new(
676                "Assets:USD",
677                Amount::new(dec!(-100.30), "USD"),
678            ));
679
680        let residual = calculate_residual(&txn);
681        // Price posting: |85| * 1.18 * signum(85) = 100.30 USD
682        // USD posting: -100.30 USD
683        assert_eq!(residual.get("USD"), Some(&dec!(0)));
684    }
685
686    /// Test `UnitIncomplete` price annotation with complete amount.
687    #[test]
688    fn test_calculate_residual_unit_incomplete_with_amount() {
689        let txn = Transaction::new(date(2024, 1, 15), "Exchange")
690            .with_posting(
691                Posting::new("Assets:USD", Amount::new(dec!(-100.00), "USD")).with_price(
692                    PriceAnnotation::UnitIncomplete(IncompleteAmount::Complete(Amount::new(
693                        dec!(0.85),
694                        "EUR",
695                    ))),
696                ),
697            )
698            .with_posting(Posting::new("Assets:EUR", Amount::new(dec!(85.00), "EUR")));
699
700        let residual = calculate_residual(&txn);
701        assert_eq!(residual.get("EUR"), Some(&dec!(0)));
702    }
703
704    /// Test `TotalIncomplete` price annotation with complete amount.
705    #[test]
706    fn test_calculate_residual_total_incomplete_with_amount() {
707        let txn = Transaction::new(date(2024, 1, 15), "Exchange")
708            .with_posting(
709                Posting::new("Assets:USD", Amount::new(dec!(-100.00), "USD")).with_price(
710                    PriceAnnotation::TotalIncomplete(IncompleteAmount::Complete(Amount::new(
711                        dec!(85.00),
712                        "EUR",
713                    ))),
714                ),
715            )
716            .with_posting(Posting::new("Assets:EUR", Amount::new(dec!(85.00), "EUR")));
717
718        let residual = calculate_residual(&txn);
719        assert_eq!(residual.get("EUR"), Some(&dec!(0)));
720    }
721
722    /// Test `UnitIncomplete` without amount falls back to units.
723    #[test]
724    fn test_calculate_residual_unit_incomplete_no_amount_fallback() {
725        let txn = Transaction::new(date(2024, 1, 15), "Test")
726            .with_posting(
727                Posting::new("Assets:USD", Amount::new(dec!(100.00), "USD")).with_price(
728                    PriceAnnotation::UnitIncomplete(IncompleteAmount::NumberOnly(dec!(0.85))),
729                ),
730            )
731            .with_posting(Posting::new(
732                "Assets:USD",
733                Amount::new(dec!(-100.00), "USD"),
734            ));
735
736        let residual = calculate_residual(&txn);
737        // Falls back to units since no currency in incomplete amount
738        assert_eq!(residual.get("USD"), Some(&dec!(0)));
739    }
740
741    /// Test `TotalIncomplete` without amount falls back to units.
742    #[test]
743    fn test_calculate_residual_total_incomplete_no_amount_fallback() {
744        let txn = Transaction::new(date(2024, 1, 15), "Test")
745            .with_posting(
746                Posting::new("Assets:USD", Amount::new(dec!(100.00), "USD")).with_price(
747                    PriceAnnotation::TotalIncomplete(IncompleteAmount::NumberOnly(dec!(85.00))),
748                ),
749            )
750            .with_posting(Posting::new(
751                "Assets:USD",
752                Amount::new(dec!(-100.00), "USD"),
753            ));
754
755        let residual = calculate_residual(&txn);
756        assert_eq!(residual.get("USD"), Some(&dec!(0)));
757    }
758
759    /// Test `UnitEmpty` price annotation falls back to units.
760    #[test]
761    fn test_calculate_residual_unit_empty_fallback() {
762        let txn = Transaction::new(date(2024, 1, 15), "Test")
763            .with_posting(
764                Posting::new("Assets:USD", Amount::new(dec!(100.00), "USD"))
765                    .with_price(PriceAnnotation::UnitEmpty),
766            )
767            .with_posting(Posting::new(
768                "Assets:USD",
769                Amount::new(dec!(-100.00), "USD"),
770            ));
771
772        let residual = calculate_residual(&txn);
773        // Falls back to units
774        assert_eq!(residual.get("USD"), Some(&dec!(0)));
775    }
776
777    /// Test `TotalEmpty` price annotation falls back to units.
778    #[test]
779    fn test_calculate_residual_total_empty_fallback() {
780        let txn = Transaction::new(date(2024, 1, 15), "Test")
781            .with_posting(
782                Posting::new("Assets:USD", Amount::new(dec!(100.00), "USD"))
783                    .with_price(PriceAnnotation::TotalEmpty),
784            )
785            .with_posting(Posting::new(
786                "Assets:USD",
787                Amount::new(dec!(-100.00), "USD"),
788            ));
789
790        let residual = calculate_residual(&txn);
791        assert_eq!(residual.get("USD"), Some(&dec!(0)));
792    }
793
794    // =========================================================================
795    // Mixed and edge case tests
796    // =========================================================================
797
798    /// Test transaction with both cost and regular postings.
799    #[test]
800    fn test_calculate_residual_mixed_cost_and_simple() {
801        let txn = Transaction::new(date(2024, 1, 15), "Buy with fee")
802            .with_posting(
803                Posting::new("Assets:Stock", Amount::new(dec!(10), "AAPL")).with_cost(
804                    CostSpec::empty()
805                        .with_number_per(dec!(150.00))
806                        .with_currency("USD"),
807                ),
808            )
809            .with_posting(Posting::new(
810                "Expenses:Fees",
811                Amount::new(dec!(10.00), "USD"),
812            ))
813            .with_posting(Posting::new(
814                "Assets:Cash",
815                Amount::new(dec!(-1510.00), "USD"),
816            ));
817
818        let residual = calculate_residual(&txn);
819        // 10 * 150 + 10 - 1510 = 0
820        assert_eq!(residual.get("USD"), Some(&dec!(0)));
821    }
822
823    /// Test sell with cost basis and capital gains.
824    #[test]
825    fn test_calculate_residual_sell_with_gains() {
826        let txn = Transaction::new(date(2024, 6, 15), "Sell stock")
827            .with_posting(
828                Posting::new("Assets:Stock", Amount::new(dec!(-10), "AAPL"))
829                    .with_cost(
830                        CostSpec::empty()
831                            .with_number_per(dec!(150.00))
832                            .with_currency("USD"),
833                    )
834                    .with_price(PriceAnnotation::Unit(Amount::new(dec!(175.00), "USD"))),
835            )
836            .with_posting(Posting::new(
837                "Assets:Cash",
838                Amount::new(dec!(1750.00), "USD"),
839            ))
840            .with_posting(Posting::new(
841                "Income:CapitalGains",
842                Amount::new(dec!(-250.00), "USD"),
843            ));
844
845        let residual = calculate_residual(&txn);
846        // Stock posting with cost: -10 * 150 = -1500 USD (cost takes precedence)
847        // Cash: +1750 USD
848        // Gains: -250 USD
849        // Total: -1500 + 1750 - 250 = 0
850        assert_eq!(residual.get("USD"), Some(&dec!(0)));
851    }
852
853    /// Test multi-currency transaction with costs.
854    #[test]
855    fn test_calculate_residual_multi_currency_with_cost() {
856        let txn = Transaction::new(date(2024, 1, 15), "Multi-currency")
857            .with_posting(
858                Posting::new("Assets:Stock:US", Amount::new(dec!(10), "AAPL")).with_cost(
859                    CostSpec::empty()
860                        .with_number_per(dec!(150.00))
861                        .with_currency("USD"),
862                ),
863            )
864            .with_posting(
865                Posting::new("Assets:Stock:EU", Amount::new(dec!(5), "SAP")).with_cost(
866                    CostSpec::empty()
867                        .with_number_per(dec!(100.00))
868                        .with_currency("EUR"),
869                ),
870            )
871            .with_posting(Posting::new(
872                "Assets:Cash:USD",
873                Amount::new(dec!(-1500.00), "USD"),
874            ))
875            .with_posting(Posting::new(
876                "Assets:Cash:EUR",
877                Amount::new(dec!(-500.00), "EUR"),
878            ));
879
880        let residual = calculate_residual(&txn);
881        assert_eq!(residual.get("USD"), Some(&dec!(0)));
882        assert_eq!(residual.get("EUR"), Some(&dec!(0)));
883    }
884
885    /// Test that incomplete units (auto postings) are skipped.
886    #[test]
887    fn test_calculate_residual_skips_incomplete_units() {
888        let txn = Transaction::new(date(2024, 1, 15), "Test")
889            .with_posting(Posting::new(
890                "Expenses:Food",
891                Amount::new(dec!(50.00), "USD"),
892            ))
893            .with_posting(Posting::auto("Assets:Cash")); // No units
894
895        let residual = calculate_residual(&txn);
896        // Only the complete posting is counted
897        assert_eq!(residual.get("USD"), Some(&dec!(50.00)));
898    }
899
900    // =========================================================================
901    // Cost currency inference tests (issue #203)
902    // =========================================================================
903
904    /// Test cost currency is inferred from other postings.
905    /// This is the exact case from issue #203.
906    #[test]
907    fn test_calculate_residual_infers_cost_currency_from_other_posting() {
908        // 2026-01-01 * "Opening balance"
909        //   Assets:Vanguard:IRA:Trad:VFIFX  10 VFIFX {100}
910        //   Equity:Opening-Balances      -1000 USD
911        //
912        // Python beancount infers the cost currency as USD from the second posting.
913        let txn = Transaction::new(date(2026, 1, 1), "Opening balance")
914            .with_posting(
915                Posting::new(
916                    "Assets:Vanguard:IRA:Trad:VFIFX",
917                    Amount::new(dec!(10), "VFIFX"),
918                )
919                .with_cost(CostSpec::empty().with_number_per(dec!(100))),
920            )
921            .with_posting(Posting::new(
922                "Equity:Opening-Balances",
923                Amount::new(dec!(-1000), "USD"),
924            ));
925
926        let residual = calculate_residual(&txn);
927        // Cost posting should contribute 10 * 100 = 1000 USD (inferred from other posting)
928        // Equity posting contributes -1000 USD
929        // Residual should be 0
930        assert_eq!(
931            residual.get("USD"),
932            Some(&dec!(0)),
933            "Should balance when cost currency is inferred from other posting"
934        );
935        // VFIFX should not appear in residuals
936        assert_eq!(residual.get("VFIFX"), None);
937    }
938
939    /// Test cost currency inference with total cost.
940    #[test]
941    fn test_calculate_residual_infers_cost_currency_total_cost() {
942        // 10 VFIFX {{1000}} with -1000 USD posting
943        let txn = Transaction::new(date(2026, 1, 1), "Test")
944            .with_posting(
945                Posting::new("Assets:Stock", Amount::new(dec!(10), "VFIFX"))
946                    .with_cost(CostSpec::empty().with_number_total(dec!(1000))),
947            )
948            .with_posting(Posting::new("Assets:Cash", Amount::new(dec!(-1000), "USD")));
949
950        let residual = calculate_residual(&txn);
951        assert_eq!(residual.get("USD"), Some(&dec!(0)));
952    }
953
954    /// Test that explicit cost currency takes precedence over inference.
955    #[test]
956    fn test_calculate_residual_explicit_cost_currency_takes_precedence() {
957        // If cost has explicit currency, don't infer from other postings
958        let txn = Transaction::new(date(2026, 1, 1), "Test")
959            .with_posting(
960                Posting::new("Assets:Stock", Amount::new(dec!(10), "AAPL")).with_cost(
961                    CostSpec::empty()
962                        .with_number_per(dec!(100))
963                        .with_currency("EUR"), // Explicit EUR
964                ),
965            )
966            .with_posting(Posting::new(
967                "Assets:Cash",
968                Amount::new(dec!(-1000), "USD"), // USD posting
969            ));
970
971        let residual = calculate_residual(&txn);
972        // Should use EUR (explicit) not USD (from other posting)
973        assert_eq!(residual.get("EUR"), Some(&dec!(1000)));
974        assert_eq!(residual.get("USD"), Some(&dec!(-1000)));
975    }
976
977    /// Test that price annotation takes precedence over other posting inference.
978    #[test]
979    fn test_calculate_residual_price_annotation_takes_precedence() {
980        // If cost has price annotation, use that currency
981        let txn = Transaction::new(date(2026, 1, 1), "Test")
982            .with_posting(
983                Posting::new("Assets:Stock", Amount::new(dec!(10), "AAPL"))
984                    .with_cost(CostSpec::empty().with_number_per(dec!(100)))
985                    .with_price(PriceAnnotation::Unit(Amount::new(dec!(105), "EUR"))),
986            )
987            .with_posting(Posting::new("Assets:Cash", Amount::new(dec!(-1000), "USD")));
988
989        let residual = calculate_residual(&txn);
990        // Should use EUR (from price annotation) not USD (from other posting)
991        assert_eq!(residual.get("EUR"), Some(&dec!(1000)));
992        assert_eq!(residual.get("USD"), Some(&dec!(-1000)));
993    }
994
995    // =========================================================================
996    // infer_cost_currency_from_postings tests
997    // =========================================================================
998
999    /// Test that cost spec currency is used as fallback when no simple postings exist.
1000    #[test]
1001    fn test_infer_cost_currency_from_cost_spec() {
1002        // Transaction with only cost-spec posting - should get currency from cost spec
1003        let txn = Transaction::new(date(2022, 4, 16), "Free tokens")
1004            .with_posting(
1005                Posting::new("Assets:Crypto", Amount::new(dec!(100), "TOKEN")).with_cost(
1006                    CostSpec::empty()
1007                        .with_number_per(dec!(0))
1008                        .with_currency("USD"),
1009                ),
1010            )
1011            .with_posting(Posting::auto("Income:Bonus"));
1012
1013        let inferred = infer_cost_currency_from_postings(&txn);
1014        assert_eq!(inferred.as_deref(), Some("USD"));
1015    }
1016
1017    /// Test that simple posting currency takes precedence over cost spec currency.
1018    #[test]
1019    fn test_infer_cost_currency_simple_takes_precedence() {
1020        // Transaction with both simple posting and cost spec - simple should win
1021        let txn = Transaction::new(date(2022, 4, 16), "Trade")
1022            .with_posting(
1023                Posting::new("Assets:Crypto", Amount::new(dec!(100), "TOKEN")).with_cost(
1024                    CostSpec::empty()
1025                        .with_number_per(dec!(10))
1026                        .with_currency("EUR"),
1027                ),
1028            )
1029            .with_posting(Posting::new("Assets:Cash", Amount::new(dec!(-1000), "USD")));
1030
1031        let inferred = infer_cost_currency_from_postings(&txn);
1032        // Should get USD from the simple posting, not EUR from cost spec
1033        assert_eq!(inferred.as_deref(), Some("USD"));
1034    }
1035
1036    /// Test that zero-cost spec currency is still used for inference.
1037    #[test]
1038    fn test_infer_cost_currency_zero_cost() {
1039        // Zero cost should still provide the currency
1040        let txn = Transaction::new(date(2022, 4, 16), "Airdrop")
1041            .with_posting(
1042                Posting::new("Assets:Crypto", Amount::new(dec!(1000), "SHIB")).with_cost(
1043                    CostSpec::empty()
1044                        .with_number_per(dec!(0))
1045                        .with_currency("JPY"),
1046                ),
1047            )
1048            .with_posting(Posting::auto("Income:Airdrop"));
1049
1050        let inferred = infer_cost_currency_from_postings(&txn);
1051        assert_eq!(inferred.as_deref(), Some("JPY"));
1052    }
1053}