1pub mod adam;
8pub mod adamw;
9pub mod momentum;
10pub mod rmsprop;
11pub mod sgd;
12
13pub use adam::{minimize_adam, AdamOptions};
15pub use adamw::{minimize_adamw, AdamWOptions};
16pub use momentum::{minimize_sgd_momentum, MomentumOptions};
17pub use rmsprop::{minimize_rmsprop, RMSPropOptions};
18pub use sgd::{minimize_sgd, SGDOptions};
19
20use crate::error::OptimizeError;
21use crate::unconstrained::result::OptimizeResult;
22use scirs2_core::ndarray::{Array1, ArrayView1};
23use scirs2_core::random::prelude::*;
24
25#[derive(Debug, Clone, Copy)]
27pub enum StochasticMethod {
28 SGD,
30 Momentum,
32 RMSProp,
34 Adam,
36 AdamW,
38}
39
40#[derive(Debug, Clone)]
42pub struct StochasticOptions {
43 pub learning_rate: f64,
45 pub max_iter: usize,
47 pub batch_size: Option<usize>,
49 pub tol: f64,
51 pub adaptive_lr: bool,
53 pub lr_decay: f64,
55 pub lr_schedule: LearningRateSchedule,
57 pub gradient_clip: Option<f64>,
59 pub early_stopping_patience: Option<usize>,
61}
62
63impl Default for StochasticOptions {
64 fn default() -> Self {
65 Self {
66 learning_rate: 0.001,
67 max_iter: 1000,
68 batch_size: None,
69 tol: 1e-6,
70 adaptive_lr: false,
71 lr_decay: 0.99,
72 lr_schedule: LearningRateSchedule::Constant,
73 gradient_clip: None,
74 early_stopping_patience: None,
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
81pub enum LearningRateSchedule {
82 Constant,
84 ExponentialDecay { decay_rate: f64 },
86 StepDecay {
88 decay_factor: f64,
89 decay_steps: usize,
90 },
91 LinearDecay,
93 CosineAnnealing,
95 InverseTimeDecay { decay_rate: f64 },
97}
98
99pub trait DataProvider {
101 fn num_samples(&self) -> usize;
103
104 fn get_batch(&self, indices: &[usize]) -> Vec<f64>;
106
107 fn get_full_data(&self) -> Vec<f64>;
109}
110
111#[derive(Clone)]
113pub struct InMemoryDataProvider {
114 data: Vec<f64>,
115}
116
117impl InMemoryDataProvider {
118 pub fn new(data: Vec<f64>) -> Self {
119 Self { data }
120 }
121}
122
123impl DataProvider for InMemoryDataProvider {
124 fn num_samples(&self) -> usize {
125 self.data.len()
126 }
127
128 fn get_batch(&self, indices: &[usize]) -> Vec<f64> {
129 indices.iter().map(|&i| self.data[i]).collect()
130 }
131
132 fn get_full_data(&self) -> Vec<f64> {
133 self.data.clone()
134 }
135}
136
137pub trait StochasticGradientFunction {
139 fn compute_gradient(&mut self, x: &ArrayView1<f64>, batchdata: &[f64]) -> Array1<f64>;
141
142 fn compute_value(&mut self, x: &ArrayView1<f64>, batchdata: &[f64]) -> f64;
144}
145
146pub struct BatchGradientWrapper<F, G> {
148 func: F,
149 grad: G,
150}
151
152impl<F, G> BatchGradientWrapper<F, G>
153where
154 F: FnMut(&ArrayView1<f64>) -> f64,
155 G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
156{
157 pub fn new(func: F, grad: G) -> Self {
158 Self { func, grad }
159 }
160}
161
162impl<F, G> StochasticGradientFunction for BatchGradientWrapper<F, G>
163where
164 F: FnMut(&ArrayView1<f64>) -> f64,
165 G: FnMut(&ArrayView1<f64>) -> Array1<f64>,
166{
167 fn compute_gradient(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> Array1<f64> {
168 (self.grad)(x)
169 }
170
171 fn compute_value(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> f64 {
172 (self.func)(x)
173 }
174}
175
176#[allow(dead_code)]
178pub fn update_learning_rate(
179 initial_lr: f64,
180 epoch: usize,
181 max_epochs: usize,
182 schedule: &LearningRateSchedule,
183) -> f64 {
184 match schedule {
185 LearningRateSchedule::Constant => initial_lr,
186 LearningRateSchedule::ExponentialDecay { decay_rate } => {
187 initial_lr * decay_rate.powi(epoch as i32)
188 }
189 LearningRateSchedule::StepDecay {
190 decay_factor,
191 decay_steps,
192 } => initial_lr * decay_factor.powi((epoch / decay_steps) as i32),
193 LearningRateSchedule::LinearDecay => {
194 initial_lr * (1.0 - epoch as f64 / max_epochs as f64).max(0.0)
195 }
196 LearningRateSchedule::CosineAnnealing => {
197 initial_lr
198 * 0.5
199 * (1.0 + (std::f64::consts::PI * epoch as f64 / max_epochs as f64).cos())
200 }
201 LearningRateSchedule::InverseTimeDecay { decay_rate } => {
202 initial_lr / (1.0 + decay_rate * epoch as f64)
203 }
204 }
205}
206
207#[allow(dead_code)]
209pub fn clip_gradients(gradient: &mut Array1<f64>, maxnorm: f64) {
210 let grad_norm = gradient.mapv(|x| x * x).sum().sqrt();
211 if grad_norm > maxnorm {
212 let scale = maxnorm / grad_norm;
213 gradient.mapv_inplace(|x| x * scale);
214 }
215}
216
217#[allow(dead_code)]
219pub fn generate_batch_indices(_num_samples: usize, batchsize: usize, shuffle: bool) -> Vec<usize> {
220 let mut indices: Vec<usize> = (0.._num_samples).collect();
221
222 if shuffle {
223 use scirs2_core::random::seq::SliceRandom;
224 indices.shuffle(&mut thread_rng());
225 }
226
227 indices.into_iter().take(batchsize).collect()
228}
229
230#[allow(dead_code)]
232pub fn minimize_stochastic<F>(
233 method: StochasticMethod,
234 grad_func: F,
235 x0: Array1<f64>,
236 data_provider: Box<dyn DataProvider>,
237 options: StochasticOptions,
238) -> Result<OptimizeResult<f64>, OptimizeError>
239where
240 F: StochasticGradientFunction,
241{
242 match method {
243 StochasticMethod::SGD => {
244 let sgd_options = SGDOptions {
245 learning_rate: options.learning_rate,
246 max_iter: options.max_iter,
247 tol: options.tol,
248 lr_schedule: options.lr_schedule,
249 gradient_clip: options.gradient_clip,
250 batch_size: options.batch_size,
251 };
252 sgd::minimize_sgd(grad_func, x0, data_provider, sgd_options)
253 }
254 StochasticMethod::Momentum => {
255 let momentum_options = MomentumOptions {
256 learning_rate: options.learning_rate,
257 momentum: 0.9, max_iter: options.max_iter,
259 tol: options.tol,
260 lr_schedule: options.lr_schedule,
261 gradient_clip: options.gradient_clip,
262 batch_size: options.batch_size,
263 nesterov: false,
264 dampening: 0.0,
265 };
266 momentum::minimize_sgd_momentum(grad_func, x0, data_provider, momentum_options)
267 }
268 StochasticMethod::RMSProp => {
269 let rmsprop_options = RMSPropOptions {
270 learning_rate: options.learning_rate,
271 decay_rate: 0.99, epsilon: 1e-8,
273 max_iter: options.max_iter,
274 tol: options.tol,
275 lr_schedule: options.lr_schedule,
276 gradient_clip: options.gradient_clip,
277 batch_size: options.batch_size,
278 centered: false,
279 momentum: None,
280 };
281 rmsprop::minimize_rmsprop(grad_func, x0, data_provider, rmsprop_options)
282 }
283 StochasticMethod::Adam => {
284 let adam_options = AdamOptions {
285 learning_rate: options.learning_rate,
286 beta1: 0.9,
287 beta2: 0.999,
288 epsilon: 1e-8,
289 max_iter: options.max_iter,
290 tol: options.tol,
291 lr_schedule: options.lr_schedule,
292 gradient_clip: options.gradient_clip,
293 batch_size: options.batch_size,
294 amsgrad: false,
295 };
296 adam::minimize_adam(grad_func, x0, data_provider, adam_options)
297 }
298 StochasticMethod::AdamW => {
299 let adamw_options = AdamWOptions {
300 learning_rate: options.learning_rate,
301 beta1: 0.9,
302 beta2: 0.999,
303 epsilon: 1e-8,
304 weight_decay: 0.01, max_iter: options.max_iter,
306 tol: options.tol,
307 lr_schedule: options.lr_schedule,
308 gradient_clip: options.gradient_clip,
309 batch_size: options.batch_size,
310 decouple_weight_decay: true,
311 };
312 adamw::minimize_adamw(grad_func, x0, data_provider, adamw_options)
313 }
314 }
315}
316
317#[allow(dead_code)]
319pub fn create_stochastic_options_for_problem(
320 problem_type: &str,
321 dataset_size: usize,
322) -> StochasticOptions {
323 match problem_type.to_lowercase().as_str() {
324 "neural_network" => StochasticOptions {
325 learning_rate: 0.001,
326 max_iter: 1000,
327 batch_size: Some(32.min(dataset_size / 10)),
328 lr_schedule: LearningRateSchedule::ExponentialDecay { decay_rate: 0.99 },
329 gradient_clip: Some(1.0),
330 early_stopping_patience: Some(50),
331 ..Default::default()
332 },
333 "linear_regression" => StochasticOptions {
334 learning_rate: 0.01,
335 max_iter: 500,
336 batch_size: Some(64.min(dataset_size / 5)),
337 lr_schedule: LearningRateSchedule::LinearDecay,
338 ..Default::default()
339 },
340 "logistic_regression" => StochasticOptions {
341 learning_rate: 0.01,
342 max_iter: 200,
343 batch_size: Some(32.min(dataset_size / 10)),
344 lr_schedule: LearningRateSchedule::StepDecay {
345 decay_factor: 0.9,
346 decay_steps: 50,
347 },
348 ..Default::default()
349 },
350 "large_scale" => StochasticOptions {
351 learning_rate: 0.001,
352 max_iter: 2000,
353 batch_size: Some(128.min(dataset_size / 20)),
354 lr_schedule: LearningRateSchedule::CosineAnnealing,
355 gradient_clip: Some(5.0),
356 adaptive_lr: true,
357 ..Default::default()
358 },
359 "noisy_gradients" => StochasticOptions {
360 learning_rate: 0.01,
361 max_iter: 1000,
362 batch_size: Some(64.min(dataset_size / 5)),
363 lr_schedule: LearningRateSchedule::InverseTimeDecay { decay_rate: 1.0 },
364 gradient_clip: Some(2.0),
365 early_stopping_patience: Some(100),
366 ..Default::default()
367 },
368 _ => StochasticOptions::default(),
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375 use approx::assert_abs_diff_eq;
376
377 #[test]
378 fn test_learning_rate_schedules() {
379 let initial_lr = 0.1;
380 let max_epochs = 100;
381
382 let constant = LearningRateSchedule::Constant;
384 assert_abs_diff_eq!(
385 update_learning_rate(initial_lr, 50, max_epochs, &constant),
386 initial_lr,
387 epsilon = 1e-10
388 );
389
390 let exp_decay = LearningRateSchedule::ExponentialDecay { decay_rate: 0.9 };
392 let lr_exp = update_learning_rate(initial_lr, 10, max_epochs, &exp_decay);
393 assert_abs_diff_eq!(lr_exp, initial_lr * 0.9_f64.powi(10), epsilon = 1e-10);
394
395 let linear = LearningRateSchedule::LinearDecay;
397 let lr_linear = update_learning_rate(initial_lr, 50, max_epochs, &linear);
398 assert_abs_diff_eq!(lr_linear, initial_lr * 0.5, epsilon = 1e-10);
399 }
400
401 #[test]
402 fn test_gradient_clipping() {
403 let mut grad = Array1::from_vec(vec![3.0, 4.0]); clip_gradients(&mut grad, 2.5);
405
406 let clipped_norm = grad.mapv(|x| x * x).sum().sqrt();
407 assert_abs_diff_eq!(clipped_norm, 2.5, epsilon = 1e-10);
408
409 assert_abs_diff_eq!(grad[0] / grad[1], 3.0 / 4.0, epsilon = 1e-10);
411 }
412
413 #[test]
414 fn test_batch_indices_generation() {
415 let indices = generate_batch_indices(100, 10, false);
416 assert_eq!(indices.len(), 10);
417 assert_eq!(indices, (0..10).collect::<Vec<usize>>());
418
419 let shuffled = generate_batch_indices(100, 10, true);
420 assert_eq!(shuffled.len(), 10);
421 assert!(shuffled.iter().all(|&i| i < 100));
423 }
424
425 #[test]
426 fn test_in_memory_data_provider() {
427 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
428 let provider = InMemoryDataProvider::new(data.clone());
429
430 assert_eq!(provider.num_samples(), 5);
431 assert_eq!(provider.get_full_data(), data);
432
433 let batch = provider.get_batch(&[0, 2, 4]);
434 assert_eq!(batch, vec![1.0, 3.0, 5.0]);
435 }
436
437 #[test]
438 fn test_problem_specific_options() {
439 let nn_options = create_stochastic_options_for_problem("neural_network", 1000);
440 assert_eq!(nn_options.learning_rate, 0.001);
441 assert!(nn_options.batch_size.is_some());
442 assert!(nn_options.gradient_clip.is_some());
443
444 let lr_options = create_stochastic_options_for_problem("linear_regression", 500);
445 assert_eq!(lr_options.learning_rate, 0.01);
446 assert!(matches!(
447 lr_options.lr_schedule,
448 LearningRateSchedule::LinearDecay
449 ));
450
451 let large_options = create_stochastic_options_for_problem("large_scale", 10000);
452 assert!(matches!(
453 large_options.lr_schedule,
454 LearningRateSchedule::CosineAnnealing
455 ));
456 assert_eq!(large_options.batch_size, Some(128));
457 }
458}