Skip to main content

rustledger_plugin/native/plugins/
forecast.rs

1//! Forecast plugin - generate recurring transactions.
2//!
3//! This plugin generates recurring transactions from template transactions
4//! marked with the "#" flag. The periodicity is specified in the narration.
5//!
6//! Example:
7//! ```beancount
8//! 2014-03-08 # "Electricity bill [MONTHLY]"
9//!   Expenses:Electricity  50.10 USD
10//!   Assets:Checking      -50.10 USD
11//! ```
12//!
13//! Supported patterns:
14//! - `[MONTHLY]` - Repeat monthly until end of current year
15//! - `[WEEKLY]` - Repeat weekly until end of current year
16//! - `[DAILY]` - Repeat daily until end of current year
17//! - `[YEARLY]` - Repeat yearly until end of current year
18//! - `[MONTHLY REPEAT 3 TIMES]` - Repeat 3 times
19//! - `[MONTHLY UNTIL 2020-12-31]` - Repeat until specified date
20//! - `[MONTHLY SKIP 1 TIME]` - Skip every other month
21
22use chrono::{Datelike, NaiveDate};
23use regex::Regex;
24use std::sync::LazyLock;
25
26use crate::types::{DirectiveData, PluginInput, PluginOutput};
27
28use super::super::NativePlugin;
29
30/// Regex for parsing forecast patterns in narrations.
31/// Matches: `[MONTHLY]`, `[WEEKLY SKIP 2 TIMES]`, `[MONTHLY UNTIL 2025-12-31]`, etc.
32static FORECAST_PATTERN_RE: LazyLock<Regex> = LazyLock::new(|| {
33    Regex::new(
34        r"(?x)
35        (^.*?)                             # narration prefix
36        \[
37        (MONTHLY|YEARLY|WEEKLY|DAILY)     # interval type
38        (?:\s+SKIP\s+(\d+)\s+TIMES?)?     # optional SKIP n TIMES
39        (?:\s+REPEAT\s+(\d+)\s+TIMES?)?   # optional REPEAT n TIMES
40        (?:\s+UNTIL\s+(\d{4}-\d{2}-\d{2}))? # optional UNTIL date
41        \]
42    ",
43    )
44    .expect("FORECAST_PATTERN_RE: invalid regex pattern")
45});
46
47/// Plugin for generating recurring forecast transactions.
48pub struct ForecastPlugin;
49
50#[derive(Debug, Clone, Copy, PartialEq)]
51enum Interval {
52    Daily,
53    Weekly,
54    Monthly,
55    Yearly,
56}
57
58impl NativePlugin for ForecastPlugin {
59    fn name(&self) -> &'static str {
60        "forecast"
61    }
62
63    fn description(&self) -> &'static str {
64        "Generate recurring forecast transactions"
65    }
66
67    fn process(&self, input: PluginInput) -> PluginOutput {
68        let mut forecast_entries = Vec::new();
69        let mut filtered_entries = Vec::new();
70
71        // Separate forecast entries from regular entries
72        for directive in input.directives {
73            if directive.directive_type == "transaction"
74                && let DirectiveData::Transaction(ref txn) = directive.data
75                && txn.flag == "#"
76            {
77                forecast_entries.push(directive);
78            } else {
79                filtered_entries.push(directive);
80            }
81        }
82
83        // Get current year end as default until date
84        let today = chrono::Local::now().naive_local().date();
85        let default_until = NaiveDate::from_ymd_opt(today.year(), 12, 31).unwrap();
86
87        // Generate recurring transactions
88        let mut new_entries = Vec::new();
89
90        for directive in forecast_entries {
91            if let DirectiveData::Transaction(ref txn) = directive.data {
92                if let Some(caps) = FORECAST_PATTERN_RE.captures(&txn.narration) {
93                    let narration_prefix = caps.get(1).map_or("", |m| m.as_str().trim());
94                    let interval_str = caps.get(2).map_or("MONTHLY", |m| m.as_str());
95                    let skip_count: usize = caps
96                        .get(3)
97                        .and_then(|m| m.as_str().parse().ok())
98                        .unwrap_or(0);
99                    let repeat_count: Option<usize> =
100                        caps.get(4).and_then(|m| m.as_str().parse().ok());
101                    let until_date: Option<NaiveDate> = caps
102                        .get(5)
103                        .and_then(|m| NaiveDate::parse_from_str(m.as_str(), "%Y-%m-%d").ok());
104
105                    let interval = match interval_str {
106                        "DAILY" => Interval::Daily,
107                        "WEEKLY" => Interval::Weekly,
108                        "YEARLY" => Interval::Yearly,
109                        _ => Interval::Monthly,
110                    };
111
112                    // Parse start date
113                    let start_date =
114                        if let Ok(date) = NaiveDate::parse_from_str(&directive.date, "%Y-%m-%d") {
115                            date
116                        } else {
117                            // Skip if date is unparsable
118                            new_entries.push(directive);
119                            continue;
120                        };
121
122                    // Determine end condition
123                    let until = until_date.unwrap_or(default_until);
124
125                    // Generate dates
126                    let dates =
127                        generate_dates(start_date, interval, skip_count, repeat_count, until);
128
129                    // Create a transaction for each date
130                    for date in dates {
131                        let mut new_directive = directive.clone();
132                        new_directive.date = date.format("%Y-%m-%d").to_string();
133
134                        if let DirectiveData::Transaction(ref mut new_txn) = new_directive.data {
135                            new_txn.narration = narration_prefix.to_string();
136                        }
137
138                        new_entries.push(new_directive);
139                    }
140                } else {
141                    // No pattern match, keep original
142                    new_entries.push(directive);
143                }
144            }
145        }
146
147        // Sort new entries by date
148        new_entries.sort_by(|a, b| a.date.cmp(&b.date));
149
150        // Combine filtered entries with new entries
151        filtered_entries.extend(new_entries);
152
153        PluginOutput {
154            directives: filtered_entries,
155            errors: Vec::new(),
156        }
157    }
158}
159
160/// Generate dates according to the specified interval and constraints.
161fn generate_dates(
162    start: NaiveDate,
163    interval: Interval,
164    skip: usize,
165    repeat: Option<usize>,
166    until: NaiveDate,
167) -> Vec<NaiveDate> {
168    let mut dates = Vec::new();
169    let mut current = start;
170    let step = skip + 1; // Skip means interval multiplier
171
172    loop {
173        dates.push(current);
174
175        // Check repeat count
176        if let Some(max_count) = repeat
177            && dates.len() >= max_count
178        {
179            break;
180        }
181
182        // Advance to next date
183        current = match interval {
184            Interval::Daily => current + chrono::Duration::days(step as i64),
185            Interval::Weekly => current + chrono::Duration::weeks(step as i64),
186            Interval::Monthly => add_months(current, step as i32),
187            Interval::Yearly => add_months(current, (step * 12) as i32),
188        };
189
190        // Check until date
191        if current > until {
192            break;
193        }
194
195        // Safety limit
196        if dates.len() > 1000 {
197            break;
198        }
199    }
200
201    dates
202}
203
204/// Add months to a date, handling month-end overflow.
205fn add_months(date: NaiveDate, months: i32) -> NaiveDate {
206    let total_months = date.month0() as i32 + months;
207    let new_year = date.year() + total_months / 12;
208    // Normalize total_months to a 0–11 month index even when total_months is negative
209    // (Rust's % operator can return a negative remainder, so we use a double modulo).
210    let new_month = (total_months % 12 + 12) % 12 + 1;
211
212    // Try to keep the same day, but clamp to valid days in the new month
213    let max_day = days_in_month(new_year, new_month as u32);
214    let new_day = date.day().min(max_day);
215
216    NaiveDate::from_ymd_opt(new_year, new_month as u32, new_day).unwrap_or(date)
217}
218
219/// Get the number of days in a month.
220const fn days_in_month(year: i32, month: u32) -> u32 {
221    match month {
222        1 | 3 | 5 | 7 | 8 | 10 | 12 => 31,
223        4 | 6 | 9 | 11 => 30,
224        2 => {
225            if (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0) {
226                29
227            } else {
228                28
229            }
230        }
231        _ => 30, // Fallback
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238    use crate::types::*;
239
240    fn create_forecast_transaction(date: &str, narration: &str) -> DirectiveWrapper {
241        DirectiveWrapper {
242            directive_type: "transaction".to_string(),
243            date: date.to_string(),
244            filename: None,
245            lineno: None,
246            data: DirectiveData::Transaction(TransactionData {
247                flag: "#".to_string(),
248                payee: None,
249                narration: narration.to_string(),
250                tags: vec![],
251                links: vec![],
252                metadata: vec![],
253                postings: vec![
254                    PostingData {
255                        account: "Expenses:Test".to_string(),
256                        units: Some(AmountData {
257                            number: "100.00".to_string(),
258                            currency: "USD".to_string(),
259                        }),
260                        cost: None,
261                        price: None,
262                        flag: None,
263                        metadata: vec![],
264                    },
265                    PostingData {
266                        account: "Assets:Cash".to_string(),
267                        units: Some(AmountData {
268                            number: "-100.00".to_string(),
269                            currency: "USD".to_string(),
270                        }),
271                        cost: None,
272                        price: None,
273                        flag: None,
274                        metadata: vec![],
275                    },
276                ],
277            }),
278        }
279    }
280
281    #[test]
282    fn test_forecast_monthly_repeat() {
283        let plugin = ForecastPlugin;
284
285        let input = PluginInput {
286            directives: vec![create_forecast_transaction(
287                "2024-01-15",
288                "Electric bill [MONTHLY REPEAT 3 TIMES]",
289            )],
290            options: PluginOptions {
291                operating_currencies: vec!["USD".to_string()],
292                title: None,
293            },
294            config: None,
295        };
296
297        let output = plugin.process(input);
298        assert_eq!(output.errors.len(), 0);
299        assert_eq!(output.directives.len(), 3);
300
301        // Check dates
302        assert_eq!(output.directives[0].date, "2024-01-15");
303        assert_eq!(output.directives[1].date, "2024-02-15");
304        assert_eq!(output.directives[2].date, "2024-03-15");
305
306        // Check narration is cleaned
307        if let DirectiveData::Transaction(txn) = &output.directives[0].data {
308            assert_eq!(txn.narration, "Electric bill");
309        }
310    }
311
312    #[test]
313    fn test_forecast_weekly_repeat() {
314        let plugin = ForecastPlugin;
315
316        let input = PluginInput {
317            directives: vec![create_forecast_transaction(
318                "2024-01-01",
319                "Groceries [WEEKLY REPEAT 4 TIMES]",
320            )],
321            options: PluginOptions {
322                operating_currencies: vec!["USD".to_string()],
323                title: None,
324            },
325            config: None,
326        };
327
328        let output = plugin.process(input);
329        assert_eq!(output.directives.len(), 4);
330
331        assert_eq!(output.directives[0].date, "2024-01-01");
332        assert_eq!(output.directives[1].date, "2024-01-08");
333        assert_eq!(output.directives[2].date, "2024-01-15");
334        assert_eq!(output.directives[3].date, "2024-01-22");
335    }
336
337    #[test]
338    fn test_forecast_until_date() {
339        let plugin = ForecastPlugin;
340
341        let input = PluginInput {
342            directives: vec![create_forecast_transaction(
343                "2024-01-15",
344                "Rent [MONTHLY UNTIL 2024-03-15]",
345            )],
346            options: PluginOptions {
347                operating_currencies: vec!["USD".to_string()],
348                title: None,
349            },
350            config: None,
351        };
352
353        let output = plugin.process(input);
354        assert_eq!(output.directives.len(), 3);
355
356        assert_eq!(output.directives[0].date, "2024-01-15");
357        assert_eq!(output.directives[1].date, "2024-02-15");
358        assert_eq!(output.directives[2].date, "2024-03-15");
359    }
360
361    #[test]
362    fn test_forecast_skip() {
363        let plugin = ForecastPlugin;
364
365        let input = PluginInput {
366            directives: vec![create_forecast_transaction(
367                "2024-01-01",
368                "Insurance [MONTHLY SKIP 1 TIME REPEAT 3 TIMES]",
369            )],
370            options: PluginOptions {
371                operating_currencies: vec!["USD".to_string()],
372                title: None,
373            },
374            config: None,
375        };
376
377        let output = plugin.process(input);
378        assert_eq!(output.directives.len(), 3);
379
380        // With SKIP 1 TIME, it should skip every other month (bi-monthly)
381        assert_eq!(output.directives[0].date, "2024-01-01");
382        assert_eq!(output.directives[1].date, "2024-03-01");
383        assert_eq!(output.directives[2].date, "2024-05-01");
384    }
385
386    #[test]
387    fn test_forecast_preserves_non_forecast_transactions() {
388        let plugin = ForecastPlugin;
389
390        let mut regular_txn = create_forecast_transaction("2024-01-15", "Regular purchase");
391        if let DirectiveData::Transaction(ref mut txn) = regular_txn.data {
392            txn.flag = "*".to_string(); // Regular transaction, not forecast
393        }
394
395        let input = PluginInput {
396            directives: vec![regular_txn],
397            options: PluginOptions {
398                operating_currencies: vec!["USD".to_string()],
399                title: None,
400            },
401            config: None,
402        };
403
404        let output = plugin.process(input);
405        assert_eq!(output.directives.len(), 1);
406
407        if let DirectiveData::Transaction(txn) = &output.directives[0].data {
408            assert_eq!(txn.flag, "*");
409            assert_eq!(txn.narration, "Regular purchase");
410        }
411    }
412
413    #[test]
414    fn test_add_months() {
415        // Regular case
416        assert_eq!(
417            add_months(NaiveDate::from_ymd_opt(2024, 1, 15).unwrap(), 1),
418            NaiveDate::from_ymd_opt(2024, 2, 15).unwrap()
419        );
420
421        // Month-end overflow (Jan 31 -> Feb 28/29)
422        assert_eq!(
423            add_months(NaiveDate::from_ymd_opt(2024, 1, 31).unwrap(), 1),
424            NaiveDate::from_ymd_opt(2024, 2, 29).unwrap() // 2024 is leap year
425        );
426
427        // Year overflow
428        assert_eq!(
429            add_months(NaiveDate::from_ymd_opt(2024, 11, 15).unwrap(), 3),
430            NaiveDate::from_ymd_opt(2025, 2, 15).unwrap()
431        );
432    }
433}