1use rustledger_plugin_types::{DirectiveData, DirectiveWrapper};
23use std::collections::HashMap;
24
25pub struct CategorizationModel {
30 model: MultinomialNB,
32 vocabulary: HashMap<String, usize>,
34 idf: Vec<f64>,
36 labels: Vec<String>,
38}
39
40#[derive(Debug)]
48#[non_exhaustive]
49pub enum MlError {
50 InsufficientData(String),
52}
53
54impl std::fmt::Display for MlError {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 let Self::InsufficientData(msg) = self;
57 write!(f, "insufficient training data: {msg}")
58 }
59}
60
61impl std::error::Error for MlError {}
62
63impl CategorizationModel {
64 pub fn train(directives: &[DirectiveWrapper]) -> Result<Self, MlError> {
75 let mut samples: Vec<(String, String)> = Vec::new();
77
78 for d in directives {
79 if let DirectiveData::Transaction(txn) = &d.data {
80 if txn.postings.len() < 2 {
82 continue;
83 }
84
85 let account = &txn.postings[1].account;
87
88 let mut text = String::new();
90 if let Some(ref payee) = txn.payee {
91 text.push_str(payee);
92 text.push(' ');
93 }
94 text.push_str(&txn.narration);
95
96 if !text.trim().is_empty() {
97 samples.push((text.to_lowercase(), account.clone()));
98 }
99 }
100 }
101
102 if samples.len() < 2 {
103 return Err(MlError::InsufficientData(format!(
104 "need at least 2 transactions, got {}",
105 samples.len()
106 )));
107 }
108
109 let mut label_set: Vec<String> = samples.iter().map(|(_, a)| a.clone()).collect();
111 label_set.sort();
112 label_set.dedup();
113
114 if label_set.len() < 2 {
115 return Err(MlError::InsufficientData(
116 "need at least 2 distinct categories".to_string(),
117 ));
118 }
119
120 let label_to_idx: HashMap<&str, usize> = label_set
121 .iter()
122 .enumerate()
123 .map(|(i, s)| (s.as_str(), i))
124 .collect();
125
126 let mut vocab: HashMap<String, usize> = HashMap::new();
128 let tokenized: Vec<Vec<String>> = samples.iter().map(|(text, _)| tokenize(text)).collect();
129
130 for tokens in &tokenized {
131 for token in tokens {
132 let len = vocab.len();
133 vocab.entry(token.clone()).or_insert(len);
134 }
135 }
136
137 if vocab.is_empty() {
138 return Err(MlError::InsufficientData(
139 "no tokens found in training data".to_string(),
140 ));
141 }
142
143 let n_docs = samples.len() as f64;
145 let mut doc_freq = vec![0u32; vocab.len()];
146 for tokens in &tokenized {
147 let mut seen = std::collections::HashSet::new();
148 for token in tokens {
149 if let Some(&idx) = vocab.get(token)
150 && seen.insert(idx)
151 {
152 doc_freq[idx] += 1;
153 }
154 }
155 }
156 let idf: Vec<f64> = doc_freq
157 .iter()
158 .map(|&df| (n_docs / (1.0 + f64::from(df))).ln() + 1.0)
159 .collect();
160
161 let n_features = vocab.len();
166 let mut features: Vec<Vec<(usize, f64)>> = Vec::with_capacity(samples.len());
167 let mut targets: Vec<usize> = Vec::with_capacity(samples.len());
168
169 for (tokens, (_, account)) in tokenized.iter().zip(samples.iter()) {
170 features.push(tfidf_row(tokens, &vocab, &idf));
171 targets.push(label_to_idx[account.as_str()]);
172 }
173
174 let model = MultinomialNB::fit(&features, &targets, label_set.len(), n_features);
177
178 Ok(Self {
179 model,
180 vocabulary: vocab,
181 idf,
182 labels: label_set,
183 })
184 }
185
186 #[must_use]
194 pub fn predict(&self, narration: &str, payee: Option<&str>) -> Vec<(String, f64)> {
195 let mut text = String::new();
196 if let Some(p) = payee {
197 text.push_str(p);
198 text.push(' ');
199 }
200 text.push_str(narration);
201
202 let features = self.vectorize(&text.to_lowercase());
203
204 let mut results: Vec<(String, f64)> = self
208 .labels
209 .iter()
210 .cloned()
211 .zip(self.model.predict_proba(&features))
212 .collect();
213
214 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
215 results
216 }
217
218 fn vectorize(&self, text: &str) -> Vec<(usize, f64)> {
221 tfidf_row(&tokenize(text), &self.vocabulary, &self.idf)
222 }
223
224 #[must_use]
226 pub const fn num_categories(&self) -> usize {
227 self.labels.len()
228 }
229
230 #[must_use]
232 pub fn vocab_size(&self) -> usize {
233 self.vocabulary.len()
234 }
235}
236
237struct MultinomialNB {
252 class_log_prior: Vec<f64>,
254 feature_log_prob: Vec<Vec<f64>>,
256}
257
258impl MultinomialNB {
259 const ALPHA: f64 = 1.0;
262
263 fn fit(
272 features: &[Vec<(usize, f64)>],
273 targets: &[usize],
274 n_classes: usize,
275 n_features: usize,
276 ) -> Self {
277 debug_assert_eq!(
278 features.len(),
279 targets.len(),
280 "features and targets must be parallel"
281 );
282 let n_samples = features.len() as f64;
283
284 let mut class_count = vec![0.0_f64; n_classes];
286 let mut feature_count = vec![vec![0.0_f64; n_features]; n_classes];
287 for (row, &class) in features.iter().zip(targets) {
288 class_count[class] += 1.0;
289 let counts = &mut feature_count[class];
290 for &(j, value) in row {
291 counts[j] += value;
292 }
293 }
294
295 let class_log_prior = class_count.iter().map(|&n| (n / n_samples).ln()).collect();
296
297 let feature_log_prob = feature_count
298 .iter()
299 .map(|counts| {
300 let denom: f64 = Self::ALPHA.mul_add(n_features as f64, counts.iter().sum::<f64>());
301 counts
302 .iter()
303 .map(|&count| ((count + Self::ALPHA) / denom).ln())
304 .collect()
305 })
306 .collect();
307
308 Self {
309 class_log_prior,
310 feature_log_prob,
311 }
312 }
313
314 fn predict_proba(&self, x: &[(usize, f64)]) -> Vec<f64> {
319 let jll: Vec<f64> = self
321 .class_log_prior
322 .iter()
323 .zip(&self.feature_log_prob)
324 .map(|(&prior, log_prob)| prior + x.iter().map(|&(j, v)| v * log_prob[j]).sum::<f64>())
325 .collect();
326
327 let max = jll.iter().copied().fold(f64::NEG_INFINITY, f64::max);
329 let exps: Vec<f64> = jll.iter().map(|&v| (v - max).exp()).collect();
330 let total: f64 = exps.iter().sum();
331 exps.iter().map(|&e| e / total).collect()
332 }
333}
334
335fn tokenize(text: &str) -> Vec<String> {
337 text.split(|c: char| !c.is_alphanumeric())
338 .filter(|s| s.len() >= 2)
339 .map(str::to_lowercase)
340 .collect()
341}
342
343fn tfidf_row(tokens: &[String], vocab: &HashMap<String, usize>, idf: &[f64]) -> Vec<(usize, f64)> {
348 let mut tf: HashMap<usize, u32> = HashMap::new();
349 for token in tokens {
350 if let Some(&idx) = vocab.get(token) {
351 *tf.entry(idx).or_insert(0) += 1;
352 }
353 }
354 let mut row: Vec<(usize, f64)> = tf
355 .into_iter()
356 .map(|(idx, count)| (idx, f64::from(count) * idf[idx]))
357 .collect();
358 row.sort_unstable_by_key(|&(idx, _)| idx);
359 row
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use rustledger_plugin_types::{AmountData, PostingData, TransactionData};
366
367 fn make_txn(
368 payee: Option<&str>,
369 narration: &str,
370 from_account: &str,
371 to_account: &str,
372 ) -> DirectiveWrapper {
373 DirectiveWrapper {
374 directive_type: "transaction".to_string(),
375 date: "2024-01-15".to_string(),
376 filename: None,
377 lineno: None,
378 data: DirectiveData::Transaction(TransactionData {
379 flag: "*".to_string(),
380 payee: payee.map(String::from),
381 narration: narration.to_string(),
382 tags: vec![],
383 links: vec![],
384 metadata: vec![],
385 postings: vec![
386 PostingData {
387 account: from_account.to_string(),
388 units: Some(AmountData {
389 number: "-50.00".to_string(),
390 currency: "USD".to_string(),
391 }),
392 cost: None,
393 price: None,
394 flag: None,
395 metadata: vec![],
396 span: None,
397 },
398 PostingData {
399 account: to_account.to_string(),
400 units: None,
401 cost: None,
402 price: None,
403 flag: None,
404 metadata: vec![],
405 span: None,
406 },
407 ],
408 }),
409 }
410 }
411
412 fn training_data() -> Vec<DirectiveWrapper> {
413 vec![
414 make_txn(
415 Some("Whole Foods"),
416 "Groceries",
417 "Assets:Bank",
418 "Expenses:Groceries",
419 ),
420 make_txn(
421 Some("Trader Joe's"),
422 "Weekly groceries",
423 "Assets:Bank",
424 "Expenses:Groceries",
425 ),
426 make_txn(
427 Some("Safeway"),
428 "Food shopping",
429 "Assets:Bank",
430 "Expenses:Groceries",
431 ),
432 make_txn(
433 Some("Kroger"),
434 "Groceries",
435 "Assets:Bank",
436 "Expenses:Groceries",
437 ),
438 make_txn(
439 Some("Starbucks"),
440 "Coffee",
441 "Assets:Bank",
442 "Expenses:Dining",
443 ),
444 make_txn(
445 Some("McDonald's"),
446 "Lunch",
447 "Assets:Bank",
448 "Expenses:Dining",
449 ),
450 make_txn(Some("Chipotle"), "Dinner", "Assets:Bank", "Expenses:Dining"),
451 make_txn(
452 Some("Panera"),
453 "Coffee and sandwich",
454 "Assets:Bank",
455 "Expenses:Dining",
456 ),
457 make_txn(Some("Shell"), "Gas", "Assets:Bank", "Expenses:Transport"),
458 make_txn(Some("Chevron"), "Fuel", "Assets:Bank", "Expenses:Transport"),
459 make_txn(
460 Some("Uber"),
461 "Ride to airport",
462 "Assets:Bank",
463 "Expenses:Transport",
464 ),
465 ]
466 }
467
468 #[test]
469 fn train_and_predict() {
470 let data = training_data();
471 let model = CategorizationModel::train(&data).unwrap();
472
473 assert_eq!(model.num_categories(), 3);
474 assert!(model.vocab_size() > 5);
475
476 let predictions = model.predict("Weekly food shopping at the store", None);
477 assert!(!predictions.is_empty());
478 assert_eq!(predictions[0].0, "Expenses:Groceries");
480 }
481
482 #[test]
483 fn predict_dining() {
484 let data = training_data();
485 let model = CategorizationModel::train(&data).unwrap();
486
487 let predictions = model.predict("Coffee", Some("Starbucks"));
488 assert!(!predictions.is_empty());
489 assert_eq!(predictions[0].0, "Expenses:Dining");
490 }
491
492 #[test]
493 fn predict_transport() {
494 let data = training_data();
495 let model = CategorizationModel::train(&data).unwrap();
496
497 let predictions = model.predict("Fuel for car", Some("Shell"));
498 assert!(!predictions.is_empty());
499 assert_eq!(predictions[0].0, "Expenses:Transport");
500 }
501
502 #[test]
503 fn insufficient_data() {
504 let data = vec![make_txn(
505 Some("Store"),
506 "Stuff",
507 "Assets:Bank",
508 "Expenses:Misc",
509 )];
510 let result = CategorizationModel::train(&data);
511 assert!(result.is_err());
512 }
513
514 #[test]
515 fn insufficient_categories() {
516 let data = vec![
517 make_txn(Some("Store"), "Stuff", "Assets:Bank", "Expenses:Misc"),
518 make_txn(Some("Shop"), "Things", "Assets:Bank", "Expenses:Misc"),
519 ];
520 let result = CategorizationModel::train(&data);
521 assert!(result.is_err());
522 }
523
524 #[test]
525 fn tokenize_basic() {
526 let tokens = tokenize("WHOLE FOODS MARKET #1234");
527 assert!(tokens.contains(&"whole".to_string()));
528 assert!(tokens.contains(&"foods".to_string()));
529 assert!(tokens.contains(&"market".to_string()));
530 assert!(tokens.contains(&"1234".to_string()));
531 }
532
533 #[test]
534 fn naive_bayes_known_values() {
535 let nb = MultinomialNB::fit(&[vec![(0, 2.0)], vec![(1, 2.0)]], &[0, 1], 2, 2);
543
544 let p = nb.predict_proba(&[(0, 1.0)]);
545 assert!((p[0] - 0.75).abs() < 1e-9, "p[0] = {}", p[0]);
546 assert!((p[1] - 0.25).abs() < 1e-9, "p[1] = {}", p[1]);
547 assert!(
548 (p.iter().sum::<f64>() - 1.0).abs() < 1e-12,
549 "posteriors must sum to 1.0"
550 );
551
552 let q = nb.predict_proba(&[(1, 1.0)]);
554 assert!((q[0] - 0.25).abs() < 1e-9 && (q[1] - 0.75).abs() < 1e-9);
555 }
556}