Skip to main content

rustledger_plugin/native/plugins/
zerosum.rs

1//! Zero-sum account matching plugin.
2//!
3//! Matches postings in "zerosum" accounts that net to zero within a date range,
4//! and moves them to a "matched" account. Useful for tracking transfers between
5//! accounts.
6//!
7//! Configuration (as a Python-style dict string):
8//! ```text
9//! plugin "beancount_reds_plugins.zerosum.zerosum" "{
10//!   'zerosum_accounts': {
11//!     'Assets:ZeroSum:Transfers': ('Assets:ZeroSum-Matched:Transfers', 30),
12//!   },
13//!   'account_name_replace': ('ZeroSum', 'ZeroSum-Matched')
14//! }"
15//! ```
16
17use rust_decimal::Decimal;
18use std::collections::{HashMap, HashSet};
19use std::str::FromStr;
20
21use crate::types::{
22    DirectiveData, DirectiveWrapper, OpenData, PluginError, PluginErrorSeverity, PluginInput,
23    PluginOutput,
24};
25
26use super::super::NativePlugin;
27
28/// Default tolerance for matching amounts.
29const DEFAULT_TOLERANCE: &str = "0.0099";
30
31/// Plugin for matching zero-sum postings.
32pub struct ZerosumPlugin;
33
34impl NativePlugin for ZerosumPlugin {
35    fn name(&self) -> &'static str {
36        "zerosum"
37    }
38
39    fn description(&self) -> &'static str {
40        "Match postings in zero-sum accounts and move to matched account"
41    }
42
43    fn process(&self, input: PluginInput) -> PluginOutput {
44        // Parse configuration
45        let config = match &input.config {
46            Some(c) => c,
47            None => {
48                return PluginOutput {
49                    directives: input.directives,
50                    errors: vec![PluginError {
51                        message: "zerosum plugin requires configuration".to_string(),
52                        source_file: None,
53                        line_number: None,
54                        severity: PluginErrorSeverity::Error,
55                    }],
56                };
57            }
58        };
59
60        // Parse the Python-style dict config
61        let (zerosum_accounts, account_replace, tolerance) = match parse_config(config) {
62            Ok(c) => c,
63            Err(e) => {
64                return PluginOutput {
65                    directives: input.directives,
66                    errors: vec![PluginError {
67                        message: format!("Failed to parse zerosum config: {e}"),
68                        source_file: None,
69                        line_number: None,
70                        severity: PluginErrorSeverity::Error,
71                    }],
72                };
73            }
74        };
75
76        let mut new_accounts: HashSet<String> = HashSet::new();
77        let mut earliest_date: Option<String> = None;
78
79        // Collect existing Open accounts to avoid creating duplicates
80        let existing_opens: HashSet<String> = input
81            .directives
82            .iter()
83            .filter_map(|d| {
84                if let DirectiveData::Open(ref open) = d.data {
85                    Some(open.account.clone())
86                } else {
87                    None
88                }
89            })
90            .collect();
91
92        // Index transactions by zerosum account
93        let mut txn_indices: HashMap<String, Vec<usize>> = HashMap::new();
94
95        for (i, directive) in input.directives.iter().enumerate() {
96            if directive.directive_type == "transaction" {
97                if earliest_date.is_none() || directive.date < *earliest_date.as_ref().unwrap() {
98                    earliest_date = Some(directive.date.clone());
99                }
100
101                if let DirectiveData::Transaction(ref txn) = directive.data {
102                    for zs_account in zerosum_accounts.keys() {
103                        if txn.postings.iter().any(|p| &p.account == zs_account) {
104                            txn_indices.entry(zs_account.clone()).or_default().push(i);
105                        }
106                    }
107                }
108            }
109        }
110
111        // Convert to mutable
112        let mut directives = input.directives;
113
114        // For each zerosum account, find matching pairs
115        for (zs_account, (target_account_opt, date_range)) in &zerosum_accounts {
116            // Determine target account
117            let target_account = target_account_opt.clone().unwrap_or_else(|| {
118                if let Some((from, to)) = &account_replace {
119                    zs_account.replace(from, to)
120                } else {
121                    format!("{zs_account}-Matched")
122                }
123            });
124
125            let indices = match txn_indices.get(zs_account) {
126                Some(i) => i.clone(),
127                None => continue,
128            };
129
130            // Track which postings have been matched (by txn_idx, posting_idx)
131            let mut matched: HashSet<(usize, usize)> = HashSet::new();
132
133            // For each transaction in this zerosum account
134            for &txn_i in &indices {
135                let directive = &directives[txn_i];
136                let txn_date = &directive.date;
137
138                if let DirectiveData::Transaction(ref txn) = directive.data {
139                    // Find postings in this transaction that are in the zerosum account
140                    for (post_i, posting) in txn.postings.iter().enumerate() {
141                        if &posting.account != zs_account {
142                            continue;
143                        }
144                        if matched.contains(&(txn_i, post_i)) {
145                            continue;
146                        }
147
148                        // Get the amount
149                        let amount = match &posting.units {
150                            Some(u) => match Decimal::from_str(&u.number) {
151                                Ok(n) => n,
152                                Err(_) => continue,
153                            },
154                            None => continue,
155                        };
156                        let currency = posting.units.as_ref().map(|u| &u.currency);
157
158                        // Look for a matching posting in other transactions
159                        for &other_txn_i in &indices {
160                            if other_txn_i == txn_i {
161                                // Check within same transaction but different posting
162                                if let DirectiveData::Transaction(ref other_txn) =
163                                    directives[other_txn_i].data
164                                {
165                                    for (other_post_i, other_posting) in
166                                        other_txn.postings.iter().enumerate()
167                                    {
168                                        if other_post_i == post_i {
169                                            continue;
170                                        }
171                                        if &other_posting.account != zs_account {
172                                            continue;
173                                        }
174                                        if matched.contains(&(other_txn_i, other_post_i)) {
175                                            continue;
176                                        }
177
178                                        let other_currency =
179                                            other_posting.units.as_ref().map(|u| &u.currency);
180                                        if currency != other_currency {
181                                            continue;
182                                        }
183
184                                        let other_amount = match &other_posting.units {
185                                            Some(u) => match Decimal::from_str(&u.number) {
186                                                Ok(n) => n,
187                                                Err(_) => continue,
188                                            },
189                                            None => continue,
190                                        };
191
192                                        // Check if they sum to zero (within tolerance)
193                                        let sum = (amount + other_amount).abs();
194                                        if sum < tolerance {
195                                            // Found a match!
196                                            matched.insert((txn_i, post_i));
197                                            matched.insert((other_txn_i, other_post_i));
198                                            new_accounts.insert(target_account.clone());
199                                            break;
200                                        }
201                                    }
202                                }
203                                continue;
204                            }
205
206                            // Check date range
207                            let other_date = &directives[other_txn_i].date;
208                            if !within_date_range(txn_date, other_date, *date_range) {
209                                continue;
210                            }
211
212                            if let DirectiveData::Transaction(ref other_txn) =
213                                directives[other_txn_i].data
214                            {
215                                for (other_post_i, other_posting) in
216                                    other_txn.postings.iter().enumerate()
217                                {
218                                    if &other_posting.account != zs_account {
219                                        continue;
220                                    }
221                                    if matched.contains(&(other_txn_i, other_post_i)) {
222                                        continue;
223                                    }
224
225                                    let other_currency =
226                                        other_posting.units.as_ref().map(|u| &u.currency);
227                                    if currency != other_currency {
228                                        continue;
229                                    }
230
231                                    let other_amount = match &other_posting.units {
232                                        Some(u) => match Decimal::from_str(&u.number) {
233                                            Ok(n) => n,
234                                            Err(_) => continue,
235                                        },
236                                        None => continue,
237                                    };
238
239                                    // Check if they sum to zero (within tolerance)
240                                    let sum = (amount + other_amount).abs();
241                                    if sum < tolerance {
242                                        // Found a match!
243                                        matched.insert((txn_i, post_i));
244                                        matched.insert((other_txn_i, other_post_i));
245                                        new_accounts.insert(target_account.clone());
246                                        break;
247                                    }
248                                }
249                            }
250
251                            // If we found a match, break out
252                            if matched.contains(&(txn_i, post_i)) {
253                                break;
254                            }
255                        }
256                    }
257                }
258            }
259
260            // Now update the matched postings to use the target account
261            for (txn_i, post_i) in &matched {
262                if let DirectiveData::Transaction(ref mut txn) = directives[*txn_i].data
263                    && *post_i < txn.postings.len()
264                {
265                    txn.postings[*post_i].account.clone_from(&target_account);
266                }
267            }
268        }
269
270        // Create Open directives for new accounts (only if not already opened)
271        let mut open_directives: Vec<DirectiveWrapper> = Vec::new();
272        if let Some(date) = earliest_date {
273            for account in &new_accounts {
274                // Skip if account already has an Open directive
275                if existing_opens.contains(account) {
276                    continue;
277                }
278                open_directives.push(DirectiveWrapper {
279                    directive_type: "open".to_string(),
280                    date: date.clone(),
281                    filename: Some("<zerosum>".to_string()),
282                    lineno: Some(0),
283                    data: DirectiveData::Open(OpenData {
284                        account: account.clone(),
285                        currencies: vec![],
286                        booking: None,
287                        metadata: vec![],
288                    }),
289                });
290            }
291        }
292
293        // Combine open directives with modified directives
294        let mut all_directives = open_directives;
295        all_directives.extend(directives);
296
297        PluginOutput {
298            directives: all_directives,
299            errors: Vec::new(),
300        }
301    }
302}
303
304/// Parse the Python-style config dict.
305fn parse_config(
306    config: &str,
307) -> Result<
308    (
309        HashMap<String, (Option<String>, i64)>,
310        Option<(String, String)>,
311        Decimal,
312    ),
313    String,
314> {
315    let mut zerosum_accounts = HashMap::new();
316    let mut account_replace: Option<(String, String)> = None;
317    let mut tolerance = Decimal::from_str(DEFAULT_TOLERANCE).unwrap();
318
319    // Simple parsing of Python dict format
320    // 'zerosum_accounts': {'Account': ('Target', 30), ...}
321    // 'account_name_replace': ('From', 'To')
322
323    // Extract zerosum_accounts
324    if let Some(start) = config.find("'zerosum_accounts'")
325        && let Some(dict_offset) = config[start..].find('{')
326    {
327        let dict_start = start + dict_offset;
328        let mut depth = 0;
329        let mut dict_end = dict_start;
330        for (i, c) in config[dict_start..].char_indices() {
331            match c {
332                '{' => depth += 1,
333                '}' => {
334                    depth -= 1;
335                    if depth == 0 {
336                        dict_end = dict_start + i + 1;
337                        break;
338                    }
339                }
340                _ => {}
341            }
342        }
343
344        let dict_str = &config[dict_start..dict_end];
345        // Parse individual account entries
346        // Format: 'AccountName': ('TargetAccount', days)
347        // or: 'AccountName': ('', days)
348        let re = regex::Regex::new(r"'([^']+)'\s*:\s*\(\s*'([^']*)'\s*,\s*(\d+)\s*\)")
349            .map_err(|e| e.to_string())?;
350
351        for cap in re.captures_iter(dict_str) {
352            let account = cap[1].to_string();
353            let target = if cap[2].is_empty() {
354                None
355            } else {
356                Some(cap[2].to_string())
357            };
358            let days: i64 = cap[3].parse().unwrap_or(30);
359            zerosum_accounts.insert(account, (target, days));
360        }
361    }
362
363    // Extract account_name_replace
364    if let Some(start) = config.find("'account_name_replace'") {
365        let re =
366            regex::Regex::new(r"'account_name_replace'\s*:\s*\(\s*'([^']*)'\s*,\s*'([^']*)'\s*\)")
367                .map_err(|e| e.to_string())?;
368        if let Some(cap) = re.captures(&config[start..]) {
369            account_replace = Some((cap[1].to_string(), cap[2].to_string()));
370        }
371    }
372
373    // Extract tolerance
374    if let Some(start) = config.find("'tolerance'") {
375        let re = regex::Regex::new(r"'tolerance'\s*:\s*([0-9.]+)").map_err(|e| e.to_string())?;
376        if let Some(cap) = re.captures(&config[start..])
377            && let Ok(t) = Decimal::from_str(&cap[1])
378        {
379            tolerance = t;
380        }
381    }
382
383    Ok((zerosum_accounts, account_replace, tolerance))
384}
385
386/// Check if two dates are within a given range (in days).
387fn within_date_range(date1: &str, date2: &str, days: i64) -> bool {
388    use chrono::NaiveDate;
389
390    let d1 = match NaiveDate::parse_from_str(date1, "%Y-%m-%d") {
391        Ok(d) => d,
392        Err(_) => return false,
393    };
394    let d2 = match NaiveDate::parse_from_str(date2, "%Y-%m-%d") {
395        Ok(d) => d,
396        Err(_) => return false,
397    };
398
399    let diff = (d2 - d1).num_days().abs();
400    diff <= days
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406    use crate::types::*;
407
408    fn create_transfer_txn(
409        date: &str,
410        from_account: &str,
411        to_account: &str,
412        amount: &str,
413        currency: &str,
414    ) -> DirectiveWrapper {
415        DirectiveWrapper {
416            directive_type: "transaction".to_string(),
417            date: date.to_string(),
418            filename: None,
419            lineno: None,
420            data: DirectiveData::Transaction(TransactionData {
421                flag: "*".to_string(),
422                payee: None,
423                narration: "Transfer".to_string(),
424                tags: vec![],
425                links: vec![],
426                metadata: vec![],
427                postings: vec![
428                    PostingData {
429                        account: from_account.to_string(),
430                        units: Some(AmountData {
431                            number: format!("-{amount}"),
432                            currency: currency.to_string(),
433                        }),
434                        cost: None,
435                        price: None,
436                        flag: None,
437                        metadata: vec![],
438                    },
439                    PostingData {
440                        account: to_account.to_string(),
441                        units: Some(AmountData {
442                            number: amount.to_string(),
443                            currency: currency.to_string(),
444                        }),
445                        cost: None,
446                        price: None,
447                        flag: None,
448                        metadata: vec![],
449                    },
450                ],
451            }),
452        }
453    }
454
455    #[test]
456    fn test_zerosum_matches_transfers() {
457        let plugin = ZerosumPlugin;
458
459        let config = r"{
460            'zerosum_accounts': {
461                'Assets:ZeroSum:Transfers': ('Assets:ZeroSum-Matched:Transfers', 30)
462            }
463        }";
464
465        let input = PluginInput {
466            directives: vec![
467                create_transfer_txn(
468                    "2024-01-01",
469                    "Assets:Bank",
470                    "Assets:ZeroSum:Transfers",
471                    "100.00",
472                    "USD",
473                ),
474                create_transfer_txn(
475                    "2024-01-03",
476                    "Assets:ZeroSum:Transfers",
477                    "Assets:Investment",
478                    "100.00",
479                    "USD",
480                ),
481            ],
482            options: PluginOptions {
483                operating_currencies: vec!["USD".to_string()],
484                title: None,
485            },
486            config: Some(config.to_string()),
487        };
488
489        let output = plugin.process(input);
490        assert_eq!(output.errors.len(), 0);
491
492        // Check that matched postings were moved to target account
493        let mut found_matched = false;
494        for directive in &output.directives {
495            if let DirectiveData::Transaction(ref txn) = directive.data {
496                for posting in &txn.postings {
497                    if posting.account == "Assets:ZeroSum-Matched:Transfers" {
498                        found_matched = true;
499                    }
500                }
501            }
502        }
503        assert!(found_matched, "Should have matched postings");
504    }
505
506    #[test]
507    fn test_zerosum_no_match_outside_range() {
508        let plugin = ZerosumPlugin;
509
510        let config = r"{
511            'zerosum_accounts': {
512                'Assets:ZeroSum:Transfers': ('Assets:ZeroSum-Matched:Transfers', 5)
513            }
514        }";
515
516        let input = PluginInput {
517            directives: vec![
518                create_transfer_txn(
519                    "2024-01-01",
520                    "Assets:Bank",
521                    "Assets:ZeroSum:Transfers",
522                    "100.00",
523                    "USD",
524                ),
525                // 10 days later - outside the 5-day range
526                create_transfer_txn(
527                    "2024-01-11",
528                    "Assets:ZeroSum:Transfers",
529                    "Assets:Investment",
530                    "100.00",
531                    "USD",
532                ),
533            ],
534            options: PluginOptions {
535                operating_currencies: vec!["USD".to_string()],
536                title: None,
537            },
538            config: Some(config.to_string()),
539        };
540
541        let output = plugin.process(input);
542        assert_eq!(output.errors.len(), 0);
543
544        // Check that postings were NOT matched (still in original account)
545        let mut found_unmatched = false;
546        for directive in &output.directives {
547            if let DirectiveData::Transaction(ref txn) = directive.data {
548                for posting in &txn.postings {
549                    if posting.account == "Assets:ZeroSum:Transfers" {
550                        found_unmatched = true;
551                    }
552                }
553            }
554        }
555        assert!(found_unmatched, "Should have unmatched postings");
556    }
557
558    #[test]
559    fn test_zerosum_does_not_duplicate_open() {
560        // Regression test: zerosum should not create duplicate Open directives
561        // when the target account already has an Open directive.
562        let plugin = ZerosumPlugin;
563
564        let config = r"{
565            'zerosum_accounts': {
566                'Assets:Transfer': ('Assets:ZSA-Matched:Transfer', 7)
567            }
568        }";
569
570        // Create an existing Open for the target account
571        let existing_open = DirectiveWrapper {
572            directive_type: "open".to_string(),
573            date: "2020-01-01".to_string(),
574            filename: Some("accounts.beancount".to_string()),
575            lineno: Some(422),
576            data: DirectiveData::Open(OpenData {
577                account: "Assets:ZSA-Matched:Transfer".to_string(),
578                currencies: vec![],
579                booking: None,
580                metadata: vec![],
581            }),
582        };
583
584        let input = PluginInput {
585            directives: vec![
586                existing_open,
587                create_transfer_txn(
588                    "2024-01-01",
589                    "Assets:Bank",
590                    "Assets:Transfer",
591                    "100.00",
592                    "USD",
593                ),
594                create_transfer_txn(
595                    "2024-01-02",
596                    "Assets:Transfer",
597                    "Assets:Investment",
598                    "100.00",
599                    "USD",
600                ),
601            ],
602            options: PluginOptions {
603                operating_currencies: vec!["USD".to_string()],
604                title: None,
605            },
606            config: Some(config.to_string()),
607        };
608
609        let output = plugin.process(input);
610        assert_eq!(output.errors.len(), 0);
611
612        // Count Open directives for the target account
613        let open_count = output
614            .directives
615            .iter()
616            .filter(|d| {
617                if let DirectiveData::Open(ref open) = d.data {
618                    open.account == "Assets:ZSA-Matched:Transfer"
619                } else {
620                    false
621                }
622            })
623            .count();
624
625        // Should only have 1 Open (the existing one, not a duplicate from the plugin)
626        assert_eq!(
627            open_count, 1,
628            "Should not create duplicate Open directives for existing accounts"
629        );
630    }
631
632    #[test]
633    fn test_parse_config() {
634        let config = r"{
635            'zerosum_accounts': {
636                'Assets:ZeroSum:Transfers': ('Assets:ZeroSum-Matched:Transfers', 30),
637                'Assets:ZeroSum:CreditCard': ('', 6)
638            },
639            'account_name_replace': ('ZeroSum', 'ZeroSum-Matched'),
640            'tolerance': 0.01
641        }";
642
643        let (accounts, replace, tolerance) = parse_config(config).unwrap();
644
645        assert_eq!(accounts.len(), 2);
646        assert!(accounts.contains_key("Assets:ZeroSum:Transfers"));
647        assert!(accounts.contains_key("Assets:ZeroSum:CreditCard"));
648
649        let (target, days) = accounts.get("Assets:ZeroSum:Transfers").unwrap();
650        assert_eq!(target.as_ref().unwrap(), "Assets:ZeroSum-Matched:Transfers");
651        assert_eq!(*days, 30);
652
653        let (target2, days2) = accounts.get("Assets:ZeroSum:CreditCard").unwrap();
654        assert!(target2.is_none()); // Empty target means use account_name_replace
655        assert_eq!(*days2, 6);
656
657        assert!(replace.is_some());
658        let (from, to) = replace.unwrap();
659        assert_eq!(from, "ZeroSum");
660        assert_eq!(to, "ZeroSum-Matched");
661
662        assert_eq!(tolerance, Decimal::from_str("0.01").unwrap());
663    }
664}