scirs2_optimize/
neural_integration.rs

1//! Integration with scirs2-neural for machine learning optimization
2//!
3//! This module provides wrappers and utilities to use scirs2-optimize's
4//! stochastic optimizers with scirs2-neural's neural network models.
5
6use crate::error::OptimizeError;
7use crate::stochastic::{StochasticMethod, StochasticOptions};
8use scirs2_core::ndarray::{s, Array1, ScalarOperand};
9use scirs2_core::numeric::Float;
10use std::collections::HashMap;
11
12/// Neural network parameter container
13#[derive(Debug, Clone)]
14pub struct NeuralParameters<F: Float + ScalarOperand> {
15    /// Parameter vectors for each layer
16    pub parameters: Vec<Array1<F>>,
17    /// Gradient vectors for each layer
18    pub gradients: Vec<Array1<F>>,
19    /// Parameter names/ids for tracking
20    pub names: Vec<String>,
21}
22
23impl<F: Float + ScalarOperand> Default for NeuralParameters<F> {
24    fn default() -> Self {
25        Self {
26            parameters: Vec::new(),
27            gradients: Vec::new(),
28            names: Vec::new(),
29        }
30    }
31}
32
33impl<F: Float + ScalarOperand> NeuralParameters<F> {
34    /// Create new neural parameters container
35    pub fn new() -> Self {
36        Self::default()
37    }
38
39    /// Add a parameter vector
40    pub fn add_parameter(&mut self, name: String, param: Array1<F>) {
41        self.names.push(name);
42        self.gradients.push(Array1::zeros(param.raw_dim()));
43        self.parameters.push(param);
44    }
45
46    /// Get total number of parameters
47    pub fn total_parameters(&self) -> usize {
48        self.parameters.iter().map(|p| p.len()).sum()
49    }
50
51    /// Flatten all parameters into a single vector
52    pub fn flatten_parameters(&self) -> Array1<F> {
53        let total_len = self.total_parameters();
54        let mut flat = Array1::zeros(total_len);
55        let mut offset = 0;
56
57        for param in &self.parameters {
58            let len = param.len();
59            flat.slice_mut(s![offset..offset + len]).assign(param);
60            offset += len;
61        }
62
63        flat
64    }
65
66    /// Flatten all gradients into a single vector
67    pub fn flatten_gradients(&self) -> Array1<F> {
68        let total_len = self.total_parameters();
69        let mut flat = Array1::zeros(total_len);
70        let mut offset = 0;
71
72        for grad in &self.gradients {
73            let len = grad.len();
74            flat.slice_mut(s![offset..offset + len]).assign(grad);
75            offset += len;
76        }
77
78        flat
79    }
80
81    /// Update parameters from flattened vector
82    pub fn update_from_flat(&mut self, flat_params: &Array1<F>) {
83        let mut offset = 0;
84
85        for param in &mut self.parameters {
86            let len = param.len();
87            param.assign(&flat_params.slice(s![offset..offset + len]));
88            offset += len;
89        }
90    }
91
92    /// Update gradients from flattened vector
93    pub fn update_gradients_from_flat(&mut self, flat_grads: &Array1<F>) {
94        let mut offset = 0;
95
96        for grad in &mut self.gradients {
97            let len = grad.len();
98            grad.assign(&flat_grads.slice(s![offset..offset + len]));
99            offset += len;
100        }
101    }
102}
103
104/// Neural network optimizer that uses scirs2-optimize stochastic methods
105pub struct NeuralOptimizer<F: Float + ScalarOperand> {
106    method: StochasticMethod,
107    options: StochasticOptions,
108    /// Internal state for momentum-based optimizers
109    momentum_buffers: HashMap<String, Array1<F>>,
110    /// Internal state for Adam-family optimizers
111    first_moment: HashMap<String, Array1<F>>,
112    second_moment: HashMap<String, Array1<F>>,
113    /// Step counter for bias correction
114    step_count: usize,
115}
116
117impl<F: Float + ScalarOperand> NeuralOptimizer<F>
118where
119    F: 'static + Send + Sync,
120{
121    /// Create a new neural optimizer
122    pub fn new(method: StochasticMethod, options: StochasticOptions) -> Self {
123        Self {
124            method,
125            options,
126            momentum_buffers: HashMap::new(),
127            first_moment: HashMap::new(),
128            second_moment: HashMap::new(),
129            step_count: 0,
130        }
131    }
132
133    /// Create SGD optimizer for neural networks
134    pub fn sgd(learning_rate: F, max_iter: usize) -> Self {
135        let options = StochasticOptions {
136            learning_rate: learning_rate.to_f64().unwrap_or(0.01),
137            max_iter,
138            batch_size: None,
139            tol: 1e-6,
140            adaptive_lr: false,
141            lr_decay: 0.99,
142            lr_schedule: crate::stochastic::LearningRateSchedule::Constant,
143            gradient_clip: None,
144            early_stopping_patience: None,
145        };
146
147        Self::new(StochasticMethod::SGD, options)
148    }
149
150    /// Create Adam optimizer for neural networks
151    pub fn adam(learning_rate: F, max_iter: usize) -> Self {
152        let options = StochasticOptions {
153            learning_rate: learning_rate.to_f64().unwrap_or(0.001),
154            max_iter,
155            batch_size: None,
156            tol: 1e-6,
157            adaptive_lr: false,
158            lr_decay: 0.99,
159            lr_schedule: crate::stochastic::LearningRateSchedule::Constant,
160            gradient_clip: Some(1.0),
161            early_stopping_patience: None,
162        };
163
164        Self::new(StochasticMethod::Adam, options)
165    }
166
167    /// Create AdamW optimizer for neural networks
168    pub fn adamw(learning_rate: F, max_iter: usize) -> Self {
169        let options = StochasticOptions {
170            learning_rate: learning_rate.to_f64().unwrap_or(0.001),
171            max_iter,
172            batch_size: None,
173            tol: 1e-6,
174            adaptive_lr: false,
175            lr_decay: 0.99,
176            lr_schedule: crate::stochastic::LearningRateSchedule::Constant,
177            gradient_clip: Some(1.0),
178            early_stopping_patience: None,
179        };
180
181        Self::new(StochasticMethod::AdamW, options)
182    }
183
184    /// Update neural network parameters using the selected optimizer
185    pub fn step(&mut self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
186        self.step_count += 1;
187
188        match self.method {
189            StochasticMethod::SGD => self.sgd_step(params),
190            StochasticMethod::Momentum => self.momentum_step(params),
191            StochasticMethod::Adam => self.adam_step(params),
192            StochasticMethod::AdamW => self.adamw_step(params),
193            StochasticMethod::RMSProp => self.rmsprop_step(params),
194        }
195    }
196
197    /// SGD parameter update
198    fn sgd_step(&self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
199        let lr = F::from(self.options.learning_rate)
200            .unwrap_or_else(|| F::from(0.01).expect("Failed to convert constant to float"));
201
202        for (param, grad) in params.parameters.iter_mut().zip(params.gradients.iter()) {
203            *param = param.clone() - &(grad.clone() * lr);
204        }
205
206        Ok(())
207    }
208
209    /// Momentum SGD parameter update
210    fn momentum_step(&mut self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
211        let lr = F::from(self.options.learning_rate)
212            .unwrap_or_else(|| F::from(0.01).expect("Failed to convert constant to float"));
213        let momentum = F::from(0.9).expect("Failed to convert constant to float");
214
215        for (i, (param, grad)) in params
216            .parameters
217            .iter_mut()
218            .zip(params.gradients.iter())
219            .enumerate()
220        {
221            let param_name = format!("param_{}", i);
222
223            // Initialize momentum buffer if not exists
224            if !self.momentum_buffers.contains_key(&param_name) {
225                self.momentum_buffers
226                    .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
227            }
228
229            let momentum_buffer = self
230                .momentum_buffers
231                .get_mut(&param_name)
232                .expect("Operation failed");
233
234            // v = momentum * v + grad
235            *momentum_buffer = momentum_buffer.clone() * momentum + grad;
236
237            // param = param - lr * v
238            *param = param.clone() - &(momentum_buffer.clone() * lr);
239        }
240
241        Ok(())
242    }
243
244    /// Adam parameter update
245    fn adam_step(&mut self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
246        let lr = F::from(self.options.learning_rate)
247            .unwrap_or_else(|| F::from(0.001).expect("Failed to convert constant to float"));
248        let beta1 = F::from(0.9).expect("Failed to convert constant to float");
249        let beta2 = F::from(0.999).expect("Failed to convert constant to float");
250        let epsilon = F::from(1e-8).expect("Failed to convert constant to float");
251
252        for (i, (param, grad)) in params
253            .parameters
254            .iter_mut()
255            .zip(params.gradients.iter())
256            .enumerate()
257        {
258            let param_name = format!("param_{}", i);
259
260            // Initialize moment buffers if not exists
261            if !self.first_moment.contains_key(&param_name) {
262                self.first_moment
263                    .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
264                self.second_moment
265                    .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
266            }
267
268            let m = self
269                .first_moment
270                .get_mut(&param_name)
271                .expect("Operation failed");
272            let v = self
273                .second_moment
274                .get_mut(&param_name)
275                .expect("Operation failed");
276
277            // m = beta1 * m + (1 - beta1) * grad
278            *m = m.clone() * beta1 + &(grad.clone() * (F::one() - beta1));
279
280            // v = beta2 * v + (1 - beta2) * grad^2
281            let grad_squared = grad.mapv(|x| x * x);
282            *v = v.clone() * beta2 + &(grad_squared * (F::one() - beta2));
283
284            // Bias correction
285            let step_f = F::from(self.step_count).expect("Failed to convert to float");
286            let m_hat = m.clone() / (F::one() - beta1.powf(step_f));
287            let v_hat = v.clone() / (F::one() - beta2.powf(step_f));
288
289            // param = param - lr * m_hat / (sqrt(v_hat) + epsilon)
290            let denominator = v_hat.mapv(|x| x.sqrt()) + epsilon;
291            let update = m_hat / denominator * lr;
292            *param = param.clone() - &update;
293        }
294
295        Ok(())
296    }
297
298    /// AdamW parameter update (Adam with decoupled weight decay)
299    fn adamw_step(&mut self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
300        let lr = F::from(self.options.learning_rate)
301            .unwrap_or_else(|| F::from(0.001).expect("Failed to convert constant to float"));
302        let beta1 = F::from(0.9).expect("Failed to convert constant to float");
303        let beta2 = F::from(0.999).expect("Failed to convert constant to float");
304        let epsilon = F::from(1e-8).expect("Failed to convert constant to float");
305        let weight_decay = F::from(0.01).expect("Failed to convert constant to float");
306
307        for (i, (param, grad)) in params
308            .parameters
309            .iter_mut()
310            .zip(params.gradients.iter())
311            .enumerate()
312        {
313            let param_name = format!("param_{}", i);
314
315            // Initialize moment buffers if not exists
316            if !self.first_moment.contains_key(&param_name) {
317                self.first_moment
318                    .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
319                self.second_moment
320                    .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
321            }
322
323            let m = self
324                .first_moment
325                .get_mut(&param_name)
326                .expect("Operation failed");
327            let v = self
328                .second_moment
329                .get_mut(&param_name)
330                .expect("Operation failed");
331
332            // m = beta1 * m + (1 - beta1) * grad
333            *m = m.clone() * beta1 + &(grad.clone() * (F::one() - beta1));
334
335            // v = beta2 * v + (1 - beta2) * grad^2
336            let grad_squared = grad.mapv(|x| x * x);
337            *v = v.clone() * beta2 + &(grad_squared * (F::one() - beta2));
338
339            // Bias correction
340            let step_f = F::from(self.step_count).expect("Failed to convert to float");
341            let m_hat = m.clone() / (F::one() - beta1.powf(step_f));
342            let v_hat = v.clone() / (F::one() - beta2.powf(step_f));
343
344            // AdamW update: param = param - lr * (m_hat / (sqrt(v_hat) + epsilon) + weight_decay * param)
345            let denominator = v_hat.mapv(|x| x.sqrt()) + epsilon;
346            let adam_update = m_hat / denominator;
347            let weight_decay_update = param.clone() * weight_decay;
348            let total_update = (adam_update + weight_decay_update) * lr;
349
350            *param = param.clone() - &total_update;
351        }
352
353        Ok(())
354    }
355
356    /// RMSprop parameter update
357    fn rmsprop_step(&mut self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
358        let lr = F::from(self.options.learning_rate)
359            .unwrap_or_else(|| F::from(0.001).expect("Failed to convert constant to float"));
360        let alpha = F::from(0.99).expect("Failed to convert constant to float"); // decay rate
361        let epsilon = F::from(1e-8).expect("Failed to convert constant to float");
362
363        for (i, (param, grad)) in params
364            .parameters
365            .iter_mut()
366            .zip(params.gradients.iter())
367            .enumerate()
368        {
369            let param_name = format!("param_{}", i);
370
371            // Initialize squared average buffer if not exists
372            if !self.second_moment.contains_key(&param_name) {
373                self.second_moment
374                    .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
375            }
376
377            let v = self
378                .second_moment
379                .get_mut(&param_name)
380                .expect("Operation failed");
381
382            // v = alpha * v + (1 - alpha) * grad^2
383            let grad_squared = grad.mapv(|x| x * x);
384            *v = v.clone() * alpha + &(grad_squared * (F::one() - alpha));
385
386            // param = param - lr * grad / (sqrt(v) + epsilon)
387            let denominator = v.mapv(|x| x.sqrt()) + epsilon;
388            let update = grad.clone() / denominator * lr;
389            *param = param.clone() - &update;
390        }
391
392        Ok(())
393    }
394
395    /// Get current learning rate
396    pub fn get_learning_rate(&self) -> f64 {
397        self.options.learning_rate
398    }
399
400    /// Set learning rate
401    pub fn set_learning_rate(&mut self, lr: f64) {
402        self.options.learning_rate = lr;
403    }
404
405    /// Reset internal state
406    pub fn reset(&mut self) {
407        self.momentum_buffers.clear();
408        self.first_moment.clear();
409        self.second_moment.clear();
410        self.step_count = 0;
411    }
412
413    /// Get optimizer method name
414    pub fn method_name(&self) -> &'static str {
415        match self.method {
416            StochasticMethod::SGD => "SGD",
417            StochasticMethod::Momentum => "SGD with Momentum",
418            StochasticMethod::Adam => "Adam",
419            StochasticMethod::AdamW => "AdamW",
420            StochasticMethod::RMSProp => "RMSprop",
421        }
422    }
423}
424
425/// Neural network trainer that combines optimization with training loops
426pub struct NeuralTrainer<F: Float + ScalarOperand> {
427    optimizer: NeuralOptimizer<F>,
428    loss_history: Vec<F>,
429    early_stopping_patience: Option<usize>,
430    best_loss: Option<F>,
431    patience_counter: usize,
432}
433
434impl<F: Float + ScalarOperand> NeuralTrainer<F>
435where
436    F: 'static + Send + Sync + std::fmt::Display,
437{
438    /// Create a new neural trainer
439    pub fn new(optimizer: NeuralOptimizer<F>) -> Self {
440        Self {
441            optimizer,
442            loss_history: Vec::new(),
443            early_stopping_patience: None,
444            best_loss: None,
445            patience_counter: 0,
446        }
447    }
448
449    /// Set early stopping patience
450    pub fn with_early_stopping(mut self, patience: usize) -> Self {
451        self.early_stopping_patience = Some(patience);
452        self
453    }
454
455    /// Train for one epoch
456    pub fn train_epoch<LossFn, GradFn>(
457        &mut self,
458        params: &mut NeuralParameters<F>,
459        loss_fn: &mut LossFn,
460        grad_fn: &mut GradFn,
461    ) -> Result<F, OptimizeError>
462    where
463        LossFn: FnMut(&NeuralParameters<F>) -> F,
464        GradFn: FnMut(&NeuralParameters<F>) -> Vec<Array1<F>>,
465    {
466        // Compute gradients
467        let gradients = grad_fn(params);
468        params.gradients = gradients;
469
470        // Apply gradient clipping if specified
471        if let Some(max_norm) = self.optimizer.options.gradient_clip {
472            self.clip_gradients(params, max_norm);
473        }
474
475        // Update parameters
476        self.optimizer.step(params)?;
477
478        // Compute loss
479        let loss = loss_fn(params);
480        self.loss_history.push(loss);
481
482        // Check early stopping
483        if let Some(_patience) = self.early_stopping_patience {
484            if let Some(best_loss) = self.best_loss {
485                if loss < best_loss {
486                    self.best_loss = Some(loss);
487                    self.patience_counter = 0;
488                } else {
489                    self.patience_counter += 1;
490                }
491            } else {
492                self.best_loss = Some(loss);
493            }
494        }
495
496        Ok(loss)
497    }
498
499    /// Check if training should stop early
500    pub fn should_stop_early(&self) -> bool {
501        if let Some(patience) = self.early_stopping_patience {
502            self.patience_counter >= patience
503        } else {
504            false
505        }
506    }
507
508    /// Get loss history
509    pub fn loss_history(&self) -> &[F] {
510        &self.loss_history
511    }
512
513    /// Get current learning rate
514    pub fn learning_rate(&self) -> f64 {
515        self.optimizer.get_learning_rate()
516    }
517
518    /// Set learning rate
519    pub fn set_learning_rate(&mut self, lr: f64) {
520        self.optimizer.set_learning_rate(lr);
521    }
522
523    /// Clip gradients to prevent exploding gradients
524    fn clip_gradients(&self, params: &mut NeuralParameters<F>, max_norm: f64) {
525        let max_norm_f = F::from(max_norm).expect("Failed to convert to float");
526
527        // Compute total gradient norm
528        let mut total_norm_sq = F::zero();
529        for grad in &params.gradients {
530            total_norm_sq = total_norm_sq + grad.mapv(|x| x * x).sum();
531        }
532        let total_norm = total_norm_sq.sqrt();
533
534        if total_norm > max_norm_f {
535            let scale = max_norm_f / total_norm;
536            for grad in &mut params.gradients {
537                grad.mapv_inplace(|x| x * scale);
538            }
539        }
540    }
541}
542
543/// Convenience functions for creating neural optimizers
544pub mod optimizers {
545    use super::*;
546
547    /// Create SGD optimizer with default settings for neural networks
548    pub fn sgd<F>(learning_rate: F) -> NeuralOptimizer<F>
549    where
550        F: Float + ScalarOperand + 'static + Send + Sync,
551    {
552        NeuralOptimizer::sgd(learning_rate, 1000)
553    }
554
555    /// Create Adam optimizer with default settings for neural networks
556    pub fn adam<F>(learning_rate: F) -> NeuralOptimizer<F>
557    where
558        F: Float + ScalarOperand + 'static + Send + Sync,
559    {
560        NeuralOptimizer::adam(learning_rate, 1000)
561    }
562
563    /// Create AdamW optimizer with default settings for neural networks
564    pub fn adamw<F>(learning_rate: F) -> NeuralOptimizer<F>
565    where
566        F: Float + ScalarOperand + 'static + Send + Sync,
567    {
568        NeuralOptimizer::adamw(learning_rate, 1000)
569    }
570}
571
572#[cfg(test)]
573mod tests {
574    use super::*;
575    use approx::assert_abs_diff_eq;
576
577    #[test]
578    fn test_neural_parameters() {
579        let mut params = NeuralParameters::<f64>::new();
580
581        // Add some parameters
582        params.add_parameter("layer1".to_string(), Array1::from_vec(vec![1.0, 2.0, 3.0]));
583        params.add_parameter("layer2".to_string(), Array1::from_vec(vec![4.0, 5.0]));
584
585        assert_eq!(params.total_parameters(), 5);
586
587        // Test flattening
588        let flat = params.flatten_parameters();
589        assert_eq!(
590            flat.as_slice().expect("Operation failed"),
591            &[1.0, 2.0, 3.0, 4.0, 5.0]
592        );
593
594        // Test updating from flat
595        let new_flat = Array1::from_vec(vec![6.0, 7.0, 8.0, 9.0, 10.0]);
596        params.update_from_flat(&new_flat);
597
598        assert_eq!(
599            params.parameters[0].as_slice().expect("Operation failed"),
600            &[6.0, 7.0, 8.0]
601        );
602        assert_eq!(
603            params.parameters[1].as_slice().expect("Operation failed"),
604            &[9.0, 10.0]
605        );
606    }
607
608    #[test]
609    fn test_sgd_optimizer() {
610        let mut optimizer = NeuralOptimizer::sgd(0.1, 100);
611        let mut params = NeuralParameters::<f64>::new();
612
613        // Add parameter
614        params.add_parameter("test".to_string(), Array1::from_vec(vec![1.0, 2.0]));
615        // Set gradient
616        params.gradients[0] = Array1::from_vec(vec![0.5, 1.0]);
617
618        // Perform one step
619        optimizer.step(&mut params).expect("Operation failed");
620
621        // Check update: param = param - lr * grad
622        let expected = [1.0 - 0.1 * 0.5, 2.0 - 0.1 * 1.0];
623        assert_abs_diff_eq!(params.parameters[0][0], expected[0], epsilon = 1e-10);
624        assert_abs_diff_eq!(params.parameters[0][1], expected[1], epsilon = 1e-10);
625    }
626
627    #[test]
628    fn test_adam_optimizer() {
629        let mut optimizer = NeuralOptimizer::adam(0.001, 100);
630        let mut params = NeuralParameters::<f64>::new();
631
632        // Add parameter
633        params.add_parameter("test".to_string(), Array1::from_vec(vec![1.0, 2.0]));
634        // Set gradient
635        params.gradients[0] = Array1::from_vec(vec![0.1, 0.2]);
636
637        let original_params = params.parameters[0].clone();
638
639        // Perform one step
640        optimizer.step(&mut params).expect("Operation failed");
641
642        // Parameters should have changed
643        assert_ne!(params.parameters[0][0], original_params[0]);
644        assert_ne!(params.parameters[0][1], original_params[1]);
645
646        // Parameters should have decreased (since gradients are positive)
647        assert!(params.parameters[0][0] < original_params[0]);
648        assert!(params.parameters[0][1] < original_params[1]);
649    }
650
651    #[test]
652    fn test_neural_trainer() {
653        let optimizer = NeuralOptimizer::sgd(0.1, 100);
654        let mut trainer = NeuralTrainer::new(optimizer).with_early_stopping(5);
655
656        let mut params = NeuralParameters::<f64>::new();
657        params.add_parameter("test".to_string(), Array1::from_vec(vec![1.0]));
658        params.gradients[0] = Array1::from_vec(vec![1.0]);
659
660        // Simple quadratic loss function
661        let mut loss_fn = |p: &NeuralParameters<f64>| p.parameters[0][0] * p.parameters[0][0];
662        let mut grad_fn =
663            |p: &NeuralParameters<f64>| vec![Array1::from_vec(vec![2.0 * p.parameters[0][0]])];
664
665        // Train for one epoch
666        let loss = trainer
667            .train_epoch(&mut params, &mut loss_fn, &mut grad_fn)
668            .expect("Operation failed");
669
670        // Loss should be computed
671        assert_eq!(trainer.loss_history().len(), 1);
672        assert_eq!(trainer.loss_history()[0], loss);
673    }
674
675    #[test]
676    fn test_optimizer_convenience_functions() {
677        let sgd_opt = optimizers::sgd(0.01);
678        assert_eq!(sgd_opt.method_name(), "SGD");
679
680        let adam_opt = optimizers::adam(0.001);
681        assert_eq!(adam_opt.method_name(), "Adam");
682
683        let adamw_opt = optimizers::adamw(0.001);
684        assert_eq!(adamw_opt.method_name(), "AdamW");
685    }
686}