scirs2_optimize/
neural_integration.rs

1//! Integration with scirs2-neural for machine learning optimization
2//!
3//! This module provides wrappers and utilities to use scirs2-optimize's
4//! stochastic optimizers with scirs2-neural's neural network models.
5
6use crate::error::OptimizeError;
7use crate::stochastic::{StochasticMethod, StochasticOptions};
8use ndarray::{s, Array1, ScalarOperand};
9use num_traits::Float;
10use std::collections::HashMap;
11
12/// Neural network parameter container
13#[derive(Debug, Clone)]
14pub struct NeuralParameters<F: Float + ScalarOperand> {
15    /// Parameter vectors for each layer
16    pub parameters: Vec<Array1<F>>,
17    /// Gradient vectors for each layer
18    pub gradients: Vec<Array1<F>>,
19    /// Parameter names/ids for tracking
20    pub names: Vec<String>,
21}
22
23impl<F: Float + ScalarOperand> Default for NeuralParameters<F> {
24    fn default() -> Self {
25        Self {
26            parameters: Vec::new(),
27            gradients: Vec::new(),
28            names: Vec::new(),
29        }
30    }
31}
32
33impl<F: Float + ScalarOperand> NeuralParameters<F> {
34    /// Create new neural parameters container
35    pub fn new() -> Self {
36        Self::default()
37    }
38
39    /// Add a parameter vector
40    pub fn add_parameter(&mut self, name: String, param: Array1<F>) {
41        self.names.push(name);
42        self.gradients.push(Array1::zeros(param.raw_dim()));
43        self.parameters.push(param);
44    }
45
46    /// Get total number of parameters
47    pub fn total_parameters(&self) -> usize {
48        self.parameters.iter().map(|p| p.len()).sum()
49    }
50
51    /// Flatten all parameters into a single vector
52    pub fn flatten_parameters(&self) -> Array1<F> {
53        let total_len = self.total_parameters();
54        let mut flat = Array1::zeros(total_len);
55        let mut offset = 0;
56
57        for param in &self.parameters {
58            let len = param.len();
59            flat.slice_mut(s![offset..offset + len]).assign(param);
60            offset += len;
61        }
62
63        flat
64    }
65
66    /// Flatten all gradients into a single vector
67    pub fn flatten_gradients(&self) -> Array1<F> {
68        let total_len = self.total_parameters();
69        let mut flat = Array1::zeros(total_len);
70        let mut offset = 0;
71
72        for grad in &self.gradients {
73            let len = grad.len();
74            flat.slice_mut(s![offset..offset + len]).assign(grad);
75            offset += len;
76        }
77
78        flat
79    }
80
81    /// Update parameters from flattened vector
82    pub fn update_from_flat(&mut self, flat_params: &Array1<F>) {
83        let mut offset = 0;
84
85        for param in &mut self.parameters {
86            let len = param.len();
87            param.assign(&flat_params.slice(s![offset..offset + len]));
88            offset += len;
89        }
90    }
91
92    /// Update gradients from flattened vector
93    pub fn update_gradients_from_flat(&mut self, flat_grads: &Array1<F>) {
94        let mut offset = 0;
95
96        for grad in &mut self.gradients {
97            let len = grad.len();
98            grad.assign(&flat_grads.slice(s![offset..offset + len]));
99            offset += len;
100        }
101    }
102}
103
104/// Neural network optimizer that uses scirs2-optimize stochastic methods
105pub struct NeuralOptimizer<F: Float + ScalarOperand> {
106    method: StochasticMethod,
107    options: StochasticOptions,
108    /// Internal state for momentum-based optimizers
109    momentum_buffers: HashMap<String, Array1<F>>,
110    /// Internal state for Adam-family optimizers
111    first_moment: HashMap<String, Array1<F>>,
112    second_moment: HashMap<String, Array1<F>>,
113    /// Step counter for bias correction
114    step_count: usize,
115}
116
117impl<F: Float + ScalarOperand> NeuralOptimizer<F>
118where
119    F: 'static + Send + Sync,
120{
121    /// Create a new neural optimizer
122    pub fn new(method: StochasticMethod, options: StochasticOptions) -> Self {
123        Self {
124            method,
125            options,
126            momentum_buffers: HashMap::new(),
127            first_moment: HashMap::new(),
128            second_moment: HashMap::new(),
129            step_count: 0,
130        }
131    }
132
133    /// Create SGD optimizer for neural networks
134    pub fn sgd(learning_rate: F, max_iter: usize) -> Self {
135        let options = StochasticOptions {
136            learning_rate: learning_rate.to_f64().unwrap_or(0.01),
137            max_iter,
138            batch_size: None,
139            tol: 1e-6,
140            adaptive_lr: false,
141            lr_decay: 0.99,
142            lr_schedule: crate::stochastic::LearningRateSchedule::Constant,
143            gradient_clip: None,
144            early_stopping_patience: None,
145        };
146
147        Self::new(StochasticMethod::SGD, options)
148    }
149
150    /// Create Adam optimizer for neural networks
151    pub fn adam(learning_rate: F, max_iter: usize) -> Self {
152        let options = StochasticOptions {
153            learning_rate: learning_rate.to_f64().unwrap_or(0.001),
154            max_iter,
155            batch_size: None,
156            tol: 1e-6,
157            adaptive_lr: false,
158            lr_decay: 0.99,
159            lr_schedule: crate::stochastic::LearningRateSchedule::Constant,
160            gradient_clip: Some(1.0),
161            early_stopping_patience: None,
162        };
163
164        Self::new(StochasticMethod::Adam, options)
165    }
166
167    /// Create AdamW optimizer for neural networks
168    pub fn adamw(learning_rate: F, max_iter: usize) -> Self {
169        let options = StochasticOptions {
170            learning_rate: learning_rate.to_f64().unwrap_or(0.001),
171            max_iter,
172            batch_size: None,
173            tol: 1e-6,
174            adaptive_lr: false,
175            lr_decay: 0.99,
176            lr_schedule: crate::stochastic::LearningRateSchedule::Constant,
177            gradient_clip: Some(1.0),
178            early_stopping_patience: None,
179        };
180
181        Self::new(StochasticMethod::AdamW, options)
182    }
183
184    /// Update neural network parameters using the selected optimizer
185    pub fn step(&mut self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
186        self.step_count += 1;
187
188        match self.method {
189            StochasticMethod::SGD => self.sgd_step(params),
190            StochasticMethod::Momentum => self.momentum_step(params),
191            StochasticMethod::Adam => self.adam_step(params),
192            StochasticMethod::AdamW => self.adamw_step(params),
193            StochasticMethod::RMSProp => self.rmsprop_step(params),
194        }
195    }
196
197    /// SGD parameter update
198    fn sgd_step(&self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
199        let lr = F::from(self.options.learning_rate).unwrap_or_else(|| F::from(0.01).unwrap());
200
201        for (param, grad) in params.parameters.iter_mut().zip(params.gradients.iter()) {
202            *param = param.clone() - &(grad.clone() * lr);
203        }
204
205        Ok(())
206    }
207
208    /// Momentum SGD parameter update
209    fn momentum_step(&mut self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
210        let lr = F::from(self.options.learning_rate).unwrap_or_else(|| F::from(0.01).unwrap());
211        let momentum = F::from(0.9).unwrap();
212
213        for (i, (param, grad)) in params
214            .parameters
215            .iter_mut()
216            .zip(params.gradients.iter())
217            .enumerate()
218        {
219            let param_name = format!("param_{}", i);
220
221            // Initialize momentum buffer if not exists
222            if !self.momentum_buffers.contains_key(&param_name) {
223                self.momentum_buffers
224                    .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
225            }
226
227            let momentum_buffer = self.momentum_buffers.get_mut(&param_name).unwrap();
228
229            // v = momentum * v + grad
230            *momentum_buffer = momentum_buffer.clone() * momentum + grad;
231
232            // param = param - lr * v
233            *param = param.clone() - &(momentum_buffer.clone() * lr);
234        }
235
236        Ok(())
237    }
238
239    /// Adam parameter update
240    fn adam_step(&mut self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
241        let lr = F::from(self.options.learning_rate).unwrap_or_else(|| F::from(0.001).unwrap());
242        let beta1 = F::from(0.9).unwrap();
243        let beta2 = F::from(0.999).unwrap();
244        let epsilon = F::from(1e-8).unwrap();
245
246        for (i, (param, grad)) in params
247            .parameters
248            .iter_mut()
249            .zip(params.gradients.iter())
250            .enumerate()
251        {
252            let param_name = format!("param_{}", i);
253
254            // Initialize moment buffers if not exists
255            if !self.first_moment.contains_key(&param_name) {
256                self.first_moment
257                    .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
258                self.second_moment
259                    .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
260            }
261
262            let m = self.first_moment.get_mut(&param_name).unwrap();
263            let v = self.second_moment.get_mut(&param_name).unwrap();
264
265            // m = beta1 * m + (1 - beta1) * grad
266            *m = m.clone() * beta1 + &(grad.clone() * (F::one() - beta1));
267
268            // v = beta2 * v + (1 - beta2) * grad^2
269            let grad_squared = grad.mapv(|x| x * x);
270            *v = v.clone() * beta2 + &(grad_squared * (F::one() - beta2));
271
272            // Bias correction
273            let step_f = F::from(self.step_count).unwrap();
274            let m_hat = m.clone() / (F::one() - beta1.powf(step_f));
275            let v_hat = v.clone() / (F::one() - beta2.powf(step_f));
276
277            // param = param - lr * m_hat / (sqrt(v_hat) + epsilon)
278            let denominator = v_hat.mapv(|x| x.sqrt()) + epsilon;
279            let update = m_hat / denominator * lr;
280            *param = param.clone() - &update;
281        }
282
283        Ok(())
284    }
285
286    /// AdamW parameter update (Adam with decoupled weight decay)
287    fn adamw_step(&mut self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
288        let lr = F::from(self.options.learning_rate).unwrap_or_else(|| F::from(0.001).unwrap());
289        let beta1 = F::from(0.9).unwrap();
290        let beta2 = F::from(0.999).unwrap();
291        let epsilon = F::from(1e-8).unwrap();
292        let weight_decay = F::from(0.01).unwrap();
293
294        for (i, (param, grad)) in params
295            .parameters
296            .iter_mut()
297            .zip(params.gradients.iter())
298            .enumerate()
299        {
300            let param_name = format!("param_{}", i);
301
302            // Initialize moment buffers if not exists
303            if !self.first_moment.contains_key(&param_name) {
304                self.first_moment
305                    .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
306                self.second_moment
307                    .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
308            }
309
310            let m = self.first_moment.get_mut(&param_name).unwrap();
311            let v = self.second_moment.get_mut(&param_name).unwrap();
312
313            // m = beta1 * m + (1 - beta1) * grad
314            *m = m.clone() * beta1 + &(grad.clone() * (F::one() - beta1));
315
316            // v = beta2 * v + (1 - beta2) * grad^2
317            let grad_squared = grad.mapv(|x| x * x);
318            *v = v.clone() * beta2 + &(grad_squared * (F::one() - beta2));
319
320            // Bias correction
321            let step_f = F::from(self.step_count).unwrap();
322            let m_hat = m.clone() / (F::one() - beta1.powf(step_f));
323            let v_hat = v.clone() / (F::one() - beta2.powf(step_f));
324
325            // AdamW update: param = param - lr * (m_hat / (sqrt(v_hat) + epsilon) + weight_decay * param)
326            let denominator = v_hat.mapv(|x| x.sqrt()) + epsilon;
327            let adam_update = m_hat / denominator;
328            let weight_decay_update = param.clone() * weight_decay;
329            let total_update = (adam_update + weight_decay_update) * lr;
330
331            *param = param.clone() - &total_update;
332        }
333
334        Ok(())
335    }
336
337    /// RMSprop parameter update
338    fn rmsprop_step(&mut self, params: &mut NeuralParameters<F>) -> Result<(), OptimizeError> {
339        let lr = F::from(self.options.learning_rate).unwrap_or_else(|| F::from(0.001).unwrap());
340        let alpha = F::from(0.99).unwrap(); // decay rate
341        let epsilon = F::from(1e-8).unwrap();
342
343        for (i, (param, grad)) in params
344            .parameters
345            .iter_mut()
346            .zip(params.gradients.iter())
347            .enumerate()
348        {
349            let param_name = format!("param_{}", i);
350
351            // Initialize squared average buffer if not exists
352            if !self.second_moment.contains_key(&param_name) {
353                self.second_moment
354                    .insert(param_name.clone(), Array1::zeros(param.raw_dim()));
355            }
356
357            let v = self.second_moment.get_mut(&param_name).unwrap();
358
359            // v = alpha * v + (1 - alpha) * grad^2
360            let grad_squared = grad.mapv(|x| x * x);
361            *v = v.clone() * alpha + &(grad_squared * (F::one() - alpha));
362
363            // param = param - lr * grad / (sqrt(v) + epsilon)
364            let denominator = v.mapv(|x| x.sqrt()) + epsilon;
365            let update = grad.clone() / denominator * lr;
366            *param = param.clone() - &update;
367        }
368
369        Ok(())
370    }
371
372    /// Get current learning rate
373    pub fn get_learning_rate(&self) -> f64 {
374        self.options.learning_rate
375    }
376
377    /// Set learning rate
378    pub fn set_learning_rate(&mut self, lr: f64) {
379        self.options.learning_rate = lr;
380    }
381
382    /// Reset internal state
383    pub fn reset(&mut self) {
384        self.momentum_buffers.clear();
385        self.first_moment.clear();
386        self.second_moment.clear();
387        self.step_count = 0;
388    }
389
390    /// Get optimizer method name
391    pub fn method_name(&self) -> &'static str {
392        match self.method {
393            StochasticMethod::SGD => "SGD",
394            StochasticMethod::Momentum => "SGD with Momentum",
395            StochasticMethod::Adam => "Adam",
396            StochasticMethod::AdamW => "AdamW",
397            StochasticMethod::RMSProp => "RMSprop",
398        }
399    }
400}
401
402/// Neural network trainer that combines optimization with training loops
403pub struct NeuralTrainer<F: Float + ScalarOperand> {
404    optimizer: NeuralOptimizer<F>,
405    loss_history: Vec<F>,
406    early_stopping_patience: Option<usize>,
407    best_loss: Option<F>,
408    patience_counter: usize,
409}
410
411impl<F: Float + ScalarOperand> NeuralTrainer<F>
412where
413    F: 'static + Send + Sync + std::fmt::Display,
414{
415    /// Create a new neural trainer
416    pub fn new(optimizer: NeuralOptimizer<F>) -> Self {
417        Self {
418            optimizer,
419            loss_history: Vec::new(),
420            early_stopping_patience: None,
421            best_loss: None,
422            patience_counter: 0,
423        }
424    }
425
426    /// Set early stopping patience
427    pub fn with_early_stopping(mut self, patience: usize) -> Self {
428        self.early_stopping_patience = Some(patience);
429        self
430    }
431
432    /// Train for one epoch
433    pub fn train_epoch<LossFn, GradFn>(
434        &mut self,
435        params: &mut NeuralParameters<F>,
436        loss_fn: &mut LossFn,
437        grad_fn: &mut GradFn,
438    ) -> Result<F, OptimizeError>
439    where
440        LossFn: FnMut(&NeuralParameters<F>) -> F,
441        GradFn: FnMut(&NeuralParameters<F>) -> Vec<Array1<F>>,
442    {
443        // Compute gradients
444        let gradients = grad_fn(params);
445        params.gradients = gradients;
446
447        // Apply gradient clipping if specified
448        if let Some(max_norm) = self.optimizer.options.gradient_clip {
449            self.clip_gradients(params, max_norm);
450        }
451
452        // Update parameters
453        self.optimizer.step(params)?;
454
455        // Compute loss
456        let loss = loss_fn(params);
457        self.loss_history.push(loss);
458
459        // Check early stopping
460        if let Some(_patience) = self.early_stopping_patience {
461            if let Some(best_loss) = self.best_loss {
462                if loss < best_loss {
463                    self.best_loss = Some(loss);
464                    self.patience_counter = 0;
465                } else {
466                    self.patience_counter += 1;
467                }
468            } else {
469                self.best_loss = Some(loss);
470            }
471        }
472
473        Ok(loss)
474    }
475
476    /// Check if training should stop early
477    pub fn should_stop_early(&self) -> bool {
478        if let Some(patience) = self.early_stopping_patience {
479            self.patience_counter >= patience
480        } else {
481            false
482        }
483    }
484
485    /// Get loss history
486    pub fn loss_history(&self) -> &[F] {
487        &self.loss_history
488    }
489
490    /// Get current learning rate
491    pub fn learning_rate(&self) -> f64 {
492        self.optimizer.get_learning_rate()
493    }
494
495    /// Set learning rate
496    pub fn set_learning_rate(&mut self, lr: f64) {
497        self.optimizer.set_learning_rate(lr);
498    }
499
500    /// Clip gradients to prevent exploding gradients
501    fn clip_gradients(&self, params: &mut NeuralParameters<F>, max_norm: f64) {
502        let max_norm_f = F::from(max_norm).unwrap();
503
504        // Compute total gradient norm
505        let mut total_norm_sq = F::zero();
506        for grad in &params.gradients {
507            total_norm_sq = total_norm_sq + grad.mapv(|x| x * x).sum();
508        }
509        let total_norm = total_norm_sq.sqrt();
510
511        if total_norm > max_norm_f {
512            let scale = max_norm_f / total_norm;
513            for grad in &mut params.gradients {
514                grad.mapv_inplace(|x| x * scale);
515            }
516        }
517    }
518}
519
520/// Convenience functions for creating neural optimizers
521pub mod optimizers {
522    use super::*;
523
524    /// Create SGD optimizer with default settings for neural networks
525    pub fn sgd<F>(learning_rate: F) -> NeuralOptimizer<F>
526    where
527        F: Float + ScalarOperand + 'static + Send + Sync,
528    {
529        NeuralOptimizer::sgd(learning_rate, 1000)
530    }
531
532    /// Create Adam optimizer with default settings for neural networks
533    pub fn adam<F>(learning_rate: F) -> NeuralOptimizer<F>
534    where
535        F: Float + ScalarOperand + 'static + Send + Sync,
536    {
537        NeuralOptimizer::adam(learning_rate, 1000)
538    }
539
540    /// Create AdamW optimizer with default settings for neural networks
541    pub fn adamw<F>(learning_rate: F) -> NeuralOptimizer<F>
542    where
543        F: Float + ScalarOperand + 'static + Send + Sync,
544    {
545        NeuralOptimizer::adamw(learning_rate, 1000)
546    }
547}
548
549#[cfg(test)]
550mod tests {
551    use super::*;
552    use approx::assert_abs_diff_eq;
553
554    #[test]
555    fn test_neural_parameters() {
556        let mut params = NeuralParameters::<f64>::new();
557
558        // Add some parameters
559        params.add_parameter("layer1".to_string(), Array1::from_vec(vec![1.0, 2.0, 3.0]));
560        params.add_parameter("layer2".to_string(), Array1::from_vec(vec![4.0, 5.0]));
561
562        assert_eq!(params.total_parameters(), 5);
563
564        // Test flattening
565        let flat = params.flatten_parameters();
566        assert_eq!(flat.as_slice().unwrap(), &[1.0, 2.0, 3.0, 4.0, 5.0]);
567
568        // Test updating from flat
569        let new_flat = Array1::from_vec(vec![6.0, 7.0, 8.0, 9.0, 10.0]);
570        params.update_from_flat(&new_flat);
571
572        assert_eq!(params.parameters[0].as_slice().unwrap(), &[6.0, 7.0, 8.0]);
573        assert_eq!(params.parameters[1].as_slice().unwrap(), &[9.0, 10.0]);
574    }
575
576    #[test]
577    fn test_sgd_optimizer() {
578        let mut optimizer = NeuralOptimizer::sgd(0.1, 100);
579        let mut params = NeuralParameters::<f64>::new();
580
581        // Add parameter
582        params.add_parameter("test".to_string(), Array1::from_vec(vec![1.0, 2.0]));
583        // Set gradient
584        params.gradients[0] = Array1::from_vec(vec![0.5, 1.0]);
585
586        // Perform one step
587        optimizer.step(&mut params).unwrap();
588
589        // Check update: param = param - lr * grad
590        let expected = [1.0 - 0.1 * 0.5, 2.0 - 0.1 * 1.0];
591        assert_abs_diff_eq!(params.parameters[0][0], expected[0], epsilon = 1e-10);
592        assert_abs_diff_eq!(params.parameters[0][1], expected[1], epsilon = 1e-10);
593    }
594
595    #[test]
596    fn test_adam_optimizer() {
597        let mut optimizer = NeuralOptimizer::adam(0.001, 100);
598        let mut params = NeuralParameters::<f64>::new();
599
600        // Add parameter
601        params.add_parameter("test".to_string(), Array1::from_vec(vec![1.0, 2.0]));
602        // Set gradient
603        params.gradients[0] = Array1::from_vec(vec![0.1, 0.2]);
604
605        let original_params = params.parameters[0].clone();
606
607        // Perform one step
608        optimizer.step(&mut params).unwrap();
609
610        // Parameters should have changed
611        assert_ne!(params.parameters[0][0], original_params[0]);
612        assert_ne!(params.parameters[0][1], original_params[1]);
613
614        // Parameters should have decreased (since gradients are positive)
615        assert!(params.parameters[0][0] < original_params[0]);
616        assert!(params.parameters[0][1] < original_params[1]);
617    }
618
619    #[test]
620    fn test_neural_trainer() {
621        let optimizer = NeuralOptimizer::sgd(0.1, 100);
622        let mut trainer = NeuralTrainer::new(optimizer).with_early_stopping(5);
623
624        let mut params = NeuralParameters::<f64>::new();
625        params.add_parameter("test".to_string(), Array1::from_vec(vec![1.0]));
626        params.gradients[0] = Array1::from_vec(vec![1.0]);
627
628        // Simple quadratic loss function
629        let mut loss_fn = |p: &NeuralParameters<f64>| p.parameters[0][0] * p.parameters[0][0];
630        let mut grad_fn =
631            |p: &NeuralParameters<f64>| vec![Array1::from_vec(vec![2.0 * p.parameters[0][0]])];
632
633        // Train for one epoch
634        let loss = trainer
635            .train_epoch(&mut params, &mut loss_fn, &mut grad_fn)
636            .unwrap();
637
638        // Loss should be computed
639        assert_eq!(trainer.loss_history().len(), 1);
640        assert_eq!(trainer.loss_history()[0], loss);
641    }
642
643    #[test]
644    fn test_optimizer_convenience_functions() {
645        let sgd_opt = optimizers::sgd(0.01);
646        assert_eq!(sgd_opt.method_name(), "SGD");
647
648        let adam_opt = optimizers::adam(0.001);
649        assert_eq!(adam_opt.method_name(), "Adam");
650
651        let adamw_opt = optimizers::adamw(0.001);
652        assert_eq!(adamw_opt.method_name(), "AdamW");
653    }
654}