Skip to main content

scirs2_text/
classification.rs

1//! Text classification functionality
2//!
3//! This module provides comprehensive tools for text classification:
4//!
5//! - **Naive Bayes**: Multinomial and Bernoulli variants
6//! - **TF-IDF + Cosine Similarity**: k-NN style classification
7//! - **Feature Hashing**: Memory-efficient hashing trick for large vocabularies
8//! - **Multi-label**: Support for texts belonging to multiple categories
9//! - **Cross-validation**: k-fold evaluation utilities
10//! - **Metrics**: Precision, recall, F1, accuracy
11//! - **Feature Selection**: Document frequency based filtering
12//! - **Pipelines**: End-to-end TF-IDF + classify pipelines
13
14use crate::error::{Result, TextError};
15use crate::tokenize::{Tokenizer, WordTokenizer};
16use crate::vectorize::{TfidfVectorizer, Vectorizer};
17use scirs2_core::ndarray::{Array1, Array2, Axis};
18use scirs2_core::random::prelude::*;
19use scirs2_core::random::seq::SliceRandom;
20use scirs2_core::random::SeedableRng;
21use std::collections::{HashMap, HashSet};
22
23// ─── Feature Selector ────────────────────────────────────────────────────────
24
25/// Text feature selector
26///
27/// Filters features based on document frequency.
28#[derive(Debug, Clone)]
29pub struct TextFeatureSelector {
30    /// Minimum document frequency (fraction or count)
31    min_df: f64,
32    /// Maximum document frequency (fraction or count)
33    max_df: f64,
34    /// Whether to use raw counts instead of fractions
35    use_counts: bool,
36    /// Selected feature indices
37    selected_features: Option<Vec<usize>>,
38}
39
40impl Default for TextFeatureSelector {
41    fn default() -> Self {
42        Self {
43            min_df: 0.0,
44            max_df: 1.0,
45            use_counts: false,
46            selected_features: None,
47        }
48    }
49}
50
51impl TextFeatureSelector {
52    /// Create a new feature selector
53    pub fn new() -> Self {
54        Self::default()
55    }
56
57    /// Set minimum document frequency
58    pub fn set_min_df(mut self, mindf: f64) -> Result<Self> {
59        if mindf < 0.0 {
60            return Err(TextError::InvalidInput(
61                "min_df must be non-negative".to_string(),
62            ));
63        }
64        self.min_df = mindf;
65        Ok(self)
66    }
67
68    /// Set maximum document frequency
69    pub fn set_max_df(mut self, maxdf: f64) -> Result<Self> {
70        if !(0.0..=1.0).contains(&maxdf) {
71            return Err(TextError::InvalidInput(
72                "max_df must be between 0 and 1 for fractions".to_string(),
73            ));
74        }
75        self.max_df = maxdf;
76        Ok(self)
77    }
78
79    /// Set maximum document frequency (alias for set_max_df)
80    pub fn set_max_features(self, maxfeatures: f64) -> Result<Self> {
81        self.set_max_df(maxfeatures)
82    }
83
84    /// Set to use absolute counts instead of fractions
85    pub fn use_counts(mut self, usecounts: bool) -> Self {
86        self.use_counts = usecounts;
87        self
88    }
89
90    /// Fit the feature selector to data
91    pub fn fit(&mut self, x: &Array2<f64>) -> Result<&mut Self> {
92        let n_samples = x.nrows();
93        let n_features = x.ncols();
94
95        let mut document_frequencies = vec![0; n_features];
96
97        for sample in x.axis_iter(Axis(0)) {
98            for (feature_idx, &value) in sample.iter().enumerate() {
99                if value > 0.0 {
100                    document_frequencies[feature_idx] += 1;
101                }
102            }
103        }
104
105        let min_count = if self.use_counts {
106            self.min_df
107        } else {
108            self.min_df * n_samples as f64
109        };
110
111        let max_count = if self.use_counts {
112            self.max_df
113        } else {
114            self.max_df * n_samples as f64
115        };
116
117        let mut selected_features = Vec::new();
118        for (idx, &df) in document_frequencies.iter().enumerate() {
119            let df_f64 = df as f64;
120            if df_f64 >= min_count && df_f64 <= max_count {
121                selected_features.push(idx);
122            }
123        }
124
125        self.selected_features = Some(selected_features);
126        Ok(self)
127    }
128
129    /// Transform data using selected features
130    pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
131        let selected_features = self
132            .selected_features
133            .as_ref()
134            .ok_or_else(|| TextError::ModelNotFitted("Feature selector not fitted".to_string()))?;
135
136        if selected_features.is_empty() {
137            return Err(TextError::InvalidInput(
138                "No features selected. Try adjusting min_df and max_df".to_string(),
139            ));
140        }
141
142        let n_samples = x.nrows();
143        let n_selected = selected_features.len();
144
145        let mut result = Array2::zeros((n_samples, n_selected));
146
147        for (i, row) in x.axis_iter(Axis(0)).enumerate() {
148            for (j, &feature_idx) in selected_features.iter().enumerate() {
149                result[[i, j]] = row[feature_idx];
150            }
151        }
152
153        Ok(result)
154    }
155
156    /// Fit and transform in one step
157    pub fn fit_transform(&mut self, x: &Array2<f64>) -> Result<Array2<f64>> {
158        self.fit(x)?;
159        self.transform(x)
160    }
161
162    /// Get selected feature indices
163    pub fn get_selected_features(&self) -> Option<&Vec<usize>> {
164        self.selected_features.as_ref()
165    }
166}
167
168// ─── Classification Metrics ──────────────────────────────────────────────────
169
170/// Text classification metrics
171#[derive(Debug, Clone)]
172pub struct TextClassificationMetrics;
173
174impl Default for TextClassificationMetrics {
175    fn default() -> Self {
176        Self
177    }
178}
179
180impl TextClassificationMetrics {
181    /// Create a new metrics calculator
182    pub fn new() -> Self {
183        Self
184    }
185
186    /// Calculate precision score
187    pub fn precision<T>(
188        &self,
189        predictions: &[T],
190        true_labels: &[T],
191        class_idx: Option<T>,
192    ) -> Result<f64>
193    where
194        T: PartialEq + Copy + Default,
195    {
196        let positive_class = class_idx.unwrap_or_default();
197
198        if predictions.len() != true_labels.len() {
199            return Err(TextError::InvalidInput(
200                "Predictions and labels must have the same length".to_string(),
201            ));
202        }
203
204        let mut true_positives = 0;
205        let mut predicted_positives = 0;
206
207        for i in 0..predictions.len() {
208            if predictions[i] == positive_class {
209                predicted_positives += 1;
210                if true_labels[i] == positive_class {
211                    true_positives += 1;
212                }
213            }
214        }
215
216        if predicted_positives == 0 {
217            return Ok(0.0);
218        }
219
220        Ok(true_positives as f64 / predicted_positives as f64)
221    }
222
223    /// Calculate recall score
224    pub fn recall<T>(
225        &self,
226        predictions: &[T],
227        true_labels: &[T],
228        class_idx: Option<T>,
229    ) -> Result<f64>
230    where
231        T: PartialEq + Copy + Default,
232    {
233        let positive_class = class_idx.unwrap_or_default();
234
235        if predictions.len() != true_labels.len() {
236            return Err(TextError::InvalidInput(
237                "Predictions and labels must have the same length".to_string(),
238            ));
239        }
240
241        let mut true_positives = 0;
242        let mut actual_positives = 0;
243
244        for i in 0..predictions.len() {
245            if true_labels[i] == positive_class {
246                actual_positives += 1;
247                if predictions[i] == positive_class {
248                    true_positives += 1;
249                }
250            }
251        }
252
253        if actual_positives == 0 {
254            return Ok(0.0);
255        }
256
257        Ok(true_positives as f64 / actual_positives as f64)
258    }
259
260    /// Calculate F1 score
261    pub fn f1_score<T>(
262        &self,
263        predictions: &[T],
264        true_labels: &[T],
265        class_idx: Option<T>,
266    ) -> Result<f64>
267    where
268        T: PartialEq + Copy + Default,
269    {
270        let precision = self.precision(predictions, true_labels, class_idx)?;
271        let recall = self.recall(predictions, true_labels, class_idx)?;
272
273        if precision + recall == 0.0 {
274            return Ok(0.0);
275        }
276
277        Ok(2.0 * precision * recall / (precision + recall))
278    }
279
280    /// Calculate accuracy from predictions and true labels
281    pub fn accuracy<T>(&self, predictions: &[T], truelabels: &[T]) -> Result<f64>
282    where
283        T: PartialEq,
284    {
285        if predictions.len() != truelabels.len() {
286            return Err(TextError::InvalidInput(
287                "Predictions and labels must have the same length".to_string(),
288            ));
289        }
290
291        if predictions.is_empty() {
292            return Err(TextError::InvalidInput(
293                "Cannot calculate accuracy for empty arrays".to_string(),
294            ));
295        }
296
297        let correct = predictions
298            .iter()
299            .zip(truelabels.iter())
300            .filter(|(pred, true_label)| pred == true_label)
301            .count();
302
303        Ok(correct as f64 / predictions.len() as f64)
304    }
305
306    /// Calculate precision, recall, and F1 score for binary classification
307    pub fn binary_metrics<T>(&self, predictions: &[T], truelabels: &[T]) -> Result<(f64, f64, f64)>
308    where
309        T: PartialEq + Copy + Default + PartialEq<usize>,
310    {
311        if predictions.len() != truelabels.len() {
312            return Err(TextError::InvalidInput(
313                "Predictions and labels must have the same length".to_string(),
314            ));
315        }
316
317        let mut tp = 0;
318        let mut fp = 0;
319        let mut fn_ = 0;
320
321        for (pred, true_label) in predictions.iter().zip(truelabels.iter()) {
322            if *pred == 1 && *true_label == 1 {
323                tp += 1;
324            } else if *pred == 1 && *true_label == 0 {
325                fp += 1;
326            } else if *pred == 0 && *true_label == 1 {
327                fn_ += 1;
328            }
329        }
330
331        let precision = if tp + fp > 0 {
332            tp as f64 / (tp + fp) as f64
333        } else {
334            0.0
335        };
336
337        let recall = if tp + fn_ > 0 {
338            tp as f64 / (tp + fn_) as f64
339        } else {
340            0.0
341        };
342
343        let f1 = if precision + recall > 0.0 {
344            2.0 * precision * recall / (precision + recall)
345        } else {
346            0.0
347        };
348
349        Ok((precision, recall, f1))
350    }
351}
352
353// ─── Text Dataset ────────────────────────────────────────────────────────────
354
355/// Text classification dataset
356#[derive(Debug, Clone)]
357pub struct TextDataset {
358    /// The text samples
359    pub texts: Vec<String>,
360    /// The labels for each text
361    pub labels: Vec<String>,
362    /// Index mapping for labels
363    label_index: Option<HashMap<String, usize>>,
364}
365
366impl TextDataset {
367    /// Create a new text dataset
368    pub fn new(texts: Vec<String>, labels: Vec<String>) -> Result<Self> {
369        if texts.len() != labels.len() {
370            return Err(TextError::InvalidInput(
371                "Texts and labels must have the same length".to_string(),
372            ));
373        }
374
375        Ok(Self {
376            texts,
377            labels,
378            label_index: None,
379        })
380    }
381
382    /// Get the number of samples
383    pub fn len(&self) -> usize {
384        self.texts.len()
385    }
386
387    /// Check if the dataset is empty
388    pub fn is_empty(&self) -> bool {
389        self.texts.is_empty()
390    }
391
392    /// Get the unique labels in the dataset
393    pub fn unique_labels(&self) -> Vec<String> {
394        let mut unique = HashSet::new();
395        for label in &self.labels {
396            unique.insert(label.clone());
397        }
398        unique.into_iter().collect()
399    }
400
401    /// Build a label index mapping
402    pub fn build_label_index(&mut self) -> Result<&mut Self> {
403        let mut index = HashMap::new();
404        let unique_labels = self.unique_labels();
405
406        for (i, label) in unique_labels.iter().enumerate() {
407            index.insert(label.clone(), i);
408        }
409
410        self.label_index = Some(index);
411        Ok(self)
412    }
413
414    /// Get label indices
415    pub fn get_label_indices(&self) -> Result<Vec<usize>> {
416        let index = self
417            .label_index
418            .as_ref()
419            .ok_or_else(|| TextError::ModelNotFitted("Label index not built".to_string()))?;
420
421        self.labels
422            .iter()
423            .map(|label| {
424                index
425                    .get(label)
426                    .copied()
427                    .ok_or_else(|| TextError::InvalidInput(format!("Unknown label: {label}")))
428            })
429            .collect()
430    }
431
432    /// Split the dataset into train and test sets
433    pub fn train_test_split(
434        &self,
435        test_size: f64,
436        random_seed: Option<u64>,
437    ) -> Result<(Self, Self)> {
438        if test_size <= 0.0 || test_size >= 1.0 {
439            return Err(TextError::InvalidInput(
440                "test_size must be between 0 and 1".to_string(),
441            ));
442        }
443
444        if self.is_empty() {
445            return Err(TextError::InvalidInput("Dataset is empty".to_string()));
446        }
447
448        let mut indices: Vec<usize> = (0..self.len()).collect();
449
450        if let Some(seed) = random_seed {
451            let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(seed);
452            indices.shuffle(&mut rng);
453        } else {
454            let mut rng = scirs2_core::random::rng();
455            indices.shuffle(&mut rng);
456        }
457
458        let test_count = (self.len() as f64 * test_size).ceil() as usize;
459        let test_indices = indices[0..test_count].to_vec();
460        let train_indices = indices[test_count..].to_vec();
461
462        let train_texts = train_indices
463            .iter()
464            .map(|&i| self.texts[i].clone())
465            .collect();
466        let train_labels = train_indices
467            .iter()
468            .map(|&i| self.labels[i].clone())
469            .collect();
470        let test_texts = test_indices
471            .iter()
472            .map(|&i| self.texts[i].clone())
473            .collect();
474        let test_labels = test_indices
475            .iter()
476            .map(|&i| self.labels[i].clone())
477            .collect();
478
479        let mut train_dataset = Self::new(train_texts, train_labels)?;
480        let mut test_dataset = Self::new(test_texts, test_labels)?;
481
482        if self.label_index.is_some() {
483            train_dataset.build_label_index()?;
484            test_dataset.build_label_index()?;
485        }
486
487        Ok((train_dataset, test_dataset))
488    }
489}
490
491// ─── Classification Pipeline ─────────────────────────────────────────────────
492
493/// Pipeline for text classification
494pub struct TextClassificationPipeline {
495    /// The vectorizer to use
496    vectorizer: TfidfVectorizer,
497    /// Optional feature selector
498    feature_selector: Option<TextFeatureSelector>,
499}
500
501impl TextClassificationPipeline {
502    /// Create a new pipeline with a default TF-IDF vectorizer
503    pub fn with_tfidf() -> Self {
504        Self::new(TfidfVectorizer::default())
505    }
506
507    /// Create a new pipeline with the given vectorizer
508    pub fn new(vectorizer: TfidfVectorizer) -> Self {
509        Self {
510            vectorizer,
511            feature_selector: None,
512        }
513    }
514
515    /// Add a feature selector to the pipeline
516    pub fn with_feature_selector(mut self, selector: TextFeatureSelector) -> Self {
517        self.feature_selector = Some(selector);
518        self
519    }
520
521    /// Fit the pipeline to training data
522    pub fn fit(&mut self, dataset: &TextDataset) -> Result<&mut Self> {
523        let texts: Vec<&str> = dataset.texts.iter().map(AsRef::as_ref).collect();
524        self.vectorizer.fit(&texts)?;
525        Ok(self)
526    }
527
528    /// Transform text data using the pipeline
529    pub fn transform(&self, dataset: &TextDataset) -> Result<Array2<f64>> {
530        let texts: Vec<&str> = dataset.texts.iter().map(AsRef::as_ref).collect();
531        let mut features = self.vectorizer.transform_batch(&texts)?;
532
533        if let Some(selector) = &self.feature_selector {
534            features = selector.transform(&features)?;
535        }
536
537        Ok(features)
538    }
539
540    /// Fit and transform in one step
541    pub fn fit_transform(&mut self, dataset: &TextDataset) -> Result<Array2<f64>> {
542        self.fit(dataset)?;
543        self.transform(dataset)
544    }
545}
546
547// ─── Multinomial Naive Bayes Classifier ──────────────────────────────────────
548
549/// Multinomial Naive Bayes classifier for text
550///
551/// Suitable for text classification with word count / TF-IDF features.
552/// Implements Laplace smoothing.
553#[derive(Debug, Clone)]
554pub struct MultinomialNaiveBayes {
555    /// Word log-probabilities per class: class -> feature_idx -> log(P(w|c))
556    feature_log_probs: HashMap<String, Vec<f64>>,
557    /// Prior log-probabilities: class -> log(P(c))
558    class_log_priors: HashMap<String, f64>,
559    /// Number of features
560    n_features: usize,
561    /// Laplace smoothing parameter
562    alpha: f64,
563    /// Classes
564    classes: Vec<String>,
565}
566
567impl MultinomialNaiveBayes {
568    /// Create a new multinomial Naive Bayes classifier
569    pub fn new(alpha: f64) -> Self {
570        Self {
571            feature_log_probs: HashMap::new(),
572            class_log_priors: HashMap::new(),
573            n_features: 0,
574            alpha,
575            classes: Vec::new(),
576        }
577    }
578
579    /// Train the classifier
580    ///
581    /// # Arguments
582    /// * `features` - Feature matrix (n_samples x n_features), e.g. TF-IDF
583    /// * `labels` - Class labels for each sample
584    pub fn fit(&mut self, features: &Array2<f64>, labels: &[String]) -> Result<()> {
585        if features.nrows() != labels.len() {
586            return Err(TextError::InvalidInput(
587                "Features and labels must have the same number of rows".into(),
588            ));
589        }
590
591        let n_samples = features.nrows();
592        self.n_features = features.ncols();
593
594        // Determine classes
595        let mut class_set = HashSet::new();
596        for label in labels {
597            class_set.insert(label.clone());
598        }
599        self.classes = class_set.into_iter().collect();
600        self.classes.sort();
601
602        // Compute per-class statistics
603        for class in &self.classes {
604            // Gather rows belonging to this class
605            let class_indices: Vec<usize> = labels
606                .iter()
607                .enumerate()
608                .filter(|(_, l)| *l == class)
609                .map(|(i, _)| i)
610                .collect();
611
612            let class_count = class_indices.len();
613
614            // Prior log-probability
615            let log_prior = (class_count as f64 / n_samples as f64).ln();
616            self.class_log_priors.insert(class.clone(), log_prior);
617
618            // Sum features for this class
619            let mut feature_sums = vec![0.0; self.n_features];
620            for &idx in &class_indices {
621                for j in 0..self.n_features {
622                    feature_sums[j] += features[[idx, j]];
623                }
624            }
625
626            // Total count for this class (with smoothing)
627            let total: f64 = feature_sums.iter().sum::<f64>() + self.alpha * self.n_features as f64;
628
629            // Log probabilities with Laplace smoothing
630            let log_probs: Vec<f64> = feature_sums
631                .iter()
632                .map(|&count| ((count + self.alpha) / total).ln())
633                .collect();
634
635            self.feature_log_probs.insert(class.clone(), log_probs);
636        }
637
638        Ok(())
639    }
640
641    /// Predict class labels for feature matrix
642    pub fn predict(&self, features: &Array2<f64>) -> Result<Vec<String>> {
643        let mut predictions = Vec::with_capacity(features.nrows());
644
645        for row in features.axis_iter(Axis(0)) {
646            let (label, _) = self.predict_single(&row.to_owned())?;
647            predictions.push(label);
648        }
649
650        Ok(predictions)
651    }
652
653    /// Predict a single sample
654    fn predict_single(&self, features: &Array1<f64>) -> Result<(String, f64)> {
655        if self.classes.is_empty() {
656            return Err(TextError::ModelNotFitted("Classifier not trained".into()));
657        }
658
659        let mut best_class = String::new();
660        let mut best_score = f64::NEG_INFINITY;
661
662        for class in &self.classes {
663            let log_prior = self
664                .class_log_priors
665                .get(class)
666                .copied()
667                .unwrap_or(f64::NEG_INFINITY);
668
669            let log_probs = self
670                .feature_log_probs
671                .get(class)
672                .ok_or_else(|| TextError::RuntimeError("Missing feature probs".into()))?;
673
674            let log_likelihood: f64 = features
675                .iter()
676                .zip(log_probs.iter())
677                .map(|(&feat, &log_p)| feat * log_p)
678                .sum();
679
680            let score = log_prior + log_likelihood;
681            if score > best_score {
682                best_score = score;
683                best_class = class.clone();
684            }
685        }
686
687        Ok((best_class, best_score))
688    }
689}
690
691// ─── Bernoulli Naive Bayes Classifier ────────────────────────────────────────
692
693/// Bernoulli Naive Bayes classifier
694///
695/// Works with binary features (word present/absent).
696/// Suitable for short texts or bag-of-words with binary encoding.
697#[derive(Debug, Clone)]
698pub struct BernoulliNaiveBayes {
699    /// Log probability that feature is present for each class
700    feature_log_probs: HashMap<String, Vec<f64>>,
701    /// Log probability that feature is absent for each class
702    feature_log_neg_probs: HashMap<String, Vec<f64>>,
703    /// Prior log-probabilities
704    class_log_priors: HashMap<String, f64>,
705    /// Number of features
706    n_features: usize,
707    /// Smoothing parameter
708    alpha: f64,
709    /// Binarization threshold
710    binarize_threshold: f64,
711    /// Classes
712    classes: Vec<String>,
713}
714
715impl BernoulliNaiveBayes {
716    /// Create a new Bernoulli Naive Bayes classifier
717    pub fn new(alpha: f64) -> Self {
718        Self {
719            feature_log_probs: HashMap::new(),
720            feature_log_neg_probs: HashMap::new(),
721            class_log_priors: HashMap::new(),
722            n_features: 0,
723            alpha,
724            binarize_threshold: 0.0,
725            classes: Vec::new(),
726        }
727    }
728
729    /// Set the binarization threshold
730    pub fn with_binarize_threshold(mut self, threshold: f64) -> Self {
731        self.binarize_threshold = threshold;
732        self
733    }
734
735    /// Train the classifier
736    pub fn fit(&mut self, features: &Array2<f64>, labels: &[String]) -> Result<()> {
737        if features.nrows() != labels.len() {
738            return Err(TextError::InvalidInput(
739                "Features and labels must have the same number of rows".into(),
740            ));
741        }
742
743        let n_samples = features.nrows();
744        self.n_features = features.ncols();
745
746        let mut class_set = HashSet::new();
747        for label in labels {
748            class_set.insert(label.clone());
749        }
750        self.classes = class_set.into_iter().collect();
751        self.classes.sort();
752
753        for class in &self.classes {
754            let class_indices: Vec<usize> = labels
755                .iter()
756                .enumerate()
757                .filter(|(_, l)| *l == class)
758                .map(|(i, _)| i)
759                .collect();
760
761            let class_count = class_indices.len() as f64;
762
763            let log_prior = (class_count / n_samples as f64).ln();
764            self.class_log_priors.insert(class.clone(), log_prior);
765
766            // Count documents where each feature is present
767            let mut feature_present = vec![0.0; self.n_features];
768            for &idx in &class_indices {
769                for j in 0..self.n_features {
770                    if features[[idx, j]] > self.binarize_threshold {
771                        feature_present[j] += 1.0;
772                    }
773                }
774            }
775
776            // P(feature_j = 1 | class) with smoothing
777            let log_probs: Vec<f64> = feature_present
778                .iter()
779                .map(|&count| ((count + self.alpha) / (class_count + 2.0 * self.alpha)).ln())
780                .collect();
781
782            let log_neg_probs: Vec<f64> = feature_present
783                .iter()
784                .map(|&count| {
785                    ((class_count - count + self.alpha) / (class_count + 2.0 * self.alpha)).ln()
786                })
787                .collect();
788
789            self.feature_log_probs.insert(class.clone(), log_probs);
790            self.feature_log_neg_probs
791                .insert(class.clone(), log_neg_probs);
792        }
793
794        Ok(())
795    }
796
797    /// Predict class labels
798    pub fn predict(&self, features: &Array2<f64>) -> Result<Vec<String>> {
799        let mut predictions = Vec::with_capacity(features.nrows());
800
801        for row in features.axis_iter(Axis(0)) {
802            let label = self.predict_single(&row.to_owned())?;
803            predictions.push(label);
804        }
805
806        Ok(predictions)
807    }
808
809    fn predict_single(&self, features: &Array1<f64>) -> Result<String> {
810        if self.classes.is_empty() {
811            return Err(TextError::ModelNotFitted("Classifier not trained".into()));
812        }
813
814        let mut best_class = String::new();
815        let mut best_score = f64::NEG_INFINITY;
816
817        for class in &self.classes {
818            let log_prior = self
819                .class_log_priors
820                .get(class)
821                .copied()
822                .unwrap_or(f64::NEG_INFINITY);
823
824            let log_probs = self
825                .feature_log_probs
826                .get(class)
827                .ok_or_else(|| TextError::RuntimeError("Missing probs".into()))?;
828            let log_neg_probs = self
829                .feature_log_neg_probs
830                .get(class)
831                .ok_or_else(|| TextError::RuntimeError("Missing neg probs".into()))?;
832
833            let mut log_likelihood = 0.0;
834            for j in 0..self.n_features {
835                if features[j] > self.binarize_threshold {
836                    log_likelihood += log_probs[j];
837                } else {
838                    log_likelihood += log_neg_probs[j];
839                }
840            }
841
842            let score = log_prior + log_likelihood;
843            if score > best_score {
844                best_score = score;
845                best_class = class.clone();
846            }
847        }
848
849        Ok(best_class)
850    }
851}
852
853// ─── TF-IDF Cosine Similarity Classifier ─────────────────────────────────────
854
855/// TF-IDF + cosine similarity k-NN classifier
856///
857/// Classifies text by finding the k nearest training examples
858/// (by cosine similarity) and taking a majority vote.
859pub struct TfidfCosineClassifier {
860    /// Training TF-IDF vectors
861    train_vectors: Option<Array2<f64>>,
862    /// Training labels
863    train_labels: Vec<String>,
864    /// Number of neighbors
865    k: usize,
866}
867
868impl TfidfCosineClassifier {
869    /// Create a new TF-IDF cosine similarity classifier
870    pub fn new(k: usize) -> Self {
871        Self {
872            train_vectors: None,
873            train_labels: Vec::new(),
874            k,
875        }
876    }
877
878    /// Fit the classifier with pre-computed TF-IDF features
879    pub fn fit(&mut self, features: &Array2<f64>, labels: &[String]) -> Result<()> {
880        if features.nrows() != labels.len() {
881            return Err(TextError::InvalidInput(
882                "Features and labels must have the same number of rows".into(),
883            ));
884        }
885
886        self.train_vectors = Some(features.clone());
887        self.train_labels = labels.to_vec();
888        Ok(())
889    }
890
891    /// Predict class labels for test features
892    pub fn predict(&self, features: &Array2<f64>) -> Result<Vec<String>> {
893        let train_vectors = self
894            .train_vectors
895            .as_ref()
896            .ok_or_else(|| TextError::ModelNotFitted("Classifier not trained".into()))?;
897
898        let mut predictions = Vec::with_capacity(features.nrows());
899
900        for row in features.axis_iter(Axis(0)) {
901            let query = row.to_owned();
902
903            // Compute cosine similarity with all training samples
904            let mut similarities: Vec<(usize, f64)> = Vec::with_capacity(train_vectors.nrows());
905
906            for (idx, train_row) in train_vectors.axis_iter(Axis(0)).enumerate() {
907                let sim = cosine_similarity(&query, &train_row.to_owned());
908                similarities.push((idx, sim));
909            }
910
911            // Sort by similarity descending
912            similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
913
914            // Take top-k and do majority vote
915            let mut class_votes: HashMap<&str, usize> = HashMap::new();
916            let k = self.k.min(similarities.len());
917
918            for &(idx, _) in similarities.iter().take(k) {
919                *class_votes.entry(&self.train_labels[idx]).or_insert(0) += 1;
920            }
921
922            let best_class = class_votes
923                .iter()
924                .max_by_key(|(_, &count)| count)
925                .map(|(label, _)| label.to_string())
926                .unwrap_or_default();
927
928            predictions.push(best_class);
929        }
930
931        Ok(predictions)
932    }
933}
934
935/// Compute cosine similarity between two vectors
936fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
937    let dot: f64 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
938    let norm_a = a.iter().map(|x| x * x).sum::<f64>().sqrt();
939    let norm_b = b.iter().map(|x| x * x).sum::<f64>().sqrt();
940
941    if norm_a > 0.0 && norm_b > 0.0 {
942        dot / (norm_a * norm_b)
943    } else {
944        0.0
945    }
946}
947
948// ─── Feature Hashing ─────────────────────────────────────────────────────────
949
950/// Feature hasher (hashing trick) for text classification
951///
952/// Maps tokens to a fixed-size feature vector using hashing,
953/// avoiding the need to maintain a vocabulary dictionary.
954/// This is memory-efficient for large vocabularies.
955pub struct FeatureHasher {
956    /// Number of output features (hash buckets)
957    n_features: usize,
958    /// Tokenizer
959    tokenizer: Box<dyn Tokenizer + Send + Sync>,
960}
961
962impl std::fmt::Debug for FeatureHasher {
963    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
964        f.debug_struct("FeatureHasher")
965            .field("n_features", &self.n_features)
966            .finish()
967    }
968}
969
970impl FeatureHasher {
971    /// Create a new feature hasher with the specified number of features
972    pub fn new(n_features: usize) -> Self {
973        Self {
974            n_features,
975            tokenizer: Box::new(WordTokenizer::default()),
976        }
977    }
978
979    /// Hash a string to a feature index using FNV-1a
980    fn hash_feature(&self, token: &str) -> usize {
981        let mut hash: u64 = 2166136261;
982        for byte in token.bytes() {
983            hash ^= u64::from(byte);
984            hash = hash.wrapping_mul(16777619);
985        }
986        (hash % (self.n_features as u64)) as usize
987    }
988
989    /// Determine sign from hash (for signed hashing to reduce collision bias)
990    fn hash_sign(&self, token: &str) -> f64 {
991        let mut hash: u64 = 84696351;
992        for byte in token.bytes() {
993            hash ^= u64::from(byte);
994            hash = hash.wrapping_mul(16777619);
995        }
996        if hash.is_multiple_of(2) {
997            1.0
998        } else {
999            -1.0
1000        }
1001    }
1002
1003    /// Transform a single text into a hashed feature vector
1004    pub fn transform_text(&self, text: &str) -> Result<Array1<f64>> {
1005        let tokens = self.tokenizer.tokenize(text)?;
1006        let mut features = Array1::zeros(self.n_features);
1007
1008        for token in &tokens {
1009            let idx = self.hash_feature(&token.to_lowercase());
1010            let sign = self.hash_sign(&token.to_lowercase());
1011            features[idx] += sign;
1012        }
1013
1014        Ok(features)
1015    }
1016
1017    /// Transform multiple texts into a feature matrix
1018    pub fn transform_batch(&self, texts: &[&str]) -> Result<Array2<f64>> {
1019        let mut matrix = Array2::zeros((texts.len(), self.n_features));
1020
1021        for (i, &text) in texts.iter().enumerate() {
1022            let features = self.transform_text(text)?;
1023            for j in 0..self.n_features {
1024                matrix[[i, j]] = features[j];
1025            }
1026        }
1027
1028        Ok(matrix)
1029    }
1030
1031    /// Get the number of features
1032    pub fn num_features(&self) -> usize {
1033        self.n_features
1034    }
1035}
1036
1037// ─── Multi-Label Classification ──────────────────────────────────────────────
1038
1039/// Multi-label prediction result
1040#[derive(Debug, Clone)]
1041pub struct MultiLabelPrediction {
1042    /// The predicted labels (can be multiple)
1043    pub labels: Vec<String>,
1044    /// Confidence scores for each label
1045    pub scores: HashMap<String, f64>,
1046}
1047
1048/// Multi-label classifier using binary relevance approach
1049///
1050/// Trains one binary classifier per label, allowing texts
1051/// to belong to multiple categories.
1052#[derive(Debug, Clone)]
1053pub struct MultiLabelClassifier {
1054    /// One binary Naive Bayes per label
1055    classifiers: HashMap<String, MultinomialNaiveBayes>,
1056    /// Prediction threshold
1057    threshold: f64,
1058    /// All known labels
1059    all_labels: Vec<String>,
1060}
1061
1062impl MultiLabelClassifier {
1063    /// Create a new multi-label classifier
1064    pub fn new(threshold: f64) -> Self {
1065        Self {
1066            classifiers: HashMap::new(),
1067            threshold,
1068            all_labels: Vec::new(),
1069        }
1070    }
1071
1072    /// Train the classifier
1073    ///
1074    /// # Arguments
1075    /// * `features` - Feature matrix (n_samples x n_features)
1076    /// * `label_sets` - For each sample, a set of labels it belongs to
1077    pub fn fit(&mut self, features: &Array2<f64>, label_sets: &[Vec<String>]) -> Result<()> {
1078        if features.nrows() != label_sets.len() {
1079            return Err(TextError::InvalidInput(
1080                "Features and label_sets must have the same length".into(),
1081            ));
1082        }
1083
1084        // Collect all unique labels
1085        let mut all_labels_set = HashSet::new();
1086        for labels in label_sets {
1087            for label in labels {
1088                all_labels_set.insert(label.clone());
1089            }
1090        }
1091        self.all_labels = all_labels_set.into_iter().collect();
1092        self.all_labels.sort();
1093
1094        // Train one binary classifier per label
1095        for label in &self.all_labels {
1096            let binary_labels: Vec<String> = label_sets
1097                .iter()
1098                .map(|ls| {
1099                    if ls.contains(label) {
1100                        "positive".to_string()
1101                    } else {
1102                        "negative".to_string()
1103                    }
1104                })
1105                .collect();
1106
1107            let mut clf = MultinomialNaiveBayes::new(1.0);
1108            clf.fit(features, &binary_labels)?;
1109            self.classifiers.insert(label.clone(), clf);
1110        }
1111
1112        Ok(())
1113    }
1114
1115    /// Predict labels for feature matrix
1116    pub fn predict(&self, features: &Array2<f64>) -> Result<Vec<MultiLabelPrediction>> {
1117        let mut predictions = Vec::with_capacity(features.nrows());
1118
1119        for row in features.axis_iter(Axis(0)) {
1120            let row_arr = row.to_owned();
1121            let mut labels = Vec::new();
1122            let mut scores = HashMap::new();
1123
1124            // Create a 1-row matrix for the classifier
1125            let single_row = Array2::from_shape_fn((1, row_arr.len()), |(_, j)| row_arr[j]);
1126
1127            for label in &self.all_labels {
1128                if let Some(clf) = self.classifiers.get(label) {
1129                    let pred = clf.predict(&single_row)?;
1130                    if !pred.is_empty() && pred[0] == "positive" {
1131                        labels.push(label.clone());
1132                        scores.insert(label.clone(), 1.0);
1133                    } else {
1134                        scores.insert(label.clone(), 0.0);
1135                    }
1136                }
1137            }
1138
1139            predictions.push(MultiLabelPrediction { labels, scores });
1140        }
1141
1142        Ok(predictions)
1143    }
1144}
1145
1146// ─── Cross-Validation ────────────────────────────────────────────────────────
1147
1148/// Result of a cross-validation fold
1149#[derive(Debug, Clone)]
1150pub struct FoldResult {
1151    /// Fold index
1152    pub fold: usize,
1153    /// Accuracy on the fold
1154    pub accuracy: f64,
1155    /// Predictions for this fold
1156    pub predictions: Vec<String>,
1157    /// True labels for this fold
1158    pub true_labels: Vec<String>,
1159}
1160
1161/// Result of cross-validation
1162#[derive(Debug, Clone)]
1163pub struct CrossValidationResult {
1164    /// Results for each fold
1165    pub fold_results: Vec<FoldResult>,
1166    /// Mean accuracy across folds
1167    pub mean_accuracy: f64,
1168    /// Standard deviation of accuracy
1169    pub std_accuracy: f64,
1170}
1171
1172/// Perform k-fold cross-validation with multinomial Naive Bayes
1173///
1174/// # Arguments
1175/// * `features` - Feature matrix
1176/// * `labels` - Labels
1177/// * `k` - Number of folds
1178/// * `alpha` - Naive Bayes smoothing parameter
1179/// * `seed` - Optional random seed for reproducibility
1180pub fn cross_validate_nb(
1181    features: &Array2<f64>,
1182    labels: &[String],
1183    k: usize,
1184    alpha: f64,
1185    seed: Option<u64>,
1186) -> Result<CrossValidationResult> {
1187    if features.nrows() != labels.len() {
1188        return Err(TextError::InvalidInput(
1189            "Features and labels must have the same length".into(),
1190        ));
1191    }
1192
1193    let n = features.nrows();
1194    if k < 2 || k > n {
1195        return Err(TextError::InvalidInput(format!(
1196            "k must be between 2 and {} (number of samples)",
1197            n
1198        )));
1199    }
1200
1201    // Create shuffled indices
1202    let mut indices: Vec<usize> = (0..n).collect();
1203    if let Some(s) = seed {
1204        let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(s);
1205        indices.shuffle(&mut rng);
1206    } else {
1207        let mut rng = scirs2_core::random::rng();
1208        indices.shuffle(&mut rng);
1209    }
1210
1211    let fold_size = n / k;
1212    let mut fold_results = Vec::with_capacity(k);
1213
1214    for fold in 0..k {
1215        let test_start = fold * fold_size;
1216        let test_end = if fold == k - 1 {
1217            n
1218        } else {
1219            (fold + 1) * fold_size
1220        };
1221
1222        let test_indices: Vec<usize> = indices[test_start..test_end].to_vec();
1223        let train_indices: Vec<usize> = indices
1224            .iter()
1225            .enumerate()
1226            .filter(|(i, _)| *i < test_start || *i >= test_end)
1227            .map(|(_, &idx)| idx)
1228            .collect();
1229
1230        // Build train/test sets
1231        let n_train = train_indices.len();
1232        let n_test = test_indices.len();
1233        let n_features = features.ncols();
1234
1235        let mut train_features = Array2::zeros((n_train, n_features));
1236        let mut train_labels = Vec::with_capacity(n_train);
1237
1238        for (i, &idx) in train_indices.iter().enumerate() {
1239            for j in 0..n_features {
1240                train_features[[i, j]] = features[[idx, j]];
1241            }
1242            train_labels.push(labels[idx].clone());
1243        }
1244
1245        let mut test_features = Array2::zeros((n_test, n_features));
1246        let mut test_labels = Vec::with_capacity(n_test);
1247
1248        for (i, &idx) in test_indices.iter().enumerate() {
1249            for j in 0..n_features {
1250                test_features[[i, j]] = features[[idx, j]];
1251            }
1252            test_labels.push(labels[idx].clone());
1253        }
1254
1255        // Train and predict
1256        let mut clf = MultinomialNaiveBayes::new(alpha);
1257        clf.fit(&train_features, &train_labels)?;
1258        let predictions = clf.predict(&test_features)?;
1259
1260        // Calculate accuracy
1261        let correct = predictions
1262            .iter()
1263            .zip(test_labels.iter())
1264            .filter(|(p, t)| p == t)
1265            .count();
1266        let accuracy = correct as f64 / n_test as f64;
1267
1268        fold_results.push(FoldResult {
1269            fold,
1270            accuracy,
1271            predictions,
1272            true_labels: test_labels,
1273        });
1274    }
1275
1276    // Compute mean and std
1277    let accuracies: Vec<f64> = fold_results.iter().map(|f| f.accuracy).collect();
1278    let mean_accuracy = accuracies.iter().sum::<f64>() / k as f64;
1279    let variance = accuracies
1280        .iter()
1281        .map(|&a| (a - mean_accuracy).powi(2))
1282        .sum::<f64>()
1283        / k as f64;
1284    let std_accuracy = variance.sqrt();
1285
1286    Ok(CrossValidationResult {
1287        fold_results,
1288        mean_accuracy,
1289        std_accuracy,
1290    })
1291}
1292
1293#[cfg(test)]
1294mod tests {
1295    use super::*;
1296
1297    #[test]
1298    fn test_text_dataset() {
1299        let texts = vec![
1300            "This is document 1".to_string(),
1301            "Another document".to_string(),
1302            "A third document".to_string(),
1303        ];
1304        let labels = vec!["A".to_string(), "B".to_string(), "A".to_string()];
1305
1306        let mut dataset = TextDataset::new(texts, labels).expect("Operation failed");
1307
1308        let mut label_index = HashMap::new();
1309        label_index.insert("A".to_string(), 0);
1310        label_index.insert("B".to_string(), 1);
1311        dataset.label_index = Some(label_index);
1312
1313        let label_indices = dataset.get_label_indices().expect("Operation failed");
1314        assert_eq!(label_indices[0], 0);
1315        assert_eq!(label_indices[1], 1);
1316        assert_eq!(label_indices[2], 0);
1317
1318        let unique_labels = dataset.unique_labels();
1319        assert_eq!(unique_labels.len(), 2);
1320    }
1321
1322    #[test]
1323    fn test_train_test_split() {
1324        let texts = (0..10).map(|i| format!("Text {i}")).collect();
1325        let labels = (0..10).map(|_| "A".to_string()).collect();
1326
1327        let dataset = TextDataset::new(texts, labels).expect("Operation failed");
1328        let (train, test) = dataset
1329            .train_test_split(0.3, Some(42))
1330            .expect("Operation failed");
1331
1332        assert_eq!(train.len(), 7);
1333        assert_eq!(test.len(), 3);
1334    }
1335
1336    #[test]
1337    fn test_feature_selector() {
1338        let mut features = Array2::zeros((5, 3));
1339        features[[0, 0]] = 1.0;
1340        features[[1, 0]] = 1.0;
1341        features[[2, 0]] = 1.0;
1342
1343        for i in 0..5 {
1344            features[[i, 1]] = 1.0;
1345        }
1346
1347        features[[0, 2]] = 1.0;
1348
1349        let mut selector = TextFeatureSelector::new()
1350            .set_min_df(0.25)
1351            .expect("Operation failed")
1352            .set_max_df(0.75)
1353            .expect("Operation failed");
1354
1355        let filtered = selector.fit_transform(&features).expect("Operation failed");
1356        assert_eq!(filtered.ncols(), 1);
1357    }
1358
1359    #[test]
1360    fn test_classification_metrics() {
1361        let predictions = vec![1_usize, 0, 1, 1, 0];
1362        let true_labels = vec![1_usize, 0, 1, 0, 0];
1363
1364        let metrics = TextClassificationMetrics::new();
1365        let accuracy = metrics
1366            .accuracy(&predictions, &true_labels)
1367            .expect("Operation failed");
1368        assert_eq!(accuracy, 0.8);
1369
1370        let (precision, recall, f1) = metrics
1371            .binary_metrics(&predictions, &true_labels)
1372            .expect("Operation failed");
1373        assert!((precision - 0.667).abs() < 0.001);
1374        assert_eq!(recall, 1.0);
1375        assert!((f1 - 0.8).abs() < 0.001);
1376    }
1377
1378    // ─── Multinomial NB Tests ────────────────────────────────────────
1379
1380    #[test]
1381    fn test_multinomial_nb_basic() {
1382        // Simple 2-class problem with 3 features
1383        let features = Array2::from_shape_vec(
1384            (6, 3),
1385            vec![
1386                3.0, 1.0, 0.0, // positive
1387                2.0, 2.0, 0.0, // positive
1388                4.0, 0.0, 1.0, // positive
1389                0.0, 1.0, 3.0, // negative
1390                0.0, 2.0, 2.0, // negative
1391                1.0, 0.0, 4.0, // negative
1392            ],
1393        )
1394        .expect("shape");
1395
1396        let labels = vec![
1397            "pos".to_string(),
1398            "pos".to_string(),
1399            "pos".to_string(),
1400            "neg".to_string(),
1401            "neg".to_string(),
1402            "neg".to_string(),
1403        ];
1404
1405        let mut clf = MultinomialNaiveBayes::new(1.0);
1406        clf.fit(&features, &labels).expect("fit");
1407
1408        // Test with clearly positive sample
1409        let test = Array2::from_shape_vec((1, 3), vec![5.0, 0.0, 0.0]).expect("shape");
1410        let pred = clf.predict(&test).expect("predict");
1411        assert_eq!(pred[0], "pos");
1412
1413        // Test with clearly negative sample
1414        let test = Array2::from_shape_vec((1, 3), vec![0.0, 0.0, 5.0]).expect("shape");
1415        let pred = clf.predict(&test).expect("predict");
1416        assert_eq!(pred[0], "neg");
1417    }
1418
1419    // ─── Bernoulli NB Tests ──────────────────────────────────────────
1420
1421    #[test]
1422    fn test_bernoulli_nb_basic() {
1423        let features = Array2::from_shape_vec(
1424            (6, 4),
1425            vec![
1426                1.0, 1.0, 0.0, 0.0, // pos
1427                1.0, 0.0, 1.0, 0.0, // pos
1428                0.0, 1.0, 1.0, 0.0, // pos
1429                0.0, 0.0, 0.0, 1.0, // neg
1430                0.0, 0.0, 1.0, 1.0, // neg
1431                0.0, 1.0, 0.0, 1.0, // neg
1432            ],
1433        )
1434        .expect("shape");
1435
1436        let labels = vec![
1437            "pos".to_string(),
1438            "pos".to_string(),
1439            "pos".to_string(),
1440            "neg".to_string(),
1441            "neg".to_string(),
1442            "neg".to_string(),
1443        ];
1444
1445        let mut clf = BernoulliNaiveBayes::new(1.0);
1446        clf.fit(&features, &labels).expect("fit");
1447
1448        let test = Array2::from_shape_vec((1, 4), vec![1.0, 1.0, 0.0, 0.0]).expect("shape");
1449        let pred = clf.predict(&test).expect("predict");
1450        assert_eq!(pred[0], "pos");
1451
1452        let test = Array2::from_shape_vec((1, 4), vec![0.0, 0.0, 0.0, 1.0]).expect("shape");
1453        let pred = clf.predict(&test).expect("predict");
1454        assert_eq!(pred[0], "neg");
1455    }
1456
1457    // ─── TF-IDF Cosine Classifier Tests ──────────────────────────────
1458
1459    #[test]
1460    fn test_tfidf_cosine_classifier() {
1461        let features = Array2::from_shape_vec(
1462            (4, 3),
1463            vec![
1464                1.0, 0.0, 0.0, // A
1465                0.9, 0.1, 0.0, // A
1466                0.0, 0.0, 1.0, // B
1467                0.1, 0.0, 0.9, // B
1468            ],
1469        )
1470        .expect("shape");
1471
1472        let labels = vec![
1473            "A".to_string(),
1474            "A".to_string(),
1475            "B".to_string(),
1476            "B".to_string(),
1477        ];
1478
1479        let mut clf = TfidfCosineClassifier::new(1);
1480        clf.fit(&features, &labels).expect("fit");
1481
1482        let test = Array2::from_shape_vec((1, 3), vec![0.8, 0.2, 0.0]).expect("shape");
1483        let pred = clf.predict(&test).expect("predict");
1484        assert_eq!(pred[0], "A");
1485    }
1486
1487    // ─── Feature Hashing Tests ───────────────────────────────────────
1488
1489    #[test]
1490    fn test_feature_hasher() {
1491        let hasher = FeatureHasher::new(100);
1492
1493        let features = hasher.transform_text("the quick brown fox").expect("hash");
1494        assert_eq!(features.len(), 100);
1495
1496        // Should have non-zero entries
1497        let nnz = features.iter().filter(|&&v| v != 0.0).count();
1498        assert!(nnz > 0);
1499    }
1500
1501    #[test]
1502    fn test_feature_hasher_batch() {
1503        let hasher = FeatureHasher::new(50);
1504
1505        let texts = vec!["hello world", "foo bar baz"];
1506        let matrix = hasher.transform_batch(&texts).expect("batch");
1507
1508        assert_eq!(matrix.nrows(), 2);
1509        assert_eq!(matrix.ncols(), 50);
1510    }
1511
1512    #[test]
1513    fn test_feature_hasher_deterministic() {
1514        let hasher = FeatureHasher::new(100);
1515
1516        let f1 = hasher.transform_text("hello world").expect("h1");
1517        let f2 = hasher.transform_text("hello world").expect("h2");
1518
1519        for i in 0..100 {
1520            assert_eq!(f1[i], f2[i]);
1521        }
1522    }
1523
1524    // ─── Multi-Label Tests ───────────────────────────────────────────
1525
1526    #[test]
1527    fn test_multi_label_classifier() {
1528        let features = Array2::from_shape_vec(
1529            (4, 3),
1530            vec![
1531                3.0, 1.0, 0.0, // sports + positive
1532                2.0, 2.0, 0.0, // sports
1533                0.0, 1.0, 3.0, // tech + negative
1534                0.0, 0.0, 4.0, // tech
1535            ],
1536        )
1537        .expect("shape");
1538
1539        let label_sets = vec![
1540            vec!["sports".to_string(), "positive".to_string()],
1541            vec!["sports".to_string()],
1542            vec!["tech".to_string(), "negative".to_string()],
1543            vec!["tech".to_string()],
1544        ];
1545
1546        let mut clf = MultiLabelClassifier::new(0.5);
1547        clf.fit(&features, &label_sets).expect("fit");
1548
1549        let test = Array2::from_shape_vec((1, 3), vec![4.0, 0.0, 0.0]).expect("shape");
1550        let preds = clf.predict(&test).expect("predict");
1551        assert!(!preds.is_empty());
1552        // Should predict sports-related labels
1553    }
1554
1555    // ─── Cross-Validation Tests ──────────────────────────────────────
1556
1557    #[test]
1558    fn test_cross_validation() {
1559        // Create a simple linearly separable dataset
1560        let n = 20;
1561        let features = Array2::from_shape_fn((n, 2), |(i, j)| {
1562            if i < n / 2 {
1563                if j == 0 {
1564                    3.0
1565                } else {
1566                    0.0
1567                }
1568            } else {
1569                if j == 0 {
1570                    0.0
1571                } else {
1572                    3.0
1573                }
1574            }
1575        });
1576
1577        let labels: Vec<String> = (0..n)
1578            .map(|i| {
1579                if i < n / 2 {
1580                    "A".to_string()
1581                } else {
1582                    "B".to_string()
1583                }
1584            })
1585            .collect();
1586
1587        let result = cross_validate_nb(&features, &labels, 5, 1.0, Some(42)).expect("cv");
1588
1589        assert_eq!(result.fold_results.len(), 5);
1590        // With linearly separable data, should get high accuracy
1591        assert!(
1592            result.mean_accuracy >= 0.5,
1593            "Mean accuracy: {}",
1594            result.mean_accuracy
1595        );
1596    }
1597
1598    #[test]
1599    fn test_cross_validation_invalid_k() {
1600        let features = Array2::zeros((5, 2));
1601        let labels = vec!["A".to_string(); 5];
1602
1603        let result = cross_validate_nb(&features, &labels, 1, 1.0, None);
1604        assert!(result.is_err());
1605
1606        let result = cross_validate_nb(&features, &labels, 10, 1.0, None);
1607        assert!(result.is_err());
1608    }
1609}