1use crate::types::{DirectiveData, PluginError, PluginInput, PluginOutput};
4
5use super::super::NativePlugin;
6
7pub struct CheckAverageCostPlugin {
13 tolerance: rust_decimal::Decimal,
15}
16
17impl CheckAverageCostPlugin {
18 pub fn new() -> Self {
20 Self {
21 tolerance: rust_decimal::Decimal::new(1, 2), }
23 }
24
25 pub const fn with_tolerance(tolerance: rust_decimal::Decimal) -> Self {
27 Self { tolerance }
28 }
29}
30
31impl Default for CheckAverageCostPlugin {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl NativePlugin for CheckAverageCostPlugin {
38 fn name(&self) -> &'static str {
39 "check_average_cost"
40 }
41
42 fn description(&self) -> &'static str {
43 "Validate reducing postings match average cost"
44 }
45
46 fn process(&self, input: PluginInput) -> PluginOutput {
47 use rust_decimal::Decimal;
48 use std::collections::HashMap;
49 use std::str::FromStr;
50
51 let tolerance = if let Some(config) = &input.config {
53 Decimal::from_str(config.trim()).unwrap_or(self.tolerance)
54 } else {
55 self.tolerance
56 };
57
58 let mut inventory: HashMap<(String, String), (Decimal, Decimal)> = HashMap::new();
61
62 let mut errors = Vec::new();
63
64 for wrapper in &input.directives {
65 if let DirectiveData::Transaction(txn) = &wrapper.data {
66 for posting in &txn.postings {
67 let Some(units) = &posting.units else {
69 continue;
70 };
71 let Some(cost) = &posting.cost else {
72 continue;
73 };
74
75 let units_num = Decimal::from_str(&units.number).unwrap_or_default();
76 let Some(cost_currency) = &cost.currency else {
77 continue;
78 };
79
80 let key = (posting.account.clone(), units.currency.clone());
81
82 if units_num > Decimal::ZERO {
83 let cost_per = cost
85 .number_per
86 .as_ref()
87 .and_then(|s| Decimal::from_str(s).ok())
88 .unwrap_or_default();
89
90 let entry = inventory
91 .entry(key)
92 .or_insert((Decimal::ZERO, Decimal::ZERO));
93 entry.0 += units_num; entry.1 += units_num * cost_per; } else if units_num < Decimal::ZERO {
96 let entry = inventory.get(&key);
98
99 if let Some((total_units, total_cost)) = entry
100 && *total_units > Decimal::ZERO
101 {
102 let avg_cost = *total_cost / *total_units;
103
104 let used_cost = cost
106 .number_per
107 .as_ref()
108 .and_then(|s| Decimal::from_str(s).ok())
109 .unwrap_or_default();
110
111 let diff = (used_cost - avg_cost).abs();
113 let relative_diff = if avg_cost == Decimal::ZERO {
114 diff
115 } else {
116 diff / avg_cost
117 };
118
119 if relative_diff > tolerance {
120 errors.push(PluginError::warning(format!(
121 "Sale of {} {} in {} uses cost {} {} but average cost is {} {} (difference: {:.2}%)",
122 units_num.abs(),
123 units.currency,
124 posting.account,
125 used_cost,
126 cost_currency,
127 avg_cost.round_dp(4),
128 cost_currency,
129 relative_diff * Decimal::from(100)
130 )));
131 }
132
133 let entry = inventory.get_mut(&key).unwrap();
135 let units_sold = units_num.abs();
136 let cost_removed = units_sold * avg_cost;
137 entry.0 -= units_sold;
138 entry.1 -= cost_removed;
139 }
140 }
141 }
142 }
143 }
144
145 PluginOutput {
146 directives: input.directives,
147 errors,
148 }
149 }
150}
151
152#[cfg(test)]
153mod check_average_cost_tests {
154 use super::*;
155 use crate::types::*;
156
157 #[test]
158 fn test_check_average_cost_matching() {
159 let plugin = CheckAverageCostPlugin::new();
160
161 let input = PluginInput {
162 directives: vec![
163 DirectiveWrapper {
164 directive_type: "transaction".to_string(),
165 date: "2024-01-01".to_string(),
166 filename: None,
167 lineno: None,
168 data: DirectiveData::Transaction(TransactionData {
169 flag: "*".to_string(),
170 payee: None,
171 narration: "Buy".to_string(),
172 tags: vec![],
173 links: vec![],
174 metadata: vec![],
175 postings: vec![PostingData {
176 account: "Assets:Broker".to_string(),
177 units: Some(AmountData {
178 number: "10".to_string(),
179 currency: "AAPL".to_string(),
180 }),
181 cost: Some(CostData {
182 number_per: Some("100.00".to_string()),
183 number_total: None,
184 currency: Some("USD".to_string()),
185 date: None,
186 label: None,
187 merge: false,
188 }),
189 price: None,
190 flag: None,
191 metadata: vec![],
192 }],
193 }),
194 },
195 DirectiveWrapper {
196 directive_type: "transaction".to_string(),
197 date: "2024-02-01".to_string(),
198 filename: None,
199 lineno: None,
200 data: DirectiveData::Transaction(TransactionData {
201 flag: "*".to_string(),
202 payee: None,
203 narration: "Sell at avg cost".to_string(),
204 tags: vec![],
205 links: vec![],
206 metadata: vec![],
207 postings: vec![PostingData {
208 account: "Assets:Broker".to_string(),
209 units: Some(AmountData {
210 number: "-5".to_string(),
211 currency: "AAPL".to_string(),
212 }),
213 cost: Some(CostData {
214 number_per: Some("100.00".to_string()), number_total: None,
216 currency: Some("USD".to_string()),
217 date: None,
218 label: None,
219 merge: false,
220 }),
221 price: None,
222 flag: None,
223 metadata: vec![],
224 }],
225 }),
226 },
227 ],
228 options: PluginOptions {
229 operating_currencies: vec!["USD".to_string()],
230 title: None,
231 },
232 config: None,
233 };
234
235 let output = plugin.process(input);
236 assert_eq!(output.errors.len(), 0);
237 }
238
239 #[test]
240 fn test_check_average_cost_mismatch() {
241 let plugin = CheckAverageCostPlugin::new();
242
243 let input = PluginInput {
244 directives: vec![
245 DirectiveWrapper {
246 directive_type: "transaction".to_string(),
247 date: "2024-01-01".to_string(),
248 filename: None,
249 lineno: None,
250 data: DirectiveData::Transaction(TransactionData {
251 flag: "*".to_string(),
252 payee: None,
253 narration: "Buy at 100".to_string(),
254 tags: vec![],
255 links: vec![],
256 metadata: vec![],
257 postings: vec![PostingData {
258 account: "Assets:Broker".to_string(),
259 units: Some(AmountData {
260 number: "10".to_string(),
261 currency: "AAPL".to_string(),
262 }),
263 cost: Some(CostData {
264 number_per: Some("100.00".to_string()),
265 number_total: None,
266 currency: Some("USD".to_string()),
267 date: None,
268 label: None,
269 merge: false,
270 }),
271 price: None,
272 flag: None,
273 metadata: vec![],
274 }],
275 }),
276 },
277 DirectiveWrapper {
278 directive_type: "transaction".to_string(),
279 date: "2024-02-01".to_string(),
280 filename: None,
281 lineno: None,
282 data: DirectiveData::Transaction(TransactionData {
283 flag: "*".to_string(),
284 payee: None,
285 narration: "Sell at wrong cost".to_string(),
286 tags: vec![],
287 links: vec![],
288 metadata: vec![],
289 postings: vec![PostingData {
290 account: "Assets:Broker".to_string(),
291 units: Some(AmountData {
292 number: "-5".to_string(),
293 currency: "AAPL".to_string(),
294 }),
295 cost: Some(CostData {
296 number_per: Some("90.00".to_string()), number_total: None,
298 currency: Some("USD".to_string()),
299 date: None,
300 label: None,
301 merge: false,
302 }),
303 price: None,
304 flag: None,
305 metadata: vec![],
306 }],
307 }),
308 },
309 ],
310 options: PluginOptions {
311 operating_currencies: vec!["USD".to_string()],
312 title: None,
313 },
314 config: None,
315 };
316
317 let output = plugin.process(input);
318 assert_eq!(output.errors.len(), 1);
319 assert!(output.errors[0].message.contains("average cost"));
320 }
321
322 #[test]
323 fn test_check_average_cost_multiple_buys() {
324 let plugin = CheckAverageCostPlugin::new();
325
326 let input = PluginInput {
328 directives: vec![
329 DirectiveWrapper {
330 directive_type: "transaction".to_string(),
331 date: "2024-01-01".to_string(),
332 filename: None,
333 lineno: None,
334 data: DirectiveData::Transaction(TransactionData {
335 flag: "*".to_string(),
336 payee: None,
337 narration: "Buy at 100".to_string(),
338 tags: vec![],
339 links: vec![],
340 metadata: vec![],
341 postings: vec![PostingData {
342 account: "Assets:Broker".to_string(),
343 units: Some(AmountData {
344 number: "10".to_string(),
345 currency: "AAPL".to_string(),
346 }),
347 cost: Some(CostData {
348 number_per: Some("100.00".to_string()),
349 number_total: None,
350 currency: Some("USD".to_string()),
351 date: None,
352 label: None,
353 merge: false,
354 }),
355 price: None,
356 flag: None,
357 metadata: vec![],
358 }],
359 }),
360 },
361 DirectiveWrapper {
362 directive_type: "transaction".to_string(),
363 date: "2024-01-15".to_string(),
364 filename: None,
365 lineno: None,
366 data: DirectiveData::Transaction(TransactionData {
367 flag: "*".to_string(),
368 payee: None,
369 narration: "Buy at 120".to_string(),
370 tags: vec![],
371 links: vec![],
372 metadata: vec![],
373 postings: vec![PostingData {
374 account: "Assets:Broker".to_string(),
375 units: Some(AmountData {
376 number: "10".to_string(),
377 currency: "AAPL".to_string(),
378 }),
379 cost: Some(CostData {
380 number_per: Some("120.00".to_string()),
381 number_total: None,
382 currency: Some("USD".to_string()),
383 date: None,
384 label: None,
385 merge: false,
386 }),
387 price: None,
388 flag: None,
389 metadata: vec![],
390 }],
391 }),
392 },
393 DirectiveWrapper {
394 directive_type: "transaction".to_string(),
395 date: "2024-02-01".to_string(),
396 filename: None,
397 lineno: None,
398 data: DirectiveData::Transaction(TransactionData {
399 flag: "*".to_string(),
400 payee: None,
401 narration: "Sell at avg cost".to_string(),
402 tags: vec![],
403 links: vec![],
404 metadata: vec![],
405 postings: vec![PostingData {
406 account: "Assets:Broker".to_string(),
407 units: Some(AmountData {
408 number: "-5".to_string(),
409 currency: "AAPL".to_string(),
410 }),
411 cost: Some(CostData {
412 number_per: Some("110.00".to_string()), number_total: None,
414 currency: Some("USD".to_string()),
415 date: None,
416 label: None,
417 merge: false,
418 }),
419 price: None,
420 flag: None,
421 metadata: vec![],
422 }],
423 }),
424 },
425 ],
426 options: PluginOptions {
427 operating_currencies: vec!["USD".to_string()],
428 title: None,
429 },
430 config: None,
431 };
432
433 let output = plugin.process(input);
434 assert_eq!(output.errors.len(), 0);
435 }
436}