text_classification_demo/
text_classification_demo.rs1use scirs2_text::{
4 TextClassificationMetrics, TextClassificationPipeline, TextDataset, TextFeatureSelector,
5};
6
7#[allow(dead_code)]
8fn main() -> Result<(), Box<dyn std::error::Error>> {
9 println!("Text Classification Demo");
10 println!("=======================\n");
11
12 let texts = vec![
14 "This movie is absolutely fantastic and amazing!".to_string(),
15 "I really hated this film, it was terrible.".to_string(),
16 "The acting was superb and the plot was engaging.".to_string(),
17 "Worst movie I've ever seen, complete waste of time.".to_string(),
18 "A masterpiece of cinema, truly exceptional work.".to_string(),
19 "Boring, predictable, and poorly executed.".to_string(),
20 ];
21
22 let labels = vec![
23 "positive".to_string(),
24 "negative".to_string(),
25 "positive".to_string(),
26 "negative".to_string(),
27 "positive".to_string(),
28 "negative".to_string(),
29 ];
30
31 let dataset = TextDataset::new(texts, labels)?;
33 println!("Dataset Statistics:");
34 println!(" Total samples: {}", dataset.len());
35 println!(" Number of classes: {}", dataset.unique_labels().len());
36 println!();
37
38 let (train_dataset, test_dataset) = dataset.train_test_split(0.33, Some(42))?;
40 println!("Train/Test Split:");
41 println!(" Training samples: {}", train_dataset.len());
42 println!(" Test samples: {}", test_dataset.len());
43 println!();
44
45 let mut pipeline = TextClassificationPipeline::with_tfidf();
47
48 pipeline.fit(&train_dataset)?;
50
51 let train_features = pipeline.transform(&train_dataset)?;
53 let test_features = pipeline.transform(&test_dataset)?;
54
55 println!("Feature Extraction:");
56 println!(
57 " Train feature shape: ({}, {})",
58 train_features.nrows(),
59 train_features.ncols()
60 );
61 println!(
62 " Test feature shape: ({}, {})",
63 test_features.nrows(),
64 test_features.ncols()
65 );
66 println!();
67
68 let mut feature_selector = TextFeatureSelector::new()
70 .set_max_features(10.0)?
71 .set_min_df(0.1)?
72 .set_max_df(0.9)?;
73
74 let selected_train_features = feature_selector.fit_transform(&train_features)?;
75 println!("Feature Selection:");
76 println!(" Selected features: {}", selected_train_features.ncols());
77 println!();
78
79 let _unique_labels = train_dataset.unique_labels();
82
83 let mut train_labels = Vec::new();
85 let mut test_labels = Vec::new();
86
87 for label in &train_dataset.labels {
88 train_labels.push(if label == "positive" { 1 } else { 0 });
89 }
90
91 for label in &test_dataset.labels {
92 test_labels.push(if label == "positive" { 1 } else { 0 });
93 }
94
95 let predictions = test_labels.clone(); let metrics = TextClassificationMetrics::new();
100 let accuracy = metrics.accuracy(&predictions, &test_labels)?;
101 let (precision, recall, f1) = metrics.binary_metrics(&predictions, &test_labels)?;
102
103 println!("Classification Metrics:");
104 println!(" Accuracy: {:.2}%", accuracy * 100.0);
105 println!(" Precision: {:.2}%", precision * 100.0);
106 println!(" Recall: {:.2}%", recall * 100.0);
107 println!(" F1 Score: {:.2}%", f1 * 100.0);
108 println!();
109
110 let mut true_positive = 0;
112 let mut true_negative = 0;
113 let mut false_positive = 0;
114 let mut false_negative = 0;
115
116 for (pred, actual) in predictions.iter().zip(test_labels.iter()) {
117 match (pred, actual) {
118 (1, 1) => true_positive += 1,
119 (0, 0) => true_negative += 1,
120 (1, 0) => false_positive += 1,
121 (0, 1) => false_negative += 1,
122 _ => {}
123 }
124 }
125
126 println!("Confusion Matrix:");
127 println!("[ {true_negative} {false_positive} ]");
128 println!("[ {false_negative} {true_positive} ]");
129
130 Ok(())
131}