1use regex::Regex;
22use rust_decimal::Decimal;
23use rustledger_core::NaiveDate;
24use std::collections::HashSet;
25use std::str::FromStr;
26use std::sync::LazyLock;
27
28use crate::types::{
29 AmountData, DirectiveData, DirectiveWrapper, OpenData, PluginInput, PluginOp, PluginOutput,
30 PostingData, TransactionData,
31};
32
33use super::super::{NativePlugin, RegularPlugin};
34
35static CONFIG_ENTRY_RE: LazyLock<Regex> = LazyLock::new(|| {
38 Regex::new(r"'([^']+)'\s*:\s*\[\s*'([^']*)'\s*,\s*'([^']*)'\s*,\s*'([^']*)'\s*\]")
39 .expect("CONFIG_ENTRY_RE: invalid regex pattern")
40});
41
42pub struct CapitalGainsLongShortPlugin;
44
45pub struct CapitalGainsGainLossPlugin;
47
48impl NativePlugin for CapitalGainsLongShortPlugin {
49 fn name(&self) -> &'static str {
50 "long_short"
51 }
52
53 fn description(&self) -> &'static str {
54 "Classify capital gains into long-term vs short-term based on holding period"
55 }
56
57 fn process(&self, input: PluginInput) -> PluginOutput {
58 process_long_short(input)
59 }
60}
61
62impl NativePlugin for CapitalGainsGainLossPlugin {
63 fn name(&self) -> &'static str {
64 "gain_loss"
65 }
66
67 fn description(&self) -> &'static str {
68 "Classify capital gains into gains vs losses based on posting amount"
69 }
70
71 fn process(&self, input: PluginInput) -> PluginOutput {
72 process_gain_loss(input)
73 }
74}
75
76impl RegularPlugin for CapitalGainsLongShortPlugin {}
77impl RegularPlugin for CapitalGainsGainLossPlugin {}
78
79struct LongShortConfig {
81 pattern: Regex,
82 account_to_replace: String,
83 short_replacement: String,
84 long_replacement: String,
85}
86
87struct GainLossConfig {
89 pattern: Regex,
90 account_to_replace: String,
91 gains_replacement: String,
92 losses_replacement: String,
93}
94
95fn process_long_short(input: PluginInput) -> PluginOutput {
97 let config = match &input.config {
98 Some(c) => match parse_long_short_config(c) {
99 Some(cfg) => cfg,
100 None => {
101 return PluginOutput {
102 ops: (0..input.directives.len()).map(PluginOp::Keep).collect(),
103 errors: Vec::new(),
104 };
105 }
106 },
107 None => {
108 return PluginOutput {
109 ops: (0..input.directives.len()).map(PluginOp::Keep).collect(),
110 errors: Vec::new(),
111 };
112 }
113 };
114
115 let mut new_accounts: HashSet<String> = HashSet::new();
116 let mut ops: Vec<PluginOp> = Vec::with_capacity(input.directives.len());
117 let earliest_date = input
119 .directives
120 .iter()
121 .map(|d| d.date.as_str())
122 .min()
123 .unwrap_or("1970-01-01")
124 .to_string();
125 let existing_opens: HashSet<String> = input
128 .directives
129 .iter()
130 .filter_map(|d| match &d.data {
131 DirectiveData::Open(open) => Some(open.account.clone()),
132 _ => None,
133 })
134 .collect();
135
136 for (i, directive) in input.directives.into_iter().enumerate() {
137 if directive.directive_type != "transaction" {
138 ops.push(PluginOp::Keep(i));
139 continue;
140 }
141
142 if let DirectiveData::Transaction(txn) = &directive.data {
143 let has_generic = txn
145 .postings
146 .iter()
147 .any(|p| config.pattern.is_match(&p.account));
148 let has_specific = txn.postings.iter().any(|p| {
149 p.account.contains(&config.short_replacement)
150 || p.account.contains(&config.long_replacement)
151 });
152
153 if !has_generic || has_specific {
154 ops.push(PluginOp::Keep(i));
155 continue;
156 }
157
158 let reductions: Vec<&PostingData> = txn
160 .postings
161 .iter()
162 .filter(|p| p.cost.is_some() && p.units.is_some() && p.price.is_some())
163 .collect();
164
165 if reductions.is_empty() {
166 ops.push(PluginOp::Keep(i));
167 continue;
168 }
169
170 let entry_date = if let Ok(d) = directive.date.parse::<NaiveDate>() {
172 d
173 } else {
174 ops.push(PluginOp::Keep(i));
175 continue;
176 };
177
178 let any_missing_cost_date = reductions.iter().any(|p| {
185 p.cost
186 .as_ref()
187 .and_then(|c| c.date.as_ref())
188 .and_then(|d| d.parse::<NaiveDate>().ok())
189 .is_none()
190 });
191 if any_missing_cost_date {
192 ops.push(PluginOp::Keep(i));
193 continue;
194 }
195
196 let mut short_gains = Decimal::ZERO;
197 let mut long_gains = Decimal::ZERO;
198
199 for posting in &reductions {
200 if let (Some(cost), Some(units), Some(price)) =
201 (&posting.cost, &posting.units, &posting.price)
202 {
203 let cost_date = cost.date.as_ref().and_then(|d| d.parse::<NaiveDate>().ok());
205
206 if let Some(cost_date) = cost_date {
207 let cost_number = cost
213 .number
214 .as_ref()
215 .and_then(|cn| cn.per_unit())
216 .map_or(Decimal::ZERO, |s| {
217 Decimal::from_str(s).unwrap_or(Decimal::ZERO)
218 });
219 let price_number = price
220 .amount
221 .as_ref()
222 .and_then(|a| Decimal::from_str(&a.number).ok())
223 .unwrap_or(Decimal::ZERO);
224 let units_number =
225 Decimal::from_str(&units.number).unwrap_or(Decimal::ZERO);
226
227 let gain = (cost_number - price_number) * units_number.abs();
228
229 let days_held = entry_date.since(cost_date).map_or(0, |s| s.get_days());
231 let years_held = (days_held / 365) as u32;
232 let is_long_term = years_held > 1
233 || (years_held == 1
234 && (entry_date.month() > cost_date.month()
235 || (entry_date.month() == cost_date.month()
236 && entry_date.day() >= cost_date.day())));
237
238 if is_long_term {
239 long_gains += gain;
240 } else {
241 short_gains += gain;
242 }
243 }
244 }
245 }
246
247 let orig_postings: Vec<&PostingData> = txn
249 .postings
250 .iter()
251 .filter(|p| config.pattern.is_match(&p.account))
252 .collect();
253
254 if orig_postings.is_empty() {
255 ops.push(PluginOp::Keep(i));
256 continue;
257 }
258
259 let orig_sum: Decimal = orig_postings
260 .iter()
261 .filter_map(|p| p.units.as_ref())
262 .filter_map(|u| Decimal::from_str(&u.number).ok())
263 .sum();
264
265 let diff = orig_sum - (short_gains + long_gains);
267 if diff.abs() > Decimal::new(1, 6) {
268 let total = short_gains + long_gains;
269 if total != Decimal::ZERO {
270 short_gains += (short_gains / total) * diff;
271 long_gains += (long_gains / total) * diff;
272 }
273 }
274
275 let mut new_postings: Vec<PostingData> = txn
277 .postings
278 .iter()
279 .filter(|p| !config.pattern.is_match(&p.account))
280 .cloned()
281 .collect();
282
283 let template = orig_postings[0];
284
285 if short_gains != Decimal::ZERO {
286 let new_account = template
287 .account
288 .replace(&config.account_to_replace, &config.short_replacement);
289 new_accounts.insert(new_account.clone());
290 new_postings.push(PostingData {
291 account: new_account,
292 units: template.units.as_ref().map(|u| AmountData {
293 number: format_decimal(short_gains),
294 currency: u.currency.clone(),
295 }),
296 cost: None,
297 price: None,
298 flag: template.flag.clone(),
299 metadata: vec![],
300 span: None,
301 });
302 }
303
304 if long_gains != Decimal::ZERO {
305 let new_account = template
306 .account
307 .replace(&config.account_to_replace, &config.long_replacement);
308 new_accounts.insert(new_account.clone());
309 new_postings.push(PostingData {
310 account: new_account,
311 units: template.units.as_ref().map(|u| AmountData {
312 number: format_decimal(long_gains),
313 currency: u.currency.clone(),
314 }),
315 cost: None,
316 price: None,
317 flag: template.flag.clone(),
318 metadata: vec![],
319 span: None,
320 });
321 }
322
323 ops.push(PluginOp::Modify(
324 i,
325 DirectiveWrapper {
326 directive_type: "transaction".to_string(),
327 date: directive.date.clone(),
328 filename: directive.filename.clone(),
329 lineno: directive.lineno,
330 data: DirectiveData::Transaction(TransactionData {
331 flag: txn.flag.clone(),
332 payee: txn.payee.clone(),
333 narration: txn.narration.clone(),
334 tags: txn.tags.clone(),
335 links: txn.links.clone(),
336 metadata: txn.metadata.clone(),
337 postings: new_postings,
338 }),
339 },
340 ));
341 } else {
342 ops.push(PluginOp::Keep(i));
343 }
344 }
345
346 for account in &new_accounts {
349 if existing_opens.contains(account) {
350 continue;
351 }
352 ops.push(PluginOp::Insert(DirectiveWrapper {
353 directive_type: "open".to_string(),
354 date: earliest_date.clone(),
355 filename: Some("<long_short>".to_string()),
356 lineno: Some(0),
357 data: DirectiveData::Open(OpenData {
358 account: account.clone(),
359 currencies: vec![],
360 booking: None,
361 metadata: vec![],
362 }),
363 }));
364 }
365
366 PluginOutput {
367 ops,
368 errors: Vec::new(),
369 }
370}
371
372fn process_gain_loss(input: PluginInput) -> PluginOutput {
374 let config = match &input.config {
375 Some(c) => match parse_gain_loss_config(c) {
376 Some(cfg) => cfg,
377 None => {
378 return PluginOutput {
379 ops: (0..input.directives.len()).map(PluginOp::Keep).collect(),
380 errors: Vec::new(),
381 };
382 }
383 },
384 None => {
385 return PluginOutput {
386 ops: (0..input.directives.len()).map(PluginOp::Keep).collect(),
387 errors: Vec::new(),
388 };
389 }
390 };
391
392 let mut new_accounts: HashSet<String> = HashSet::new();
393 let mut ops: Vec<PluginOp> = Vec::with_capacity(input.directives.len());
394 let earliest_date = input
396 .directives
397 .iter()
398 .map(|d| d.date.as_str())
399 .min()
400 .unwrap_or("1970-01-01")
401 .to_string();
402 let existing_opens: HashSet<String> = input
404 .directives
405 .iter()
406 .filter_map(|d| match &d.data {
407 DirectiveData::Open(open) => Some(open.account.clone()),
408 _ => None,
409 })
410 .collect();
411
412 for (i, directive) in input.directives.into_iter().enumerate() {
413 if directive.directive_type != "transaction" {
414 ops.push(PluginOp::Keep(i));
415 continue;
416 }
417
418 if let DirectiveData::Transaction(txn) = &directive.data {
419 let mut modified = false;
420 let mut new_postings: Vec<PostingData> = Vec::new();
421
422 for posting in &txn.postings {
423 if config.pattern.is_match(&posting.account)
424 && let Some(units) = &posting.units
425 && let Ok(number) = Decimal::from_str(&units.number)
426 {
427 let new_account = if number < Decimal::ZERO {
428 posting
430 .account
431 .replace(&config.account_to_replace, &config.gains_replacement)
432 } else {
433 posting
435 .account
436 .replace(&config.account_to_replace, &config.losses_replacement)
437 };
438
439 new_accounts.insert(new_account.clone());
440 new_postings.push(PostingData {
441 account: new_account,
442 units: posting.units.clone(),
443 cost: posting.cost.clone(),
444 price: posting.price.clone(),
445 flag: posting.flag.clone(),
446 metadata: posting.metadata.clone(),
447 span: None,
448 });
449 modified = true;
450 continue;
451 }
452 new_postings.push(posting.clone());
453 }
454
455 if modified {
456 ops.push(PluginOp::Modify(
457 i,
458 DirectiveWrapper {
459 directive_type: "transaction".to_string(),
460 date: directive.date.clone(),
461 filename: directive.filename.clone(),
462 lineno: directive.lineno,
463 data: DirectiveData::Transaction(TransactionData {
464 flag: txn.flag.clone(),
465 payee: txn.payee.clone(),
466 narration: txn.narration.clone(),
467 tags: txn.tags.clone(),
468 links: txn.links.clone(),
469 metadata: txn.metadata.clone(),
470 postings: new_postings,
471 }),
472 },
473 ));
474 } else {
475 ops.push(PluginOp::Keep(i));
476 }
477 } else {
478 ops.push(PluginOp::Keep(i));
479 }
480 }
481
482 for account in &new_accounts {
485 if existing_opens.contains(account) {
486 continue;
487 }
488 ops.push(PluginOp::Insert(DirectiveWrapper {
489 directive_type: "open".to_string(),
490 date: earliest_date.clone(),
491 filename: Some("<gain_loss>".to_string()),
492 lineno: Some(0),
493 data: DirectiveData::Open(OpenData {
494 account: account.clone(),
495 currencies: vec![],
496 booking: None,
497 metadata: vec![],
498 }),
499 }));
500 }
501
502 PluginOutput {
503 ops,
504 errors: Vec::new(),
505 }
506}
507
508fn parse_long_short_config(config: &str) -> Option<LongShortConfig> {
511 let cap = CONFIG_ENTRY_RE.captures(config)?;
513 let pattern = Regex::new(&cap[1]).ok()?;
514 let account_to_replace = cap[2].to_string();
515 let short_replacement = cap[3].to_string();
516 let long_replacement = cap[4].to_string();
517
518 Some(LongShortConfig {
519 pattern,
520 account_to_replace,
521 short_replacement,
522 long_replacement,
523 })
524}
525
526fn parse_gain_loss_config(config: &str) -> Option<GainLossConfig> {
529 let cap = CONFIG_ENTRY_RE.captures(config)?;
531 let pattern = Regex::new(&cap[1]).ok()?;
532 let account_to_replace = cap[2].to_string();
533 let gains_replacement = cap[3].to_string();
534 let losses_replacement = cap[4].to_string();
535
536 Some(GainLossConfig {
537 pattern,
538 account_to_replace,
539 gains_replacement,
540 losses_replacement,
541 })
542}
543
544fn format_decimal(d: Decimal) -> String {
546 let s = d.to_string();
547 if s.contains('.') {
548 s.trim_end_matches('0').trim_end_matches('.').to_string()
549 } else {
550 s
551 }
552}
553
554#[cfg(test)]
555mod tests {
556 use super::super::utils::materialize_ops;
557 use super::*;
558 use crate::types::*;
559
560 #[test]
561 fn test_parse_long_short_config() {
562 let config = "{'Income.*:Capital-Gains': [':Capital-Gains', ':Capital-Gains:Short', ':Capital-Gains:Long']}";
563 let parsed = parse_long_short_config(config);
564 assert!(parsed.is_some());
565 let cfg = parsed.unwrap();
566 assert_eq!(cfg.account_to_replace, ":Capital-Gains");
567 assert_eq!(cfg.short_replacement, ":Capital-Gains:Short");
568 assert_eq!(cfg.long_replacement, ":Capital-Gains:Long");
569 }
570
571 #[test]
572 fn test_parse_gain_loss_config() {
573 let config = "{'Income.*:Long': [':Long', ':Long:Gains', ':Long:Losses']}";
574 let parsed = parse_gain_loss_config(config);
575 assert!(parsed.is_some());
576 let cfg = parsed.unwrap();
577 assert_eq!(cfg.account_to_replace, ":Long");
578 assert_eq!(cfg.gains_replacement, ":Long:Gains");
579 assert_eq!(cfg.losses_replacement, ":Long:Losses");
580 }
581
582 #[test]
583 fn test_gain_loss_classification() {
584 let plugin = CapitalGainsGainLossPlugin;
585
586 let input = PluginInput {
587 directives: vec![DirectiveWrapper {
588 directive_type: "transaction".to_string(),
589 date: "2024-01-15".to_string(),
590 filename: None,
591 lineno: None,
592 data: DirectiveData::Transaction(TransactionData {
593 flag: "*".to_string(),
594 payee: None,
595 narration: "Sell stock".to_string(),
596 tags: vec![],
597 links: vec![],
598 metadata: vec![],
599 postings: vec![
600 PostingData {
601 account: "Assets:Broker".to_string(),
602 units: Some(AmountData {
603 number: "1000".to_string(),
604 currency: "USD".to_string(),
605 }),
606 cost: None,
607 price: None,
608 flag: None,
609 metadata: vec![],
610 span: None,
611 },
612 PostingData {
613 account: "Income:Capital-Gains:Long".to_string(),
614 units: Some(AmountData {
615 number: "-100".to_string(),
616 currency: "USD".to_string(),
617 }),
618 cost: None,
619 price: None,
620 flag: None,
621 metadata: vec![],
622 span: None,
623 },
624 ],
625 }),
626 }],
627 options: PluginOptions {
628 operating_currencies: vec!["USD".to_string()],
629 title: None,
630 },
631 config: Some(
632 "{'Income.*:Capital-Gains:Long': [':Long', ':Long:Gains', ':Long:Losses']}"
633 .to_string(),
634 ),
635 };
636
637 let input_dirs = input.directives.clone();
638 let output = plugin.process(input);
639 assert_eq!(output.errors.len(), 0);
640 let directives = materialize_ops(&input_dirs, &output);
641
642 let txn = directives
644 .iter()
645 .find(|d| matches!(d.data, DirectiveData::Transaction(_)));
646 assert!(txn.is_some());
647
648 if let DirectiveData::Transaction(t) = &txn.unwrap().data {
649 let gains_posting = t
651 .postings
652 .iter()
653 .find(|p| p.account.contains(":Long:Gains"));
654 assert!(gains_posting.is_some());
655 }
656 }
657}