Skip to main content

rustledger_plugin/native/plugins/
check_average_cost.rs

1//! Validate that reducing postings use the average cost for accounts
2//! opened with the NONE booking method.
3
4use crate::types::{DirectiveData, PluginError, PluginInput, PluginOp, PluginOutput};
5
6use super::super::{NativePlugin, RegularPlugin};
7
8/// Plugin that validates reducing postings against the running average cost
9/// for accounts opened with the `NONE` booking method.
10///
11/// When an account is opened with `NONE` booking, the ledger author is responsible
12/// for lot matching — there is no booker to enforce it. This plugin is a safety net
13/// for that case: it verifies that the cost basis used on any reducing leg is within
14/// tolerance of the running average cost basis in the account.
15///
16/// Accounts opened with any other booking method (`STRICT`, `STRICT_WITH_SIZE`,
17/// `FIFO`, `LIFO`, `HIFO`, `AVERAGE`, …) are skipped — their booker already validates
18/// lot matching, so re-checking here would produce false positives (see issue #907).
19/// This matches Python beancount's `beancount.plugins.check_average_cost` behavior.
20pub struct CheckAverageCostPlugin {
21    /// Tolerance for cost comparison (default: 0.01 = 1%).
22    tolerance: rust_decimal::Decimal,
23}
24
25impl CheckAverageCostPlugin {
26    /// Create with default tolerance (1%).
27    pub fn new() -> Self {
28        Self {
29            tolerance: rust_decimal::Decimal::new(1, 2), // 0.01 = 1%
30        }
31    }
32
33    /// Create with custom tolerance.
34    pub const fn with_tolerance(tolerance: rust_decimal::Decimal) -> Self {
35        Self { tolerance }
36    }
37}
38
39impl Default for CheckAverageCostPlugin {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl NativePlugin for CheckAverageCostPlugin {
46    fn name(&self) -> &'static str {
47        "check_average_cost"
48    }
49
50    fn description(&self) -> &'static str {
51        "Validate reducing postings match average cost"
52    }
53
54    fn process(&self, input: PluginInput) -> PluginOutput {
55        use rust_decimal::Decimal;
56        use std::collections::{HashMap, HashSet};
57        use std::str::FromStr;
58
59        // Parse optional tolerance from config
60        let tolerance = if let Some(config) = &input.config {
61            Decimal::from_str(config.trim()).unwrap_or(self.tolerance)
62        } else {
63            self.tolerance
64        };
65
66        // First pass: collect accounts opened with the NONE booking method.
67        // Only these accounts are subject to the average-cost check — see the
68        // type-level docstring and issue #907 for rationale.
69        let none_booking_accounts: HashSet<&str> = input
70            .directives
71            .iter()
72            .filter_map(|w| match &w.data {
73                DirectiveData::Open(o) => o
74                    .booking
75                    .as_deref()
76                    .filter(|b| b.eq_ignore_ascii_case("NONE"))
77                    .map(|_| o.account.as_str()),
78                _ => None,
79            })
80            .collect();
81
82        // Track average cost per account per commodity
83        // Key: (account, commodity) -> (total_units, total_cost)
84        let mut inventory: HashMap<(String, String), (Decimal, Decimal)> = HashMap::new();
85
86        let mut errors = Vec::new();
87
88        for wrapper in &input.directives {
89            if let DirectiveData::Transaction(txn) = &wrapper.data {
90                for posting in &txn.postings {
91                    // Only process accounts opened with NONE booking (issue #907).
92                    if !none_booking_accounts.contains(posting.account.as_str()) {
93                        continue;
94                    }
95
96                    // Only process postings with units and cost
97                    let Some(units) = &posting.units else {
98                        continue;
99                    };
100                    let Some(cost) = &posting.cost else {
101                        continue;
102                    };
103
104                    let units_num = Decimal::from_str(&units.number).unwrap_or_default();
105                    let Some(cost_currency) = &cost.currency else {
106                        continue;
107                    };
108
109                    let key = (posting.account.clone(), units.currency.clone());
110
111                    if units_num > Decimal::ZERO {
112                        // Acquisition: add to inventory. `per_unit()`
113                        // covers both PerUnit and PerUnitFromTotal;
114                        // raw Total (pre-booking) has no per-unit yet
115                        // and yields zero (gate is permissive there).
116                        let cost_per = cost
117                            .number
118                            .as_ref()
119                            .and_then(|cn| cn.per_unit())
120                            .map(|s| Decimal::from_str(s).unwrap_or_default())
121                            .unwrap_or_default();
122
123                        let entry = inventory
124                            .entry(key)
125                            .or_insert((Decimal::ZERO, Decimal::ZERO));
126                        entry.0 += units_num; // total units
127                        entry.1 += units_num * cost_per; // total cost
128                    } else if units_num < Decimal::ZERO {
129                        // Reduction: check against average cost
130                        let entry = inventory.get(&key);
131
132                        if let Some((total_units, total_cost)) = entry
133                            && *total_units > Decimal::ZERO
134                        {
135                            let avg_cost = *total_cost / *total_units;
136
137                            // Get the cost used in this posting.
138                            // Same per-unit-only read as above.
139                            let used_cost = cost
140                                .number
141                                .as_ref()
142                                .and_then(|cn| cn.per_unit())
143                                .map(|s| Decimal::from_str(s).unwrap_or_default())
144                                .unwrap_or_default();
145
146                            // Calculate relative difference
147                            let diff = (used_cost - avg_cost).abs();
148                            let relative_diff = if avg_cost == Decimal::ZERO {
149                                diff
150                            } else {
151                                diff / avg_cost
152                            };
153
154                            if relative_diff > tolerance {
155                                errors.push(PluginError::warning(format!(
156                                        "Sale of {} {} in {} uses cost {} {} but average cost is {} {} (difference: {:.2}%)",
157                                        units_num.abs(),
158                                        units.currency,
159                                        posting.account,
160                                        used_cost,
161                                        cost_currency,
162                                        avg_cost.round_dp(4),
163                                        cost_currency,
164                                        relative_diff * Decimal::from(100)
165                                    )));
166                            }
167
168                            // Update inventory
169                            let entry = inventory.get_mut(&key).unwrap();
170                            let units_sold = units_num.abs();
171                            let cost_removed = units_sold * avg_cost;
172                            entry.0 -= units_sold;
173                            entry.1 -= cost_removed;
174                        }
175                    }
176                }
177            }
178        }
179
180        PluginOutput {
181            ops: (0..input.directives.len()).map(PluginOp::Keep).collect(),
182            errors,
183        }
184    }
185}
186
187impl RegularPlugin for CheckAverageCostPlugin {}
188
189#[cfg(test)]
190mod check_average_cost_tests {
191    use super::*;
192    use crate::types::*;
193
194    fn open_with_none_booking(account: &str) -> DirectiveWrapper {
195        DirectiveWrapper {
196            directive_type: "open".to_string(),
197            date: "2024-01-01".to_string(),
198            filename: None,
199            lineno: None,
200            data: DirectiveData::Open(OpenData {
201                account: account.to_string(),
202                currencies: vec![],
203                booking: Some("NONE".to_string()),
204                metadata: vec![],
205            }),
206        }
207    }
208
209    #[test]
210    fn test_check_average_cost_matching() {
211        let plugin = CheckAverageCostPlugin::new();
212
213        let input = PluginInput {
214            directives: vec![
215                open_with_none_booking("Assets:Broker"),
216                DirectiveWrapper {
217                    directive_type: "transaction".to_string(),
218                    date: "2024-01-01".to_string(),
219                    filename: None,
220                    lineno: None,
221                    data: DirectiveData::Transaction(TransactionData {
222                        flag: "*".to_string(),
223                        payee: None,
224                        narration: "Buy".to_string(),
225                        tags: vec![],
226                        links: vec![],
227                        metadata: vec![],
228                        postings: vec![PostingData {
229                            account: "Assets:Broker".to_string(),
230                            units: Some(AmountData {
231                                number: "10".to_string(),
232                                currency: "AAPL".to_string(),
233                            }),
234                            cost: Some(CostData {
235                                number: Some(rustledger_plugin_types::CostNumberData::PerUnit {
236                                    value: "100.00".to_string(),
237                                }),
238                                currency: Some("USD".to_string()),
239                                date: None,
240                                label: None,
241                                merge: false,
242                            }),
243                            price: None,
244                            flag: None,
245                            metadata: vec![],
246                            span: None,
247                        }],
248                    }),
249                },
250                DirectiveWrapper {
251                    directive_type: "transaction".to_string(),
252                    date: "2024-02-01".to_string(),
253                    filename: None,
254                    lineno: None,
255                    data: DirectiveData::Transaction(TransactionData {
256                        flag: "*".to_string(),
257                        payee: None,
258                        narration: "Sell at avg cost".to_string(),
259                        tags: vec![],
260                        links: vec![],
261                        metadata: vec![],
262                        postings: vec![PostingData {
263                            account: "Assets:Broker".to_string(),
264                            units: Some(AmountData {
265                                number: "-5".to_string(),
266                                currency: "AAPL".to_string(),
267                            }),
268                            cost: Some(CostData {
269                                number: Some(rustledger_plugin_types::CostNumberData::PerUnit {
270                                    value: "100.00".to_string(),
271                                }), // Matches average
272                                currency: Some("USD".to_string()),
273                                date: None,
274                                label: None,
275                                merge: false,
276                            }),
277                            price: None,
278                            flag: None,
279                            metadata: vec![],
280                            span: None,
281                        }],
282                    }),
283                },
284            ],
285            options: PluginOptions {
286                operating_currencies: vec!["USD".to_string()],
287                title: None,
288            },
289            config: None,
290        };
291
292        let output = plugin.process(input);
293        assert_eq!(output.errors.len(), 0);
294    }
295
296    #[test]
297    fn test_check_average_cost_mismatch() {
298        let plugin = CheckAverageCostPlugin::new();
299
300        let input = PluginInput {
301            directives: vec![
302                open_with_none_booking("Assets:Broker"),
303                DirectiveWrapper {
304                    directive_type: "transaction".to_string(),
305                    date: "2024-01-01".to_string(),
306                    filename: None,
307                    lineno: None,
308                    data: DirectiveData::Transaction(TransactionData {
309                        flag: "*".to_string(),
310                        payee: None,
311                        narration: "Buy at 100".to_string(),
312                        tags: vec![],
313                        links: vec![],
314                        metadata: vec![],
315                        postings: vec![PostingData {
316                            account: "Assets:Broker".to_string(),
317                            units: Some(AmountData {
318                                number: "10".to_string(),
319                                currency: "AAPL".to_string(),
320                            }),
321                            cost: Some(CostData {
322                                number: Some(rustledger_plugin_types::CostNumberData::PerUnit {
323                                    value: "100.00".to_string(),
324                                }),
325                                currency: Some("USD".to_string()),
326                                date: None,
327                                label: None,
328                                merge: false,
329                            }),
330                            price: None,
331                            flag: None,
332                            metadata: vec![],
333                            span: None,
334                        }],
335                    }),
336                },
337                DirectiveWrapper {
338                    directive_type: "transaction".to_string(),
339                    date: "2024-02-01".to_string(),
340                    filename: None,
341                    lineno: None,
342                    data: DirectiveData::Transaction(TransactionData {
343                        flag: "*".to_string(),
344                        payee: None,
345                        narration: "Sell at wrong cost".to_string(),
346                        tags: vec![],
347                        links: vec![],
348                        metadata: vec![],
349                        postings: vec![PostingData {
350                            account: "Assets:Broker".to_string(),
351                            units: Some(AmountData {
352                                number: "-5".to_string(),
353                                currency: "AAPL".to_string(),
354                            }),
355                            cost: Some(CostData {
356                                number: Some(rustledger_plugin_types::CostNumberData::PerUnit {
357                                    value: "90.00".to_string(),
358                                }), // 10% different from avg
359                                currency: Some("USD".to_string()),
360                                date: None,
361                                label: None,
362                                merge: false,
363                            }),
364                            price: None,
365                            flag: None,
366                            metadata: vec![],
367                            span: None,
368                        }],
369                    }),
370                },
371            ],
372            options: PluginOptions {
373                operating_currencies: vec!["USD".to_string()],
374                title: None,
375            },
376            config: None,
377        };
378
379        let output = plugin.process(input);
380        assert_eq!(output.errors.len(), 1);
381        assert!(output.errors[0].message.contains("average cost"));
382    }
383
384    #[test]
385    fn test_check_average_cost_multiple_buys() {
386        let plugin = CheckAverageCostPlugin::new();
387
388        // Buy 10 at $100, then 10 at $120 -> avg = $110
389        let input = PluginInput {
390            directives: vec![
391                open_with_none_booking("Assets:Broker"),
392                DirectiveWrapper {
393                    directive_type: "transaction".to_string(),
394                    date: "2024-01-01".to_string(),
395                    filename: None,
396                    lineno: None,
397                    data: DirectiveData::Transaction(TransactionData {
398                        flag: "*".to_string(),
399                        payee: None,
400                        narration: "Buy at 100".to_string(),
401                        tags: vec![],
402                        links: vec![],
403                        metadata: vec![],
404                        postings: vec![PostingData {
405                            account: "Assets:Broker".to_string(),
406                            units: Some(AmountData {
407                                number: "10".to_string(),
408                                currency: "AAPL".to_string(),
409                            }),
410                            cost: Some(CostData {
411                                number: Some(rustledger_plugin_types::CostNumberData::PerUnit {
412                                    value: "100.00".to_string(),
413                                }),
414                                currency: Some("USD".to_string()),
415                                date: None,
416                                label: None,
417                                merge: false,
418                            }),
419                            price: None,
420                            flag: None,
421                            metadata: vec![],
422                            span: None,
423                        }],
424                    }),
425                },
426                DirectiveWrapper {
427                    directive_type: "transaction".to_string(),
428                    date: "2024-01-15".to_string(),
429                    filename: None,
430                    lineno: None,
431                    data: DirectiveData::Transaction(TransactionData {
432                        flag: "*".to_string(),
433                        payee: None,
434                        narration: "Buy at 120".to_string(),
435                        tags: vec![],
436                        links: vec![],
437                        metadata: vec![],
438                        postings: vec![PostingData {
439                            account: "Assets:Broker".to_string(),
440                            units: Some(AmountData {
441                                number: "10".to_string(),
442                                currency: "AAPL".to_string(),
443                            }),
444                            cost: Some(CostData {
445                                number: Some(rustledger_plugin_types::CostNumberData::PerUnit {
446                                    value: "120.00".to_string(),
447                                }),
448                                currency: Some("USD".to_string()),
449                                date: None,
450                                label: None,
451                                merge: false,
452                            }),
453                            price: None,
454                            flag: None,
455                            metadata: vec![],
456                            span: None,
457                        }],
458                    }),
459                },
460                DirectiveWrapper {
461                    directive_type: "transaction".to_string(),
462                    date: "2024-02-01".to_string(),
463                    filename: None,
464                    lineno: None,
465                    data: DirectiveData::Transaction(TransactionData {
466                        flag: "*".to_string(),
467                        payee: None,
468                        narration: "Sell at avg cost".to_string(),
469                        tags: vec![],
470                        links: vec![],
471                        metadata: vec![],
472                        postings: vec![PostingData {
473                            account: "Assets:Broker".to_string(),
474                            units: Some(AmountData {
475                                number: "-5".to_string(),
476                                currency: "AAPL".to_string(),
477                            }),
478                            cost: Some(CostData {
479                                number: Some(rustledger_plugin_types::CostNumberData::PerUnit {
480                                    value: "110.00".to_string(),
481                                }), // Matches average
482                                currency: Some("USD".to_string()),
483                                date: None,
484                                label: None,
485                                merge: false,
486                            }),
487                            price: None,
488                            flag: None,
489                            metadata: vec![],
490                            span: None,
491                        }],
492                    }),
493                },
494            ],
495            options: PluginOptions {
496                operating_currencies: vec!["USD".to_string()],
497                title: None,
498            },
499            config: None,
500        };
501
502        let output = plugin.process(input);
503        assert_eq!(output.errors.len(), 0);
504    }
505
506    #[test]
507    fn test_non_none_booking_is_skipped() {
508        // Regression test for issue #907: accounts opened with any booking
509        // method other than NONE (including the default, an unspecified
510        // booking, or explicit STRICT/FIFO/etc.) must be skipped entirely.
511        // The booker is responsible for lot matching in those cases, so
512        // re-checking here produces false positives like the reporter's
513        // "500 USD vs. avg 566.67 USD" error.
514        let plugin = CheckAverageCostPlugin::new();
515
516        let input = PluginInput {
517            directives: vec![
518                // No booking specified (booking: None). Whatever the effective
519                // default is — STRICT unless overridden by `option
520                // "booking_method"` or a loader setting — it is NOT `NONE`,
521                // so the plugin MUST skip this account.
522                DirectiveWrapper {
523                    directive_type: "open".to_string(),
524                    date: "2024-01-01".to_string(),
525                    filename: None,
526                    lineno: None,
527                    data: DirectiveData::Open(OpenData {
528                        account: "Assets:Broker".to_string(),
529                        currencies: vec![],
530                        booking: None,
531                        metadata: vec![],
532                    }),
533                },
534                DirectiveWrapper {
535                    directive_type: "transaction".to_string(),
536                    date: "2024-01-01".to_string(),
537                    filename: None,
538                    lineno: None,
539                    data: DirectiveData::Transaction(TransactionData {
540                        flag: "*".to_string(),
541                        payee: None,
542                        narration: "Buy at 100".to_string(),
543                        tags: vec![],
544                        links: vec![],
545                        metadata: vec![],
546                        postings: vec![PostingData {
547                            account: "Assets:Broker".to_string(),
548                            units: Some(AmountData {
549                                number: "10".to_string(),
550                                currency: "AAPL".to_string(),
551                            }),
552                            cost: Some(CostData {
553                                number: Some(rustledger_plugin_types::CostNumberData::PerUnit {
554                                    value: "100.00".to_string(),
555                                }),
556                                currency: Some("USD".to_string()),
557                                date: None,
558                                label: None,
559                                merge: false,
560                            }),
561                            price: None,
562                            flag: None,
563                            metadata: vec![],
564                            span: None,
565                        }],
566                    }),
567                },
568                DirectiveWrapper {
569                    directive_type: "transaction".to_string(),
570                    date: "2024-02-01".to_string(),
571                    filename: None,
572                    lineno: None,
573                    data: DirectiveData::Transaction(TransactionData {
574                        flag: "*".to_string(),
575                        payee: None,
576                        narration: "Sell at way-off cost".to_string(),
577                        tags: vec![],
578                        links: vec![],
579                        metadata: vec![],
580                        postings: vec![PostingData {
581                            account: "Assets:Broker".to_string(),
582                            units: Some(AmountData {
583                                number: "-5".to_string(),
584                                currency: "AAPL".to_string(),
585                            }),
586                            cost: Some(CostData {
587                                // 50% off the average — would fire for NONE,
588                                // but must be silent for STRICT/default.
589                                number: Some(rustledger_plugin_types::CostNumberData::PerUnit {
590                                    value: "50.00".to_string(),
591                                }),
592                                currency: Some("USD".to_string()),
593                                date: None,
594                                label: None,
595                                merge: false,
596                            }),
597                            price: None,
598                            flag: None,
599                            metadata: vec![],
600                            span: None,
601                        }],
602                    }),
603                },
604            ],
605            options: PluginOptions {
606                operating_currencies: vec!["USD".to_string()],
607                title: None,
608            },
609            config: None,
610        };
611
612        let output = plugin.process(input);
613        assert!(
614            output.errors.is_empty(),
615            "non-NONE accounts must be skipped; got errors: {:?}",
616            output.errors
617        );
618    }
619}