Skip to main content

rustledger_plugin/native/plugins/
capital_gains_classifier.rs

1//! Capital gains classifier plugin.
2//!
3//! This plugin rebooks capital gains into separate accounts based on:
4//! - **`long_short`**: Whether gains are long-term (held > 1 year) or short-term
5//! - **`gain_loss`**: Whether the posting is a gain (negative income) or loss (positive income)
6//!
7//! Usage for `long_short`:
8//! ```text
9//! plugin "beancount_reds_plugins.capital_gains_classifier.long_short" "{
10//!   'Income.*:Capital-Gains': [':Capital-Gains', ':Capital-Gains:Short', ':Capital-Gains:Long']
11//! }"
12//! ```
13//!
14//! Usage for `gain_loss`:
15//! ```text
16//! plugin "beancount_reds_plugins.capital_gains_classifier.gain_loss" "{
17//!   'Income.*:Capital-Gains:Long': [':Long', ':Long:Gains', ':Long:Losses']
18//! }"
19//! ```
20
21use chrono::{Datelike, NaiveDate};
22use regex::Regex;
23use rust_decimal::Decimal;
24use std::collections::HashSet;
25use std::str::FromStr;
26use std::sync::LazyLock;
27
28use crate::types::{
29    AmountData, DirectiveData, DirectiveWrapper, OpenData, PluginInput, PluginOutput, PostingData,
30    TransactionData,
31};
32
33use super::super::NativePlugin;
34
35/// Regex for parsing capital gains config entries.
36/// Format: `'pattern': ['to_replace', 'repl1', 'repl2']`
37static CONFIG_ENTRY_RE: LazyLock<Regex> = LazyLock::new(|| {
38    Regex::new(r"'([^']+)'\s*:\s*\[\s*'([^']*)'\s*,\s*'([^']*)'\s*,\s*'([^']*)'\s*\]")
39        .expect("CONFIG_ENTRY_RE: invalid regex pattern")
40});
41
42/// Plugin for classifying capital gains into long/short term categories.
43pub struct CapitalGainsLongShortPlugin;
44
45/// Plugin for classifying capital gains into gains/losses categories.
46pub struct CapitalGainsGainLossPlugin;
47
48impl NativePlugin for CapitalGainsLongShortPlugin {
49    fn name(&self) -> &'static str {
50        "long_short"
51    }
52
53    fn description(&self) -> &'static str {
54        "Classify capital gains into long-term vs short-term based on holding period"
55    }
56
57    fn process(&self, input: PluginInput) -> PluginOutput {
58        process_long_short(input)
59    }
60}
61
62impl NativePlugin for CapitalGainsGainLossPlugin {
63    fn name(&self) -> &'static str {
64        "gain_loss"
65    }
66
67    fn description(&self) -> &'static str {
68        "Classify capital gains into gains vs losses based on posting amount"
69    }
70
71    fn process(&self, input: PluginInput) -> PluginOutput {
72        process_gain_loss(input)
73    }
74}
75
76/// Configuration for `long_short` classification.
77struct LongShortConfig {
78    pattern: Regex,
79    account_to_replace: String,
80    short_replacement: String,
81    long_replacement: String,
82}
83
84/// Configuration for `gain_loss` classification.
85struct GainLossConfig {
86    pattern: Regex,
87    account_to_replace: String,
88    gains_replacement: String,
89    losses_replacement: String,
90}
91
92/// Process entries with `long_short` classification.
93fn process_long_short(input: PluginInput) -> PluginOutput {
94    let config = match &input.config {
95        Some(c) => match parse_long_short_config(c) {
96            Some(cfg) => cfg,
97            None => {
98                return PluginOutput {
99                    directives: input.directives,
100                    errors: Vec::new(),
101                };
102            }
103        },
104        None => {
105            return PluginOutput {
106                directives: input.directives,
107                errors: Vec::new(),
108            };
109        }
110    };
111
112    let mut new_accounts: HashSet<String> = HashSet::new();
113    let mut new_directives: Vec<DirectiveWrapper> = Vec::new();
114
115    for directive in input.directives {
116        if directive.directive_type != "transaction" {
117            new_directives.push(directive);
118            continue;
119        }
120
121        if let DirectiveData::Transaction(txn) = &directive.data {
122            // Check if transaction has matching capital gains postings
123            let has_generic = txn
124                .postings
125                .iter()
126                .any(|p| config.pattern.is_match(&p.account));
127            let has_specific = txn.postings.iter().any(|p| {
128                p.account.contains(&config.short_replacement)
129                    || p.account.contains(&config.long_replacement)
130            });
131
132            if !has_generic || has_specific {
133                new_directives.push(directive);
134                continue;
135            }
136
137            // Find reduction postings (sales with cost and price)
138            let reductions: Vec<&PostingData> = txn
139                .postings
140                .iter()
141                .filter(|p| p.cost.is_some() && p.units.is_some() && p.price.is_some())
142                .collect();
143
144            if reductions.is_empty() {
145                new_directives.push(directive);
146                continue;
147            }
148
149            // Calculate short vs long gains
150            let entry_date = if let Ok(d) = NaiveDate::parse_from_str(&directive.date, "%Y-%m-%d") {
151                d
152            } else {
153                new_directives.push(directive);
154                continue;
155            };
156
157            let mut short_gains = Decimal::ZERO;
158            let mut long_gains = Decimal::ZERO;
159
160            for posting in &reductions {
161                if let (Some(cost), Some(units), Some(price)) =
162                    (&posting.cost, &posting.units, &posting.price)
163                {
164                    // Get cost date
165                    let cost_date = cost
166                        .date
167                        .as_ref()
168                        .and_then(|d| NaiveDate::parse_from_str(d, "%Y-%m-%d").ok());
169
170                    if let Some(cost_date) = cost_date {
171                        // Calculate gain
172                        let cost_number = cost
173                            .number_per
174                            .as_ref()
175                            .and_then(|n| Decimal::from_str(n).ok())
176                            .unwrap_or(Decimal::ZERO);
177                        let price_number = price
178                            .amount
179                            .as_ref()
180                            .and_then(|a| Decimal::from_str(&a.number).ok())
181                            .unwrap_or(Decimal::ZERO);
182                        let units_number =
183                            Decimal::from_str(&units.number).unwrap_or(Decimal::ZERO);
184
185                        let gain = (cost_number - price_number) * units_number.abs();
186
187                        // Check if long-term (> 1 year)
188                        let years_held = entry_date.years_since(cost_date).unwrap_or(0);
189                        let is_long_term = years_held > 1
190                            || (years_held == 1
191                                && (entry_date.month() > cost_date.month()
192                                    || (entry_date.month() == cost_date.month()
193                                        && entry_date.day() >= cost_date.day())));
194
195                        if is_long_term {
196                            long_gains += gain;
197                        } else {
198                            short_gains += gain;
199                        }
200                    }
201                }
202            }
203
204            // Find and remove original capital gains postings
205            let orig_postings: Vec<&PostingData> = txn
206                .postings
207                .iter()
208                .filter(|p| config.pattern.is_match(&p.account))
209                .collect();
210
211            if orig_postings.is_empty() {
212                new_directives.push(directive);
213                continue;
214            }
215
216            let orig_sum: Decimal = orig_postings
217                .iter()
218                .filter_map(|p| p.units.as_ref())
219                .filter_map(|u| Decimal::from_str(&u.number).ok())
220                .sum();
221
222            // Adjust for rounding differences
223            let diff = orig_sum - (short_gains + long_gains);
224            if diff.abs() > Decimal::new(1, 6) {
225                let total = short_gains + long_gains;
226                if total != Decimal::ZERO {
227                    short_gains += (short_gains / total) * diff;
228                    long_gains += (long_gains / total) * diff;
229                }
230            }
231
232            // Create new postings
233            let mut new_postings: Vec<PostingData> = txn
234                .postings
235                .iter()
236                .filter(|p| !config.pattern.is_match(&p.account))
237                .cloned()
238                .collect();
239
240            let template = orig_postings[0];
241
242            if short_gains != Decimal::ZERO {
243                let new_account = template
244                    .account
245                    .replace(&config.account_to_replace, &config.short_replacement);
246                new_accounts.insert(new_account.clone());
247                new_postings.push(PostingData {
248                    account: new_account,
249                    units: template.units.as_ref().map(|u| AmountData {
250                        number: format_decimal(short_gains),
251                        currency: u.currency.clone(),
252                    }),
253                    cost: None,
254                    price: None,
255                    flag: template.flag.clone(),
256                    metadata: vec![],
257                });
258            }
259
260            if long_gains != Decimal::ZERO {
261                let new_account = template
262                    .account
263                    .replace(&config.account_to_replace, &config.long_replacement);
264                new_accounts.insert(new_account.clone());
265                new_postings.push(PostingData {
266                    account: new_account,
267                    units: template.units.as_ref().map(|u| AmountData {
268                        number: format_decimal(long_gains),
269                        currency: u.currency.clone(),
270                    }),
271                    cost: None,
272                    price: None,
273                    flag: template.flag.clone(),
274                    metadata: vec![],
275                });
276            }
277
278            new_directives.push(DirectiveWrapper {
279                directive_type: "transaction".to_string(),
280                date: directive.date.clone(),
281                filename: directive.filename.clone(),
282                lineno: directive.lineno,
283                data: DirectiveData::Transaction(TransactionData {
284                    flag: txn.flag.clone(),
285                    payee: txn.payee.clone(),
286                    narration: txn.narration.clone(),
287                    tags: txn.tags.clone(),
288                    links: txn.links.clone(),
289                    metadata: txn.metadata.clone(),
290                    postings: new_postings,
291                }),
292            });
293        } else {
294            new_directives.push(directive);
295        }
296    }
297
298    // Create Open directives for new accounts
299    let earliest_date = new_directives
300        .iter()
301        .map(|d| d.date.as_str())
302        .min()
303        .unwrap_or("1970-01-01")
304        .to_string();
305
306    let mut open_directives: Vec<DirectiveWrapper> = new_accounts
307        .iter()
308        .map(|account| DirectiveWrapper {
309            directive_type: "open".to_string(),
310            date: earliest_date.clone(),
311            filename: Some("<long_short>".to_string()),
312            lineno: Some(0),
313            data: DirectiveData::Open(OpenData {
314                account: account.clone(),
315                currencies: vec![],
316                booking: None,
317                metadata: vec![],
318            }),
319        })
320        .collect();
321
322    open_directives.extend(new_directives);
323
324    PluginOutput {
325        directives: open_directives,
326        errors: Vec::new(),
327    }
328}
329
330/// Process entries with `gain_loss` classification.
331fn process_gain_loss(input: PluginInput) -> PluginOutput {
332    let config = match &input.config {
333        Some(c) => match parse_gain_loss_config(c) {
334            Some(cfg) => cfg,
335            None => {
336                return PluginOutput {
337                    directives: input.directives,
338                    errors: Vec::new(),
339                };
340            }
341        },
342        None => {
343            return PluginOutput {
344                directives: input.directives,
345                errors: Vec::new(),
346            };
347        }
348    };
349
350    let mut new_accounts: HashSet<String> = HashSet::new();
351    let mut new_directives: Vec<DirectiveWrapper> = Vec::new();
352
353    for directive in input.directives {
354        if directive.directive_type != "transaction" {
355            new_directives.push(directive);
356            continue;
357        }
358
359        if let DirectiveData::Transaction(txn) = &directive.data {
360            let mut modified = false;
361            let mut new_postings: Vec<PostingData> = Vec::new();
362
363            for posting in &txn.postings {
364                if config.pattern.is_match(&posting.account)
365                    && let Some(units) = &posting.units
366                    && let Ok(number) = Decimal::from_str(&units.number)
367                {
368                    let new_account = if number < Decimal::ZERO {
369                        // Negative = gains (income is negative)
370                        posting
371                            .account
372                            .replace(&config.account_to_replace, &config.gains_replacement)
373                    } else {
374                        // Positive = losses
375                        posting
376                            .account
377                            .replace(&config.account_to_replace, &config.losses_replacement)
378                    };
379
380                    new_accounts.insert(new_account.clone());
381                    new_postings.push(PostingData {
382                        account: new_account,
383                        units: posting.units.clone(),
384                        cost: posting.cost.clone(),
385                        price: posting.price.clone(),
386                        flag: posting.flag.clone(),
387                        metadata: posting.metadata.clone(),
388                    });
389                    modified = true;
390                    continue;
391                }
392                new_postings.push(posting.clone());
393            }
394
395            if modified {
396                new_directives.push(DirectiveWrapper {
397                    directive_type: "transaction".to_string(),
398                    date: directive.date.clone(),
399                    filename: directive.filename.clone(),
400                    lineno: directive.lineno,
401                    data: DirectiveData::Transaction(TransactionData {
402                        flag: txn.flag.clone(),
403                        payee: txn.payee.clone(),
404                        narration: txn.narration.clone(),
405                        tags: txn.tags.clone(),
406                        links: txn.links.clone(),
407                        metadata: txn.metadata.clone(),
408                        postings: new_postings,
409                    }),
410                });
411            } else {
412                new_directives.push(directive);
413            }
414        } else {
415            new_directives.push(directive);
416        }
417    }
418
419    // Create Open directives for new accounts
420    let earliest_date = new_directives
421        .iter()
422        .map(|d| d.date.as_str())
423        .min()
424        .unwrap_or("1970-01-01")
425        .to_string();
426
427    let mut open_directives: Vec<DirectiveWrapper> = new_accounts
428        .iter()
429        .map(|account| DirectiveWrapper {
430            directive_type: "open".to_string(),
431            date: earliest_date.clone(),
432            filename: Some("<gain_loss>".to_string()),
433            lineno: Some(0),
434            data: DirectiveData::Open(OpenData {
435                account: account.clone(),
436                currencies: vec![],
437                booking: None,
438                metadata: vec![],
439            }),
440        })
441        .collect();
442
443    open_directives.extend(new_directives);
444
445    PluginOutput {
446        directives: open_directives,
447        errors: Vec::new(),
448    }
449}
450
451/// Parse `long_short` configuration.
452/// Format: `{'pattern': ['to_replace', 'short_repl', 'long_repl']}`
453fn parse_long_short_config(config: &str) -> Option<LongShortConfig> {
454    // Parse pattern: 'key': ['val1', 'val2', 'val3']
455    let cap = CONFIG_ENTRY_RE.captures(config)?;
456    let pattern = Regex::new(&cap[1]).ok()?;
457    let account_to_replace = cap[2].to_string();
458    let short_replacement = cap[3].to_string();
459    let long_replacement = cap[4].to_string();
460
461    Some(LongShortConfig {
462        pattern,
463        account_to_replace,
464        short_replacement,
465        long_replacement,
466    })
467}
468
469/// Parse `gain_loss` configuration.
470/// Format: `{'pattern': ['to_replace', 'gains_repl', 'losses_repl']}`
471fn parse_gain_loss_config(config: &str) -> Option<GainLossConfig> {
472    // Parse pattern: 'key': ['val1', 'val2', 'val3']
473    let cap = CONFIG_ENTRY_RE.captures(config)?;
474    let pattern = Regex::new(&cap[1]).ok()?;
475    let account_to_replace = cap[2].to_string();
476    let gains_replacement = cap[3].to_string();
477    let losses_replacement = cap[4].to_string();
478
479    Some(GainLossConfig {
480        pattern,
481        account_to_replace,
482        gains_replacement,
483        losses_replacement,
484    })
485}
486
487/// Format a decimal number.
488fn format_decimal(d: Decimal) -> String {
489    let s = d.to_string();
490    if s.contains('.') {
491        s.trim_end_matches('0').trim_end_matches('.').to_string()
492    } else {
493        s
494    }
495}
496
497#[cfg(test)]
498mod tests {
499    use super::*;
500    use crate::types::*;
501
502    #[test]
503    fn test_parse_long_short_config() {
504        let config = "{'Income.*:Capital-Gains': [':Capital-Gains', ':Capital-Gains:Short', ':Capital-Gains:Long']}";
505        let parsed = parse_long_short_config(config);
506        assert!(parsed.is_some());
507        let cfg = parsed.unwrap();
508        assert_eq!(cfg.account_to_replace, ":Capital-Gains");
509        assert_eq!(cfg.short_replacement, ":Capital-Gains:Short");
510        assert_eq!(cfg.long_replacement, ":Capital-Gains:Long");
511    }
512
513    #[test]
514    fn test_parse_gain_loss_config() {
515        let config = "{'Income.*:Long': [':Long', ':Long:Gains', ':Long:Losses']}";
516        let parsed = parse_gain_loss_config(config);
517        assert!(parsed.is_some());
518        let cfg = parsed.unwrap();
519        assert_eq!(cfg.account_to_replace, ":Long");
520        assert_eq!(cfg.gains_replacement, ":Long:Gains");
521        assert_eq!(cfg.losses_replacement, ":Long:Losses");
522    }
523
524    #[test]
525    fn test_gain_loss_classification() {
526        let plugin = CapitalGainsGainLossPlugin;
527
528        let input = PluginInput {
529            directives: vec![DirectiveWrapper {
530                directive_type: "transaction".to_string(),
531                date: "2024-01-15".to_string(),
532                filename: None,
533                lineno: None,
534                data: DirectiveData::Transaction(TransactionData {
535                    flag: "*".to_string(),
536                    payee: None,
537                    narration: "Sell stock".to_string(),
538                    tags: vec![],
539                    links: vec![],
540                    metadata: vec![],
541                    postings: vec![
542                        PostingData {
543                            account: "Assets:Broker".to_string(),
544                            units: Some(AmountData {
545                                number: "1000".to_string(),
546                                currency: "USD".to_string(),
547                            }),
548                            cost: None,
549                            price: None,
550                            flag: None,
551                            metadata: vec![],
552                        },
553                        PostingData {
554                            account: "Income:Capital-Gains:Long".to_string(),
555                            units: Some(AmountData {
556                                number: "-100".to_string(),
557                                currency: "USD".to_string(),
558                            }),
559                            cost: None,
560                            price: None,
561                            flag: None,
562                            metadata: vec![],
563                        },
564                    ],
565                }),
566            }],
567            options: PluginOptions {
568                operating_currencies: vec!["USD".to_string()],
569                title: None,
570            },
571            config: Some(
572                "{'Income.*:Capital-Gains:Long': [':Long', ':Long:Gains', ':Long:Losses']}"
573                    .to_string(),
574            ),
575        };
576
577        let output = plugin.process(input);
578        assert_eq!(output.errors.len(), 0);
579
580        // Find the transaction
581        let txn = output
582            .directives
583            .iter()
584            .find(|d| d.directive_type == "transaction");
585        assert!(txn.is_some());
586
587        if let DirectiveData::Transaction(t) = &txn.unwrap().data {
588            // The negative posting should be renamed to :Long:Gains
589            let gains_posting = t
590                .postings
591                .iter()
592                .find(|p| p.account.contains(":Long:Gains"));
593            assert!(gains_posting.is_some());
594        }
595    }
596}