1use scirs2_text::{MLSentimentAnalyzer, MLSentimentConfig, TextDataset};
4
5#[allow(dead_code)]
6fn main() -> Result<(), Box<dyn std::error::Error>> {
7 println!("ML-based Sentiment Analysis Demo");
8 println!("================================\n");
9
10 let (train_dataset, test_dataset) = create_sentiment_dataset()?;
12
13 println!("1. Dataset Information");
14 println!("--------------------");
15 println!("Training examples: {}", train_dataset.texts.len());
16 println!("Test examples: {}", test_dataset.texts.len());
17
18 let labels_train: std::collections::HashSet<_> = train_dataset.labels.iter().cloned().collect();
19 println!("Labels: {labels_train:?}\n");
20
21 println!("2. Training ML Sentiment Analyzer");
23 println!("------------------------------");
24
25 let config = MLSentimentConfig {
26 learning_rate: 0.05,
27 epochs: 200,
28 regularization: 0.01,
29 batch_size: 32,
30 random_seed: Some(42),
31 };
32
33 let mut analyzer = MLSentimentAnalyzer::new().with_config(config);
34
35 println!("Training...");
36 let training_metrics = analyzer.train(&train_dataset)?;
37
38 println!("Training complete!");
39 println!("Final accuracy: {:.4}", training_metrics.accuracy);
40 println!(
41 "Final loss: {:.4}",
42 training_metrics
43 .loss_history
44 .last()
45 .expect("Operation failed")
46 );
47
48 println!("\nLoss history (first 10 epochs):");
50 print!(" ");
51 for i in 0..10 {
52 print!("{:.2} ", training_metrics.loss_history[i]);
53 }
54 println!("...");
55
56 println!("\n3. Evaluation");
58 println!("-----------");
59
60 let eval_metrics = analyzer.evaluate(&test_dataset)?;
61
62 println!("Accuracy: {:.4}", eval_metrics.accuracy);
63 println!("Precision: {:.4}", eval_metrics.precision);
64 println!("Recall: {:.4}", eval_metrics.recall);
65 println!("F1 Score: {:.4}", eval_metrics.f1_score);
66
67 println!("\nClass metrics:");
68 for (label, metrics) in &eval_metrics.class_metrics {
69 println!(
70 " {}: Precision={:.4}, Recall={:.4}, F1={:.4}",
71 label, metrics.precision, metrics.recall, metrics.f1_score
72 );
73 }
74
75 println!("\nConfusion Matrix:");
77 for i in 0..eval_metrics.confusion_matrix.nrows() {
78 print!(" ");
79 for j in 0..eval_metrics.confusion_matrix.ncols() {
80 print!("{:4} ", eval_metrics.confusion_matrix[[i, j]]);
81 }
82 println!();
83 }
84
85 println!("\n4. Sentiment Predictions");
87 println!("----------------------");
88
89 let testtexts = vec![
90 "This product is amazing! I absolutely love it and would recommend it to everyone.",
91 "Terrible experience. The customer service was awful and the product doesn't work.",
92 "It's okay. Not great, not terrible, just average.",
93 "Good value for money, but there are some issues with the packaging.",
94 "Worst purchase ever. Complete waste of money.",
95 ];
96
97 println!("Sample text predictions:");
98 for text in testtexts {
99 let result = analyzer.predict(text)?;
100 println!(
101 "\"{}...\"\n → {} (Score: {:.2}, Confidence: {:.2})\n",
102 text.chars().take(40).collect::<String>(),
103 result.sentiment,
104 result.score,
105 result.confidence
106 );
107 }
108
109 println!("5. Batch Prediction");
111 println!("----------------");
112
113 let batchtexts = vec![
114 "Excellent quality product",
115 "Poor performance for the price",
116 "Somewhat satisfied with purchase",
117 ];
118
119 let batch_results = analyzer.predict_batch(&batchtexts)?;
120
121 for (i, result) in batch_results.iter().enumerate() {
122 println!(
123 "Text {}: {} (Confidence: {:.2})",
124 i + 1,
125 result.sentiment,
126 result.confidence
127 );
128 }
129
130 println!("\n6. Hyperparameter Comparison");
132 println!("--------------------------");
133
134 let configs = vec![
135 (0.01, 100, "Low learning rate"),
136 (0.1, 100, "High learning rate"),
137 (0.05, 50, "Medium rate, fewer epochs"),
138 (0.05, 200, "Medium rate, more epochs"),
139 ];
140
141 for (lr, epochs, desc) in configs {
142 let config = MLSentimentConfig {
143 learning_rate: lr,
144 epochs,
145 regularization: 0.01,
146 batch_size: 32,
147 random_seed: Some(42),
148 };
149
150 let mut temp_analyzer = MLSentimentAnalyzer::new().with_config(config);
151 let _metrics = temp_analyzer.train(&train_dataset)?;
152 let eval = temp_analyzer.evaluate(&test_dataset)?;
153
154 println!(
155 "{}: Accuracy={:.4}, F1={:.4}",
156 desc, eval.accuracy, eval.f1_score
157 );
158 }
159
160 Ok(())
161}
162
163#[allow(dead_code)]
164fn create_sentiment_dataset() -> Result<(TextDataset, TextDataset), Box<dyn std::error::Error>> {
165 let traintexts = vec![
167 "I absolutely loved this movie! The acting was superb.",
168 "Terrible experience, would not recommend to anyone.",
169 "The product was okay, nothing special but it works.",
170 "Great customer service and fast delivery.",
171 "Disappointing quality for the price paid.",
172 "This is the best purchase I've made all year!",
173 "Waste of money, doesn't work as advertised.",
174 "Mixed feelings about this. Some parts good, others bad.",
175 "Pleasantly surprised by how well this performs.",
176 "Not worth the price. Broke after two weeks.",
177 "Amazing value for the price. Highly recommended!",
178 "Mediocre at best. Wouldn't buy again.",
179 "Fantastic product that exceeds expectations.",
180 "Poor construction quality, arrived damaged.",
181 "It's decent but there are better options available.",
182 "This changed my life! Can't imagine living without it.",
183 "Regret buying this. Customer service was unhelpful.",
184 "Satisfied with my purchase, does what it claims.",
185 "Best in its class. Outstanding performance.",
186 "Very disappointed, doesn't match the description.",
187 "Just average, nothing to write home about.",
188 "Exceeded my expectations in every way.",
189 "One of the worst products I've ever bought.",
190 "Good enough for the price, but has limitations.",
191 "Incredible value! Works perfectly for my needs.",
192 "Would not purchase again. Many flaws.",
193 "Does the job fine, but nothing spectacular.",
194 "Absolutely worthless. Don't waste your money.",
195 "A solid choice. Reliable and well-designed.",
196 "Not impressed at all. Many issues from day one.",
197 ];
198
199 let train_labels = vec![
200 "positive", "negative", "neutral", "positive", "negative", "positive", "negative",
201 "neutral", "positive", "negative", "positive", "negative", "positive", "negative",
202 "neutral", "positive", "negative", "neutral", "positive", "negative", "neutral",
203 "positive", "negative", "neutral", "positive", "negative", "neutral", "negative",
204 "positive", "negative",
205 ];
206
207 let testtexts = [
209 "Loved every minute of it. Highly recommended!",
210 "Terrible product. Complete waste of money.",
211 "It's okay, nothing special but gets the job done.",
212 "Outstanding quality and service. Will buy again!",
213 "Very poor experience. Many issues encountered.",
214 "Adequate for basic needs, but lacks advanced features.",
215 "Couldn't be happier with this purchase.",
216 "Avoid at all costs. Terrible quality.",
217 "Average performance. Neither good nor bad.",
218 "Top-notch quality and design. Very impressed!",
219 ];
220
221 let test_labels = [
222 "positive", "negative", "neutral", "positive", "negative", "neutral", "positive",
223 "negative", "neutral", "positive",
224 ];
225
226 let traintexts = traintexts.iter().map(|t| t.to_string()).collect();
228 let train_labels = train_labels.iter().map(|l| l.to_string()).collect();
229 let testtexts = testtexts.iter().map(|t| t.to_string()).collect();
230 let test_labels = test_labels.iter().map(|l| l.to_string()).collect();
231
232 let train_dataset = TextDataset::new(traintexts, train_labels)?;
234 let test_dataset = TextDataset::new(testtexts, test_labels)?;
235
236 Ok((train_dataset, test_dataset))
237}