1use ndarray::Array2;
2use crate::models::lstm_network::LSTMNetwork;
3use crate::loss::{LossFunction, MSELoss};
4use crate::optimizers::{Optimizer, SGD, ScheduledOptimizer};
5use crate::schedulers::LearningRateScheduler;
6use std::time::Instant;
7
8pub struct TrainingConfig {
10 pub epochs: usize,
11 pub print_every: usize,
12 pub clip_gradient: Option<f64>,
13 pub log_lr_changes: bool,
14}
15
16impl Default for TrainingConfig {
17 fn default() -> Self {
18 TrainingConfig {
19 epochs: 100,
20 print_every: 10,
21 clip_gradient: Some(5.0),
22 log_lr_changes: true,
23 }
24 }
25}
26
27#[derive(Debug, Clone)]
29pub struct TrainingMetrics {
30 pub epoch: usize,
31 pub train_loss: f64,
32 pub validation_loss: Option<f64>,
33 pub time_elapsed: f64,
34 pub learning_rate: f64,
35}
36
37pub struct LSTMTrainer<L: LossFunction, O: Optimizer> {
39 pub network: LSTMNetwork,
40 pub loss_function: L,
41 pub optimizer: O,
42 pub config: TrainingConfig,
43 pub metrics_history: Vec<TrainingMetrics>,
44}
45
46impl<L: LossFunction, O: Optimizer> LSTMTrainer<L, O> {
47 pub fn new(network: LSTMNetwork, loss_function: L, optimizer: O) -> Self {
48 LSTMTrainer {
49 network,
50 loss_function,
51 optimizer,
52 config: TrainingConfig::default(),
53 metrics_history: Vec::new(),
54 }
55 }
56
57 pub fn with_config(mut self, config: TrainingConfig) -> Self {
58 self.config = config;
59 self
60 }
61
62 pub fn train_sequence(&mut self, inputs: &[Array2<f64>], targets: &[Array2<f64>]) -> f64 {
64 if inputs.len() != targets.len() {
65 panic!("Inputs and targets must have the same length");
66 }
67
68 self.network.train();
69
70 let (outputs, caches) = self.network.forward_sequence_with_cache(inputs);
71
72 let mut total_loss = 0.0;
73 let mut total_gradients = self.network.zero_gradients();
74
75 for (i, ((output, _), target)) in outputs.iter().zip(targets.iter()).enumerate().rev() {
76 let loss = self.loss_function.compute_loss(output, target);
77 total_loss += loss;
78
79 let dhy = self.loss_function.compute_gradient(output, target);
80 let dcy = Array2::zeros(output.raw_dim());
81
82 let (step_gradients, _) = self.network.backward(&dhy, &dcy, &caches[i]);
83
84 for (total_grad, step_grad) in total_gradients.iter_mut().zip(step_gradients.iter()) {
85 total_grad.w_ih = &total_grad.w_ih + &step_grad.w_ih;
86 total_grad.w_hh = &total_grad.w_hh + &step_grad.w_hh;
87 total_grad.b_ih = &total_grad.b_ih + &step_grad.b_ih;
88 total_grad.b_hh = &total_grad.b_hh + &step_grad.b_hh;
89 }
90 }
91
92 if let Some(clip_value) = self.config.clip_gradient {
93 self.clip_gradients(&mut total_gradients, clip_value);
94 }
95
96 self.network.update_parameters(&total_gradients, &mut self.optimizer);
97
98 total_loss / inputs.len() as f64
99 }
100
101 pub fn train(&mut self, train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)],
103 validation_data: Option<&[(Vec<Array2<f64>>, Vec<Array2<f64>>)]>) {
104
105 println!("Starting training for {} epochs...", self.config.epochs);
106
107 for epoch in 0..self.config.epochs {
108 let start_time = Instant::now();
109 let mut epoch_loss = 0.0;
110
111 self.network.train();
113 for (inputs, targets) in train_data {
114 let loss = self.train_sequence(inputs, targets);
115 epoch_loss += loss;
116 }
117 epoch_loss /= train_data.len() as f64;
118
119 let validation_loss = if let Some(val_data) = validation_data {
120 self.network.eval();
121 Some(self.evaluate(val_data))
122 } else {
123 None
124 };
125
126 let time_elapsed = start_time.elapsed().as_secs_f64();
127
128 let current_lr = self.optimizer.get_learning_rate();
129 let metrics = TrainingMetrics {
130 epoch,
131 train_loss: epoch_loss,
132 validation_loss,
133 time_elapsed,
134 learning_rate: current_lr,
135 };
136
137 self.metrics_history.push(metrics.clone());
138
139 if epoch % self.config.print_every == 0 {
140 if let Some(val_loss) = validation_loss {
141 println!("Epoch {}: Train Loss: {:.6}, Val Loss: {:.6}, LR: {:.2e}, Time: {:.2}s",
142 epoch, epoch_loss, val_loss, current_lr, time_elapsed);
143 } else {
144 println!("Epoch {}: Train Loss: {:.6}, LR: {:.2e}, Time: {:.2}s",
145 epoch, epoch_loss, current_lr, time_elapsed);
146 }
147 }
148 }
149
150 println!("Training completed!");
151 }
152
153 pub fn evaluate(&mut self, data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)]) -> f64 {
155 self.network.eval();
156
157 let mut total_loss = 0.0;
158 let mut total_samples = 0;
159
160 for (inputs, targets) in data {
161 if inputs.len() != targets.len() {
162 continue;
163 }
164
165 let (outputs, _) = self.network.forward_sequence_with_cache(inputs);
166
167 for ((output, _), target) in outputs.iter().zip(targets.iter()) {
168 let loss = self.loss_function.compute_loss(output, target);
169 total_loss += loss;
170 total_samples += 1;
171 }
172 }
173
174 if total_samples > 0 {
175 total_loss / total_samples as f64
176 } else {
177 0.0
178 }
179 }
180
181 pub fn predict(&mut self, inputs: &[Array2<f64>]) -> Vec<Array2<f64>> {
183 self.network.eval();
184
185 let (outputs, _) = self.network.forward_sequence_with_cache(inputs);
186 outputs.into_iter().map(|(output, _)| output).collect()
187 }
188
189 fn clip_gradients(&self, gradients: &mut [crate::layers::lstm_cell::LSTMCellGradients], max_norm: f64) {
191 for gradient in gradients.iter_mut() {
192 self.clip_gradient_matrix(&mut gradient.w_ih, max_norm);
193 self.clip_gradient_matrix(&mut gradient.w_hh, max_norm);
194 self.clip_gradient_matrix(&mut gradient.b_ih, max_norm);
195 self.clip_gradient_matrix(&mut gradient.b_hh, max_norm);
196 }
197 }
198
199 fn clip_gradient_matrix(&self, matrix: &mut Array2<f64>, max_norm: f64) {
200 let norm = (&*matrix * &*matrix).sum().sqrt();
201 if norm > max_norm {
202 let scale = max_norm / norm;
203 *matrix = matrix.map(|x| x * scale);
204 }
205 }
206
207 pub fn get_latest_metrics(&self) -> Option<&TrainingMetrics> {
208 self.metrics_history.last()
209 }
210
211 pub fn get_metrics_history(&self) -> &[TrainingMetrics] {
212 &self.metrics_history
213 }
214
215 pub fn set_training_mode(&mut self, training: bool) {
217 if training {
218 self.network.train();
219 } else {
220 self.network.eval();
221 }
222 }
223}
224
225pub struct ScheduledLSTMTrainer<L: LossFunction, O: Optimizer, S: LearningRateScheduler> {
227 pub network: LSTMNetwork,
228 pub loss_function: L,
229 pub optimizer: ScheduledOptimizer<O, S>,
230 pub config: TrainingConfig,
231 pub metrics_history: Vec<TrainingMetrics>,
232}
233
234impl<L: LossFunction, O: Optimizer, S: LearningRateScheduler> ScheduledLSTMTrainer<L, O, S> {
235 pub fn new(network: LSTMNetwork, loss_function: L, optimizer: ScheduledOptimizer<O, S>) -> Self {
236 ScheduledLSTMTrainer {
237 network,
238 loss_function,
239 optimizer,
240 config: TrainingConfig::default(),
241 metrics_history: Vec::new(),
242 }
243 }
244
245 pub fn with_config(mut self, config: TrainingConfig) -> Self {
246 self.config = config;
247 self
248 }
249
250 pub fn train_sequence(&mut self, inputs: &[Array2<f64>], targets: &[Array2<f64>]) -> f64 {
252 if inputs.len() != targets.len() {
253 panic!("Inputs and targets must have the same length");
254 }
255
256 self.network.train();
257
258 let (outputs, caches) = self.network.forward_sequence_with_cache(inputs);
259
260 let mut total_loss = 0.0;
261 let mut total_gradients = self.network.zero_gradients();
262
263 for (i, ((output, _), target)) in outputs.iter().zip(targets.iter()).enumerate().rev() {
264 let loss = self.loss_function.compute_loss(output, target);
265 total_loss += loss;
266
267 let dhy = self.loss_function.compute_gradient(output, target);
268 let dcy = Array2::zeros(output.raw_dim());
269
270 let (step_gradients, _) = self.network.backward(&dhy, &dcy, &caches[i]);
271
272 for (total_grad, step_grad) in total_gradients.iter_mut().zip(step_gradients.iter()) {
273 total_grad.w_ih = &total_grad.w_ih + &step_grad.w_ih;
274 total_grad.w_hh = &total_grad.w_hh + &step_grad.w_hh;
275 total_grad.b_ih = &total_grad.b_ih + &step_grad.b_ih;
276 total_grad.b_hh = &total_grad.b_hh + &step_grad.b_hh;
277 }
278 }
279
280 if let Some(clip_value) = self.config.clip_gradient {
281 self.clip_gradients(&mut total_gradients, clip_value);
282 }
283
284 self.network.update_parameters(&total_gradients, &mut self.optimizer);
285
286 total_loss / inputs.len() as f64
287 }
288
289 pub fn train(&mut self, train_data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)],
291 validation_data: Option<&[(Vec<Array2<f64>>, Vec<Array2<f64>>)]>) {
292
293 println!("Starting training for {} epochs with {} scheduler...",
294 self.config.epochs, self.optimizer.scheduler_name());
295
296 for epoch in 0..self.config.epochs {
297 let start_time = Instant::now();
298 let mut epoch_loss = 0.0;
299
300 self.network.train();
302 for (inputs, targets) in train_data {
303 let loss = self.train_sequence(inputs, targets);
304 epoch_loss += loss;
305 }
306 epoch_loss /= train_data.len() as f64;
307
308 let validation_loss = if let Some(val_data) = validation_data {
309 self.network.eval();
310 Some(self.evaluate(val_data))
311 } else {
312 None
313 };
314
315 let prev_lr = self.optimizer.get_learning_rate();
317 if let Some(val_loss) = validation_loss {
318 self.optimizer.step_with_val_loss(val_loss);
319 } else {
320 self.optimizer.step();
321 }
322 let new_lr = self.optimizer.get_learning_rate();
323
324 if self.config.log_lr_changes && (new_lr - prev_lr).abs() > 1e-10 {
326 println!("Learning rate changed from {:.2e} to {:.2e}", prev_lr, new_lr);
327 }
328
329 let time_elapsed = start_time.elapsed().as_secs_f64();
330
331 let metrics = TrainingMetrics {
332 epoch,
333 train_loss: epoch_loss,
334 validation_loss,
335 time_elapsed,
336 learning_rate: new_lr,
337 };
338
339 self.metrics_history.push(metrics.clone());
340
341 if epoch % self.config.print_every == 0 {
342 if let Some(val_loss) = validation_loss {
343 println!("Epoch {}: Train Loss: {:.6}, Val Loss: {:.6}, LR: {:.2e}, Time: {:.2}s",
344 epoch, epoch_loss, val_loss, new_lr, time_elapsed);
345 } else {
346 println!("Epoch {}: Train Loss: {:.6}, LR: {:.2e}, Time: {:.2}s",
347 epoch, epoch_loss, new_lr, time_elapsed);
348 }
349 }
350 }
351
352 println!("Training completed!");
353 }
354
355 pub fn evaluate(&mut self, data: &[(Vec<Array2<f64>>, Vec<Array2<f64>>)]) -> f64 {
357 self.network.eval();
358
359 let mut total_loss = 0.0;
360 let mut total_samples = 0;
361
362 for (inputs, targets) in data {
363 if inputs.len() != targets.len() {
364 continue;
365 }
366
367 let (outputs, _) = self.network.forward_sequence_with_cache(inputs);
368
369 for ((output, _), target) in outputs.iter().zip(targets.iter()) {
370 let loss = self.loss_function.compute_loss(output, target);
371 total_loss += loss;
372 total_samples += 1;
373 }
374 }
375
376 if total_samples > 0 {
377 total_loss / total_samples as f64
378 } else {
379 0.0
380 }
381 }
382
383 pub fn predict(&mut self, inputs: &[Array2<f64>]) -> Vec<Array2<f64>> {
385 self.network.eval();
386
387 let (outputs, _) = self.network.forward_sequence_with_cache(inputs);
388 outputs.into_iter().map(|(output, _)| output).collect()
389 }
390
391 fn clip_gradients(&self, gradients: &mut [crate::layers::lstm_cell::LSTMCellGradients], max_norm: f64) {
393 for gradient in gradients.iter_mut() {
394 self.clip_gradient_matrix(&mut gradient.w_ih, max_norm);
395 self.clip_gradient_matrix(&mut gradient.w_hh, max_norm);
396 self.clip_gradient_matrix(&mut gradient.b_ih, max_norm);
397 self.clip_gradient_matrix(&mut gradient.b_hh, max_norm);
398 }
399 }
400
401 fn clip_gradient_matrix(&self, matrix: &mut Array2<f64>, max_norm: f64) {
402 let norm = (&*matrix * &*matrix).sum().sqrt();
403 if norm > max_norm {
404 let scale = max_norm / norm;
405 *matrix = matrix.map(|x| x * scale);
406 }
407 }
408
409 pub fn get_latest_metrics(&self) -> Option<&TrainingMetrics> {
410 self.metrics_history.last()
411 }
412
413 pub fn get_metrics_history(&self) -> &[TrainingMetrics] {
414 &self.metrics_history
415 }
416
417 pub fn set_training_mode(&mut self, training: bool) {
419 if training {
420 self.network.train();
421 } else {
422 self.network.eval();
423 }
424 }
425
426 pub fn get_current_lr(&self) -> f64 {
428 self.optimizer.get_current_lr()
429 }
430
431 pub fn get_current_epoch(&self) -> usize {
433 self.optimizer.get_current_epoch()
434 }
435
436 pub fn reset_optimizer(&mut self) {
438 self.optimizer.reset();
439 }
440}
441
442pub fn create_basic_trainer(network: LSTMNetwork, learning_rate: f64) -> LSTMTrainer<MSELoss, SGD> {
444 let loss_function = MSELoss;
445 let optimizer = SGD::new(learning_rate);
446 LSTMTrainer::new(network, loss_function, optimizer)
447}
448
449pub fn create_step_lr_trainer(
451 network: LSTMNetwork,
452 learning_rate: f64,
453 step_size: usize,
454 gamma: f64
455) -> ScheduledLSTMTrainer<MSELoss, SGD, crate::schedulers::StepLR> {
456 let loss_function = MSELoss;
457 let optimizer = ScheduledOptimizer::step_lr(SGD::new(learning_rate), learning_rate, step_size, gamma);
458 ScheduledLSTMTrainer::new(network, loss_function, optimizer)
459}
460
461pub fn create_one_cycle_trainer(
463 network: LSTMNetwork,
464 max_lr: f64,
465 total_steps: usize
466) -> ScheduledLSTMTrainer<MSELoss, crate::optimizers::Adam, crate::schedulers::OneCycleLR> {
467 let loss_function = MSELoss;
468 let optimizer = ScheduledOptimizer::one_cycle(
469 crate::optimizers::Adam::new(max_lr),
470 max_lr,
471 total_steps
472 );
473 ScheduledLSTMTrainer::new(network, loss_function, optimizer)
474}
475
476pub fn create_cosine_annealing_trainer(
478 network: LSTMNetwork,
479 learning_rate: f64,
480 t_max: usize,
481 eta_min: f64
482) -> ScheduledLSTMTrainer<MSELoss, crate::optimizers::Adam, crate::schedulers::CosineAnnealingLR> {
483 let loss_function = MSELoss;
484 let optimizer = ScheduledOptimizer::cosine_annealing(
485 crate::optimizers::Adam::new(learning_rate),
486 learning_rate,
487 t_max,
488 eta_min
489 );
490 ScheduledLSTMTrainer::new(network, loss_function, optimizer)
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496 use ndarray::arr2;
497
498 #[test]
499 fn test_trainer_creation() {
500 let network = LSTMNetwork::new(2, 3, 1);
501 let trainer = create_basic_trainer(network, 0.01);
502
503 assert_eq!(trainer.network.input_size, 2);
504 assert_eq!(trainer.network.hidden_size, 3);
505 assert_eq!(trainer.network.num_layers, 1);
506 }
507
508 #[test]
509 fn test_sequence_training() {
510 let network = LSTMNetwork::new(2, 3, 1);
511 let mut trainer = create_basic_trainer(network, 0.01);
512
513 let inputs = vec![
514 arr2(&[[1.0], [0.0]]),
515 arr2(&[[0.0], [1.0]]),
516 ];
517 let targets = vec![
518 arr2(&[[1.0], [0.0], [0.0]]),
519 arr2(&[[0.0], [1.0], [0.0]]),
520 ];
521
522 let loss = trainer.train_sequence(&inputs, &targets);
523 assert!(loss >= 0.0);
524 }
525}