1use regex::Regex;
22use rust_decimal::Decimal;
23use rustledger_core::NaiveDate;
24use std::collections::HashSet;
25use std::str::FromStr;
26use std::sync::LazyLock;
27
28use crate::types::{
29 AmountData, DirectiveData, DirectiveWrapper, OpenData, PluginInput, PluginOp, PluginOutput,
30 PostingData, TransactionData,
31};
32
33use super::super::NativePlugin;
34
35static CONFIG_ENTRY_RE: LazyLock<Regex> = LazyLock::new(|| {
38 Regex::new(r"'([^']+)'\s*:\s*\[\s*'([^']*)'\s*,\s*'([^']*)'\s*,\s*'([^']*)'\s*\]")
39 .expect("CONFIG_ENTRY_RE: invalid regex pattern")
40});
41
42pub struct CapitalGainsLongShortPlugin;
44
45pub struct CapitalGainsGainLossPlugin;
47
48impl NativePlugin for CapitalGainsLongShortPlugin {
49 fn name(&self) -> &'static str {
50 "long_short"
51 }
52
53 fn description(&self) -> &'static str {
54 "Classify capital gains into long-term vs short-term based on holding period"
55 }
56
57 fn process(&self, input: PluginInput) -> PluginOutput {
58 process_long_short(input)
59 }
60}
61
62impl NativePlugin for CapitalGainsGainLossPlugin {
63 fn name(&self) -> &'static str {
64 "gain_loss"
65 }
66
67 fn description(&self) -> &'static str {
68 "Classify capital gains into gains vs losses based on posting amount"
69 }
70
71 fn process(&self, input: PluginInput) -> PluginOutput {
72 process_gain_loss(input)
73 }
74}
75
76struct LongShortConfig {
78 pattern: Regex,
79 account_to_replace: String,
80 short_replacement: String,
81 long_replacement: String,
82}
83
84struct GainLossConfig {
86 pattern: Regex,
87 account_to_replace: String,
88 gains_replacement: String,
89 losses_replacement: String,
90}
91
92fn process_long_short(input: PluginInput) -> PluginOutput {
94 let config = match &input.config {
95 Some(c) => match parse_long_short_config(c) {
96 Some(cfg) => cfg,
97 None => {
98 return PluginOutput {
99 ops: (0..input.directives.len()).map(PluginOp::Keep).collect(),
100 errors: Vec::new(),
101 };
102 }
103 },
104 None => {
105 return PluginOutput {
106 ops: (0..input.directives.len()).map(PluginOp::Keep).collect(),
107 errors: Vec::new(),
108 };
109 }
110 };
111
112 let mut new_accounts: HashSet<String> = HashSet::new();
113 let mut ops: Vec<PluginOp> = Vec::with_capacity(input.directives.len());
114 let earliest_date = input
116 .directives
117 .iter()
118 .map(|d| d.date.as_str())
119 .min()
120 .unwrap_or("1970-01-01")
121 .to_string();
122 let existing_opens: HashSet<String> = input
125 .directives
126 .iter()
127 .filter_map(|d| match &d.data {
128 DirectiveData::Open(open) => Some(open.account.clone()),
129 _ => None,
130 })
131 .collect();
132
133 for (i, directive) in input.directives.into_iter().enumerate() {
134 if directive.directive_type != "transaction" {
135 ops.push(PluginOp::Keep(i));
136 continue;
137 }
138
139 if let DirectiveData::Transaction(txn) = &directive.data {
140 let has_generic = txn
142 .postings
143 .iter()
144 .any(|p| config.pattern.is_match(&p.account));
145 let has_specific = txn.postings.iter().any(|p| {
146 p.account.contains(&config.short_replacement)
147 || p.account.contains(&config.long_replacement)
148 });
149
150 if !has_generic || has_specific {
151 ops.push(PluginOp::Keep(i));
152 continue;
153 }
154
155 let reductions: Vec<&PostingData> = txn
157 .postings
158 .iter()
159 .filter(|p| p.cost.is_some() && p.units.is_some() && p.price.is_some())
160 .collect();
161
162 if reductions.is_empty() {
163 ops.push(PluginOp::Keep(i));
164 continue;
165 }
166
167 let entry_date = if let Ok(d) = directive.date.parse::<NaiveDate>() {
169 d
170 } else {
171 ops.push(PluginOp::Keep(i));
172 continue;
173 };
174
175 let any_missing_cost_date = reductions.iter().any(|p| {
182 p.cost
183 .as_ref()
184 .and_then(|c| c.date.as_ref())
185 .and_then(|d| d.parse::<NaiveDate>().ok())
186 .is_none()
187 });
188 if any_missing_cost_date {
189 ops.push(PluginOp::Keep(i));
190 continue;
191 }
192
193 let mut short_gains = Decimal::ZERO;
194 let mut long_gains = Decimal::ZERO;
195
196 for posting in &reductions {
197 if let (Some(cost), Some(units), Some(price)) =
198 (&posting.cost, &posting.units, &posting.price)
199 {
200 let cost_date = cost.date.as_ref().and_then(|d| d.parse::<NaiveDate>().ok());
202
203 if let Some(cost_date) = cost_date {
204 let cost_number = cost
206 .number_per
207 .as_ref()
208 .and_then(|n| Decimal::from_str(n).ok())
209 .unwrap_or(Decimal::ZERO);
210 let price_number = price
211 .amount
212 .as_ref()
213 .and_then(|a| Decimal::from_str(&a.number).ok())
214 .unwrap_or(Decimal::ZERO);
215 let units_number =
216 Decimal::from_str(&units.number).unwrap_or(Decimal::ZERO);
217
218 let gain = (cost_number - price_number) * units_number.abs();
219
220 let days_held = entry_date.since(cost_date).map_or(0, |s| s.get_days());
222 let years_held = (days_held / 365) as u32;
223 let is_long_term = years_held > 1
224 || (years_held == 1
225 && (entry_date.month() > cost_date.month()
226 || (entry_date.month() == cost_date.month()
227 && entry_date.day() >= cost_date.day())));
228
229 if is_long_term {
230 long_gains += gain;
231 } else {
232 short_gains += gain;
233 }
234 }
235 }
236 }
237
238 let orig_postings: Vec<&PostingData> = txn
240 .postings
241 .iter()
242 .filter(|p| config.pattern.is_match(&p.account))
243 .collect();
244
245 if orig_postings.is_empty() {
246 ops.push(PluginOp::Keep(i));
247 continue;
248 }
249
250 let orig_sum: Decimal = orig_postings
251 .iter()
252 .filter_map(|p| p.units.as_ref())
253 .filter_map(|u| Decimal::from_str(&u.number).ok())
254 .sum();
255
256 let diff = orig_sum - (short_gains + long_gains);
258 if diff.abs() > Decimal::new(1, 6) {
259 let total = short_gains + long_gains;
260 if total != Decimal::ZERO {
261 short_gains += (short_gains / total) * diff;
262 long_gains += (long_gains / total) * diff;
263 }
264 }
265
266 let mut new_postings: Vec<PostingData> = txn
268 .postings
269 .iter()
270 .filter(|p| !config.pattern.is_match(&p.account))
271 .cloned()
272 .collect();
273
274 let template = orig_postings[0];
275
276 if short_gains != Decimal::ZERO {
277 let new_account = template
278 .account
279 .replace(&config.account_to_replace, &config.short_replacement);
280 new_accounts.insert(new_account.clone());
281 new_postings.push(PostingData {
282 account: new_account,
283 units: template.units.as_ref().map(|u| AmountData {
284 number: format_decimal(short_gains),
285 currency: u.currency.clone(),
286 }),
287 cost: None,
288 price: None,
289 flag: template.flag.clone(),
290 metadata: vec![],
291 });
292 }
293
294 if long_gains != Decimal::ZERO {
295 let new_account = template
296 .account
297 .replace(&config.account_to_replace, &config.long_replacement);
298 new_accounts.insert(new_account.clone());
299 new_postings.push(PostingData {
300 account: new_account,
301 units: template.units.as_ref().map(|u| AmountData {
302 number: format_decimal(long_gains),
303 currency: u.currency.clone(),
304 }),
305 cost: None,
306 price: None,
307 flag: template.flag.clone(),
308 metadata: vec![],
309 });
310 }
311
312 ops.push(PluginOp::Modify(
313 i,
314 DirectiveWrapper {
315 directive_type: "transaction".to_string(),
316 date: directive.date.clone(),
317 filename: directive.filename.clone(),
318 lineno: directive.lineno,
319 data: DirectiveData::Transaction(TransactionData {
320 flag: txn.flag.clone(),
321 payee: txn.payee.clone(),
322 narration: txn.narration.clone(),
323 tags: txn.tags.clone(),
324 links: txn.links.clone(),
325 metadata: txn.metadata.clone(),
326 postings: new_postings,
327 }),
328 },
329 ));
330 } else {
331 ops.push(PluginOp::Keep(i));
332 }
333 }
334
335 for account in &new_accounts {
338 if existing_opens.contains(account) {
339 continue;
340 }
341 ops.push(PluginOp::Insert(DirectiveWrapper {
342 directive_type: "open".to_string(),
343 date: earliest_date.clone(),
344 filename: Some("<long_short>".to_string()),
345 lineno: Some(0),
346 data: DirectiveData::Open(OpenData {
347 account: account.clone(),
348 currencies: vec![],
349 booking: None,
350 metadata: vec![],
351 }),
352 }));
353 }
354
355 PluginOutput {
356 ops,
357 errors: Vec::new(),
358 }
359}
360
361fn process_gain_loss(input: PluginInput) -> PluginOutput {
363 let config = match &input.config {
364 Some(c) => match parse_gain_loss_config(c) {
365 Some(cfg) => cfg,
366 None => {
367 return PluginOutput {
368 ops: (0..input.directives.len()).map(PluginOp::Keep).collect(),
369 errors: Vec::new(),
370 };
371 }
372 },
373 None => {
374 return PluginOutput {
375 ops: (0..input.directives.len()).map(PluginOp::Keep).collect(),
376 errors: Vec::new(),
377 };
378 }
379 };
380
381 let mut new_accounts: HashSet<String> = HashSet::new();
382 let mut ops: Vec<PluginOp> = Vec::with_capacity(input.directives.len());
383 let earliest_date = input
385 .directives
386 .iter()
387 .map(|d| d.date.as_str())
388 .min()
389 .unwrap_or("1970-01-01")
390 .to_string();
391 let existing_opens: HashSet<String> = input
393 .directives
394 .iter()
395 .filter_map(|d| match &d.data {
396 DirectiveData::Open(open) => Some(open.account.clone()),
397 _ => None,
398 })
399 .collect();
400
401 for (i, directive) in input.directives.into_iter().enumerate() {
402 if directive.directive_type != "transaction" {
403 ops.push(PluginOp::Keep(i));
404 continue;
405 }
406
407 if let DirectiveData::Transaction(txn) = &directive.data {
408 let mut modified = false;
409 let mut new_postings: Vec<PostingData> = Vec::new();
410
411 for posting in &txn.postings {
412 if config.pattern.is_match(&posting.account)
413 && let Some(units) = &posting.units
414 && let Ok(number) = Decimal::from_str(&units.number)
415 {
416 let new_account = if number < Decimal::ZERO {
417 posting
419 .account
420 .replace(&config.account_to_replace, &config.gains_replacement)
421 } else {
422 posting
424 .account
425 .replace(&config.account_to_replace, &config.losses_replacement)
426 };
427
428 new_accounts.insert(new_account.clone());
429 new_postings.push(PostingData {
430 account: new_account,
431 units: posting.units.clone(),
432 cost: posting.cost.clone(),
433 price: posting.price.clone(),
434 flag: posting.flag.clone(),
435 metadata: posting.metadata.clone(),
436 });
437 modified = true;
438 continue;
439 }
440 new_postings.push(posting.clone());
441 }
442
443 if modified {
444 ops.push(PluginOp::Modify(
445 i,
446 DirectiveWrapper {
447 directive_type: "transaction".to_string(),
448 date: directive.date.clone(),
449 filename: directive.filename.clone(),
450 lineno: directive.lineno,
451 data: DirectiveData::Transaction(TransactionData {
452 flag: txn.flag.clone(),
453 payee: txn.payee.clone(),
454 narration: txn.narration.clone(),
455 tags: txn.tags.clone(),
456 links: txn.links.clone(),
457 metadata: txn.metadata.clone(),
458 postings: new_postings,
459 }),
460 },
461 ));
462 } else {
463 ops.push(PluginOp::Keep(i));
464 }
465 } else {
466 ops.push(PluginOp::Keep(i));
467 }
468 }
469
470 for account in &new_accounts {
473 if existing_opens.contains(account) {
474 continue;
475 }
476 ops.push(PluginOp::Insert(DirectiveWrapper {
477 directive_type: "open".to_string(),
478 date: earliest_date.clone(),
479 filename: Some("<gain_loss>".to_string()),
480 lineno: Some(0),
481 data: DirectiveData::Open(OpenData {
482 account: account.clone(),
483 currencies: vec![],
484 booking: None,
485 metadata: vec![],
486 }),
487 }));
488 }
489
490 PluginOutput {
491 ops,
492 errors: Vec::new(),
493 }
494}
495
496fn parse_long_short_config(config: &str) -> Option<LongShortConfig> {
499 let cap = CONFIG_ENTRY_RE.captures(config)?;
501 let pattern = Regex::new(&cap[1]).ok()?;
502 let account_to_replace = cap[2].to_string();
503 let short_replacement = cap[3].to_string();
504 let long_replacement = cap[4].to_string();
505
506 Some(LongShortConfig {
507 pattern,
508 account_to_replace,
509 short_replacement,
510 long_replacement,
511 })
512}
513
514fn parse_gain_loss_config(config: &str) -> Option<GainLossConfig> {
517 let cap = CONFIG_ENTRY_RE.captures(config)?;
519 let pattern = Regex::new(&cap[1]).ok()?;
520 let account_to_replace = cap[2].to_string();
521 let gains_replacement = cap[3].to_string();
522 let losses_replacement = cap[4].to_string();
523
524 Some(GainLossConfig {
525 pattern,
526 account_to_replace,
527 gains_replacement,
528 losses_replacement,
529 })
530}
531
532fn format_decimal(d: Decimal) -> String {
534 let s = d.to_string();
535 if s.contains('.') {
536 s.trim_end_matches('0').trim_end_matches('.').to_string()
537 } else {
538 s
539 }
540}
541
542#[cfg(test)]
543mod tests {
544 use super::super::utils::materialize_ops;
545 use super::*;
546 use crate::types::*;
547
548 #[test]
549 fn test_parse_long_short_config() {
550 let config = "{'Income.*:Capital-Gains': [':Capital-Gains', ':Capital-Gains:Short', ':Capital-Gains:Long']}";
551 let parsed = parse_long_short_config(config);
552 assert!(parsed.is_some());
553 let cfg = parsed.unwrap();
554 assert_eq!(cfg.account_to_replace, ":Capital-Gains");
555 assert_eq!(cfg.short_replacement, ":Capital-Gains:Short");
556 assert_eq!(cfg.long_replacement, ":Capital-Gains:Long");
557 }
558
559 #[test]
560 fn test_parse_gain_loss_config() {
561 let config = "{'Income.*:Long': [':Long', ':Long:Gains', ':Long:Losses']}";
562 let parsed = parse_gain_loss_config(config);
563 assert!(parsed.is_some());
564 let cfg = parsed.unwrap();
565 assert_eq!(cfg.account_to_replace, ":Long");
566 assert_eq!(cfg.gains_replacement, ":Long:Gains");
567 assert_eq!(cfg.losses_replacement, ":Long:Losses");
568 }
569
570 #[test]
571 fn test_gain_loss_classification() {
572 let plugin = CapitalGainsGainLossPlugin;
573
574 let input = PluginInput {
575 directives: vec![DirectiveWrapper {
576 directive_type: "transaction".to_string(),
577 date: "2024-01-15".to_string(),
578 filename: None,
579 lineno: None,
580 data: DirectiveData::Transaction(TransactionData {
581 flag: "*".to_string(),
582 payee: None,
583 narration: "Sell stock".to_string(),
584 tags: vec![],
585 links: vec![],
586 metadata: vec![],
587 postings: vec![
588 PostingData {
589 account: "Assets:Broker".to_string(),
590 units: Some(AmountData {
591 number: "1000".to_string(),
592 currency: "USD".to_string(),
593 }),
594 cost: None,
595 price: None,
596 flag: None,
597 metadata: vec![],
598 },
599 PostingData {
600 account: "Income:Capital-Gains:Long".to_string(),
601 units: Some(AmountData {
602 number: "-100".to_string(),
603 currency: "USD".to_string(),
604 }),
605 cost: None,
606 price: None,
607 flag: None,
608 metadata: vec![],
609 },
610 ],
611 }),
612 }],
613 options: PluginOptions {
614 operating_currencies: vec!["USD".to_string()],
615 title: None,
616 },
617 config: Some(
618 "{'Income.*:Capital-Gains:Long': [':Long', ':Long:Gains', ':Long:Losses']}"
619 .to_string(),
620 ),
621 };
622
623 let input_dirs = input.directives.clone();
624 let output = plugin.process(input);
625 assert_eq!(output.errors.len(), 0);
626 let directives = materialize_ops(&input_dirs, &output);
627
628 let txn = directives
630 .iter()
631 .find(|d| matches!(d.data, DirectiveData::Transaction(_)));
632 assert!(txn.is_some());
633
634 if let DirectiveData::Transaction(t) = &txn.unwrap().data {
635 let gains_posting = t
637 .postings
638 .iter()
639 .find(|p| p.account.contains(":Long:Gains"));
640 assert!(gains_posting.is_some());
641 }
642 }
643}