1use crate::error::{Result, TextError};
7use crate::vectorize::{TfidfVectorizer, Vectorizer};
8use scirs2_core::ndarray::{Array2, Axis};
9use scirs2_core::random::prelude::*;
10use scirs2_core::random::seq::SliceRandom;
11use scirs2_core::random::SeedableRng;
12
13#[derive(Debug, Clone)]
18pub struct TextFeatureSelector {
19 min_df: f64,
21 max_df: f64,
23 use_counts: bool,
25 selected_features: Option<Vec<usize>>,
27}
28
29impl Default for TextFeatureSelector {
30 fn default() -> Self {
31 Self {
32 min_df: 0.0,
33 max_df: 1.0,
34 use_counts: false,
35 selected_features: None,
36 }
37 }
38}
39
40impl TextFeatureSelector {
41 pub fn new() -> Self {
43 Self::default()
44 }
45
46 pub fn set_min_df(mut self, mindf: f64) -> Result<Self> {
48 if mindf < 0.0 {
49 return Err(TextError::InvalidInput(
50 "min_df must be non-negative".to_string(),
51 ));
52 }
53 self.min_df = mindf;
54 Ok(self)
55 }
56
57 pub fn set_max_df(mut self, maxdf: f64) -> Result<Self> {
59 if !(0.0..=1.0).contains(&maxdf) {
60 return Err(TextError::InvalidInput(
61 "max_df must be between 0 and 1 for fractions".to_string(),
62 ));
63 }
64 self.max_df = maxdf;
65 Ok(self)
66 }
67
68 pub fn set_max_features(self, maxfeatures: f64) -> Result<Self> {
70 self.set_max_df(maxfeatures)
71 }
72
73 pub fn use_counts(mut self, usecounts: bool) -> Self {
75 self.use_counts = usecounts;
76 self
77 }
78
79 pub fn fit(&mut self, x: &Array2<f64>) -> Result<&mut Self> {
81 let n_samples = x.nrows();
82 let n_features = x.ncols();
83
84 let mut document_frequencies = vec![0; n_features];
85
86 for sample in x.axis_iter(Axis(0)) {
88 for (feature_idx, &value) in sample.iter().enumerate() {
89 if value > 0.0 {
90 document_frequencies[feature_idx] += 1;
91 }
92 }
93 }
94
95 let min_count = if self.use_counts {
97 self.min_df
98 } else {
99 self.min_df * n_samples as f64
100 };
101
102 let max_count = if self.use_counts {
103 self.max_df
104 } else {
105 self.max_df * n_samples as f64
106 };
107
108 let mut selected_features = Vec::new();
110 for (idx, &df) in document_frequencies.iter().enumerate() {
111 let df_f64 = df as f64;
112 if df_f64 >= min_count && df_f64 <= max_count {
113 selected_features.push(idx);
114 }
115 }
116
117 self.selected_features = Some(selected_features);
118 Ok(self)
119 }
120
121 pub fn transform(&self, x: &Array2<f64>) -> Result<Array2<f64>> {
123 let selected_features = self
124 .selected_features
125 .as_ref()
126 .ok_or_else(|| TextError::ModelNotFitted("Feature selector not fitted".to_string()))?;
127
128 if selected_features.is_empty() {
129 return Err(TextError::InvalidInput(
130 "No features selected. Try adjusting min_df and max_df".to_string(),
131 ));
132 }
133
134 let n_samples = x.nrows();
135 let n_selected = selected_features.len();
136
137 let mut result = Array2::zeros((n_samples, n_selected));
138
139 for (i, row) in x.axis_iter(Axis(0)).enumerate() {
140 for (j, &feature_idx) in selected_features.iter().enumerate() {
141 result[[i, j]] = row[feature_idx];
142 }
143 }
144
145 Ok(result)
146 }
147
148 pub fn fit_transform(&mut self, x: &Array2<f64>) -> Result<Array2<f64>> {
150 self.fit(x)?;
151 self.transform(x)
152 }
153
154 pub fn get_selected_features(&self) -> Option<&Vec<usize>> {
156 self.selected_features.as_ref()
157 }
158}
159
160#[derive(Debug, Clone)]
162pub struct TextClassificationMetrics;
163
164impl Default for TextClassificationMetrics {
165 fn default() -> Self {
166 Self
167 }
168}
169
170impl TextClassificationMetrics {
171 pub fn new() -> Self {
173 Self
174 }
175
176 pub fn precision<T>(
178 &self,
179 predictions: &[T],
180 true_labels: &[T],
181 class_idx: Option<T>,
182 ) -> Result<f64>
183 where
184 T: PartialEq + Copy + Default,
185 {
186 let positive_class = class_idx.unwrap_or_default();
187
188 if predictions.len() != true_labels.len() {
189 return Err(TextError::InvalidInput(
190 "Predictions and _labels must have the same length".to_string(),
191 ));
192 }
193
194 let mut true_positives = 0;
195 let mut predicted_positives = 0;
196
197 for i in 0..predictions.len() {
198 if predictions[i] == positive_class {
199 predicted_positives += 1;
200 if true_labels[i] == positive_class {
201 true_positives += 1;
202 }
203 }
204 }
205
206 if predicted_positives == 0 {
207 return Ok(0.0);
208 }
209
210 Ok(true_positives as f64 / predicted_positives as f64)
211 }
212
213 pub fn recall<T>(
215 &self,
216 predictions: &[T],
217 true_labels: &[T],
218 class_idx: Option<T>,
219 ) -> Result<f64>
220 where
221 T: PartialEq + Copy + Default,
222 {
223 let positive_class = class_idx.unwrap_or_default();
224
225 if predictions.len() != true_labels.len() {
226 return Err(TextError::InvalidInput(
227 "Predictions and _labels must have the same length".to_string(),
228 ));
229 }
230
231 let mut true_positives = 0;
232 let mut actual_positives = 0;
233
234 for i in 0..predictions.len() {
235 if true_labels[i] == positive_class {
236 actual_positives += 1;
237 if predictions[i] == positive_class {
238 true_positives += 1;
239 }
240 }
241 }
242
243 if actual_positives == 0 {
244 return Ok(0.0);
245 }
246
247 Ok(true_positives as f64 / actual_positives as f64)
248 }
249
250 pub fn f1_score<T>(
252 &self,
253 predictions: &[T],
254 true_labels: &[T],
255 class_idx: Option<T>,
256 ) -> Result<f64>
257 where
258 T: PartialEq + Copy + Default,
259 {
260 let precision = self.precision(predictions, true_labels, class_idx)?;
261 let recall = self.recall(predictions, true_labels, class_idx)?;
262
263 if precision + recall == 0.0 {
264 return Ok(0.0);
265 }
266
267 Ok(2.0 * precision * recall / (precision + recall))
268 }
269
270 pub fn accuracy<T>(&self, predictions: &[T], truelabels: &[T]) -> Result<f64>
272 where
273 T: PartialEq,
274 {
275 if predictions.len() != truelabels.len() {
276 return Err(TextError::InvalidInput(
277 "Predictions and _labels must have the same length".to_string(),
278 ));
279 }
280
281 if predictions.is_empty() {
282 return Err(TextError::InvalidInput(
283 "Cannot calculate accuracy for empty arrays".to_string(),
284 ));
285 }
286
287 let correct = predictions
288 .iter()
289 .zip(truelabels.iter())
290 .filter(|(pred, true_label)| pred == true_label)
291 .count();
292
293 Ok(correct as f64 / predictions.len() as f64)
294 }
295
296 pub fn binary_metrics<T>(&self, predictions: &[T], truelabels: &[T]) -> Result<(f64, f64, f64)>
298 where
299 T: PartialEq + Copy + Default + PartialEq<usize>,
300 {
301 if predictions.len() != truelabels.len() {
302 return Err(TextError::InvalidInput(
303 "Predictions and _labels must have the same length".to_string(),
304 ));
305 }
306
307 let mut tp = 0;
309 let mut fp = 0;
310 let mut fn_ = 0;
311
312 for (pred, true_label) in predictions.iter().zip(truelabels.iter()) {
313 if *pred == 1 && *true_label == 1 {
314 tp += 1;
315 } else if *pred == 1 && *true_label == 0 {
316 fp += 1;
317 } else if *pred == 0 && *true_label == 1 {
318 fn_ += 1;
319 }
320 }
321
322 let precision = if tp + fp > 0 {
324 tp as f64 / (tp + fp) as f64
325 } else {
326 0.0
327 };
328
329 let recall = if tp + fn_ > 0 {
330 tp as f64 / (tp + fn_) as f64
331 } else {
332 0.0
333 };
334
335 let f1 = if precision + recall > 0.0 {
336 2.0 * precision * recall / (precision + recall)
337 } else {
338 0.0
339 };
340
341 Ok((precision, recall, f1))
342 }
343}
344
345#[derive(Debug, Clone)]
347pub struct TextDataset {
348 pub texts: Vec<String>,
350 pub labels: Vec<String>,
352 label_index: Option<std::collections::HashMap<String, usize>>,
354}
355
356impl TextDataset {
357 pub fn new(texts: Vec<String>, labels: Vec<String>) -> Result<Self> {
359 if texts.len() != labels.len() {
360 return Err(TextError::InvalidInput(
361 "Texts and labels must have the same length".to_string(),
362 ));
363 }
364
365 Ok(Self {
366 texts,
367 labels,
368 label_index: None,
369 })
370 }
371
372 pub fn len(&self) -> usize {
374 self.texts.len()
375 }
376
377 pub fn is_empty(&self) -> bool {
379 self.texts.is_empty()
380 }
381
382 pub fn unique_labels(&self) -> Vec<String> {
384 let mut unique = std::collections::HashSet::new();
385 for label in &self.labels {
386 unique.insert(label.clone());
387 }
388 unique.into_iter().collect()
389 }
390
391 pub fn build_label_index(&mut self) -> Result<&mut Self> {
393 let mut index = std::collections::HashMap::new();
394 let unique_labels = self.unique_labels();
395
396 for (i, label) in unique_labels.iter().enumerate() {
397 index.insert(label.clone(), i);
398 }
399
400 self.label_index = Some(index);
401 Ok(self)
402 }
403
404 pub fn get_label_indices(&self) -> Result<Vec<usize>> {
406 let index = self
407 .label_index
408 .as_ref()
409 .ok_or_else(|| TextError::ModelNotFitted("Label index not built".to_string()))?;
410
411 self.labels
412 .iter()
413 .map(|label| {
414 index
415 .get(label)
416 .copied()
417 .ok_or_else(|| TextError::InvalidInput(format!("Unknown label: {label}")))
418 })
419 .collect()
420 }
421
422 pub fn train_test_split(
424 &self,
425 test_size: f64,
426 random_seed: Option<u64>,
427 ) -> Result<(Self, Self)> {
428 if test_size <= 0.0 || test_size >= 1.0 {
429 return Err(TextError::InvalidInput(
430 "test_size must be between 0 and 1".to_string(),
431 ));
432 }
433
434 if self.is_empty() {
435 return Err(TextError::InvalidInput("Dataset is empty".to_string()));
436 }
437
438 let mut indices: Vec<usize> = (0..self.len()).collect();
440
441 if let Some(_seed) = random_seed {
443 let mut rng = scirs2_core::random::rngs::StdRng::seed_from_u64(_seed);
445 indices.shuffle(&mut rng);
446 } else {
447 let mut rng = scirs2_core::random::rng();
449 indices.shuffle(&mut rng);
450 }
451
452 let test_size = (self.len() as f64 * test_size).ceil() as usize;
454 let test_indices = indices[0..test_size].to_vec();
455 let train_indices = indices[test_size..].to_vec();
456
457 let traintexts = train_indices
459 .iter()
460 .map(|&i| self.texts[i].clone())
461 .collect();
462 let train_labels = train_indices
463 .iter()
464 .map(|&i| self.labels[i].clone())
465 .collect();
466
467 let testtexts = test_indices
468 .iter()
469 .map(|&i| self.texts[i].clone())
470 .collect();
471 let test_labels = test_indices
472 .iter()
473 .map(|&i| self.labels[i].clone())
474 .collect();
475
476 let mut train_dataset = Self::new(traintexts, train_labels)?;
477 let mut test_dataset = Self::new(testtexts, test_labels)?;
478
479 if self.label_index.is_some() {
481 train_dataset.build_label_index()?;
482 test_dataset.build_label_index()?;
483 }
484
485 Ok((train_dataset, test_dataset))
486 }
487}
488
489pub struct TextClassificationPipeline {
491 vectorizer: TfidfVectorizer,
493 feature_selector: Option<TextFeatureSelector>,
495}
496
497impl TextClassificationPipeline {
498 pub fn with_tfidf() -> Self {
500 Self::new(TfidfVectorizer::default())
501 }
502
503 pub fn new(vectorizer: TfidfVectorizer) -> Self {
505 Self {
506 vectorizer,
507 feature_selector: None,
508 }
509 }
510
511 pub fn with_feature_selector(mut self, selector: TextFeatureSelector) -> Self {
513 self.feature_selector = Some(selector);
514 self
515 }
516
517 pub fn fit(&mut self, dataset: &TextDataset) -> Result<&mut Self> {
519 let texts: Vec<&str> = dataset.texts.iter().map(AsRef::as_ref).collect();
520 self.vectorizer.fit(&texts)?;
521
522 Ok(self)
523 }
524
525 pub fn transform(&self, dataset: &TextDataset) -> Result<Array2<f64>> {
527 let texts: Vec<&str> = dataset.texts.iter().map(AsRef::as_ref).collect();
528 let mut features = self.vectorizer.transform_batch(&texts)?;
529
530 if let Some(selector) = &self.feature_selector {
531 features = selector.transform(&features)?;
532 }
533
534 Ok(features)
535 }
536
537 pub fn fit_transform(&mut self, dataset: &TextDataset) -> Result<Array2<f64>> {
539 self.fit(dataset)?;
540 self.transform(dataset)
541 }
542}
543
544#[cfg(test)]
545mod tests {
546 use super::*;
547
548 #[test]
549 fn testtext_dataset() {
550 let texts = vec![
551 "This is document 1".to_string(),
552 "Another document".to_string(),
553 "A third document".to_string(),
554 ];
555 let labels = vec!["A".to_string(), "B".to_string(), "A".to_string()];
556
557 let mut dataset = TextDataset::new(texts, labels).unwrap();
558
559 let mut label_index = std::collections::HashMap::new();
561 label_index.insert("A".to_string(), 0);
562 label_index.insert("B".to_string(), 1);
563 dataset.label_index = Some(label_index);
564
565 let label_indices = dataset.get_label_indices().unwrap();
566
567 assert_eq!(label_indices[0], 0); assert_eq!(label_indices[1], 1); assert_eq!(label_indices[2], 0); let unique_labels = dataset.unique_labels();
573 assert_eq!(unique_labels.len(), 2);
574 assert!(unique_labels.contains(&"A".to_string()));
575 assert!(unique_labels.contains(&"B".to_string()));
576 }
577
578 #[test]
579 fn test_train_test_split() {
580 let texts = (0..10).map(|i| format!("Text {i}")).collect();
581 let labels = (0..10).map(|_| "A".to_string()).collect();
582
583 let dataset = TextDataset::new(texts, labels).unwrap();
584 let (train, test) = dataset.train_test_split(0.3, Some(42)).unwrap();
585
586 assert_eq!(train.len(), 7);
587 assert_eq!(test.len(), 3);
588 }
589
590 #[test]
591 fn test_feature_selector() {
592 let mut features = Array2::zeros((5, 3));
593 features[[0, 0]] = 1.0;
595 features[[1, 0]] = 1.0;
596 features[[2, 0]] = 1.0;
597
598 for i in 0..5 {
600 features[[i, 1]] = 1.0;
601 }
602
603 features[[0, 2]] = 1.0;
605
606 let mut selector = TextFeatureSelector::new()
607 .set_min_df(0.25)
608 .unwrap()
609 .set_max_df(0.75)
610 .unwrap();
611
612 let filtered = selector.fit_transform(&features).unwrap();
613 assert_eq!(filtered.ncols(), 1); }
615
616 #[test]
617 fn test_classification_metrics() {
618 let predictions = vec![1_usize, 0, 1, 1, 0];
619 let true_labels = vec![1_usize, 0, 1, 0, 0];
620
621 let metrics = TextClassificationMetrics::new();
622 let accuracy = metrics.accuracy(&predictions, &true_labels).unwrap();
623 assert_eq!(accuracy, 0.8);
624
625 let (precision, recall, f1) = metrics.binary_metrics(&predictions, &true_labels).unwrap();
626 assert!((precision - 0.667).abs() < 0.001);
627 assert_eq!(recall, 1.0);
628 assert!((f1 - 0.8).abs() < 0.001);
629 }
630}