Skip to main content

rustledger_plugin/native/plugins/
capital_gains_classifier.rs

1//! Capital gains classifier plugin.
2//!
3//! This plugin rebooks capital gains into separate accounts based on:
4//! - **`long_short`**: Whether gains are long-term (held > 1 year) or short-term
5//! - **`gain_loss`**: Whether the posting is a gain (negative income) or loss (positive income)
6//!
7//! Usage for `long_short`:
8//! ```text
9//! plugin "beancount_reds_plugins.capital_gains_classifier.long_short" "{
10//!   'Income.*:Capital-Gains': [':Capital-Gains', ':Capital-Gains:Short', ':Capital-Gains:Long']
11//! }"
12//! ```
13//!
14//! Usage for `gain_loss`:
15//! ```text
16//! plugin "beancount_reds_plugins.capital_gains_classifier.gain_loss" "{
17//!   'Income.*:Capital-Gains:Long': [':Long', ':Long:Gains', ':Long:Losses']
18//! }"
19//! ```
20
21use regex::Regex;
22use rust_decimal::Decimal;
23use rustledger_core::NaiveDate;
24use std::collections::HashSet;
25use std::str::FromStr;
26use std::sync::LazyLock;
27
28use crate::types::{
29    AmountData, DirectiveData, DirectiveWrapper, OpenData, PluginInput, PluginOutput, PostingData,
30    TransactionData,
31};
32
33use super::super::NativePlugin;
34
35/// Regex for parsing capital gains config entries.
36/// Format: `'pattern': ['to_replace', 'repl1', 'repl2']`
37static CONFIG_ENTRY_RE: LazyLock<Regex> = LazyLock::new(|| {
38    Regex::new(r"'([^']+)'\s*:\s*\[\s*'([^']*)'\s*,\s*'([^']*)'\s*,\s*'([^']*)'\s*\]")
39        .expect("CONFIG_ENTRY_RE: invalid regex pattern")
40});
41
42/// Plugin for classifying capital gains into long/short term categories.
43pub struct CapitalGainsLongShortPlugin;
44
45/// Plugin for classifying capital gains into gains/losses categories.
46pub struct CapitalGainsGainLossPlugin;
47
48impl NativePlugin for CapitalGainsLongShortPlugin {
49    fn name(&self) -> &'static str {
50        "long_short"
51    }
52
53    fn description(&self) -> &'static str {
54        "Classify capital gains into long-term vs short-term based on holding period"
55    }
56
57    fn process(&self, input: PluginInput) -> PluginOutput {
58        process_long_short(input)
59    }
60}
61
62impl NativePlugin for CapitalGainsGainLossPlugin {
63    fn name(&self) -> &'static str {
64        "gain_loss"
65    }
66
67    fn description(&self) -> &'static str {
68        "Classify capital gains into gains vs losses based on posting amount"
69    }
70
71    fn process(&self, input: PluginInput) -> PluginOutput {
72        process_gain_loss(input)
73    }
74}
75
76/// Configuration for `long_short` classification.
77struct LongShortConfig {
78    pattern: Regex,
79    account_to_replace: String,
80    short_replacement: String,
81    long_replacement: String,
82}
83
84/// Configuration for `gain_loss` classification.
85struct GainLossConfig {
86    pattern: Regex,
87    account_to_replace: String,
88    gains_replacement: String,
89    losses_replacement: String,
90}
91
92/// Process entries with `long_short` classification.
93fn process_long_short(input: PluginInput) -> PluginOutput {
94    let config = match &input.config {
95        Some(c) => match parse_long_short_config(c) {
96            Some(cfg) => cfg,
97            None => {
98                return PluginOutput {
99                    directives: input.directives,
100                    errors: Vec::new(),
101                };
102            }
103        },
104        None => {
105            return PluginOutput {
106                directives: input.directives,
107                errors: Vec::new(),
108            };
109        }
110    };
111
112    let mut new_accounts: HashSet<String> = HashSet::new();
113    let mut new_directives: Vec<DirectiveWrapper> = Vec::new();
114
115    for directive in input.directives {
116        if directive.directive_type != "transaction" {
117            new_directives.push(directive);
118            continue;
119        }
120
121        if let DirectiveData::Transaction(txn) = &directive.data {
122            // Check if transaction has matching capital gains postings
123            let has_generic = txn
124                .postings
125                .iter()
126                .any(|p| config.pattern.is_match(&p.account));
127            let has_specific = txn.postings.iter().any(|p| {
128                p.account.contains(&config.short_replacement)
129                    || p.account.contains(&config.long_replacement)
130            });
131
132            if !has_generic || has_specific {
133                new_directives.push(directive);
134                continue;
135            }
136
137            // Find reduction postings (sales with cost and price)
138            let reductions: Vec<&PostingData> = txn
139                .postings
140                .iter()
141                .filter(|p| p.cost.is_some() && p.units.is_some() && p.price.is_some())
142                .collect();
143
144            if reductions.is_empty() {
145                new_directives.push(directive);
146                continue;
147            }
148
149            // Calculate short vs long gains
150            let entry_date = if let Ok(d) = directive.date.parse::<NaiveDate>() {
151                d
152            } else {
153                new_directives.push(directive);
154                continue;
155            };
156
157            let mut short_gains = Decimal::ZERO;
158            let mut long_gains = Decimal::ZERO;
159
160            for posting in &reductions {
161                if let (Some(cost), Some(units), Some(price)) =
162                    (&posting.cost, &posting.units, &posting.price)
163                {
164                    // Get cost date
165                    let cost_date = cost.date.as_ref().and_then(|d| d.parse::<NaiveDate>().ok());
166
167                    if let Some(cost_date) = cost_date {
168                        // Calculate gain
169                        let cost_number = cost
170                            .number_per
171                            .as_ref()
172                            .and_then(|n| Decimal::from_str(n).ok())
173                            .unwrap_or(Decimal::ZERO);
174                        let price_number = price
175                            .amount
176                            .as_ref()
177                            .and_then(|a| Decimal::from_str(&a.number).ok())
178                            .unwrap_or(Decimal::ZERO);
179                        let units_number =
180                            Decimal::from_str(&units.number).unwrap_or(Decimal::ZERO);
181
182                        let gain = (cost_number - price_number) * units_number.abs();
183
184                        // Check if long-term (> 1 year)
185                        let days_held = entry_date.since(cost_date).map_or(0, |s| s.get_days());
186                        let years_held = (days_held / 365) as u32;
187                        let is_long_term = years_held > 1
188                            || (years_held == 1
189                                && (entry_date.month() > cost_date.month()
190                                    || (entry_date.month() == cost_date.month()
191                                        && entry_date.day() >= cost_date.day())));
192
193                        if is_long_term {
194                            long_gains += gain;
195                        } else {
196                            short_gains += gain;
197                        }
198                    }
199                }
200            }
201
202            // Find and remove original capital gains postings
203            let orig_postings: Vec<&PostingData> = txn
204                .postings
205                .iter()
206                .filter(|p| config.pattern.is_match(&p.account))
207                .collect();
208
209            if orig_postings.is_empty() {
210                new_directives.push(directive);
211                continue;
212            }
213
214            let orig_sum: Decimal = orig_postings
215                .iter()
216                .filter_map(|p| p.units.as_ref())
217                .filter_map(|u| Decimal::from_str(&u.number).ok())
218                .sum();
219
220            // Adjust for rounding differences
221            let diff = orig_sum - (short_gains + long_gains);
222            if diff.abs() > Decimal::new(1, 6) {
223                let total = short_gains + long_gains;
224                if total != Decimal::ZERO {
225                    short_gains += (short_gains / total) * diff;
226                    long_gains += (long_gains / total) * diff;
227                }
228            }
229
230            // Create new postings
231            let mut new_postings: Vec<PostingData> = txn
232                .postings
233                .iter()
234                .filter(|p| !config.pattern.is_match(&p.account))
235                .cloned()
236                .collect();
237
238            let template = orig_postings[0];
239
240            if short_gains != Decimal::ZERO {
241                let new_account = template
242                    .account
243                    .replace(&config.account_to_replace, &config.short_replacement);
244                new_accounts.insert(new_account.clone());
245                new_postings.push(PostingData {
246                    account: new_account,
247                    units: template.units.as_ref().map(|u| AmountData {
248                        number: format_decimal(short_gains),
249                        currency: u.currency.clone(),
250                    }),
251                    cost: None,
252                    price: None,
253                    flag: template.flag.clone(),
254                    metadata: vec![],
255                });
256            }
257
258            if long_gains != Decimal::ZERO {
259                let new_account = template
260                    .account
261                    .replace(&config.account_to_replace, &config.long_replacement);
262                new_accounts.insert(new_account.clone());
263                new_postings.push(PostingData {
264                    account: new_account,
265                    units: template.units.as_ref().map(|u| AmountData {
266                        number: format_decimal(long_gains),
267                        currency: u.currency.clone(),
268                    }),
269                    cost: None,
270                    price: None,
271                    flag: template.flag.clone(),
272                    metadata: vec![],
273                });
274            }
275
276            new_directives.push(DirectiveWrapper {
277                directive_type: "transaction".to_string(),
278                date: directive.date.clone(),
279                filename: directive.filename.clone(),
280                lineno: directive.lineno,
281                data: DirectiveData::Transaction(TransactionData {
282                    flag: txn.flag.clone(),
283                    payee: txn.payee.clone(),
284                    narration: txn.narration.clone(),
285                    tags: txn.tags.clone(),
286                    links: txn.links.clone(),
287                    metadata: txn.metadata.clone(),
288                    postings: new_postings,
289                }),
290            });
291        } else {
292            new_directives.push(directive);
293        }
294    }
295
296    // Create Open directives for new accounts
297    let earliest_date = new_directives
298        .iter()
299        .map(|d| d.date.as_str())
300        .min()
301        .unwrap_or("1970-01-01")
302        .to_string();
303
304    let mut open_directives: Vec<DirectiveWrapper> = new_accounts
305        .iter()
306        .map(|account| DirectiveWrapper {
307            directive_type: "open".to_string(),
308            date: earliest_date.clone(),
309            filename: Some("<long_short>".to_string()),
310            lineno: Some(0),
311            data: DirectiveData::Open(OpenData {
312                account: account.clone(),
313                currencies: vec![],
314                booking: None,
315                metadata: vec![],
316            }),
317        })
318        .collect();
319
320    open_directives.extend(new_directives);
321
322    PluginOutput {
323        directives: open_directives,
324        errors: Vec::new(),
325    }
326}
327
328/// Process entries with `gain_loss` classification.
329fn process_gain_loss(input: PluginInput) -> PluginOutput {
330    let config = match &input.config {
331        Some(c) => match parse_gain_loss_config(c) {
332            Some(cfg) => cfg,
333            None => {
334                return PluginOutput {
335                    directives: input.directives,
336                    errors: Vec::new(),
337                };
338            }
339        },
340        None => {
341            return PluginOutput {
342                directives: input.directives,
343                errors: Vec::new(),
344            };
345        }
346    };
347
348    let mut new_accounts: HashSet<String> = HashSet::new();
349    let mut new_directives: Vec<DirectiveWrapper> = Vec::new();
350
351    for directive in input.directives {
352        if directive.directive_type != "transaction" {
353            new_directives.push(directive);
354            continue;
355        }
356
357        if let DirectiveData::Transaction(txn) = &directive.data {
358            let mut modified = false;
359            let mut new_postings: Vec<PostingData> = Vec::new();
360
361            for posting in &txn.postings {
362                if config.pattern.is_match(&posting.account)
363                    && let Some(units) = &posting.units
364                    && let Ok(number) = Decimal::from_str(&units.number)
365                {
366                    let new_account = if number < Decimal::ZERO {
367                        // Negative = gains (income is negative)
368                        posting
369                            .account
370                            .replace(&config.account_to_replace, &config.gains_replacement)
371                    } else {
372                        // Positive = losses
373                        posting
374                            .account
375                            .replace(&config.account_to_replace, &config.losses_replacement)
376                    };
377
378                    new_accounts.insert(new_account.clone());
379                    new_postings.push(PostingData {
380                        account: new_account,
381                        units: posting.units.clone(),
382                        cost: posting.cost.clone(),
383                        price: posting.price.clone(),
384                        flag: posting.flag.clone(),
385                        metadata: posting.metadata.clone(),
386                    });
387                    modified = true;
388                    continue;
389                }
390                new_postings.push(posting.clone());
391            }
392
393            if modified {
394                new_directives.push(DirectiveWrapper {
395                    directive_type: "transaction".to_string(),
396                    date: directive.date.clone(),
397                    filename: directive.filename.clone(),
398                    lineno: directive.lineno,
399                    data: DirectiveData::Transaction(TransactionData {
400                        flag: txn.flag.clone(),
401                        payee: txn.payee.clone(),
402                        narration: txn.narration.clone(),
403                        tags: txn.tags.clone(),
404                        links: txn.links.clone(),
405                        metadata: txn.metadata.clone(),
406                        postings: new_postings,
407                    }),
408                });
409            } else {
410                new_directives.push(directive);
411            }
412        } else {
413            new_directives.push(directive);
414        }
415    }
416
417    // Create Open directives for new accounts
418    let earliest_date = new_directives
419        .iter()
420        .map(|d| d.date.as_str())
421        .min()
422        .unwrap_or("1970-01-01")
423        .to_string();
424
425    let mut open_directives: Vec<DirectiveWrapper> = new_accounts
426        .iter()
427        .map(|account| DirectiveWrapper {
428            directive_type: "open".to_string(),
429            date: earliest_date.clone(),
430            filename: Some("<gain_loss>".to_string()),
431            lineno: Some(0),
432            data: DirectiveData::Open(OpenData {
433                account: account.clone(),
434                currencies: vec![],
435                booking: None,
436                metadata: vec![],
437            }),
438        })
439        .collect();
440
441    open_directives.extend(new_directives);
442
443    PluginOutput {
444        directives: open_directives,
445        errors: Vec::new(),
446    }
447}
448
449/// Parse `long_short` configuration.
450/// Format: `{'pattern': ['to_replace', 'short_repl', 'long_repl']}`
451fn parse_long_short_config(config: &str) -> Option<LongShortConfig> {
452    // Parse pattern: 'key': ['val1', 'val2', 'val3']
453    let cap = CONFIG_ENTRY_RE.captures(config)?;
454    let pattern = Regex::new(&cap[1]).ok()?;
455    let account_to_replace = cap[2].to_string();
456    let short_replacement = cap[3].to_string();
457    let long_replacement = cap[4].to_string();
458
459    Some(LongShortConfig {
460        pattern,
461        account_to_replace,
462        short_replacement,
463        long_replacement,
464    })
465}
466
467/// Parse `gain_loss` configuration.
468/// Format: `{'pattern': ['to_replace', 'gains_repl', 'losses_repl']}`
469fn parse_gain_loss_config(config: &str) -> Option<GainLossConfig> {
470    // Parse pattern: 'key': ['val1', 'val2', 'val3']
471    let cap = CONFIG_ENTRY_RE.captures(config)?;
472    let pattern = Regex::new(&cap[1]).ok()?;
473    let account_to_replace = cap[2].to_string();
474    let gains_replacement = cap[3].to_string();
475    let losses_replacement = cap[4].to_string();
476
477    Some(GainLossConfig {
478        pattern,
479        account_to_replace,
480        gains_replacement,
481        losses_replacement,
482    })
483}
484
485/// Format a decimal number.
486fn format_decimal(d: Decimal) -> String {
487    let s = d.to_string();
488    if s.contains('.') {
489        s.trim_end_matches('0').trim_end_matches('.').to_string()
490    } else {
491        s
492    }
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498    use crate::types::*;
499
500    #[test]
501    fn test_parse_long_short_config() {
502        let config = "{'Income.*:Capital-Gains': [':Capital-Gains', ':Capital-Gains:Short', ':Capital-Gains:Long']}";
503        let parsed = parse_long_short_config(config);
504        assert!(parsed.is_some());
505        let cfg = parsed.unwrap();
506        assert_eq!(cfg.account_to_replace, ":Capital-Gains");
507        assert_eq!(cfg.short_replacement, ":Capital-Gains:Short");
508        assert_eq!(cfg.long_replacement, ":Capital-Gains:Long");
509    }
510
511    #[test]
512    fn test_parse_gain_loss_config() {
513        let config = "{'Income.*:Long': [':Long', ':Long:Gains', ':Long:Losses']}";
514        let parsed = parse_gain_loss_config(config);
515        assert!(parsed.is_some());
516        let cfg = parsed.unwrap();
517        assert_eq!(cfg.account_to_replace, ":Long");
518        assert_eq!(cfg.gains_replacement, ":Long:Gains");
519        assert_eq!(cfg.losses_replacement, ":Long:Losses");
520    }
521
522    #[test]
523    fn test_gain_loss_classification() {
524        let plugin = CapitalGainsGainLossPlugin;
525
526        let input = PluginInput {
527            directives: vec![DirectiveWrapper {
528                directive_type: "transaction".to_string(),
529                date: "2024-01-15".to_string(),
530                filename: None,
531                lineno: None,
532                data: DirectiveData::Transaction(TransactionData {
533                    flag: "*".to_string(),
534                    payee: None,
535                    narration: "Sell stock".to_string(),
536                    tags: vec![],
537                    links: vec![],
538                    metadata: vec![],
539                    postings: vec![
540                        PostingData {
541                            account: "Assets:Broker".to_string(),
542                            units: Some(AmountData {
543                                number: "1000".to_string(),
544                                currency: "USD".to_string(),
545                            }),
546                            cost: None,
547                            price: None,
548                            flag: None,
549                            metadata: vec![],
550                        },
551                        PostingData {
552                            account: "Income:Capital-Gains:Long".to_string(),
553                            units: Some(AmountData {
554                                number: "-100".to_string(),
555                                currency: "USD".to_string(),
556                            }),
557                            cost: None,
558                            price: None,
559                            flag: None,
560                            metadata: vec![],
561                        },
562                    ],
563                }),
564            }],
565            options: PluginOptions {
566                operating_currencies: vec!["USD".to_string()],
567                title: None,
568            },
569            config: Some(
570                "{'Income.*:Capital-Gains:Long': [':Long', ':Long:Gains', ':Long:Losses']}"
571                    .to_string(),
572            ),
573        };
574
575        let output = plugin.process(input);
576        assert_eq!(output.errors.len(), 0);
577
578        // Find the transaction
579        let txn = output
580            .directives
581            .iter()
582            .find(|d| d.directive_type == "transaction");
583        assert!(txn.is_some());
584
585        if let DirectiveData::Transaction(t) = &txn.unwrap().data {
586            // The negative posting should be renamed to :Long:Gains
587            let gains_posting = t
588                .postings
589                .iter()
590                .find(|p| p.account.contains(":Long:Gains"));
591            assert!(gains_posting.is_some());
592        }
593    }
594}