1use crate::error::{Result, TinyDancerError};
13use crate::model::{FastGRNN, FastGRNNConfig};
14use ndarray::{Array1, Array2};
15use rand::seq::SliceRandom;
16use serde::{Deserialize, Serialize};
17use std::path::Path;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct TrainingConfig {
22 pub learning_rate: f32,
24 pub batch_size: usize,
26 pub epochs: usize,
28 pub validation_split: f32,
30 pub early_stopping_patience: Option<usize>,
32 pub lr_decay: f32,
34 pub lr_decay_step: usize,
36 pub grad_clip: f32,
38 pub adam_beta1: f32,
40 pub adam_beta2: f32,
42 pub adam_epsilon: f32,
44 pub l2_reg: f32,
46 pub enable_distillation: bool,
48 pub distillation_temperature: f32,
50 pub distillation_alpha: f32,
52}
53
54impl Default for TrainingConfig {
55 fn default() -> Self {
56 Self {
57 learning_rate: 0.001,
58 batch_size: 32,
59 epochs: 100,
60 validation_split: 0.2,
61 early_stopping_patience: Some(10),
62 lr_decay: 0.5,
63 lr_decay_step: 20,
64 grad_clip: 5.0,
65 adam_beta1: 0.9,
66 adam_beta2: 0.999,
67 adam_epsilon: 1e-8,
68 l2_reg: 1e-5,
69 enable_distillation: false,
70 distillation_temperature: 3.0,
71 distillation_alpha: 0.5,
72 }
73 }
74}
75
76#[derive(Debug, Clone)]
78pub struct TrainingDataset {
79 pub features: Vec<Vec<f32>>,
81 pub labels: Vec<f32>,
83 pub soft_targets: Option<Vec<f32>>,
85}
86
87impl TrainingDataset {
88 pub fn new(features: Vec<Vec<f32>>, labels: Vec<f32>) -> Result<Self> {
90 if features.len() != labels.len() {
91 return Err(TinyDancerError::InvalidInput(
92 "Features and labels must have the same length".to_string(),
93 ));
94 }
95 if features.is_empty() {
96 return Err(TinyDancerError::InvalidInput(
97 "Dataset cannot be empty".to_string(),
98 ));
99 }
100
101 Ok(Self {
102 features,
103 labels,
104 soft_targets: None,
105 })
106 }
107
108 pub fn with_soft_targets(mut self, soft_targets: Vec<f32>) -> Result<Self> {
110 if soft_targets.len() != self.labels.len() {
111 return Err(TinyDancerError::InvalidInput(
112 "Soft targets must match dataset size".to_string(),
113 ));
114 }
115 self.soft_targets = Some(soft_targets);
116 Ok(self)
117 }
118
119 pub fn split(&self, val_ratio: f32) -> Result<(Self, Self)> {
121 if !(0.0..=1.0).contains(&val_ratio) {
122 return Err(TinyDancerError::InvalidInput(
123 "Validation ratio must be between 0.0 and 1.0".to_string(),
124 ));
125 }
126
127 let n_samples = self.features.len();
128 let n_val = (n_samples as f32 * val_ratio) as usize;
129 let n_train = n_samples - n_val;
130
131 let mut indices: Vec<usize> = (0..n_samples).collect();
133 let mut rng = rand::thread_rng();
134 indices.shuffle(&mut rng);
135
136 let train_indices = &indices[..n_train];
137 let val_indices = &indices[n_train..];
138
139 let train_features: Vec<Vec<f32>> =
140 train_indices.iter().map(|&i| self.features[i].clone()).collect();
141 let train_labels: Vec<f32> = train_indices.iter().map(|&i| self.labels[i]).collect();
142
143 let val_features: Vec<Vec<f32>> =
144 val_indices.iter().map(|&i| self.features[i].clone()).collect();
145 let val_labels: Vec<f32> = val_indices.iter().map(|&i| self.labels[i]).collect();
146
147 let mut train_dataset = Self::new(train_features, train_labels)?;
148 let mut val_dataset = Self::new(val_features, val_labels)?;
149
150 if let Some(soft_targets) = &self.soft_targets {
152 let train_soft: Vec<f32> = train_indices.iter().map(|&i| soft_targets[i]).collect();
153 let val_soft: Vec<f32> = val_indices.iter().map(|&i| soft_targets[i]).collect();
154 train_dataset.soft_targets = Some(train_soft);
155 val_dataset.soft_targets = Some(val_soft);
156 }
157
158 Ok((train_dataset, val_dataset))
159 }
160
161 pub fn normalize(&mut self) -> Result<(Vec<f32>, Vec<f32>)> {
163 if self.features.is_empty() {
164 return Err(TinyDancerError::InvalidInput(
165 "Cannot normalize empty dataset".to_string(),
166 ));
167 }
168
169 let n_features = self.features[0].len();
170 let mut means = vec![0.0; n_features];
171 let mut stds = vec![0.0; n_features];
172
173 for feature in &self.features {
175 for (i, &val) in feature.iter().enumerate() {
176 means[i] += val;
177 }
178 }
179 for mean in &mut means {
180 *mean /= self.features.len() as f32;
181 }
182
183 for feature in &self.features {
185 for (i, &val) in feature.iter().enumerate() {
186 stds[i] += (val - means[i]).powi(2);
187 }
188 }
189 for std in &mut stds {
190 *std = (*std / self.features.len() as f32).sqrt();
191 if *std < 1e-8 {
192 *std = 1.0; }
194 }
195
196 for feature in &mut self.features {
198 for (i, val) in feature.iter_mut().enumerate() {
199 *val = (*val - means[i]) / stds[i];
200 }
201 }
202
203 Ok((means, stds))
204 }
205
206 pub fn len(&self) -> usize {
208 self.features.len()
209 }
210
211 pub fn is_empty(&self) -> bool {
213 self.features.is_empty()
214 }
215}
216
217pub struct BatchIterator<'a> {
219 dataset: &'a TrainingDataset,
220 batch_size: usize,
221 indices: Vec<usize>,
222 current_idx: usize,
223}
224
225impl<'a> BatchIterator<'a> {
226 pub fn new(dataset: &'a TrainingDataset, batch_size: usize, shuffle: bool) -> Self {
228 let mut indices: Vec<usize> = (0..dataset.len()).collect();
229 if shuffle {
230 let mut rng = rand::thread_rng();
231 indices.shuffle(&mut rng);
232 }
233
234 Self {
235 dataset,
236 batch_size,
237 indices,
238 current_idx: 0,
239 }
240 }
241}
242
243impl<'a> Iterator for BatchIterator<'a> {
244 type Item = (Vec<Vec<f32>>, Vec<f32>, Option<Vec<f32>>);
245
246 fn next(&mut self) -> Option<Self::Item> {
247 if self.current_idx >= self.indices.len() {
248 return None;
249 }
250
251 let end_idx = (self.current_idx + self.batch_size).min(self.indices.len());
252 let batch_indices = &self.indices[self.current_idx..end_idx];
253
254 let features: Vec<Vec<f32>> = batch_indices
255 .iter()
256 .map(|&i| self.dataset.features[i].clone())
257 .collect();
258
259 let labels: Vec<f32> = batch_indices.iter().map(|&i| self.dataset.labels[i]).collect();
260
261 let soft_targets = self.dataset.soft_targets.as_ref().map(|targets| {
262 batch_indices.iter().map(|&i| targets[i]).collect()
263 });
264
265 self.current_idx = end_idx;
266
267 Some((features, labels, soft_targets))
268 }
269}
270
271#[derive(Debug)]
273struct AdamOptimizer {
274 m_weights: Vec<Array2<f32>>,
276 m_biases: Vec<Array1<f32>>,
277 v_weights: Vec<Array2<f32>>,
279 v_biases: Vec<Array1<f32>>,
280 t: usize,
282 beta1: f32,
284 beta2: f32,
285 epsilon: f32,
286}
287
288impl AdamOptimizer {
289 fn new(model_config: &FastGRNNConfig, training_config: &TrainingConfig) -> Self {
290 let hidden_dim = model_config.hidden_dim;
291 let input_dim = model_config.input_dim;
292 let output_dim = model_config.output_dim;
293
294 Self {
295 m_weights: vec![
296 Array2::zeros((hidden_dim, input_dim)), Array2::zeros((hidden_dim, input_dim)), Array2::zeros((hidden_dim, input_dim)), Array2::zeros((hidden_dim, hidden_dim)), Array2::zeros((output_dim, hidden_dim)), ],
302 m_biases: vec![
303 Array1::zeros(hidden_dim), Array1::zeros(hidden_dim), Array1::zeros(hidden_dim), Array1::zeros(output_dim), ],
308 v_weights: vec![
309 Array2::zeros((hidden_dim, input_dim)),
310 Array2::zeros((hidden_dim, input_dim)),
311 Array2::zeros((hidden_dim, input_dim)),
312 Array2::zeros((hidden_dim, hidden_dim)),
313 Array2::zeros((output_dim, hidden_dim)),
314 ],
315 v_biases: vec![
316 Array1::zeros(hidden_dim),
317 Array1::zeros(hidden_dim),
318 Array1::zeros(hidden_dim),
319 Array1::zeros(output_dim),
320 ],
321 t: 0,
322 beta1: training_config.adam_beta1,
323 beta2: training_config.adam_beta2,
324 epsilon: training_config.adam_epsilon,
325 }
326 }
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize)]
331pub struct TrainingMetrics {
332 pub epoch: usize,
334 pub train_loss: f32,
336 pub val_loss: f32,
338 pub train_accuracy: f32,
340 pub val_accuracy: f32,
342 pub learning_rate: f32,
344}
345
346pub struct Trainer {
348 config: TrainingConfig,
349 optimizer: AdamOptimizer,
350 best_val_loss: f32,
351 patience_counter: usize,
352 metrics_history: Vec<TrainingMetrics>,
353}
354
355impl Trainer {
356 pub fn new(model_config: &FastGRNNConfig, config: TrainingConfig) -> Self {
358 let optimizer = AdamOptimizer::new(model_config, &config);
359
360 Self {
361 config,
362 optimizer,
363 best_val_loss: f32::INFINITY,
364 patience_counter: 0,
365 metrics_history: Vec::new(),
366 }
367 }
368
369 pub fn train(
371 &mut self,
372 model: &mut FastGRNN,
373 dataset: &TrainingDataset,
374 ) -> Result<Vec<TrainingMetrics>> {
375 let (train_dataset, val_dataset) = dataset.split(self.config.validation_split)?;
377
378 println!("Training FastGRNN model");
379 println!("Train samples: {}, Val samples: {}", train_dataset.len(), val_dataset.len());
380 println!("Hyperparameters: {:?}", self.config);
381
382 let mut current_lr = self.config.learning_rate;
383
384 for epoch in 0..self.config.epochs {
385 if epoch > 0 && epoch % self.config.lr_decay_step == 0 {
387 current_lr *= self.config.lr_decay;
388 println!("Decaying learning rate to {:.6}", current_lr);
389 }
390
391 let train_loss = self.train_epoch(model, &train_dataset, current_lr)?;
393
394 let (val_loss, val_accuracy) = self.evaluate(model, &val_dataset)?;
396 let (_, train_accuracy) = self.evaluate(model, &train_dataset)?;
397
398 let metrics = TrainingMetrics {
400 epoch,
401 train_loss,
402 val_loss,
403 train_accuracy,
404 val_accuracy,
405 learning_rate: current_lr,
406 };
407 self.metrics_history.push(metrics.clone());
408
409 println!(
411 "Epoch {}/{}: train_loss={:.4}, val_loss={:.4}, train_acc={:.4}, val_acc={:.4}",
412 epoch + 1,
413 self.config.epochs,
414 train_loss,
415 val_loss,
416 train_accuracy,
417 val_accuracy
418 );
419
420 if let Some(patience) = self.config.early_stopping_patience {
422 if val_loss < self.best_val_loss {
423 self.best_val_loss = val_loss;
424 self.patience_counter = 0;
425 println!("New best validation loss: {:.4}", val_loss);
426 } else {
427 self.patience_counter += 1;
428 if self.patience_counter >= patience {
429 println!("Early stopping triggered at epoch {}", epoch + 1);
430 break;
431 }
432 }
433 }
434 }
435
436 Ok(self.metrics_history.clone())
437 }
438
439 fn train_epoch(
441 &mut self,
442 model: &mut FastGRNN,
443 dataset: &TrainingDataset,
444 learning_rate: f32,
445 ) -> Result<f32> {
446 let mut total_loss = 0.0;
447 let mut n_batches = 0;
448
449 let batch_iter = BatchIterator::new(dataset, self.config.batch_size, true);
450
451 for (features, labels, soft_targets) in batch_iter {
452 let batch_loss = self.train_batch(model, &features, &labels, soft_targets.as_ref(), learning_rate)?;
453 total_loss += batch_loss;
454 n_batches += 1;
455 }
456
457 Ok(total_loss / n_batches as f32)
458 }
459
460 fn train_batch(
462 &mut self,
463 model: &mut FastGRNN,
464 features: &[Vec<f32>],
465 labels: &[f32],
466 soft_targets: Option<&Vec<f32>>,
467 learning_rate: f32,
468 ) -> Result<f32> {
469 let batch_size = features.len();
470 let mut total_loss = 0.0;
471
472 for (i, feature) in features.iter().enumerate() {
481 let prediction = model.forward(feature, None)?;
482 let target = labels[i];
483
484 let loss = if self.config.enable_distillation {
486 if let Some(soft_targets) = soft_targets {
487 let hard_loss = binary_cross_entropy(prediction, target);
489 let soft_loss = binary_cross_entropy(prediction, soft_targets[i]);
490 self.config.distillation_alpha * soft_loss
491 + (1.0 - self.config.distillation_alpha) * hard_loss
492 } else {
493 binary_cross_entropy(prediction, target)
494 }
495 } else {
496 binary_cross_entropy(prediction, target)
497 };
498
499 total_loss += loss;
500
501 }
506
507 self.apply_gradients(model, learning_rate)?;
509
510 Ok(total_loss / batch_size as f32)
511 }
512
513 fn apply_gradients(&mut self, _model: &mut FastGRNN, _learning_rate: f32) -> Result<()> {
515 self.optimizer.t += 1;
517
518 Ok(())
529 }
530
531 fn evaluate(&self, model: &FastGRNN, dataset: &TrainingDataset) -> Result<(f32, f32)> {
533 let mut total_loss = 0.0;
534 let mut correct = 0;
535
536 for (i, feature) in dataset.features.iter().enumerate() {
537 let prediction = model.forward(feature, None)?;
538 let target = dataset.labels[i];
539
540 let loss = binary_cross_entropy(prediction, target);
542 total_loss += loss;
543
544 let predicted_class = if prediction >= 0.5 { 1.0_f32 } else { 0.0_f32 };
546 let target_class = if target >= 0.5 { 1.0_f32 } else { 0.0_f32 };
547 if (predicted_class - target_class).abs() < 0.01_f32 {
548 correct += 1;
549 }
550 }
551
552 let avg_loss = total_loss / dataset.len() as f32;
553 let accuracy = correct as f32 / dataset.len() as f32;
554
555 Ok((avg_loss, accuracy))
556 }
557
558 pub fn metrics_history(&self) -> &[TrainingMetrics] {
560 &self.metrics_history
561 }
562
563 pub fn save_metrics<P: AsRef<Path>>(&self, path: P) -> Result<()> {
565 let json = serde_json::to_string_pretty(&self.metrics_history)
566 .map_err(|e| TinyDancerError::SerializationError(e.to_string()))?;
567 std::fs::write(path, json)?;
568 Ok(())
569 }
570}
571
572fn binary_cross_entropy(prediction: f32, target: f32) -> f32 {
574 let eps = 1e-7;
575 let pred = prediction.clamp(eps, 1.0 - eps);
576 -target * pred.ln() - (1.0 - target) * (1.0 - pred).ln()
577}
578
579pub fn temperature_softmax(logit: f32, temperature: f32) -> f32 {
581 let scaled = logit / temperature;
583 if scaled > 0.0 {
584 1.0 / (1.0 + (-scaled).exp())
585 } else {
586 let ex = scaled.exp();
587 ex / (1.0 + ex)
588 }
589}
590
591pub fn generate_teacher_predictions(
593 teacher: &FastGRNN,
594 features: &[Vec<f32>],
595 temperature: f32,
596) -> Result<Vec<f32>> {
597 features
598 .iter()
599 .map(|feature| {
600 let logit = teacher.forward(feature, None)?;
601 Ok(temperature_softmax(logit, temperature))
603 })
604 .collect()
605}
606
607#[cfg(test)]
608mod tests {
609 use super::*;
610
611 #[test]
612 fn test_dataset_creation() {
613 let features = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
614 let labels = vec![0.0, 1.0, 0.0];
615 let dataset = TrainingDataset::new(features, labels).unwrap();
616 assert_eq!(dataset.len(), 3);
617 }
618
619 #[test]
620 fn test_dataset_split() {
621 let features = vec![vec![1.0; 5]; 100];
622 let labels = vec![0.0; 100];
623 let dataset = TrainingDataset::new(features, labels).unwrap();
624 let (train, val) = dataset.split(0.2).unwrap();
625 assert_eq!(train.len(), 80);
626 assert_eq!(val.len(), 20);
627 }
628
629 #[test]
630 fn test_batch_iterator() {
631 let features = vec![vec![1.0; 5]; 10];
632 let labels = vec![0.0; 10];
633 let dataset = TrainingDataset::new(features, labels).unwrap();
634 let mut iter = BatchIterator::new(&dataset, 3, false);
635
636 let batch1 = iter.next().unwrap();
637 assert_eq!(batch1.0.len(), 3);
638
639 let batch2 = iter.next().unwrap();
640 assert_eq!(batch2.0.len(), 3);
641
642 let batch3 = iter.next().unwrap();
643 assert_eq!(batch3.0.len(), 3);
644
645 let batch4 = iter.next().unwrap();
646 assert_eq!(batch4.0.len(), 1); assert!(iter.next().is_none());
649 }
650
651 #[test]
652 fn test_normalization() {
653 let features = vec![
654 vec![1.0, 2.0, 3.0],
655 vec![4.0, 5.0, 6.0],
656 vec![7.0, 8.0, 9.0],
657 ];
658 let labels = vec![0.0, 1.0, 0.0];
659 let mut dataset = TrainingDataset::new(features, labels).unwrap();
660 let (means, stds) = dataset.normalize().unwrap();
661
662 assert_eq!(means.len(), 3);
663 assert_eq!(stds.len(), 3);
664
665 let sum: f32 = dataset.features.iter().map(|f| f[0]).sum();
667 let mean = sum / dataset.len() as f32;
668 assert!((mean.abs()) < 1e-5);
669 }
670
671 #[test]
672 fn test_bce_loss() {
673 let loss1 = binary_cross_entropy(0.9, 1.0);
674 let loss2 = binary_cross_entropy(0.1, 1.0);
675 assert!(loss1 < loss2); }
677
678 #[test]
679 fn test_temperature_softmax() {
680 let logit = 2.0;
681 let soft1 = temperature_softmax(logit, 1.0);
682 let soft2 = temperature_softmax(logit, 2.0);
683
684 assert!((soft1 - 0.5).abs() > (soft2 - 0.5).abs());
686 }
687}