rust_lstm/
training.rs

1use ndarray::Array2;
2use crate::models::lstm_network::LSTMNetwork;
3use crate::loss::{LossFunction, MSELoss};
4use crate::optimizers::{Optimizer, SGD, ScheduledOptimizer};
5use crate::schedulers::LearningRateScheduler;
6use crate::persistence::SerializableLSTMNetwork;
7use std::time::Instant;
8
9/// Configuration for training hyperparameters
10pub struct TrainingConfig {
11    pub epochs: usize,
12    pub print_every: usize,
13    pub clip_gradient: Option<f64>,
14    pub log_lr_changes: bool,
15    pub early_stopping: Option<EarlyStoppingConfig>,
16}
17
18/// Configuration for early stopping
19#[derive(Debug, Clone)]
20pub struct EarlyStoppingConfig {
21    /// Number of epochs with no improvement after which training will be stopped
22    pub patience: usize,
23    /// Minimum change in the monitored quantity to qualify as an improvement
24    pub min_delta: f64,
25    /// Whether to restore the best weights when early stopping triggers
26    pub restore_best_weights: bool,
27    /// Metric to monitor for early stopping ('val_loss' or 'train_loss')
28    pub monitor: EarlyStoppingMetric,
29}
30
31/// Metric to monitor for early stopping
32#[derive(Debug, Clone, PartialEq)]
33pub enum EarlyStoppingMetric {
34    ValidationLoss,
35    TrainLoss,
36}
37
38impl Default for EarlyStoppingConfig {
39    fn default() -> Self {
40        EarlyStoppingConfig {
41            patience: 10,
42            min_delta: 1e-4,
43            restore_best_weights: true,
44            monitor: EarlyStoppingMetric::ValidationLoss,
45        }
46    }
47}
48
49impl Default for TrainingConfig {
50    fn default() -> Self {
51        TrainingConfig {
52            epochs: 100,
53            print_every: 10,
54            clip_gradient: Some(5.0),
55            log_lr_changes: true,
56            early_stopping: None,
57        }
58    }
59}
60
61/// Training metrics tracked during training
62#[derive(Debug, Clone)]
63pub struct TrainingMetrics {
64    pub epoch: usize,
65    pub train_loss: f64,
66    pub validation_loss: Option<f64>,
67    pub time_elapsed: f64,
68    pub learning_rate: f64,
69}
70
71/// Early stopping state tracker
72#[derive(Debug, Clone)]
73pub struct EarlyStopper {
74    config: EarlyStoppingConfig,
75    best_score: f64,
76    wait_count: usize,
77    stopped_epoch: Option<usize>,
78    best_weights: Option<SerializableLSTMNetwork>, // Serialized network weights
79}
80
81impl EarlyStopper {
82    pub fn new(config: EarlyStoppingConfig) -> Self {
83        EarlyStopper {
84            config,
85            best_score: f64::INFINITY,
86            wait_count: 0,
87            stopped_epoch: None,
88            best_weights: None,
89        }
90    }
91
92    /// Check if training should stop based on current metrics
93    /// Returns (should_stop, is_best_score)
94    pub fn should_stop(&mut self, current_metrics: &TrainingMetrics, network: &LSTMNetwork) -> (bool, bool) {
95        let current_score = match self.config.monitor {
96            EarlyStoppingMetric::ValidationLoss => {
97                match current_metrics.validation_loss {
98                    Some(val_loss) => val_loss,
99                    None => {
100                        // If validation loss is not available, fall back to train loss
101                        current_metrics.train_loss
102                    }
103                }
104            }
105            EarlyStoppingMetric::TrainLoss => current_metrics.train_loss,
106        };
107
108        let is_improvement = current_score < self.best_score - self.config.min_delta;
109        
110        if is_improvement {
111            self.best_score = current_score;
112            self.wait_count = 0;
113            
114            // Save best weights if restore_best_weights is enabled
115            if self.config.restore_best_weights {
116                self.best_weights = Some(network.into());
117            }
118            
119            (false, true)
120        } else {
121            self.wait_count += 1;
122            
123            if self.wait_count >= self.config.patience {
124                self.stopped_epoch = Some(current_metrics.epoch);
125                (true, false)
126            } else {
127                (false, false)
128            }
129        }
130    }
131
132    /// Get the epoch where training was stopped
133    pub fn stopped_epoch(&self) -> Option<usize> {
134        self.stopped_epoch
135    }
136
137    /// Get the best score achieved
138    pub fn best_score(&self) -> f64 {
139        self.best_score
140    }
141
142    /// Restore the best weights to the network if available
143    pub fn restore_best_weights(&self, network: &mut LSTMNetwork) -> Result<(), String> {
144        if let Some(ref weights) = self.best_weights {
145            *network = weights.clone().into();
146            Ok(())
147        } else {
148            Err("No best weights available to restore".to_string())
149        }
150    }
151}
152
153/// Main trainer for LSTM networks with configurable loss and optimizer
154pub struct LSTMTrainer<L: LossFunction, O: Optimizer> {
155    pub network: LSTMNetwork,
156    pub loss_function: L,
157    pub optimizer: O,
158    pub config: TrainingConfig,
159    pub metrics_history: Vec<TrainingMetrics>,
160    early_stopper: Option<EarlyStopper>,
161}
162
163impl<L: LossFunction, O: Optimizer> LSTMTrainer<L, O> {
164    pub fn new(network: LSTMNetwork, loss_function: L, optimizer: O) -> Self {
165        LSTMTrainer {
166            network,
167            loss_function,
168            optimizer,
169            config: TrainingConfig::default(),
170            metrics_history: Vec::new(),
171            early_stopper: None,
172        }
173    }
174
175    pub fn with_config(mut self, config: TrainingConfig) -> Self {
176        // Initialize early stopper if early stopping is configured
177        self.early_stopper = config.early_stopping.as_ref().map(|es_config| {
178            EarlyStopper::new(es_config.clone())
179        });
180        self.config = config;
181        self
182    }
183
184    /// Train on a single sequence using backpropagation through time (BPTT)
185    pub fn train_sequence(&mut self, inputs: &[Array2<f64>], targets: &[Array2<f64>]) -> f64 {
186        if inputs.len() != targets.len() {
187            panic!("Inputs and targets must have the same length");
188        }
189
190        self.network.train();
191
192        let (outputs, caches) = self.network.forward_sequence_with_cache(inputs);
193        
194        let mut total_loss = 0.0;
195        let mut total_gradients = self.network.zero_gradients();
196
197        for (i, ((output, _), target)) in outputs.iter().zip(targets.iter()).enumerate().rev() {
198            let loss = self.loss_function.compute_loss(output, target);
199            total_loss += loss;
200
201            let dhy = self.loss_function.compute_gradient(output, target);
202            let dcy = Array2::zeros(output.raw_dim());
203
204            let (step_gradients, _) = self.network.backward(&dhy, &dcy, &caches[i]);
205
206            for (total_grad, step_grad) in total_gradients.iter_mut().zip(step_gradients.iter()) {
207                total_grad.w_ih = &total_grad.w_ih + &step_grad.w_ih;
208                total_grad.w_hh = &total_grad.w_hh + &step_grad.w_hh;
209                total_grad.b_ih = &total_grad.b_ih + &step_grad.b_ih;
210                total_grad.b_hh = &total_grad.b_hh + &step_grad.b_hh;
211            }
212        }
213
214        if let Some(clip_value) = self.config.clip_gradient {
215            self.clip_gradients(&mut total_gradients, clip_value);
216        }
217
218        self.network.update_parameters(&total_gradients, &mut self.optimizer);
219
220        total_loss / inputs.len() as f64
221    }
222
223    /// Train for multiple epochs with optional validation
224    pub fn train(&mut self, train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)], 
225                 validation_data: Option<&[(Vec<Array2<f64>>, Vec<Array2<f64>>)]>) {
226        
227        println!("Starting training for {} epochs...", self.config.epochs);
228        
229        for epoch in 0..self.config.epochs {
230            let start_time = Instant::now();
231            let mut epoch_loss = 0.0;
232
233            // Training phase
234            self.network.train();
235            for (inputs, targets) in train_data {
236                let loss = self.train_sequence(inputs, targets);
237                epoch_loss += loss;
238            }
239            epoch_loss /= train_data.len() as f64;
240
241            let validation_loss = if let Some(val_data) = validation_data {
242                self.network.eval();
243                Some(self.evaluate(val_data))
244            } else {
245                None
246            };
247
248            let time_elapsed = start_time.elapsed().as_secs_f64();
249
250            let current_lr = self.optimizer.get_learning_rate();
251            let metrics = TrainingMetrics {
252                epoch,
253                train_loss: epoch_loss,
254                validation_loss,
255                time_elapsed,
256                learning_rate: current_lr,
257            };
258
259            self.metrics_history.push(metrics.clone());
260
261            // Check early stopping
262            let mut should_stop = false;
263            let mut is_best = false;
264            if let Some(ref mut early_stopper) = self.early_stopper {
265                let (stop, best) = early_stopper.should_stop(&metrics, &self.network);
266                should_stop = stop;
267                is_best = best;
268            }
269
270            if epoch % self.config.print_every == 0 {
271                let best_indicator = if is_best { " *" } else { "" };
272                if let Some(val_loss) = validation_loss {
273                    println!("Epoch {}: Train Loss: {:.6}, Val Loss: {:.6}, LR: {:.2e}, Time: {:.2}s{}", 
274                             epoch, epoch_loss, val_loss, current_lr, time_elapsed, best_indicator);
275                } else {
276                    println!("Epoch {}: Train Loss: {:.6}, LR: {:.2e}, Time: {:.2}s{}", 
277                             epoch, epoch_loss, current_lr, time_elapsed, best_indicator);
278                }
279            }
280
281            if should_stop {
282                let stopped_epoch = self.early_stopper.as_ref().unwrap().stopped_epoch().unwrap();
283                let best_score = self.early_stopper.as_ref().unwrap().best_score();
284                println!("Early stopping triggered at epoch {} (best score: {:.6})", stopped_epoch, best_score);
285                
286                // Restore best weights if configured
287                if let Some(ref early_stopper) = self.early_stopper {
288                    if let Err(e) = early_stopper.restore_best_weights(&mut self.network) {
289                        println!("Warning: Could not restore best weights: {}", e);
290                    } else {
291                        println!("Restored best weights from epoch with score {:.6}", best_score);
292                    }
293                }
294                break;
295            }
296        }
297
298        println!("Training completed!");
299    }
300
301    /// Evaluate model performance on validation data
302    pub fn evaluate(&mut self, data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)]) -> f64 {
303        self.network.eval();
304        
305        let mut total_loss = 0.0;
306        let mut total_samples = 0;
307
308        for (inputs, targets) in data {
309            if inputs.len() != targets.len() {
310                continue;
311            }
312
313            let (outputs, _) = self.network.forward_sequence_with_cache(inputs);
314            
315            for ((output, _), target) in outputs.iter().zip(targets.iter()) {
316                let loss = self.loss_function.compute_loss(output, target);
317                total_loss += loss;
318                total_samples += 1;
319            }
320        }
321
322        if total_samples > 0 {
323            total_loss / total_samples as f64
324        } else {
325            0.0
326        }
327    }
328
329    /// Generate predictions for input sequences
330    pub fn predict(&mut self, inputs: &[Array2<f64>]) -> Vec<Array2<f64>> {
331        self.network.eval();
332        
333        let (outputs, _) = self.network.forward_sequence_with_cache(inputs);
334        outputs.into_iter().map(|(output, _)| output).collect()
335    }
336
337    /// Clip gradients by global norm to prevent exploding gradients
338    fn clip_gradients(&self, gradients: &mut [crate::layers::lstm_cell::LSTMCellGradients], max_norm: f64) {
339        for gradient in gradients.iter_mut() {
340            self.clip_gradient_matrix(&mut gradient.w_ih, max_norm);
341            self.clip_gradient_matrix(&mut gradient.w_hh, max_norm);
342            self.clip_gradient_matrix(&mut gradient.b_ih, max_norm);
343            self.clip_gradient_matrix(&mut gradient.b_hh, max_norm);
344        }
345    }
346
347    fn clip_gradient_matrix(&self, matrix: &mut Array2<f64>, max_norm: f64) {
348        let norm = (&*matrix * &*matrix).sum().sqrt();
349        if norm > max_norm {
350            let scale = max_norm / norm;
351            *matrix = matrix.map(|x| x * scale);
352        }
353    }
354
355    pub fn get_latest_metrics(&self) -> Option<&TrainingMetrics> {
356        self.metrics_history.last()
357    }
358
359    pub fn get_metrics_history(&self) -> &[TrainingMetrics] {
360        &self.metrics_history
361    }
362
363    /// Set network to training mode
364    pub fn set_training_mode(&mut self, training: bool) {
365        if training {
366            self.network.train();
367        } else {
368            self.network.eval();
369        }
370    }
371}
372
373/// Specialized trainer for scheduled optimizers that automatically steps the scheduler
374pub struct ScheduledLSTMTrainer<L: LossFunction, O: Optimizer, S: LearningRateScheduler> {
375    pub network: LSTMNetwork,
376    pub loss_function: L,
377    pub optimizer: ScheduledOptimizer<O, S>,
378    pub config: TrainingConfig,
379    pub metrics_history: Vec<TrainingMetrics>,
380    early_stopper: Option<EarlyStopper>,
381}
382
383impl<L: LossFunction, O: Optimizer, S: LearningRateScheduler> ScheduledLSTMTrainer<L, O, S> {
384    pub fn new(network: LSTMNetwork, loss_function: L, optimizer: ScheduledOptimizer<O, S>) -> Self {
385        ScheduledLSTMTrainer {
386            network,
387            loss_function,
388            optimizer,
389            config: TrainingConfig::default(),
390            metrics_history: Vec::new(),
391            early_stopper: None,
392        }
393    }
394
395    pub fn with_config(mut self, config: TrainingConfig) -> Self {
396        // Initialize early stopper if early stopping is configured
397        self.early_stopper = config.early_stopping.as_ref().map(|es_config| {
398            EarlyStopper::new(es_config.clone())
399        });
400        self.config = config;
401        self
402    }
403
404    /// Train on a single sequence using backpropagation through time (BPTT)
405    pub fn train_sequence(&mut self, inputs: &[Array2<f64>], targets: &[Array2<f64>]) -> f64 {
406        if inputs.len() != targets.len() {
407            panic!("Inputs and targets must have the same length");
408        }
409
410        self.network.train();
411
412        let (outputs, caches) = self.network.forward_sequence_with_cache(inputs);
413        
414        let mut total_loss = 0.0;
415        let mut total_gradients = self.network.zero_gradients();
416
417        for (i, ((output, _), target)) in outputs.iter().zip(targets.iter()).enumerate().rev() {
418            let loss = self.loss_function.compute_loss(output, target);
419            total_loss += loss;
420
421            let dhy = self.loss_function.compute_gradient(output, target);
422            let dcy = Array2::zeros(output.raw_dim());
423
424            let (step_gradients, _) = self.network.backward(&dhy, &dcy, &caches[i]);
425
426            for (total_grad, step_grad) in total_gradients.iter_mut().zip(step_gradients.iter()) {
427                total_grad.w_ih = &total_grad.w_ih + &step_grad.w_ih;
428                total_grad.w_hh = &total_grad.w_hh + &step_grad.w_hh;
429                total_grad.b_ih = &total_grad.b_ih + &step_grad.b_ih;
430                total_grad.b_hh = &total_grad.b_hh + &step_grad.b_hh;
431            }
432        }
433
434        if let Some(clip_value) = self.config.clip_gradient {
435            self.clip_gradients(&mut total_gradients, clip_value);
436        }
437
438        self.network.update_parameters(&total_gradients, &mut self.optimizer);
439
440        total_loss / inputs.len() as f64
441    }
442
443    /// Train for multiple epochs with automatic scheduler stepping
444    pub fn train(&mut self, train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)], 
445                 validation_data: Option<&[(Vec<Array2<f64>>, Vec<Array2<f64>>)]>) {
446        
447        println!("Starting training for {} epochs with {} scheduler...", 
448                 self.config.epochs, self.optimizer.scheduler_name());
449        
450        for epoch in 0..self.config.epochs {
451            let start_time = Instant::now();
452            let mut epoch_loss = 0.0;
453
454            // Training phase
455            self.network.train();
456            for (inputs, targets) in train_data {
457                let loss = self.train_sequence(inputs, targets);
458                epoch_loss += loss;
459            }
460            epoch_loss /= train_data.len() as f64;
461
462            let validation_loss = if let Some(val_data) = validation_data {
463                self.network.eval();
464                Some(self.evaluate(val_data))
465            } else {
466                None
467            };
468
469            // Step the scheduler at the end of each epoch
470            let prev_lr = self.optimizer.get_learning_rate();
471            if let Some(val_loss) = validation_loss {
472                self.optimizer.step_with_val_loss(val_loss);
473            } else {
474                self.optimizer.step();
475            }
476            let new_lr = self.optimizer.get_learning_rate();
477
478            // Log learning rate changes if enabled
479            if self.config.log_lr_changes && (new_lr - prev_lr).abs() > 1e-10 {
480                println!("Learning rate changed from {:.2e} to {:.2e}", prev_lr, new_lr);
481            }
482
483            let time_elapsed = start_time.elapsed().as_secs_f64();
484
485            let metrics = TrainingMetrics {
486                epoch,
487                train_loss: epoch_loss,
488                validation_loss,
489                time_elapsed,
490                learning_rate: new_lr,
491            };
492
493            self.metrics_history.push(metrics.clone());
494
495            // Check early stopping
496            let mut should_stop = false;
497            let mut is_best = false;
498            if let Some(ref mut early_stopper) = self.early_stopper {
499                let (stop, best) = early_stopper.should_stop(&metrics, &self.network);
500                should_stop = stop;
501                is_best = best;
502            }
503
504            if epoch % self.config.print_every == 0 {
505                let best_indicator = if is_best { " *" } else { "" };
506                if let Some(val_loss) = validation_loss {
507                    println!("Epoch {}: Train Loss: {:.6}, Val Loss: {:.6}, LR: {:.2e}, Time: {:.2}s{}", 
508                             epoch, epoch_loss, val_loss, new_lr, time_elapsed, best_indicator);
509                } else {
510                    println!("Epoch {}: Train Loss: {:.6}, LR: {:.2e}, Time: {:.2}s{}", 
511                             epoch, epoch_loss, new_lr, time_elapsed, best_indicator);
512                }
513            }
514
515            if should_stop {
516                let stopped_epoch = self.early_stopper.as_ref().unwrap().stopped_epoch().unwrap();
517                let best_score = self.early_stopper.as_ref().unwrap().best_score();
518                println!("Early stopping triggered at epoch {} (best score: {:.6})", stopped_epoch, best_score);
519                
520                // Restore best weights if configured
521                if let Some(ref early_stopper) = self.early_stopper {
522                    if let Err(e) = early_stopper.restore_best_weights(&mut self.network) {
523                        println!("Warning: Could not restore best weights: {}", e);
524                    } else {
525                        println!("Restored best weights from epoch with score {:.6}", best_score);
526                    }
527                }
528                break;
529            }
530        }
531
532        println!("Training completed!");
533    }
534
535    /// Evaluate model performance on validation data
536    pub fn evaluate(&mut self, data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)]) -> f64 {
537        self.network.eval();
538        
539        let mut total_loss = 0.0;
540        let mut total_samples = 0;
541
542        for (inputs, targets) in data {
543            if inputs.len() != targets.len() {
544                continue;
545            }
546
547            let (outputs, _) = self.network.forward_sequence_with_cache(inputs);
548            
549            for ((output, _), target) in outputs.iter().zip(targets.iter()) {
550                let loss = self.loss_function.compute_loss(output, target);
551                total_loss += loss;
552                total_samples += 1;
553            }
554        }
555
556        if total_samples > 0 {
557            total_loss / total_samples as f64
558        } else {
559            0.0
560        }
561    }
562
563    /// Generate predictions for input sequences
564    pub fn predict(&mut self, inputs: &[Array2<f64>]) -> Vec<Array2<f64>> {
565        self.network.eval();
566        
567        let (outputs, _) = self.network.forward_sequence_with_cache(inputs);
568        outputs.into_iter().map(|(output, _)| output).collect()
569    }
570
571    /// Clip gradients by global norm to prevent exploding gradients
572    fn clip_gradients(&self, gradients: &mut [crate::layers::lstm_cell::LSTMCellGradients], max_norm: f64) {
573        for gradient in gradients.iter_mut() {
574            self.clip_gradient_matrix(&mut gradient.w_ih, max_norm);
575            self.clip_gradient_matrix(&mut gradient.w_hh, max_norm);
576            self.clip_gradient_matrix(&mut gradient.b_ih, max_norm);
577            self.clip_gradient_matrix(&mut gradient.b_hh, max_norm);
578        }
579    }
580
581    fn clip_gradient_matrix(&self, matrix: &mut Array2<f64>, max_norm: f64) {
582        let norm = (&*matrix * &*matrix).sum().sqrt();
583        if norm > max_norm {
584            let scale = max_norm / norm;
585            *matrix = matrix.map(|x| x * scale);
586        }
587    }
588
589    pub fn get_latest_metrics(&self) -> Option<&TrainingMetrics> {
590        self.metrics_history.last()
591    }
592
593    pub fn get_metrics_history(&self) -> &[TrainingMetrics] {
594        &self.metrics_history
595    }
596
597    /// Set network to training mode
598    pub fn set_training_mode(&mut self, training: bool) {
599        if training {
600            self.network.train();
601        } else {
602            self.network.eval();
603        }
604    }
605
606    /// Get the current learning rate
607    pub fn get_current_lr(&self) -> f64 {
608        self.optimizer.get_current_lr()
609    }
610
611    /// Get the current epoch from the scheduler
612    pub fn get_current_epoch(&self) -> usize {
613        self.optimizer.get_current_epoch()
614    }
615
616    /// Reset the optimizer and scheduler
617    pub fn reset_optimizer(&mut self) {
618        self.optimizer.reset();
619    }
620}
621
622/// Batch trainer for LSTM networks with configurable loss and optimizer
623/// Processes multiple sequences simultaneously for improved performance
624pub struct LSTMBatchTrainer<L: LossFunction, O: Optimizer> {
625    pub network: LSTMNetwork,
626    pub loss_function: L,
627    pub optimizer: O,
628    pub config: TrainingConfig,
629    pub metrics_history: Vec<TrainingMetrics>,
630    early_stopper: Option<EarlyStopper>,
631}
632
633impl<L: LossFunction, O: Optimizer> LSTMBatchTrainer<L, O> {
634    pub fn new(network: LSTMNetwork, loss_function: L, optimizer: O) -> Self {
635        LSTMBatchTrainer {
636            network,
637            loss_function,
638            optimizer,
639            config: TrainingConfig::default(),
640            metrics_history: Vec::new(),
641            early_stopper: None,
642        }
643    }
644
645    pub fn with_config(mut self, config: TrainingConfig) -> Self {
646        // Initialize early stopper if early stopping is configured
647        self.early_stopper = config.early_stopping.as_ref().map(|es_config| {
648            EarlyStopper::new(es_config.clone())
649        });
650        self.config = config;
651        self
652    }
653
654    /// Train on a batch of sequences using batch processing
655    /// 
656    /// # Arguments
657    /// * `batch_inputs` - Vector of input sequences, each sequence is Vec<Array2<f64>>
658    /// * `batch_targets` - Vector of target sequences, each sequence is Vec<Array2<f64>>
659    /// 
660    /// # Returns
661    /// * Average loss across the batch
662    pub fn train_batch(&mut self, batch_inputs: &[Vec<Array2<f64>>], batch_targets: &[Vec<Array2<f64>>]) -> f64 {
663        assert_eq!(batch_inputs.len(), batch_targets.len(), "Batch inputs and targets must have same length");
664        
665        if batch_inputs.is_empty() {
666            return 0.0;
667        }
668
669        self.network.train();
670
671        // Find maximum sequence length for padding
672        let max_seq_len = batch_inputs.iter().map(|seq| seq.len()).max().unwrap_or(0);
673        let batch_size = batch_inputs.len();
674
675        let mut total_loss = 0.0;
676        let mut total_gradients = self.network.zero_gradients();
677        let mut valid_steps = 0;
678
679        // Initialize batch states
680        let mut batch_hx = Array2::zeros((self.network.hidden_size, batch_size));
681        let mut batch_cx = Array2::zeros((self.network.hidden_size, batch_size));
682
683        // Process each time step
684        for t in 0..max_seq_len {
685            // Prepare batch input and targets for current time step
686            let mut batch_input = Array2::zeros((self.network.input_size, batch_size));
687            let mut batch_target = Array2::zeros((self.network.hidden_size, batch_size));
688            let mut active_sequences = Vec::new();
689
690            // Collect active sequences for this time step
691            for (batch_idx, (input_seq, target_seq)) in batch_inputs.iter().zip(batch_targets.iter()).enumerate() {
692                if t < input_seq.len() && t < target_seq.len() {
693                    batch_input.column_mut(batch_idx).assign(&input_seq[t].column(0));
694                    batch_target.column_mut(batch_idx).assign(&target_seq[t].column(0));
695                    active_sequences.push(batch_idx);
696                }
697            }
698
699            if active_sequences.is_empty() {
700                break;
701            }
702
703            // Forward pass with caching for active sequences
704            let (new_batch_hx, new_batch_cx, cache) = self.network.forward_batch_with_cache(&batch_input, &batch_hx, &batch_cx);
705
706            // Compute loss only for active sequences
707            let active_predictions = if active_sequences.len() == batch_size {
708                new_batch_hx.clone()
709            } else {
710                let mut active_preds = Array2::zeros((self.network.hidden_size, active_sequences.len()));
711                for (idx, &batch_idx) in active_sequences.iter().enumerate() {
712                    active_preds.column_mut(idx).assign(&new_batch_hx.column(batch_idx));
713                }
714                active_preds
715            };
716
717            let active_targets = if active_sequences.len() == batch_size {
718                batch_target.clone()
719            } else {
720                let mut active_targs = Array2::zeros((self.network.hidden_size, active_sequences.len()));
721                for (idx, &batch_idx) in active_sequences.iter().enumerate() {
722                    active_targs.column_mut(idx).assign(&batch_target.column(batch_idx));
723                }
724                active_targs
725            };
726
727            let step_loss = self.loss_function.compute_batch_loss(&active_predictions, &active_targets);
728            total_loss += step_loss;
729            valid_steps += 1;
730
731            // Compute gradients
732            let dhy = self.loss_function.compute_batch_gradient(&active_predictions, &active_targets);
733            let _dcy = Array2::<f64>::zeros(dhy.raw_dim());
734
735            // Expand gradients back to full batch size if needed
736            let full_dhy = if active_sequences.len() == batch_size {
737                dhy
738            } else {
739                let mut full_grad = Array2::zeros((self.network.hidden_size, batch_size));
740                for (idx, &batch_idx) in active_sequences.iter().enumerate() {
741                    full_grad.column_mut(batch_idx).assign(&dhy.column(idx));
742                }
743                full_grad
744            };
745
746            let full_dcy = Array2::<f64>::zeros(full_dhy.raw_dim());
747
748            // Backward pass
749            let (step_gradients, _) = self.network.backward_batch(&full_dhy, &full_dcy, &cache);
750
751            // Accumulate gradients
752            for (total_grad, step_grad) in total_gradients.iter_mut().zip(step_gradients.iter()) {
753                total_grad.w_ih = &total_grad.w_ih + &step_grad.w_ih;
754                total_grad.w_hh = &total_grad.w_hh + &step_grad.w_hh;
755                total_grad.b_ih = &total_grad.b_ih + &step_grad.b_ih;
756                total_grad.b_hh = &total_grad.b_hh + &step_grad.b_hh;
757            }
758
759            // Update states
760            batch_hx = new_batch_hx;
761            batch_cx = new_batch_cx;
762        }
763
764        // Apply gradient clipping
765        if let Some(clip_value) = self.config.clip_gradient {
766            self.clip_gradients(&mut total_gradients, clip_value);
767        }
768
769        // Update parameters
770        self.network.update_parameters(&total_gradients, &mut self.optimizer);
771
772        if valid_steps > 0 {
773            total_loss / valid_steps as f64
774        } else {
775            0.0
776        }
777    }
778
779    /// Train for multiple epochs with batch processing
780    /// 
781    /// # Arguments
782    /// * `train_data` - Vector of (input_sequences, target_sequences) tuples for training
783    /// * `validation_data` - Optional validation data
784    /// * `batch_size` - Number of sequences to process in each batch
785    pub fn train(&mut self, 
786                 train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)], 
787                 validation_data: Option<&[(Vec<Array2<f64>>, Vec<Array2<f64>>)]>,
788                 batch_size: usize) {
789        
790        println!("Starting batch training for {} epochs with batch size {}...", 
791                 self.config.epochs, batch_size);
792        
793        for epoch in 0..self.config.epochs {
794            let start_time = Instant::now();
795            let mut epoch_loss = 0.0;
796            let mut num_batches = 0;
797
798            // Create batches
799            for batch_start in (0..train_data.len()).step_by(batch_size) {
800                let batch_end = (batch_start + batch_size).min(train_data.len());
801                let batch = &train_data[batch_start..batch_end];
802                
803                let batch_inputs: Vec<_> = batch.iter().map(|(inputs, _)| inputs.clone()).collect();
804                let batch_targets: Vec<_> = batch.iter().map(|(_, targets)| targets.clone()).collect();
805                
806                let batch_loss = self.train_batch(&batch_inputs, &batch_targets);
807                epoch_loss += batch_loss;
808                num_batches += 1;
809            }
810
811            epoch_loss /= num_batches as f64;
812
813            // Validation
814            let validation_loss = if let Some(val_data) = validation_data {
815                self.network.eval();
816                Some(self.evaluate_batch(val_data, batch_size))
817            } else {
818                None
819            };
820
821            let time_elapsed = start_time.elapsed().as_secs_f64();
822            let current_lr = self.optimizer.get_learning_rate();
823
824            let metrics = TrainingMetrics {
825                epoch,
826                train_loss: epoch_loss,
827                validation_loss,
828                time_elapsed,
829                learning_rate: current_lr,
830            };
831
832            self.metrics_history.push(metrics.clone());
833
834            // Check early stopping
835            let mut should_stop = false;
836            let mut is_best = false;
837            if let Some(ref mut early_stopper) = self.early_stopper {
838                let (stop, best) = early_stopper.should_stop(&metrics, &self.network);
839                should_stop = stop;
840                is_best = best;
841            }
842
843            if epoch % self.config.print_every == 0 {
844                let best_indicator = if is_best { " *" } else { "" };
845                if let Some(val_loss) = validation_loss {
846                    println!("Epoch {}: Train Loss: {:.6}, Val Loss: {:.6}, LR: {:.2e}, Time: {:.2}s, Batches: {}{}", 
847                             epoch, epoch_loss, val_loss, current_lr, time_elapsed, num_batches, best_indicator);
848                } else {
849                    println!("Epoch {}: Train Loss: {:.6}, LR: {:.2e}, Time: {:.2}s, Batches: {}{}", 
850                             epoch, epoch_loss, current_lr, time_elapsed, num_batches, best_indicator);
851                }
852            }
853
854            if should_stop {
855                let stopped_epoch = self.early_stopper.as_ref().unwrap().stopped_epoch().unwrap();
856                let best_score = self.early_stopper.as_ref().unwrap().best_score();
857                println!("Early stopping triggered at epoch {} (best score: {:.6})", stopped_epoch, best_score);
858                
859                // Restore best weights if configured
860                if let Some(ref early_stopper) = self.early_stopper {
861                    if let Err(e) = early_stopper.restore_best_weights(&mut self.network) {
862                        println!("Warning: Could not restore best weights: {}", e);
863                    } else {
864                        println!("Restored best weights from epoch with score {:.6}", best_score);
865                    }
866                }
867                break;
868            }
869        }
870
871        println!("Batch training completed!");
872    }
873
874    /// Evaluate model performance using batch processing
875    pub fn evaluate_batch(&mut self, data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)], batch_size: usize) -> f64 {
876        self.network.eval();
877        
878        let mut total_loss = 0.0;
879        let mut num_batches = 0;
880
881        for batch_start in (0..data.len()).step_by(batch_size) {
882            let batch_end = (batch_start + batch_size).min(data.len());
883            let batch = &data[batch_start..batch_end];
884            
885            let batch_inputs: Vec<_> = batch.iter().map(|(inputs, _)| inputs.clone()).collect();
886            let batch_targets: Vec<_> = batch.iter().map(|(_, targets)| targets.clone()).collect();
887            
888            // Process batch and compute loss (simplified evaluation)
889            let batch_outputs = self.network.forward_batch_sequences(&batch_inputs);
890            
891            let mut batch_loss = 0.0;
892            let mut valid_samples = 0;
893            
894            for (outputs, targets) in batch_outputs.iter().zip(batch_targets.iter()) {
895                for ((output, _), target) in outputs.iter().zip(targets.iter()) {
896                    let loss = self.loss_function.compute_loss(output, target);
897                    batch_loss += loss;
898                    valid_samples += 1;
899                }
900            }
901            
902            if valid_samples > 0 {
903                total_loss += batch_loss / valid_samples as f64;
904                num_batches += 1;
905            }
906        }
907
908        if num_batches > 0 {
909            total_loss / num_batches as f64
910        } else {
911            0.0
912        }
913    }
914
915    /// Generate predictions using batch processing
916    pub fn predict_batch(&mut self, inputs: &[Vec<Array2<f64>>]) -> Vec<Vec<Array2<f64>>> {
917        self.network.eval();
918        
919        let batch_outputs = self.network.forward_batch_sequences(inputs);
920        batch_outputs.into_iter()
921            .map(|sequence_outputs| sequence_outputs.into_iter().map(|(output, _)| output).collect())
922            .collect()
923    }
924
925    /// Clip gradients by global norm to prevent exploding gradients
926    fn clip_gradients(&self, gradients: &mut [crate::layers::lstm_cell::LSTMCellGradients], max_norm: f64) {
927        for gradient in gradients.iter_mut() {
928            self.clip_gradient_matrix(&mut gradient.w_ih, max_norm);
929            self.clip_gradient_matrix(&mut gradient.w_hh, max_norm);
930            self.clip_gradient_matrix(&mut gradient.b_ih, max_norm);
931            self.clip_gradient_matrix(&mut gradient.b_hh, max_norm);
932        }
933    }
934
935    fn clip_gradient_matrix(&self, matrix: &mut Array2<f64>, max_norm: f64) {
936        let norm = (&*matrix * &*matrix).sum().sqrt();
937        if norm > max_norm {
938            let scale = max_norm / norm;
939            *matrix = matrix.map(|x| x * scale);
940        }
941    }
942
943    pub fn get_latest_metrics(&self) -> Option<&TrainingMetrics> {
944        self.metrics_history.last()
945    }
946
947    pub fn get_metrics_history(&self) -> &[TrainingMetrics] {
948        &self.metrics_history
949    }
950
951    pub fn set_training_mode(&mut self, training: bool) {
952        if training {
953            self.network.train();
954        } else {
955            self.network.eval();
956        }
957    }
958}
959
960/// Create a basic trainer with SGD optimizer and MSE loss
961pub fn create_basic_trainer(network: LSTMNetwork, learning_rate: f64) -> LSTMTrainer<MSELoss, SGD> {
962    let loss_function = MSELoss;
963    let optimizer = SGD::new(learning_rate);
964    LSTMTrainer::new(network, loss_function, optimizer)
965}
966
967/// Create a scheduled trainer with SGD and StepLR scheduler
968pub fn create_step_lr_trainer(
969    network: LSTMNetwork, 
970    learning_rate: f64, 
971    step_size: usize, 
972    gamma: f64
973) -> ScheduledLSTMTrainer<MSELoss, SGD, crate::schedulers::StepLR> {
974    let loss_function = MSELoss;
975    let optimizer = ScheduledOptimizer::step_lr(SGD::new(learning_rate), learning_rate, step_size, gamma);
976    ScheduledLSTMTrainer::new(network, loss_function, optimizer)
977}
978
979/// Create a scheduled trainer with Adam and OneCycleLR scheduler  
980pub fn create_one_cycle_trainer(
981    network: LSTMNetwork,
982    max_lr: f64,
983    total_steps: usize
984) -> ScheduledLSTMTrainer<MSELoss, crate::optimizers::Adam, crate::schedulers::OneCycleLR> {
985    let loss_function = MSELoss;
986    let optimizer = ScheduledOptimizer::one_cycle(
987        crate::optimizers::Adam::new(max_lr), 
988        max_lr, 
989        total_steps
990    );
991    ScheduledLSTMTrainer::new(network, loss_function, optimizer)
992}
993
994/// Create a scheduled trainer with Adam and CosineAnnealingLR scheduler
995pub fn create_cosine_annealing_trainer(
996    network: LSTMNetwork,
997    learning_rate: f64,
998    t_max: usize,
999    eta_min: f64
1000) -> ScheduledLSTMTrainer<MSELoss, crate::optimizers::Adam, crate::schedulers::CosineAnnealingLR> {
1001    let loss_function = MSELoss;
1002    let optimizer = crate::optimizers::Adam::new(learning_rate);
1003    let scheduler = crate::schedulers::CosineAnnealingLR::new(t_max, eta_min);
1004    let scheduled_optimizer = crate::optimizers::ScheduledOptimizer::new(optimizer, scheduler, learning_rate);
1005    
1006    ScheduledLSTMTrainer::new(network, loss_function, scheduled_optimizer)
1007}
1008
1009/// Create a basic batch trainer with SGD optimizer and MSE loss
1010pub fn create_basic_batch_trainer(network: LSTMNetwork, learning_rate: f64) -> LSTMBatchTrainer<MSELoss, SGD> {
1011    let loss_function = MSELoss;
1012    let optimizer = SGD::new(learning_rate);
1013    LSTMBatchTrainer::new(network, loss_function, optimizer)
1014}
1015
1016/// Create a batch trainer with Adam optimizer and MSE loss
1017pub fn create_adam_batch_trainer(network: LSTMNetwork, learning_rate: f64) -> LSTMBatchTrainer<MSELoss, crate::optimizers::Adam> {
1018    let loss_function = MSELoss;
1019    let optimizer = crate::optimizers::Adam::new(learning_rate);
1020    LSTMBatchTrainer::new(network, loss_function, optimizer)
1021}
1022
1023#[cfg(test)]
1024mod tests {
1025    use super::*;
1026    use ndarray::arr2;
1027
1028    #[test]
1029    fn test_trainer_creation() {
1030        let network = LSTMNetwork::new(2, 3, 1);
1031        let trainer = create_basic_trainer(network, 0.01);
1032        
1033        assert_eq!(trainer.network.input_size, 2);
1034        assert_eq!(trainer.network.hidden_size, 3);
1035        assert_eq!(trainer.network.num_layers, 1);
1036    }
1037
1038    #[test]
1039    fn test_sequence_training() {
1040        let network = LSTMNetwork::new(2, 3, 1);
1041        let mut trainer = create_basic_trainer(network, 0.01);
1042        
1043        let inputs = vec![
1044            arr2(&[[1.0], [0.0]]),
1045            arr2(&[[0.0], [1.0]]),
1046        ];
1047        let targets = vec![
1048            arr2(&[[1.0], [0.0], [0.0]]),
1049            arr2(&[[0.0], [1.0], [0.0]]),
1050        ];
1051        
1052        let loss = trainer.train_sequence(&inputs, &targets);
1053        assert!(loss >= 0.0);
1054    }
1055}