Skip to main content

rustledger_plugin/native/plugins/
capital_gains_classifier.rs

1//! Capital gains classifier plugin.
2//!
3//! This plugin rebooks capital gains into separate accounts based on:
4//! - **`long_short`**: Whether gains are long-term (held > 1 year) or short-term
5//! - **`gain_loss`**: Whether the posting is a gain (negative income) or loss (positive income)
6//!
7//! Usage for `long_short`:
8//! ```text
9//! plugin "beancount_reds_plugins.capital_gains_classifier.long_short" "{
10//!   'Income.*:Capital-Gains': [':Capital-Gains', ':Capital-Gains:Short', ':Capital-Gains:Long']
11//! }"
12//! ```
13//!
14//! Usage for `gain_loss`:
15//! ```text
16//! plugin "beancount_reds_plugins.capital_gains_classifier.gain_loss" "{
17//!   'Income.*:Capital-Gains:Long': [':Long', ':Long:Gains', ':Long:Losses']
18//! }"
19//! ```
20
21use chrono::{Datelike, NaiveDate};
22use regex::Regex;
23use rust_decimal::Decimal;
24use std::collections::HashSet;
25use std::str::FromStr;
26
27use crate::types::{
28    AmountData, DirectiveData, DirectiveWrapper, OpenData, PluginInput, PluginOutput, PostingData,
29    TransactionData,
30};
31
32use super::super::NativePlugin;
33
34/// Plugin for classifying capital gains into long/short term categories.
35pub struct CapitalGainsLongShortPlugin;
36
37/// Plugin for classifying capital gains into gains/losses categories.
38pub struct CapitalGainsGainLossPlugin;
39
40impl NativePlugin for CapitalGainsLongShortPlugin {
41    fn name(&self) -> &'static str {
42        "long_short"
43    }
44
45    fn description(&self) -> &'static str {
46        "Classify capital gains into long-term vs short-term based on holding period"
47    }
48
49    fn process(&self, input: PluginInput) -> PluginOutput {
50        process_long_short(input)
51    }
52}
53
54impl NativePlugin for CapitalGainsGainLossPlugin {
55    fn name(&self) -> &'static str {
56        "gain_loss"
57    }
58
59    fn description(&self) -> &'static str {
60        "Classify capital gains into gains vs losses based on posting amount"
61    }
62
63    fn process(&self, input: PluginInput) -> PluginOutput {
64        process_gain_loss(input)
65    }
66}
67
68/// Configuration for `long_short` classification.
69struct LongShortConfig {
70    pattern: Regex,
71    account_to_replace: String,
72    short_replacement: String,
73    long_replacement: String,
74}
75
76/// Configuration for `gain_loss` classification.
77struct GainLossConfig {
78    pattern: Regex,
79    account_to_replace: String,
80    gains_replacement: String,
81    losses_replacement: String,
82}
83
84/// Process entries with `long_short` classification.
85fn process_long_short(input: PluginInput) -> PluginOutput {
86    let config = match &input.config {
87        Some(c) => match parse_long_short_config(c) {
88            Some(cfg) => cfg,
89            None => {
90                return PluginOutput {
91                    directives: input.directives,
92                    errors: Vec::new(),
93                };
94            }
95        },
96        None => {
97            return PluginOutput {
98                directives: input.directives,
99                errors: Vec::new(),
100            };
101        }
102    };
103
104    let mut new_accounts: HashSet<String> = HashSet::new();
105    let mut new_directives: Vec<DirectiveWrapper> = Vec::new();
106
107    for directive in input.directives {
108        if directive.directive_type != "transaction" {
109            new_directives.push(directive);
110            continue;
111        }
112
113        if let DirectiveData::Transaction(txn) = &directive.data {
114            // Check if transaction has matching capital gains postings
115            let has_generic = txn
116                .postings
117                .iter()
118                .any(|p| config.pattern.is_match(&p.account));
119            let has_specific = txn.postings.iter().any(|p| {
120                p.account.contains(&config.short_replacement)
121                    || p.account.contains(&config.long_replacement)
122            });
123
124            if !has_generic || has_specific {
125                new_directives.push(directive);
126                continue;
127            }
128
129            // Find reduction postings (sales with cost and price)
130            let reductions: Vec<&PostingData> = txn
131                .postings
132                .iter()
133                .filter(|p| p.cost.is_some() && p.units.is_some() && p.price.is_some())
134                .collect();
135
136            if reductions.is_empty() {
137                new_directives.push(directive);
138                continue;
139            }
140
141            // Calculate short vs long gains
142            let entry_date = if let Ok(d) = NaiveDate::parse_from_str(&directive.date, "%Y-%m-%d") {
143                d
144            } else {
145                new_directives.push(directive);
146                continue;
147            };
148
149            let mut short_gains = Decimal::ZERO;
150            let mut long_gains = Decimal::ZERO;
151
152            for posting in &reductions {
153                if let (Some(cost), Some(units), Some(price)) =
154                    (&posting.cost, &posting.units, &posting.price)
155                {
156                    // Get cost date
157                    let cost_date = cost
158                        .date
159                        .as_ref()
160                        .and_then(|d| NaiveDate::parse_from_str(d, "%Y-%m-%d").ok());
161
162                    if let Some(cost_date) = cost_date {
163                        // Calculate gain
164                        let cost_number = cost
165                            .number_per
166                            .as_ref()
167                            .and_then(|n| Decimal::from_str(n).ok())
168                            .unwrap_or(Decimal::ZERO);
169                        let price_number = price
170                            .amount
171                            .as_ref()
172                            .and_then(|a| Decimal::from_str(&a.number).ok())
173                            .unwrap_or(Decimal::ZERO);
174                        let units_number =
175                            Decimal::from_str(&units.number).unwrap_or(Decimal::ZERO);
176
177                        let gain = (cost_number - price_number) * units_number.abs();
178
179                        // Check if long-term (> 1 year)
180                        let years_held = entry_date.years_since(cost_date).unwrap_or(0);
181                        let is_long_term = years_held > 1
182                            || (years_held == 1
183                                && (entry_date.month() > cost_date.month()
184                                    || (entry_date.month() == cost_date.month()
185                                        && entry_date.day() >= cost_date.day())));
186
187                        if is_long_term {
188                            long_gains += gain;
189                        } else {
190                            short_gains += gain;
191                        }
192                    }
193                }
194            }
195
196            // Find and remove original capital gains postings
197            let orig_postings: Vec<&PostingData> = txn
198                .postings
199                .iter()
200                .filter(|p| config.pattern.is_match(&p.account))
201                .collect();
202
203            if orig_postings.is_empty() {
204                new_directives.push(directive);
205                continue;
206            }
207
208            let orig_sum: Decimal = orig_postings
209                .iter()
210                .filter_map(|p| p.units.as_ref())
211                .filter_map(|u| Decimal::from_str(&u.number).ok())
212                .sum();
213
214            // Adjust for rounding differences
215            let diff = orig_sum - (short_gains + long_gains);
216            if diff.abs() > Decimal::new(1, 6) {
217                let total = short_gains + long_gains;
218                if total != Decimal::ZERO {
219                    short_gains += (short_gains / total) * diff;
220                    long_gains += (long_gains / total) * diff;
221                }
222            }
223
224            // Create new postings
225            let mut new_postings: Vec<PostingData> = txn
226                .postings
227                .iter()
228                .filter(|p| !config.pattern.is_match(&p.account))
229                .cloned()
230                .collect();
231
232            let template = orig_postings[0];
233
234            if short_gains != Decimal::ZERO {
235                let new_account = template
236                    .account
237                    .replace(&config.account_to_replace, &config.short_replacement);
238                new_accounts.insert(new_account.clone());
239                new_postings.push(PostingData {
240                    account: new_account,
241                    units: template.units.as_ref().map(|u| AmountData {
242                        number: format_decimal(short_gains),
243                        currency: u.currency.clone(),
244                    }),
245                    cost: None,
246                    price: None,
247                    flag: template.flag.clone(),
248                    metadata: vec![],
249                });
250            }
251
252            if long_gains != Decimal::ZERO {
253                let new_account = template
254                    .account
255                    .replace(&config.account_to_replace, &config.long_replacement);
256                new_accounts.insert(new_account.clone());
257                new_postings.push(PostingData {
258                    account: new_account,
259                    units: template.units.as_ref().map(|u| AmountData {
260                        number: format_decimal(long_gains),
261                        currency: u.currency.clone(),
262                    }),
263                    cost: None,
264                    price: None,
265                    flag: template.flag.clone(),
266                    metadata: vec![],
267                });
268            }
269
270            new_directives.push(DirectiveWrapper {
271                directive_type: "transaction".to_string(),
272                date: directive.date.clone(),
273                filename: directive.filename.clone(),
274                lineno: directive.lineno,
275                data: DirectiveData::Transaction(TransactionData {
276                    flag: txn.flag.clone(),
277                    payee: txn.payee.clone(),
278                    narration: txn.narration.clone(),
279                    tags: txn.tags.clone(),
280                    links: txn.links.clone(),
281                    metadata: txn.metadata.clone(),
282                    postings: new_postings,
283                }),
284            });
285        } else {
286            new_directives.push(directive);
287        }
288    }
289
290    // Create Open directives for new accounts
291    let earliest_date = new_directives
292        .iter()
293        .map(|d| d.date.as_str())
294        .min()
295        .unwrap_or("1970-01-01")
296        .to_string();
297
298    let mut open_directives: Vec<DirectiveWrapper> = new_accounts
299        .iter()
300        .map(|account| DirectiveWrapper {
301            directive_type: "open".to_string(),
302            date: earliest_date.clone(),
303            filename: Some("<long_short>".to_string()),
304            lineno: Some(0),
305            data: DirectiveData::Open(OpenData {
306                account: account.clone(),
307                currencies: vec![],
308                booking: None,
309                metadata: vec![],
310            }),
311        })
312        .collect();
313
314    open_directives.extend(new_directives);
315
316    PluginOutput {
317        directives: open_directives,
318        errors: Vec::new(),
319    }
320}
321
322/// Process entries with `gain_loss` classification.
323fn process_gain_loss(input: PluginInput) -> PluginOutput {
324    let config = match &input.config {
325        Some(c) => match parse_gain_loss_config(c) {
326            Some(cfg) => cfg,
327            None => {
328                return PluginOutput {
329                    directives: input.directives,
330                    errors: Vec::new(),
331                };
332            }
333        },
334        None => {
335            return PluginOutput {
336                directives: input.directives,
337                errors: Vec::new(),
338            };
339        }
340    };
341
342    let mut new_accounts: HashSet<String> = HashSet::new();
343    let mut new_directives: Vec<DirectiveWrapper> = Vec::new();
344
345    for directive in input.directives {
346        if directive.directive_type != "transaction" {
347            new_directives.push(directive);
348            continue;
349        }
350
351        if let DirectiveData::Transaction(txn) = &directive.data {
352            let mut modified = false;
353            let mut new_postings: Vec<PostingData> = Vec::new();
354
355            for posting in &txn.postings {
356                if config.pattern.is_match(&posting.account)
357                    && let Some(units) = &posting.units
358                    && let Ok(number) = Decimal::from_str(&units.number)
359                {
360                    let new_account = if number < Decimal::ZERO {
361                        // Negative = gains (income is negative)
362                        posting
363                            .account
364                            .replace(&config.account_to_replace, &config.gains_replacement)
365                    } else {
366                        // Positive = losses
367                        posting
368                            .account
369                            .replace(&config.account_to_replace, &config.losses_replacement)
370                    };
371
372                    new_accounts.insert(new_account.clone());
373                    new_postings.push(PostingData {
374                        account: new_account,
375                        units: posting.units.clone(),
376                        cost: posting.cost.clone(),
377                        price: posting.price.clone(),
378                        flag: posting.flag.clone(),
379                        metadata: posting.metadata.clone(),
380                    });
381                    modified = true;
382                    continue;
383                }
384                new_postings.push(posting.clone());
385            }
386
387            if modified {
388                new_directives.push(DirectiveWrapper {
389                    directive_type: "transaction".to_string(),
390                    date: directive.date.clone(),
391                    filename: directive.filename.clone(),
392                    lineno: directive.lineno,
393                    data: DirectiveData::Transaction(TransactionData {
394                        flag: txn.flag.clone(),
395                        payee: txn.payee.clone(),
396                        narration: txn.narration.clone(),
397                        tags: txn.tags.clone(),
398                        links: txn.links.clone(),
399                        metadata: txn.metadata.clone(),
400                        postings: new_postings,
401                    }),
402                });
403            } else {
404                new_directives.push(directive);
405            }
406        } else {
407            new_directives.push(directive);
408        }
409    }
410
411    // Create Open directives for new accounts
412    let earliest_date = new_directives
413        .iter()
414        .map(|d| d.date.as_str())
415        .min()
416        .unwrap_or("1970-01-01")
417        .to_string();
418
419    let mut open_directives: Vec<DirectiveWrapper> = new_accounts
420        .iter()
421        .map(|account| DirectiveWrapper {
422            directive_type: "open".to_string(),
423            date: earliest_date.clone(),
424            filename: Some("<gain_loss>".to_string()),
425            lineno: Some(0),
426            data: DirectiveData::Open(OpenData {
427                account: account.clone(),
428                currencies: vec![],
429                booking: None,
430                metadata: vec![],
431            }),
432        })
433        .collect();
434
435    open_directives.extend(new_directives);
436
437    PluginOutput {
438        directives: open_directives,
439        errors: Vec::new(),
440    }
441}
442
443/// Parse `long_short` configuration.
444/// Format: `{'pattern': ['to_replace', 'short_repl', 'long_repl']}`
445fn parse_long_short_config(config: &str) -> Option<LongShortConfig> {
446    // Parse pattern: 'key': ['val1', 'val2', 'val3']
447    let re =
448        Regex::new(r"'([^']+)'\s*:\s*\[\s*'([^']*)'\s*,\s*'([^']*)'\s*,\s*'([^']*)'\s*\]").ok()?;
449
450    let cap = re.captures(config)?;
451    let pattern = Regex::new(&cap[1]).ok()?;
452    let account_to_replace = cap[2].to_string();
453    let short_replacement = cap[3].to_string();
454    let long_replacement = cap[4].to_string();
455
456    Some(LongShortConfig {
457        pattern,
458        account_to_replace,
459        short_replacement,
460        long_replacement,
461    })
462}
463
464/// Parse `gain_loss` configuration.
465/// Format: `{'pattern': ['to_replace', 'gains_repl', 'losses_repl']}`
466fn parse_gain_loss_config(config: &str) -> Option<GainLossConfig> {
467    // Parse pattern: 'key': ['val1', 'val2', 'val3']
468    let re =
469        Regex::new(r"'([^']+)'\s*:\s*\[\s*'([^']*)'\s*,\s*'([^']*)'\s*,\s*'([^']*)'\s*\]").ok()?;
470
471    let cap = re.captures(config)?;
472    let pattern = Regex::new(&cap[1]).ok()?;
473    let account_to_replace = cap[2].to_string();
474    let gains_replacement = cap[3].to_string();
475    let losses_replacement = cap[4].to_string();
476
477    Some(GainLossConfig {
478        pattern,
479        account_to_replace,
480        gains_replacement,
481        losses_replacement,
482    })
483}
484
485/// Format a decimal number.
486fn format_decimal(d: Decimal) -> String {
487    let s = d.to_string();
488    if s.contains('.') {
489        s.trim_end_matches('0').trim_end_matches('.').to_string()
490    } else {
491        s
492    }
493}
494
495#[cfg(test)]
496mod tests {
497    use super::*;
498    use crate::types::*;
499
500    #[test]
501    fn test_parse_long_short_config() {
502        let config = "{'Income.*:Capital-Gains': [':Capital-Gains', ':Capital-Gains:Short', ':Capital-Gains:Long']}";
503        let parsed = parse_long_short_config(config);
504        assert!(parsed.is_some());
505        let cfg = parsed.unwrap();
506        assert_eq!(cfg.account_to_replace, ":Capital-Gains");
507        assert_eq!(cfg.short_replacement, ":Capital-Gains:Short");
508        assert_eq!(cfg.long_replacement, ":Capital-Gains:Long");
509    }
510
511    #[test]
512    fn test_parse_gain_loss_config() {
513        let config = "{'Income.*:Long': [':Long', ':Long:Gains', ':Long:Losses']}";
514        let parsed = parse_gain_loss_config(config);
515        assert!(parsed.is_some());
516        let cfg = parsed.unwrap();
517        assert_eq!(cfg.account_to_replace, ":Long");
518        assert_eq!(cfg.gains_replacement, ":Long:Gains");
519        assert_eq!(cfg.losses_replacement, ":Long:Losses");
520    }
521
522    #[test]
523    fn test_gain_loss_classification() {
524        let plugin = CapitalGainsGainLossPlugin;
525
526        let input = PluginInput {
527            directives: vec![DirectiveWrapper {
528                directive_type: "transaction".to_string(),
529                date: "2024-01-15".to_string(),
530                filename: None,
531                lineno: None,
532                data: DirectiveData::Transaction(TransactionData {
533                    flag: "*".to_string(),
534                    payee: None,
535                    narration: "Sell stock".to_string(),
536                    tags: vec![],
537                    links: vec![],
538                    metadata: vec![],
539                    postings: vec![
540                        PostingData {
541                            account: "Assets:Broker".to_string(),
542                            units: Some(AmountData {
543                                number: "1000".to_string(),
544                                currency: "USD".to_string(),
545                            }),
546                            cost: None,
547                            price: None,
548                            flag: None,
549                            metadata: vec![],
550                        },
551                        PostingData {
552                            account: "Income:Capital-Gains:Long".to_string(),
553                            units: Some(AmountData {
554                                number: "-100".to_string(),
555                                currency: "USD".to_string(),
556                            }),
557                            cost: None,
558                            price: None,
559                            flag: None,
560                            metadata: vec![],
561                        },
562                    ],
563                }),
564            }],
565            options: PluginOptions {
566                operating_currencies: vec!["USD".to_string()],
567                title: None,
568            },
569            config: Some(
570                "{'Income.*:Capital-Gains:Long': [':Long', ':Long:Gains', ':Long:Losses']}"
571                    .to_string(),
572            ),
573        };
574
575        let output = plugin.process(input);
576        assert_eq!(output.errors.len(), 0);
577
578        // Find the transaction
579        let txn = output
580            .directives
581            .iter()
582            .find(|d| d.directive_type == "transaction");
583        assert!(txn.is_some());
584
585        if let DirectiveData::Transaction(t) = &txn.unwrap().data {
586            // The negative posting should be renamed to :Long:Gains
587            let gains_posting = t
588                .postings
589                .iter()
590                .find(|p| p.account.contains(":Long:Gains"));
591            assert!(gains_posting.is_some());
592        }
593    }
594}