scirs2_text/
classification.rs

1//! Text classification functionality
2//!
3//! This module provides tools for text classification including
4//! metrics, feature selection, and classification pipelines.
5
6use crate::error::{Result, TextError};
7use crate::vectorize::{TfidfVectorizer, Vectorizer};
8use scirs2_core::ndarray::{Array2, Axis};
9use scirs2_core::random::prelude::*;
10use scirs2_core::random::seq::SliceRandom;
11use scirs2_core::random::SeedableRng;
12
13/// Text feature selector
14///
15/// This utility selects features based on document frequency.
16/// It can filter out features that are too rare or too common.
17#[derive(Debug, Clone)]
18pub struct TextFeatureSelector {
19    /// Minimum document frequency (fraction or count)
20    min_df: f64,
21    /// Maximum document frequency (fraction or count)
22    max_df: f64,
23    /// Whether to use raw counts instead of fractions
24    use_counts: bool,
25    /// Selected feature indices
26    selected_features: Option<Vec<usize>>,
27}
28
29impl Default for TextFeatureSelector {
30    fn default() -> Self {
31        Self {
32            min_df: 0.0,
33            max_df: 1.0,
34            use_counts: false,
35            selected_features: None,
36        }
37    }
38}
39
40impl TextFeatureSelector {
41    /// Create a new feature selector
42    pub fn new() -> Self {
43        Self::default()
44    }
45
46    /// Set minimum document frequency
47    pub fn set_min_df(mut self, mindf: f64) -> Result<Self> {
48        if mindf < 0.0 {
49            return Err(TextError::InvalidInput(
50                "min_df must be non-negative".to_string(),
51            ));
52        }
53        self.min_df = mindf;
54        Ok(self)
55    }
56
57    /// Set maximum document frequency
58    pub fn set_max_df(mut self, maxdf: f64) -> Result<Self> {
59        if !(0.0..=1.0).contains(&maxdf) {
60            return Err(TextError::InvalidInput(
61                "max_df must be between 0 and 1 for fractions".to_string(),
62            ));
63        }
64        self.max_df = maxdf;
65        Ok(self)
66    }
67
68    /// Set maximum document frequency (alias for set_max_df)
69    pub fn set_max_features(self, maxfeatures: f64) -> Result<Self> {
70        self.set_max_df(maxfeatures)
71    }
72
73    /// Set to use absolute counts instead of fractions
74    pub fn use_counts(mut self, usecounts: bool) -> Self {
75        self.use_counts = usecounts;
76        self
77    }
78
79    /// Fit the feature selector to data
80    pub fn fit(&mut self, x: &Array2<f64>) -> Result<&mut Self> {
81        let n_samples = x.nrows();
82        let n_features = x.ncols();
83
84        let mut document_frequencies = vec![0; n_features];
85
86        // Count document frequency for each feature
87        for sample in x.axis_iter(Axis(0)) {
88            for (feature_idx, &value) in sample.iter().enumerate() {
89                if value > 0.0 {
90                    document_frequencies[feature_idx] += 1;
91                }
92            }
93        }
94
95        // Calculate min and max document counts
96        let min_count = if self.use_counts {
97            self.min_df
98        } else {
99            self.min_df * n_samples as f64
100        };
101
102        let max_count = if self.use_counts {
103            self.max_df
104        } else {
105            self.max_df * n_samples as f64
106        };
107
108        // Select features based on document frequency
109        let mut selected_features = Vec::new();
110        for (idx, &df) in document_frequencies.iter().enumerate() {
111            let df_f64 = df as f64;
112            if df_f64 >= min_count && df_f64 <= max_count {
113                selected_features.push(idx);
114            }
115        }
116
117        self.selected_features = Some(selected_features);
118        Ok(self)
119    }
120
121    /// Transform data using selected features
122    pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
123        let selected_features = self
124            .selected_features
125            .as_ref()
126            .ok_or_else(|| TextError::ModelNotFitted("Feature selector not fitted".to_string()))?;
127
128        if selected_features.is_empty() {
129            return Err(TextError::InvalidInput(
130                "No features selected. Try adjusting min_df and max_df".to_string(),
131            ));
132        }
133
134        let n_samples = x.nrows();
135        let n_selected = selected_features.len();
136
137        let mut result = Array2::zeros((n_samples, n_selected));
138
139        for (i, row) in x.axis_iter(Axis(0)).enumerate() {
140            for (j, &feature_idx) in selected_features.iter().enumerate() {
141                result[[i, j]] = row[feature_idx];
142            }
143        }
144
145        Ok(result)
146    }
147
148    /// Fit and transform in one step
149    pub fn fit_transform(&mut self, x: &Array2<f64>) -> Result<Array2<f64>> {
150        self.fit(x)?;
151        self.transform(x)
152    }
153
154    /// Get selected feature indices
155    pub fn get_selected_features(&self) -> Option<&Vec<usize>> {
156        self.selected_features.as_ref()
157    }
158}
159
160/// Text classification metrics
161#[derive(Debug, Clone)]
162pub struct TextClassificationMetrics;
163
164impl Default for TextClassificationMetrics {
165    fn default() -> Self {
166        Self
167    }
168}
169
170impl TextClassificationMetrics {
171    /// Create a new metrics calculator
172    pub fn new() -> Self {
173        Self
174    }
175
176    /// Calculate precision score
177    pub fn precision<T>(
178        &self,
179        predictions: &[T],
180        true_labels: &[T],
181        class_idx: Option<T>,
182    ) -> Result<f64>
183    where
184        T: PartialEq + Copy + Default,
185    {
186        let positive_class = class_idx.unwrap_or_default();
187
188        if predictions.len() != true_labels.len() {
189            return Err(TextError::InvalidInput(
190                "Predictions and _labels must have the same length".to_string(),
191            ));
192        }
193
194        let mut true_positives = 0;
195        let mut predicted_positives = 0;
196
197        for i in 0..predictions.len() {
198            if predictions[i] == positive_class {
199                predicted_positives += 1;
200                if true_labels[i] == positive_class {
201                    true_positives += 1;
202                }
203            }
204        }
205
206        if predicted_positives == 0 {
207            return Ok(0.0);
208        }
209
210        Ok(true_positives as f64 / predicted_positives as f64)
211    }
212
213    /// Calculate recall score
214    pub fn recall<T>(
215        &self,
216        predictions: &[T],
217        true_labels: &[T],
218        class_idx: Option<T>,
219    ) -> Result<f64>
220    where
221        T: PartialEq + Copy + Default,
222    {
223        let positive_class = class_idx.unwrap_or_default();
224
225        if predictions.len() != true_labels.len() {
226            return Err(TextError::InvalidInput(
227                "Predictions and _labels must have the same length".to_string(),
228            ));
229        }
230
231        let mut true_positives = 0;
232        let mut actual_positives = 0;
233
234        for i in 0..predictions.len() {
235            if true_labels[i] == positive_class {
236                actual_positives += 1;
237                if predictions[i] == positive_class {
238                    true_positives += 1;
239                }
240            }
241        }
242
243        if actual_positives == 0 {
244            return Ok(0.0);
245        }
246
247        Ok(true_positives as f64 / actual_positives as f64)
248    }
249
250    /// Calculate F1 score
251    pub fn f1_score<T>(
252        &self,
253        predictions: &[T],
254        true_labels: &[T],
255        class_idx: Option<T>,
256    ) -> Result<f64>
257    where
258        T: PartialEq + Copy + Default,
259    {
260        let precision = self.precision(predictions, true_labels, class_idx)?;
261        let recall = self.recall(predictions, true_labels, class_idx)?;
262
263        if precision + recall == 0.0 {
264            return Ok(0.0);
265        }
266
267        Ok(2.0 * precision * recall / (precision + recall))
268    }
269
270    /// Calculate accuracy from predictions and true labels
271    pub fn accuracy<T>(&self, predictions: &[T], truelabels: &[T]) -> Result<f64>
272    where
273        T: PartialEq,
274    {
275        if predictions.len() != truelabels.len() {
276            return Err(TextError::InvalidInput(
277                "Predictions and _labels must have the same length".to_string(),
278            ));
279        }
280
281        if predictions.is_empty() {
282            return Err(TextError::InvalidInput(
283                "Cannot calculate accuracy for empty arrays".to_string(),
284            ));
285        }
286
287        let correct = predictions
288            .iter()
289            .zip(truelabels.iter())
290            .filter(|(pred, true_label)| pred == true_label)
291            .count();
292
293        Ok(correct as f64 / predictions.len() as f64)
294    }
295
296    /// Calculate precision, recall, and F1 score for binary classification
297    pub fn binary_metrics<T>(&self, predictions: &[T], truelabels: &[T]) -> Result<(f64, f64, f64)>
298    where
299        T: PartialEq + Copy + Default + PartialEq<usize>,
300    {
301        if predictions.len() != truelabels.len() {
302            return Err(TextError::InvalidInput(
303                "Predictions and _labels must have the same length".to_string(),
304            ));
305        }
306
307        // Count true positives, false positives, false negatives
308        let mut tp = 0;
309        let mut fp = 0;
310        let mut fn_ = 0;
311
312        for (pred, true_label) in predictions.iter().zip(truelabels.iter()) {
313            if *pred == 1 && *true_label == 1 {
314                tp += 1;
315            } else if *pred == 1 && *true_label == 0 {
316                fp += 1;
317            } else if *pred == 0 && *true_label == 1 {
318                fn_ += 1;
319            }
320        }
321
322        // Calculate precision, recall, F1
323        let precision = if tp + fp > 0 {
324            tp as f64 / (tp + fp) as f64
325        } else {
326            0.0
327        };
328
329        let recall = if tp + fn_ > 0 {
330            tp as f64 / (tp + fn_) as f64
331        } else {
332            0.0
333        };
334
335        let f1 = if precision + recall > 0.0 {
336            2.0 * precision * recall / (precision + recall)
337        } else {
338            0.0
339        };
340
341        Ok((precision, recall, f1))
342    }
343}
344
345/// Text classification dataset
346#[derive(Debug, Clone)]
347pub struct TextDataset {
348    /// The text samples
349    pub texts: Vec<String>,
350    /// The labels for each text
351    pub labels: Vec<String>,
352    /// Index mapping for labels
353    label_index: Option<std::collections::HashMap<String, usize>>,
354}
355
356impl TextDataset {
357    /// Create a new text dataset
358    pub fn new(texts: Vec<String>, labels: Vec<String>) -> Result<Self> {
359        if texts.len() != labels.len() {
360            return Err(TextError::InvalidInput(
361                "Texts and labels must have the same length".to_string(),
362            ));
363        }
364
365        Ok(Self {
366            texts,
367            labels,
368            label_index: None,
369        })
370    }
371
372    /// Get the number of samples
373    pub fn len(&self) -> usize {
374        self.texts.len()
375    }
376
377    /// Check if the dataset is empty
378    pub fn is_empty(&self) -> bool {
379        self.texts.is_empty()
380    }
381
382    /// Get the unique labels in the dataset
383    pub fn unique_labels(&self) -> Vec<String> {
384        let mut unique = std::collections::HashSet::new();
385        for label in &self.labels {
386            unique.insert(label.clone());
387        }
388        unique.into_iter().collect()
389    }
390
391    /// Build a label index mapping
392    pub fn build_label_index(&mut self) -> Result<&mut Self> {
393        let mut index = std::collections::HashMap::new();
394        let unique_labels = self.unique_labels();
395
396        for (i, label) in unique_labels.iter().enumerate() {
397            index.insert(label.clone(), i);
398        }
399
400        self.label_index = Some(index);
401        Ok(self)
402    }
403
404    /// Get label indices
405    pub fn get_label_indices(&self) -> Result<Vec<usize>> {
406        let index = self
407            .label_index
408            .as_ref()
409            .ok_or_else(|| TextError::ModelNotFitted("Label index not built".to_string()))?;
410
411        self.labels
412            .iter()
413            .map(|label| {
414                index
415                    .get(label)
416                    .copied()
417                    .ok_or_else(|| TextError::InvalidInput(format!("Unknown label: {label}")))
418            })
419            .collect()
420    }
421
422    /// Split the dataset into train and test sets
423    pub fn train_test_split(
424        &self,
425        test_size: f64,
426        random_seed: Option<u64>,
427    ) -> Result<(Self, Self)> {
428        if test_size <= 0.0 || test_size >= 1.0 {
429            return Err(TextError::InvalidInput(
430                "test_size must be between 0 and 1".to_string(),
431            ));
432        }
433
434        if self.is_empty() {
435            return Err(TextError::InvalidInput("Dataset is empty".to_string()));
436        }
437
438        // Create indices and shuffle them
439        let mut indices: Vec<usize> = (0..self.len()).collect();
440
441        // Shuffle indices
442        if let Some(_seed) = random_seed {
443            // Use deterministic RNG with _seed
444            let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(_seed);
445            indices.shuffle(&mut rng);
446        } else {
447            // Use standard rng
448            let mut rng = scirs2_core::random::rng();
449            indices.shuffle(&mut rng);
450        }
451
452        // Split indices
453        let test_size = (self.len() as f64 * test_size).ceil() as usize;
454        let test_indices = indices[0..test_size].to_vec();
455        let train_indices = indices[test_size..].to_vec();
456
457        // Create datasets
458        let traintexts = train_indices
459            .iter()
460            .map(|&i| self.texts[i].clone())
461            .collect();
462        let train_labels = train_indices
463            .iter()
464            .map(|&i| self.labels[i].clone())
465            .collect();
466
467        let testtexts = test_indices
468            .iter()
469            .map(|&i| self.texts[i].clone())
470            .collect();
471        let test_labels = test_indices
472            .iter()
473            .map(|&i| self.labels[i].clone())
474            .collect();
475
476        let mut train_dataset = Self::new(traintexts, train_labels)?;
477        let mut test_dataset = Self::new(testtexts, test_labels)?;
478
479        // If we have a label index, build it for the split datasets
480        if self.label_index.is_some() {
481            train_dataset.build_label_index()?;
482            test_dataset.build_label_index()?;
483        }
484
485        Ok((train_dataset, test_dataset))
486    }
487}
488
489/// Pipeline for text classification
490pub struct TextClassificationPipeline {
491    /// The vectorizer to use
492    vectorizer: TfidfVectorizer,
493    /// Optional feature selector
494    feature_selector: Option<TextFeatureSelector>,
495}
496
497impl TextClassificationPipeline {
498    /// Create a new pipeline with a default TF-IDF vectorizer
499    pub fn with_tfidf() -> Self {
500        Self::new(TfidfVectorizer::default())
501    }
502
503    /// Create a new pipeline with the given vectorizer
504    pub fn new(vectorizer: TfidfVectorizer) -> Self {
505        Self {
506            vectorizer,
507            feature_selector: None,
508        }
509    }
510
511    /// Add a feature selector to the pipeline
512    pub fn with_feature_selector(mut self, selector: TextFeatureSelector) -> Self {
513        self.feature_selector = Some(selector);
514        self
515    }
516
517    /// Fit the pipeline to training data
518    pub fn fit(&mut self, dataset: &TextDataset) -> Result<&mut Self> {
519        let texts: Vec<&str> = dataset.texts.iter().map(AsRef::as_ref).collect();
520        self.vectorizer.fit(&texts)?;
521
522        Ok(self)
523    }
524
525    /// Transform text data using the pipeline
526    pub fn transform(&self, dataset: &TextDataset) -> Result<Array2<f64>> {
527        let texts: Vec<&str> = dataset.texts.iter().map(AsRef::as_ref).collect();
528        let mut features = self.vectorizer.transform_batch(&texts)?;
529
530        if let Some(selector) = &self.feature_selector {
531            features = selector.transform(&features)?;
532        }
533
534        Ok(features)
535    }
536
537    /// Fit and transform in one step
538    pub fn fit_transform(&mut self, dataset: &TextDataset) -> Result<Array2<f64>> {
539        self.fit(dataset)?;
540        self.transform(dataset)
541    }
542}
543
544#[cfg(test)]
545mod tests {
546    use super::*;
547
548    #[test]
549    fn testtext_dataset() {
550        let texts = vec![
551            "This is document 1".to_string(),
552            "Another document".to_string(),
553            "A third document".to_string(),
554        ];
555        let labels = vec!["A".to_string(), "B".to_string(), "A".to_string()];
556
557        let mut dataset = TextDataset::new(texts, labels).unwrap();
558
559        // Create a manual label index to explicitly control the index values
560        let mut label_index = std::collections::HashMap::new();
561        label_index.insert("A".to_string(), 0);
562        label_index.insert("B".to_string(), 1);
563        dataset.label_index = Some(label_index);
564
565        let label_indices = dataset.get_label_indices().unwrap();
566
567        // Now we know exactly what the indices should be
568        assert_eq!(label_indices[0], 0); // First label "A" should be index 0
569        assert_eq!(label_indices[1], 1); // Second label "B" should be index 1
570        assert_eq!(label_indices[2], 0); // Third label "A" should be index 0
571
572        let unique_labels = dataset.unique_labels();
573        assert_eq!(unique_labels.len(), 2);
574        assert!(unique_labels.contains(&"A".to_string()));
575        assert!(unique_labels.contains(&"B".to_string()));
576    }
577
578    #[test]
579    fn test_train_test_split() {
580        let texts = (0..10).map(|i| format!("Text {i}")).collect();
581        let labels = (0..10).map(|_| "A".to_string()).collect();
582
583        let dataset = TextDataset::new(texts, labels).unwrap();
584        let (train, test) = dataset.train_test_split(0.3, Some(42)).unwrap();
585
586        assert_eq!(train.len(), 7);
587        assert_eq!(test.len(), 3);
588    }
589
590    #[test]
591    fn test_feature_selector() {
592        let mut features = Array2::zeros((5, 3));
593        // Feature 0: appears in doc 0, 1, 2 (60% of docs)
594        features[[0, 0]] = 1.0;
595        features[[1, 0]] = 1.0;
596        features[[2, 0]] = 1.0;
597
598        // Feature 1: appears in all docs (100% of docs)
599        for i in 0..5 {
600            features[[i, 1]] = 1.0;
601        }
602
603        // Feature 2: appears in doc 0 only (20% of docs)
604        features[[0, 2]] = 1.0;
605
606        let mut selector = TextFeatureSelector::new()
607            .set_min_df(0.25)
608            .unwrap()
609            .set_max_df(0.75)
610            .unwrap();
611
612        let filtered = selector.fit_transform(&features).unwrap();
613        assert_eq!(filtered.ncols(), 1); // Only feature 0 should pass the filters
614    }
615
616    #[test]
617    fn test_classification_metrics() {
618        let predictions = vec![1_usize, 0, 1, 1, 0];
619        let true_labels = vec![1_usize, 0, 1, 0, 0];
620
621        let metrics = TextClassificationMetrics::new();
622        let accuracy = metrics.accuracy(&predictions, &true_labels).unwrap();
623        assert_eq!(accuracy, 0.8);
624
625        let (precision, recall, f1) = metrics.binary_metrics(&predictions, &true_labels).unwrap();
626        assert!((precision - 0.667).abs() < 0.001);
627        assert_eq!(recall, 1.0);
628        assert!((f1 - 0.8).abs() < 0.001);
629    }
630}