1use regex::Regex;
22use rust_decimal::Decimal;
23use rustledger_core::NaiveDate;
24use std::collections::HashSet;
25use std::str::FromStr;
26use std::sync::LazyLock;
27
28use crate::types::{
29 AmountData, DirectiveData, DirectiveWrapper, OpenData, PluginInput, PluginOutput, PostingData,
30 TransactionData,
31};
32
33use super::super::NativePlugin;
34
35static CONFIG_ENTRY_RE: LazyLock<Regex> = LazyLock::new(|| {
38 Regex::new(r"'([^']+)'\s*:\s*\[\s*'([^']*)'\s*,\s*'([^']*)'\s*,\s*'([^']*)'\s*\]")
39 .expect("CONFIG_ENTRY_RE: invalid regex pattern")
40});
41
42pub struct CapitalGainsLongShortPlugin;
44
45pub struct CapitalGainsGainLossPlugin;
47
48impl NativePlugin for CapitalGainsLongShortPlugin {
49 fn name(&self) -> &'static str {
50 "long_short"
51 }
52
53 fn description(&self) -> &'static str {
54 "Classify capital gains into long-term vs short-term based on holding period"
55 }
56
57 fn process(&self, input: PluginInput) -> PluginOutput {
58 process_long_short(input)
59 }
60}
61
62impl NativePlugin for CapitalGainsGainLossPlugin {
63 fn name(&self) -> &'static str {
64 "gain_loss"
65 }
66
67 fn description(&self) -> &'static str {
68 "Classify capital gains into gains vs losses based on posting amount"
69 }
70
71 fn process(&self, input: PluginInput) -> PluginOutput {
72 process_gain_loss(input)
73 }
74}
75
76struct LongShortConfig {
78 pattern: Regex,
79 account_to_replace: String,
80 short_replacement: String,
81 long_replacement: String,
82}
83
84struct GainLossConfig {
86 pattern: Regex,
87 account_to_replace: String,
88 gains_replacement: String,
89 losses_replacement: String,
90}
91
92fn process_long_short(input: PluginInput) -> PluginOutput {
94 let config = match &input.config {
95 Some(c) => match parse_long_short_config(c) {
96 Some(cfg) => cfg,
97 None => {
98 return PluginOutput {
99 directives: input.directives,
100 errors: Vec::new(),
101 };
102 }
103 },
104 None => {
105 return PluginOutput {
106 directives: input.directives,
107 errors: Vec::new(),
108 };
109 }
110 };
111
112 let mut new_accounts: HashSet<String> = HashSet::new();
113 let mut new_directives: Vec<DirectiveWrapper> = Vec::new();
114
115 for directive in input.directives {
116 if directive.directive_type != "transaction" {
117 new_directives.push(directive);
118 continue;
119 }
120
121 if let DirectiveData::Transaction(txn) = &directive.data {
122 let has_generic = txn
124 .postings
125 .iter()
126 .any(|p| config.pattern.is_match(&p.account));
127 let has_specific = txn.postings.iter().any(|p| {
128 p.account.contains(&config.short_replacement)
129 || p.account.contains(&config.long_replacement)
130 });
131
132 if !has_generic || has_specific {
133 new_directives.push(directive);
134 continue;
135 }
136
137 let reductions: Vec<&PostingData> = txn
139 .postings
140 .iter()
141 .filter(|p| p.cost.is_some() && p.units.is_some() && p.price.is_some())
142 .collect();
143
144 if reductions.is_empty() {
145 new_directives.push(directive);
146 continue;
147 }
148
149 let entry_date = if let Ok(d) = directive.date.parse::<NaiveDate>() {
151 d
152 } else {
153 new_directives.push(directive);
154 continue;
155 };
156
157 let mut short_gains = Decimal::ZERO;
158 let mut long_gains = Decimal::ZERO;
159
160 for posting in &reductions {
161 if let (Some(cost), Some(units), Some(price)) =
162 (&posting.cost, &posting.units, &posting.price)
163 {
164 let cost_date = cost.date.as_ref().and_then(|d| d.parse::<NaiveDate>().ok());
166
167 if let Some(cost_date) = cost_date {
168 let cost_number = cost
170 .number_per
171 .as_ref()
172 .and_then(|n| Decimal::from_str(n).ok())
173 .unwrap_or(Decimal::ZERO);
174 let price_number = price
175 .amount
176 .as_ref()
177 .and_then(|a| Decimal::from_str(&a.number).ok())
178 .unwrap_or(Decimal::ZERO);
179 let units_number =
180 Decimal::from_str(&units.number).unwrap_or(Decimal::ZERO);
181
182 let gain = (cost_number - price_number) * units_number.abs();
183
184 let days_held = entry_date.since(cost_date).map_or(0, |s| s.get_days());
186 let years_held = (days_held / 365) as u32;
187 let is_long_term = years_held > 1
188 || (years_held == 1
189 && (entry_date.month() > cost_date.month()
190 || (entry_date.month() == cost_date.month()
191 && entry_date.day() >= cost_date.day())));
192
193 if is_long_term {
194 long_gains += gain;
195 } else {
196 short_gains += gain;
197 }
198 }
199 }
200 }
201
202 let orig_postings: Vec<&PostingData> = txn
204 .postings
205 .iter()
206 .filter(|p| config.pattern.is_match(&p.account))
207 .collect();
208
209 if orig_postings.is_empty() {
210 new_directives.push(directive);
211 continue;
212 }
213
214 let orig_sum: Decimal = orig_postings
215 .iter()
216 .filter_map(|p| p.units.as_ref())
217 .filter_map(|u| Decimal::from_str(&u.number).ok())
218 .sum();
219
220 let diff = orig_sum - (short_gains + long_gains);
222 if diff.abs() > Decimal::new(1, 6) {
223 let total = short_gains + long_gains;
224 if total != Decimal::ZERO {
225 short_gains += (short_gains / total) * diff;
226 long_gains += (long_gains / total) * diff;
227 }
228 }
229
230 let mut new_postings: Vec<PostingData> = txn
232 .postings
233 .iter()
234 .filter(|p| !config.pattern.is_match(&p.account))
235 .cloned()
236 .collect();
237
238 let template = orig_postings[0];
239
240 if short_gains != Decimal::ZERO {
241 let new_account = template
242 .account
243 .replace(&config.account_to_replace, &config.short_replacement);
244 new_accounts.insert(new_account.clone());
245 new_postings.push(PostingData {
246 account: new_account,
247 units: template.units.as_ref().map(|u| AmountData {
248 number: format_decimal(short_gains),
249 currency: u.currency.clone(),
250 }),
251 cost: None,
252 price: None,
253 flag: template.flag.clone(),
254 metadata: vec![],
255 });
256 }
257
258 if long_gains != Decimal::ZERO {
259 let new_account = template
260 .account
261 .replace(&config.account_to_replace, &config.long_replacement);
262 new_accounts.insert(new_account.clone());
263 new_postings.push(PostingData {
264 account: new_account,
265 units: template.units.as_ref().map(|u| AmountData {
266 number: format_decimal(long_gains),
267 currency: u.currency.clone(),
268 }),
269 cost: None,
270 price: None,
271 flag: template.flag.clone(),
272 metadata: vec![],
273 });
274 }
275
276 new_directives.push(DirectiveWrapper {
277 directive_type: "transaction".to_string(),
278 date: directive.date.clone(),
279 filename: directive.filename.clone(),
280 lineno: directive.lineno,
281 data: DirectiveData::Transaction(TransactionData {
282 flag: txn.flag.clone(),
283 payee: txn.payee.clone(),
284 narration: txn.narration.clone(),
285 tags: txn.tags.clone(),
286 links: txn.links.clone(),
287 metadata: txn.metadata.clone(),
288 postings: new_postings,
289 }),
290 });
291 } else {
292 new_directives.push(directive);
293 }
294 }
295
296 let earliest_date = new_directives
298 .iter()
299 .map(|d| d.date.as_str())
300 .min()
301 .unwrap_or("1970-01-01")
302 .to_string();
303
304 let mut open_directives: Vec<DirectiveWrapper> = new_accounts
305 .iter()
306 .map(|account| DirectiveWrapper {
307 directive_type: "open".to_string(),
308 date: earliest_date.clone(),
309 filename: Some("<long_short>".to_string()),
310 lineno: Some(0),
311 data: DirectiveData::Open(OpenData {
312 account: account.clone(),
313 currencies: vec![],
314 booking: None,
315 metadata: vec![],
316 }),
317 })
318 .collect();
319
320 open_directives.extend(new_directives);
321
322 PluginOutput {
323 directives: open_directives,
324 errors: Vec::new(),
325 }
326}
327
328fn process_gain_loss(input: PluginInput) -> PluginOutput {
330 let config = match &input.config {
331 Some(c) => match parse_gain_loss_config(c) {
332 Some(cfg) => cfg,
333 None => {
334 return PluginOutput {
335 directives: input.directives,
336 errors: Vec::new(),
337 };
338 }
339 },
340 None => {
341 return PluginOutput {
342 directives: input.directives,
343 errors: Vec::new(),
344 };
345 }
346 };
347
348 let mut new_accounts: HashSet<String> = HashSet::new();
349 let mut new_directives: Vec<DirectiveWrapper> = Vec::new();
350
351 for directive in input.directives {
352 if directive.directive_type != "transaction" {
353 new_directives.push(directive);
354 continue;
355 }
356
357 if let DirectiveData::Transaction(txn) = &directive.data {
358 let mut modified = false;
359 let mut new_postings: Vec<PostingData> = Vec::new();
360
361 for posting in &txn.postings {
362 if config.pattern.is_match(&posting.account)
363 && let Some(units) = &posting.units
364 && let Ok(number) = Decimal::from_str(&units.number)
365 {
366 let new_account = if number < Decimal::ZERO {
367 posting
369 .account
370 .replace(&config.account_to_replace, &config.gains_replacement)
371 } else {
372 posting
374 .account
375 .replace(&config.account_to_replace, &config.losses_replacement)
376 };
377
378 new_accounts.insert(new_account.clone());
379 new_postings.push(PostingData {
380 account: new_account,
381 units: posting.units.clone(),
382 cost: posting.cost.clone(),
383 price: posting.price.clone(),
384 flag: posting.flag.clone(),
385 metadata: posting.metadata.clone(),
386 });
387 modified = true;
388 continue;
389 }
390 new_postings.push(posting.clone());
391 }
392
393 if modified {
394 new_directives.push(DirectiveWrapper {
395 directive_type: "transaction".to_string(),
396 date: directive.date.clone(),
397 filename: directive.filename.clone(),
398 lineno: directive.lineno,
399 data: DirectiveData::Transaction(TransactionData {
400 flag: txn.flag.clone(),
401 payee: txn.payee.clone(),
402 narration: txn.narration.clone(),
403 tags: txn.tags.clone(),
404 links: txn.links.clone(),
405 metadata: txn.metadata.clone(),
406 postings: new_postings,
407 }),
408 });
409 } else {
410 new_directives.push(directive);
411 }
412 } else {
413 new_directives.push(directive);
414 }
415 }
416
417 let earliest_date = new_directives
419 .iter()
420 .map(|d| d.date.as_str())
421 .min()
422 .unwrap_or("1970-01-01")
423 .to_string();
424
425 let mut open_directives: Vec<DirectiveWrapper> = new_accounts
426 .iter()
427 .map(|account| DirectiveWrapper {
428 directive_type: "open".to_string(),
429 date: earliest_date.clone(),
430 filename: Some("<gain_loss>".to_string()),
431 lineno: Some(0),
432 data: DirectiveData::Open(OpenData {
433 account: account.clone(),
434 currencies: vec![],
435 booking: None,
436 metadata: vec![],
437 }),
438 })
439 .collect();
440
441 open_directives.extend(new_directives);
442
443 PluginOutput {
444 directives: open_directives,
445 errors: Vec::new(),
446 }
447}
448
449fn parse_long_short_config(config: &str) -> Option<LongShortConfig> {
452 let cap = CONFIG_ENTRY_RE.captures(config)?;
454 let pattern = Regex::new(&cap[1]).ok()?;
455 let account_to_replace = cap[2].to_string();
456 let short_replacement = cap[3].to_string();
457 let long_replacement = cap[4].to_string();
458
459 Some(LongShortConfig {
460 pattern,
461 account_to_replace,
462 short_replacement,
463 long_replacement,
464 })
465}
466
467fn parse_gain_loss_config(config: &str) -> Option<GainLossConfig> {
470 let cap = CONFIG_ENTRY_RE.captures(config)?;
472 let pattern = Regex::new(&cap[1]).ok()?;
473 let account_to_replace = cap[2].to_string();
474 let gains_replacement = cap[3].to_string();
475 let losses_replacement = cap[4].to_string();
476
477 Some(GainLossConfig {
478 pattern,
479 account_to_replace,
480 gains_replacement,
481 losses_replacement,
482 })
483}
484
485fn format_decimal(d: Decimal) -> String {
487 let s = d.to_string();
488 if s.contains('.') {
489 s.trim_end_matches('0').trim_end_matches('.').to_string()
490 } else {
491 s
492 }
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498 use crate::types::*;
499
500 #[test]
501 fn test_parse_long_short_config() {
502 let config = "{'Income.*:Capital-Gains': [':Capital-Gains', ':Capital-Gains:Short', ':Capital-Gains:Long']}";
503 let parsed = parse_long_short_config(config);
504 assert!(parsed.is_some());
505 let cfg = parsed.unwrap();
506 assert_eq!(cfg.account_to_replace, ":Capital-Gains");
507 assert_eq!(cfg.short_replacement, ":Capital-Gains:Short");
508 assert_eq!(cfg.long_replacement, ":Capital-Gains:Long");
509 }
510
511 #[test]
512 fn test_parse_gain_loss_config() {
513 let config = "{'Income.*:Long': [':Long', ':Long:Gains', ':Long:Losses']}";
514 let parsed = parse_gain_loss_config(config);
515 assert!(parsed.is_some());
516 let cfg = parsed.unwrap();
517 assert_eq!(cfg.account_to_replace, ":Long");
518 assert_eq!(cfg.gains_replacement, ":Long:Gains");
519 assert_eq!(cfg.losses_replacement, ":Long:Losses");
520 }
521
522 #[test]
523 fn test_gain_loss_classification() {
524 let plugin = CapitalGainsGainLossPlugin;
525
526 let input = PluginInput {
527 directives: vec![DirectiveWrapper {
528 directive_type: "transaction".to_string(),
529 date: "2024-01-15".to_string(),
530 filename: None,
531 lineno: None,
532 data: DirectiveData::Transaction(TransactionData {
533 flag: "*".to_string(),
534 payee: None,
535 narration: "Sell stock".to_string(),
536 tags: vec![],
537 links: vec![],
538 metadata: vec![],
539 postings: vec![
540 PostingData {
541 account: "Assets:Broker".to_string(),
542 units: Some(AmountData {
543 number: "1000".to_string(),
544 currency: "USD".to_string(),
545 }),
546 cost: None,
547 price: None,
548 flag: None,
549 metadata: vec![],
550 },
551 PostingData {
552 account: "Income:Capital-Gains:Long".to_string(),
553 units: Some(AmountData {
554 number: "-100".to_string(),
555 currency: "USD".to_string(),
556 }),
557 cost: None,
558 price: None,
559 flag: None,
560 metadata: vec![],
561 },
562 ],
563 }),
564 }],
565 options: PluginOptions {
566 operating_currencies: vec!["USD".to_string()],
567 title: None,
568 },
569 config: Some(
570 "{'Income.*:Capital-Gains:Long': [':Long', ':Long:Gains', ':Long:Losses']}"
571 .to_string(),
572 ),
573 };
574
575 let output = plugin.process(input);
576 assert_eq!(output.errors.len(), 0);
577
578 let txn = output
580 .directives
581 .iter()
582 .find(|d| d.directive_type == "transaction");
583 assert!(txn.is_some());
584
585 if let DirectiveData::Transaction(t) = &txn.unwrap().data {
586 let gains_posting = t
588 .postings
589 .iter()
590 .find(|p| p.account.contains(":Long:Gains"));
591 assert!(gains_posting.is_some());
592 }
593 }
594}