scirs2_optimize/stochastic/
sgd.rs

1//! Stochastic Gradient Descent (SGD) optimization
2//!
3//! This module implements the basic SGD algorithm and its variants for
4//! stochastic optimization problems.
5
6use crate::error::OptimizeError;
7use crate::stochastic::{
8    clip_gradients, generate_batch_indices, update_learning_rate, DataProvider,
9    LearningRateSchedule, StochasticGradientFunction,
10};
11use crate::unconstrained::result::OptimizeResult;
12use scirs2_core::ndarray::Array1;
13use scirs2_core::random::prelude::*;
14
15/// Options for SGD optimization
16#[derive(Debug, Clone)]
17pub struct SGDOptions {
18    /// Learning rate (step size)
19    pub learning_rate: f64,
20    /// Maximum number of iterations
21    pub max_iter: usize,
22    /// Convergence tolerance
23    pub tol: f64,
24    /// Learning rate schedule
25    pub lr_schedule: LearningRateSchedule,
26    /// Gradient clipping threshold
27    pub gradient_clip: Option<f64>,
28    /// Batch size for mini-batch SGD
29    pub batch_size: Option<usize>,
30}
31
32impl Default for SGDOptions {
33    fn default() -> Self {
34        Self {
35            learning_rate: 0.01,
36            max_iter: 1000,
37            tol: 1e-6,
38            lr_schedule: LearningRateSchedule::Constant,
39            gradient_clip: None,
40            batch_size: None,
41        }
42    }
43}
44
45/// Stochastic Gradient Descent optimizer
46#[allow(dead_code)]
47pub fn minimize_sgd<F>(
48    mut grad_func: F,
49    mut x: Array1<f64>,
50    data_provider: Box<dyn DataProvider>,
51    options: SGDOptions,
52) -> Result<OptimizeResult<f64>, OptimizeError>
53where
54    F: StochasticGradientFunction,
55{
56    let mut func_evals = 0;
57    let mut _grad_evals = 0;
58
59    let num_samples = data_provider.num_samples();
60    let batch_size = options.batch_size.unwrap_or(num_samples);
61    let actual_batch_size = batch_size.min(num_samples);
62
63    // Track the best solution found
64    let mut best_x = x.clone();
65    let mut best_f = f64::INFINITY;
66
67    // Convergence tracking
68    let mut prev_loss = f64::INFINITY;
69    let mut stagnant_iterations = 0;
70
71    println!("Starting SGD optimization:");
72    println!("  Parameters: {}", x.len());
73    println!("  Dataset size: {}", num_samples);
74    println!("  Batch size: {}", actual_batch_size);
75    println!("  Initial learning rate: {}", options.learning_rate);
76
77    #[allow(clippy::explicit_counter_loop)]
78    for iteration in 0..options.max_iter {
79        // Update learning rate according to schedule
80        let current_lr = update_learning_rate(
81            options.learning_rate,
82            iteration,
83            options.max_iter,
84            &options.lr_schedule,
85        );
86
87        // Generate batch indices
88        let batch_indices = if actual_batch_size < num_samples {
89            generate_batch_indices(num_samples, actual_batch_size, true)
90        } else {
91            (0..num_samples).collect()
92        };
93
94        // Get batch data
95        let batch_data = data_provider.get_batch(&batch_indices);
96
97        // Compute gradient on batch
98        let mut gradient = grad_func.compute_gradient(&x.view(), &batch_data);
99        _grad_evals += 1;
100
101        // Apply gradient clipping if specified
102        if let Some(clip_threshold) = options.gradient_clip {
103            clip_gradients(&mut gradient, clip_threshold);
104        }
105
106        // SGD update: x = x - lr * gradient
107        x = &x - &(&gradient * current_lr);
108
109        // Evaluate on full dataset periodically for convergence check
110        if iteration % 10 == 0 || iteration == options.max_iter - 1 {
111            let full_data = data_provider.get_full_data();
112            let current_loss = grad_func.compute_value(&x.view(), &full_data);
113            func_evals += 1;
114
115            // Update best solution
116            if current_loss < best_f {
117                best_f = current_loss;
118                best_x = x.clone();
119                stagnant_iterations = 0;
120            } else {
121                stagnant_iterations += 1;
122            }
123
124            // Progress reporting
125            if iteration % 100 == 0 {
126                let grad_norm = gradient.mapv(|g| g * g).sum().sqrt();
127                println!(
128                    "  Iteration {}: loss = {:.6e}, |grad| = {:.3e}, lr = {:.3e}",
129                    iteration, current_loss, grad_norm, current_lr
130                );
131            }
132
133            // Check convergence
134            let loss_change = (prev_loss - current_loss).abs();
135            if loss_change < options.tol {
136                return Ok(OptimizeResult {
137                    x: best_x,
138                    fun: best_f,
139                    nit: iteration,
140                    func_evals,
141                    nfev: func_evals,
142                    success: true,
143                    message: format!(
144                        "SGD converged: loss change {:.2e} < {:.2e}",
145                        loss_change, options.tol
146                    ),
147                    jacobian: Some(gradient),
148                    hessian: None,
149                });
150            }
151
152            prev_loss = current_loss;
153
154            // Early stopping for stagnation
155            if stagnant_iterations > 50 {
156                return Ok(OptimizeResult {
157                    x: best_x,
158                    fun: best_f,
159                    nit: iteration,
160                    func_evals,
161                    nfev: func_evals,
162                    success: false,
163                    message: "SGD stopped due to stagnation".to_string(),
164                    jacobian: Some(gradient),
165                    hessian: None,
166                });
167            }
168        }
169    }
170
171    // Final evaluation
172    let full_data = data_provider.get_full_data();
173    let final_loss = grad_func.compute_value(&best_x.view(), &full_data);
174    func_evals += 1;
175
176    Ok(OptimizeResult {
177        x: best_x,
178        fun: final_loss.min(best_f),
179        nit: options.max_iter,
180        func_evals,
181        nfev: func_evals,
182        success: false,
183        message: "SGD reached maximum iterations".to_string(),
184        jacobian: None,
185        hessian: None,
186    })
187}
188
189/// Variance-reduced SGD using SVRG (Stochastic Variance Reduced Gradient)
190#[allow(dead_code)]
191pub fn minimize_svrg<F>(
192    mut grad_func: F,
193    mut x: Array1<f64>,
194    data_provider: Box<dyn DataProvider>,
195    options: SGDOptions,
196) -> Result<OptimizeResult<f64>, OptimizeError>
197where
198    F: StochasticGradientFunction,
199{
200    let mut func_evals = 0;
201    let mut _grad_evals = 0;
202
203    let num_samples = data_provider.num_samples();
204    let batch_size = options.batch_size.unwrap_or(1);
205    let update_frequency = num_samples / batch_size; // Full pass frequency
206
207    // Compute full gradient initially
208    let full_data = data_provider.get_full_data();
209    let mut full_gradient = grad_func.compute_gradient(&x.view(), &full_data);
210    _grad_evals += 1;
211
212    let mut x_snapshot = x.clone();
213    let mut best_x = x.clone();
214    let mut best_f = f64::INFINITY;
215
216    println!("Starting SVRG optimization:");
217    println!("  Parameters: {}", x.len());
218    println!("  Dataset size: {}", num_samples);
219    println!("  Batch size: {}", batch_size);
220    println!("  Update frequency: {}", update_frequency);
221
222    for epoch in 0..options.max_iter {
223        let current_lr = update_learning_rate(
224            options.learning_rate,
225            epoch,
226            options.max_iter,
227            &options.lr_schedule,
228        );
229
230        // Inner loop: one pass through data
231        for _inner_iter in 0..update_frequency {
232            // Generate batch indices
233            let batch_indices = generate_batch_indices(num_samples, batch_size, true);
234            let batch_data = data_provider.get_batch(&batch_indices);
235
236            // Compute stochastic gradient
237            let stoch_grad = grad_func.compute_gradient(&x.view(), &batch_data);
238            _grad_evals += 1;
239
240            // Compute control variate gradient at snapshot
241            let control_grad = grad_func.compute_gradient(&x_snapshot.view(), &batch_data);
242            _grad_evals += 1;
243
244            // SVRG gradient estimate: g_i - g_i(snapshot) + full_gradient
245            let mut svrg_gradient = &stoch_grad - &control_grad + &full_gradient;
246
247            // Apply gradient clipping
248            if let Some(clip_threshold) = options.gradient_clip {
249                clip_gradients(&mut svrg_gradient, clip_threshold);
250            }
251
252            // Update parameters
253            x = &x - &(&svrg_gradient * current_lr);
254        }
255
256        // Update snapshot and full gradient
257        x_snapshot = x.clone();
258        full_gradient = grad_func.compute_gradient(&x_snapshot.view(), &full_data);
259        _grad_evals += 1;
260
261        // Evaluate progress
262        let current_loss = grad_func.compute_value(&x.view(), &full_data);
263        func_evals += 1;
264
265        if current_loss < best_f {
266            best_f = current_loss;
267            best_x = x.clone();
268        }
269
270        if epoch % 10 == 0 {
271            let grad_norm = full_gradient.mapv(|g| g * g).sum().sqrt();
272            println!(
273                "  Epoch {}: loss = {:.6e}, |grad| = {:.3e}, lr = {:.3e}",
274                epoch, current_loss, grad_norm, current_lr
275            );
276        }
277
278        // Check convergence
279        let grad_norm = full_gradient.mapv(|g| g * g).sum().sqrt();
280        if grad_norm < options.tol {
281            return Ok(OptimizeResult {
282                x: best_x,
283                fun: best_f,
284                nit: epoch,
285                func_evals,
286                nfev: func_evals,
287                success: true,
288                message: format!(
289                    "SVRG converged: gradient norm {:.2e} < {:.2e}",
290                    grad_norm, options.tol
291                ),
292                jacobian: Some(full_gradient),
293                hessian: None,
294            });
295        }
296    }
297
298    Ok(OptimizeResult {
299        x: best_x,
300        fun: best_f,
301        nit: options.max_iter,
302        func_evals,
303        nfev: func_evals,
304        success: false,
305        message: "SVRG reached maximum iterations".to_string(),
306        jacobian: Some(full_gradient),
307        hessian: None,
308    })
309}
310
311/// Mini-batch SGD with averaging for better convergence
312#[allow(dead_code)]
313pub fn minimize_mini_batch_sgd<F>(
314    mut grad_func: F,
315    mut x: Array1<f64>,
316    data_provider: Box<dyn DataProvider>,
317    options: SGDOptions,
318) -> Result<OptimizeResult<f64>, OptimizeError>
319where
320    F: StochasticGradientFunction,
321{
322    let mut func_evals = 0;
323    let mut _grad_evals = 0;
324
325    let num_samples = data_provider.num_samples();
326    let batch_size = options.batch_size.unwrap_or(32.min(num_samples / 10));
327    let batches_per_epoch = num_samples.div_ceil(batch_size);
328
329    let mut best_x = x.clone();
330    let mut best_f = f64::INFINITY;
331
332    // Moving average for parameters (Polyak averaging)
333    let mut x_avg = x.clone();
334    let avg_start_epoch = options.max_iter / 4; // Start averaging after 25% of iterations
335
336    println!("Starting Mini-batch SGD optimization:");
337    println!("  Parameters: {}", x.len());
338    println!("  Dataset size: {}", num_samples);
339    println!("  Batch size: {}", batch_size);
340    println!("  Batches per epoch: {}", batches_per_epoch);
341
342    for epoch in 0..options.max_iter {
343        let current_lr = update_learning_rate(
344            options.learning_rate,
345            epoch,
346            options.max_iter,
347            &options.lr_schedule,
348        );
349
350        // Shuffle data indices for this epoch
351        let mut all_indices: Vec<usize> = (0..num_samples).collect();
352        use scirs2_core::random::seq::SliceRandom;
353        all_indices.shuffle(&mut thread_rng());
354
355        let mut _epoch_loss = 0.0;
356        let mut epoch_grad_norm = 0.0;
357
358        // Process all batches in epoch
359        for batch_idx in 0..batches_per_epoch {
360            let start_idx = batch_idx * batch_size;
361            let end_idx = (start_idx + batch_size).min(num_samples);
362            let batch_indices = &all_indices[start_idx..end_idx];
363
364            let batch_data = data_provider.get_batch(batch_indices);
365
366            // Compute gradient on batch
367            let mut gradient = grad_func.compute_gradient(&x.view(), &batch_data);
368            _grad_evals += 1;
369
370            // Apply gradient clipping
371            if let Some(clip_threshold) = options.gradient_clip {
372                clip_gradients(&mut gradient, clip_threshold);
373            }
374
375            // Update parameters
376            x = &x - &(&gradient * current_lr);
377
378            // Update running averages
379            let grad_norm = gradient.mapv(|g| g * g).sum().sqrt();
380            epoch_grad_norm += grad_norm;
381
382            let batch_loss = grad_func.compute_value(&x.view(), &batch_data);
383            func_evals += 1;
384            _epoch_loss += batch_loss;
385        }
386
387        // Update Polyak averaging
388        if epoch >= avg_start_epoch {
389            let weight = 1.0 / (epoch - avg_start_epoch + 1) as f64;
390            x_avg = &x_avg * (1.0 - weight) + &x * weight;
391        }
392
393        // Use averaged parameters for evaluation after averaging starts
394        let eval_x = if epoch >= avg_start_epoch { &x_avg } else { &x };
395
396        // Evaluate on full dataset
397        let full_data = data_provider.get_full_data();
398        let current_loss = grad_func.compute_value(&eval_x.view(), &full_data);
399        func_evals += 1;
400
401        if current_loss < best_f {
402            best_f = current_loss;
403            best_x = eval_x.clone();
404        }
405
406        // Progress reporting
407        if epoch % 10 == 0 {
408            let avg_grad_norm = epoch_grad_norm / batches_per_epoch as f64;
409            println!(
410                "  Epoch {}: loss = {:.6e}, avg |grad| = {:.3e}, lr = {:.3e}{}",
411                epoch,
412                current_loss,
413                avg_grad_norm,
414                current_lr,
415                if epoch >= avg_start_epoch {
416                    " (averaged)"
417                } else {
418                    ""
419                }
420            );
421        }
422
423        // Check convergence
424        let avg_grad_norm = epoch_grad_norm / batches_per_epoch as f64;
425        if avg_grad_norm < options.tol {
426            return Ok(OptimizeResult {
427                x: best_x,
428                fun: best_f,
429                nit: epoch,
430                func_evals,
431                nfev: func_evals,
432                success: true,
433                message: format!(
434                    "Mini-batch SGD converged: avg gradient norm {:.2e} < {:.2e}",
435                    avg_grad_norm, options.tol
436                ),
437                jacobian: None,
438                hessian: None,
439            });
440        }
441    }
442
443    Ok(OptimizeResult {
444        x: best_x,
445        fun: best_f,
446        nit: options.max_iter,
447        func_evals,
448        nfev: func_evals,
449        success: false,
450        message: "Mini-batch SGD reached maximum iterations".to_string(),
451        jacobian: None,
452        hessian: None,
453    })
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459    use crate::stochastic::InMemoryDataProvider;
460    use approx::assert_abs_diff_eq;
461    use scirs2_core::ndarray::ArrayView1;
462
463    // Simple quadratic function for testing
464    struct QuadraticFunction;
465
466    impl StochasticGradientFunction for QuadraticFunction {
467        fn compute_gradient(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> Array1<f64> {
468            // Gradient of f(x) = sum(x_i^2) is 2*x
469            x.mapv(|xi| 2.0 * xi)
470        }
471
472        fn compute_value(&mut self, x: &ArrayView1<f64>, _batchdata: &[f64]) -> f64 {
473            // f(x) = sum(x_i^2)
474            x.mapv(|xi| xi * xi).sum()
475        }
476    }
477
478    #[test]
479    fn test_sgd_quadratic() {
480        let grad_func = QuadraticFunction;
481        let x0 = Array1::from_vec(vec![1.0, 2.0, -1.5]);
482        let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 100]));
483
484        let options = SGDOptions {
485            learning_rate: 0.1,
486            max_iter: 100,
487            tol: 1e-6,
488            ..Default::default()
489        };
490
491        let result = minimize_sgd(grad_func, x0, data_provider, options).unwrap();
492
493        // Should converge to zero
494        assert!(result.success || result.fun < 1e-4);
495        for &xi in result.x.iter() {
496            assert_abs_diff_eq!(xi, 0.0, epsilon = 1e-2);
497        }
498    }
499
500    #[test]
501    fn test_svrg_quadratic() {
502        let grad_func = QuadraticFunction;
503        let x0 = Array1::from_vec(vec![1.0, -1.0]);
504        let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 50]));
505
506        let options = SGDOptions {
507            learning_rate: 0.05,
508            max_iter: 50,
509            batch_size: Some(5),
510            tol: 1e-6,
511            ..Default::default()
512        };
513
514        let result = minimize_svrg(grad_func, x0, data_provider, options).unwrap();
515
516        // SVRG should converge faster than regular SGD
517        assert!(result.success || result.fun < 1e-4);
518    }
519
520    #[test]
521    fn test_mini_batch_sgd() {
522        let grad_func = QuadraticFunction;
523        let x0 = Array1::from_vec(vec![2.0, -2.0, 1.0]);
524        let data_provider = Box::new(InMemoryDataProvider::new(vec![1.0; 200]));
525
526        let options = SGDOptions {
527            learning_rate: 0.01,
528            max_iter: 100,
529            batch_size: Some(10),
530            tol: 1e-6,
531            lr_schedule: LearningRateSchedule::ExponentialDecay { decay_rate: 0.99 },
532            ..Default::default()
533        };
534
535        let result = minimize_mini_batch_sgd(grad_func, x0, data_provider, options).unwrap();
536
537        // Should converge with Polyak averaging
538        assert!(result.success || result.fun < 1e-3);
539    }
540}