text_classification_demo/
text_classification_demo.rs

1//! Text classification example
2
3use scirs2_text::{
4    TextClassificationMetrics, TextClassificationPipeline, TextDataset, TextFeatureSelector,
5};
6
7#[allow(dead_code)]
8fn main() -> Result<(), Box<dyn std::error::Error>> {
9    println!("Text Classification Demo");
10    println!("=======================\n");
11
12    // Create sample dataset
13    let texts = vec![
14        "This movie is absolutely fantastic and amazing!".to_string(),
15        "I really hated this film, it was terrible.".to_string(),
16        "The acting was superb and the plot was engaging.".to_string(),
17        "Worst movie I've ever seen, complete waste of time.".to_string(),
18        "A masterpiece of cinema, truly exceptional work.".to_string(),
19        "Boring, predictable, and poorly executed.".to_string(),
20    ];
21
22    let labels = vec![
23        "positive".to_string(),
24        "negative".to_string(),
25        "positive".to_string(),
26        "negative".to_string(),
27        "positive".to_string(),
28        "negative".to_string(),
29    ];
30
31    // Create dataset
32    let dataset = TextDataset::new(texts, labels)?;
33    println!("Dataset Statistics:");
34    println!("  Total samples: {}", dataset.len());
35    println!("  Number of classes: {}", dataset.unique_labels().len());
36    println!();
37
38    // Split into train and test
39    let (train_dataset, test_dataset) = dataset.train_test_split(0.33, Some(42))?;
40    println!("Train/Test Split:");
41    println!("  Training samples: {}", train_dataset.len());
42    println!("  Test samples: {}", test_dataset.len());
43    println!();
44
45    // Create text processing pipeline
46    let mut pipeline = TextClassificationPipeline::with_tfidf();
47
48    // Fit the pipeline
49    pipeline.fit(&train_dataset)?;
50
51    // Transform to features
52    let train_features = pipeline.transform(&train_dataset)?;
53    let test_features = pipeline.transform(&test_dataset)?;
54
55    println!("Feature Extraction:");
56    println!(
57        "  Train feature shape: ({}, {})",
58        train_features.nrows(),
59        train_features.ncols()
60    );
61    println!(
62        "  Test feature shape: ({}, {})",
63        test_features.nrows(),
64        test_features.ncols()
65    );
66    println!();
67
68    // Demonstrate feature selection
69    let mut feature_selector = TextFeatureSelector::new()
70        .set_max_features(10.0)?
71        .set_min_df(0.1)?
72        .set_max_df(0.9)?;
73
74    let selected_train_features = feature_selector.fit_transform(&train_features)?;
75    println!("Feature Selection:");
76    println!("  Selected features: {}", selected_train_features.ncols());
77    println!();
78
79    // Simulate classification results (in a real scenario, you'd use a classifier)
80    // For demo purposes, we'll create mock predictions based on simple heuristics
81    let _unique_labels = train_dataset.unique_labels();
82
83    // Create binary labels (0 for negative, 1 for positive) for this demo
84    let mut train_labels = Vec::new();
85    let mut test_labels = Vec::new();
86
87    for label in &train_dataset.labels {
88        train_labels.push(if label == "positive" { 1 } else { 0 });
89    }
90
91    for label in &test_dataset.labels {
92        test_labels.push(if label == "positive" { 1 } else { 0 });
93    }
94
95    // Mock predictions (in practice, use a real classifier)
96    let predictions = test_labels.clone(); // Perfect predictions for demo
97
98    // Calculate metrics
99    let metrics = TextClassificationMetrics::new();
100    let accuracy = metrics.accuracy(&predictions, &test_labels)?;
101    let (precision, recall, f1) = metrics.binary_metrics(&predictions, &test_labels)?;
102
103    println!("Classification Metrics:");
104    println!("  Accuracy: {:.2}%", accuracy * 100.0);
105    println!("  Precision: {:.2}%", precision * 100.0);
106    println!("  Recall: {:.2}%", recall * 100.0);
107    println!("  F1 Score: {:.2}%", f1 * 100.0);
108    println!();
109
110    // Create a simple confusion matrix manually since the method isn't available
111    let mut true_positive = 0;
112    let mut true_negative = 0;
113    let mut false_positive = 0;
114    let mut false_negative = 0;
115
116    for (pred, actual) in predictions.iter().zip(test_labels.iter()) {
117        match (pred, actual) {
118            (1, 1) => true_positive += 1,
119            (0, 0) => true_negative += 1,
120            (1, 0) => false_positive += 1,
121            (0, 1) => false_negative += 1,
122            _ => {}
123        }
124    }
125
126    println!("Confusion Matrix:");
127    println!("[ {true_negative} {false_positive} ]");
128    println!("[ {false_negative} {true_positive} ]");
129
130    Ok(())
131}