1use ndarray::Array2;
2use crate::models::lstm_network::LSTMNetwork;
3use crate::loss::{LossFunction, MSELoss};
4use crate::optimizers::{Optimizer, SGD, ScheduledOptimizer};
5use crate::schedulers::LearningRateScheduler;
6use crate::persistence::SerializableLSTMNetwork;
7use std::time::Instant;
8
9pub struct TrainingConfig {
11 pub epochs: usize,
12 pub print_every: usize,
13 pub clip_gradient: Option<f64>,
14 pub log_lr_changes: bool,
15 pub early_stopping: Option<EarlyStoppingConfig>,
16}
17
18#[derive(Debug, Clone)]
20pub struct EarlyStoppingConfig {
21 pub patience: usize,
23 pub min_delta: f64,
25 pub restore_best_weights: bool,
27 pub monitor: EarlyStoppingMetric,
29}
30
31#[derive(Debug, Clone, PartialEq)]
33pub enum EarlyStoppingMetric {
34 ValidationLoss,
35 TrainLoss,
36}
37
38impl Default for EarlyStoppingConfig {
39 fn default() -> Self {
40 EarlyStoppingConfig {
41 patience: 10,
42 min_delta: 1e-4,
43 restore_best_weights: true,
44 monitor: EarlyStoppingMetric::ValidationLoss,
45 }
46 }
47}
48
49impl Default for TrainingConfig {
50 fn default() -> Self {
51 TrainingConfig {
52 epochs: 100,
53 print_every: 10,
54 clip_gradient: Some(5.0),
55 log_lr_changes: true,
56 early_stopping: None,
57 }
58 }
59}
60
61#[derive(Debug, Clone)]
63pub struct TrainingMetrics {
64 pub epoch: usize,
65 pub train_loss: f64,
66 pub validation_loss: Option<f64>,
67 pub time_elapsed: f64,
68 pub learning_rate: f64,
69}
70
71#[derive(Debug, Clone)]
73pub struct EarlyStopper {
74 config: EarlyStoppingConfig,
75 best_score: f64,
76 wait_count: usize,
77 stopped_epoch: Option<usize>,
78 best_weights: Option<SerializableLSTMNetwork>, }
80
81impl EarlyStopper {
82 pub fn new(config: EarlyStoppingConfig) -> Self {
83 EarlyStopper {
84 config,
85 best_score: f64::INFINITY,
86 wait_count: 0,
87 stopped_epoch: None,
88 best_weights: None,
89 }
90 }
91
92 pub fn should_stop(&mut self, current_metrics: &TrainingMetrics, network: &LSTMNetwork) -> (bool, bool) {
95 let current_score = match self.config.monitor {
96 EarlyStoppingMetric::ValidationLoss => {
97 match current_metrics.validation_loss {
98 Some(val_loss) => val_loss,
99 None => {
100 current_metrics.train_loss
102 }
103 }
104 }
105 EarlyStoppingMetric::TrainLoss => current_metrics.train_loss,
106 };
107
108 let is_improvement = current_score < self.best_score - self.config.min_delta;
109
110 if is_improvement {
111 self.best_score = current_score;
112 self.wait_count = 0;
113
114 if self.config.restore_best_weights {
116 self.best_weights = Some(network.into());
117 }
118
119 (false, true)
120 } else {
121 self.wait_count += 1;
122
123 if self.wait_count >= self.config.patience {
124 self.stopped_epoch = Some(current_metrics.epoch);
125 (true, false)
126 } else {
127 (false, false)
128 }
129 }
130 }
131
132 pub fn stopped_epoch(&self) -> Option<usize> {
134 self.stopped_epoch
135 }
136
137 pub fn best_score(&self) -> f64 {
139 self.best_score
140 }
141
142 pub fn restore_best_weights(&self, network: &mut LSTMNetwork) -> Result<(), String> {
144 if let Some(ref weights) = self.best_weights {
145 *network = weights.clone().into();
146 Ok(())
147 } else {
148 Err("No best weights available to restore".to_string())
149 }
150 }
151}
152
153pub struct LSTMTrainer<L: LossFunction, O: Optimizer> {
155 pub network: LSTMNetwork,
156 pub loss_function: L,
157 pub optimizer: O,
158 pub config: TrainingConfig,
159 pub metrics_history: Vec<TrainingMetrics>,
160 early_stopper: Option<EarlyStopper>,
161}
162
163impl<L: LossFunction, O: Optimizer> LSTMTrainer<L, O> {
164 pub fn new(network: LSTMNetwork, loss_function: L, optimizer: O) -> Self {
165 LSTMTrainer {
166 network,
167 loss_function,
168 optimizer,
169 config: TrainingConfig::default(),
170 metrics_history: Vec::new(),
171 early_stopper: None,
172 }
173 }
174
175 pub fn with_config(mut self, config: TrainingConfig) -> Self {
176 self.early_stopper = config.early_stopping.as_ref().map(|es_config| {
178 EarlyStopper::new(es_config.clone())
179 });
180 self.config = config;
181 self
182 }
183
184 pub fn train_sequence(&mut self, inputs: &[Array2<f64>], targets: &[Array2<f64>]) -> f64 {
186 if inputs.len() != targets.len() {
187 panic!("Inputs and targets must have the same length");
188 }
189
190 self.network.train();
191
192 let (outputs, caches) = self.network.forward_sequence_with_cache(inputs);
193
194 let mut total_loss = 0.0;
195 let mut total_gradients = self.network.zero_gradients();
196
197 for (i, ((output, _), target)) in outputs.iter().zip(targets.iter()).enumerate().rev() {
198 let loss = self.loss_function.compute_loss(output, target);
199 total_loss += loss;
200
201 let dhy = self.loss_function.compute_gradient(output, target);
202 let dcy = Array2::zeros(output.raw_dim());
203
204 let (step_gradients, _) = self.network.backward(&dhy, &dcy, &caches[i]);
205
206 for (total_grad, step_grad) in total_gradients.iter_mut().zip(step_gradients.iter()) {
207 total_grad.w_ih = &total_grad.w_ih + &step_grad.w_ih;
208 total_grad.w_hh = &total_grad.w_hh + &step_grad.w_hh;
209 total_grad.b_ih = &total_grad.b_ih + &step_grad.b_ih;
210 total_grad.b_hh = &total_grad.b_hh + &step_grad.b_hh;
211 }
212 }
213
214 if let Some(clip_value) = self.config.clip_gradient {
215 self.clip_gradients(&mut total_gradients, clip_value);
216 }
217
218 self.network.update_parameters(&total_gradients, &mut self.optimizer);
219
220 total_loss / inputs.len() as f64
221 }
222
223 pub fn train(&mut self, train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)],
225 validation_data: Option<&[(Vec<Array2<f64>>, Vec<Array2<f64>>)]>) {
226
227 println!("Starting training for {} epochs...", self.config.epochs);
228
229 for epoch in 0..self.config.epochs {
230 let start_time = Instant::now();
231 let mut epoch_loss = 0.0;
232
233 self.network.train();
235 for (inputs, targets) in train_data {
236 let loss = self.train_sequence(inputs, targets);
237 epoch_loss += loss;
238 }
239 epoch_loss /= train_data.len() as f64;
240
241 let validation_loss = if let Some(val_data) = validation_data {
242 self.network.eval();
243 Some(self.evaluate(val_data))
244 } else {
245 None
246 };
247
248 let time_elapsed = start_time.elapsed().as_secs_f64();
249
250 let current_lr = self.optimizer.get_learning_rate();
251 let metrics = TrainingMetrics {
252 epoch,
253 train_loss: epoch_loss,
254 validation_loss,
255 time_elapsed,
256 learning_rate: current_lr,
257 };
258
259 self.metrics_history.push(metrics.clone());
260
261 let mut should_stop = false;
263 let mut is_best = false;
264 if let Some(ref mut early_stopper) = self.early_stopper {
265 let (stop, best) = early_stopper.should_stop(&metrics, &self.network);
266 should_stop = stop;
267 is_best = best;
268 }
269
270 if epoch % self.config.print_every == 0 {
271 let best_indicator = if is_best { " *" } else { "" };
272 if let Some(val_loss) = validation_loss {
273 println!("Epoch {}: Train Loss: {:.6}, Val Loss: {:.6}, LR: {:.2e}, Time: {:.2}s{}",
274 epoch, epoch_loss, val_loss, current_lr, time_elapsed, best_indicator);
275 } else {
276 println!("Epoch {}: Train Loss: {:.6}, LR: {:.2e}, Time: {:.2}s{}",
277 epoch, epoch_loss, current_lr, time_elapsed, best_indicator);
278 }
279 }
280
281 if should_stop {
282 let stopped_epoch = self.early_stopper.as_ref().unwrap().stopped_epoch().unwrap();
283 let best_score = self.early_stopper.as_ref().unwrap().best_score();
284 println!("Early stopping triggered at epoch {} (best score: {:.6})", stopped_epoch, best_score);
285
286 if let Some(ref early_stopper) = self.early_stopper {
288 if let Err(e) = early_stopper.restore_best_weights(&mut self.network) {
289 println!("Warning: Could not restore best weights: {}", e);
290 } else {
291 println!("Restored best weights from epoch with score {:.6}", best_score);
292 }
293 }
294 break;
295 }
296 }
297
298 println!("Training completed!");
299 }
300
301 pub fn evaluate(&mut self, data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)]) -> f64 {
303 self.network.eval();
304
305 let mut total_loss = 0.0;
306 let mut total_samples = 0;
307
308 for (inputs, targets) in data {
309 if inputs.len() != targets.len() {
310 continue;
311 }
312
313 let (outputs, _) = self.network.forward_sequence_with_cache(inputs);
314
315 for ((output, _), target) in outputs.iter().zip(targets.iter()) {
316 let loss = self.loss_function.compute_loss(output, target);
317 total_loss += loss;
318 total_samples += 1;
319 }
320 }
321
322 if total_samples > 0 {
323 total_loss / total_samples as f64
324 } else {
325 0.0
326 }
327 }
328
329 pub fn predict(&mut self, inputs: &[Array2<f64>]) -> Vec<Array2<f64>> {
331 self.network.eval();
332
333 let (outputs, _) = self.network.forward_sequence_with_cache(inputs);
334 outputs.into_iter().map(|(output, _)| output).collect()
335 }
336
337 fn clip_gradients(&self, gradients: &mut [crate::layers::lstm_cell::LSTMCellGradients], max_norm: f64) {
339 for gradient in gradients.iter_mut() {
340 self.clip_gradient_matrix(&mut gradient.w_ih, max_norm);
341 self.clip_gradient_matrix(&mut gradient.w_hh, max_norm);
342 self.clip_gradient_matrix(&mut gradient.b_ih, max_norm);
343 self.clip_gradient_matrix(&mut gradient.b_hh, max_norm);
344 }
345 }
346
347 fn clip_gradient_matrix(&self, matrix: &mut Array2<f64>, max_norm: f64) {
348 let norm = (&*matrix * &*matrix).sum().sqrt();
349 if norm > max_norm {
350 let scale = max_norm / norm;
351 *matrix = matrix.map(|x| x * scale);
352 }
353 }
354
355 pub fn get_latest_metrics(&self) -> Option<&TrainingMetrics> {
356 self.metrics_history.last()
357 }
358
359 pub fn get_metrics_history(&self) -> &[TrainingMetrics] {
360 &self.metrics_history
361 }
362
363 pub fn set_training_mode(&mut self, training: bool) {
365 if training {
366 self.network.train();
367 } else {
368 self.network.eval();
369 }
370 }
371}
372
373pub struct ScheduledLSTMTrainer<L: LossFunction, O: Optimizer, S: LearningRateScheduler> {
375 pub network: LSTMNetwork,
376 pub loss_function: L,
377 pub optimizer: ScheduledOptimizer<O, S>,
378 pub config: TrainingConfig,
379 pub metrics_history: Vec<TrainingMetrics>,
380 early_stopper: Option<EarlyStopper>,
381}
382
383impl<L: LossFunction, O: Optimizer, S: LearningRateScheduler> ScheduledLSTMTrainer<L, O, S> {
384 pub fn new(network: LSTMNetwork, loss_function: L, optimizer: ScheduledOptimizer<O, S>) -> Self {
385 ScheduledLSTMTrainer {
386 network,
387 loss_function,
388 optimizer,
389 config: TrainingConfig::default(),
390 metrics_history: Vec::new(),
391 early_stopper: None,
392 }
393 }
394
395 pub fn with_config(mut self, config: TrainingConfig) -> Self {
396 self.early_stopper = config.early_stopping.as_ref().map(|es_config| {
398 EarlyStopper::new(es_config.clone())
399 });
400 self.config = config;
401 self
402 }
403
404 pub fn train_sequence(&mut self, inputs: &[Array2<f64>], targets: &[Array2<f64>]) -> f64 {
406 if inputs.len() != targets.len() {
407 panic!("Inputs and targets must have the same length");
408 }
409
410 self.network.train();
411
412 let (outputs, caches) = self.network.forward_sequence_with_cache(inputs);
413
414 let mut total_loss = 0.0;
415 let mut total_gradients = self.network.zero_gradients();
416
417 for (i, ((output, _), target)) in outputs.iter().zip(targets.iter()).enumerate().rev() {
418 let loss = self.loss_function.compute_loss(output, target);
419 total_loss += loss;
420
421 let dhy = self.loss_function.compute_gradient(output, target);
422 let dcy = Array2::zeros(output.raw_dim());
423
424 let (step_gradients, _) = self.network.backward(&dhy, &dcy, &caches[i]);
425
426 for (total_grad, step_grad) in total_gradients.iter_mut().zip(step_gradients.iter()) {
427 total_grad.w_ih = &total_grad.w_ih + &step_grad.w_ih;
428 total_grad.w_hh = &total_grad.w_hh + &step_grad.w_hh;
429 total_grad.b_ih = &total_grad.b_ih + &step_grad.b_ih;
430 total_grad.b_hh = &total_grad.b_hh + &step_grad.b_hh;
431 }
432 }
433
434 if let Some(clip_value) = self.config.clip_gradient {
435 self.clip_gradients(&mut total_gradients, clip_value);
436 }
437
438 self.network.update_parameters(&total_gradients, &mut self.optimizer);
439
440 total_loss / inputs.len() as f64
441 }
442
443 pub fn train(&mut self, train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)],
445 validation_data: Option<&[(Vec<Array2<f64>>, Vec<Array2<f64>>)]>) {
446
447 println!("Starting training for {} epochs with {} scheduler...",
448 self.config.epochs, self.optimizer.scheduler_name());
449
450 for epoch in 0..self.config.epochs {
451 let start_time = Instant::now();
452 let mut epoch_loss = 0.0;
453
454 self.network.train();
456 for (inputs, targets) in train_data {
457 let loss = self.train_sequence(inputs, targets);
458 epoch_loss += loss;
459 }
460 epoch_loss /= train_data.len() as f64;
461
462 let validation_loss = if let Some(val_data) = validation_data {
463 self.network.eval();
464 Some(self.evaluate(val_data))
465 } else {
466 None
467 };
468
469 let prev_lr = self.optimizer.get_learning_rate();
471 if let Some(val_loss) = validation_loss {
472 self.optimizer.step_with_val_loss(val_loss);
473 } else {
474 self.optimizer.step();
475 }
476 let new_lr = self.optimizer.get_learning_rate();
477
478 if self.config.log_lr_changes && (new_lr - prev_lr).abs() > 1e-10 {
480 println!("Learning rate changed from {:.2e} to {:.2e}", prev_lr, new_lr);
481 }
482
483 let time_elapsed = start_time.elapsed().as_secs_f64();
484
485 let metrics = TrainingMetrics {
486 epoch,
487 train_loss: epoch_loss,
488 validation_loss,
489 time_elapsed,
490 learning_rate: new_lr,
491 };
492
493 self.metrics_history.push(metrics.clone());
494
495 let mut should_stop = false;
497 let mut is_best = false;
498 if let Some(ref mut early_stopper) = self.early_stopper {
499 let (stop, best) = early_stopper.should_stop(&metrics, &self.network);
500 should_stop = stop;
501 is_best = best;
502 }
503
504 if epoch % self.config.print_every == 0 {
505 let best_indicator = if is_best { " *" } else { "" };
506 if let Some(val_loss) = validation_loss {
507 println!("Epoch {}: Train Loss: {:.6}, Val Loss: {:.6}, LR: {:.2e}, Time: {:.2}s{}",
508 epoch, epoch_loss, val_loss, new_lr, time_elapsed, best_indicator);
509 } else {
510 println!("Epoch {}: Train Loss: {:.6}, LR: {:.2e}, Time: {:.2}s{}",
511 epoch, epoch_loss, new_lr, time_elapsed, best_indicator);
512 }
513 }
514
515 if should_stop {
516 let stopped_epoch = self.early_stopper.as_ref().unwrap().stopped_epoch().unwrap();
517 let best_score = self.early_stopper.as_ref().unwrap().best_score();
518 println!("Early stopping triggered at epoch {} (best score: {:.6})", stopped_epoch, best_score);
519
520 if let Some(ref early_stopper) = self.early_stopper {
522 if let Err(e) = early_stopper.restore_best_weights(&mut self.network) {
523 println!("Warning: Could not restore best weights: {}", e);
524 } else {
525 println!("Restored best weights from epoch with score {:.6}", best_score);
526 }
527 }
528 break;
529 }
530 }
531
532 println!("Training completed!");
533 }
534
535 pub fn evaluate(&mut self, data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)]) -> f64 {
537 self.network.eval();
538
539 let mut total_loss = 0.0;
540 let mut total_samples = 0;
541
542 for (inputs, targets) in data {
543 if inputs.len() != targets.len() {
544 continue;
545 }
546
547 let (outputs, _) = self.network.forward_sequence_with_cache(inputs);
548
549 for ((output, _), target) in outputs.iter().zip(targets.iter()) {
550 let loss = self.loss_function.compute_loss(output, target);
551 total_loss += loss;
552 total_samples += 1;
553 }
554 }
555
556 if total_samples > 0 {
557 total_loss / total_samples as f64
558 } else {
559 0.0
560 }
561 }
562
563 pub fn predict(&mut self, inputs: &[Array2<f64>]) -> Vec<Array2<f64>> {
565 self.network.eval();
566
567 let (outputs, _) = self.network.forward_sequence_with_cache(inputs);
568 outputs.into_iter().map(|(output, _)| output).collect()
569 }
570
571 fn clip_gradients(&self, gradients: &mut [crate::layers::lstm_cell::LSTMCellGradients], max_norm: f64) {
573 for gradient in gradients.iter_mut() {
574 self.clip_gradient_matrix(&mut gradient.w_ih, max_norm);
575 self.clip_gradient_matrix(&mut gradient.w_hh, max_norm);
576 self.clip_gradient_matrix(&mut gradient.b_ih, max_norm);
577 self.clip_gradient_matrix(&mut gradient.b_hh, max_norm);
578 }
579 }
580
581 fn clip_gradient_matrix(&self, matrix: &mut Array2<f64>, max_norm: f64) {
582 let norm = (&*matrix * &*matrix).sum().sqrt();
583 if norm > max_norm {
584 let scale = max_norm / norm;
585 *matrix = matrix.map(|x| x * scale);
586 }
587 }
588
589 pub fn get_latest_metrics(&self) -> Option<&TrainingMetrics> {
590 self.metrics_history.last()
591 }
592
593 pub fn get_metrics_history(&self) -> &[TrainingMetrics] {
594 &self.metrics_history
595 }
596
597 pub fn set_training_mode(&mut self, training: bool) {
599 if training {
600 self.network.train();
601 } else {
602 self.network.eval();
603 }
604 }
605
606 pub fn get_current_lr(&self) -> f64 {
608 self.optimizer.get_current_lr()
609 }
610
611 pub fn get_current_epoch(&self) -> usize {
613 self.optimizer.get_current_epoch()
614 }
615
616 pub fn reset_optimizer(&mut self) {
618 self.optimizer.reset();
619 }
620}
621
622pub struct LSTMBatchTrainer<L: LossFunction, O: Optimizer> {
625 pub network: LSTMNetwork,
626 pub loss_function: L,
627 pub optimizer: O,
628 pub config: TrainingConfig,
629 pub metrics_history: Vec<TrainingMetrics>,
630 early_stopper: Option<EarlyStopper>,
631}
632
633impl<L: LossFunction, O: Optimizer> LSTMBatchTrainer<L, O> {
634 pub fn new(network: LSTMNetwork, loss_function: L, optimizer: O) -> Self {
635 LSTMBatchTrainer {
636 network,
637 loss_function,
638 optimizer,
639 config: TrainingConfig::default(),
640 metrics_history: Vec::new(),
641 early_stopper: None,
642 }
643 }
644
645 pub fn with_config(mut self, config: TrainingConfig) -> Self {
646 self.early_stopper = config.early_stopping.as_ref().map(|es_config| {
648 EarlyStopper::new(es_config.clone())
649 });
650 self.config = config;
651 self
652 }
653
654 pub fn train_batch(&mut self, batch_inputs: &[Vec<Array2<f64>>], batch_targets: &[Vec<Array2<f64>>]) -> f64 {
663 assert_eq!(batch_inputs.len(), batch_targets.len(), "Batch inputs and targets must have same length");
664
665 if batch_inputs.is_empty() {
666 return 0.0;
667 }
668
669 self.network.train();
670
671 let max_seq_len = batch_inputs.iter().map(|seq| seq.len()).max().unwrap_or(0);
673 let batch_size = batch_inputs.len();
674
675 let mut total_loss = 0.0;
676 let mut total_gradients = self.network.zero_gradients();
677 let mut valid_steps = 0;
678
679 let mut batch_hx = Array2::zeros((self.network.hidden_size, batch_size));
681 let mut batch_cx = Array2::zeros((self.network.hidden_size, batch_size));
682
683 for t in 0..max_seq_len {
685 let mut batch_input = Array2::zeros((self.network.input_size, batch_size));
687 let mut batch_target = Array2::zeros((self.network.hidden_size, batch_size));
688 let mut active_sequences = Vec::new();
689
690 for (batch_idx, (input_seq, target_seq)) in batch_inputs.iter().zip(batch_targets.iter()).enumerate() {
692 if t < input_seq.len() && t < target_seq.len() {
693 batch_input.column_mut(batch_idx).assign(&input_seq[t].column(0));
694 batch_target.column_mut(batch_idx).assign(&target_seq[t].column(0));
695 active_sequences.push(batch_idx);
696 }
697 }
698
699 if active_sequences.is_empty() {
700 break;
701 }
702
703 let (new_batch_hx, new_batch_cx, cache) = self.network.forward_batch_with_cache(&batch_input, &batch_hx, &batch_cx);
705
706 let active_predictions = if active_sequences.len() == batch_size {
708 new_batch_hx.clone()
709 } else {
710 let mut active_preds = Array2::zeros((self.network.hidden_size, active_sequences.len()));
711 for (idx, &batch_idx) in active_sequences.iter().enumerate() {
712 active_preds.column_mut(idx).assign(&new_batch_hx.column(batch_idx));
713 }
714 active_preds
715 };
716
717 let active_targets = if active_sequences.len() == batch_size {
718 batch_target.clone()
719 } else {
720 let mut active_targs = Array2::zeros((self.network.hidden_size, active_sequences.len()));
721 for (idx, &batch_idx) in active_sequences.iter().enumerate() {
722 active_targs.column_mut(idx).assign(&batch_target.column(batch_idx));
723 }
724 active_targs
725 };
726
727 let step_loss = self.loss_function.compute_batch_loss(&active_predictions, &active_targets);
728 total_loss += step_loss;
729 valid_steps += 1;
730
731 let dhy = self.loss_function.compute_batch_gradient(&active_predictions, &active_targets);
733 let _dcy = Array2::<f64>::zeros(dhy.raw_dim());
734
735 let full_dhy = if active_sequences.len() == batch_size {
737 dhy
738 } else {
739 let mut full_grad = Array2::zeros((self.network.hidden_size, batch_size));
740 for (idx, &batch_idx) in active_sequences.iter().enumerate() {
741 full_grad.column_mut(batch_idx).assign(&dhy.column(idx));
742 }
743 full_grad
744 };
745
746 let full_dcy = Array2::<f64>::zeros(full_dhy.raw_dim());
747
748 let (step_gradients, _) = self.network.backward_batch(&full_dhy, &full_dcy, &cache);
750
751 for (total_grad, step_grad) in total_gradients.iter_mut().zip(step_gradients.iter()) {
753 total_grad.w_ih = &total_grad.w_ih + &step_grad.w_ih;
754 total_grad.w_hh = &total_grad.w_hh + &step_grad.w_hh;
755 total_grad.b_ih = &total_grad.b_ih + &step_grad.b_ih;
756 total_grad.b_hh = &total_grad.b_hh + &step_grad.b_hh;
757 }
758
759 batch_hx = new_batch_hx;
761 batch_cx = new_batch_cx;
762 }
763
764 if let Some(clip_value) = self.config.clip_gradient {
766 self.clip_gradients(&mut total_gradients, clip_value);
767 }
768
769 self.network.update_parameters(&total_gradients, &mut self.optimizer);
771
772 if valid_steps > 0 {
773 total_loss / valid_steps as f64
774 } else {
775 0.0
776 }
777 }
778
779 pub fn train(&mut self,
786 train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)],
787 validation_data: Option<&[(Vec<Array2<f64>>, Vec<Array2<f64>>)]>,
788 batch_size: usize) {
789
790 println!("Starting batch training for {} epochs with batch size {}...",
791 self.config.epochs, batch_size);
792
793 for epoch in 0..self.config.epochs {
794 let start_time = Instant::now();
795 let mut epoch_loss = 0.0;
796 let mut num_batches = 0;
797
798 for batch_start in (0..train_data.len()).step_by(batch_size) {
800 let batch_end = (batch_start + batch_size).min(train_data.len());
801 let batch = &train_data[batch_start..batch_end];
802
803 let batch_inputs: Vec<_> = batch.iter().map(|(inputs, _)| inputs.clone()).collect();
804 let batch_targets: Vec<_> = batch.iter().map(|(_, targets)| targets.clone()).collect();
805
806 let batch_loss = self.train_batch(&batch_inputs, &batch_targets);
807 epoch_loss += batch_loss;
808 num_batches += 1;
809 }
810
811 epoch_loss /= num_batches as f64;
812
813 let validation_loss = if let Some(val_data) = validation_data {
815 self.network.eval();
816 Some(self.evaluate_batch(val_data, batch_size))
817 } else {
818 None
819 };
820
821 let time_elapsed = start_time.elapsed().as_secs_f64();
822 let current_lr = self.optimizer.get_learning_rate();
823
824 let metrics = TrainingMetrics {
825 epoch,
826 train_loss: epoch_loss,
827 validation_loss,
828 time_elapsed,
829 learning_rate: current_lr,
830 };
831
832 self.metrics_history.push(metrics.clone());
833
834 let mut should_stop = false;
836 let mut is_best = false;
837 if let Some(ref mut early_stopper) = self.early_stopper {
838 let (stop, best) = early_stopper.should_stop(&metrics, &self.network);
839 should_stop = stop;
840 is_best = best;
841 }
842
843 if epoch % self.config.print_every == 0 {
844 let best_indicator = if is_best { " *" } else { "" };
845 if let Some(val_loss) = validation_loss {
846 println!("Epoch {}: Train Loss: {:.6}, Val Loss: {:.6}, LR: {:.2e}, Time: {:.2}s, Batches: {}{}",
847 epoch, epoch_loss, val_loss, current_lr, time_elapsed, num_batches, best_indicator);
848 } else {
849 println!("Epoch {}: Train Loss: {:.6}, LR: {:.2e}, Time: {:.2}s, Batches: {}{}",
850 epoch, epoch_loss, current_lr, time_elapsed, num_batches, best_indicator);
851 }
852 }
853
854 if should_stop {
855 let stopped_epoch = self.early_stopper.as_ref().unwrap().stopped_epoch().unwrap();
856 let best_score = self.early_stopper.as_ref().unwrap().best_score();
857 println!("Early stopping triggered at epoch {} (best score: {:.6})", stopped_epoch, best_score);
858
859 if let Some(ref early_stopper) = self.early_stopper {
861 if let Err(e) = early_stopper.restore_best_weights(&mut self.network) {
862 println!("Warning: Could not restore best weights: {}", e);
863 } else {
864 println!("Restored best weights from epoch with score {:.6}", best_score);
865 }
866 }
867 break;
868 }
869 }
870
871 println!("Batch training completed!");
872 }
873
874 pub fn evaluate_batch(&mut self, data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)], batch_size: usize) -> f64 {
876 self.network.eval();
877
878 let mut total_loss = 0.0;
879 let mut num_batches = 0;
880
881 for batch_start in (0..data.len()).step_by(batch_size) {
882 let batch_end = (batch_start + batch_size).min(data.len());
883 let batch = &data[batch_start..batch_end];
884
885 let batch_inputs: Vec<_> = batch.iter().map(|(inputs, _)| inputs.clone()).collect();
886 let batch_targets: Vec<_> = batch.iter().map(|(_, targets)| targets.clone()).collect();
887
888 let batch_outputs = self.network.forward_batch_sequences(&batch_inputs);
890
891 let mut batch_loss = 0.0;
892 let mut valid_samples = 0;
893
894 for (outputs, targets) in batch_outputs.iter().zip(batch_targets.iter()) {
895 for ((output, _), target) in outputs.iter().zip(targets.iter()) {
896 let loss = self.loss_function.compute_loss(output, target);
897 batch_loss += loss;
898 valid_samples += 1;
899 }
900 }
901
902 if valid_samples > 0 {
903 total_loss += batch_loss / valid_samples as f64;
904 num_batches += 1;
905 }
906 }
907
908 if num_batches > 0 {
909 total_loss / num_batches as f64
910 } else {
911 0.0
912 }
913 }
914
915 pub fn predict_batch(&mut self, inputs: &[Vec<Array2<f64>>]) -> Vec<Vec<Array2<f64>>> {
917 self.network.eval();
918
919 let batch_outputs = self.network.forward_batch_sequences(inputs);
920 batch_outputs.into_iter()
921 .map(|sequence_outputs| sequence_outputs.into_iter().map(|(output, _)| output).collect())
922 .collect()
923 }
924
925 fn clip_gradients(&self, gradients: &mut [crate::layers::lstm_cell::LSTMCellGradients], max_norm: f64) {
927 for gradient in gradients.iter_mut() {
928 self.clip_gradient_matrix(&mut gradient.w_ih, max_norm);
929 self.clip_gradient_matrix(&mut gradient.w_hh, max_norm);
930 self.clip_gradient_matrix(&mut gradient.b_ih, max_norm);
931 self.clip_gradient_matrix(&mut gradient.b_hh, max_norm);
932 }
933 }
934
935 fn clip_gradient_matrix(&self, matrix: &mut Array2<f64>, max_norm: f64) {
936 let norm = (&*matrix * &*matrix).sum().sqrt();
937 if norm > max_norm {
938 let scale = max_norm / norm;
939 *matrix = matrix.map(|x| x * scale);
940 }
941 }
942
943 pub fn get_latest_metrics(&self) -> Option<&TrainingMetrics> {
944 self.metrics_history.last()
945 }
946
947 pub fn get_metrics_history(&self) -> &[TrainingMetrics] {
948 &self.metrics_history
949 }
950
951 pub fn set_training_mode(&mut self, training: bool) {
952 if training {
953 self.network.train();
954 } else {
955 self.network.eval();
956 }
957 }
958}
959
960pub fn create_basic_trainer(network: LSTMNetwork, learning_rate: f64) -> LSTMTrainer<MSELoss, SGD> {
962 let loss_function = MSELoss;
963 let optimizer = SGD::new(learning_rate);
964 LSTMTrainer::new(network, loss_function, optimizer)
965}
966
967pub fn create_step_lr_trainer(
969 network: LSTMNetwork,
970 learning_rate: f64,
971 step_size: usize,
972 gamma: f64
973) -> ScheduledLSTMTrainer<MSELoss, SGD, crate::schedulers::StepLR> {
974 let loss_function = MSELoss;
975 let optimizer = ScheduledOptimizer::step_lr(SGD::new(learning_rate), learning_rate, step_size, gamma);
976 ScheduledLSTMTrainer::new(network, loss_function, optimizer)
977}
978
979pub fn create_one_cycle_trainer(
981 network: LSTMNetwork,
982 max_lr: f64,
983 total_steps: usize
984) -> ScheduledLSTMTrainer<MSELoss, crate::optimizers::Adam, crate::schedulers::OneCycleLR> {
985 let loss_function = MSELoss;
986 let optimizer = ScheduledOptimizer::one_cycle(
987 crate::optimizers::Adam::new(max_lr),
988 max_lr,
989 total_steps
990 );
991 ScheduledLSTMTrainer::new(network, loss_function, optimizer)
992}
993
994pub fn create_cosine_annealing_trainer(
996 network: LSTMNetwork,
997 learning_rate: f64,
998 t_max: usize,
999 eta_min: f64
1000) -> ScheduledLSTMTrainer<MSELoss, crate::optimizers::Adam, crate::schedulers::CosineAnnealingLR> {
1001 let loss_function = MSELoss;
1002 let optimizer = crate::optimizers::Adam::new(learning_rate);
1003 let scheduler = crate::schedulers::CosineAnnealingLR::new(t_max, eta_min);
1004 let scheduled_optimizer = crate::optimizers::ScheduledOptimizer::new(optimizer, scheduler, learning_rate);
1005
1006 ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer)
1007}
1008
1009pub fn create_basic_batch_trainer(network: LSTMNetwork, learning_rate: f64) -> LSTMBatchTrainer<MSELoss, SGD> {
1011 let loss_function = MSELoss;
1012 let optimizer = SGD::new(learning_rate);
1013 LSTMBatchTrainer::new(network, loss_function, optimizer)
1014}
1015
1016pub fn create_adam_batch_trainer(network: LSTMNetwork, learning_rate: f64) -> LSTMBatchTrainer<MSELoss, crate::optimizers::Adam> {
1018 let loss_function = MSELoss;
1019 let optimizer = crate::optimizers::Adam::new(learning_rate);
1020 LSTMBatchTrainer::new(network, loss_function, optimizer)
1021}
1022
1023#[cfg(test)]
1024mod tests {
1025 use super::*;
1026 use ndarray::arr2;
1027
1028 #[test]
1029 fn test_trainer_creation() {
1030 let network = LSTMNetwork::new(2, 3, 1);
1031 let trainer = create_basic_trainer(network, 0.01);
1032
1033 assert_eq!(trainer.network.input_size, 2);
1034 assert_eq!(trainer.network.hidden_size, 3);
1035 assert_eq!(trainer.network.num_layers, 1);
1036 }
1037
1038 #[test]
1039 fn test_sequence_training() {
1040 let network = LSTMNetwork::new(2, 3, 1);
1041 let mut trainer = create_basic_trainer(network, 0.01);
1042
1043 let inputs = vec![
1044 arr2(&[[1.0], [0.0]]),
1045 arr2(&[[0.0], [1.0]]),
1046 ];
1047 let targets = vec![
1048 arr2(&[[1.0], [0.0], [0.0]]),
1049 arr2(&[[0.0], [1.0], [0.0]]),
1050 ];
1051
1052 let loss = trainer.train_sequence(&inputs, &targets);
1053 assert!(loss >= 0.0);
1054 }
1055}