1use crate::error::OptimizeError;
7use crate::stochastic::{StochasticMethod, StochasticOptions};
8use ndarray::{s, Array1, ScalarOperand};
9use num_traits::Float;
10use std::collections::HashMap;
11
12#[derive(Debug, Clone)]
14pub struct NeuralParameters<F: Float + ScalarOperand> {
15 pub parameters: Vec<Array1<F>>,
17 pub gradients: Vec<Array1<F>>,
19 pub names: Vec<String>,
21}
22
23impl<F: Float + ScalarOperand> Default for NeuralParameters<F> {
24 fn default() -> Self {
25 Self {
26 parameters: Vec::new(),
27 gradients: Vec::new(),
28 names: Vec::new(),
29 }
30 }
31}
32
33impl<F: Float + ScalarOperand> NeuralParameters<F> {
34 pub fn new() -> Self {
36 Self::default()
37 }
38
39 pub fn add_parameter(&mut self, name: String, param: Array1<F>) {
41 self.names.push(name);
42 self.gradients.push(Array1::zeros(param.raw_dim()));
43 self.parameters.push(param);
44 }
45
46 pub fn total_parameters(&self) -> usize {
48 self.parameters.iter().map(|p| p.len()).sum()
49 }
50
51 pub fn flatten_parameters(&self) -> Array1<F> {
53 let total_len = self.total_parameters();
54 let mut flat = Array1::zeros(total_len);
55 let mut offset = 0;
56
57 for param in &self.parameters {
58 let len = param.len();
59 flat.slice_mut(s![offset..offset + len]).assign(param);
60 offset += len;
61 }
62
63 flat
64 }
65
66 pub fn flatten_gradients(&self) -> Array1<F> {
68 let total_len = self.total_parameters();
69 let mut flat = Array1::zeros(total_len);
70 let mut offset = 0;
71
72 for grad in &self.gradients {
73 let len = grad.len();
74 flat.slice_mut(s![offset..offset + len]).assign(grad);
75 offset += len;
76 }
77
78 flat
79 }
80
81 pub fn update_from_flat(&mut self, flat_params: &Array1<F>) {
83 let mut offset = 0;
84
85 for param in &mut self.parameters {
86 let len = param.len();
87 param.assign(&flat_params.slice(s![offset..offset + len]));
88 offset += len;
89 }
90 }
91
92 pub fn update_gradients_from_flat(&mut self, flat_grads: &Array1<F>) {
94 let mut offset = 0;
95
96 for grad in &mut self.gradients {
97 let len = grad.len();
98 grad.assign(&flat_grads.slice(s![offset..offset + len]));
99 offset += len;
100 }
101 }
102}
103
104pub struct NeuralOptimizer<F: Float + ScalarOperand> {
106 method: StochasticMethod,
107 options: StochasticOptions,
108 momentum_buffers: HashMap<String, Array1<F>>,
110 first_moment: HashMap<String, Array1<F>>,
112 second_moment: HashMap<String, Array1<F>>,
113 step_count: usize,
115}
116
117impl<F: Float + ScalarOperand> NeuralOptimizer<F>
118where
119 F: 'static + Send + Sync,
120{
121 pub fn new(method: StochasticMethod, options: StochasticOptions) -> Self {
123 Self {
124 method,
125 options,
126 momentum_buffers: HashMap::new(),
127 first_moment: HashMap::new(),
128 second_moment: HashMap::new(),
129 step_count: 0,
130 }
131 }
132
133 pub fn sgd(learning_rate: F, max_iter: usize) -> Self {
135 let options = StochasticOptions {
136 learning_rate: learning_rate.to_f64().unwrap_or(0.01),
137 max_iter,
138 batch_size: None,
139 tol: 1e-6,
140 adaptive_lr: false,
141 lr_decay: 0.99,
142 lr_schedule: crate::stochastic::LearningRateSchedule::Constant,
143 gradient_clip: None,
144 early_stopping_patience: None,
145 };
146
147 Self::new(StochasticMethod::SGD, options)
148 }
149
150 pub fn adam(learning_rate: F, max_iter: usize) -> Self {
152 let options = StochasticOptions {
153 learning_rate: learning_rate.to_f64().unwrap_or(0.001),
154 max_iter,
155 batch_size: None,
156 tol: 1e-6,
157 adaptive_lr: false,
158 lr_decay: 0.99,
159 lr_schedule: crate::stochastic::LearningRateSchedule::Constant,
160 gradient_clip: Some(1.0),
161 early_stopping_patience: None,
162 };
163
164 Self::new(StochasticMethod::Adam, options)
165 }
166
167 pub fn adamw(learning_rate: F, max_iter: usize) -> Self {
169 let options = StochasticOptions {
170 learning_rate: learning_rate.to_f64().unwrap_or(0.001),
171 max_iter,
172 batch_size: None,
173 tol: 1e-6,
174 adaptive_lr: false,
175 lr_decay: 0.99,
176 lr_schedule: crate::stochastic::LearningRateSchedule::Constant,
177 gradient_clip: Some(1.0),
178 early_stopping_patience: None,
179 };
180
181 Self::new(StochasticMethod::AdamW, options)
182 }
183
184 pub fn step(&mut self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
186 self.step_count += 1;
187
188 match self.method {
189 StochasticMethod::SGD => self.sgd_step(params),
190 StochasticMethod::Momentum => self.momentum_step(params),
191 StochasticMethod::Adam => self.adam_step(params),
192 StochasticMethod::AdamW => self.adamw_step(params),
193 StochasticMethod::RMSProp => self.rmsprop_step(params),
194 }
195 }
196
197 fn sgd_step(&self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
199 let lr = F::from(self.options.learning_rate).unwrap_or_else(|| F::from(0.01).unwrap());
200
201 for (param, grad) in params.parameters.iter_mut().zip(params.gradients.iter()) {
202 *param = param.clone() - &(grad.clone() * lr);
203 }
204
205 Ok(())
206 }
207
208 fn momentum_step(&mut self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
210 let lr = F::from(self.options.learning_rate).unwrap_or_else(|| F::from(0.01).unwrap());
211 let momentum = F::from(0.9).unwrap();
212
213 for (i, (param, grad)) in params
214 .parameters
215 .iter_mut()
216 .zip(params.gradients.iter())
217 .enumerate()
218 {
219 let param_name = format!("param_{}", i);
220
221 if !self.momentum_buffers.contains_key(¶m_name) {
223 self.momentum_buffers
224 .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
225 }
226
227 let momentum_buffer = self.momentum_buffers.get_mut(¶m_name).unwrap();
228
229 *momentum_buffer = momentum_buffer.clone() * momentum + grad;
231
232 *param = param.clone() - &(momentum_buffer.clone() * lr);
234 }
235
236 Ok(())
237 }
238
239 fn adam_step(&mut self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
241 let lr = F::from(self.options.learning_rate).unwrap_or_else(|| F::from(0.001).unwrap());
242 let beta1 = F::from(0.9).unwrap();
243 let beta2 = F::from(0.999).unwrap();
244 let epsilon = F::from(1e-8).unwrap();
245
246 for (i, (param, grad)) in params
247 .parameters
248 .iter_mut()
249 .zip(params.gradients.iter())
250 .enumerate()
251 {
252 let param_name = format!("param_{}", i);
253
254 if !self.first_moment.contains_key(¶m_name) {
256 self.first_moment
257 .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
258 self.second_moment
259 .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
260 }
261
262 let m = self.first_moment.get_mut(¶m_name).unwrap();
263 let v = self.second_moment.get_mut(¶m_name).unwrap();
264
265 *m = m.clone() * beta1 + &(grad.clone() * (F::one() - beta1));
267
268 let grad_squared = grad.mapv(|x| x * x);
270 *v = v.clone() * beta2 + &(grad_squared * (F::one() - beta2));
271
272 let step_f = F::from(self.step_count).unwrap();
274 let m_hat = m.clone() / (F::one() - beta1.powf(step_f));
275 let v_hat = v.clone() / (F::one() - beta2.powf(step_f));
276
277 let denominator = v_hat.mapv(|x| x.sqrt()) + epsilon;
279 let update = m_hat / denominator * lr;
280 *param = param.clone() - &update;
281 }
282
283 Ok(())
284 }
285
286 fn adamw_step(&mut self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
288 let lr = F::from(self.options.learning_rate).unwrap_or_else(|| F::from(0.001).unwrap());
289 let beta1 = F::from(0.9).unwrap();
290 let beta2 = F::from(0.999).unwrap();
291 let epsilon = F::from(1e-8).unwrap();
292 let weight_decay = F::from(0.01).unwrap();
293
294 for (i, (param, grad)) in params
295 .parameters
296 .iter_mut()
297 .zip(params.gradients.iter())
298 .enumerate()
299 {
300 let param_name = format!("param_{}", i);
301
302 if !self.first_moment.contains_key(¶m_name) {
304 self.first_moment
305 .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
306 self.second_moment
307 .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
308 }
309
310 let m = self.first_moment.get_mut(¶m_name).unwrap();
311 let v = self.second_moment.get_mut(¶m_name).unwrap();
312
313 *m = m.clone() * beta1 + &(grad.clone() * (F::one() - beta1));
315
316 let grad_squared = grad.mapv(|x| x * x);
318 *v = v.clone() * beta2 + &(grad_squared * (F::one() - beta2));
319
320 let step_f = F::from(self.step_count).unwrap();
322 let m_hat = m.clone() / (F::one() - beta1.powf(step_f));
323 let v_hat = v.clone() / (F::one() - beta2.powf(step_f));
324
325 let denominator = v_hat.mapv(|x| x.sqrt()) + epsilon;
327 let adam_update = m_hat / denominator;
328 let weight_decay_update = param.clone() * weight_decay;
329 let total_update = (adam_update + weight_decay_update) * lr;
330
331 *param = param.clone() - &total_update;
332 }
333
334 Ok(())
335 }
336
337 fn rmsprop_step(&mut self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
339 let lr = F::from(self.options.learning_rate).unwrap_or_else(|| F::from(0.001).unwrap());
340 let alpha = F::from(0.99).unwrap(); let epsilon = F::from(1e-8).unwrap();
342
343 for (i, (param, grad)) in params
344 .parameters
345 .iter_mut()
346 .zip(params.gradients.iter())
347 .enumerate()
348 {
349 let param_name = format!("param_{}", i);
350
351 if !self.second_moment.contains_key(¶m_name) {
353 self.second_moment
354 .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
355 }
356
357 let v = self.second_moment.get_mut(¶m_name).unwrap();
358
359 let grad_squared = grad.mapv(|x| x * x);
361 *v = v.clone() * alpha + &(grad_squared * (F::one() - alpha));
362
363 let denominator = v.mapv(|x| x.sqrt()) + epsilon;
365 let update = grad.clone() / denominator * lr;
366 *param = param.clone() - &update;
367 }
368
369 Ok(())
370 }
371
372 pub fn get_learning_rate(&self) -> f64 {
374 self.options.learning_rate
375 }
376
377 pub fn set_learning_rate(&mut self, lr: f64) {
379 self.options.learning_rate = lr;
380 }
381
382 pub fn reset(&mut self) {
384 self.momentum_buffers.clear();
385 self.first_moment.clear();
386 self.second_moment.clear();
387 self.step_count = 0;
388 }
389
390 pub fn method_name(&self) -> &'static str {
392 match self.method {
393 StochasticMethod::SGD => "SGD",
394 StochasticMethod::Momentum => "SGD with Momentum",
395 StochasticMethod::Adam => "Adam",
396 StochasticMethod::AdamW => "AdamW",
397 StochasticMethod::RMSProp => "RMSprop",
398 }
399 }
400}
401
402pub struct NeuralTrainer<F: Float + ScalarOperand> {
404 optimizer: NeuralOptimizer<F>,
405 loss_history: Vec<F>,
406 early_stopping_patience: Option<usize>,
407 best_loss: Option<F>,
408 patience_counter: usize,
409}
410
411impl<F: Float + ScalarOperand> NeuralTrainer<F>
412where
413 F: 'static + Send + Sync + std::fmt::Display,
414{
415 pub fn new(optimizer: NeuralOptimizer<F>) -> Self {
417 Self {
418 optimizer,
419 loss_history: Vec::new(),
420 early_stopping_patience: None,
421 best_loss: None,
422 patience_counter: 0,
423 }
424 }
425
426 pub fn with_early_stopping(mut self, patience: usize) -> Self {
428 self.early_stopping_patience = Some(patience);
429 self
430 }
431
432 pub fn train_epoch<LossFn, GradFn>(
434 &mut self,
435 params: &mut NeuralParameters<F>,
436 loss_fn: &mut LossFn,
437 grad_fn: &mut GradFn,
438 ) -> Result<F, OptimizeError>
439 where
440 LossFn: FnMut(&NeuralParameters<F>) -> F,
441 GradFn: FnMut(&NeuralParameters<F>) -> Vec<Array1<F>>,
442 {
443 let gradients = grad_fn(params);
445 params.gradients = gradients;
446
447 if let Some(max_norm) = self.optimizer.options.gradient_clip {
449 self.clip_gradients(params, max_norm);
450 }
451
452 self.optimizer.step(params)?;
454
455 let loss = loss_fn(params);
457 self.loss_history.push(loss);
458
459 if let Some(_patience) = self.early_stopping_patience {
461 if let Some(best_loss) = self.best_loss {
462 if loss < best_loss {
463 self.best_loss = Some(loss);
464 self.patience_counter = 0;
465 } else {
466 self.patience_counter += 1;
467 }
468 } else {
469 self.best_loss = Some(loss);
470 }
471 }
472
473 Ok(loss)
474 }
475
476 pub fn should_stop_early(&self) -> bool {
478 if let Some(patience) = self.early_stopping_patience {
479 self.patience_counter >= patience
480 } else {
481 false
482 }
483 }
484
485 pub fn loss_history(&self) -> &[F] {
487 &self.loss_history
488 }
489
490 pub fn learning_rate(&self) -> f64 {
492 self.optimizer.get_learning_rate()
493 }
494
495 pub fn set_learning_rate(&mut self, lr: f64) {
497 self.optimizer.set_learning_rate(lr);
498 }
499
500 fn clip_gradients(&self, params: &mut NeuralParameters<F>, max_norm: f64) {
502 let max_norm_f = F::from(max_norm).unwrap();
503
504 let mut total_norm_sq = F::zero();
506 for grad in ¶ms.gradients {
507 total_norm_sq = total_norm_sq + grad.mapv(|x| x * x).sum();
508 }
509 let total_norm = total_norm_sq.sqrt();
510
511 if total_norm > max_norm_f {
512 let scale = max_norm_f / total_norm;
513 for grad in &mut params.gradients {
514 grad.mapv_inplace(|x| x * scale);
515 }
516 }
517 }
518}
519
520pub mod optimizers {
522 use super::*;
523
524 pub fn sgd<F>(learning_rate: F) -> NeuralOptimizer<F>
526 where
527 F: Float + ScalarOperand + 'static + Send + Sync,
528 {
529 NeuralOptimizer::sgd(learning_rate, 1000)
530 }
531
532 pub fn adam<F>(learning_rate: F) -> NeuralOptimizer<F>
534 where
535 F: Float + ScalarOperand + 'static + Send + Sync,
536 {
537 NeuralOptimizer::adam(learning_rate, 1000)
538 }
539
540 pub fn adamw<F>(learning_rate: F) -> NeuralOptimizer<F>
542 where
543 F: Float + ScalarOperand + 'static + Send + Sync,
544 {
545 NeuralOptimizer::adamw(learning_rate, 1000)
546 }
547}
548
549#[cfg(test)]
550mod tests {
551 use super::*;
552 use approx::assert_abs_diff_eq;
553
554 #[test]
555 fn test_neural_parameters() {
556 let mut params = NeuralParameters::<f64>::new();
557
558 params.add_parameter("layer1".to_string(), Array1::from_vec(vec![1.0, 2.0, 3.0]));
560 params.add_parameter("layer2".to_string(), Array1::from_vec(vec![4.0, 5.0]));
561
562 assert_eq!(params.total_parameters(), 5);
563
564 let flat = params.flatten_parameters();
566 assert_eq!(flat.as_slice().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
567
568 let new_flat = Array1::from_vec(vec![6.0, 7.0, 8.0, 9.0, 10.0]);
570 params.update_from_flat(&new_flat);
571
572 assert_eq!(params.parameters[0].as_slice().unwrap(), &[6.0, 7.0, 8.0]);
573 assert_eq!(params.parameters[1].as_slice().unwrap(), &[9.0, 10.0]);
574 }
575
576 #[test]
577 fn test_sgd_optimizer() {
578 let mut optimizer = NeuralOptimizer::sgd(0.1, 100);
579 let mut params = NeuralParameters::<f64>::new();
580
581 params.add_parameter("test".to_string(), Array1::from_vec(vec![1.0, 2.0]));
583 params.gradients[0] = Array1::from_vec(vec![0.5, 1.0]);
585
586 optimizer.step(&mut params).unwrap();
588
589 let expected = [1.0 - 0.1 * 0.5, 2.0 - 0.1 * 1.0];
591 assert_abs_diff_eq!(params.parameters[0][0], expected[0], epsilon = 1e-10);
592 assert_abs_diff_eq!(params.parameters[0][1], expected[1], epsilon = 1e-10);
593 }
594
595 #[test]
596 fn test_adam_optimizer() {
597 let mut optimizer = NeuralOptimizer::adam(0.001, 100);
598 let mut params = NeuralParameters::<f64>::new();
599
600 params.add_parameter("test".to_string(), Array1::from_vec(vec![1.0, 2.0]));
602 params.gradients[0] = Array1::from_vec(vec![0.1, 0.2]);
604
605 let original_params = params.parameters[0].clone();
606
607 optimizer.step(&mut params).unwrap();
609
610 assert_ne!(params.parameters[0][0], original_params[0]);
612 assert_ne!(params.parameters[0][1], original_params[1]);
613
614 assert!(params.parameters[0][0] < original_params[0]);
616 assert!(params.parameters[0][1] < original_params[1]);
617 }
618
619 #[test]
620 fn test_neural_trainer() {
621 let optimizer = NeuralOptimizer::sgd(0.1, 100);
622 let mut trainer = NeuralTrainer::new(optimizer).with_early_stopping(5);
623
624 let mut params = NeuralParameters::<f64>::new();
625 params.add_parameter("test".to_string(), Array1::from_vec(vec![1.0]));
626 params.gradients[0] = Array1::from_vec(vec![1.0]);
627
628 let mut loss_fn = |p: &NeuralParameters<f64>| p.parameters[0][0] * p.parameters[0][0];
630 let mut grad_fn =
631 |p: &NeuralParameters<f64>| vec![Array1::from_vec(vec![2.0 * p.parameters[0][0]])];
632
633 let loss = trainer
635 .train_epoch(&mut params, &mut loss_fn, &mut grad_fn)
636 .unwrap();
637
638 assert_eq!(trainer.loss_history().len(), 1);
640 assert_eq!(trainer.loss_history()[0], loss);
641 }
642
643 #[test]
644 fn test_optimizer_convenience_functions() {
645 let sgd_opt = optimizers::sgd(0.01);
646 assert_eq!(sgd_opt.method_name(), "SGD");
647
648 let adam_opt = optimizers::adam(0.001);
649 assert_eq!(adam_opt.method_name(), "Adam");
650
651 let adamw_opt = optimizers::adamw(0.001);
652 assert_eq!(adamw_opt.method_name(), "AdamW");
653 }
654}