Skip to main content

rustledger_plugin/native/plugins/
capital_gains_classifier.rs

1//! Capital gains classifier plugin.
2//!
3//! This plugin rebooks capital gains into separate accounts based on:
4//! - **`long_short`**: Whether gains are long-term (held > 1 year) or short-term
5//! - **`gain_loss`**: Whether the posting is a gain (negative income) or loss (positive income)
6//!
7//! Usage for `long_short`:
8//! ```text
9//! plugin "beancount_reds_plugins.capital_gains_classifier.long_short" "{
10//!   'Income.*:Capital-Gains': [':Capital-Gains', ':Capital-Gains:Short', ':Capital-Gains:Long']
11//! }"
12//! ```
13//!
14//! Usage for `gain_loss`:
15//! ```text
16//! plugin "beancount_reds_plugins.capital_gains_classifier.gain_loss" "{
17//!   'Income.*:Capital-Gains:Long': [':Long', ':Long:Gains', ':Long:Losses']
18//! }"
19//! ```
20
21use regex::Regex;
22use rust_decimal::Decimal;
23use rustledger_core::NaiveDate;
24use std::collections::HashSet;
25use std::str::FromStr;
26use std::sync::LazyLock;
27
28use crate::types::{
29    AmountData, DirectiveData, DirectiveWrapper, OpenData, PluginInput, PluginOp, PluginOutput,
30    PostingData, TransactionData,
31};
32
33use super::super::NativePlugin;
34
35/// Regex for parsing capital gains config entries.
36/// Format: `'pattern': ['to_replace', 'repl1', 'repl2']`
37static CONFIG_ENTRY_RE: LazyLock<Regex> = LazyLock::new(|| {
38    Regex::new(r"'([^']+)'\s*:\s*\[\s*'([^']*)'\s*,\s*'([^']*)'\s*,\s*'([^']*)'\s*\]")
39        .expect("CONFIG_ENTRY_RE: invalid regex pattern")
40});
41
42/// Plugin for classifying capital gains into long/short term categories.
43pub struct CapitalGainsLongShortPlugin;
44
45/// Plugin for classifying capital gains into gains/losses categories.
46pub struct CapitalGainsGainLossPlugin;
47
48impl NativePlugin for CapitalGainsLongShortPlugin {
49    fn name(&self) -> &'static str {
50        "long_short"
51    }
52
53    fn description(&self) -> &'static str {
54        "Classify capital gains into long-term vs short-term based on holding period"
55    }
56
57    fn process(&self, input: PluginInput) -> PluginOutput {
58        process_long_short(input)
59    }
60}
61
62impl NativePlugin for CapitalGainsGainLossPlugin {
63    fn name(&self) -> &'static str {
64        "gain_loss"
65    }
66
67    fn description(&self) -> &'static str {
68        "Classify capital gains into gains vs losses based on posting amount"
69    }
70
71    fn process(&self, input: PluginInput) -> PluginOutput {
72        process_gain_loss(input)
73    }
74}
75
76/// Configuration for `long_short` classification.
77struct LongShortConfig {
78    pattern: Regex,
79    account_to_replace: String,
80    short_replacement: String,
81    long_replacement: String,
82}
83
84/// Configuration for `gain_loss` classification.
85struct GainLossConfig {
86    pattern: Regex,
87    account_to_replace: String,
88    gains_replacement: String,
89    losses_replacement: String,
90}
91
92/// Process entries with `long_short` classification.
93fn process_long_short(input: PluginInput) -> PluginOutput {
94    let config = match &input.config {
95        Some(c) => match parse_long_short_config(c) {
96            Some(cfg) => cfg,
97            None => {
98                return PluginOutput {
99                    ops: (0..input.directives.len()).map(PluginOp::Keep).collect(),
100                    errors: Vec::new(),
101                };
102            }
103        },
104        None => {
105            return PluginOutput {
106                ops: (0..input.directives.len()).map(PluginOp::Keep).collect(),
107                errors: Vec::new(),
108            };
109        }
110    };
111
112    let mut new_accounts: HashSet<String> = HashSet::new();
113    let mut ops: Vec<PluginOp> = Vec::with_capacity(input.directives.len());
114    // Track earliest date from input directives for synthesized Open directives.
115    let earliest_date = input
116        .directives
117        .iter()
118        .map(|d| d.date.as_str())
119        .min()
120        .unwrap_or("1970-01-01")
121        .to_string();
122    // Accounts already opened by the user; we must not synthesize a
123    // duplicate Open or Late validation will emit E1002.
124    let existing_opens: HashSet<String> = input
125        .directives
126        .iter()
127        .filter_map(|d| match &d.data {
128            DirectiveData::Open(open) => Some(open.account.clone()),
129            _ => None,
130        })
131        .collect();
132
133    for (i, directive) in input.directives.into_iter().enumerate() {
134        if directive.directive_type != "transaction" {
135            ops.push(PluginOp::Keep(i));
136            continue;
137        }
138
139        if let DirectiveData::Transaction(txn) = &directive.data {
140            // Check if transaction has matching capital gains postings
141            let has_generic = txn
142                .postings
143                .iter()
144                .any(|p| config.pattern.is_match(&p.account));
145            let has_specific = txn.postings.iter().any(|p| {
146                p.account.contains(&config.short_replacement)
147                    || p.account.contains(&config.long_replacement)
148            });
149
150            if !has_generic || has_specific {
151                ops.push(PluginOp::Keep(i));
152                continue;
153            }
154
155            // Find reduction postings (sales with cost and price)
156            let reductions: Vec<&PostingData> = txn
157                .postings
158                .iter()
159                .filter(|p| p.cost.is_some() && p.units.is_some() && p.price.is_some())
160                .collect();
161
162            if reductions.is_empty() {
163                ops.push(PluginOp::Keep(i));
164                continue;
165            }
166
167            // Calculate short vs long gains
168            let entry_date = if let Ok(d) = directive.date.parse::<NaiveDate>() {
169                d
170            } else {
171                ops.push(PluginOp::Keep(i));
172                continue;
173            };
174
175            // Fall through if ANY reduction lacks a parseable cost
176            // date. Without it the plugin can't classify holding
177            // period, and pre-fix (issue #1010) it would silently
178            // drop the generic Income:Capital-Gains posting in the
179            // post-loop filter, leaving the transaction unbalanced.
180            // Falling through preserves the user's ledger.
181            let any_missing_cost_date = reductions.iter().any(|p| {
182                p.cost
183                    .as_ref()
184                    .and_then(|c| c.date.as_ref())
185                    .and_then(|d| d.parse::<NaiveDate>().ok())
186                    .is_none()
187            });
188            if any_missing_cost_date {
189                ops.push(PluginOp::Keep(i));
190                continue;
191            }
192
193            let mut short_gains = Decimal::ZERO;
194            let mut long_gains = Decimal::ZERO;
195
196            for posting in &reductions {
197                if let (Some(cost), Some(units), Some(price)) =
198                    (&posting.cost, &posting.units, &posting.price)
199                {
200                    // Get cost date
201                    let cost_date = cost.date.as_ref().and_then(|d| d.parse::<NaiveDate>().ok());
202
203                    if let Some(cost_date) = cost_date {
204                        // Calculate gain
205                        let cost_number = cost
206                            .number_per
207                            .as_ref()
208                            .and_then(|n| Decimal::from_str(n).ok())
209                            .unwrap_or(Decimal::ZERO);
210                        let price_number = price
211                            .amount
212                            .as_ref()
213                            .and_then(|a| Decimal::from_str(&a.number).ok())
214                            .unwrap_or(Decimal::ZERO);
215                        let units_number =
216                            Decimal::from_str(&units.number).unwrap_or(Decimal::ZERO);
217
218                        let gain = (cost_number - price_number) * units_number.abs();
219
220                        // Check if long-term (> 1 year)
221                        let days_held = entry_date.since(cost_date).map_or(0, |s| s.get_days());
222                        let years_held = (days_held / 365) as u32;
223                        let is_long_term = years_held > 1
224                            || (years_held == 1
225                                && (entry_date.month() > cost_date.month()
226                                    || (entry_date.month() == cost_date.month()
227                                        && entry_date.day() >= cost_date.day())));
228
229                        if is_long_term {
230                            long_gains += gain;
231                        } else {
232                            short_gains += gain;
233                        }
234                    }
235                }
236            }
237
238            // Find and remove original capital gains postings
239            let orig_postings: Vec<&PostingData> = txn
240                .postings
241                .iter()
242                .filter(|p| config.pattern.is_match(&p.account))
243                .collect();
244
245            if orig_postings.is_empty() {
246                ops.push(PluginOp::Keep(i));
247                continue;
248            }
249
250            let orig_sum: Decimal = orig_postings
251                .iter()
252                .filter_map(|p| p.units.as_ref())
253                .filter_map(|u| Decimal::from_str(&u.number).ok())
254                .sum();
255
256            // Adjust for rounding differences
257            let diff = orig_sum - (short_gains + long_gains);
258            if diff.abs() > Decimal::new(1, 6) {
259                let total = short_gains + long_gains;
260                if total != Decimal::ZERO {
261                    short_gains += (short_gains / total) * diff;
262                    long_gains += (long_gains / total) * diff;
263                }
264            }
265
266            // Create new postings
267            let mut new_postings: Vec<PostingData> = txn
268                .postings
269                .iter()
270                .filter(|p| !config.pattern.is_match(&p.account))
271                .cloned()
272                .collect();
273
274            let template = orig_postings[0];
275
276            if short_gains != Decimal::ZERO {
277                let new_account = template
278                    .account
279                    .replace(&config.account_to_replace, &config.short_replacement);
280                new_accounts.insert(new_account.clone());
281                new_postings.push(PostingData {
282                    account: new_account,
283                    units: template.units.as_ref().map(|u| AmountData {
284                        number: format_decimal(short_gains),
285                        currency: u.currency.clone(),
286                    }),
287                    cost: None,
288                    price: None,
289                    flag: template.flag.clone(),
290                    metadata: vec![],
291                });
292            }
293
294            if long_gains != Decimal::ZERO {
295                let new_account = template
296                    .account
297                    .replace(&config.account_to_replace, &config.long_replacement);
298                new_accounts.insert(new_account.clone());
299                new_postings.push(PostingData {
300                    account: new_account,
301                    units: template.units.as_ref().map(|u| AmountData {
302                        number: format_decimal(long_gains),
303                        currency: u.currency.clone(),
304                    }),
305                    cost: None,
306                    price: None,
307                    flag: template.flag.clone(),
308                    metadata: vec![],
309                });
310            }
311
312            ops.push(PluginOp::Modify(
313                i,
314                DirectiveWrapper {
315                    directive_type: "transaction".to_string(),
316                    date: directive.date.clone(),
317                    filename: directive.filename.clone(),
318                    lineno: directive.lineno,
319                    data: DirectiveData::Transaction(TransactionData {
320                        flag: txn.flag.clone(),
321                        payee: txn.payee.clone(),
322                        narration: txn.narration.clone(),
323                        tags: txn.tags.clone(),
324                        links: txn.links.clone(),
325                        metadata: txn.metadata.clone(),
326                        postings: new_postings,
327                    }),
328                },
329            ));
330        } else {
331            ops.push(PluginOp::Keep(i));
332        }
333    }
334
335    // Insert Open directives for newly synthesized accounts the user
336    // hasn't already opened.
337    for account in &new_accounts {
338        if existing_opens.contains(account) {
339            continue;
340        }
341        ops.push(PluginOp::Insert(DirectiveWrapper {
342            directive_type: "open".to_string(),
343            date: earliest_date.clone(),
344            filename: Some("<long_short>".to_string()),
345            lineno: Some(0),
346            data: DirectiveData::Open(OpenData {
347                account: account.clone(),
348                currencies: vec![],
349                booking: None,
350                metadata: vec![],
351            }),
352        }));
353    }
354
355    PluginOutput {
356        ops,
357        errors: Vec::new(),
358    }
359}
360
361/// Process entries with `gain_loss` classification.
362fn process_gain_loss(input: PluginInput) -> PluginOutput {
363    let config = match &input.config {
364        Some(c) => match parse_gain_loss_config(c) {
365            Some(cfg) => cfg,
366            None => {
367                return PluginOutput {
368                    ops: (0..input.directives.len()).map(PluginOp::Keep).collect(),
369                    errors: Vec::new(),
370                };
371            }
372        },
373        None => {
374            return PluginOutput {
375                ops: (0..input.directives.len()).map(PluginOp::Keep).collect(),
376                errors: Vec::new(),
377            };
378        }
379    };
380
381    let mut new_accounts: HashSet<String> = HashSet::new();
382    let mut ops: Vec<PluginOp> = Vec::with_capacity(input.directives.len());
383    // Earliest date from input for synthesized Open directives.
384    let earliest_date = input
385        .directives
386        .iter()
387        .map(|d| d.date.as_str())
388        .min()
389        .unwrap_or("1970-01-01")
390        .to_string();
391    // Accounts already opened by the user; suppress duplicate Opens.
392    let existing_opens: HashSet<String> = input
393        .directives
394        .iter()
395        .filter_map(|d| match &d.data {
396            DirectiveData::Open(open) => Some(open.account.clone()),
397            _ => None,
398        })
399        .collect();
400
401    for (i, directive) in input.directives.into_iter().enumerate() {
402        if directive.directive_type != "transaction" {
403            ops.push(PluginOp::Keep(i));
404            continue;
405        }
406
407        if let DirectiveData::Transaction(txn) = &directive.data {
408            let mut modified = false;
409            let mut new_postings: Vec<PostingData> = Vec::new();
410
411            for posting in &txn.postings {
412                if config.pattern.is_match(&posting.account)
413                    && let Some(units) = &posting.units
414                    && let Ok(number) = Decimal::from_str(&units.number)
415                {
416                    let new_account = if number < Decimal::ZERO {
417                        // Negative = gains (income is negative)
418                        posting
419                            .account
420                            .replace(&config.account_to_replace, &config.gains_replacement)
421                    } else {
422                        // Positive = losses
423                        posting
424                            .account
425                            .replace(&config.account_to_replace, &config.losses_replacement)
426                    };
427
428                    new_accounts.insert(new_account.clone());
429                    new_postings.push(PostingData {
430                        account: new_account,
431                        units: posting.units.clone(),
432                        cost: posting.cost.clone(),
433                        price: posting.price.clone(),
434                        flag: posting.flag.clone(),
435                        metadata: posting.metadata.clone(),
436                    });
437                    modified = true;
438                    continue;
439                }
440                new_postings.push(posting.clone());
441            }
442
443            if modified {
444                ops.push(PluginOp::Modify(
445                    i,
446                    DirectiveWrapper {
447                        directive_type: "transaction".to_string(),
448                        date: directive.date.clone(),
449                        filename: directive.filename.clone(),
450                        lineno: directive.lineno,
451                        data: DirectiveData::Transaction(TransactionData {
452                            flag: txn.flag.clone(),
453                            payee: txn.payee.clone(),
454                            narration: txn.narration.clone(),
455                            tags: txn.tags.clone(),
456                            links: txn.links.clone(),
457                            metadata: txn.metadata.clone(),
458                            postings: new_postings,
459                        }),
460                    },
461                ));
462            } else {
463                ops.push(PluginOp::Keep(i));
464            }
465        } else {
466            ops.push(PluginOp::Keep(i));
467        }
468    }
469
470    // Insert Open directives for newly synthesized accounts the user
471    // hasn't already opened.
472    for account in &new_accounts {
473        if existing_opens.contains(account) {
474            continue;
475        }
476        ops.push(PluginOp::Insert(DirectiveWrapper {
477            directive_type: "open".to_string(),
478            date: earliest_date.clone(),
479            filename: Some("<gain_loss>".to_string()),
480            lineno: Some(0),
481            data: DirectiveData::Open(OpenData {
482                account: account.clone(),
483                currencies: vec![],
484                booking: None,
485                metadata: vec![],
486            }),
487        }));
488    }
489
490    PluginOutput {
491        ops,
492        errors: Vec::new(),
493    }
494}
495
496/// Parse `long_short` configuration.
497/// Format: `{'pattern': ['to_replace', 'short_repl', 'long_repl']}`
498fn parse_long_short_config(config: &str) -> Option<LongShortConfig> {
499    // Parse pattern: 'key': ['val1', 'val2', 'val3']
500    let cap = CONFIG_ENTRY_RE.captures(config)?;
501    let pattern = Regex::new(&cap[1]).ok()?;
502    let account_to_replace = cap[2].to_string();
503    let short_replacement = cap[3].to_string();
504    let long_replacement = cap[4].to_string();
505
506    Some(LongShortConfig {
507        pattern,
508        account_to_replace,
509        short_replacement,
510        long_replacement,
511    })
512}
513
514/// Parse `gain_loss` configuration.
515/// Format: `{'pattern': ['to_replace', 'gains_repl', 'losses_repl']}`
516fn parse_gain_loss_config(config: &str) -> Option<GainLossConfig> {
517    // Parse pattern: 'key': ['val1', 'val2', 'val3']
518    let cap = CONFIG_ENTRY_RE.captures(config)?;
519    let pattern = Regex::new(&cap[1]).ok()?;
520    let account_to_replace = cap[2].to_string();
521    let gains_replacement = cap[3].to_string();
522    let losses_replacement = cap[4].to_string();
523
524    Some(GainLossConfig {
525        pattern,
526        account_to_replace,
527        gains_replacement,
528        losses_replacement,
529    })
530}
531
532/// Format a decimal number.
533fn format_decimal(d: Decimal) -> String {
534    let s = d.to_string();
535    if s.contains('.') {
536        s.trim_end_matches('0').trim_end_matches('.').to_string()
537    } else {
538        s
539    }
540}
541
542#[cfg(test)]
543mod tests {
544    use super::super::utils::materialize_ops;
545    use super::*;
546    use crate::types::*;
547
548    #[test]
549    fn test_parse_long_short_config() {
550        let config = "{'Income.*:Capital-Gains': [':Capital-Gains', ':Capital-Gains:Short', ':Capital-Gains:Long']}";
551        let parsed = parse_long_short_config(config);
552        assert!(parsed.is_some());
553        let cfg = parsed.unwrap();
554        assert_eq!(cfg.account_to_replace, ":Capital-Gains");
555        assert_eq!(cfg.short_replacement, ":Capital-Gains:Short");
556        assert_eq!(cfg.long_replacement, ":Capital-Gains:Long");
557    }
558
559    #[test]
560    fn test_parse_gain_loss_config() {
561        let config = "{'Income.*:Long': [':Long', ':Long:Gains', ':Long:Losses']}";
562        let parsed = parse_gain_loss_config(config);
563        assert!(parsed.is_some());
564        let cfg = parsed.unwrap();
565        assert_eq!(cfg.account_to_replace, ":Long");
566        assert_eq!(cfg.gains_replacement, ":Long:Gains");
567        assert_eq!(cfg.losses_replacement, ":Long:Losses");
568    }
569
570    #[test]
571    fn test_gain_loss_classification() {
572        let plugin = CapitalGainsGainLossPlugin;
573
574        let input = PluginInput {
575            directives: vec![DirectiveWrapper {
576                directive_type: "transaction".to_string(),
577                date: "2024-01-15".to_string(),
578                filename: None,
579                lineno: None,
580                data: DirectiveData::Transaction(TransactionData {
581                    flag: "*".to_string(),
582                    payee: None,
583                    narration: "Sell stock".to_string(),
584                    tags: vec![],
585                    links: vec![],
586                    metadata: vec![],
587                    postings: vec![
588                        PostingData {
589                            account: "Assets:Broker".to_string(),
590                            units: Some(AmountData {
591                                number: "1000".to_string(),
592                                currency: "USD".to_string(),
593                            }),
594                            cost: None,
595                            price: None,
596                            flag: None,
597                            metadata: vec![],
598                        },
599                        PostingData {
600                            account: "Income:Capital-Gains:Long".to_string(),
601                            units: Some(AmountData {
602                                number: "-100".to_string(),
603                                currency: "USD".to_string(),
604                            }),
605                            cost: None,
606                            price: None,
607                            flag: None,
608                            metadata: vec![],
609                        },
610                    ],
611                }),
612            }],
613            options: PluginOptions {
614                operating_currencies: vec!["USD".to_string()],
615                title: None,
616            },
617            config: Some(
618                "{'Income.*:Capital-Gains:Long': [':Long', ':Long:Gains', ':Long:Losses']}"
619                    .to_string(),
620            ),
621        };
622
623        let input_dirs = input.directives.clone();
624        let output = plugin.process(input);
625        assert_eq!(output.errors.len(), 0);
626        let directives = materialize_ops(&input_dirs, &output);
627
628        // Find the transaction
629        let txn = directives
630            .iter()
631            .find(|d| matches!(d.data, DirectiveData::Transaction(_)));
632        assert!(txn.is_some());
633
634        if let DirectiveData::Transaction(t) = &txn.unwrap().data {
635            // The negative posting should be renamed to :Long:Gains
636            let gains_posting = t
637                .postings
638                .iter()
639                .find(|p| p.account.contains(":Long:Gains"));
640            assert!(gains_posting.is_some());
641        }
642    }
643}