Skip to main content

ml_sentiment_demo/
ml_sentiment_demo.rs

1//! Machine learning based sentiment analysis demonstration
2
3use scirs2_text::{MLSentimentAnalyzer, MLSentimentConfig, TextDataset};
4
5#[allow(dead_code)]
6fn main() -> Result<(), Box<dyn std::error::Error>> {
7    println!("ML-based Sentiment Analysis Demo");
8    println!("================================\n");
9
10    // Create a sample dataset
11    let (train_dataset, test_dataset) = create_sentiment_dataset()?;
12
13    println!("1. Dataset Information");
14    println!("--------------------");
15    println!("Training examples: {}", train_dataset.texts.len());
16    println!("Test examples: {}", test_dataset.texts.len());
17
18    let labels_train: std::collections::HashSet<_> = train_dataset.labels.iter().cloned().collect();
19    println!("Labels: {labels_train:?}\n");
20
21    // Configure and train ML sentiment analyzer
22    println!("2. Training ML Sentiment Analyzer");
23    println!("------------------------------");
24
25    let config = MLSentimentConfig {
26        learning_rate: 0.05,
27        epochs: 200,
28        regularization: 0.01,
29        batch_size: 32,
30        random_seed: Some(42),
31    };
32
33    let mut analyzer = MLSentimentAnalyzer::new().with_config(config);
34
35    println!("Training...");
36    let training_metrics = analyzer.train(&train_dataset)?;
37
38    println!("Training complete!");
39    println!("Final accuracy: {:.4}", training_metrics.accuracy);
40    println!(
41        "Final loss: {:.4}",
42        training_metrics
43            .loss_history
44            .last()
45            .expect("Operation failed")
46    );
47
48    // Plot loss history
49    println!("\nLoss history (first 10 epochs):");
50    print!("  ");
51    for i in 0..10 {
52        print!("{:.2} ", training_metrics.loss_history[i]);
53    }
54    println!("...");
55
56    // Evaluate on test data
57    println!("\n3. Evaluation");
58    println!("-----------");
59
60    let eval_metrics = analyzer.evaluate(&test_dataset)?;
61
62    println!("Accuracy: {:.4}", eval_metrics.accuracy);
63    println!("Precision: {:.4}", eval_metrics.precision);
64    println!("Recall: {:.4}", eval_metrics.recall);
65    println!("F1 Score: {:.4}", eval_metrics.f1_score);
66
67    println!("\nClass metrics:");
68    for (label, metrics) in &eval_metrics.class_metrics {
69        println!(
70            "  {}: Precision={:.4}, Recall={:.4}, F1={:.4}",
71            label, metrics.precision, metrics.recall, metrics.f1_score
72        );
73    }
74
75    // Display confusion matrix
76    println!("\nConfusion Matrix:");
77    for i in 0..eval_metrics.confusion_matrix.nrows() {
78        print!("  ");
79        for j in 0..eval_metrics.confusion_matrix.ncols() {
80            print!("{:4} ", eval_metrics.confusion_matrix[[i, j]]);
81        }
82        println!();
83    }
84
85    // Use the model for predictions
86    println!("\n4. Sentiment Predictions");
87    println!("----------------------");
88
89    let testtexts = vec![
90        "This product is amazing! I absolutely love it and would recommend it to everyone.",
91        "Terrible experience. The customer service was awful and the product doesn't work.",
92        "It's okay. Not great, not terrible, just average.",
93        "Good value for money, but there are some issues with the packaging.",
94        "Worst purchase ever. Complete waste of money.",
95    ];
96
97    println!("Sample text predictions:");
98    for text in testtexts {
99        let result = analyzer.predict(text)?;
100        println!(
101            "\"{}...\"\n  → {} (Score: {:.2}, Confidence: {:.2})\n",
102            text.chars().take(40).collect::<String>(),
103            result.sentiment,
104            result.score,
105            result.confidence
106        );
107    }
108
109    // Batch prediction
110    println!("5. Batch Prediction");
111    println!("----------------");
112
113    let batchtexts = vec![
114        "Excellent quality product",
115        "Poor performance for the price",
116        "Somewhat satisfied with purchase",
117    ];
118
119    let batch_results = analyzer.predict_batch(&batchtexts)?;
120
121    for (i, result) in batch_results.iter().enumerate() {
122        println!(
123            "Text {}: {} (Confidence: {:.2})",
124            i + 1,
125            result.sentiment,
126            result.confidence
127        );
128    }
129
130    // Compare with different configurations
131    println!("\n6. Hyperparameter Comparison");
132    println!("--------------------------");
133
134    let configs = vec![
135        (0.01, 100, "Low learning rate"),
136        (0.1, 100, "High learning rate"),
137        (0.05, 50, "Medium rate, fewer epochs"),
138        (0.05, 200, "Medium rate, more epochs"),
139    ];
140
141    for (lr, epochs, desc) in configs {
142        let config = MLSentimentConfig {
143            learning_rate: lr,
144            epochs,
145            regularization: 0.01,
146            batch_size: 32,
147            random_seed: Some(42),
148        };
149
150        let mut temp_analyzer = MLSentimentAnalyzer::new().with_config(config);
151        let _metrics = temp_analyzer.train(&train_dataset)?;
152        let eval = temp_analyzer.evaluate(&test_dataset)?;
153
154        println!(
155            "{}: Accuracy={:.4}, F1={:.4}",
156            desc, eval.accuracy, eval.f1_score
157        );
158    }
159
160    Ok(())
161}
162
163#[allow(dead_code)]
164fn create_sentiment_dataset() -> Result<(TextDataset, TextDataset), Box<dyn std::error::Error>> {
165    // Training data
166    let traintexts = vec![
167        "I absolutely loved this movie! The acting was superb.",
168        "Terrible experience, would not recommend to anyone.",
169        "The product was okay, nothing special but it works.",
170        "Great customer service and fast delivery.",
171        "Disappointing quality for the price paid.",
172        "This is the best purchase I've made all year!",
173        "Waste of money, doesn't work as advertised.",
174        "Mixed feelings about this. Some parts good, others bad.",
175        "Pleasantly surprised by how well this performs.",
176        "Not worth the price. Broke after two weeks.",
177        "Amazing value for the price. Highly recommended!",
178        "Mediocre at best. Wouldn't buy again.",
179        "Fantastic product that exceeds expectations.",
180        "Poor construction quality, arrived damaged.",
181        "It's decent but there are better options available.",
182        "This changed my life! Can't imagine living without it.",
183        "Regret buying this. Customer service was unhelpful.",
184        "Satisfied with my purchase, does what it claims.",
185        "Best in its class. Outstanding performance.",
186        "Very disappointed, doesn't match the description.",
187        "Just average, nothing to write home about.",
188        "Exceeded my expectations in every way.",
189        "One of the worst products I've ever bought.",
190        "Good enough for the price, but has limitations.",
191        "Incredible value! Works perfectly for my needs.",
192        "Would not purchase again. Many flaws.",
193        "Does the job fine, but nothing spectacular.",
194        "Absolutely worthless. Don't waste your money.",
195        "A solid choice. Reliable and well-designed.",
196        "Not impressed at all. Many issues from day one.",
197    ];
198
199    let train_labels = vec![
200        "positive", "negative", "neutral", "positive", "negative", "positive", "negative",
201        "neutral", "positive", "negative", "positive", "negative", "positive", "negative",
202        "neutral", "positive", "negative", "neutral", "positive", "negative", "neutral",
203        "positive", "negative", "neutral", "positive", "negative", "neutral", "negative",
204        "positive", "negative",
205    ];
206
207    // Test data (different examples)
208    let testtexts = [
209        "Loved every minute of it. Highly recommended!",
210        "Terrible product. Complete waste of money.",
211        "It's okay, nothing special but gets the job done.",
212        "Outstanding quality and service. Will buy again!",
213        "Very poor experience. Many issues encountered.",
214        "Adequate for basic needs, but lacks advanced features.",
215        "Couldn't be happier with this purchase.",
216        "Avoid at all costs. Terrible quality.",
217        "Average performance. Neither good nor bad.",
218        "Top-notch quality and design. Very impressed!",
219    ];
220
221    let test_labels = [
222        "positive", "negative", "neutral", "positive", "negative", "neutral", "positive",
223        "negative", "neutral", "positive",
224    ];
225
226    // Convert to strings
227    let traintexts = traintexts.iter().map(|t| t.to_string()).collect();
228    let train_labels = train_labels.iter().map(|l| l.to_string()).collect();
229    let testtexts = testtexts.iter().map(|t| t.to_string()).collect();
230    let test_labels = test_labels.iter().map(|l| l.to_string()).collect();
231
232    // Create datasets
233    let train_dataset = TextDataset::new(traintexts, train_labels)?;
234    let test_dataset = TextDataset::new(testtexts, test_labels)?;
235
236    Ok((train_dataset, test_dataset))
237}