Skip to main content

rustledger_plugin/native/plugins/
check_average_cost.rs

1//! Validate that reducing postings use the average cost for accounts
2//! opened with the NONE booking method.
3
4use crate::types::{DirectiveData, PluginError, PluginInput, PluginOp, PluginOutput};
5
6use super::super::NativePlugin;
7
8/// Plugin that validates reducing postings against the running average cost
9/// for accounts opened with the `NONE` booking method.
10///
11/// When an account is opened with `NONE` booking, the ledger author is responsible
12/// for lot matching — there is no booker to enforce it. This plugin is a safety net
13/// for that case: it verifies that the cost basis used on any reducing leg is within
14/// tolerance of the running average cost basis in the account.
15///
16/// Accounts opened with any other booking method (`STRICT`, `STRICT_WITH_SIZE`,
17/// `FIFO`, `LIFO`, `HIFO`, `AVERAGE`, …) are skipped — their booker already validates
18/// lot matching, so re-checking here would produce false positives (see issue #907).
19/// This matches Python beancount's `beancount.plugins.check_average_cost` behavior.
20pub struct CheckAverageCostPlugin {
21    /// Tolerance for cost comparison (default: 0.01 = 1%).
22    tolerance: rust_decimal::Decimal,
23}
24
25impl CheckAverageCostPlugin {
26    /// Create with default tolerance (1%).
27    pub fn new() -> Self {
28        Self {
29            tolerance: rust_decimal::Decimal::new(1, 2), // 0.01 = 1%
30        }
31    }
32
33    /// Create with custom tolerance.
34    pub const fn with_tolerance(tolerance: rust_decimal::Decimal) -> Self {
35        Self { tolerance }
36    }
37}
38
39impl Default for CheckAverageCostPlugin {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45impl NativePlugin for CheckAverageCostPlugin {
46    fn name(&self) -> &'static str {
47        "check_average_cost"
48    }
49
50    fn description(&self) -> &'static str {
51        "Validate reducing postings match average cost"
52    }
53
54    fn process(&self, input: PluginInput) -> PluginOutput {
55        use rust_decimal::Decimal;
56        use std::collections::{HashMap, HashSet};
57        use std::str::FromStr;
58
59        // Parse optional tolerance from config
60        let tolerance = if let Some(config) = &input.config {
61            Decimal::from_str(config.trim()).unwrap_or(self.tolerance)
62        } else {
63            self.tolerance
64        };
65
66        // First pass: collect accounts opened with the NONE booking method.
67        // Only these accounts are subject to the average-cost check — see the
68        // type-level docstring and issue #907 for rationale.
69        let none_booking_accounts: HashSet<&str> = input
70            .directives
71            .iter()
72            .filter_map(|w| match &w.data {
73                DirectiveData::Open(o) => o
74                    .booking
75                    .as_deref()
76                    .filter(|b| b.eq_ignore_ascii_case("NONE"))
77                    .map(|_| o.account.as_str()),
78                _ => None,
79            })
80            .collect();
81
82        // Track average cost per account per commodity
83        // Key: (account, commodity) -> (total_units, total_cost)
84        let mut inventory: HashMap<(String, String), (Decimal, Decimal)> = HashMap::new();
85
86        let mut errors = Vec::new();
87
88        for wrapper in &input.directives {
89            if let DirectiveData::Transaction(txn) = &wrapper.data {
90                for posting in &txn.postings {
91                    // Only process accounts opened with NONE booking (issue #907).
92                    if !none_booking_accounts.contains(posting.account.as_str()) {
93                        continue;
94                    }
95
96                    // Only process postings with units and cost
97                    let Some(units) = &posting.units else {
98                        continue;
99                    };
100                    let Some(cost) = &posting.cost else {
101                        continue;
102                    };
103
104                    let units_num = Decimal::from_str(&units.number).unwrap_or_default();
105                    let Some(cost_currency) = &cost.currency else {
106                        continue;
107                    };
108
109                    let key = (posting.account.clone(), units.currency.clone());
110
111                    if units_num > Decimal::ZERO {
112                        // Acquisition: add to inventory
113                        let cost_per = cost
114                            .number_per
115                            .as_ref()
116                            .and_then(|s| Decimal::from_str(s).ok())
117                            .unwrap_or_default();
118
119                        let entry = inventory
120                            .entry(key)
121                            .or_insert((Decimal::ZERO, Decimal::ZERO));
122                        entry.0 += units_num; // total units
123                        entry.1 += units_num * cost_per; // total cost
124                    } else if units_num < Decimal::ZERO {
125                        // Reduction: check against average cost
126                        let entry = inventory.get(&key);
127
128                        if let Some((total_units, total_cost)) = entry
129                            && *total_units > Decimal::ZERO
130                        {
131                            let avg_cost = *total_cost / *total_units;
132
133                            // Get the cost used in this posting
134                            let used_cost = cost
135                                .number_per
136                                .as_ref()
137                                .and_then(|s| Decimal::from_str(s).ok())
138                                .unwrap_or_default();
139
140                            // Calculate relative difference
141                            let diff = (used_cost - avg_cost).abs();
142                            let relative_diff = if avg_cost == Decimal::ZERO {
143                                diff
144                            } else {
145                                diff / avg_cost
146                            };
147
148                            if relative_diff > tolerance {
149                                errors.push(PluginError::warning(format!(
150                                        "Sale of {} {} in {} uses cost {} {} but average cost is {} {} (difference: {:.2}%)",
151                                        units_num.abs(),
152                                        units.currency,
153                                        posting.account,
154                                        used_cost,
155                                        cost_currency,
156                                        avg_cost.round_dp(4),
157                                        cost_currency,
158                                        relative_diff * Decimal::from(100)
159                                    )));
160                            }
161
162                            // Update inventory
163                            let entry = inventory.get_mut(&key).unwrap();
164                            let units_sold = units_num.abs();
165                            let cost_removed = units_sold * avg_cost;
166                            entry.0 -= units_sold;
167                            entry.1 -= cost_removed;
168                        }
169                    }
170                }
171            }
172        }
173
174        PluginOutput {
175            ops: (0..input.directives.len()).map(PluginOp::Keep).collect(),
176            errors,
177        }
178    }
179}
180
181#[cfg(test)]
182mod check_average_cost_tests {
183    use super::*;
184    use crate::types::*;
185
186    fn open_with_none_booking(account: &str) -> DirectiveWrapper {
187        DirectiveWrapper {
188            directive_type: "open".to_string(),
189            date: "2024-01-01".to_string(),
190            filename: None,
191            lineno: None,
192            data: DirectiveData::Open(OpenData {
193                account: account.to_string(),
194                currencies: vec![],
195                booking: Some("NONE".to_string()),
196                metadata: vec![],
197            }),
198        }
199    }
200
201    #[test]
202    fn test_check_average_cost_matching() {
203        let plugin = CheckAverageCostPlugin::new();
204
205        let input = PluginInput {
206            directives: vec![
207                open_with_none_booking("Assets:Broker"),
208                DirectiveWrapper {
209                    directive_type: "transaction".to_string(),
210                    date: "2024-01-01".to_string(),
211                    filename: None,
212                    lineno: None,
213                    data: DirectiveData::Transaction(TransactionData {
214                        flag: "*".to_string(),
215                        payee: None,
216                        narration: "Buy".to_string(),
217                        tags: vec![],
218                        links: vec![],
219                        metadata: vec![],
220                        postings: vec![PostingData {
221                            account: "Assets:Broker".to_string(),
222                            units: Some(AmountData {
223                                number: "10".to_string(),
224                                currency: "AAPL".to_string(),
225                            }),
226                            cost: Some(CostData {
227                                number_per: Some("100.00".to_string()),
228                                number_total: None,
229                                currency: Some("USD".to_string()),
230                                date: None,
231                                label: None,
232                                merge: false,
233                            }),
234                            price: None,
235                            flag: None,
236                            metadata: vec![],
237                        }],
238                    }),
239                },
240                DirectiveWrapper {
241                    directive_type: "transaction".to_string(),
242                    date: "2024-02-01".to_string(),
243                    filename: None,
244                    lineno: None,
245                    data: DirectiveData::Transaction(TransactionData {
246                        flag: "*".to_string(),
247                        payee: None,
248                        narration: "Sell at avg cost".to_string(),
249                        tags: vec![],
250                        links: vec![],
251                        metadata: vec![],
252                        postings: vec![PostingData {
253                            account: "Assets:Broker".to_string(),
254                            units: Some(AmountData {
255                                number: "-5".to_string(),
256                                currency: "AAPL".to_string(),
257                            }),
258                            cost: Some(CostData {
259                                number_per: Some("100.00".to_string()), // Matches average
260                                number_total: None,
261                                currency: Some("USD".to_string()),
262                                date: None,
263                                label: None,
264                                merge: false,
265                            }),
266                            price: None,
267                            flag: None,
268                            metadata: vec![],
269                        }],
270                    }),
271                },
272            ],
273            options: PluginOptions {
274                operating_currencies: vec!["USD".to_string()],
275                title: None,
276            },
277            config: None,
278        };
279
280        let output = plugin.process(input);
281        assert_eq!(output.errors.len(), 0);
282    }
283
284    #[test]
285    fn test_check_average_cost_mismatch() {
286        let plugin = CheckAverageCostPlugin::new();
287
288        let input = PluginInput {
289            directives: vec![
290                open_with_none_booking("Assets:Broker"),
291                DirectiveWrapper {
292                    directive_type: "transaction".to_string(),
293                    date: "2024-01-01".to_string(),
294                    filename: None,
295                    lineno: None,
296                    data: DirectiveData::Transaction(TransactionData {
297                        flag: "*".to_string(),
298                        payee: None,
299                        narration: "Buy at 100".to_string(),
300                        tags: vec![],
301                        links: vec![],
302                        metadata: vec![],
303                        postings: vec![PostingData {
304                            account: "Assets:Broker".to_string(),
305                            units: Some(AmountData {
306                                number: "10".to_string(),
307                                currency: "AAPL".to_string(),
308                            }),
309                            cost: Some(CostData {
310                                number_per: Some("100.00".to_string()),
311                                number_total: None,
312                                currency: Some("USD".to_string()),
313                                date: None,
314                                label: None,
315                                merge: false,
316                            }),
317                            price: None,
318                            flag: None,
319                            metadata: vec![],
320                        }],
321                    }),
322                },
323                DirectiveWrapper {
324                    directive_type: "transaction".to_string(),
325                    date: "2024-02-01".to_string(),
326                    filename: None,
327                    lineno: None,
328                    data: DirectiveData::Transaction(TransactionData {
329                        flag: "*".to_string(),
330                        payee: None,
331                        narration: "Sell at wrong cost".to_string(),
332                        tags: vec![],
333                        links: vec![],
334                        metadata: vec![],
335                        postings: vec![PostingData {
336                            account: "Assets:Broker".to_string(),
337                            units: Some(AmountData {
338                                number: "-5".to_string(),
339                                currency: "AAPL".to_string(),
340                            }),
341                            cost: Some(CostData {
342                                number_per: Some("90.00".to_string()), // 10% different from avg
343                                number_total: None,
344                                currency: Some("USD".to_string()),
345                                date: None,
346                                label: None,
347                                merge: false,
348                            }),
349                            price: None,
350                            flag: None,
351                            metadata: vec![],
352                        }],
353                    }),
354                },
355            ],
356            options: PluginOptions {
357                operating_currencies: vec!["USD".to_string()],
358                title: None,
359            },
360            config: None,
361        };
362
363        let output = plugin.process(input);
364        assert_eq!(output.errors.len(), 1);
365        assert!(output.errors[0].message.contains("average cost"));
366    }
367
368    #[test]
369    fn test_check_average_cost_multiple_buys() {
370        let plugin = CheckAverageCostPlugin::new();
371
372        // Buy 10 at $100, then 10 at $120 -> avg = $110
373        let input = PluginInput {
374            directives: vec![
375                open_with_none_booking("Assets:Broker"),
376                DirectiveWrapper {
377                    directive_type: "transaction".to_string(),
378                    date: "2024-01-01".to_string(),
379                    filename: None,
380                    lineno: None,
381                    data: DirectiveData::Transaction(TransactionData {
382                        flag: "*".to_string(),
383                        payee: None,
384                        narration: "Buy at 100".to_string(),
385                        tags: vec![],
386                        links: vec![],
387                        metadata: vec![],
388                        postings: vec![PostingData {
389                            account: "Assets:Broker".to_string(),
390                            units: Some(AmountData {
391                                number: "10".to_string(),
392                                currency: "AAPL".to_string(),
393                            }),
394                            cost: Some(CostData {
395                                number_per: Some("100.00".to_string()),
396                                number_total: None,
397                                currency: Some("USD".to_string()),
398                                date: None,
399                                label: None,
400                                merge: false,
401                            }),
402                            price: None,
403                            flag: None,
404                            metadata: vec![],
405                        }],
406                    }),
407                },
408                DirectiveWrapper {
409                    directive_type: "transaction".to_string(),
410                    date: "2024-01-15".to_string(),
411                    filename: None,
412                    lineno: None,
413                    data: DirectiveData::Transaction(TransactionData {
414                        flag: "*".to_string(),
415                        payee: None,
416                        narration: "Buy at 120".to_string(),
417                        tags: vec![],
418                        links: vec![],
419                        metadata: vec![],
420                        postings: vec![PostingData {
421                            account: "Assets:Broker".to_string(),
422                            units: Some(AmountData {
423                                number: "10".to_string(),
424                                currency: "AAPL".to_string(),
425                            }),
426                            cost: Some(CostData {
427                                number_per: Some("120.00".to_string()),
428                                number_total: None,
429                                currency: Some("USD".to_string()),
430                                date: None,
431                                label: None,
432                                merge: false,
433                            }),
434                            price: None,
435                            flag: None,
436                            metadata: vec![],
437                        }],
438                    }),
439                },
440                DirectiveWrapper {
441                    directive_type: "transaction".to_string(),
442                    date: "2024-02-01".to_string(),
443                    filename: None,
444                    lineno: None,
445                    data: DirectiveData::Transaction(TransactionData {
446                        flag: "*".to_string(),
447                        payee: None,
448                        narration: "Sell at avg cost".to_string(),
449                        tags: vec![],
450                        links: vec![],
451                        metadata: vec![],
452                        postings: vec![PostingData {
453                            account: "Assets:Broker".to_string(),
454                            units: Some(AmountData {
455                                number: "-5".to_string(),
456                                currency: "AAPL".to_string(),
457                            }),
458                            cost: Some(CostData {
459                                number_per: Some("110.00".to_string()), // Matches average
460                                number_total: None,
461                                currency: Some("USD".to_string()),
462                                date: None,
463                                label: None,
464                                merge: false,
465                            }),
466                            price: None,
467                            flag: None,
468                            metadata: vec![],
469                        }],
470                    }),
471                },
472            ],
473            options: PluginOptions {
474                operating_currencies: vec!["USD".to_string()],
475                title: None,
476            },
477            config: None,
478        };
479
480        let output = plugin.process(input);
481        assert_eq!(output.errors.len(), 0);
482    }
483
484    #[test]
485    fn test_non_none_booking_is_skipped() {
486        // Regression test for issue #907: accounts opened with any booking
487        // method other than NONE (including the default, an unspecified
488        // booking, or explicit STRICT/FIFO/etc.) must be skipped entirely.
489        // The booker is responsible for lot matching in those cases, so
490        // re-checking here produces false positives like the reporter's
491        // "500 USD vs. avg 566.67 USD" error.
492        let plugin = CheckAverageCostPlugin::new();
493
494        let input = PluginInput {
495            directives: vec![
496                // No booking specified (booking: None). Whatever the effective
497                // default is — STRICT unless overridden by `option
498                // "booking_method"` or a loader setting — it is NOT `NONE`,
499                // so the plugin MUST skip this account.
500                DirectiveWrapper {
501                    directive_type: "open".to_string(),
502                    date: "2024-01-01".to_string(),
503                    filename: None,
504                    lineno: None,
505                    data: DirectiveData::Open(OpenData {
506                        account: "Assets:Broker".to_string(),
507                        currencies: vec![],
508                        booking: None,
509                        metadata: vec![],
510                    }),
511                },
512                DirectiveWrapper {
513                    directive_type: "transaction".to_string(),
514                    date: "2024-01-01".to_string(),
515                    filename: None,
516                    lineno: None,
517                    data: DirectiveData::Transaction(TransactionData {
518                        flag: "*".to_string(),
519                        payee: None,
520                        narration: "Buy at 100".to_string(),
521                        tags: vec![],
522                        links: vec![],
523                        metadata: vec![],
524                        postings: vec![PostingData {
525                            account: "Assets:Broker".to_string(),
526                            units: Some(AmountData {
527                                number: "10".to_string(),
528                                currency: "AAPL".to_string(),
529                            }),
530                            cost: Some(CostData {
531                                number_per: Some("100.00".to_string()),
532                                number_total: None,
533                                currency: Some("USD".to_string()),
534                                date: None,
535                                label: None,
536                                merge: false,
537                            }),
538                            price: None,
539                            flag: None,
540                            metadata: vec![],
541                        }],
542                    }),
543                },
544                DirectiveWrapper {
545                    directive_type: "transaction".to_string(),
546                    date: "2024-02-01".to_string(),
547                    filename: None,
548                    lineno: None,
549                    data: DirectiveData::Transaction(TransactionData {
550                        flag: "*".to_string(),
551                        payee: None,
552                        narration: "Sell at way-off cost".to_string(),
553                        tags: vec![],
554                        links: vec![],
555                        metadata: vec![],
556                        postings: vec![PostingData {
557                            account: "Assets:Broker".to_string(),
558                            units: Some(AmountData {
559                                number: "-5".to_string(),
560                                currency: "AAPL".to_string(),
561                            }),
562                            cost: Some(CostData {
563                                // 50% off the average — would fire for NONE,
564                                // but must be silent for STRICT/default.
565                                number_per: Some("50.00".to_string()),
566                                number_total: None,
567                                currency: Some("USD".to_string()),
568                                date: None,
569                                label: None,
570                                merge: false,
571                            }),
572                            price: None,
573                            flag: None,
574                            metadata: vec![],
575                        }],
576                    }),
577                },
578            ],
579            options: PluginOptions {
580                operating_currencies: vec!["USD".to_string()],
581                title: None,
582            },
583            config: None,
584        };
585
586        let output = plugin.process(input);
587        assert!(
588            output.errors.is_empty(),
589            "non-NONE accounts must be skipped; got errors: {:?}",
590            output.errors
591        );
592    }
593}