Skip to main content

rustledger_plugin/native/plugins/
capital_gains_classifier.rs

1//! Capital gains classifier plugin.
2//!
3//! This plugin rebooks capital gains into separate accounts based on:
4//! - **`long_short`**: Whether gains are long-term (held > 1 year) or short-term
5//! - **`gain_loss`**: Whether the posting is a gain (negative income) or loss (positive income)
6//!
7//! Usage for `long_short`:
8//! ```text
9//! plugin "beancount_reds_plugins.capital_gains_classifier.long_short" "{
10//!   'Income.*:Capital-Gains': [':Capital-Gains', ':Capital-Gains:Short', ':Capital-Gains:Long']
11//! }"
12//! ```
13//!
14//! Usage for `gain_loss`:
15//! ```text
16//! plugin "beancount_reds_plugins.capital_gains_classifier.gain_loss" "{
17//!   'Income.*:Capital-Gains:Long': [':Long', ':Long:Gains', ':Long:Losses']
18//! }"
19//! ```
20
21use regex::Regex;
22use rust_decimal::Decimal;
23use rustledger_core::NaiveDate;
24use std::collections::HashSet;
25use std::str::FromStr;
26use std::sync::LazyLock;
27
28use crate::types::{
29    AmountData, DirectiveData, DirectiveWrapper, OpenData, PluginInput, PluginOp, PluginOutput,
30    PostingData, TransactionData,
31};
32
33use super::super::{NativePlugin, RegularPlugin};
34
35/// Regex for parsing capital gains config entries.
36/// Format: `'pattern': ['to_replace', 'repl1', 'repl2']`
37static CONFIG_ENTRY_RE: LazyLock<Regex> = LazyLock::new(|| {
38    Regex::new(r"'([^']+)'\s*:\s*\[\s*'([^']*)'\s*,\s*'([^']*)'\s*,\s*'([^']*)'\s*\]")
39        .expect("CONFIG_ENTRY_RE: invalid regex pattern")
40});
41
42/// Plugin for classifying capital gains into long/short term categories.
43pub struct CapitalGainsLongShortPlugin;
44
45/// Plugin for classifying capital gains into gains/losses categories.
46pub struct CapitalGainsGainLossPlugin;
47
48impl NativePlugin for CapitalGainsLongShortPlugin {
49    fn name(&self) -> &'static str {
50        "long_short"
51    }
52
53    fn description(&self) -> &'static str {
54        "Classify capital gains into long-term vs short-term based on holding period"
55    }
56
57    fn process(&self, input: PluginInput) -> PluginOutput {
58        process_long_short(input)
59    }
60}
61
62impl NativePlugin for CapitalGainsGainLossPlugin {
63    fn name(&self) -> &'static str {
64        "gain_loss"
65    }
66
67    fn description(&self) -> &'static str {
68        "Classify capital gains into gains vs losses based on posting amount"
69    }
70
71    fn process(&self, input: PluginInput) -> PluginOutput {
72        process_gain_loss(input)
73    }
74}
75
76impl RegularPlugin for CapitalGainsLongShortPlugin {}
77impl RegularPlugin for CapitalGainsGainLossPlugin {}
78
79/// Configuration for `long_short` classification.
80struct LongShortConfig {
81    pattern: Regex,
82    account_to_replace: String,
83    short_replacement: String,
84    long_replacement: String,
85}
86
87/// Configuration for `gain_loss` classification.
88struct GainLossConfig {
89    pattern: Regex,
90    account_to_replace: String,
91    gains_replacement: String,
92    losses_replacement: String,
93}
94
95/// Process entries with `long_short` classification.
96fn process_long_short(input: PluginInput) -> PluginOutput {
97    let config = match &input.config {
98        Some(c) => match parse_long_short_config(c) {
99            Some(cfg) => cfg,
100            None => {
101                return PluginOutput {
102                    ops: (0..input.directives.len()).map(PluginOp::Keep).collect(),
103                    errors: Vec::new(),
104                };
105            }
106        },
107        None => {
108            return PluginOutput {
109                ops: (0..input.directives.len()).map(PluginOp::Keep).collect(),
110                errors: Vec::new(),
111            };
112        }
113    };
114
115    let mut new_accounts: HashSet<String> = HashSet::new();
116    let mut ops: Vec<PluginOp> = Vec::with_capacity(input.directives.len());
117    // Track earliest date from input directives for synthesized Open directives.
118    let earliest_date = input
119        .directives
120        .iter()
121        .map(|d| d.date.as_str())
122        .min()
123        .unwrap_or("1970-01-01")
124        .to_string();
125    // Accounts already opened by the user; we must not synthesize a
126    // duplicate Open or Late validation will emit E1002.
127    let existing_opens: HashSet<String> = input
128        .directives
129        .iter()
130        .filter_map(|d| match &d.data {
131            DirectiveData::Open(open) => Some(open.account.clone()),
132            _ => None,
133        })
134        .collect();
135
136    for (i, directive) in input.directives.into_iter().enumerate() {
137        if directive.directive_type != "transaction" {
138            ops.push(PluginOp::Keep(i));
139            continue;
140        }
141
142        if let DirectiveData::Transaction(txn) = &directive.data {
143            // Check if transaction has matching capital gains postings
144            let has_generic = txn
145                .postings
146                .iter()
147                .any(|p| config.pattern.is_match(&p.account));
148            let has_specific = txn.postings.iter().any(|p| {
149                p.account.contains(&config.short_replacement)
150                    || p.account.contains(&config.long_replacement)
151            });
152
153            if !has_generic || has_specific {
154                ops.push(PluginOp::Keep(i));
155                continue;
156            }
157
158            // Find reduction postings (sales with cost and price)
159            let reductions: Vec<&PostingData> = txn
160                .postings
161                .iter()
162                .filter(|p| p.cost.is_some() && p.units.is_some() && p.price.is_some())
163                .collect();
164
165            if reductions.is_empty() {
166                ops.push(PluginOp::Keep(i));
167                continue;
168            }
169
170            // Calculate short vs long gains
171            let entry_date = if let Ok(d) = directive.date.parse::<NaiveDate>() {
172                d
173            } else {
174                ops.push(PluginOp::Keep(i));
175                continue;
176            };
177
178            // Fall through if ANY reduction lacks a parseable cost
179            // date. Without it the plugin can't classify holding
180            // period, and pre-fix (issue #1010) it would silently
181            // drop the generic Income:Capital-Gains posting in the
182            // post-loop filter, leaving the transaction unbalanced.
183            // Falling through preserves the user's ledger.
184            let any_missing_cost_date = reductions.iter().any(|p| {
185                p.cost
186                    .as_ref()
187                    .and_then(|c| c.date.as_ref())
188                    .and_then(|d| d.parse::<NaiveDate>().ok())
189                    .is_none()
190            });
191            if any_missing_cost_date {
192                ops.push(PluginOp::Keep(i));
193                continue;
194            }
195
196            let mut short_gains = Decimal::ZERO;
197            let mut long_gains = Decimal::ZERO;
198
199            for posting in &reductions {
200                if let (Some(cost), Some(units), Some(price)) =
201                    (&posting.cost, &posting.units, &posting.price)
202                {
203                    // Get cost date
204                    let cost_date = cost.date.as_ref().and_then(|d| d.parse::<NaiveDate>().ok());
205
206                    if let Some(cost_date) = cost_date {
207                        // Calculate gain
208                        // Capital-gains classification operates on
209                        // post-booking transactions where the cost
210                        // carries a per-unit value (both PerUnit and
211                        // PerUnitFromTotal expose per_unit()).
212                        let cost_number = cost
213                            .number
214                            .as_ref()
215                            .and_then(|cn| cn.per_unit())
216                            .map_or(Decimal::ZERO, |s| {
217                                Decimal::from_str(s).unwrap_or(Decimal::ZERO)
218                            });
219                        let price_number = price
220                            .amount
221                            .as_ref()
222                            .and_then(|a| Decimal::from_str(&a.number).ok())
223                            .unwrap_or(Decimal::ZERO);
224                        let units_number =
225                            Decimal::from_str(&units.number).unwrap_or(Decimal::ZERO);
226
227                        let gain = (cost_number - price_number) * units_number.abs();
228
229                        // Check if long-term (> 1 year)
230                        let days_held = entry_date.since(cost_date).map_or(0, |s| s.get_days());
231                        let years_held = (days_held / 365) as u32;
232                        let is_long_term = years_held > 1
233                            || (years_held == 1
234                                && (entry_date.month() > cost_date.month()
235                                    || (entry_date.month() == cost_date.month()
236                                        && entry_date.day() >= cost_date.day())));
237
238                        if is_long_term {
239                            long_gains += gain;
240                        } else {
241                            short_gains += gain;
242                        }
243                    }
244                }
245            }
246
247            // Find and remove original capital gains postings
248            let orig_postings: Vec<&PostingData> = txn
249                .postings
250                .iter()
251                .filter(|p| config.pattern.is_match(&p.account))
252                .collect();
253
254            if orig_postings.is_empty() {
255                ops.push(PluginOp::Keep(i));
256                continue;
257            }
258
259            let orig_sum: Decimal = orig_postings
260                .iter()
261                .filter_map(|p| p.units.as_ref())
262                .filter_map(|u| Decimal::from_str(&u.number).ok())
263                .sum();
264
265            // Adjust for rounding differences
266            let diff = orig_sum - (short_gains + long_gains);
267            if diff.abs() > Decimal::new(1, 6) {
268                let total = short_gains + long_gains;
269                if total != Decimal::ZERO {
270                    short_gains += (short_gains / total) * diff;
271                    long_gains += (long_gains / total) * diff;
272                }
273            }
274
275            // Create new postings
276            let mut new_postings: Vec<PostingData> = txn
277                .postings
278                .iter()
279                .filter(|p| !config.pattern.is_match(&p.account))
280                .cloned()
281                .collect();
282
283            let template = orig_postings[0];
284
285            if short_gains != Decimal::ZERO {
286                let new_account = template
287                    .account
288                    .replace(&config.account_to_replace, &config.short_replacement);
289                new_accounts.insert(new_account.clone());
290                new_postings.push(PostingData {
291                    account: new_account,
292                    units: template.units.as_ref().map(|u| AmountData {
293                        number: format_decimal(short_gains),
294                        currency: u.currency.clone(),
295                    }),
296                    cost: None,
297                    price: None,
298                    flag: template.flag.clone(),
299                    metadata: vec![],
300                    span: None,
301                });
302            }
303
304            if long_gains != Decimal::ZERO {
305                let new_account = template
306                    .account
307                    .replace(&config.account_to_replace, &config.long_replacement);
308                new_accounts.insert(new_account.clone());
309                new_postings.push(PostingData {
310                    account: new_account,
311                    units: template.units.as_ref().map(|u| AmountData {
312                        number: format_decimal(long_gains),
313                        currency: u.currency.clone(),
314                    }),
315                    cost: None,
316                    price: None,
317                    flag: template.flag.clone(),
318                    metadata: vec![],
319                    span: None,
320                });
321            }
322
323            ops.push(PluginOp::Modify(
324                i,
325                DirectiveWrapper {
326                    directive_type: "transaction".to_string(),
327                    date: directive.date.clone(),
328                    filename: directive.filename.clone(),
329                    lineno: directive.lineno,
330                    data: DirectiveData::Transaction(TransactionData {
331                        flag: txn.flag.clone(),
332                        payee: txn.payee.clone(),
333                        narration: txn.narration.clone(),
334                        tags: txn.tags.clone(),
335                        links: txn.links.clone(),
336                        metadata: txn.metadata.clone(),
337                        postings: new_postings,
338                    }),
339                },
340            ));
341        } else {
342            ops.push(PluginOp::Keep(i));
343        }
344    }
345
346    // Insert Open directives for newly synthesized accounts the user
347    // hasn't already opened.
348    for account in &new_accounts {
349        if existing_opens.contains(account) {
350            continue;
351        }
352        ops.push(PluginOp::Insert(DirectiveWrapper {
353            directive_type: "open".to_string(),
354            date: earliest_date.clone(),
355            filename: Some("<long_short>".to_string()),
356            lineno: Some(0),
357            data: DirectiveData::Open(OpenData {
358                account: account.clone(),
359                currencies: vec![],
360                booking: None,
361                metadata: vec![],
362            }),
363        }));
364    }
365
366    PluginOutput {
367        ops,
368        errors: Vec::new(),
369    }
370}
371
372/// Process entries with `gain_loss` classification.
373fn process_gain_loss(input: PluginInput) -> PluginOutput {
374    let config = match &input.config {
375        Some(c) => match parse_gain_loss_config(c) {
376            Some(cfg) => cfg,
377            None => {
378                return PluginOutput {
379                    ops: (0..input.directives.len()).map(PluginOp::Keep).collect(),
380                    errors: Vec::new(),
381                };
382            }
383        },
384        None => {
385            return PluginOutput {
386                ops: (0..input.directives.len()).map(PluginOp::Keep).collect(),
387                errors: Vec::new(),
388            };
389        }
390    };
391
392    let mut new_accounts: HashSet<String> = HashSet::new();
393    let mut ops: Vec<PluginOp> = Vec::with_capacity(input.directives.len());
394    // Earliest date from input for synthesized Open directives.
395    let earliest_date = input
396        .directives
397        .iter()
398        .map(|d| d.date.as_str())
399        .min()
400        .unwrap_or("1970-01-01")
401        .to_string();
402    // Accounts already opened by the user; suppress duplicate Opens.
403    let existing_opens: HashSet<String> = input
404        .directives
405        .iter()
406        .filter_map(|d| match &d.data {
407            DirectiveData::Open(open) => Some(open.account.clone()),
408            _ => None,
409        })
410        .collect();
411
412    for (i, directive) in input.directives.into_iter().enumerate() {
413        if directive.directive_type != "transaction" {
414            ops.push(PluginOp::Keep(i));
415            continue;
416        }
417
418        if let DirectiveData::Transaction(txn) = &directive.data {
419            let mut modified = false;
420            let mut new_postings: Vec<PostingData> = Vec::new();
421
422            for posting in &txn.postings {
423                if config.pattern.is_match(&posting.account)
424                    && let Some(units) = &posting.units
425                    && let Ok(number) = Decimal::from_str(&units.number)
426                {
427                    let new_account = if number < Decimal::ZERO {
428                        // Negative = gains (income is negative)
429                        posting
430                            .account
431                            .replace(&config.account_to_replace, &config.gains_replacement)
432                    } else {
433                        // Positive = losses
434                        posting
435                            .account
436                            .replace(&config.account_to_replace, &config.losses_replacement)
437                    };
438
439                    new_accounts.insert(new_account.clone());
440                    new_postings.push(PostingData {
441                        account: new_account,
442                        units: posting.units.clone(),
443                        cost: posting.cost.clone(),
444                        price: posting.price.clone(),
445                        flag: posting.flag.clone(),
446                        metadata: posting.metadata.clone(),
447                        span: None,
448                    });
449                    modified = true;
450                    continue;
451                }
452                new_postings.push(posting.clone());
453            }
454
455            if modified {
456                ops.push(PluginOp::Modify(
457                    i,
458                    DirectiveWrapper {
459                        directive_type: "transaction".to_string(),
460                        date: directive.date.clone(),
461                        filename: directive.filename.clone(),
462                        lineno: directive.lineno,
463                        data: DirectiveData::Transaction(TransactionData {
464                            flag: txn.flag.clone(),
465                            payee: txn.payee.clone(),
466                            narration: txn.narration.clone(),
467                            tags: txn.tags.clone(),
468                            links: txn.links.clone(),
469                            metadata: txn.metadata.clone(),
470                            postings: new_postings,
471                        }),
472                    },
473                ));
474            } else {
475                ops.push(PluginOp::Keep(i));
476            }
477        } else {
478            ops.push(PluginOp::Keep(i));
479        }
480    }
481
482    // Insert Open directives for newly synthesized accounts the user
483    // hasn't already opened.
484    for account in &new_accounts {
485        if existing_opens.contains(account) {
486            continue;
487        }
488        ops.push(PluginOp::Insert(DirectiveWrapper {
489            directive_type: "open".to_string(),
490            date: earliest_date.clone(),
491            filename: Some("<gain_loss>".to_string()),
492            lineno: Some(0),
493            data: DirectiveData::Open(OpenData {
494                account: account.clone(),
495                currencies: vec![],
496                booking: None,
497                metadata: vec![],
498            }),
499        }));
500    }
501
502    PluginOutput {
503        ops,
504        errors: Vec::new(),
505    }
506}
507
508/// Parse `long_short` configuration.
509/// Format: `{'pattern': ['to_replace', 'short_repl', 'long_repl']}`
510fn parse_long_short_config(config: &str) -> Option<LongShortConfig> {
511    // Parse pattern: 'key': ['val1', 'val2', 'val3']
512    let cap = CONFIG_ENTRY_RE.captures(config)?;
513    let pattern = Regex::new(&cap[1]).ok()?;
514    let account_to_replace = cap[2].to_string();
515    let short_replacement = cap[3].to_string();
516    let long_replacement = cap[4].to_string();
517
518    Some(LongShortConfig {
519        pattern,
520        account_to_replace,
521        short_replacement,
522        long_replacement,
523    })
524}
525
526/// Parse `gain_loss` configuration.
527/// Format: `{'pattern': ['to_replace', 'gains_repl', 'losses_repl']}`
528fn parse_gain_loss_config(config: &str) -> Option<GainLossConfig> {
529    // Parse pattern: 'key': ['val1', 'val2', 'val3']
530    let cap = CONFIG_ENTRY_RE.captures(config)?;
531    let pattern = Regex::new(&cap[1]).ok()?;
532    let account_to_replace = cap[2].to_string();
533    let gains_replacement = cap[3].to_string();
534    let losses_replacement = cap[4].to_string();
535
536    Some(GainLossConfig {
537        pattern,
538        account_to_replace,
539        gains_replacement,
540        losses_replacement,
541    })
542}
543
544/// Format a decimal number.
545fn format_decimal(d: Decimal) -> String {
546    let s = d.to_string();
547    if s.contains('.') {
548        s.trim_end_matches('0').trim_end_matches('.').to_string()
549    } else {
550        s
551    }
552}
553
554#[cfg(test)]
555mod tests {
556    use super::super::utils::materialize_ops;
557    use super::*;
558    use crate::types::*;
559
560    #[test]
561    fn test_parse_long_short_config() {
562        let config = "{'Income.*:Capital-Gains': [':Capital-Gains', ':Capital-Gains:Short', ':Capital-Gains:Long']}";
563        let parsed = parse_long_short_config(config);
564        assert!(parsed.is_some());
565        let cfg = parsed.unwrap();
566        assert_eq!(cfg.account_to_replace, ":Capital-Gains");
567        assert_eq!(cfg.short_replacement, ":Capital-Gains:Short");
568        assert_eq!(cfg.long_replacement, ":Capital-Gains:Long");
569    }
570
571    #[test]
572    fn test_parse_gain_loss_config() {
573        let config = "{'Income.*:Long': [':Long', ':Long:Gains', ':Long:Losses']}";
574        let parsed = parse_gain_loss_config(config);
575        assert!(parsed.is_some());
576        let cfg = parsed.unwrap();
577        assert_eq!(cfg.account_to_replace, ":Long");
578        assert_eq!(cfg.gains_replacement, ":Long:Gains");
579        assert_eq!(cfg.losses_replacement, ":Long:Losses");
580    }
581
582    #[test]
583    fn test_gain_loss_classification() {
584        let plugin = CapitalGainsGainLossPlugin;
585
586        let input = PluginInput {
587            directives: vec![DirectiveWrapper {
588                directive_type: "transaction".to_string(),
589                date: "2024-01-15".to_string(),
590                filename: None,
591                lineno: None,
592                data: DirectiveData::Transaction(TransactionData {
593                    flag: "*".to_string(),
594                    payee: None,
595                    narration: "Sell stock".to_string(),
596                    tags: vec![],
597                    links: vec![],
598                    metadata: vec![],
599                    postings: vec![
600                        PostingData {
601                            account: "Assets:Broker".to_string(),
602                            units: Some(AmountData {
603                                number: "1000".to_string(),
604                                currency: "USD".to_string(),
605                            }),
606                            cost: None,
607                            price: None,
608                            flag: None,
609                            metadata: vec![],
610                            span: None,
611                        },
612                        PostingData {
613                            account: "Income:Capital-Gains:Long".to_string(),
614                            units: Some(AmountData {
615                                number: "-100".to_string(),
616                                currency: "USD".to_string(),
617                            }),
618                            cost: None,
619                            price: None,
620                            flag: None,
621                            metadata: vec![],
622                            span: None,
623                        },
624                    ],
625                }),
626            }],
627            options: PluginOptions {
628                operating_currencies: vec!["USD".to_string()],
629                title: None,
630            },
631            config: Some(
632                "{'Income.*:Capital-Gains:Long': [':Long', ':Long:Gains', ':Long:Losses']}"
633                    .to_string(),
634            ),
635        };
636
637        let input_dirs = input.directives.clone();
638        let output = plugin.process(input);
639        assert_eq!(output.errors.len(), 0);
640        let directives = materialize_ops(&input_dirs, &output);
641
642        // Find the transaction
643        let txn = directives
644            .iter()
645            .find(|d| matches!(d.data, DirectiveData::Transaction(_)));
646        assert!(txn.is_some());
647
648        if let DirectiveData::Transaction(t) = &txn.unwrap().data {
649            // The negative posting should be renamed to :Long:Gains
650            let gains_posting = t
651                .postings
652                .iter()
653                .find(|p| p.account.contains(":Long:Gains"));
654            assert!(gains_posting.is_some());
655        }
656    }
657}