Skip to main content

scirs2_text/
advanced_classification.rs

1//! Advanced text classification and feature extraction.
2//!
3//! Provides:
4//! - [`NaiveBayesClassifier`]: Multinomial Naive Bayes with Laplace smoothing.
5//! - [`FastTextClassifier`]: Averaged word-vector classifier inspired by fastText.
6//! - [`CountVectorizer`]: N-gram count matrix builder.
7//! - [`TfidfTransformer`]: TF-IDF weighting from count matrices.
8
9use std::collections::HashMap;
10
11use crate::error::{Result, TextError};
12
13// ─────────────────────────────────────────────────────────────────────────────
14// NaiveBayesClassifier
15// ─────────────────────────────────────────────────────────────────────────────
16
17/// Multinomial Naive Bayes text classifier with Laplace smoothing.
18///
19/// Uses a bag-of-words representation; each token in a document contributes
20/// to the likelihood estimate.
21#[derive(Debug, Clone, Default)]
22pub struct NaiveBayesClassifier {
23    /// log P(class) for each class
24    class_log_priors: Vec<f64>,
25    /// log P(word | class): indexed as [class_idx][word_idx]
26    log_likelihoods: Vec<Vec<f64>>,
27    /// Ordered class names
28    classes: Vec<String>,
29    /// word → index mapping built during fit
30    vocabulary: HashMap<String, usize>,
31    /// `true` after `fit()` has been called
32    fitted: bool,
33}
34
35impl NaiveBayesClassifier {
36    /// Create an unfitted classifier.
37    pub fn new() -> Self {
38        Self::default()
39    }
40
41    // ── Internal helpers ──────────────────────────────────────────────
42
43    /// Tokenise `text` into words (lower-case, alpha-only).
44    fn tokenize(text: &str) -> Vec<String> {
45        text.split(|c: char| !c.is_alphanumeric())
46            .filter(|s| !s.is_empty())
47            .map(|s| s.to_lowercase())
48            .collect()
49    }
50
51    /// Convert `text` to a word-count vector aligned with `vocabulary`.
52    fn text_to_counts(&self, text: &str) -> Vec<f64> {
53        let mut counts = vec![0.0f64; self.vocabulary.len()];
54        for word in Self::tokenize(text) {
55            if let Some(&idx) = self.vocabulary.get(&word) {
56                counts[idx] += 1.0;
57            }
58        }
59        counts
60    }
61
62    // ── Public API ────────────────────────────────────────────────────
63
64    /// Train on `(text, label)` pairs.
65    ///
66    /// `alpha` is the Laplace smoothing factor (typical: 1.0).
67    pub fn fit(&mut self, corpus: &[(String, String)], alpha: f64) -> Result<()> {
68        if corpus.is_empty() {
69            return Err(TextError::InvalidInput("corpus is empty".to_string()));
70        }
71        if alpha <= 0.0 {
72            return Err(TextError::InvalidInput(
73                "smoothing parameter alpha must be > 0".to_string(),
74            ));
75        }
76
77        // Collect unique classes
78        let mut class_set: Vec<String> = corpus
79            .iter()
80            .map(|(_, label)| label.clone())
81            .collect::<std::collections::HashSet<_>>()
82            .into_iter()
83            .collect();
84        class_set.sort();
85        self.classes = class_set;
86        let n_classes = self.classes.len();
87        let class_to_id: HashMap<String, usize> = self
88            .classes
89            .iter()
90            .enumerate()
91            .map(|(i, c)| (c.clone(), i))
92            .collect();
93
94        // Build vocabulary
95        let mut vocab_set: std::collections::HashSet<String> = std::collections::HashSet::new();
96        for (text, _) in corpus {
97            for word in Self::tokenize(text) {
98                vocab_set.insert(word);
99            }
100        }
101        let mut vocab_sorted: Vec<String> = vocab_set.into_iter().collect();
102        vocab_sorted.sort();
103        self.vocabulary = vocab_sorted
104            .iter()
105            .enumerate()
106            .map(|(i, w)| (w.clone(), i))
107            .collect();
108        let v = self.vocabulary.len();
109
110        // Count per class
111        let mut class_counts = vec![0usize; n_classes];
112        let mut word_counts_per_class: Vec<Vec<f64>> = vec![vec![0.0; v]; n_classes];
113
114        for (text, label) in corpus {
115            let ci = class_to_id[label];
116            class_counts[ci] += 1;
117            for word in Self::tokenize(text) {
118                if let Some(&wi) = self.vocabulary.get(&word) {
119                    word_counts_per_class[ci][wi] += 1.0;
120                }
121            }
122        }
123
124        let total_docs = corpus.len() as f64;
125        self.class_log_priors = class_counts
126            .iter()
127            .map(|&c| (c as f64 / total_docs).ln())
128            .collect();
129
130        // Compute log likelihoods with Laplace smoothing
131        self.log_likelihoods = word_counts_per_class
132            .iter()
133            .map(|counts| {
134                let total: f64 = counts.iter().sum::<f64>() + alpha * v as f64;
135                counts.iter().map(|&c| ((c + alpha) / total).ln()).collect()
136            })
137            .collect();
138
139        self.fitted = true;
140        Ok(())
141    }
142
143    /// Compute log-posterior scores for each class.
144    fn log_scores(&self, text: &str) -> Result<Vec<f64>> {
145        if !self.fitted {
146            return Err(TextError::ModelNotFitted(
147                "NaiveBayesClassifier is not fitted".to_string(),
148            ));
149        }
150        let counts = self.text_to_counts(text);
151        let scores: Vec<f64> = self
152            .class_log_priors
153            .iter()
154            .zip(self.log_likelihoods.iter())
155            .map(|(&prior, likelihoods)| {
156                let ll: f64 = counts
157                    .iter()
158                    .zip(likelihoods.iter())
159                    .map(|(&c, &lp)| c * lp)
160                    .sum();
161                prior + ll
162            })
163            .collect();
164        Ok(scores)
165    }
166
167    /// Predict the most likely class label for `text`.
168    pub fn predict(&self, text: &str) -> Result<Option<String>> {
169        let scores = self.log_scores(text)?;
170        let best = scores
171            .iter()
172            .enumerate()
173            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
174            .map(|(i, _)| self.classes[i].clone());
175        Ok(best)
176    }
177
178    /// Predict posterior probabilities (softmax of log-scores) per class.
179    pub fn predict_proba(&self, text: &str) -> Result<Vec<(String, f64)>> {
180        let log_scores = self.log_scores(text)?;
181        // Softmax
182        let max_s = log_scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
183        let exps: Vec<f64> = log_scores.iter().map(|&s| (s - max_s).exp()).collect();
184        let total: f64 = exps.iter().sum();
185        Ok(self
186            .classes
187            .iter()
188            .zip(exps.iter())
189            .map(|(cls, &e)| (cls.clone(), if total == 0.0 { 0.0 } else { e / total }))
190            .collect())
191    }
192
193    /// Batch predict over multiple texts.
194    pub fn predict_batch(&self, texts: &[String]) -> Result<Vec<Option<String>>> {
195        texts.iter().map(|t| self.predict(t)).collect()
196    }
197
198    /// Compute accuracy on a labelled test set.
199    pub fn accuracy(&self, test_set: &[(String, String)]) -> Result<f64> {
200        if test_set.is_empty() {
201            return Ok(0.0);
202        }
203        let mut correct = 0usize;
204        for (text, gold) in test_set {
205            if let Ok(Some(pred)) = self.predict(text) {
206                if &pred == gold {
207                    correct += 1;
208                }
209            }
210        }
211        Ok(correct as f64 / test_set.len() as f64)
212    }
213
214    /// Return ordered class names.
215    pub fn class_names(&self) -> &[String] {
216        &self.classes
217    }
218}
219
220// ─────────────────────────────────────────────────────────────────────────────
221// FastTextClassifier
222// ─────────────────────────────────────────────────────────────────────────────
223
224/// FastText-inspired averaged word-vector text classifier.
225///
226/// Each document is represented as the average of its per-word embeddings.
227/// A linear layer maps the averaged vector to class logits.  Training uses
228/// SGD with Hogwild-style updates.
229#[derive(Debug, Clone)]
230pub struct FastTextClassifier {
231    n_classes: usize,
232    classes: Vec<String>,
233    word_vectors: HashMap<String, Vec<f32>>,
234    /// Weight matrix [dim × n_classes]
235    weights: Vec<Vec<f32>>,
236    /// Bias vector [n_classes]
237    bias: Vec<f32>,
238    dim: usize,
239    fitted: bool,
240}
241
242impl FastTextClassifier {
243    /// Create a new, unfitted classifier.
244    ///
245    /// * `n_classes` – number of output classes.
246    /// * `dim`       – word-embedding dimension.
247    /// * `classes`   – ordered class name list.
248    pub fn new(n_classes: usize, dim: usize, classes: Vec<String>) -> Self {
249        assert_eq!(
250            classes.len(),
251            n_classes,
252            "classes.len() must equal n_classes"
253        );
254        FastTextClassifier {
255            n_classes,
256            classes,
257            word_vectors: HashMap::new(),
258            weights: vec![vec![0.0f32; n_classes]; dim],
259            bias: vec![0.0f32; n_classes],
260            dim,
261            fitted: false,
262        }
263    }
264
265    // ── Internal helpers ──────────────────────────────────────────────
266
267    /// Initialise a new word vector with small random values via a deterministic
268    /// hash-based scheme (avoids external RNG dependency).
269    fn init_word_vec(word: &str, dim: usize) -> Vec<f32> {
270        let mut v = vec![0.0f32; dim];
271        for (i, c) in word.bytes().enumerate() {
272            let idx = i % dim;
273            v[idx] += (c as f32 - 64.0) / 128.0;
274        }
275        // Normalise
276        let norm: f32 = v.iter().map(|&x| x * x).sum::<f32>().sqrt();
277        if norm > 0.0 {
278            v.iter_mut().for_each(|x| *x /= norm);
279        }
280        v
281    }
282
283    /// Compute the mean embedding for a token list.
284    fn mean_embedding(&self, tokens: &[String]) -> Vec<f32> {
285        let mut sum = vec![0.0f32; self.dim];
286        let mut count = 0usize;
287        for tok in tokens {
288            if let Some(vec) = self.word_vectors.get(tok.as_str()) {
289                for (s, &v) in sum.iter_mut().zip(vec.iter()) {
290                    *s += v;
291                }
292                count += 1;
293            }
294        }
295        if count > 0 {
296            sum.iter_mut().for_each(|s| *s /= count as f32);
297        }
298        sum
299    }
300
301    /// Linear forward: z[k] = sum_d(embedding[d] * weights[d][k]) + bias[k]
302    fn forward(&self, embedding: &[f32]) -> Vec<f32> {
303        let mut logits = self.bias.clone();
304        for (d, &e) in embedding.iter().enumerate() {
305            for k in 0..self.n_classes {
306                logits[k] += e * self.weights[d][k];
307            }
308        }
309        logits
310    }
311
312    /// Softmax in-place.
313    fn softmax(logits: &mut [f32]) {
314        let max_l = logits.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
315        logits.iter_mut().for_each(|x| *x = (*x - max_l).exp());
316        let sum: f32 = logits.iter().sum();
317        if sum > 0.0 {
318            logits.iter_mut().for_each(|x| *x /= sum);
319        }
320    }
321
322    // ── Public API ────────────────────────────────────────────────────
323
324    /// Train the classifier.
325    ///
326    /// `corpus` is `(tokens, class_id)` pairs.  `lr` is the learning rate.
327    pub fn fit(&mut self, corpus: &[(Vec<String>, usize)], n_epochs: usize, lr: f32) -> Result<()> {
328        if corpus.is_empty() {
329            return Err(TextError::InvalidInput("corpus is empty".to_string()));
330        }
331
332        // Ensure all word vectors exist
333        for (tokens, _) in corpus {
334            for tok in tokens {
335                self.word_vectors
336                    .entry(tok.clone())
337                    .or_insert_with(|| Self::init_word_vec(tok, self.dim));
338            }
339        }
340
341        for _epoch in 0..n_epochs {
342            for (tokens, gold_class) in corpus {
343                let gold_class = *gold_class;
344                if gold_class >= self.n_classes {
345                    continue;
346                }
347                let emb = self.mean_embedding(tokens);
348                let mut probs = self.forward(&emb);
349                Self::softmax(&mut probs);
350
351                // Cross-entropy gradient: (probs - one_hot)
352                let mut grad = probs.clone();
353                grad[gold_class] -= 1.0;
354
355                // Update weights and bias
356                for d in 0..self.dim {
357                    for k in 0..self.n_classes {
358                        self.weights[d][k] -= lr * grad[k] * emb[d];
359                    }
360                }
361                for k in 0..self.n_classes {
362                    self.bias[k] -= lr * grad[k];
363                }
364            }
365        }
366
367        self.fitted = true;
368        Ok(())
369    }
370
371    /// Predict the class ID for a token sequence.
372    pub fn predict(&self, tokens: &[String]) -> usize {
373        let probs = self.predict_proba(tokens);
374        probs
375            .iter()
376            .enumerate()
377            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
378            .map(|(i, _)| i)
379            .unwrap_or(0)
380    }
381
382    /// Predict class probability distribution.
383    pub fn predict_proba(&self, tokens: &[String]) -> Vec<f32> {
384        let emb = self.mean_embedding(tokens);
385        let mut logits = self.forward(&emb);
386        Self::softmax(&mut logits);
387        logits
388    }
389
390    /// Return the ordered class names.
391    pub fn class_names(&self) -> &[String] {
392        &self.classes
393    }
394
395    /// `true` if the model has been trained.
396    pub fn is_fitted(&self) -> bool {
397        self.fitted
398    }
399}
400
401// ─────────────────────────────────────────────────────────────────────────────
402// CountVectorizer
403// ─────────────────────────────────────────────────────────────────────────────
404
405/// Count-based document-feature matrix builder with N-gram support.
406#[derive(Debug, Clone)]
407pub struct CountVectorizer {
408    vocabulary: HashMap<String, usize>,
409    max_features: Option<usize>,
410    min_df: usize,
411    max_df_ratio: f64,
412    ngram_range: (usize, usize),
413    fitted: bool,
414}
415
416impl Default for CountVectorizer {
417    fn default() -> Self {
418        CountVectorizer {
419            vocabulary: HashMap::new(),
420            max_features: None,
421            min_df: 1,
422            max_df_ratio: 1.0,
423            ngram_range: (1, 1),
424            fitted: false,
425        }
426    }
427}
428
429impl CountVectorizer {
430    /// Create a new `CountVectorizer` with default settings (unigrams only).
431    pub fn new() -> Self {
432        Self::default()
433    }
434
435    /// Limit the vocabulary to the `n` most frequent features.
436    pub fn with_max_features(mut self, n: usize) -> Self {
437        self.max_features = Some(n);
438        self
439    }
440
441    /// Set N-gram range `(min_n, max_n)`.
442    pub fn with_ngram_range(mut self, min: usize, max: usize) -> Self {
443        self.ngram_range = (min, max);
444        self
445    }
446
447    /// Set minimum document frequency (number of documents a token must appear in).
448    pub fn with_min_df(mut self, min_df: usize) -> Self {
449        self.min_df = min_df;
450        self
451    }
452
453    /// Set maximum document frequency as a fraction of the corpus (0.0–1.0).
454    pub fn with_max_df_ratio(mut self, ratio: f64) -> Self {
455        self.max_df_ratio = ratio;
456        self
457    }
458
459    // ── Internal helpers ──────────────────────────────────────────────
460
461    /// Generate N-grams from a token list according to `ngram_range`.
462    fn ngrams(&self, tokens: &[String]) -> Vec<String> {
463        let (min_n, max_n) = self.ngram_range;
464        let mut grams = Vec::new();
465        for n in min_n..=max_n {
466            for window in tokens.windows(n) {
467                grams.push(window.join(" "));
468            }
469        }
470        grams
471    }
472
473    /// Tokenise `text` into lower-case alphanumeric words.
474    fn tokenize(text: &str) -> Vec<String> {
475        text.split(|c: char| !c.is_alphanumeric())
476            .filter(|s| !s.is_empty())
477            .map(|s| s.to_lowercase())
478            .collect()
479    }
480
481    // ── Public API ────────────────────────────────────────────────────
482
483    /// Fit the vocabulary from `corpus`.
484    pub fn fit(&mut self, corpus: &[String]) -> Result<()> {
485        if corpus.is_empty() {
486            return Err(TextError::InvalidInput("corpus is empty".to_string()));
487        }
488        let n_docs = corpus.len();
489
490        // Count document frequencies
491        let mut df: HashMap<String, usize> = HashMap::new();
492        let mut term_freq: HashMap<String, usize> = HashMap::new();
493
494        for doc in corpus {
495            let tokens = Self::tokenize(doc);
496            let grams = self.ngrams(&tokens);
497            let unique: std::collections::HashSet<String> = grams.iter().cloned().collect();
498            for gram in unique {
499                *df.entry(gram.clone()).or_insert(0) += 1;
500                *term_freq.entry(gram).or_insert(0) += 1;
501            }
502        }
503
504        // Filter by df thresholds
505        let max_df_count = (self.max_df_ratio * n_docs as f64).ceil() as usize;
506        let mut candidates: Vec<(String, usize)> = df
507            .into_iter()
508            .filter(|(_, count)| *count >= self.min_df && *count <= max_df_count)
509            .collect();
510
511        // Sort by total term frequency descending, then alphabetically
512        candidates.sort_by(|a, b| {
513            let fa = term_freq.get(&a.0).copied().unwrap_or(0);
514            let fb = term_freq.get(&b.0).copied().unwrap_or(0);
515            fb.cmp(&fa).then_with(|| a.0.cmp(&b.0))
516        });
517
518        // Apply max_features limit
519        if let Some(max_f) = self.max_features {
520            candidates.truncate(max_f);
521        }
522
523        // Build vocabulary
524        self.vocabulary = candidates
525            .into_iter()
526            .enumerate()
527            .map(|(i, (gram, _))| (gram, i))
528            .collect();
529
530        self.fitted = true;
531        Ok(())
532    }
533
534    /// Transform `texts` into a count matrix.
535    pub fn transform(&self, texts: &[String]) -> Result<Vec<Vec<f64>>> {
536        if !self.fitted {
537            return Err(TextError::ModelNotFitted(
538                "CountVectorizer is not fitted".to_string(),
539            ));
540        }
541        let v = self.vocabulary.len();
542        texts
543            .iter()
544            .map(|text| {
545                let tokens = Self::tokenize(text);
546                let grams = self.ngrams(&tokens);
547                let mut counts = vec![0.0f64; v];
548                for gram in grams {
549                    if let Some(&idx) = self.vocabulary.get(&gram) {
550                        counts[idx] += 1.0;
551                    }
552                }
553                Ok(counts)
554            })
555            .collect()
556    }
557
558    /// Fit then transform in one step.
559    pub fn fit_transform(&mut self, corpus: &[String]) -> Result<Vec<Vec<f64>>> {
560        self.fit(corpus)?;
561        self.transform(corpus)
562    }
563
564    /// Return the current vocabulary size.
565    pub fn vocabulary_size(&self) -> usize {
566        self.vocabulary.len()
567    }
568
569    /// Borrow the vocabulary map.
570    pub fn vocabulary(&self) -> &HashMap<String, usize> {
571        &self.vocabulary
572    }
573}
574
575// ─────────────────────────────────────────────────────────────────────────────
576// TfidfTransformer
577// ─────────────────────────────────────────────────────────────────────────────
578
579/// Transforms a count matrix into a TF-IDF weighted matrix.
580///
581/// IDF formula (with `smooth_idf = true`):
582/// `idf(t) = ln((1 + n) / (1 + df(t))) + 1`
583///
584/// Each row is L2-normalised after weighting.
585#[derive(Debug, Clone)]
586pub struct TfidfTransformer {
587    /// Per-term IDF values.
588    pub idf: Vec<f64>,
589    /// Whether to smooth IDF by adding 1 to numerator and denominator.
590    pub smooth_idf: bool,
591    fitted: bool,
592}
593
594impl TfidfTransformer {
595    /// Create a new transformer.  `smooth_idf = true` avoids division by zero
596    /// for terms seen in all documents.
597    pub fn new(smooth_idf: bool) -> Self {
598        TfidfTransformer {
599            idf: Vec::new(),
600            smooth_idf,
601            fitted: false,
602        }
603    }
604
605    /// Compute IDF values from a count matrix (rows = documents, cols = terms).
606    pub fn fit(&mut self, count_matrix: &[Vec<f64>]) -> Result<()> {
607        if count_matrix.is_empty() {
608            return Err(TextError::InvalidInput("count_matrix is empty".to_string()));
609        }
610        let n_docs = count_matrix.len() as f64;
611        let n_features = count_matrix[0].len();
612
613        let mut df = vec![0.0f64; n_features];
614        for row in count_matrix {
615            for (j, &c) in row.iter().enumerate() {
616                if c > 0.0 {
617                    df[j] += 1.0;
618                }
619            }
620        }
621
622        self.idf = if self.smooth_idf {
623            df.iter()
624                .map(|&d| ((1.0 + n_docs) / (1.0 + d)).ln() + 1.0)
625                .collect()
626        } else {
627            df.iter()
628                .map(|&d| {
629                    if d == 0.0 {
630                        0.0
631                    } else {
632                        (n_docs / d).ln() + 1.0
633                    }
634                })
635                .collect()
636        };
637
638        self.fitted = true;
639        Ok(())
640    }
641
642    /// Apply TF-IDF weighting and L2 normalisation.
643    pub fn transform(&self, count_matrix: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
644        if !self.fitted {
645            return Err(TextError::ModelNotFitted(
646                "TfidfTransformer is not fitted".to_string(),
647            ));
648        }
649        count_matrix
650            .iter()
651            .map(|row| {
652                let mut tfidf: Vec<f64> = row
653                    .iter()
654                    .zip(self.idf.iter())
655                    .map(|(&c, &idf)| c * idf)
656                    .collect();
657                // L2 normalise
658                let norm: f64 = tfidf.iter().map(|&x| x * x).sum::<f64>().sqrt();
659                if norm > 0.0 {
660                    tfidf.iter_mut().for_each(|x| *x /= norm);
661                }
662                Ok(tfidf)
663            })
664            .collect()
665    }
666
667    /// Fit then transform in one step.
668    pub fn fit_transform(&mut self, count_matrix: &[Vec<f64>]) -> Result<Vec<Vec<f64>>> {
669        self.fit(count_matrix)?;
670        self.transform(count_matrix)
671    }
672}
673
674// ─────────────────────────────────────────────────────────────────────────────
675// Tests
676// ─────────────────────────────────────────────────────────────────────────────
677
678#[cfg(test)]
679mod tests {
680    use super::*;
681
682    fn news_corpus() -> Vec<(String, String)> {
683        vec![
684            ("football game soccer ball".into(), "sports".into()),
685            ("basketball players team score".into(), "sports".into()),
686            ("election president vote campaign".into(), "politics".into()),
687            ("senate congress legislation bill".into(), "politics".into()),
688            ("python rust programming language".into(), "tech".into()),
689            ("software compiler code debug".into(), "tech".into()),
690        ]
691    }
692
693    // ── NaiveBayesClassifier tests ───────────────────────────────────
694
695    #[test]
696    fn test_nb_fit_predict() {
697        let mut nb = NaiveBayesClassifier::new();
698        let corpus = news_corpus();
699        nb.fit(&corpus, 1.0).expect("fit failed");
700        // Should predict "sports" for a sports-like document
701        let pred = nb.predict("soccer football game").expect("predict failed");
702        assert!(pred.is_some());
703        assert_eq!(pred.unwrap(), "sports");
704    }
705
706    #[test]
707    fn test_nb_predict_proba_sums_to_one() {
708        let mut nb = NaiveBayesClassifier::new();
709        let corpus = news_corpus();
710        nb.fit(&corpus, 1.0).expect("fit failed");
711        let proba = nb.predict_proba("vote election").expect("proba failed");
712        let total: f64 = proba.iter().map(|(_, p)| p).sum();
713        assert!(
714            (total - 1.0).abs() < 1e-9,
715            "probabilities should sum to 1, got {}",
716            total
717        );
718    }
719
720    #[test]
721    fn test_nb_accuracy() {
722        let mut nb = NaiveBayesClassifier::new();
723        let corpus = news_corpus();
724        nb.fit(&corpus, 1.0).expect("fit failed");
725        let acc = nb.accuracy(&corpus).expect("accuracy failed");
726        assert!(acc >= 0.5, "Expected accuracy >= 0.5, got {}", acc);
727    }
728
729    #[test]
730    fn test_nb_class_names() {
731        let mut nb = NaiveBayesClassifier::new();
732        nb.fit(&news_corpus(), 1.0).expect("fit failed");
733        let classes = nb.class_names();
734        assert!(classes.contains(&"sports".to_string()));
735        assert!(classes.contains(&"tech".to_string()));
736    }
737
738    #[test]
739    fn test_nb_not_fitted_error() {
740        let nb = NaiveBayesClassifier::new();
741        let result = nb.predict("test");
742        assert!(result.is_err());
743    }
744
745    #[test]
746    fn test_nb_batch_predict() {
747        let mut nb = NaiveBayesClassifier::new();
748        nb.fit(&news_corpus(), 1.0).expect("fit failed");
749        let texts = vec!["soccer game".to_string(), "code compiler".to_string()];
750        let preds = nb.predict_batch(&texts).expect("batch predict failed");
751        assert_eq!(preds.len(), 2);
752        assert!(preds[0].is_some());
753    }
754
755    // ── FastTextClassifier tests ─────────────────────────────────────
756
757    #[test]
758    fn test_fasttext_predict_without_training() {
759        let ft = FastTextClassifier::new(2, 16, vec!["sports".to_string(), "tech".to_string()]);
760        let tokens: Vec<String> = vec!["soccer".into(), "game".into()];
761        let pred = ft.predict(&tokens);
762        assert!(pred < 2);
763    }
764
765    #[test]
766    fn test_fasttext_fit_and_predict() {
767        let classes = vec!["pos".to_string(), "neg".to_string()];
768        let mut ft = FastTextClassifier::new(2, 8, classes);
769        let corpus = vec![
770            (vec!["good".to_string(), "great".to_string()], 0usize),
771            (vec!["excellent".to_string(), "wonderful".to_string()], 0),
772            (vec!["bad".to_string(), "terrible".to_string()], 1),
773            (vec!["awful".to_string(), "horrible".to_string()], 1),
774        ];
775        ft.fit(&corpus, 10, 0.1).expect("fit failed");
776        assert!(ft.is_fitted());
777        let probs = ft.predict_proba(&["good".to_string()]);
778        assert_eq!(probs.len(), 2);
779        let total: f32 = probs.iter().sum();
780        assert!((total - 1.0).abs() < 1e-5);
781    }
782
783    // ── CountVectorizer tests ────────────────────────────────────────
784
785    #[test]
786    fn test_count_vectorizer_basic() {
787        let mut cv = CountVectorizer::new();
788        let corpus: Vec<String> = vec![
789            "hello world".to_string(),
790            "hello rust".to_string(),
791            "world rust".to_string(),
792        ];
793        let matrix = cv.fit_transform(&corpus).expect("fit_transform failed");
794        assert_eq!(matrix.len(), 3);
795        assert!(cv.vocabulary_size() > 0);
796    }
797
798    #[test]
799    fn test_count_vectorizer_ngram() {
800        let mut cv = CountVectorizer::new().with_ngram_range(1, 2);
801        let corpus: Vec<String> = vec!["the quick fox".to_string(), "the lazy dog".to_string()];
802        cv.fit(&corpus).expect("fit failed");
803        // Should have unigrams + bigrams
804        assert!(cv.vocabulary_size() > 3);
805    }
806
807    #[test]
808    fn test_count_vectorizer_max_features() {
809        let mut cv = CountVectorizer::new().with_max_features(2);
810        let corpus: Vec<String> = vec!["a b c d e f".to_string(), "a b c d e f".to_string()];
811        cv.fit(&corpus).expect("fit failed");
812        assert_eq!(cv.vocabulary_size(), 2);
813    }
814
815    #[test]
816    fn test_count_vectorizer_not_fitted_error() {
817        let cv = CountVectorizer::new();
818        let result = cv.transform(&["hello".to_string()]);
819        assert!(result.is_err());
820    }
821
822    // ── TfidfTransformer tests ───────────────────────────────────────
823
824    #[test]
825    fn test_tfidf_transformer_l2_norm() {
826        let mut tf = TfidfTransformer::new(true);
827        let counts = vec![vec![1.0, 0.0, 2.0], vec![0.0, 3.0, 1.0]];
828        let tfidf = tf.fit_transform(&counts).expect("fit_transform failed");
829        // Each row should be L2-normalised
830        for row in &tfidf {
831            let norm: f64 = row.iter().map(|&x| x * x).sum::<f64>().sqrt();
832            assert!((norm - 1.0).abs() < 1e-9, "norm = {}", norm);
833        }
834    }
835
836    #[test]
837    fn test_tfidf_transformer_not_fitted_error() {
838        let tf = TfidfTransformer::new(true);
839        let result = tf.transform(&[vec![1.0, 2.0]]);
840        assert!(result.is_err());
841    }
842
843    #[test]
844    fn test_tfidf_smooth_vs_no_smooth() {
845        let mut tf_smooth = TfidfTransformer::new(true);
846        let mut tf_no = TfidfTransformer::new(false);
847        let counts = vec![vec![1.0, 2.0], vec![3.0, 0.0]];
848        tf_smooth.fit(&counts).expect("fit");
849        tf_no.fit(&counts).expect("fit");
850        // Smooth IDF should differ from unsmoothed
851        assert_ne!(tf_smooth.idf, tf_no.idf);
852    }
853}