Skip to main content

rustledger_ops/
ml.rs

1//! ML-based transaction categorization.
2//!
3//! Trains a Multinomial Naive Bayes classifier on existing ledger transactions
4//! to predict the expense/income account for new transactions based on their
5//! payee and narration text.
6//!
7//! Uses TF-IDF vectorization plus a small, self-contained Multinomial
8//! Naive Bayes classifier (see `MultinomialNB`) implemented in pure
9//! `std` — no external ML or linear-algebra crates. Earlier versions
10//! delegated to `linfa-bayes` and then `ferrolearn-bayes`, but both
11//! dragged heavy, occasionally wasm-incompatible dependencies in for an
12//! algorithm that is ~80 lines of textbook arithmetic.
13//!
14//! # Example
15//!
16//! ```rust,ignore
17//! let model = CategorizationModel::train(&existing_directives)?;
18//! let predictions = model.predict("WHOLE FOODS", Some("groceries"));
19//! // → [("Expenses:Groceries", 0.92), ("Expenses:Dining", 0.05), ...]
20//! ```
21
22use rustledger_plugin_types::{DirectiveData, DirectiveWrapper};
23use std::collections::HashMap;
24
25/// A trained categorization model.
26///
27/// Wraps a Multinomial Naive Bayes classifier trained on TF-IDF features
28/// extracted from transaction payee/narration text.
29pub struct CategorizationModel {
30    /// The trained classifier.
31    model: MultinomialNB,
32    /// Vocabulary: word → column index in the feature matrix.
33    vocabulary: HashMap<String, usize>,
34    /// IDF weights for each word in the vocabulary.
35    idf: Vec<f64>,
36    /// Label map: index → account name.
37    labels: Vec<String>,
38}
39
40/// Error type for ML operations.
41///
42/// Training is the only fallible step, and it fails only when there's too
43/// little data to build a useful model — fitting itself is infallible.
44///
45/// Marked `#[non_exhaustive]` so future failure modes can be added without
46/// breaking downstream matches.
47#[derive(Debug)]
48#[non_exhaustive]
49pub enum MlError {
50    /// Not enough training data (too few transactions or categories).
51    InsufficientData(String),
52}
53
54impl std::fmt::Display for MlError {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        let Self::InsufficientData(msg) = self;
57        write!(f, "insufficient training data: {msg}")
58    }
59}
60
61impl std::error::Error for MlError {}
62
63impl CategorizationModel {
64    /// Train a model from existing ledger directives.
65    ///
66    /// Extracts (text, account) pairs from transactions where the second
67    /// posting's account is the categorization target. Requires at least
68    /// 2 distinct categories with at least 1 transaction each.
69    ///
70    /// # Errors
71    ///
72    /// Returns `MlError::InsufficientData` if there aren't enough transactions
73    /// or distinct categories to train a useful model.
74    pub fn train(directives: &[DirectiveWrapper]) -> Result<Self, MlError> {
75        // Extract training data: (text, account) pairs
76        let mut samples: Vec<(String, String)> = Vec::new();
77
78        for d in directives {
79            if let DirectiveData::Transaction(txn) = &d.data {
80                // Skip transactions with fewer than 2 postings
81                if txn.postings.len() < 2 {
82                    continue;
83                }
84
85                // The target account is the second posting (contra-account)
86                let account = &txn.postings[1].account;
87
88                // Build text from payee + narration
89                let mut text = String::new();
90                if let Some(ref payee) = txn.payee {
91                    text.push_str(payee);
92                    text.push(' ');
93                }
94                text.push_str(&txn.narration);
95
96                if !text.trim().is_empty() {
97                    samples.push((text.to_lowercase(), account.clone()));
98                }
99            }
100        }
101
102        if samples.len() < 2 {
103            return Err(MlError::InsufficientData(format!(
104                "need at least 2 transactions, got {}",
105                samples.len()
106            )));
107        }
108
109        // Build label map
110        let mut label_set: Vec<String> = samples.iter().map(|(_, a)| a.clone()).collect();
111        label_set.sort();
112        label_set.dedup();
113
114        if label_set.len() < 2 {
115            return Err(MlError::InsufficientData(
116                "need at least 2 distinct categories".to_string(),
117            ));
118        }
119
120        let label_to_idx: HashMap<&str, usize> = label_set
121            .iter()
122            .enumerate()
123            .map(|(i, s)| (s.as_str(), i))
124            .collect();
125
126        // Build vocabulary from all tokens
127        let mut vocab: HashMap<String, usize> = HashMap::new();
128        let tokenized: Vec<Vec<String>> = samples.iter().map(|(text, _)| tokenize(text)).collect();
129
130        for tokens in &tokenized {
131            for token in tokens {
132                let len = vocab.len();
133                vocab.entry(token.clone()).or_insert(len);
134            }
135        }
136
137        if vocab.is_empty() {
138            return Err(MlError::InsufficientData(
139                "no tokens found in training data".to_string(),
140            ));
141        }
142
143        // Compute IDF weights
144        let n_docs = samples.len() as f64;
145        let mut doc_freq = vec![0u32; vocab.len()];
146        for tokens in &tokenized {
147            let mut seen = std::collections::HashSet::new();
148            for token in tokens {
149                if let Some(&idx) = vocab.get(token)
150                    && seen.insert(idx)
151                {
152                    doc_freq[idx] += 1;
153                }
154            }
155        }
156        let idf: Vec<f64> = doc_freq
157            .iter()
158            .map(|&df| (n_docs / (1.0 + f64::from(df))).ln() + 1.0)
159            .collect();
160
161        // Build sparse TF-IDF rows — only the non-zero `(vocab index,
162        // weight)` entries, sorted by index for deterministic summation.
163        // TF-IDF vectors are mostly zero, so a dense matrix would cost
164        // O(n_samples × vocab); sparse rows cost O(total tokens).
165        let n_features = vocab.len();
166        let mut features: Vec<Vec<(usize, f64)>> = Vec::with_capacity(samples.len());
167        let mut targets: Vec<usize> = Vec::with_capacity(samples.len());
168
169        for (tokens, (_, account)) in tokenized.iter().zip(samples.iter()) {
170            features.push(tfidf_row(tokens, &vocab, &idf));
171            targets.push(label_to_idx[account.as_str()]);
172        }
173
174        // Train the classifier. Laplace smoothing (alpha = 1.0) matches
175        // the linfa-bayes / ferrolearn defaults this replaced.
176        let model = MultinomialNB::fit(&features, &targets, label_set.len(), n_features);
177
178        Ok(Self {
179            model,
180            vocabulary: vocab,
181            idf,
182            labels: label_set,
183        })
184    }
185
186    /// Predict the account for a transaction.
187    ///
188    /// Returns predictions sorted by confidence (highest first). Each
189    /// prediction is an `(account, probability)` pair. The probabilities
190    /// are the class-conditional posteriors from the underlying
191    /// Multinomial Naive Bayes model (they sum to 1.0 across all
192    /// classes), so callers can treat them as honest scores.
193    #[must_use]
194    pub fn predict(&self, narration: &str, payee: Option<&str>) -> Vec<(String, f64)> {
195        let mut text = String::new();
196        if let Some(p) = payee {
197            text.push_str(p);
198            text.push(' ');
199        }
200        text.push_str(narration);
201
202        let features = self.vectorize(&text.to_lowercase());
203
204        // Pair each class posterior with its label, then sort descending;
205        // callers take the top-k. Softmax posteriors are all > 0, so
206        // every known category is returned.
207        let mut results: Vec<(String, f64)> = self
208            .labels
209            .iter()
210            .cloned()
211            .zip(self.model.predict_proba(&features))
212            .collect();
213
214        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
215        results
216    }
217
218    /// Vectorize text into a sparse TF-IDF feature vector (the non-zero
219    /// `(vocab index, weight)` entries).
220    fn vectorize(&self, text: &str) -> Vec<(usize, f64)> {
221        tfidf_row(&tokenize(text), &self.vocabulary, &self.idf)
222    }
223
224    /// Number of distinct categories the model was trained on.
225    #[must_use]
226    pub const fn num_categories(&self) -> usize {
227        self.labels.len()
228    }
229
230    /// Number of features (vocabulary size).
231    #[must_use]
232    pub fn vocab_size(&self) -> usize {
233        self.vocabulary.len()
234    }
235}
236
237/// A Multinomial Naive Bayes classifier with Laplace (add-α) smoothing.
238///
239/// This is the standard multinomial NB used for text classification —
240/// equivalent to scikit-learn's `MultinomialNB` with `alpha = 1.0`, and
241/// to the `linfa-bayes` / `ferrolearn-bayes` implementations this
242/// replaced. Samples are passed as **sparse** `(feature index, value)`
243/// rows; values are treated as fractional counts, so TF-IDF weights are
244/// valid inputs directly.
245///
246/// `fit` computes, per class `c`: a log prior `ln(n_c / n)` and smoothed
247/// feature log-probabilities `ln((Σᵢ xᵢⱼ + α) / (Σⱼ Σᵢ xᵢⱼ + α·n_features))`.
248/// `predict_proba` forms the joint log-likelihood
249/// `log_prior[c] + Σⱼ xⱼ · feature_log_prob[c][j]` per class and
250/// normalizes it with a numerically-stable softmax (log-sum-exp).
251struct MultinomialNB {
252    /// `ln P(class)` for each class — indexed `[class]`.
253    class_log_prior: Vec<f64>,
254    /// `ln P(feature | class)` — indexed `[class][feature]`.
255    feature_log_prob: Vec<Vec<f64>>,
256}
257
258impl MultinomialNB {
259    /// Laplace / additive smoothing parameter (the sklearn / linfa /
260    /// ferrolearn default).
261    const ALPHA: f64 = 1.0;
262
263    /// Fit on sparse feature rows and their class-index targets.
264    ///
265    /// `features[i]` is sample `i` as `(feature index, value)` pairs and
266    /// `targets[i]` is its class. Preconditions, all guaranteed by the
267    /// caller (which builds both from the same sample set):
268    /// `features.len() == targets.len()`; every feature index is
269    /// `< n_features`; every class index is `< n_classes`; and every
270    /// class occurs at least once, so no class prior is `-inf`.
271    fn fit(
272        features: &[Vec<(usize, f64)>],
273        targets: &[usize],
274        n_classes: usize,
275        n_features: usize,
276    ) -> Self {
277        debug_assert_eq!(
278            features.len(),
279            targets.len(),
280            "features and targets must be parallel"
281        );
282        let n_samples = features.len() as f64;
283
284        // Per-class sample counts and summed feature weights.
285        let mut class_count = vec![0.0_f64; n_classes];
286        let mut feature_count = vec![vec![0.0_f64; n_features]; n_classes];
287        for (row, &class) in features.iter().zip(targets) {
288            class_count[class] += 1.0;
289            let counts = &mut feature_count[class];
290            for &(j, value) in row {
291                counts[j] += value;
292            }
293        }
294
295        let class_log_prior = class_count.iter().map(|&n| (n / n_samples).ln()).collect();
296
297        let feature_log_prob = feature_count
298            .iter()
299            .map(|counts| {
300                let denom: f64 = Self::ALPHA.mul_add(n_features as f64, counts.iter().sum::<f64>());
301                counts
302                    .iter()
303                    .map(|&count| ((count + Self::ALPHA) / denom).ln())
304                    .collect()
305            })
306            .collect();
307
308        Self {
309            class_log_prior,
310            feature_log_prob,
311        }
312    }
313
314    /// Posterior class probabilities for one sparse sample, summing to 1.0.
315    ///
316    /// `x` is the sample as `(feature index, value)` pairs. The returned
317    /// vector is indexed by class, in the order the model was trained with.
318    fn predict_proba(&self, x: &[(usize, f64)]) -> Vec<f64> {
319        // Joint log-likelihood per class.
320        let jll: Vec<f64> = self
321            .class_log_prior
322            .iter()
323            .zip(&self.feature_log_prob)
324            .map(|(&prior, log_prob)| prior + x.iter().map(|&(j, v)| v * log_prob[j]).sum::<f64>())
325            .collect();
326
327        // Stable softmax: subtract the max before exponentiating.
328        let max = jll.iter().copied().fold(f64::NEG_INFINITY, f64::max);
329        let exps: Vec<f64> = jll.iter().map(|&v| (v - max).exp()).collect();
330        let total: f64 = exps.iter().sum();
331        exps.iter().map(|&e| e / total).collect()
332    }
333}
334
335/// Tokenize text into lowercase words, filtering out short tokens.
336fn tokenize(text: &str) -> Vec<String> {
337    text.split(|c: char| !c.is_alphanumeric())
338        .filter(|s| s.len() >= 2)
339        .map(str::to_lowercase)
340        .collect()
341}
342
343/// Build a sparse TF-IDF row: the non-zero `(vocab index, weight)` entries
344/// for `tokens`, sorted by index. Tokens absent from `vocab` are ignored.
345/// Sorting makes the row order (and thus the downstream summation)
346/// deterministic, independent of `HashMap` iteration order.
347fn tfidf_row(tokens: &[String], vocab: &HashMap<String, usize>, idf: &[f64]) -> Vec<(usize, f64)> {
348    let mut tf: HashMap<usize, u32> = HashMap::new();
349    for token in tokens {
350        if let Some(&idx) = vocab.get(token) {
351            *tf.entry(idx).or_insert(0) += 1;
352        }
353    }
354    let mut row: Vec<(usize, f64)> = tf
355        .into_iter()
356        .map(|(idx, count)| (idx, f64::from(count) * idf[idx]))
357        .collect();
358    row.sort_unstable_by_key(|&(idx, _)| idx);
359    row
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use rustledger_plugin_types::{AmountData, PostingData, TransactionData};
366
367    fn make_txn(
368        payee: Option<&str>,
369        narration: &str,
370        from_account: &str,
371        to_account: &str,
372    ) -> DirectiveWrapper {
373        DirectiveWrapper {
374            directive_type: "transaction".to_string(),
375            date: "2024-01-15".to_string(),
376            filename: None,
377            lineno: None,
378            data: DirectiveData::Transaction(TransactionData {
379                flag: "*".to_string(),
380                payee: payee.map(String::from),
381                narration: narration.to_string(),
382                tags: vec![],
383                links: vec![],
384                metadata: vec![],
385                postings: vec![
386                    PostingData {
387                        account: from_account.to_string(),
388                        units: Some(AmountData {
389                            number: "-50.00".to_string(),
390                            currency: "USD".to_string(),
391                        }),
392                        cost: None,
393                        price: None,
394                        flag: None,
395                        metadata: vec![],
396                        span: None,
397                    },
398                    PostingData {
399                        account: to_account.to_string(),
400                        units: None,
401                        cost: None,
402                        price: None,
403                        flag: None,
404                        metadata: vec![],
405                        span: None,
406                    },
407                ],
408            }),
409        }
410    }
411
412    fn training_data() -> Vec<DirectiveWrapper> {
413        vec![
414            make_txn(
415                Some("Whole Foods"),
416                "Groceries",
417                "Assets:Bank",
418                "Expenses:Groceries",
419            ),
420            make_txn(
421                Some("Trader Joe's"),
422                "Weekly groceries",
423                "Assets:Bank",
424                "Expenses:Groceries",
425            ),
426            make_txn(
427                Some("Safeway"),
428                "Food shopping",
429                "Assets:Bank",
430                "Expenses:Groceries",
431            ),
432            make_txn(
433                Some("Kroger"),
434                "Groceries",
435                "Assets:Bank",
436                "Expenses:Groceries",
437            ),
438            make_txn(
439                Some("Starbucks"),
440                "Coffee",
441                "Assets:Bank",
442                "Expenses:Dining",
443            ),
444            make_txn(
445                Some("McDonald's"),
446                "Lunch",
447                "Assets:Bank",
448                "Expenses:Dining",
449            ),
450            make_txn(Some("Chipotle"), "Dinner", "Assets:Bank", "Expenses:Dining"),
451            make_txn(
452                Some("Panera"),
453                "Coffee and sandwich",
454                "Assets:Bank",
455                "Expenses:Dining",
456            ),
457            make_txn(Some("Shell"), "Gas", "Assets:Bank", "Expenses:Transport"),
458            make_txn(Some("Chevron"), "Fuel", "Assets:Bank", "Expenses:Transport"),
459            make_txn(
460                Some("Uber"),
461                "Ride to airport",
462                "Assets:Bank",
463                "Expenses:Transport",
464            ),
465        ]
466    }
467
468    #[test]
469    fn train_and_predict() {
470        let data = training_data();
471        let model = CategorizationModel::train(&data).unwrap();
472
473        assert_eq!(model.num_categories(), 3);
474        assert!(model.vocab_size() > 5);
475
476        let predictions = model.predict("Weekly food shopping at the store", None);
477        assert!(!predictions.is_empty());
478        // Should predict Groceries (most similar to training data)
479        assert_eq!(predictions[0].0, "Expenses:Groceries");
480    }
481
482    #[test]
483    fn predict_dining() {
484        let data = training_data();
485        let model = CategorizationModel::train(&data).unwrap();
486
487        let predictions = model.predict("Coffee", Some("Starbucks"));
488        assert!(!predictions.is_empty());
489        assert_eq!(predictions[0].0, "Expenses:Dining");
490    }
491
492    #[test]
493    fn predict_transport() {
494        let data = training_data();
495        let model = CategorizationModel::train(&data).unwrap();
496
497        let predictions = model.predict("Fuel for car", Some("Shell"));
498        assert!(!predictions.is_empty());
499        assert_eq!(predictions[0].0, "Expenses:Transport");
500    }
501
502    #[test]
503    fn insufficient_data() {
504        let data = vec![make_txn(
505            Some("Store"),
506            "Stuff",
507            "Assets:Bank",
508            "Expenses:Misc",
509        )];
510        let result = CategorizationModel::train(&data);
511        assert!(result.is_err());
512    }
513
514    #[test]
515    fn insufficient_categories() {
516        let data = vec![
517            make_txn(Some("Store"), "Stuff", "Assets:Bank", "Expenses:Misc"),
518            make_txn(Some("Shop"), "Things", "Assets:Bank", "Expenses:Misc"),
519        ];
520        let result = CategorizationModel::train(&data);
521        assert!(result.is_err());
522    }
523
524    #[test]
525    fn tokenize_basic() {
526        let tokens = tokenize("WHOLE FOODS MARKET #1234");
527        assert!(tokens.contains(&"whole".to_string()));
528        assert!(tokens.contains(&"foods".to_string()));
529        assert!(tokens.contains(&"market".to_string()));
530        assert!(tokens.contains(&"1234".to_string()));
531    }
532
533    #[test]
534    fn naive_bayes_known_values() {
535        // Two classes, two features: class 0 sees feature 0, class 1 sees
536        // feature 1. With Laplace alpha = 1.0 and equal priors:
537        //   feature_log_prob[0] = ln([3/4, 1/4]),  [1] = ln([1/4, 3/4])
538        //   class_log_prior      = ln([1/2, 1/2])
539        // For x = [1, 0]: jll = ln(3/8) vs ln(1/8) → softmax = [0.75, 0.25].
540        // Sparse rows: class 0's sample is feature 0 = 2.0, class 1's is
541        // feature 1 = 2.0; two features total.
542        let nb = MultinomialNB::fit(&[vec![(0, 2.0)], vec![(1, 2.0)]], &[0, 1], 2, 2);
543
544        let p = nb.predict_proba(&[(0, 1.0)]);
545        assert!((p[0] - 0.75).abs() < 1e-9, "p[0] = {}", p[0]);
546        assert!((p[1] - 0.25).abs() < 1e-9, "p[1] = {}", p[1]);
547        assert!(
548            (p.iter().sum::<f64>() - 1.0).abs() < 1e-12,
549            "posteriors must sum to 1.0"
550        );
551
552        // The symmetric input flips the posteriors.
553        let q = nb.predict_proba(&[(1, 1.0)]);
554        assert!((q[0] - 0.25).abs() < 1e-9 && (q[1] - 0.75).abs() < 1e-9);
555    }
556}