1use crate::error::OptimizeError;
7use crate::stochastic::{StochasticMethod, StochasticOptions};
8use scirs2_core::ndarray::{s, Array1, ScalarOperand};
9use scirs2_core::numeric::Float;
10use std::collections::HashMap;
11
12#[derive(Debug, Clone)]
14pub struct NeuralParameters<F: Float + ScalarOperand> {
15 pub parameters: Vec<Array1<F>>,
17 pub gradients: Vec<Array1<F>>,
19 pub names: Vec<String>,
21}
22
23impl<F: Float + ScalarOperand> Default for NeuralParameters<F> {
24 fn default() -> Self {
25 Self {
26 parameters: Vec::new(),
27 gradients: Vec::new(),
28 names: Vec::new(),
29 }
30 }
31}
32
33impl<F: Float + ScalarOperand> NeuralParameters<F> {
34 pub fn new() -> Self {
36 Self::default()
37 }
38
39 pub fn add_parameter(&mut self, name: String, param: Array1<F>) {
41 self.names.push(name);
42 self.gradients.push(Array1::zeros(param.raw_dim()));
43 self.parameters.push(param);
44 }
45
46 pub fn total_parameters(&self) -> usize {
48 self.parameters.iter().map(|p| p.len()).sum()
49 }
50
51 pub fn flatten_parameters(&self) -> Array1<F> {
53 let total_len = self.total_parameters();
54 let mut flat = Array1::zeros(total_len);
55 let mut offset = 0;
56
57 for param in &self.parameters {
58 let len = param.len();
59 flat.slice_mut(s![offset..offset + len]).assign(param);
60 offset += len;
61 }
62
63 flat
64 }
65
66 pub fn flatten_gradients(&self) -> Array1<F> {
68 let total_len = self.total_parameters();
69 let mut flat = Array1::zeros(total_len);
70 let mut offset = 0;
71
72 for grad in &self.gradients {
73 let len = grad.len();
74 flat.slice_mut(s![offset..offset + len]).assign(grad);
75 offset += len;
76 }
77
78 flat
79 }
80
81 pub fn update_from_flat(&mut self, flat_params: &Array1<F>) {
83 let mut offset = 0;
84
85 for param in &mut self.parameters {
86 let len = param.len();
87 param.assign(&flat_params.slice(s![offset..offset + len]));
88 offset += len;
89 }
90 }
91
92 pub fn update_gradients_from_flat(&mut self, flat_grads: &Array1<F>) {
94 let mut offset = 0;
95
96 for grad in &mut self.gradients {
97 let len = grad.len();
98 grad.assign(&flat_grads.slice(s![offset..offset + len]));
99 offset += len;
100 }
101 }
102}
103
104pub struct NeuralOptimizer<F: Float + ScalarOperand> {
106 method: StochasticMethod,
107 options: StochasticOptions,
108 momentum_buffers: HashMap<String, Array1<F>>,
110 first_moment: HashMap<String, Array1<F>>,
112 second_moment: HashMap<String, Array1<F>>,
113 step_count: usize,
115}
116
117impl<F: Float + ScalarOperand> NeuralOptimizer<F>
118where
119 F: 'static + Send + Sync,
120{
121 pub fn new(method: StochasticMethod, options: StochasticOptions) -> Self {
123 Self {
124 method,
125 options,
126 momentum_buffers: HashMap::new(),
127 first_moment: HashMap::new(),
128 second_moment: HashMap::new(),
129 step_count: 0,
130 }
131 }
132
133 pub fn sgd(learning_rate: F, max_iter: usize) -> Self {
135 let options = StochasticOptions {
136 learning_rate: learning_rate.to_f64().unwrap_or(0.01),
137 max_iter,
138 batch_size: None,
139 tol: 1e-6,
140 adaptive_lr: false,
141 lr_decay: 0.99,
142 lr_schedule: crate::stochastic::LearningRateSchedule::Constant,
143 gradient_clip: None,
144 early_stopping_patience: None,
145 };
146
147 Self::new(StochasticMethod::SGD, options)
148 }
149
150 pub fn adam(learning_rate: F, max_iter: usize) -> Self {
152 let options = StochasticOptions {
153 learning_rate: learning_rate.to_f64().unwrap_or(0.001),
154 max_iter,
155 batch_size: None,
156 tol: 1e-6,
157 adaptive_lr: false,
158 lr_decay: 0.99,
159 lr_schedule: crate::stochastic::LearningRateSchedule::Constant,
160 gradient_clip: Some(1.0),
161 early_stopping_patience: None,
162 };
163
164 Self::new(StochasticMethod::Adam, options)
165 }
166
167 pub fn adamw(learning_rate: F, max_iter: usize) -> Self {
169 let options = StochasticOptions {
170 learning_rate: learning_rate.to_f64().unwrap_or(0.001),
171 max_iter,
172 batch_size: None,
173 tol: 1e-6,
174 adaptive_lr: false,
175 lr_decay: 0.99,
176 lr_schedule: crate::stochastic::LearningRateSchedule::Constant,
177 gradient_clip: Some(1.0),
178 early_stopping_patience: None,
179 };
180
181 Self::new(StochasticMethod::AdamW, options)
182 }
183
184 pub fn step(&mut self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
186 self.step_count += 1;
187
188 match self.method {
189 StochasticMethod::SGD => self.sgd_step(params),
190 StochasticMethod::Momentum => self.momentum_step(params),
191 StochasticMethod::Adam => self.adam_step(params),
192 StochasticMethod::AdamW => self.adamw_step(params),
193 StochasticMethod::RMSProp => self.rmsprop_step(params),
194 }
195 }
196
197 fn sgd_step(&self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
199 let lr = F::from(self.options.learning_rate)
200 .unwrap_or_else(|| F::from(0.01).expect("Failed to convert constant to float"));
201
202 for (param, grad) in params.parameters.iter_mut().zip(params.gradients.iter()) {
203 *param = param.clone() - &(grad.clone() * lr);
204 }
205
206 Ok(())
207 }
208
209 fn momentum_step(&mut self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
211 let lr = F::from(self.options.learning_rate)
212 .unwrap_or_else(|| F::from(0.01).expect("Failed to convert constant to float"));
213 let momentum = F::from(0.9).expect("Failed to convert constant to float");
214
215 for (i, (param, grad)) in params
216 .parameters
217 .iter_mut()
218 .zip(params.gradients.iter())
219 .enumerate()
220 {
221 let param_name = format!("param_{}", i);
222
223 if !self.momentum_buffers.contains_key(¶m_name) {
225 self.momentum_buffers
226 .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
227 }
228
229 let momentum_buffer = self
230 .momentum_buffers
231 .get_mut(¶m_name)
232 .expect("Operation failed");
233
234 *momentum_buffer = momentum_buffer.clone() * momentum + grad;
236
237 *param = param.clone() - &(momentum_buffer.clone() * lr);
239 }
240
241 Ok(())
242 }
243
244 fn adam_step(&mut self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
246 let lr = F::from(self.options.learning_rate)
247 .unwrap_or_else(|| F::from(0.001).expect("Failed to convert constant to float"));
248 let beta1 = F::from(0.9).expect("Failed to convert constant to float");
249 let beta2 = F::from(0.999).expect("Failed to convert constant to float");
250 let epsilon = F::from(1e-8).expect("Failed to convert constant to float");
251
252 for (i, (param, grad)) in params
253 .parameters
254 .iter_mut()
255 .zip(params.gradients.iter())
256 .enumerate()
257 {
258 let param_name = format!("param_{}", i);
259
260 if !self.first_moment.contains_key(¶m_name) {
262 self.first_moment
263 .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
264 self.second_moment
265 .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
266 }
267
268 let m = self
269 .first_moment
270 .get_mut(¶m_name)
271 .expect("Operation failed");
272 let v = self
273 .second_moment
274 .get_mut(¶m_name)
275 .expect("Operation failed");
276
277 *m = m.clone() * beta1 + &(grad.clone() * (F::one() - beta1));
279
280 let grad_squared = grad.mapv(|x| x * x);
282 *v = v.clone() * beta2 + &(grad_squared * (F::one() - beta2));
283
284 let step_f = F::from(self.step_count).expect("Failed to convert to float");
286 let m_hat = m.clone() / (F::one() - beta1.powf(step_f));
287 let v_hat = v.clone() / (F::one() - beta2.powf(step_f));
288
289 let denominator = v_hat.mapv(|x| x.sqrt()) + epsilon;
291 let update = m_hat / denominator * lr;
292 *param = param.clone() - &update;
293 }
294
295 Ok(())
296 }
297
298 fn adamw_step(&mut self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
300 let lr = F::from(self.options.learning_rate)
301 .unwrap_or_else(|| F::from(0.001).expect("Failed to convert constant to float"));
302 let beta1 = F::from(0.9).expect("Failed to convert constant to float");
303 let beta2 = F::from(0.999).expect("Failed to convert constant to float");
304 let epsilon = F::from(1e-8).expect("Failed to convert constant to float");
305 let weight_decay = F::from(0.01).expect("Failed to convert constant to float");
306
307 for (i, (param, grad)) in params
308 .parameters
309 .iter_mut()
310 .zip(params.gradients.iter())
311 .enumerate()
312 {
313 let param_name = format!("param_{}", i);
314
315 if !self.first_moment.contains_key(¶m_name) {
317 self.first_moment
318 .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
319 self.second_moment
320 .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
321 }
322
323 let m = self
324 .first_moment
325 .get_mut(¶m_name)
326 .expect("Operation failed");
327 let v = self
328 .second_moment
329 .get_mut(¶m_name)
330 .expect("Operation failed");
331
332 *m = m.clone() * beta1 + &(grad.clone() * (F::one() - beta1));
334
335 let grad_squared = grad.mapv(|x| x * x);
337 *v = v.clone() * beta2 + &(grad_squared * (F::one() - beta2));
338
339 let step_f = F::from(self.step_count).expect("Failed to convert to float");
341 let m_hat = m.clone() / (F::one() - beta1.powf(step_f));
342 let v_hat = v.clone() / (F::one() - beta2.powf(step_f));
343
344 let denominator = v_hat.mapv(|x| x.sqrt()) + epsilon;
346 let adam_update = m_hat / denominator;
347 let weight_decay_update = param.clone() * weight_decay;
348 let total_update = (adam_update + weight_decay_update) * lr;
349
350 *param = param.clone() - &total_update;
351 }
352
353 Ok(())
354 }
355
356 fn rmsprop_step(&mut self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
358 let lr = F::from(self.options.learning_rate)
359 .unwrap_or_else(|| F::from(0.001).expect("Failed to convert constant to float"));
360 let alpha = F::from(0.99).expect("Failed to convert constant to float"); let epsilon = F::from(1e-8).expect("Failed to convert constant to float");
362
363 for (i, (param, grad)) in params
364 .parameters
365 .iter_mut()
366 .zip(params.gradients.iter())
367 .enumerate()
368 {
369 let param_name = format!("param_{}", i);
370
371 if !self.second_moment.contains_key(¶m_name) {
373 self.second_moment
374 .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
375 }
376
377 let v = self
378 .second_moment
379 .get_mut(¶m_name)
380 .expect("Operation failed");
381
382 let grad_squared = grad.mapv(|x| x * x);
384 *v = v.clone() * alpha + &(grad_squared * (F::one() - alpha));
385
386 let denominator = v.mapv(|x| x.sqrt()) + epsilon;
388 let update = grad.clone() / denominator * lr;
389 *param = param.clone() - &update;
390 }
391
392 Ok(())
393 }
394
395 pub fn get_learning_rate(&self) -> f64 {
397 self.options.learning_rate
398 }
399
400 pub fn set_learning_rate(&mut self, lr: f64) {
402 self.options.learning_rate = lr;
403 }
404
405 pub fn reset(&mut self) {
407 self.momentum_buffers.clear();
408 self.first_moment.clear();
409 self.second_moment.clear();
410 self.step_count = 0;
411 }
412
413 pub fn method_name(&self) -> &'static str {
415 match self.method {
416 StochasticMethod::SGD => "SGD",
417 StochasticMethod::Momentum => "SGD with Momentum",
418 StochasticMethod::Adam => "Adam",
419 StochasticMethod::AdamW => "AdamW",
420 StochasticMethod::RMSProp => "RMSprop",
421 }
422 }
423}
424
425pub struct NeuralTrainer<F: Float + ScalarOperand> {
427 optimizer: NeuralOptimizer<F>,
428 loss_history: Vec<F>,
429 early_stopping_patience: Option<usize>,
430 best_loss: Option<F>,
431 patience_counter: usize,
432}
433
434impl<F: Float + ScalarOperand> NeuralTrainer<F>
435where
436 F: 'static + Send + Sync + std::fmt::Display,
437{
438 pub fn new(optimizer: NeuralOptimizer<F>) -> Self {
440 Self {
441 optimizer,
442 loss_history: Vec::new(),
443 early_stopping_patience: None,
444 best_loss: None,
445 patience_counter: 0,
446 }
447 }
448
449 pub fn with_early_stopping(mut self, patience: usize) -> Self {
451 self.early_stopping_patience = Some(patience);
452 self
453 }
454
455 pub fn train_epoch<LossFn, GradFn>(
457 &mut self,
458 params: &mut NeuralParameters<F>,
459 loss_fn: &mut LossFn,
460 grad_fn: &mut GradFn,
461 ) -> Result<F, OptimizeError>
462 where
463 LossFn: FnMut(&NeuralParameters<F>) -> F,
464 GradFn: FnMut(&NeuralParameters<F>) -> Vec<Array1<F>>,
465 {
466 let gradients = grad_fn(params);
468 params.gradients = gradients;
469
470 if let Some(max_norm) = self.optimizer.options.gradient_clip {
472 self.clip_gradients(params, max_norm);
473 }
474
475 self.optimizer.step(params)?;
477
478 let loss = loss_fn(params);
480 self.loss_history.push(loss);
481
482 if let Some(_patience) = self.early_stopping_patience {
484 if let Some(best_loss) = self.best_loss {
485 if loss < best_loss {
486 self.best_loss = Some(loss);
487 self.patience_counter = 0;
488 } else {
489 self.patience_counter += 1;
490 }
491 } else {
492 self.best_loss = Some(loss);
493 }
494 }
495
496 Ok(loss)
497 }
498
499 pub fn should_stop_early(&self) -> bool {
501 if let Some(patience) = self.early_stopping_patience {
502 self.patience_counter >= patience
503 } else {
504 false
505 }
506 }
507
508 pub fn loss_history(&self) -> &[F] {
510 &self.loss_history
511 }
512
513 pub fn learning_rate(&self) -> f64 {
515 self.optimizer.get_learning_rate()
516 }
517
518 pub fn set_learning_rate(&mut self, lr: f64) {
520 self.optimizer.set_learning_rate(lr);
521 }
522
523 fn clip_gradients(&self, params: &mut NeuralParameters<F>, max_norm: f64) {
525 let max_norm_f = F::from(max_norm).expect("Failed to convert to float");
526
527 let mut total_norm_sq = F::zero();
529 for grad in ¶ms.gradients {
530 total_norm_sq = total_norm_sq + grad.mapv(|x| x * x).sum();
531 }
532 let total_norm = total_norm_sq.sqrt();
533
534 if total_norm > max_norm_f {
535 let scale = max_norm_f / total_norm;
536 for grad in &mut params.gradients {
537 grad.mapv_inplace(|x| x * scale);
538 }
539 }
540 }
541}
542
543pub mod optimizers {
545 use super::*;
546
547 pub fn sgd<F>(learning_rate: F) -> NeuralOptimizer<F>
549 where
550 F: Float + ScalarOperand + 'static + Send + Sync,
551 {
552 NeuralOptimizer::sgd(learning_rate, 1000)
553 }
554
555 pub fn adam<F>(learning_rate: F) -> NeuralOptimizer<F>
557 where
558 F: Float + ScalarOperand + 'static + Send + Sync,
559 {
560 NeuralOptimizer::adam(learning_rate, 1000)
561 }
562
563 pub fn adamw<F>(learning_rate: F) -> NeuralOptimizer<F>
565 where
566 F: Float + ScalarOperand + 'static + Send + Sync,
567 {
568 NeuralOptimizer::adamw(learning_rate, 1000)
569 }
570}
571
572#[cfg(test)]
573mod tests {
574 use super::*;
575 use approx::assert_abs_diff_eq;
576
577 #[test]
578 fn test_neural_parameters() {
579 let mut params = NeuralParameters::<f64>::new();
580
581 params.add_parameter("layer1".to_string(), Array1::from_vec(vec![1.0, 2.0, 3.0]));
583 params.add_parameter("layer2".to_string(), Array1::from_vec(vec![4.0, 5.0]));
584
585 assert_eq!(params.total_parameters(), 5);
586
587 let flat = params.flatten_parameters();
589 assert_eq!(
590 flat.as_slice().expect("Operation failed"),
591 &[1.0, 2.0, 3.0, 4.0, 5.0]
592 );
593
594 let new_flat = Array1::from_vec(vec![6.0, 7.0, 8.0, 9.0, 10.0]);
596 params.update_from_flat(&new_flat);
597
598 assert_eq!(
599 params.parameters[0].as_slice().expect("Operation failed"),
600 &[6.0, 7.0, 8.0]
601 );
602 assert_eq!(
603 params.parameters[1].as_slice().expect("Operation failed"),
604 &[9.0, 10.0]
605 );
606 }
607
608 #[test]
609 fn test_sgd_optimizer() {
610 let mut optimizer = NeuralOptimizer::sgd(0.1, 100);
611 let mut params = NeuralParameters::<f64>::new();
612
613 params.add_parameter("test".to_string(), Array1::from_vec(vec![1.0, 2.0]));
615 params.gradients[0] = Array1::from_vec(vec![0.5, 1.0]);
617
618 optimizer.step(&mut params).expect("Operation failed");
620
621 let expected = [1.0 - 0.1 * 0.5, 2.0 - 0.1 * 1.0];
623 assert_abs_diff_eq!(params.parameters[0][0], expected[0], epsilon = 1e-10);
624 assert_abs_diff_eq!(params.parameters[0][1], expected[1], epsilon = 1e-10);
625 }
626
627 #[test]
628 fn test_adam_optimizer() {
629 let mut optimizer = NeuralOptimizer::adam(0.001, 100);
630 let mut params = NeuralParameters::<f64>::new();
631
632 params.add_parameter("test".to_string(), Array1::from_vec(vec![1.0, 2.0]));
634 params.gradients[0] = Array1::from_vec(vec![0.1, 0.2]);
636
637 let original_params = params.parameters[0].clone();
638
639 optimizer.step(&mut params).expect("Operation failed");
641
642 assert_ne!(params.parameters[0][0], original_params[0]);
644 assert_ne!(params.parameters[0][1], original_params[1]);
645
646 assert!(params.parameters[0][0] < original_params[0]);
648 assert!(params.parameters[0][1] < original_params[1]);
649 }
650
651 #[test]
652 fn test_neural_trainer() {
653 let optimizer = NeuralOptimizer::sgd(0.1, 100);
654 let mut trainer = NeuralTrainer::new(optimizer).with_early_stopping(5);
655
656 let mut params = NeuralParameters::<f64>::new();
657 params.add_parameter("test".to_string(), Array1::from_vec(vec![1.0]));
658 params.gradients[0] = Array1::from_vec(vec![1.0]);
659
660 let mut loss_fn = |p: &NeuralParameters<f64>| p.parameters[0][0] * p.parameters[0][0];
662 let mut grad_fn =
663 |p: &NeuralParameters<f64>| vec![Array1::from_vec(vec![2.0 * p.parameters[0][0]])];
664
665 let loss = trainer
667 .train_epoch(&mut params, &mut loss_fn, &mut grad_fn)
668 .expect("Operation failed");
669
670 assert_eq!(trainer.loss_history().len(), 1);
672 assert_eq!(trainer.loss_history()[0], loss);
673 }
674
675 #[test]
676 fn test_optimizer_convenience_functions() {
677 let sgd_opt = optimizers::sgd(0.01);
678 assert_eq!(sgd_opt.method_name(), "SGD");
679
680 let adam_opt = optimizers::adam(0.001);
681 assert_eq!(adam_opt.method_name(), "Adam");
682
683 let adamw_opt = optimizers::adamw(0.001);
684 assert_eq!(adamw_opt.method_name(), "AdamW");
685 }
686}