scirs2_optimize/stochastic/
mod.rs

1//! Stochastic optimization methods for machine learning and large-scale problems
2//!
3//! This module provides stochastic optimization algorithms that are particularly
4//! well-suited for machine learning, neural networks, and large-scale problems
5//! where exact gradients are expensive or noisy.
6
7pub mod adam;
8pub mod adamw;
9pub mod momentum;
10pub mod rmsprop;
11pub mod sgd;
12
13// Re-export commonly used items
14pub use adam::{minimize_adam, AdamOptions};
15pub use adamw::{minimize_adamw, AdamWOptions};
16pub use momentum::{minimize_sgd_momentum, MomentumOptions};
17pub use rmsprop::{minimize_rmsprop, RMSPropOptions};
18pub use sgd::{minimize_sgd, SGDOptions};
19
20use crate::error::OptimizeError;
21use crate::unconstrained::result::OptimizeResult;
22use ndarray::{Array1, ArrayView1};
23use scirs2_core::random::prelude::*;
24
25/// Stochastic optimization method selection
26#[derive(Debug, Clone, Copy)]
27pub enum StochasticMethod {
28    /// Stochastic Gradient Descent
29    SGD,
30    /// SGD with Momentum
31    Momentum,
32    /// Root Mean Square Propagation
33    RMSProp,
34    /// Adaptive Moment Estimation
35    Adam,
36    /// Adam with Weight Decay
37    AdamW,
38}
39
40/// Common options for stochastic optimization
41#[derive(Debug, Clone)]
42pub struct StochasticOptions {
43    /// Learning rate (step size)
44    pub learning_rate: f64,
45    /// Maximum number of iterations (epochs)
46    pub max_iter: usize,
47    /// Batch size for mini-batch optimization
48    pub batch_size: Option<usize>,
49    /// Tolerance for convergence
50    pub tol: f64,
51    /// Whether to use adaptive learning rate
52    pub adaptive_lr: bool,
53    /// Learning rate decay factor
54    pub lr_decay: f64,
55    /// Learning rate decay schedule
56    pub lr_schedule: LearningRateSchedule,
57    /// Gradient clipping threshold
58    pub gradient_clip: Option<f64>,
59    /// Early stopping patience
60    pub early_stopping_patience: Option<usize>,
61}
62
63impl Default for StochasticOptions {
64    fn default() -> Self {
65        Self {
66            learning_rate: 0.001,
67            max_iter: 1000,
68            batch_size: None,
69            tol: 1e-6,
70            adaptive_lr: false,
71            lr_decay: 0.99,
72            lr_schedule: LearningRateSchedule::Constant,
73            gradient_clip: None,
74            early_stopping_patience: None,
75        }
76    }
77}
78
79/// Learning rate schedules
80#[derive(Debug, Clone)]
81pub enum LearningRateSchedule {
82    /// Constant learning rate
83    Constant,
84    /// Exponential decay: lr * decay^epoch
85    ExponentialDecay { decay_rate: f64 },
86    /// Step decay: lr * decay_factor every decay_steps
87    StepDecay {
88        decay_factor: f64,
89        decay_steps: usize,
90    },
91    /// Linear decay: lr * (1 - epoch/max_epochs)
92    LinearDecay,
93    /// Cosine annealing: lr * 0.5 * (1 + cos(π * epoch/max_epochs))
94    CosineAnnealing,
95    /// Inverse time decay: lr / (1 + decay_rate * epoch)
96    InverseTimeDecay { decay_rate: f64 },
97}
98
99/// Data provider trait for stochastic optimization
100pub trait DataProvider {
101    /// Get the total number of samples
102    fn num_samples(&self) -> usize;
103
104    /// Get a batch of samples
105    fn get_batch(&self, indices: &[usize]) -> Vec<f64>;
106
107    /// Get the full dataset
108    fn get_full_data(&self) -> Vec<f64>;
109}
110
111/// Simple in-memory data provider
112#[derive(Clone)]
113pub struct InMemoryDataProvider {
114    data: Vec<f64>,
115}
116
117impl InMemoryDataProvider {
118    pub fn new(data: Vec<f64>) -> Self {
119        Self { data }
120    }
121}
122
123impl DataProvider for InMemoryDataProvider {
124    fn num_samples(&self) -> usize {
125        self.data.len()
126    }
127
128    fn get_batch(&self, indices: &[usize]) -> Vec<f64> {
129        indices.iter().map(|&i| self.data[i]).collect()
130    }
131
132    fn get_full_data(&self) -> Vec<f64> {
133        self.data.clone()
134    }
135}
136
137/// Stochastic gradient function trait
138pub trait StochasticGradientFunction {
139    /// Compute gradient on a batch of data
140    fn compute_gradient(&mut self, x: &ArrayView1<f64>, batchdata: &[f64]) -> Array1<f64>;
141
142    /// Compute function value on a batch of data
143    fn compute_value(&mut self, x: &ArrayView1<f64>, batchdata: &[f64]) -> f64;
144}
145
146/// Wrapper for regular gradient functions
147pub struct BatchGradientWrapper<F, G> {
148    func: F,
149    grad: G,
150}
151
152impl<F, G> BatchGradientWrapper<F, G>
153where
154    F: FnMut(&ArrayView1<f64>) -> f64,
155    G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
156{
157    pub fn new(func: F, grad: G) -> Self {
158        Self { func, grad }
159    }
160}
161
162impl<F, G> StochasticGradientFunction for BatchGradientWrapper<F, G>
163where
164    F: FnMut(&ArrayView1<f64>) -> f64,
165    G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
166{
167    fn compute_gradient(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> Array1<f64> {
168        (self.grad)(x)
169    }
170
171    fn compute_value(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> f64 {
172        (self.func)(x)
173    }
174}
175
176/// Update learning rate according to schedule
177#[allow(dead_code)]
178pub fn update_learning_rate(
179    initial_lr: f64,
180    epoch: usize,
181    max_epochs: usize,
182    schedule: &LearningRateSchedule,
183) -> f64 {
184    match schedule {
185        LearningRateSchedule::Constant => initial_lr,
186        LearningRateSchedule::ExponentialDecay { decay_rate } => {
187            initial_lr * decay_rate.powi(epoch as i32)
188        }
189        LearningRateSchedule::StepDecay {
190            decay_factor,
191            decay_steps,
192        } => initial_lr * decay_factor.powi((epoch / decay_steps) as i32),
193        LearningRateSchedule::LinearDecay => {
194            initial_lr * (1.0 - epoch as f64 / max_epochs as f64).max(0.0)
195        }
196        LearningRateSchedule::CosineAnnealing => {
197            initial_lr
198                * 0.5
199                * (1.0 + (std::f64::consts::PI * epoch as f64 / max_epochs as f64).cos())
200        }
201        LearningRateSchedule::InverseTimeDecay { decay_rate } => {
202            initial_lr / (1.0 + decay_rate * epoch as f64)
203        }
204    }
205}
206
207/// Clip gradients to prevent exploding gradients
208#[allow(dead_code)]
209pub fn clip_gradients(gradient: &mut Array1<f64>, maxnorm: f64) {
210    let grad_norm = gradient.mapv(|x| x * x).sum().sqrt();
211    if grad_norm > maxnorm {
212        let scale = maxnorm / grad_norm;
213        gradient.mapv_inplace(|x| x * scale);
214    }
215}
216
217/// Generate random batch indices
218#[allow(dead_code)]
219pub fn generate_batch_indices(_num_samples: usize, batchsize: usize, shuffle: bool) -> Vec<usize> {
220    let mut indices: Vec<usize> = (0.._num_samples).collect();
221
222    if shuffle {
223        use rand::seq::SliceRandom;
224        indices.shuffle(&mut thread_rng());
225    }
226
227    indices.into_iter().take(batchsize).collect()
228}
229
230/// Main stochastic optimization function
231#[allow(dead_code)]
232pub fn minimize_stochastic<F>(
233    method: StochasticMethod,
234    grad_func: F,
235    x0: Array1<f64>,
236    data_provider: Box<dyn DataProvider>,
237    options: StochasticOptions,
238) -> Result<OptimizeResult<f64>, OptimizeError>
239where
240    F: StochasticGradientFunction,
241{
242    match method {
243        StochasticMethod::SGD => {
244            let sgd_options = SGDOptions {
245                learning_rate: options.learning_rate,
246                max_iter: options.max_iter,
247                tol: options.tol,
248                lr_schedule: options.lr_schedule,
249                gradient_clip: options.gradient_clip,
250                batch_size: options.batch_size,
251            };
252            sgd::minimize_sgd(grad_func, x0, data_provider, sgd_options)
253        }
254        StochasticMethod::Momentum => {
255            let momentum_options = MomentumOptions {
256                learning_rate: options.learning_rate,
257                momentum: 0.9, // Default momentum
258                max_iter: options.max_iter,
259                tol: options.tol,
260                lr_schedule: options.lr_schedule,
261                gradient_clip: options.gradient_clip,
262                batch_size: options.batch_size,
263                nesterov: false,
264                dampening: 0.0,
265            };
266            momentum::minimize_sgd_momentum(grad_func, x0, data_provider, momentum_options)
267        }
268        StochasticMethod::RMSProp => {
269            let rmsprop_options = RMSPropOptions {
270                learning_rate: options.learning_rate,
271                decay_rate: 0.99, // Default RMSProp decay
272                epsilon: 1e-8,
273                max_iter: options.max_iter,
274                tol: options.tol,
275                lr_schedule: options.lr_schedule,
276                gradient_clip: options.gradient_clip,
277                batch_size: options.batch_size,
278                centered: false,
279                momentum: None,
280            };
281            rmsprop::minimize_rmsprop(grad_func, x0, data_provider, rmsprop_options)
282        }
283        StochasticMethod::Adam => {
284            let adam_options = AdamOptions {
285                learning_rate: options.learning_rate,
286                beta1: 0.9,
287                beta2: 0.999,
288                epsilon: 1e-8,
289                max_iter: options.max_iter,
290                tol: options.tol,
291                lr_schedule: options.lr_schedule,
292                gradient_clip: options.gradient_clip,
293                batch_size: options.batch_size,
294                amsgrad: false,
295            };
296            adam::minimize_adam(grad_func, x0, data_provider, adam_options)
297        }
298        StochasticMethod::AdamW => {
299            let adamw_options = AdamWOptions {
300                learning_rate: options.learning_rate,
301                beta1: 0.9,
302                beta2: 0.999,
303                epsilon: 1e-8,
304                weight_decay: 0.01, // Default weight decay
305                max_iter: options.max_iter,
306                tol: options.tol,
307                lr_schedule: options.lr_schedule,
308                gradient_clip: options.gradient_clip,
309                batch_size: options.batch_size,
310                decouple_weight_decay: true,
311            };
312            adamw::minimize_adamw(grad_func, x0, data_provider, adamw_options)
313        }
314    }
315}
316
317/// Create stochastic options optimized for specific problem types
318#[allow(dead_code)]
319pub fn create_stochastic_options_for_problem(
320    problem_type: &str,
321    dataset_size: usize,
322) -> StochasticOptions {
323    match problem_type.to_lowercase().as_str() {
324        "neural_network" => StochasticOptions {
325            learning_rate: 0.001,
326            max_iter: 1000,
327            batch_size: Some(32.min(dataset_size / 10)),
328            lr_schedule: LearningRateSchedule::ExponentialDecay { decay_rate: 0.99 },
329            gradient_clip: Some(1.0),
330            early_stopping_patience: Some(50),
331            ..Default::default()
332        },
333        "linear_regression" => StochasticOptions {
334            learning_rate: 0.01,
335            max_iter: 500,
336            batch_size: Some(64.min(dataset_size / 5)),
337            lr_schedule: LearningRateSchedule::LinearDecay,
338            ..Default::default()
339        },
340        "logistic_regression" => StochasticOptions {
341            learning_rate: 0.01,
342            max_iter: 200,
343            batch_size: Some(32.min(dataset_size / 10)),
344            lr_schedule: LearningRateSchedule::StepDecay {
345                decay_factor: 0.9,
346                decay_steps: 50,
347            },
348            ..Default::default()
349        },
350        "large_scale" => StochasticOptions {
351            learning_rate: 0.001,
352            max_iter: 2000,
353            batch_size: Some(128.min(dataset_size / 20)),
354            lr_schedule: LearningRateSchedule::CosineAnnealing,
355            gradient_clip: Some(5.0),
356            adaptive_lr: true,
357            ..Default::default()
358        },
359        "noisy_gradients" => StochasticOptions {
360            learning_rate: 0.01,
361            max_iter: 1000,
362            batch_size: Some(64.min(dataset_size / 5)),
363            lr_schedule: LearningRateSchedule::InverseTimeDecay { decay_rate: 1.0 },
364            gradient_clip: Some(2.0),
365            early_stopping_patience: Some(100),
366            ..Default::default()
367        },
368        _ => StochasticOptions::default(),
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375    use approx::assert_abs_diff_eq;
376
377    #[test]
378    fn test_learning_rate_schedules() {
379        let initial_lr = 0.1;
380        let max_epochs = 100;
381
382        // Test constant schedule
383        let constant = LearningRateSchedule::Constant;
384        assert_abs_diff_eq!(
385            update_learning_rate(initial_lr, 50, max_epochs, &constant),
386            initial_lr,
387            epsilon = 1e-10
388        );
389
390        // Test exponential decay
391        let exp_decay = LearningRateSchedule::ExponentialDecay { decay_rate: 0.9 };
392        let lr_exp = update_learning_rate(initial_lr, 10, max_epochs, &exp_decay);
393        assert_abs_diff_eq!(lr_exp, initial_lr * 0.9_f64.powi(10), epsilon = 1e-10);
394
395        // Test linear decay
396        let linear = LearningRateSchedule::LinearDecay;
397        let lr_linear = update_learning_rate(initial_lr, 50, max_epochs, &linear);
398        assert_abs_diff_eq!(lr_linear, initial_lr * 0.5, epsilon = 1e-10);
399    }
400
401    #[test]
402    fn test_gradient_clipping() {
403        let mut grad = Array1::from_vec(vec![3.0, 4.0]); // Norm = 5
404        clip_gradients(&mut grad, 2.5);
405
406        let clipped_norm = grad.mapv(|x| x * x).sum().sqrt();
407        assert_abs_diff_eq!(clipped_norm, 2.5, epsilon = 1e-10);
408
409        // Check direction is preserved
410        assert_abs_diff_eq!(grad[0] / grad[1], 3.0 / 4.0, epsilon = 1e-10);
411    }
412
413    #[test]
414    fn test_batch_indices_generation() {
415        let indices = generate_batch_indices(100, 10, false);
416        assert_eq!(indices.len(), 10);
417        assert_eq!(indices, (0..10).collect::<Vec<usize>>());
418
419        let shuffled = generate_batch_indices(100, 10, true);
420        assert_eq!(shuffled.len(), 10);
421        // All indices should be < 100
422        assert!(shuffled.iter().all(|&i| i < 100));
423    }
424
425    #[test]
426    fn test_in_memory_data_provider() {
427        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
428        let provider = InMemoryDataProvider::new(data.clone());
429
430        assert_eq!(provider.num_samples(), 5);
431        assert_eq!(provider.get_full_data(), data);
432
433        let batch = provider.get_batch(&[0, 2, 4]);
434        assert_eq!(batch, vec![1.0, 3.0, 5.0]);
435    }
436
437    #[test]
438    fn test_problem_specific_options() {
439        let nn_options = create_stochastic_options_for_problem("neural_network", 1000);
440        assert_eq!(nn_options.learning_rate, 0.001);
441        assert!(nn_options.batch_size.is_some());
442        assert!(nn_options.gradient_clip.is_some());
443
444        let lr_options = create_stochastic_options_for_problem("linear_regression", 500);
445        assert_eq!(lr_options.learning_rate, 0.01);
446        assert!(matches!(
447            lr_options.lr_schedule,
448            LearningRateSchedule::LinearDecay
449        ));
450
451        let large_options = create_stochastic_options_for_problem("large_scale", 10000);
452        assert!(matches!(
453            large_options.lr_schedule,
454            LearningRateSchedule::CosineAnnealing
455        ));
456        assert_eq!(large_options.batch_size, Some(128));
457    }
458}