1use chrono::{Datelike, NaiveDate};
22use regex::Regex;
23use rust_decimal::Decimal;
24use std::collections::HashSet;
25use std::str::FromStr;
26
27use crate::types::{
28 AmountData, DirectiveData, DirectiveWrapper, OpenData, PluginInput, PluginOutput, PostingData,
29 TransactionData,
30};
31
32use super::super::NativePlugin;
33
34pub struct CapitalGainsLongShortPlugin;
36
37pub struct CapitalGainsGainLossPlugin;
39
40impl NativePlugin for CapitalGainsLongShortPlugin {
41 fn name(&self) -> &'static str {
42 "long_short"
43 }
44
45 fn description(&self) -> &'static str {
46 "Classify capital gains into long-term vs short-term based on holding period"
47 }
48
49 fn process(&self, input: PluginInput) -> PluginOutput {
50 process_long_short(input)
51 }
52}
53
54impl NativePlugin for CapitalGainsGainLossPlugin {
55 fn name(&self) -> &'static str {
56 "gain_loss"
57 }
58
59 fn description(&self) -> &'static str {
60 "Classify capital gains into gains vs losses based on posting amount"
61 }
62
63 fn process(&self, input: PluginInput) -> PluginOutput {
64 process_gain_loss(input)
65 }
66}
67
68struct LongShortConfig {
70 pattern: Regex,
71 account_to_replace: String,
72 short_replacement: String,
73 long_replacement: String,
74}
75
76struct GainLossConfig {
78 pattern: Regex,
79 account_to_replace: String,
80 gains_replacement: String,
81 losses_replacement: String,
82}
83
84fn process_long_short(input: PluginInput) -> PluginOutput {
86 let config = match &input.config {
87 Some(c) => match parse_long_short_config(c) {
88 Some(cfg) => cfg,
89 None => {
90 return PluginOutput {
91 directives: input.directives,
92 errors: Vec::new(),
93 };
94 }
95 },
96 None => {
97 return PluginOutput {
98 directives: input.directives,
99 errors: Vec::new(),
100 };
101 }
102 };
103
104 let mut new_accounts: HashSet<String> = HashSet::new();
105 let mut new_directives: Vec<DirectiveWrapper> = Vec::new();
106
107 for directive in input.directives {
108 if directive.directive_type != "transaction" {
109 new_directives.push(directive);
110 continue;
111 }
112
113 if let DirectiveData::Transaction(txn) = &directive.data {
114 let has_generic = txn
116 .postings
117 .iter()
118 .any(|p| config.pattern.is_match(&p.account));
119 let has_specific = txn.postings.iter().any(|p| {
120 p.account.contains(&config.short_replacement)
121 || p.account.contains(&config.long_replacement)
122 });
123
124 if !has_generic || has_specific {
125 new_directives.push(directive);
126 continue;
127 }
128
129 let reductions: Vec<&PostingData> = txn
131 .postings
132 .iter()
133 .filter(|p| p.cost.is_some() && p.units.is_some() && p.price.is_some())
134 .collect();
135
136 if reductions.is_empty() {
137 new_directives.push(directive);
138 continue;
139 }
140
141 let entry_date = if let Ok(d) = NaiveDate::parse_from_str(&directive.date, "%Y-%m-%d") {
143 d
144 } else {
145 new_directives.push(directive);
146 continue;
147 };
148
149 let mut short_gains = Decimal::ZERO;
150 let mut long_gains = Decimal::ZERO;
151
152 for posting in &reductions {
153 if let (Some(cost), Some(units), Some(price)) =
154 (&posting.cost, &posting.units, &posting.price)
155 {
156 let cost_date = cost
158 .date
159 .as_ref()
160 .and_then(|d| NaiveDate::parse_from_str(d, "%Y-%m-%d").ok());
161
162 if let Some(cost_date) = cost_date {
163 let cost_number = cost
165 .number_per
166 .as_ref()
167 .and_then(|n| Decimal::from_str(n).ok())
168 .unwrap_or(Decimal::ZERO);
169 let price_number = price
170 .amount
171 .as_ref()
172 .and_then(|a| Decimal::from_str(&a.number).ok())
173 .unwrap_or(Decimal::ZERO);
174 let units_number =
175 Decimal::from_str(&units.number).unwrap_or(Decimal::ZERO);
176
177 let gain = (cost_number - price_number) * units_number.abs();
178
179 let years_held = entry_date.years_since(cost_date).unwrap_or(0);
181 let is_long_term = years_held > 1
182 || (years_held == 1
183 && (entry_date.month() > cost_date.month()
184 || (entry_date.month() == cost_date.month()
185 && entry_date.day() >= cost_date.day())));
186
187 if is_long_term {
188 long_gains += gain;
189 } else {
190 short_gains += gain;
191 }
192 }
193 }
194 }
195
196 let orig_postings: Vec<&PostingData> = txn
198 .postings
199 .iter()
200 .filter(|p| config.pattern.is_match(&p.account))
201 .collect();
202
203 if orig_postings.is_empty() {
204 new_directives.push(directive);
205 continue;
206 }
207
208 let orig_sum: Decimal = orig_postings
209 .iter()
210 .filter_map(|p| p.units.as_ref())
211 .filter_map(|u| Decimal::from_str(&u.number).ok())
212 .sum();
213
214 let diff = orig_sum - (short_gains + long_gains);
216 if diff.abs() > Decimal::new(1, 6) {
217 let total = short_gains + long_gains;
218 if total != Decimal::ZERO {
219 short_gains += (short_gains / total) * diff;
220 long_gains += (long_gains / total) * diff;
221 }
222 }
223
224 let mut new_postings: Vec<PostingData> = txn
226 .postings
227 .iter()
228 .filter(|p| !config.pattern.is_match(&p.account))
229 .cloned()
230 .collect();
231
232 let template = orig_postings[0];
233
234 if short_gains != Decimal::ZERO {
235 let new_account = template
236 .account
237 .replace(&config.account_to_replace, &config.short_replacement);
238 new_accounts.insert(new_account.clone());
239 new_postings.push(PostingData {
240 account: new_account,
241 units: template.units.as_ref().map(|u| AmountData {
242 number: format_decimal(short_gains),
243 currency: u.currency.clone(),
244 }),
245 cost: None,
246 price: None,
247 flag: template.flag.clone(),
248 metadata: vec![],
249 });
250 }
251
252 if long_gains != Decimal::ZERO {
253 let new_account = template
254 .account
255 .replace(&config.account_to_replace, &config.long_replacement);
256 new_accounts.insert(new_account.clone());
257 new_postings.push(PostingData {
258 account: new_account,
259 units: template.units.as_ref().map(|u| AmountData {
260 number: format_decimal(long_gains),
261 currency: u.currency.clone(),
262 }),
263 cost: None,
264 price: None,
265 flag: template.flag.clone(),
266 metadata: vec![],
267 });
268 }
269
270 new_directives.push(DirectiveWrapper {
271 directive_type: "transaction".to_string(),
272 date: directive.date.clone(),
273 filename: directive.filename.clone(),
274 lineno: directive.lineno,
275 data: DirectiveData::Transaction(TransactionData {
276 flag: txn.flag.clone(),
277 payee: txn.payee.clone(),
278 narration: txn.narration.clone(),
279 tags: txn.tags.clone(),
280 links: txn.links.clone(),
281 metadata: txn.metadata.clone(),
282 postings: new_postings,
283 }),
284 });
285 } else {
286 new_directives.push(directive);
287 }
288 }
289
290 let earliest_date = new_directives
292 .iter()
293 .map(|d| d.date.as_str())
294 .min()
295 .unwrap_or("1970-01-01")
296 .to_string();
297
298 let mut open_directives: Vec<DirectiveWrapper> = new_accounts
299 .iter()
300 .map(|account| DirectiveWrapper {
301 directive_type: "open".to_string(),
302 date: earliest_date.clone(),
303 filename: Some("<long_short>".to_string()),
304 lineno: Some(0),
305 data: DirectiveData::Open(OpenData {
306 account: account.clone(),
307 currencies: vec![],
308 booking: None,
309 metadata: vec![],
310 }),
311 })
312 .collect();
313
314 open_directives.extend(new_directives);
315
316 PluginOutput {
317 directives: open_directives,
318 errors: Vec::new(),
319 }
320}
321
322fn process_gain_loss(input: PluginInput) -> PluginOutput {
324 let config = match &input.config {
325 Some(c) => match parse_gain_loss_config(c) {
326 Some(cfg) => cfg,
327 None => {
328 return PluginOutput {
329 directives: input.directives,
330 errors: Vec::new(),
331 };
332 }
333 },
334 None => {
335 return PluginOutput {
336 directives: input.directives,
337 errors: Vec::new(),
338 };
339 }
340 };
341
342 let mut new_accounts: HashSet<String> = HashSet::new();
343 let mut new_directives: Vec<DirectiveWrapper> = Vec::new();
344
345 for directive in input.directives {
346 if directive.directive_type != "transaction" {
347 new_directives.push(directive);
348 continue;
349 }
350
351 if let DirectiveData::Transaction(txn) = &directive.data {
352 let mut modified = false;
353 let mut new_postings: Vec<PostingData> = Vec::new();
354
355 for posting in &txn.postings {
356 if config.pattern.is_match(&posting.account)
357 && let Some(units) = &posting.units
358 && let Ok(number) = Decimal::from_str(&units.number)
359 {
360 let new_account = if number < Decimal::ZERO {
361 posting
363 .account
364 .replace(&config.account_to_replace, &config.gains_replacement)
365 } else {
366 posting
368 .account
369 .replace(&config.account_to_replace, &config.losses_replacement)
370 };
371
372 new_accounts.insert(new_account.clone());
373 new_postings.push(PostingData {
374 account: new_account,
375 units: posting.units.clone(),
376 cost: posting.cost.clone(),
377 price: posting.price.clone(),
378 flag: posting.flag.clone(),
379 metadata: posting.metadata.clone(),
380 });
381 modified = true;
382 continue;
383 }
384 new_postings.push(posting.clone());
385 }
386
387 if modified {
388 new_directives.push(DirectiveWrapper {
389 directive_type: "transaction".to_string(),
390 date: directive.date.clone(),
391 filename: directive.filename.clone(),
392 lineno: directive.lineno,
393 data: DirectiveData::Transaction(TransactionData {
394 flag: txn.flag.clone(),
395 payee: txn.payee.clone(),
396 narration: txn.narration.clone(),
397 tags: txn.tags.clone(),
398 links: txn.links.clone(),
399 metadata: txn.metadata.clone(),
400 postings: new_postings,
401 }),
402 });
403 } else {
404 new_directives.push(directive);
405 }
406 } else {
407 new_directives.push(directive);
408 }
409 }
410
411 let earliest_date = new_directives
413 .iter()
414 .map(|d| d.date.as_str())
415 .min()
416 .unwrap_or("1970-01-01")
417 .to_string();
418
419 let mut open_directives: Vec<DirectiveWrapper> = new_accounts
420 .iter()
421 .map(|account| DirectiveWrapper {
422 directive_type: "open".to_string(),
423 date: earliest_date.clone(),
424 filename: Some("<gain_loss>".to_string()),
425 lineno: Some(0),
426 data: DirectiveData::Open(OpenData {
427 account: account.clone(),
428 currencies: vec![],
429 booking: None,
430 metadata: vec![],
431 }),
432 })
433 .collect();
434
435 open_directives.extend(new_directives);
436
437 PluginOutput {
438 directives: open_directives,
439 errors: Vec::new(),
440 }
441}
442
443fn parse_long_short_config(config: &str) -> Option<LongShortConfig> {
446 let re =
448 Regex::new(r"'([^']+)'\s*:\s*\[\s*'([^']*)'\s*,\s*'([^']*)'\s*,\s*'([^']*)'\s*\]").ok()?;
449
450 let cap = re.captures(config)?;
451 let pattern = Regex::new(&cap[1]).ok()?;
452 let account_to_replace = cap[2].to_string();
453 let short_replacement = cap[3].to_string();
454 let long_replacement = cap[4].to_string();
455
456 Some(LongShortConfig {
457 pattern,
458 account_to_replace,
459 short_replacement,
460 long_replacement,
461 })
462}
463
464fn parse_gain_loss_config(config: &str) -> Option<GainLossConfig> {
467 let re =
469 Regex::new(r"'([^']+)'\s*:\s*\[\s*'([^']*)'\s*,\s*'([^']*)'\s*,\s*'([^']*)'\s*\]").ok()?;
470
471 let cap = re.captures(config)?;
472 let pattern = Regex::new(&cap[1]).ok()?;
473 let account_to_replace = cap[2].to_string();
474 let gains_replacement = cap[3].to_string();
475 let losses_replacement = cap[4].to_string();
476
477 Some(GainLossConfig {
478 pattern,
479 account_to_replace,
480 gains_replacement,
481 losses_replacement,
482 })
483}
484
485fn format_decimal(d: Decimal) -> String {
487 let s = d.to_string();
488 if s.contains('.') {
489 s.trim_end_matches('0').trim_end_matches('.').to_string()
490 } else {
491 s
492 }
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498 use crate::types::*;
499
500 #[test]
501 fn test_parse_long_short_config() {
502 let config = "{'Income.*:Capital-Gains': [':Capital-Gains', ':Capital-Gains:Short', ':Capital-Gains:Long']}";
503 let parsed = parse_long_short_config(config);
504 assert!(parsed.is_some());
505 let cfg = parsed.unwrap();
506 assert_eq!(cfg.account_to_replace, ":Capital-Gains");
507 assert_eq!(cfg.short_replacement, ":Capital-Gains:Short");
508 assert_eq!(cfg.long_replacement, ":Capital-Gains:Long");
509 }
510
511 #[test]
512 fn test_parse_gain_loss_config() {
513 let config = "{'Income.*:Long': [':Long', ':Long:Gains', ':Long:Losses']}";
514 let parsed = parse_gain_loss_config(config);
515 assert!(parsed.is_some());
516 let cfg = parsed.unwrap();
517 assert_eq!(cfg.account_to_replace, ":Long");
518 assert_eq!(cfg.gains_replacement, ":Long:Gains");
519 assert_eq!(cfg.losses_replacement, ":Long:Losses");
520 }
521
522 #[test]
523 fn test_gain_loss_classification() {
524 let plugin = CapitalGainsGainLossPlugin;
525
526 let input = PluginInput {
527 directives: vec![DirectiveWrapper {
528 directive_type: "transaction".to_string(),
529 date: "2024-01-15".to_string(),
530 filename: None,
531 lineno: None,
532 data: DirectiveData::Transaction(TransactionData {
533 flag: "*".to_string(),
534 payee: None,
535 narration: "Sell stock".to_string(),
536 tags: vec![],
537 links: vec![],
538 metadata: vec![],
539 postings: vec![
540 PostingData {
541 account: "Assets:Broker".to_string(),
542 units: Some(AmountData {
543 number: "1000".to_string(),
544 currency: "USD".to_string(),
545 }),
546 cost: None,
547 price: None,
548 flag: None,
549 metadata: vec![],
550 },
551 PostingData {
552 account: "Income:Capital-Gains:Long".to_string(),
553 units: Some(AmountData {
554 number: "-100".to_string(),
555 currency: "USD".to_string(),
556 }),
557 cost: None,
558 price: None,
559 flag: None,
560 metadata: vec![],
561 },
562 ],
563 }),
564 }],
565 options: PluginOptions {
566 operating_currencies: vec!["USD".to_string()],
567 title: None,
568 },
569 config: Some(
570 "{'Income.*:Capital-Gains:Long': [':Long', ':Long:Gains', ':Long:Losses']}"
571 .to_string(),
572 ),
573 };
574
575 let output = plugin.process(input);
576 assert_eq!(output.errors.len(), 0);
577
578 let txn = output
580 .directives
581 .iter()
582 .find(|d| d.directive_type == "transaction");
583 assert!(txn.is_some());
584
585 if let DirectiveData::Transaction(t) = &txn.unwrap().data {
586 let gains_posting = t
588 .postings
589 .iter()
590 .find(|p| p.account.contains(":Long:Gains"));
591 assert!(gains_posting.is_some());
592 }
593 }
594}