Skip to main content

rustledger_plugin/native/plugins/
forecast.rs

1//! Forecast plugin - generate recurring transactions.
2//!
3//! This plugin generates recurring transactions from template transactions
4//! marked with the "#" flag. The periodicity is specified in the narration.
5//!
6//! Example:
7//! ```beancount
8//! 2014-03-08 # "Electricity bill [MONTHLY]"
9//!   Expenses:Electricity  50.10 USD
10//!   Assets:Checking      -50.10 USD
11//! ```
12//!
13//! Supported patterns:
14//! - `[MONTHLY]` - Repeat monthly until end of current year
15//! - `[WEEKLY]` - Repeat weekly until end of current year
16//! - `[DAILY]` - Repeat daily until end of current year
17//! - `[YEARLY]` - Repeat yearly until end of current year
18//! - `[MONTHLY REPEAT 3 TIMES]` - Repeat 3 times
19//! - `[MONTHLY UNTIL 2020-12-31]` - Repeat until specified date
20//! - `[MONTHLY SKIP 1 TIME]` - Skip every other month
21
22use chrono::{Datelike, NaiveDate};
23use regex::Regex;
24
25use crate::types::{DirectiveData, PluginInput, PluginOutput};
26
27use super::super::NativePlugin;
28
29/// Plugin for generating recurring forecast transactions.
30pub struct ForecastPlugin;
31
32#[derive(Debug, Clone, Copy, PartialEq)]
33enum Interval {
34    Daily,
35    Weekly,
36    Monthly,
37    Yearly,
38}
39
40impl NativePlugin for ForecastPlugin {
41    fn name(&self) -> &'static str {
42        "forecast"
43    }
44
45    fn description(&self) -> &'static str {
46        "Generate recurring forecast transactions"
47    }
48
49    fn process(&self, input: PluginInput) -> PluginOutput {
50        // Regex to parse forecast patterns
51        let pattern = Regex::new(
52            r"(?x)
53            (^.*?)                             # narration prefix
54            \[
55            (MONTHLY|YEARLY|WEEKLY|DAILY)     # interval type
56            (?:\s+SKIP\s+(\d+)\s+TIMES?)?     # optional SKIP n TIMES
57            (?:\s+REPEAT\s+(\d+)\s+TIMES?)?   # optional REPEAT n TIMES
58            (?:\s+UNTIL\s+(\d{4}-\d{2}-\d{2}))? # optional UNTIL date
59            \]
60        ",
61        )
62        .unwrap();
63
64        let mut forecast_entries = Vec::new();
65        let mut filtered_entries = Vec::new();
66
67        // Separate forecast entries from regular entries
68        for directive in input.directives {
69            if directive.directive_type == "transaction"
70                && let DirectiveData::Transaction(ref txn) = directive.data
71                && txn.flag == "#"
72            {
73                forecast_entries.push(directive);
74            } else {
75                filtered_entries.push(directive);
76            }
77        }
78
79        // Get current year end as default until date
80        let today = chrono::Local::now().naive_local().date();
81        let default_until = NaiveDate::from_ymd_opt(today.year(), 12, 31).unwrap();
82
83        // Generate recurring transactions
84        let mut new_entries = Vec::new();
85
86        for directive in forecast_entries {
87            if let DirectiveData::Transaction(ref txn) = directive.data {
88                if let Some(caps) = pattern.captures(&txn.narration) {
89                    let narration_prefix = caps.get(1).map_or("", |m| m.as_str().trim());
90                    let interval_str = caps.get(2).map_or("MONTHLY", |m| m.as_str());
91                    let skip_count: usize = caps
92                        .get(3)
93                        .and_then(|m| m.as_str().parse().ok())
94                        .unwrap_or(0);
95                    let repeat_count: Option<usize> =
96                        caps.get(4).and_then(|m| m.as_str().parse().ok());
97                    let until_date: Option<NaiveDate> = caps
98                        .get(5)
99                        .and_then(|m| NaiveDate::parse_from_str(m.as_str(), "%Y-%m-%d").ok());
100
101                    let interval = match interval_str {
102                        "DAILY" => Interval::Daily,
103                        "WEEKLY" => Interval::Weekly,
104                        "YEARLY" => Interval::Yearly,
105                        _ => Interval::Monthly,
106                    };
107
108                    // Parse start date
109                    let start_date =
110                        if let Ok(date) = NaiveDate::parse_from_str(&directive.date, "%Y-%m-%d") {
111                            date
112                        } else {
113                            // Skip if date is unparsable
114                            new_entries.push(directive);
115                            continue;
116                        };
117
118                    // Determine end condition
119                    let until = until_date.unwrap_or(default_until);
120
121                    // Generate dates
122                    let dates =
123                        generate_dates(start_date, interval, skip_count, repeat_count, until);
124
125                    // Create a transaction for each date
126                    for date in dates {
127                        let mut new_directive = directive.clone();
128                        new_directive.date = date.format("%Y-%m-%d").to_string();
129
130                        if let DirectiveData::Transaction(ref mut new_txn) = new_directive.data {
131                            new_txn.narration = narration_prefix.to_string();
132                        }
133
134                        new_entries.push(new_directive);
135                    }
136                } else {
137                    // No pattern match, keep original
138                    new_entries.push(directive);
139                }
140            }
141        }
142
143        // Sort new entries by date
144        new_entries.sort_by(|a, b| a.date.cmp(&b.date));
145
146        // Combine filtered entries with new entries
147        filtered_entries.extend(new_entries);
148
149        PluginOutput {
150            directives: filtered_entries,
151            errors: Vec::new(),
152        }
153    }
154}
155
156/// Generate dates according to the specified interval and constraints.
157fn generate_dates(
158    start: NaiveDate,
159    interval: Interval,
160    skip: usize,
161    repeat: Option<usize>,
162    until: NaiveDate,
163) -> Vec<NaiveDate> {
164    let mut dates = Vec::new();
165    let mut current = start;
166    let step = skip + 1; // Skip means interval multiplier
167
168    loop {
169        dates.push(current);
170
171        // Check repeat count
172        if let Some(max_count) = repeat
173            && dates.len() >= max_count
174        {
175            break;
176        }
177
178        // Advance to next date
179        current = match interval {
180            Interval::Daily => current + chrono::Duration::days(step as i64),
181            Interval::Weekly => current + chrono::Duration::weeks(step as i64),
182            Interval::Monthly => add_months(current, step as i32),
183            Interval::Yearly => add_months(current, (step * 12) as i32),
184        };
185
186        // Check until date
187        if current > until {
188            break;
189        }
190
191        // Safety limit
192        if dates.len() > 1000 {
193            break;
194        }
195    }
196
197    dates
198}
199
200/// Add months to a date, handling month-end overflow.
201fn add_months(date: NaiveDate, months: i32) -> NaiveDate {
202    let total_months = date.month0() as i32 + months;
203    let new_year = date.year() + total_months / 12;
204    // Normalize total_months to a 0–11 month index even when total_months is negative
205    // (Rust's % operator can return a negative remainder, so we use a double modulo).
206    let new_month = (total_months % 12 + 12) % 12 + 1;
207
208    // Try to keep the same day, but clamp to valid days in the new month
209    let max_day = days_in_month(new_year, new_month as u32);
210    let new_day = date.day().min(max_day);
211
212    NaiveDate::from_ymd_opt(new_year, new_month as u32, new_day).unwrap_or(date)
213}
214
215/// Get the number of days in a month.
216const fn days_in_month(year: i32, month: u32) -> u32 {
217    match month {
218        1 | 3 | 5 | 7 | 8 | 10 | 12 => 31,
219        4 | 6 | 9 | 11 => 30,
220        2 => {
221            if (year % 4 == 0 && year % 100 != 0) || (year % 400 == 0) {
222                29
223            } else {
224                28
225            }
226        }
227        _ => 30, // Fallback
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use crate::types::*;
235
236    fn create_forecast_transaction(date: &str, narration: &str) -> DirectiveWrapper {
237        DirectiveWrapper {
238            directive_type: "transaction".to_string(),
239            date: date.to_string(),
240            filename: None,
241            lineno: None,
242            data: DirectiveData::Transaction(TransactionData {
243                flag: "#".to_string(),
244                payee: None,
245                narration: narration.to_string(),
246                tags: vec![],
247                links: vec![],
248                metadata: vec![],
249                postings: vec![
250                    PostingData {
251                        account: "Expenses:Test".to_string(),
252                        units: Some(AmountData {
253                            number: "100.00".to_string(),
254                            currency: "USD".to_string(),
255                        }),
256                        cost: None,
257                        price: None,
258                        flag: None,
259                        metadata: vec![],
260                    },
261                    PostingData {
262                        account: "Assets:Cash".to_string(),
263                        units: Some(AmountData {
264                            number: "-100.00".to_string(),
265                            currency: "USD".to_string(),
266                        }),
267                        cost: None,
268                        price: None,
269                        flag: None,
270                        metadata: vec![],
271                    },
272                ],
273            }),
274        }
275    }
276
277    #[test]
278    fn test_forecast_monthly_repeat() {
279        let plugin = ForecastPlugin;
280
281        let input = PluginInput {
282            directives: vec![create_forecast_transaction(
283                "2024-01-15",
284                "Electric bill [MONTHLY REPEAT 3 TIMES]",
285            )],
286            options: PluginOptions {
287                operating_currencies: vec!["USD".to_string()],
288                title: None,
289            },
290            config: None,
291        };
292
293        let output = plugin.process(input);
294        assert_eq!(output.errors.len(), 0);
295        assert_eq!(output.directives.len(), 3);
296
297        // Check dates
298        assert_eq!(output.directives[0].date, "2024-01-15");
299        assert_eq!(output.directives[1].date, "2024-02-15");
300        assert_eq!(output.directives[2].date, "2024-03-15");
301
302        // Check narration is cleaned
303        if let DirectiveData::Transaction(txn) = &output.directives[0].data {
304            assert_eq!(txn.narration, "Electric bill");
305        }
306    }
307
308    #[test]
309    fn test_forecast_weekly_repeat() {
310        let plugin = ForecastPlugin;
311
312        let input = PluginInput {
313            directives: vec![create_forecast_transaction(
314                "2024-01-01",
315                "Groceries [WEEKLY REPEAT 4 TIMES]",
316            )],
317            options: PluginOptions {
318                operating_currencies: vec!["USD".to_string()],
319                title: None,
320            },
321            config: None,
322        };
323
324        let output = plugin.process(input);
325        assert_eq!(output.directives.len(), 4);
326
327        assert_eq!(output.directives[0].date, "2024-01-01");
328        assert_eq!(output.directives[1].date, "2024-01-08");
329        assert_eq!(output.directives[2].date, "2024-01-15");
330        assert_eq!(output.directives[3].date, "2024-01-22");
331    }
332
333    #[test]
334    fn test_forecast_until_date() {
335        let plugin = ForecastPlugin;
336
337        let input = PluginInput {
338            directives: vec![create_forecast_transaction(
339                "2024-01-15",
340                "Rent [MONTHLY UNTIL 2024-03-15]",
341            )],
342            options: PluginOptions {
343                operating_currencies: vec!["USD".to_string()],
344                title: None,
345            },
346            config: None,
347        };
348
349        let output = plugin.process(input);
350        assert_eq!(output.directives.len(), 3);
351
352        assert_eq!(output.directives[0].date, "2024-01-15");
353        assert_eq!(output.directives[1].date, "2024-02-15");
354        assert_eq!(output.directives[2].date, "2024-03-15");
355    }
356
357    #[test]
358    fn test_forecast_skip() {
359        let plugin = ForecastPlugin;
360
361        let input = PluginInput {
362            directives: vec![create_forecast_transaction(
363                "2024-01-01",
364                "Insurance [MONTHLY SKIP 1 TIME REPEAT 3 TIMES]",
365            )],
366            options: PluginOptions {
367                operating_currencies: vec!["USD".to_string()],
368                title: None,
369            },
370            config: None,
371        };
372
373        let output = plugin.process(input);
374        assert_eq!(output.directives.len(), 3);
375
376        // With SKIP 1 TIME, it should skip every other month (bi-monthly)
377        assert_eq!(output.directives[0].date, "2024-01-01");
378        assert_eq!(output.directives[1].date, "2024-03-01");
379        assert_eq!(output.directives[2].date, "2024-05-01");
380    }
381
382    #[test]
383    fn test_forecast_preserves_non_forecast_transactions() {
384        let plugin = ForecastPlugin;
385
386        let mut regular_txn = create_forecast_transaction("2024-01-15", "Regular purchase");
387        if let DirectiveData::Transaction(ref mut txn) = regular_txn.data {
388            txn.flag = "*".to_string(); // Regular transaction, not forecast
389        }
390
391        let input = PluginInput {
392            directives: vec![regular_txn],
393            options: PluginOptions {
394                operating_currencies: vec!["USD".to_string()],
395                title: None,
396            },
397            config: None,
398        };
399
400        let output = plugin.process(input);
401        assert_eq!(output.directives.len(), 1);
402
403        if let DirectiveData::Transaction(txn) = &output.directives[0].data {
404            assert_eq!(txn.flag, "*");
405            assert_eq!(txn.narration, "Regular purchase");
406        }
407    }
408
409    #[test]
410    fn test_add_months() {
411        // Regular case
412        assert_eq!(
413            add_months(NaiveDate::from_ymd_opt(2024, 1, 15).unwrap(), 1),
414            NaiveDate::from_ymd_opt(2024, 2, 15).unwrap()
415        );
416
417        // Month-end overflow (Jan 31 -> Feb 28/29)
418        assert_eq!(
419            add_months(NaiveDate::from_ymd_opt(2024, 1, 31).unwrap(), 1),
420            NaiveDate::from_ymd_opt(2024, 2, 29).unwrap() // 2024 is leap year
421        );
422
423        // Year overflow
424        assert_eq!(
425            add_months(NaiveDate::from_ymd_opt(2024, 11, 15).unwrap(), 3),
426            NaiveDate::from_ymd_opt(2025, 2, 15).unwrap()
427        );
428    }
429}