rust_lstm/
training.rs

1use ndarray::Array2;
2use crate::models::lstm_network::LSTMNetwork;
3use crate::loss::{LossFunction, MSELoss};
4use crate::optimizers::{Optimizer, SGD, ScheduledOptimizer};
5use crate::schedulers::LearningRateScheduler;
6use std::time::Instant;
7
8/// Configuration for training hyperparameters
9pub struct TrainingConfig {
10    pub epochs: usize,
11    pub print_every: usize,
12    pub clip_gradient: Option<f64>,
13    pub log_lr_changes: bool,
14}
15
16impl Default for TrainingConfig {
17    fn default() -> Self {
18        TrainingConfig {
19            epochs: 100,
20            print_every: 10,
21            clip_gradient: Some(5.0),
22            log_lr_changes: true,
23        }
24    }
25}
26
27/// Training metrics tracked during training
28#[derive(Debug, Clone)]
29pub struct TrainingMetrics {
30    pub epoch: usize,
31    pub train_loss: f64,
32    pub validation_loss: Option<f64>,
33    pub time_elapsed: f64,
34    pub learning_rate: f64,
35}
36
37/// Main trainer for LSTM networks with configurable loss and optimizer
38pub struct LSTMTrainer<L: LossFunction, O: Optimizer> {
39    pub network: LSTMNetwork,
40    pub loss_function: L,
41    pub optimizer: O,
42    pub config: TrainingConfig,
43    pub metrics_history: Vec<TrainingMetrics>,
44}
45
46impl<L: LossFunction, O: Optimizer> LSTMTrainer<L, O> {
47    pub fn new(network: LSTMNetwork, loss_function: L, optimizer: O) -> Self {
48        LSTMTrainer {
49            network,
50            loss_function,
51            optimizer,
52            config: TrainingConfig::default(),
53            metrics_history: Vec::new(),
54        }
55    }
56
57    pub fn with_config(mut self, config: TrainingConfig) -> Self {
58        self.config = config;
59        self
60    }
61
62    /// Train on a single sequence using backpropagation through time (BPTT)
63    pub fn train_sequence(&mut self, inputs: &[Array2<f64>], targets: &[Array2<f64>]) -> f64 {
64        if inputs.len() != targets.len() {
65            panic!("Inputs and targets must have the same length");
66        }
67
68        self.network.train();
69
70        let (outputs, caches) = self.network.forward_sequence_with_cache(inputs);
71        
72        let mut total_loss = 0.0;
73        let mut total_gradients = self.network.zero_gradients();
74
75        for (i, ((output, _), target)) in outputs.iter().zip(targets.iter()).enumerate().rev() {
76            let loss = self.loss_function.compute_loss(output, target);
77            total_loss += loss;
78
79            let dhy = self.loss_function.compute_gradient(output, target);
80            let dcy = Array2::zeros(output.raw_dim());
81
82            let (step_gradients, _) = self.network.backward(&dhy, &dcy, &caches[i]);
83
84            for (total_grad, step_grad) in total_gradients.iter_mut().zip(step_gradients.iter()) {
85                total_grad.w_ih = &total_grad.w_ih + &step_grad.w_ih;
86                total_grad.w_hh = &total_grad.w_hh + &step_grad.w_hh;
87                total_grad.b_ih = &total_grad.b_ih + &step_grad.b_ih;
88                total_grad.b_hh = &total_grad.b_hh + &step_grad.b_hh;
89            }
90        }
91
92        if let Some(clip_value) = self.config.clip_gradient {
93            self.clip_gradients(&mut total_gradients, clip_value);
94        }
95
96        self.network.update_parameters(&total_gradients, &mut self.optimizer);
97
98        total_loss / inputs.len() as f64
99    }
100
101    /// Train for multiple epochs with optional validation
102    pub fn train(&mut self, train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)], 
103                 validation_data: Option<&[(Vec<Array2<f64>>, Vec<Array2<f64>>)]>) {
104        
105        println!("Starting training for {} epochs...", self.config.epochs);
106        
107        for epoch in 0..self.config.epochs {
108            let start_time = Instant::now();
109            let mut epoch_loss = 0.0;
110
111            // Training phase
112            self.network.train();
113            for (inputs, targets) in train_data {
114                let loss = self.train_sequence(inputs, targets);
115                epoch_loss += loss;
116            }
117            epoch_loss /= train_data.len() as f64;
118
119            let validation_loss = if let Some(val_data) = validation_data {
120                self.network.eval();
121                Some(self.evaluate(val_data))
122            } else {
123                None
124            };
125
126            let time_elapsed = start_time.elapsed().as_secs_f64();
127
128            let current_lr = self.optimizer.get_learning_rate();
129            let metrics = TrainingMetrics {
130                epoch,
131                train_loss: epoch_loss,
132                validation_loss,
133                time_elapsed,
134                learning_rate: current_lr,
135            };
136
137            self.metrics_history.push(metrics.clone());
138
139            if epoch % self.config.print_every == 0 {
140                if let Some(val_loss) = validation_loss {
141                    println!("Epoch {}: Train Loss: {:.6}, Val Loss: {:.6}, LR: {:.2e}, Time: {:.2}s", 
142                             epoch, epoch_loss, val_loss, current_lr, time_elapsed);
143                } else {
144                    println!("Epoch {}: Train Loss: {:.6}, LR: {:.2e}, Time: {:.2}s", 
145                             epoch, epoch_loss, current_lr, time_elapsed);
146                }
147            }
148        }
149
150        println!("Training completed!");
151    }
152
153    /// Evaluate model performance on validation data
154    pub fn evaluate(&mut self, data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)]) -> f64 {
155        self.network.eval();
156        
157        let mut total_loss = 0.0;
158        let mut total_samples = 0;
159
160        for (inputs, targets) in data {
161            if inputs.len() != targets.len() {
162                continue;
163            }
164
165            let (outputs, _) = self.network.forward_sequence_with_cache(inputs);
166            
167            for ((output, _), target) in outputs.iter().zip(targets.iter()) {
168                let loss = self.loss_function.compute_loss(output, target);
169                total_loss += loss;
170                total_samples += 1;
171            }
172        }
173
174        if total_samples > 0 {
175            total_loss / total_samples as f64
176        } else {
177            0.0
178        }
179    }
180
181    /// Generate predictions for input sequences
182    pub fn predict(&mut self, inputs: &[Array2<f64>]) -> Vec<Array2<f64>> {
183        self.network.eval();
184        
185        let (outputs, _) = self.network.forward_sequence_with_cache(inputs);
186        outputs.into_iter().map(|(output, _)| output).collect()
187    }
188
189    /// Clip gradients by global norm to prevent exploding gradients
190    fn clip_gradients(&self, gradients: &mut [crate::layers::lstm_cell::LSTMCellGradients], max_norm: f64) {
191        for gradient in gradients.iter_mut() {
192            self.clip_gradient_matrix(&mut gradient.w_ih, max_norm);
193            self.clip_gradient_matrix(&mut gradient.w_hh, max_norm);
194            self.clip_gradient_matrix(&mut gradient.b_ih, max_norm);
195            self.clip_gradient_matrix(&mut gradient.b_hh, max_norm);
196        }
197    }
198
199    fn clip_gradient_matrix(&self, matrix: &mut Array2<f64>, max_norm: f64) {
200        let norm = (&*matrix * &*matrix).sum().sqrt();
201        if norm > max_norm {
202            let scale = max_norm / norm;
203            *matrix = matrix.map(|x| x * scale);
204        }
205    }
206
207    pub fn get_latest_metrics(&self) -> Option<&TrainingMetrics> {
208        self.metrics_history.last()
209    }
210
211    pub fn get_metrics_history(&self) -> &[TrainingMetrics] {
212        &self.metrics_history
213    }
214
215    /// Set network to training mode
216    pub fn set_training_mode(&mut self, training: bool) {
217        if training {
218            self.network.train();
219        } else {
220            self.network.eval();
221        }
222    }
223}
224
225/// Specialized trainer for scheduled optimizers that automatically steps the scheduler
226pub struct ScheduledLSTMTrainer<L: LossFunction, O: Optimizer, S: LearningRateScheduler> {
227    pub network: LSTMNetwork,
228    pub loss_function: L,
229    pub optimizer: ScheduledOptimizer<O, S>,
230    pub config: TrainingConfig,
231    pub metrics_history: Vec<TrainingMetrics>,
232}
233
234impl<L: LossFunction, O: Optimizer, S: LearningRateScheduler> ScheduledLSTMTrainer<L, O, S> {
235    pub fn new(network: LSTMNetwork, loss_function: L, optimizer: ScheduledOptimizer<O, S>) -> Self {
236        ScheduledLSTMTrainer {
237            network,
238            loss_function,
239            optimizer,
240            config: TrainingConfig::default(),
241            metrics_history: Vec::new(),
242        }
243    }
244
245    pub fn with_config(mut self, config: TrainingConfig) -> Self {
246        self.config = config;
247        self
248    }
249
250    /// Train on a single sequence using backpropagation through time (BPTT)
251    pub fn train_sequence(&mut self, inputs: &[Array2<f64>], targets: &[Array2<f64>]) -> f64 {
252        if inputs.len() != targets.len() {
253            panic!("Inputs and targets must have the same length");
254        }
255
256        self.network.train();
257
258        let (outputs, caches) = self.network.forward_sequence_with_cache(inputs);
259        
260        let mut total_loss = 0.0;
261        let mut total_gradients = self.network.zero_gradients();
262
263        for (i, ((output, _), target)) in outputs.iter().zip(targets.iter()).enumerate().rev() {
264            let loss = self.loss_function.compute_loss(output, target);
265            total_loss += loss;
266
267            let dhy = self.loss_function.compute_gradient(output, target);
268            let dcy = Array2::zeros(output.raw_dim());
269
270            let (step_gradients, _) = self.network.backward(&dhy, &dcy, &caches[i]);
271
272            for (total_grad, step_grad) in total_gradients.iter_mut().zip(step_gradients.iter()) {
273                total_grad.w_ih = &total_grad.w_ih + &step_grad.w_ih;
274                total_grad.w_hh = &total_grad.w_hh + &step_grad.w_hh;
275                total_grad.b_ih = &total_grad.b_ih + &step_grad.b_ih;
276                total_grad.b_hh = &total_grad.b_hh + &step_grad.b_hh;
277            }
278        }
279
280        if let Some(clip_value) = self.config.clip_gradient {
281            self.clip_gradients(&mut total_gradients, clip_value);
282        }
283
284        self.network.update_parameters(&total_gradients, &mut self.optimizer);
285
286        total_loss / inputs.len() as f64
287    }
288
289    /// Train for multiple epochs with automatic scheduler stepping
290    pub fn train(&mut self, train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)], 
291                 validation_data: Option<&[(Vec<Array2<f64>>, Vec<Array2<f64>>)]>) {
292        
293        println!("Starting training for {} epochs with {} scheduler...", 
294                 self.config.epochs, self.optimizer.scheduler_name());
295        
296        for epoch in 0..self.config.epochs {
297            let start_time = Instant::now();
298            let mut epoch_loss = 0.0;
299
300            // Training phase
301            self.network.train();
302            for (inputs, targets) in train_data {
303                let loss = self.train_sequence(inputs, targets);
304                epoch_loss += loss;
305            }
306            epoch_loss /= train_data.len() as f64;
307
308            let validation_loss = if let Some(val_data) = validation_data {
309                self.network.eval();
310                Some(self.evaluate(val_data))
311            } else {
312                None
313            };
314
315            // Step the scheduler at the end of each epoch
316            let prev_lr = self.optimizer.get_learning_rate();
317            if let Some(val_loss) = validation_loss {
318                self.optimizer.step_with_val_loss(val_loss);
319            } else {
320                self.optimizer.step();
321            }
322            let new_lr = self.optimizer.get_learning_rate();
323
324            // Log learning rate changes if enabled
325            if self.config.log_lr_changes && (new_lr - prev_lr).abs() > 1e-10 {
326                println!("Learning rate changed from {:.2e} to {:.2e}", prev_lr, new_lr);
327            }
328
329            let time_elapsed = start_time.elapsed().as_secs_f64();
330
331            let metrics = TrainingMetrics {
332                epoch,
333                train_loss: epoch_loss,
334                validation_loss,
335                time_elapsed,
336                learning_rate: new_lr,
337            };
338
339            self.metrics_history.push(metrics.clone());
340
341            if epoch % self.config.print_every == 0 {
342                if let Some(val_loss) = validation_loss {
343                    println!("Epoch {}: Train Loss: {:.6}, Val Loss: {:.6}, LR: {:.2e}, Time: {:.2}s", 
344                             epoch, epoch_loss, val_loss, new_lr, time_elapsed);
345                } else {
346                    println!("Epoch {}: Train Loss: {:.6}, LR: {:.2e}, Time: {:.2}s", 
347                             epoch, epoch_loss, new_lr, time_elapsed);
348                }
349            }
350        }
351
352        println!("Training completed!");
353    }
354
355    /// Evaluate model performance on validation data
356    pub fn evaluate(&mut self, data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)]) -> f64 {
357        self.network.eval();
358        
359        let mut total_loss = 0.0;
360        let mut total_samples = 0;
361
362        for (inputs, targets) in data {
363            if inputs.len() != targets.len() {
364                continue;
365            }
366
367            let (outputs, _) = self.network.forward_sequence_with_cache(inputs);
368            
369            for ((output, _), target) in outputs.iter().zip(targets.iter()) {
370                let loss = self.loss_function.compute_loss(output, target);
371                total_loss += loss;
372                total_samples += 1;
373            }
374        }
375
376        if total_samples > 0 {
377            total_loss / total_samples as f64
378        } else {
379            0.0
380        }
381    }
382
383    /// Generate predictions for input sequences
384    pub fn predict(&mut self, inputs: &[Array2<f64>]) -> Vec<Array2<f64>> {
385        self.network.eval();
386        
387        let (outputs, _) = self.network.forward_sequence_with_cache(inputs);
388        outputs.into_iter().map(|(output, _)| output).collect()
389    }
390
391    /// Clip gradients by global norm to prevent exploding gradients
392    fn clip_gradients(&self, gradients: &mut [crate::layers::lstm_cell::LSTMCellGradients], max_norm: f64) {
393        for gradient in gradients.iter_mut() {
394            self.clip_gradient_matrix(&mut gradient.w_ih, max_norm);
395            self.clip_gradient_matrix(&mut gradient.w_hh, max_norm);
396            self.clip_gradient_matrix(&mut gradient.b_ih, max_norm);
397            self.clip_gradient_matrix(&mut gradient.b_hh, max_norm);
398        }
399    }
400
401    fn clip_gradient_matrix(&self, matrix: &mut Array2<f64>, max_norm: f64) {
402        let norm = (&*matrix * &*matrix).sum().sqrt();
403        if norm > max_norm {
404            let scale = max_norm / norm;
405            *matrix = matrix.map(|x| x * scale);
406        }
407    }
408
409    pub fn get_latest_metrics(&self) -> Option<&TrainingMetrics> {
410        self.metrics_history.last()
411    }
412
413    pub fn get_metrics_history(&self) -> &[TrainingMetrics] {
414        &self.metrics_history
415    }
416
417    /// Set network to training mode
418    pub fn set_training_mode(&mut self, training: bool) {
419        if training {
420            self.network.train();
421        } else {
422            self.network.eval();
423        }
424    }
425
426    /// Get the current learning rate
427    pub fn get_current_lr(&self) -> f64 {
428        self.optimizer.get_current_lr()
429    }
430
431    /// Get the current epoch from the scheduler
432    pub fn get_current_epoch(&self) -> usize {
433        self.optimizer.get_current_epoch()
434    }
435
436    /// Reset the optimizer and scheduler
437    pub fn reset_optimizer(&mut self) {
438        self.optimizer.reset();
439    }
440}
441
442/// Create a basic trainer with SGD optimizer and MSE loss
443pub fn create_basic_trainer(network: LSTMNetwork, learning_rate: f64) -> LSTMTrainer<MSELoss, SGD> {
444    let loss_function = MSELoss;
445    let optimizer = SGD::new(learning_rate);
446    LSTMTrainer::new(network, loss_function, optimizer)
447}
448
449/// Create a scheduled trainer with SGD and StepLR scheduler
450pub fn create_step_lr_trainer(
451    network: LSTMNetwork, 
452    learning_rate: f64, 
453    step_size: usize, 
454    gamma: f64
455) -> ScheduledLSTMTrainer<MSELoss, SGD, crate::schedulers::StepLR> {
456    let loss_function = MSELoss;
457    let optimizer = ScheduledOptimizer::step_lr(SGD::new(learning_rate), learning_rate, step_size, gamma);
458    ScheduledLSTMTrainer::new(network, loss_function, optimizer)
459}
460
461/// Create a scheduled trainer with Adam and OneCycleLR scheduler  
462pub fn create_one_cycle_trainer(
463    network: LSTMNetwork,
464    max_lr: f64,
465    total_steps: usize
466) -> ScheduledLSTMTrainer<MSELoss, crate::optimizers::Adam, crate::schedulers::OneCycleLR> {
467    let loss_function = MSELoss;
468    let optimizer = ScheduledOptimizer::one_cycle(
469        crate::optimizers::Adam::new(max_lr), 
470        max_lr, 
471        total_steps
472    );
473    ScheduledLSTMTrainer::new(network, loss_function, optimizer)
474}
475
476/// Create a scheduled trainer with Adam and CosineAnnealingLR scheduler
477pub fn create_cosine_annealing_trainer(
478    network: LSTMNetwork,
479    learning_rate: f64,
480    t_max: usize,
481    eta_min: f64
482) -> ScheduledLSTMTrainer<MSELoss, crate::optimizers::Adam, crate::schedulers::CosineAnnealingLR> {
483    let loss_function = MSELoss;
484    let optimizer = ScheduledOptimizer::cosine_annealing(
485        crate::optimizers::Adam::new(learning_rate),
486        learning_rate,
487        t_max,
488        eta_min
489    );
490    ScheduledLSTMTrainer::new(network, loss_function, optimizer)
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496    use ndarray::arr2;
497
498    #[test]
499    fn test_trainer_creation() {
500        let network = LSTMNetwork::new(2, 3, 1);
501        let trainer = create_basic_trainer(network, 0.01);
502        
503        assert_eq!(trainer.network.input_size, 2);
504        assert_eq!(trainer.network.hidden_size, 3);
505        assert_eq!(trainer.network.num_layers, 1);
506    }
507
508    #[test]
509    fn test_sequence_training() {
510        let network = LSTMNetwork::new(2, 3, 1);
511        let mut trainer = create_basic_trainer(network, 0.01);
512        
513        let inputs = vec![
514            arr2(&[[1.0], [0.0]]),
515            arr2(&[[0.0], [1.0]]),
516        ];
517        let targets = vec![
518            arr2(&[[1.0], [0.0], [0.0]]),
519            arr2(&[[0.0], [1.0], [0.0]]),
520        ];
521        
522        let loss = trainer.train_sequence(&inputs, &targets);
523        assert!(loss >= 0.0);
524    }
525}