1use crate::error::{Result, TinyDancerError};
13use crate::model::{FastGRNN, FastGRNNConfig};
14use ndarray::{Array1, Array2};
15use rand::seq::SliceRandom;
16use serde::{Deserialize, Serialize};
17use std::path::Path;
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct TrainingConfig {
22 pub learning_rate: f32,
24 pub batch_size: usize,
26 pub epochs: usize,
28 pub validation_split: f32,
30 pub early_stopping_patience: Option<usize>,
32 pub lr_decay: f32,
34 pub lr_decay_step: usize,
36 pub grad_clip: f32,
38 pub adam_beta1: f32,
40 pub adam_beta2: f32,
42 pub adam_epsilon: f32,
44 pub l2_reg: f32,
46 pub enable_distillation: bool,
48 pub distillation_temperature: f32,
50 pub distillation_alpha: f32,
52}
53
54impl Default for TrainingConfig {
55 fn default() -> Self {
56 Self {
57 learning_rate: 0.001,
58 batch_size: 32,
59 epochs: 100,
60 validation_split: 0.2,
61 early_stopping_patience: Some(10),
62 lr_decay: 0.5,
63 lr_decay_step: 20,
64 grad_clip: 5.0,
65 adam_beta1: 0.9,
66 adam_beta2: 0.999,
67 adam_epsilon: 1e-8,
68 l2_reg: 1e-5,
69 enable_distillation: false,
70 distillation_temperature: 3.0,
71 distillation_alpha: 0.5,
72 }
73 }
74}
75
76#[derive(Debug, Clone)]
78pub struct TrainingDataset {
79 pub features: Vec<Vec<f32>>,
81 pub labels: Vec<f32>,
83 pub soft_targets: Option<Vec<f32>>,
85}
86
87impl TrainingDataset {
88 pub fn new(features: Vec<Vec<f32>>, labels: Vec<f32>) -> Result<Self> {
90 if features.len() != labels.len() {
91 return Err(TinyDancerError::InvalidInput(
92 "Features and labels must have the same length".to_string(),
93 ));
94 }
95 if features.is_empty() {
96 return Err(TinyDancerError::InvalidInput(
97 "Dataset cannot be empty".to_string(),
98 ));
99 }
100
101 Ok(Self {
102 features,
103 labels,
104 soft_targets: None,
105 })
106 }
107
108 pub fn with_soft_targets(mut self, soft_targets: Vec<f32>) -> Result<Self> {
110 if soft_targets.len() != self.labels.len() {
111 return Err(TinyDancerError::InvalidInput(
112 "Soft targets must match dataset size".to_string(),
113 ));
114 }
115 self.soft_targets = Some(soft_targets);
116 Ok(self)
117 }
118
119 pub fn split(&self, val_ratio: f32) -> Result<(Self, Self)> {
121 if !(0.0..=1.0).contains(&val_ratio) {
122 return Err(TinyDancerError::InvalidInput(
123 "Validation ratio must be between 0.0 and 1.0".to_string(),
124 ));
125 }
126
127 let n_samples = self.features.len();
128 let n_val = (n_samples as f32 * val_ratio) as usize;
129 let n_train = n_samples - n_val;
130
131 let mut indices: Vec<usize> = (0..n_samples).collect();
133 let mut rng = rand::thread_rng();
134 indices.shuffle(&mut rng);
135
136 let train_indices = &indices[..n_train];
137 let val_indices = &indices[n_train..];
138
139 let train_features: Vec<Vec<f32>> = train_indices
140 .iter()
141 .map(|&i| self.features[i].clone())
142 .collect();
143 let train_labels: Vec<f32> = train_indices.iter().map(|&i| self.labels[i]).collect();
144
145 let val_features: Vec<Vec<f32>> = val_indices
146 .iter()
147 .map(|&i| self.features[i].clone())
148 .collect();
149 let val_labels: Vec<f32> = val_indices.iter().map(|&i| self.labels[i]).collect();
150
151 let mut train_dataset = Self::new(train_features, train_labels)?;
152 let mut val_dataset = Self::new(val_features, val_labels)?;
153
154 if let Some(soft_targets) = &self.soft_targets {
156 let train_soft: Vec<f32> = train_indices.iter().map(|&i| soft_targets[i]).collect();
157 let val_soft: Vec<f32> = val_indices.iter().map(|&i| soft_targets[i]).collect();
158 train_dataset.soft_targets = Some(train_soft);
159 val_dataset.soft_targets = Some(val_soft);
160 }
161
162 Ok((train_dataset, val_dataset))
163 }
164
165 pub fn normalize(&mut self) -> Result<(Vec<f32>, Vec<f32>)> {
167 if self.features.is_empty() {
168 return Err(TinyDancerError::InvalidInput(
169 "Cannot normalize empty dataset".to_string(),
170 ));
171 }
172
173 let n_features = self.features[0].len();
174 let mut means = vec![0.0; n_features];
175 let mut stds = vec![0.0; n_features];
176
177 for feature in &self.features {
179 for (i, &val) in feature.iter().enumerate() {
180 means[i] += val;
181 }
182 }
183 for mean in &mut means {
184 *mean /= self.features.len() as f32;
185 }
186
187 for feature in &self.features {
189 for (i, &val) in feature.iter().enumerate() {
190 stds[i] += (val - means[i]).powi(2);
191 }
192 }
193 for std in &mut stds {
194 *std = (*std / self.features.len() as f32).sqrt();
195 if *std < 1e-8 {
196 *std = 1.0; }
198 }
199
200 for feature in &mut self.features {
202 for (i, val) in feature.iter_mut().enumerate() {
203 *val = (*val - means[i]) / stds[i];
204 }
205 }
206
207 Ok((means, stds))
208 }
209
210 pub fn len(&self) -> usize {
212 self.features.len()
213 }
214
215 pub fn is_empty(&self) -> bool {
217 self.features.is_empty()
218 }
219}
220
221pub struct BatchIterator<'a> {
223 dataset: &'a TrainingDataset,
224 batch_size: usize,
225 indices: Vec<usize>,
226 current_idx: usize,
227}
228
229impl<'a> BatchIterator<'a> {
230 pub fn new(dataset: &'a TrainingDataset, batch_size: usize, shuffle: bool) -> Self {
232 let mut indices: Vec<usize> = (0..dataset.len()).collect();
233 if shuffle {
234 let mut rng = rand::thread_rng();
235 indices.shuffle(&mut rng);
236 }
237
238 Self {
239 dataset,
240 batch_size,
241 indices,
242 current_idx: 0,
243 }
244 }
245}
246
247impl<'a> Iterator for BatchIterator<'a> {
248 type Item = (Vec<Vec<f32>>, Vec<f32>, Option<Vec<f32>>);
249
250 fn next(&mut self) -> Option<Self::Item> {
251 if self.current_idx >= self.indices.len() {
252 return None;
253 }
254
255 let end_idx = (self.current_idx + self.batch_size).min(self.indices.len());
256 let batch_indices = &self.indices[self.current_idx..end_idx];
257
258 let features: Vec<Vec<f32>> = batch_indices
259 .iter()
260 .map(|&i| self.dataset.features[i].clone())
261 .collect();
262
263 let labels: Vec<f32> = batch_indices
264 .iter()
265 .map(|&i| self.dataset.labels[i])
266 .collect();
267
268 let soft_targets = self
269 .dataset
270 .soft_targets
271 .as_ref()
272 .map(|targets| batch_indices.iter().map(|&i| targets[i]).collect());
273
274 self.current_idx = end_idx;
275
276 Some((features, labels, soft_targets))
277 }
278}
279
280#[derive(Debug)]
282struct AdamOptimizer {
283 m_weights: Vec<Array2<f32>>,
285 m_biases: Vec<Array1<f32>>,
286 v_weights: Vec<Array2<f32>>,
288 v_biases: Vec<Array1<f32>>,
289 t: usize,
291 beta1: f32,
293 beta2: f32,
294 epsilon: f32,
295}
296
297impl AdamOptimizer {
298 fn new(model_config: &FastGRNNConfig, training_config: &TrainingConfig) -> Self {
299 let hidden_dim = model_config.hidden_dim;
300 let input_dim = model_config.input_dim;
301 let output_dim = model_config.output_dim;
302
303 Self {
304 m_weights: vec![
305 Array2::zeros((hidden_dim, input_dim)), Array2::zeros((hidden_dim, input_dim)), Array2::zeros((hidden_dim, input_dim)), Array2::zeros((hidden_dim, hidden_dim)), Array2::zeros((output_dim, hidden_dim)), ],
311 m_biases: vec![
312 Array1::zeros(hidden_dim), Array1::zeros(hidden_dim), Array1::zeros(hidden_dim), Array1::zeros(output_dim), ],
317 v_weights: vec![
318 Array2::zeros((hidden_dim, input_dim)),
319 Array2::zeros((hidden_dim, input_dim)),
320 Array2::zeros((hidden_dim, input_dim)),
321 Array2::zeros((hidden_dim, hidden_dim)),
322 Array2::zeros((output_dim, hidden_dim)),
323 ],
324 v_biases: vec![
325 Array1::zeros(hidden_dim),
326 Array1::zeros(hidden_dim),
327 Array1::zeros(hidden_dim),
328 Array1::zeros(output_dim),
329 ],
330 t: 0,
331 beta1: training_config.adam_beta1,
332 beta2: training_config.adam_beta2,
333 epsilon: training_config.adam_epsilon,
334 }
335 }
336}
337
338#[derive(Debug, Clone, Serialize, Deserialize)]
340pub struct TrainingMetrics {
341 pub epoch: usize,
343 pub train_loss: f32,
345 pub val_loss: f32,
347 pub train_accuracy: f32,
349 pub val_accuracy: f32,
351 pub learning_rate: f32,
353}
354
355pub struct Trainer {
357 config: TrainingConfig,
358 optimizer: AdamOptimizer,
359 best_val_loss: f32,
360 patience_counter: usize,
361 metrics_history: Vec<TrainingMetrics>,
362}
363
364impl Trainer {
365 pub fn new(model_config: &FastGRNNConfig, config: TrainingConfig) -> Self {
367 let optimizer = AdamOptimizer::new(model_config, &config);
368
369 Self {
370 config,
371 optimizer,
372 best_val_loss: f32::INFINITY,
373 patience_counter: 0,
374 metrics_history: Vec::new(),
375 }
376 }
377
378 pub fn train(
380 &mut self,
381 model: &mut FastGRNN,
382 dataset: &TrainingDataset,
383 ) -> Result<Vec<TrainingMetrics>> {
384 let (train_dataset, val_dataset) = dataset.split(self.config.validation_split)?;
386
387 println!("Training FastGRNN model");
388 println!(
389 "Train samples: {}, Val samples: {}",
390 train_dataset.len(),
391 val_dataset.len()
392 );
393 println!("Hyperparameters: {:?}", self.config);
394
395 let mut current_lr = self.config.learning_rate;
396
397 for epoch in 0..self.config.epochs {
398 if epoch > 0 && epoch % self.config.lr_decay_step == 0 {
400 current_lr *= self.config.lr_decay;
401 println!("Decaying learning rate to {:.6}", current_lr);
402 }
403
404 let train_loss = self.train_epoch(model, &train_dataset, current_lr)?;
406
407 let (val_loss, val_accuracy) = self.evaluate(model, &val_dataset)?;
409 let (_, train_accuracy) = self.evaluate(model, &train_dataset)?;
410
411 let metrics = TrainingMetrics {
413 epoch,
414 train_loss,
415 val_loss,
416 train_accuracy,
417 val_accuracy,
418 learning_rate: current_lr,
419 };
420 self.metrics_history.push(metrics.clone());
421
422 println!(
424 "Epoch {}/{}: train_loss={:.4}, val_loss={:.4}, train_acc={:.4}, val_acc={:.4}",
425 epoch + 1,
426 self.config.epochs,
427 train_loss,
428 val_loss,
429 train_accuracy,
430 val_accuracy
431 );
432
433 if let Some(patience) = self.config.early_stopping_patience {
435 if val_loss < self.best_val_loss {
436 self.best_val_loss = val_loss;
437 self.patience_counter = 0;
438 println!("New best validation loss: {:.4}", val_loss);
439 } else {
440 self.patience_counter += 1;
441 if self.patience_counter >= patience {
442 println!("Early stopping triggered at epoch {}", epoch + 1);
443 break;
444 }
445 }
446 }
447 }
448
449 Ok(self.metrics_history.clone())
450 }
451
452 fn train_epoch(
454 &mut self,
455 model: &mut FastGRNN,
456 dataset: &TrainingDataset,
457 learning_rate: f32,
458 ) -> Result<f32> {
459 let mut total_loss = 0.0;
460 let mut n_batches = 0;
461
462 let batch_iter = BatchIterator::new(dataset, self.config.batch_size, true);
463
464 for (features, labels, soft_targets) in batch_iter {
465 let batch_loss = self.train_batch(
466 model,
467 &features,
468 &labels,
469 soft_targets.as_ref(),
470 learning_rate,
471 )?;
472 total_loss += batch_loss;
473 n_batches += 1;
474 }
475
476 Ok(total_loss / n_batches as f32)
477 }
478
479 fn train_batch(
481 &mut self,
482 model: &mut FastGRNN,
483 features: &[Vec<f32>],
484 labels: &[f32],
485 soft_targets: Option<&Vec<f32>>,
486 learning_rate: f32,
487 ) -> Result<f32> {
488 let batch_size = features.len();
489 let mut total_loss = 0.0;
490
491 for (i, feature) in features.iter().enumerate() {
500 let prediction = model.forward(feature, None)?;
501 let target = labels[i];
502
503 let loss = if self.config.enable_distillation {
505 if let Some(soft_targets) = soft_targets {
506 let hard_loss = binary_cross_entropy(prediction, target);
508 let soft_loss = binary_cross_entropy(prediction, soft_targets[i]);
509 self.config.distillation_alpha * soft_loss
510 + (1.0 - self.config.distillation_alpha) * hard_loss
511 } else {
512 binary_cross_entropy(prediction, target)
513 }
514 } else {
515 binary_cross_entropy(prediction, target)
516 };
517
518 total_loss += loss;
519
520 }
525
526 self.apply_gradients(model, learning_rate)?;
528
529 Ok(total_loss / batch_size as f32)
530 }
531
532 fn apply_gradients(&mut self, _model: &mut FastGRNN, _learning_rate: f32) -> Result<()> {
534 self.optimizer.t += 1;
536
537 Ok(())
548 }
549
550 fn evaluate(&self, model: &FastGRNN, dataset: &TrainingDataset) -> Result<(f32, f32)> {
552 let mut total_loss = 0.0;
553 let mut correct = 0;
554
555 for (i, feature) in dataset.features.iter().enumerate() {
556 let prediction = model.forward(feature, None)?;
557 let target = dataset.labels[i];
558
559 let loss = binary_cross_entropy(prediction, target);
561 total_loss += loss;
562
563 let predicted_class = if prediction >= 0.5 { 1.0_f32 } else { 0.0_f32 };
565 let target_class = if target >= 0.5 { 1.0_f32 } else { 0.0_f32 };
566 if (predicted_class - target_class).abs() < 0.01_f32 {
567 correct += 1;
568 }
569 }
570
571 let avg_loss = total_loss / dataset.len() as f32;
572 let accuracy = correct as f32 / dataset.len() as f32;
573
574 Ok((avg_loss, accuracy))
575 }
576
577 pub fn metrics_history(&self) -> &[TrainingMetrics] {
579 &self.metrics_history
580 }
581
582 pub fn save_metrics<P: AsRef<Path>>(&self, path: P) -> Result<()> {
584 let json = serde_json::to_string_pretty(&self.metrics_history)
585 .map_err(|e| TinyDancerError::SerializationError(e.to_string()))?;
586 std::fs::write(path, json)?;
587 Ok(())
588 }
589}
590
591fn binary_cross_entropy(prediction: f32, target: f32) -> f32 {
593 let eps = 1e-7;
594 let pred = prediction.clamp(eps, 1.0 - eps);
595 -target * pred.ln() - (1.0 - target) * (1.0 - pred).ln()
596}
597
598pub fn temperature_softmax(logit: f32, temperature: f32) -> f32 {
600 let scaled = logit / temperature;
602 if scaled > 0.0 {
603 1.0 / (1.0 + (-scaled).exp())
604 } else {
605 let ex = scaled.exp();
606 ex / (1.0 + ex)
607 }
608}
609
610pub fn generate_teacher_predictions(
612 teacher: &FastGRNN,
613 features: &[Vec<f32>],
614 temperature: f32,
615) -> Result<Vec<f32>> {
616 features
617 .iter()
618 .map(|feature| {
619 let logit = teacher.forward(feature, None)?;
620 Ok(temperature_softmax(logit, temperature))
622 })
623 .collect()
624}
625
626#[cfg(test)]
627mod tests {
628 use super::*;
629
630 #[test]
631 fn test_dataset_creation() {
632 let features = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
633 let labels = vec![0.0, 1.0, 0.0];
634 let dataset = TrainingDataset::new(features, labels).unwrap();
635 assert_eq!(dataset.len(), 3);
636 }
637
638 #[test]
639 fn test_dataset_split() {
640 let features = vec![vec![1.0; 5]; 100];
641 let labels = vec![0.0; 100];
642 let dataset = TrainingDataset::new(features, labels).unwrap();
643 let (train, val) = dataset.split(0.2).unwrap();
644 assert_eq!(train.len(), 80);
645 assert_eq!(val.len(), 20);
646 }
647
648 #[test]
649 fn test_batch_iterator() {
650 let features = vec![vec![1.0; 5]; 10];
651 let labels = vec![0.0; 10];
652 let dataset = TrainingDataset::new(features, labels).unwrap();
653 let mut iter = BatchIterator::new(&dataset, 3, false);
654
655 let batch1 = iter.next().unwrap();
656 assert_eq!(batch1.0.len(), 3);
657
658 let batch2 = iter.next().unwrap();
659 assert_eq!(batch2.0.len(), 3);
660
661 let batch3 = iter.next().unwrap();
662 assert_eq!(batch3.0.len(), 3);
663
664 let batch4 = iter.next().unwrap();
665 assert_eq!(batch4.0.len(), 1); assert!(iter.next().is_none());
668 }
669
670 #[test]
671 fn test_normalization() {
672 let features = vec![
673 vec![1.0, 2.0, 3.0],
674 vec![4.0, 5.0, 6.0],
675 vec![7.0, 8.0, 9.0],
676 ];
677 let labels = vec![0.0, 1.0, 0.0];
678 let mut dataset = TrainingDataset::new(features, labels).unwrap();
679 let (means, stds) = dataset.normalize().unwrap();
680
681 assert_eq!(means.len(), 3);
682 assert_eq!(stds.len(), 3);
683
684 let sum: f32 = dataset.features.iter().map(|f| f[0]).sum();
686 let mean = sum / dataset.len() as f32;
687 assert!((mean.abs()) < 1e-5);
688 }
689
690 #[test]
691 fn test_bce_loss() {
692 let loss1 = binary_cross_entropy(0.9, 1.0);
693 let loss2 = binary_cross_entropy(0.1, 1.0);
694 assert!(loss1 < loss2); }
696
697 #[test]
698 fn test_temperature_softmax() {
699 let logit = 2.0;
700 let soft1 = temperature_softmax(logit, 1.0);
701 let soft2 = temperature_softmax(logit, 2.0);
702
703 assert!((soft1 - 0.5).abs() > (soft2 - 0.5).abs());
705 }
706}