Skip to main content

rustledger_plugin/native/plugins/
check_average_cost.rs

1//! Validate reducing postings use average cost for accounts with NONE booking.
2
3use crate::types::{DirectiveData, PluginError, PluginInput, PluginOutput};
4
5use super::super::NativePlugin;
6
7/// Plugin that validates reducing postings use average cost for accounts with NONE booking.
8///
9/// For accounts with booking method NONE (average cost), when selling/reducing positions,
10/// this plugin verifies that the cost basis used matches the calculated average cost
11/// within a specified tolerance.
12pub struct CheckAverageCostPlugin {
13    /// Tolerance for cost comparison (default: 0.01 = 1%).
14    tolerance: rust_decimal::Decimal,
15}
16
17impl CheckAverageCostPlugin {
18    /// Create with default tolerance (1%).
19    pub fn new() -> Self {
20        Self {
21            tolerance: rust_decimal::Decimal::new(1, 2), // 0.01 = 1%
22        }
23    }
24
25    /// Create with custom tolerance.
26    pub const fn with_tolerance(tolerance: rust_decimal::Decimal) -> Self {
27        Self { tolerance }
28    }
29}
30
31impl Default for CheckAverageCostPlugin {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37impl NativePlugin for CheckAverageCostPlugin {
38    fn name(&self) -> &'static str {
39        "check_average_cost"
40    }
41
42    fn description(&self) -> &'static str {
43        "Validate reducing postings match average cost"
44    }
45
46    fn process(&self, input: PluginInput) -> PluginOutput {
47        use rust_decimal::Decimal;
48        use std::collections::HashMap;
49        use std::str::FromStr;
50
51        // Parse optional tolerance from config
52        let tolerance = if let Some(config) = &input.config {
53            Decimal::from_str(config.trim()).unwrap_or(self.tolerance)
54        } else {
55            self.tolerance
56        };
57
58        // Track average cost per account per commodity
59        // Key: (account, commodity) -> (total_units, total_cost)
60        let mut inventory: HashMap<(String, String), (Decimal, Decimal)> = HashMap::new();
61
62        let mut errors = Vec::new();
63
64        for wrapper in &input.directives {
65            if let DirectiveData::Transaction(txn) = &wrapper.data {
66                for posting in &txn.postings {
67                    // Only process postings with units and cost
68                    let Some(units) = &posting.units else {
69                        continue;
70                    };
71                    let Some(cost) = &posting.cost else {
72                        continue;
73                    };
74
75                    let units_num = Decimal::from_str(&units.number).unwrap_or_default();
76                    let Some(cost_currency) = &cost.currency else {
77                        continue;
78                    };
79
80                    let key = (posting.account.clone(), units.currency.clone());
81
82                    if units_num > Decimal::ZERO {
83                        // Acquisition: add to inventory
84                        let cost_per = cost
85                            .number_per
86                            .as_ref()
87                            .and_then(|s| Decimal::from_str(s).ok())
88                            .unwrap_or_default();
89
90                        let entry = inventory
91                            .entry(key)
92                            .or_insert((Decimal::ZERO, Decimal::ZERO));
93                        entry.0 += units_num; // total units
94                        entry.1 += units_num * cost_per; // total cost
95                    } else if units_num < Decimal::ZERO {
96                        // Reduction: check against average cost
97                        let entry = inventory.get(&key);
98
99                        if let Some((total_units, total_cost)) = entry
100                            && *total_units > Decimal::ZERO
101                        {
102                            let avg_cost = *total_cost / *total_units;
103
104                            // Get the cost used in this posting
105                            let used_cost = cost
106                                .number_per
107                                .as_ref()
108                                .and_then(|s| Decimal::from_str(s).ok())
109                                .unwrap_or_default();
110
111                            // Calculate relative difference
112                            let diff = (used_cost - avg_cost).abs();
113                            let relative_diff = if avg_cost == Decimal::ZERO {
114                                diff
115                            } else {
116                                diff / avg_cost
117                            };
118
119                            if relative_diff > tolerance {
120                                errors.push(PluginError::warning(format!(
121                                        "Sale of {} {} in {} uses cost {} {} but average cost is {} {} (difference: {:.2}%)",
122                                        units_num.abs(),
123                                        units.currency,
124                                        posting.account,
125                                        used_cost,
126                                        cost_currency,
127                                        avg_cost.round_dp(4),
128                                        cost_currency,
129                                        relative_diff * Decimal::from(100)
130                                    )));
131                            }
132
133                            // Update inventory
134                            let entry = inventory.get_mut(&key).unwrap();
135                            let units_sold = units_num.abs();
136                            let cost_removed = units_sold * avg_cost;
137                            entry.0 -= units_sold;
138                            entry.1 -= cost_removed;
139                        }
140                    }
141                }
142            }
143        }
144
145        PluginOutput {
146            directives: input.directives,
147            errors,
148        }
149    }
150}
151
152#[cfg(test)]
153mod check_average_cost_tests {
154    use super::*;
155    use crate::types::*;
156
157    #[test]
158    fn test_check_average_cost_matching() {
159        let plugin = CheckAverageCostPlugin::new();
160
161        let input = PluginInput {
162            directives: vec![
163                DirectiveWrapper {
164                    directive_type: "transaction".to_string(),
165                    date: "2024-01-01".to_string(),
166                    filename: None,
167                    lineno: None,
168                    data: DirectiveData::Transaction(TransactionData {
169                        flag: "*".to_string(),
170                        payee: None,
171                        narration: "Buy".to_string(),
172                        tags: vec![],
173                        links: vec![],
174                        metadata: vec![],
175                        postings: vec![PostingData {
176                            account: "Assets:Broker".to_string(),
177                            units: Some(AmountData {
178                                number: "10".to_string(),
179                                currency: "AAPL".to_string(),
180                            }),
181                            cost: Some(CostData {
182                                number_per: Some("100.00".to_string()),
183                                number_total: None,
184                                currency: Some("USD".to_string()),
185                                date: None,
186                                label: None,
187                                merge: false,
188                            }),
189                            price: None,
190                            flag: None,
191                            metadata: vec![],
192                        }],
193                    }),
194                },
195                DirectiveWrapper {
196                    directive_type: "transaction".to_string(),
197                    date: "2024-02-01".to_string(),
198                    filename: None,
199                    lineno: None,
200                    data: DirectiveData::Transaction(TransactionData {
201                        flag: "*".to_string(),
202                        payee: None,
203                        narration: "Sell at avg cost".to_string(),
204                        tags: vec![],
205                        links: vec![],
206                        metadata: vec![],
207                        postings: vec![PostingData {
208                            account: "Assets:Broker".to_string(),
209                            units: Some(AmountData {
210                                number: "-5".to_string(),
211                                currency: "AAPL".to_string(),
212                            }),
213                            cost: Some(CostData {
214                                number_per: Some("100.00".to_string()), // Matches average
215                                number_total: None,
216                                currency: Some("USD".to_string()),
217                                date: None,
218                                label: None,
219                                merge: false,
220                            }),
221                            price: None,
222                            flag: None,
223                            metadata: vec![],
224                        }],
225                    }),
226                },
227            ],
228            options: PluginOptions {
229                operating_currencies: vec!["USD".to_string()],
230                title: None,
231            },
232            config: None,
233        };
234
235        let output = plugin.process(input);
236        assert_eq!(output.errors.len(), 0);
237    }
238
239    #[test]
240    fn test_check_average_cost_mismatch() {
241        let plugin = CheckAverageCostPlugin::new();
242
243        let input = PluginInput {
244            directives: vec![
245                DirectiveWrapper {
246                    directive_type: "transaction".to_string(),
247                    date: "2024-01-01".to_string(),
248                    filename: None,
249                    lineno: None,
250                    data: DirectiveData::Transaction(TransactionData {
251                        flag: "*".to_string(),
252                        payee: None,
253                        narration: "Buy at 100".to_string(),
254                        tags: vec![],
255                        links: vec![],
256                        metadata: vec![],
257                        postings: vec![PostingData {
258                            account: "Assets:Broker".to_string(),
259                            units: Some(AmountData {
260                                number: "10".to_string(),
261                                currency: "AAPL".to_string(),
262                            }),
263                            cost: Some(CostData {
264                                number_per: Some("100.00".to_string()),
265                                number_total: None,
266                                currency: Some("USD".to_string()),
267                                date: None,
268                                label: None,
269                                merge: false,
270                            }),
271                            price: None,
272                            flag: None,
273                            metadata: vec![],
274                        }],
275                    }),
276                },
277                DirectiveWrapper {
278                    directive_type: "transaction".to_string(),
279                    date: "2024-02-01".to_string(),
280                    filename: None,
281                    lineno: None,
282                    data: DirectiveData::Transaction(TransactionData {
283                        flag: "*".to_string(),
284                        payee: None,
285                        narration: "Sell at wrong cost".to_string(),
286                        tags: vec![],
287                        links: vec![],
288                        metadata: vec![],
289                        postings: vec![PostingData {
290                            account: "Assets:Broker".to_string(),
291                            units: Some(AmountData {
292                                number: "-5".to_string(),
293                                currency: "AAPL".to_string(),
294                            }),
295                            cost: Some(CostData {
296                                number_per: Some("90.00".to_string()), // 10% different from avg
297                                number_total: None,
298                                currency: Some("USD".to_string()),
299                                date: None,
300                                label: None,
301                                merge: false,
302                            }),
303                            price: None,
304                            flag: None,
305                            metadata: vec![],
306                        }],
307                    }),
308                },
309            ],
310            options: PluginOptions {
311                operating_currencies: vec!["USD".to_string()],
312                title: None,
313            },
314            config: None,
315        };
316
317        let output = plugin.process(input);
318        assert_eq!(output.errors.len(), 1);
319        assert!(output.errors[0].message.contains("average cost"));
320    }
321
322    #[test]
323    fn test_check_average_cost_multiple_buys() {
324        let plugin = CheckAverageCostPlugin::new();
325
326        // Buy 10 at $100, then 10 at $120 -> avg = $110
327        let input = PluginInput {
328            directives: vec![
329                DirectiveWrapper {
330                    directive_type: "transaction".to_string(),
331                    date: "2024-01-01".to_string(),
332                    filename: None,
333                    lineno: None,
334                    data: DirectiveData::Transaction(TransactionData {
335                        flag: "*".to_string(),
336                        payee: None,
337                        narration: "Buy at 100".to_string(),
338                        tags: vec![],
339                        links: vec![],
340                        metadata: vec![],
341                        postings: vec![PostingData {
342                            account: "Assets:Broker".to_string(),
343                            units: Some(AmountData {
344                                number: "10".to_string(),
345                                currency: "AAPL".to_string(),
346                            }),
347                            cost: Some(CostData {
348                                number_per: Some("100.00".to_string()),
349                                number_total: None,
350                                currency: Some("USD".to_string()),
351                                date: None,
352                                label: None,
353                                merge: false,
354                            }),
355                            price: None,
356                            flag: None,
357                            metadata: vec![],
358                        }],
359                    }),
360                },
361                DirectiveWrapper {
362                    directive_type: "transaction".to_string(),
363                    date: "2024-01-15".to_string(),
364                    filename: None,
365                    lineno: None,
366                    data: DirectiveData::Transaction(TransactionData {
367                        flag: "*".to_string(),
368                        payee: None,
369                        narration: "Buy at 120".to_string(),
370                        tags: vec![],
371                        links: vec![],
372                        metadata: vec![],
373                        postings: vec![PostingData {
374                            account: "Assets:Broker".to_string(),
375                            units: Some(AmountData {
376                                number: "10".to_string(),
377                                currency: "AAPL".to_string(),
378                            }),
379                            cost: Some(CostData {
380                                number_per: Some("120.00".to_string()),
381                                number_total: None,
382                                currency: Some("USD".to_string()),
383                                date: None,
384                                label: None,
385                                merge: false,
386                            }),
387                            price: None,
388                            flag: None,
389                            metadata: vec![],
390                        }],
391                    }),
392                },
393                DirectiveWrapper {
394                    directive_type: "transaction".to_string(),
395                    date: "2024-02-01".to_string(),
396                    filename: None,
397                    lineno: None,
398                    data: DirectiveData::Transaction(TransactionData {
399                        flag: "*".to_string(),
400                        payee: None,
401                        narration: "Sell at avg cost".to_string(),
402                        tags: vec![],
403                        links: vec![],
404                        metadata: vec![],
405                        postings: vec![PostingData {
406                            account: "Assets:Broker".to_string(),
407                            units: Some(AmountData {
408                                number: "-5".to_string(),
409                                currency: "AAPL".to_string(),
410                            }),
411                            cost: Some(CostData {
412                                number_per: Some("110.00".to_string()), // Matches average
413                                number_total: None,
414                                currency: Some("USD".to_string()),
415                                date: None,
416                                label: None,
417                                merge: false,
418                            }),
419                            price: None,
420                            flag: None,
421                            metadata: vec![],
422                        }],
423                    }),
424                },
425            ],
426            options: PluginOptions {
427                operating_currencies: vec!["USD".to_string()],
428                title: None,
429            },
430            config: None,
431        };
432
433        let output = plugin.process(input);
434        assert_eq!(output.errors.len(), 0);
435    }
436}