1use chrono::{Datelike, NaiveDate};
22use regex::Regex;
23use rust_decimal::Decimal;
24use std::collections::HashSet;
25use std::str::FromStr;
26use std::sync::LazyLock;
27
28use crate::types::{
29 AmountData, DirectiveData, DirectiveWrapper, OpenData, PluginInput, PluginOutput, PostingData,
30 TransactionData,
31};
32
33use super::super::NativePlugin;
34
35static CONFIG_ENTRY_RE: LazyLock<Regex> = LazyLock::new(|| {
38 Regex::new(r"'([^']+)'\s*:\s*\[\s*'([^']*)'\s*,\s*'([^']*)'\s*,\s*'([^']*)'\s*\]")
39 .expect("CONFIG_ENTRY_RE: invalid regex pattern")
40});
41
42pub struct CapitalGainsLongShortPlugin;
44
45pub struct CapitalGainsGainLossPlugin;
47
48impl NativePlugin for CapitalGainsLongShortPlugin {
49 fn name(&self) -> &'static str {
50 "long_short"
51 }
52
53 fn description(&self) -> &'static str {
54 "Classify capital gains into long-term vs short-term based on holding period"
55 }
56
57 fn process(&self, input: PluginInput) -> PluginOutput {
58 process_long_short(input)
59 }
60}
61
62impl NativePlugin for CapitalGainsGainLossPlugin {
63 fn name(&self) -> &'static str {
64 "gain_loss"
65 }
66
67 fn description(&self) -> &'static str {
68 "Classify capital gains into gains vs losses based on posting amount"
69 }
70
71 fn process(&self, input: PluginInput) -> PluginOutput {
72 process_gain_loss(input)
73 }
74}
75
76struct LongShortConfig {
78 pattern: Regex,
79 account_to_replace: String,
80 short_replacement: String,
81 long_replacement: String,
82}
83
84struct GainLossConfig {
86 pattern: Regex,
87 account_to_replace: String,
88 gains_replacement: String,
89 losses_replacement: String,
90}
91
92fn process_long_short(input: PluginInput) -> PluginOutput {
94 let config = match &input.config {
95 Some(c) => match parse_long_short_config(c) {
96 Some(cfg) => cfg,
97 None => {
98 return PluginOutput {
99 directives: input.directives,
100 errors: Vec::new(),
101 };
102 }
103 },
104 None => {
105 return PluginOutput {
106 directives: input.directives,
107 errors: Vec::new(),
108 };
109 }
110 };
111
112 let mut new_accounts: HashSet<String> = HashSet::new();
113 let mut new_directives: Vec<DirectiveWrapper> = Vec::new();
114
115 for directive in input.directives {
116 if directive.directive_type != "transaction" {
117 new_directives.push(directive);
118 continue;
119 }
120
121 if let DirectiveData::Transaction(txn) = &directive.data {
122 let has_generic = txn
124 .postings
125 .iter()
126 .any(|p| config.pattern.is_match(&p.account));
127 let has_specific = txn.postings.iter().any(|p| {
128 p.account.contains(&config.short_replacement)
129 || p.account.contains(&config.long_replacement)
130 });
131
132 if !has_generic || has_specific {
133 new_directives.push(directive);
134 continue;
135 }
136
137 let reductions: Vec<&PostingData> = txn
139 .postings
140 .iter()
141 .filter(|p| p.cost.is_some() && p.units.is_some() && p.price.is_some())
142 .collect();
143
144 if reductions.is_empty() {
145 new_directives.push(directive);
146 continue;
147 }
148
149 let entry_date = if let Ok(d) = NaiveDate::parse_from_str(&directive.date, "%Y-%m-%d") {
151 d
152 } else {
153 new_directives.push(directive);
154 continue;
155 };
156
157 let mut short_gains = Decimal::ZERO;
158 let mut long_gains = Decimal::ZERO;
159
160 for posting in &reductions {
161 if let (Some(cost), Some(units), Some(price)) =
162 (&posting.cost, &posting.units, &posting.price)
163 {
164 let cost_date = cost
166 .date
167 .as_ref()
168 .and_then(|d| NaiveDate::parse_from_str(d, "%Y-%m-%d").ok());
169
170 if let Some(cost_date) = cost_date {
171 let cost_number = cost
173 .number_per
174 .as_ref()
175 .and_then(|n| Decimal::from_str(n).ok())
176 .unwrap_or(Decimal::ZERO);
177 let price_number = price
178 .amount
179 .as_ref()
180 .and_then(|a| Decimal::from_str(&a.number).ok())
181 .unwrap_or(Decimal::ZERO);
182 let units_number =
183 Decimal::from_str(&units.number).unwrap_or(Decimal::ZERO);
184
185 let gain = (cost_number - price_number) * units_number.abs();
186
187 let years_held = entry_date.years_since(cost_date).unwrap_or(0);
189 let is_long_term = years_held > 1
190 || (years_held == 1
191 && (entry_date.month() > cost_date.month()
192 || (entry_date.month() == cost_date.month()
193 && entry_date.day() >= cost_date.day())));
194
195 if is_long_term {
196 long_gains += gain;
197 } else {
198 short_gains += gain;
199 }
200 }
201 }
202 }
203
204 let orig_postings: Vec<&PostingData> = txn
206 .postings
207 .iter()
208 .filter(|p| config.pattern.is_match(&p.account))
209 .collect();
210
211 if orig_postings.is_empty() {
212 new_directives.push(directive);
213 continue;
214 }
215
216 let orig_sum: Decimal = orig_postings
217 .iter()
218 .filter_map(|p| p.units.as_ref())
219 .filter_map(|u| Decimal::from_str(&u.number).ok())
220 .sum();
221
222 let diff = orig_sum - (short_gains + long_gains);
224 if diff.abs() > Decimal::new(1, 6) {
225 let total = short_gains + long_gains;
226 if total != Decimal::ZERO {
227 short_gains += (short_gains / total) * diff;
228 long_gains += (long_gains / total) * diff;
229 }
230 }
231
232 let mut new_postings: Vec<PostingData> = txn
234 .postings
235 .iter()
236 .filter(|p| !config.pattern.is_match(&p.account))
237 .cloned()
238 .collect();
239
240 let template = orig_postings[0];
241
242 if short_gains != Decimal::ZERO {
243 let new_account = template
244 .account
245 .replace(&config.account_to_replace, &config.short_replacement);
246 new_accounts.insert(new_account.clone());
247 new_postings.push(PostingData {
248 account: new_account,
249 units: template.units.as_ref().map(|u| AmountData {
250 number: format_decimal(short_gains),
251 currency: u.currency.clone(),
252 }),
253 cost: None,
254 price: None,
255 flag: template.flag.clone(),
256 metadata: vec![],
257 });
258 }
259
260 if long_gains != Decimal::ZERO {
261 let new_account = template
262 .account
263 .replace(&config.account_to_replace, &config.long_replacement);
264 new_accounts.insert(new_account.clone());
265 new_postings.push(PostingData {
266 account: new_account,
267 units: template.units.as_ref().map(|u| AmountData {
268 number: format_decimal(long_gains),
269 currency: u.currency.clone(),
270 }),
271 cost: None,
272 price: None,
273 flag: template.flag.clone(),
274 metadata: vec![],
275 });
276 }
277
278 new_directives.push(DirectiveWrapper {
279 directive_type: "transaction".to_string(),
280 date: directive.date.clone(),
281 filename: directive.filename.clone(),
282 lineno: directive.lineno,
283 data: DirectiveData::Transaction(TransactionData {
284 flag: txn.flag.clone(),
285 payee: txn.payee.clone(),
286 narration: txn.narration.clone(),
287 tags: txn.tags.clone(),
288 links: txn.links.clone(),
289 metadata: txn.metadata.clone(),
290 postings: new_postings,
291 }),
292 });
293 } else {
294 new_directives.push(directive);
295 }
296 }
297
298 let earliest_date = new_directives
300 .iter()
301 .map(|d| d.date.as_str())
302 .min()
303 .unwrap_or("1970-01-01")
304 .to_string();
305
306 let mut open_directives: Vec<DirectiveWrapper> = new_accounts
307 .iter()
308 .map(|account| DirectiveWrapper {
309 directive_type: "open".to_string(),
310 date: earliest_date.clone(),
311 filename: Some("<long_short>".to_string()),
312 lineno: Some(0),
313 data: DirectiveData::Open(OpenData {
314 account: account.clone(),
315 currencies: vec![],
316 booking: None,
317 metadata: vec![],
318 }),
319 })
320 .collect();
321
322 open_directives.extend(new_directives);
323
324 PluginOutput {
325 directives: open_directives,
326 errors: Vec::new(),
327 }
328}
329
330fn process_gain_loss(input: PluginInput) -> PluginOutput {
332 let config = match &input.config {
333 Some(c) => match parse_gain_loss_config(c) {
334 Some(cfg) => cfg,
335 None => {
336 return PluginOutput {
337 directives: input.directives,
338 errors: Vec::new(),
339 };
340 }
341 },
342 None => {
343 return PluginOutput {
344 directives: input.directives,
345 errors: Vec::new(),
346 };
347 }
348 };
349
350 let mut new_accounts: HashSet<String> = HashSet::new();
351 let mut new_directives: Vec<DirectiveWrapper> = Vec::new();
352
353 for directive in input.directives {
354 if directive.directive_type != "transaction" {
355 new_directives.push(directive);
356 continue;
357 }
358
359 if let DirectiveData::Transaction(txn) = &directive.data {
360 let mut modified = false;
361 let mut new_postings: Vec<PostingData> = Vec::new();
362
363 for posting in &txn.postings {
364 if config.pattern.is_match(&posting.account)
365 && let Some(units) = &posting.units
366 && let Ok(number) = Decimal::from_str(&units.number)
367 {
368 let new_account = if number < Decimal::ZERO {
369 posting
371 .account
372 .replace(&config.account_to_replace, &config.gains_replacement)
373 } else {
374 posting
376 .account
377 .replace(&config.account_to_replace, &config.losses_replacement)
378 };
379
380 new_accounts.insert(new_account.clone());
381 new_postings.push(PostingData {
382 account: new_account,
383 units: posting.units.clone(),
384 cost: posting.cost.clone(),
385 price: posting.price.clone(),
386 flag: posting.flag.clone(),
387 metadata: posting.metadata.clone(),
388 });
389 modified = true;
390 continue;
391 }
392 new_postings.push(posting.clone());
393 }
394
395 if modified {
396 new_directives.push(DirectiveWrapper {
397 directive_type: "transaction".to_string(),
398 date: directive.date.clone(),
399 filename: directive.filename.clone(),
400 lineno: directive.lineno,
401 data: DirectiveData::Transaction(TransactionData {
402 flag: txn.flag.clone(),
403 payee: txn.payee.clone(),
404 narration: txn.narration.clone(),
405 tags: txn.tags.clone(),
406 links: txn.links.clone(),
407 metadata: txn.metadata.clone(),
408 postings: new_postings,
409 }),
410 });
411 } else {
412 new_directives.push(directive);
413 }
414 } else {
415 new_directives.push(directive);
416 }
417 }
418
419 let earliest_date = new_directives
421 .iter()
422 .map(|d| d.date.as_str())
423 .min()
424 .unwrap_or("1970-01-01")
425 .to_string();
426
427 let mut open_directives: Vec<DirectiveWrapper> = new_accounts
428 .iter()
429 .map(|account| DirectiveWrapper {
430 directive_type: "open".to_string(),
431 date: earliest_date.clone(),
432 filename: Some("<gain_loss>".to_string()),
433 lineno: Some(0),
434 data: DirectiveData::Open(OpenData {
435 account: account.clone(),
436 currencies: vec![],
437 booking: None,
438 metadata: vec![],
439 }),
440 })
441 .collect();
442
443 open_directives.extend(new_directives);
444
445 PluginOutput {
446 directives: open_directives,
447 errors: Vec::new(),
448 }
449}
450
451fn parse_long_short_config(config: &str) -> Option<LongShortConfig> {
454 let cap = CONFIG_ENTRY_RE.captures(config)?;
456 let pattern = Regex::new(&cap[1]).ok()?;
457 let account_to_replace = cap[2].to_string();
458 let short_replacement = cap[3].to_string();
459 let long_replacement = cap[4].to_string();
460
461 Some(LongShortConfig {
462 pattern,
463 account_to_replace,
464 short_replacement,
465 long_replacement,
466 })
467}
468
469fn parse_gain_loss_config(config: &str) -> Option<GainLossConfig> {
472 let cap = CONFIG_ENTRY_RE.captures(config)?;
474 let pattern = Regex::new(&cap[1]).ok()?;
475 let account_to_replace = cap[2].to_string();
476 let gains_replacement = cap[3].to_string();
477 let losses_replacement = cap[4].to_string();
478
479 Some(GainLossConfig {
480 pattern,
481 account_to_replace,
482 gains_replacement,
483 losses_replacement,
484 })
485}
486
487fn format_decimal(d: Decimal) -> String {
489 let s = d.to_string();
490 if s.contains('.') {
491 s.trim_end_matches('0').trim_end_matches('.').to_string()
492 } else {
493 s
494 }
495}
496
497#[cfg(test)]
498mod tests {
499 use super::*;
500 use crate::types::*;
501
502 #[test]
503 fn test_parse_long_short_config() {
504 let config = "{'Income.*:Capital-Gains': [':Capital-Gains', ':Capital-Gains:Short', ':Capital-Gains:Long']}";
505 let parsed = parse_long_short_config(config);
506 assert!(parsed.is_some());
507 let cfg = parsed.unwrap();
508 assert_eq!(cfg.account_to_replace, ":Capital-Gains");
509 assert_eq!(cfg.short_replacement, ":Capital-Gains:Short");
510 assert_eq!(cfg.long_replacement, ":Capital-Gains:Long");
511 }
512
513 #[test]
514 fn test_parse_gain_loss_config() {
515 let config = "{'Income.*:Long': [':Long', ':Long:Gains', ':Long:Losses']}";
516 let parsed = parse_gain_loss_config(config);
517 assert!(parsed.is_some());
518 let cfg = parsed.unwrap();
519 assert_eq!(cfg.account_to_replace, ":Long");
520 assert_eq!(cfg.gains_replacement, ":Long:Gains");
521 assert_eq!(cfg.losses_replacement, ":Long:Losses");
522 }
523
524 #[test]
525 fn test_gain_loss_classification() {
526 let plugin = CapitalGainsGainLossPlugin;
527
528 let input = PluginInput {
529 directives: vec![DirectiveWrapper {
530 directive_type: "transaction".to_string(),
531 date: "2024-01-15".to_string(),
532 filename: None,
533 lineno: None,
534 data: DirectiveData::Transaction(TransactionData {
535 flag: "*".to_string(),
536 payee: None,
537 narration: "Sell stock".to_string(),
538 tags: vec![],
539 links: vec![],
540 metadata: vec![],
541 postings: vec![
542 PostingData {
543 account: "Assets:Broker".to_string(),
544 units: Some(AmountData {
545 number: "1000".to_string(),
546 currency: "USD".to_string(),
547 }),
548 cost: None,
549 price: None,
550 flag: None,
551 metadata: vec![],
552 },
553 PostingData {
554 account: "Income:Capital-Gains:Long".to_string(),
555 units: Some(AmountData {
556 number: "-100".to_string(),
557 currency: "USD".to_string(),
558 }),
559 cost: None,
560 price: None,
561 flag: None,
562 metadata: vec![],
563 },
564 ],
565 }),
566 }],
567 options: PluginOptions {
568 operating_currencies: vec!["USD".to_string()],
569 title: None,
570 },
571 config: Some(
572 "{'Income.*:Capital-Gains:Long': [':Long', ':Long:Gains', ':Long:Losses']}"
573 .to_string(),
574 ),
575 };
576
577 let output = plugin.process(input);
578 assert_eq!(output.errors.len(), 0);
579
580 let txn = output
582 .directives
583 .iter()
584 .find(|d| d.directive_type == "transaction");
585 assert!(txn.is_some());
586
587 if let DirectiveData::Transaction(t) = &txn.unwrap().data {
588 let gains_posting = t
590 .postings
591 .iter()
592 .find(|p| p.account.contains(":Long:Gains"));
593 assert!(gains_posting.is_some());
594 }
595 }
596}