Skip to main content

sklears_simd/
optimization.rs

1//! SIMD-optimized optimization algorithms
2//!
3//! This module implements high-performance optimization algorithms using SIMD instructions
4//! for machine learning applications including gradient descent, coordinate descent,
5//! and Newton-type methods.
6
7use crate::matrix::matrix_vector_multiply_f32;
8use crate::vector::{dot_product, norm_l2};
9use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayViewMut1};
10
11// Conditional imports for no-std compatibility
12#[cfg(feature = "no-std")]
13use alloc::string::String;
14#[cfg(not(feature = "no-std"))]
15use std::string::String;
16
17/// SIMD-optimized gradient descent optimizer
18pub struct GradientDescent {
19    learning_rate: f32,
20    momentum: f32,
21    #[allow(dead_code)]
22    // Standard SGD dampening term (dampens gradient contribution to momentum); deferred
23    dampening: f32,
24    weight_decay: f32,
25    nesterov: bool,
26}
27
28impl GradientDescent {
29    /// Create a new gradient descent optimizer
30    pub fn new(learning_rate: f32) -> Self {
31        Self {
32            learning_rate,
33            momentum: 0.0,
34            dampening: 0.0,
35            weight_decay: 0.0,
36            nesterov: false,
37        }
38    }
39
40    /// Set momentum for the optimizer
41    pub fn with_momentum(mut self, momentum: f32) -> Self {
42        self.momentum = momentum;
43        self
44    }
45
46    /// Set weight decay (L2 regularization)
47    pub fn with_weight_decay(mut self, weight_decay: f32) -> Self {
48        self.weight_decay = weight_decay;
49        self
50    }
51
52    /// Enable Nesterov momentum
53    pub fn with_nesterov(mut self) -> Self {
54        self.nesterov = true;
55        self
56    }
57
58    /// Perform a single optimization step
59    pub fn step(
60        &self,
61        params: &mut ArrayViewMut1<f32>,
62        gradient: &ArrayView1<f32>,
63        velocity: &mut ArrayViewMut1<f32>,
64    ) {
65        // Add weight decay to gradient if specified
66        let mut grad = gradient.to_owned();
67        if self.weight_decay != 0.0 {
68            simd_axpy(self.weight_decay, &params.view(), &mut grad.view_mut());
69        }
70
71        if self.momentum != 0.0 {
72            // Update velocity: v = momentum * v + grad
73            simd_momentum_update(self.momentum, &grad.view(), velocity);
74
75            if self.nesterov {
76                // Nesterov momentum: param = param - lr * (momentum * v + grad)
77                let mut nesterov_grad = grad.clone();
78                simd_axpy(
79                    self.momentum,
80                    &velocity.view(),
81                    &mut nesterov_grad.view_mut(),
82                );
83                simd_axpy(-self.learning_rate, &nesterov_grad.view(), params);
84            } else {
85                // Standard momentum: param = param - lr * v
86                simd_axpy(-self.learning_rate, &velocity.view(), params);
87            }
88        } else {
89            // No momentum: param = param - lr * grad
90            simd_axpy(-self.learning_rate, &grad.view(), params);
91        }
92    }
93}
94
95/// SIMD-optimized coordinate descent optimizer
96pub struct CoordinateDescent {
97    alpha: f32,
98    tolerance: f32,
99    max_iterations: usize,
100}
101
102impl CoordinateDescent {
103    /// Create a new coordinate descent optimizer
104    pub fn new(alpha: f32) -> Self {
105        Self {
106            alpha,
107            tolerance: 1e-4,
108            max_iterations: 1000,
109        }
110    }
111
112    /// Set convergence tolerance
113    pub fn with_tolerance(mut self, tolerance: f32) -> Self {
114        self.tolerance = tolerance;
115        self
116    }
117
118    /// Set maximum iterations
119    pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
120        self.max_iterations = max_iterations;
121        self
122    }
123
124    /// Optimize using coordinate descent for LASSO regression
125    pub fn optimize_lasso(
126        &self,
127        x: &Array2<f32>,
128        y: &Array1<f32>,
129        coeff: &mut Array1<f32>,
130    ) -> Result<(), String> {
131        let n_features = x.ncols();
132        let n_samples = x.nrows();
133
134        // Pre-compute X^T X diagonal for efficiency
135        let mut xtx_diag = Array1::zeros(n_features);
136        for j in 0..n_features {
137            let col = x.column(j).to_owned();
138            xtx_diag[j] = dot_product(
139                col.as_slice().expect("slice operation should succeed"),
140                col.as_slice().expect("slice operation should succeed"),
141            );
142        }
143
144        // Residuals: r = y - X * coeff
145        let mut residuals = y.clone();
146        let pred = matrix_vector_multiply_f32(x, coeff);
147        simd_axpy(-1.0, &pred.view(), &mut residuals.view_mut());
148
149        for _ in 0..self.max_iterations {
150            let mut max_change: f32 = 0.0;
151
152            for j in 0..n_features {
153                let old_coeff = coeff[j];
154
155                // Add back the contribution of feature j to residuals
156                let col = x.column(j);
157                simd_axpy(old_coeff, &col.to_owned().view(), &mut residuals.view_mut());
158
159                // Compute new coefficient
160                let col_slice = col.to_owned();
161                let rho = dot_product(
162                    col_slice
163                        .as_slice()
164                        .expect("slice operation should succeed"),
165                    residuals
166                        .as_slice()
167                        .expect("slice operation should succeed"),
168                );
169                let new_coeff = soft_threshold(rho / n_samples as f32, self.alpha)
170                    / (xtx_diag[j] / n_samples as f32);
171
172                // Update coefficient and residuals
173                coeff[j] = new_coeff;
174                let change = new_coeff - old_coeff;
175                max_change = max_change.max(change.abs());
176
177                // Subtract new contribution from residuals
178                simd_axpy(
179                    -new_coeff,
180                    &col.to_owned().view(),
181                    &mut residuals.view_mut(),
182                );
183            }
184
185            if max_change < self.tolerance {
186                return Ok(());
187            }
188        }
189
190        Ok(())
191    }
192}
193
194/// SIMD-optimized quasi-Newton optimizer (L-BFGS)
195pub struct QuasiNewton {
196    memory_size: usize,
197    tolerance: f32,
198    max_iterations: usize,
199    line_search_max_iter: usize,
200}
201
202impl Default for QuasiNewton {
203    fn default() -> Self {
204        Self::new()
205    }
206}
207
208impl QuasiNewton {
209    /// Create a new quasi-Newton optimizer
210    pub fn new() -> Self {
211        Self {
212            memory_size: 10,
213            tolerance: 1e-6,
214            max_iterations: 1000,
215            line_search_max_iter: 20,
216        }
217    }
218
219    /// Set L-BFGS memory size
220    pub fn with_memory_size(mut self, memory_size: usize) -> Self {
221        self.memory_size = memory_size;
222        self
223    }
224
225    /// Simple L-BFGS implementation for demonstration
226    pub fn optimize<F, G>(
227        &self,
228        mut x: Array1<f32>,
229        objective: F,
230        gradient: G,
231    ) -> Result<Array1<f32>, String>
232    where
233        F: Fn(&Array1<f32>) -> f32,
234        G: Fn(&Array1<f32>) -> Array1<f32>,
235    {
236        let n = x.len();
237        let mut grad = gradient(&x);
238        let h_inv = Array2::eye(n); // Initial Hessian inverse approximation
239
240        for _ in 0..self.max_iterations {
241            let grad_norm = norm_l2(grad.as_slice().expect("slice operation should succeed"));
242            if grad_norm < self.tolerance {
243                return Ok(x);
244            }
245
246            // Compute search direction: d = -H^{-1} * grad
247            let direction = matrix_vector_multiply_f32(&h_inv, &grad);
248            let mut search_dir = direction;
249            simd_scale(-1.0, &mut search_dir.view_mut());
250
251            // Line search to find step size
252            let step_size = self.line_search(&x, &search_dir, &objective, &gradient)?;
253
254            // Update parameters
255            let mut step = search_dir.clone();
256            simd_scale(step_size, &mut step.view_mut());
257            let x_new = &x + &step;
258
259            let grad_new = gradient(&x_new);
260
261            // BFGS update (simplified)
262            let s = &x_new - &x;
263            let y = &grad_new - &grad;
264
265            let sy = dot_product(
266                s.as_slice().expect("slice operation should succeed"),
267                y.as_slice().expect("slice operation should succeed"),
268            );
269            if sy > 1e-10 {
270                // Update Hessian inverse approximation (simplified rank-1 update)
271                // This is a simplified version - full L-BFGS would maintain a history
272            }
273
274            x = x_new;
275            grad = grad_new;
276        }
277
278        Ok(x)
279    }
280
281    /// Simple backtracking line search
282    fn line_search<F, G>(
283        &self,
284        x: &Array1<f32>,
285        direction: &Array1<f32>,
286        objective: &F,
287        gradient: &G,
288    ) -> Result<f32, String>
289    where
290        F: Fn(&Array1<f32>) -> f32,
291        G: Fn(&Array1<f32>) -> Array1<f32>,
292    {
293        let c1 = 1e-4;
294        let mut alpha = 1.0;
295        let f_x = objective(x);
296        let grad_x = gradient(x);
297        let grad_dot_dir = dot_product(
298            grad_x.as_slice().expect("slice operation should succeed"),
299            direction
300                .as_slice()
301                .expect("slice operation should succeed"),
302        );
303
304        for _ in 0..self.line_search_max_iter {
305            let mut x_new = x.clone();
306            let mut step = direction.clone();
307            simd_scale(alpha, &mut step.view_mut());
308            simd_axpy(1.0, &step.view(), &mut x_new.view_mut());
309
310            let f_x_new = objective(&x_new);
311
312            // Armijo condition
313            if f_x_new <= f_x + c1 * alpha * grad_dot_dir {
314                return Ok(alpha);
315            }
316
317            alpha *= 0.5;
318        }
319
320        Ok(alpha)
321    }
322}
323
324/// SIMD-optimized AXPY operation: y = alpha * x + y
325pub fn simd_axpy(alpha: f32, x: &ArrayView1<f32>, y: &mut ArrayViewMut1<f32>) {
326    assert_eq!(x.len(), y.len(), "Arrays must have the same length");
327
328    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
329    {
330        if crate::simd_feature_detected!("avx2") && crate::simd_feature_detected!("fma") {
331            unsafe { simd_axpy_avx2_fma(alpha, x, y) };
332            return;
333        } else if crate::simd_feature_detected!("avx2") {
334            unsafe { simd_axpy_avx2(alpha, x, y) };
335            return;
336        } else if crate::simd_feature_detected!("sse2") {
337            unsafe { simd_axpy_sse2(alpha, x, y) };
338            return;
339        }
340    }
341
342    // Scalar fallback
343    for i in 0..x.len() {
344        y[i] += alpha * x[i];
345    }
346}
347
348/// SIMD-optimized scaling: x = alpha * x
349pub fn simd_scale(alpha: f32, x: &mut ArrayViewMut1<f32>) {
350    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
351    {
352        if crate::simd_feature_detected!("avx2") {
353            unsafe { simd_scale_avx2(alpha, x) };
354            return;
355        } else if crate::simd_feature_detected!("sse2") {
356            unsafe { simd_scale_sse2(alpha, x) };
357            return;
358        }
359    }
360
361    // Scalar fallback
362    for val in x.iter_mut() {
363        *val *= alpha;
364    }
365}
366
367/// SIMD-optimized momentum update: v = momentum * v + grad
368pub fn simd_momentum_update(
369    momentum: f32,
370    grad: &ArrayView1<f32>,
371    velocity: &mut ArrayViewMut1<f32>,
372) {
373    assert_eq!(
374        grad.len(),
375        velocity.len(),
376        "Arrays must have the same length"
377    );
378
379    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
380    {
381        if crate::simd_feature_detected!("avx2") && crate::simd_feature_detected!("fma") {
382            unsafe { simd_momentum_update_avx2_fma(momentum, grad, velocity) };
383            return;
384        } else if crate::simd_feature_detected!("avx2") {
385            unsafe { simd_momentum_update_avx2(momentum, grad, velocity) };
386            return;
387        } else if crate::simd_feature_detected!("sse2") {
388            unsafe { simd_momentum_update_sse2(momentum, grad, velocity) };
389            return;
390        }
391    }
392
393    // Scalar fallback
394    for i in 0..grad.len() {
395        velocity[i] = momentum * velocity[i] + grad[i];
396    }
397}
398
399/// Soft thresholding function for LASSO
400fn soft_threshold(x: f32, threshold: f32) -> f32 {
401    if x > threshold {
402        x - threshold
403    } else if x < -threshold {
404        x + threshold
405    } else {
406        0.0
407    }
408}
409
410// SIMD implementations for x86/x86_64
411
412#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
413#[target_feature(enable = "sse2")]
414unsafe fn simd_axpy_sse2(alpha: f32, x: &ArrayView1<f32>, y: &mut ArrayViewMut1<f32>) {
415    use core::arch::x86_64::*;
416
417    let alpha_vec = _mm_set1_ps(alpha);
418    let len = x.len();
419    let mut i = 0;
420
421    while i + 4 <= len {
422        let x_vec = _mm_loadu_ps(&x[i]);
423        let y_vec = _mm_loadu_ps(&y[i]);
424        let result = _mm_add_ps(_mm_mul_ps(alpha_vec, x_vec), y_vec);
425        _mm_storeu_ps(&mut y[i], result);
426        i += 4;
427    }
428
429    // Handle remaining elements
430    while i < len {
431        y[i] += alpha * x[i];
432        i += 1;
433    }
434}
435
436#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
437#[target_feature(enable = "avx2")]
438unsafe fn simd_axpy_avx2(alpha: f32, x: &ArrayView1<f32>, y: &mut ArrayViewMut1<f32>) {
439    use core::arch::x86_64::*;
440
441    let alpha_vec = _mm256_set1_ps(alpha);
442    let len = x.len();
443    let mut i = 0;
444
445    while i + 8 <= len {
446        let x_vec = _mm256_loadu_ps(&x[i]);
447        let y_vec = _mm256_loadu_ps(&y[i]);
448        let result = _mm256_add_ps(_mm256_mul_ps(alpha_vec, x_vec), y_vec);
449        _mm256_storeu_ps(&mut y[i], result);
450        i += 8;
451    }
452
453    // Handle remaining elements
454    while i < len {
455        y[i] += alpha * x[i];
456        i += 1;
457    }
458}
459
460#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
461#[target_feature(enable = "avx2", enable = "fma")]
462unsafe fn simd_axpy_avx2_fma(alpha: f32, x: &ArrayView1<f32>, y: &mut ArrayViewMut1<f32>) {
463    use core::arch::x86_64::*;
464
465    let alpha_vec = _mm256_set1_ps(alpha);
466    let len = x.len();
467    let mut i = 0;
468
469    while i + 8 <= len {
470        let x_vec = _mm256_loadu_ps(&x[i]);
471        let y_vec = _mm256_loadu_ps(&y[i]);
472        let result = _mm256_fmadd_ps(alpha_vec, x_vec, y_vec);
473        _mm256_storeu_ps(&mut y[i], result);
474        i += 8;
475    }
476
477    // Handle remaining elements
478    while i < len {
479        y[i] += alpha * x[i];
480        i += 1;
481    }
482}
483
484#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
485#[target_feature(enable = "sse2")]
486unsafe fn simd_scale_sse2(alpha: f32, x: &mut ArrayViewMut1<f32>) {
487    use core::arch::x86_64::*;
488
489    let alpha_vec = _mm_set1_ps(alpha);
490    let len = x.len();
491    let mut i = 0;
492
493    while i + 4 <= len {
494        let x_vec = _mm_loadu_ps(&x[i]);
495        let result = _mm_mul_ps(alpha_vec, x_vec);
496        _mm_storeu_ps(&mut x[i], result);
497        i += 4;
498    }
499
500    // Handle remaining elements
501    while i < len {
502        x[i] *= alpha;
503        i += 1;
504    }
505}
506
507#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
508#[target_feature(enable = "avx2")]
509unsafe fn simd_scale_avx2(alpha: f32, x: &mut ArrayViewMut1<f32>) {
510    use core::arch::x86_64::*;
511
512    let alpha_vec = _mm256_set1_ps(alpha);
513    let len = x.len();
514    let mut i = 0;
515
516    while i + 8 <= len {
517        let x_vec = _mm256_loadu_ps(&x[i]);
518        let result = _mm256_mul_ps(alpha_vec, x_vec);
519        _mm256_storeu_ps(&mut x[i], result);
520        i += 8;
521    }
522
523    // Handle remaining elements
524    while i < len {
525        x[i] *= alpha;
526        i += 1;
527    }
528}
529
530#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
531#[target_feature(enable = "sse2")]
532unsafe fn simd_momentum_update_sse2(
533    momentum: f32,
534    grad: &ArrayView1<f32>,
535    velocity: &mut ArrayViewMut1<f32>,
536) {
537    use core::arch::x86_64::*;
538
539    let momentum_vec = _mm_set1_ps(momentum);
540    let len = grad.len();
541    let mut i = 0;
542
543    while i + 4 <= len {
544        let grad_vec = _mm_loadu_ps(&grad[i]);
545        let vel_vec = _mm_loadu_ps(&velocity[i]);
546        let result = _mm_add_ps(_mm_mul_ps(momentum_vec, vel_vec), grad_vec);
547        _mm_storeu_ps(&mut velocity[i], result);
548        i += 4;
549    }
550
551    // Handle remaining elements
552    while i < len {
553        velocity[i] = momentum * velocity[i] + grad[i];
554        i += 1;
555    }
556}
557
558#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
559#[target_feature(enable = "avx2")]
560unsafe fn simd_momentum_update_avx2(
561    momentum: f32,
562    grad: &ArrayView1<f32>,
563    velocity: &mut ArrayViewMut1<f32>,
564) {
565    use core::arch::x86_64::*;
566
567    let momentum_vec = _mm256_set1_ps(momentum);
568    let len = grad.len();
569    let mut i = 0;
570
571    while i + 8 <= len {
572        let grad_vec = _mm256_loadu_ps(&grad[i]);
573        let vel_vec = _mm256_loadu_ps(&velocity[i]);
574        let result = _mm256_add_ps(_mm256_mul_ps(momentum_vec, vel_vec), grad_vec);
575        _mm256_storeu_ps(&mut velocity[i], result);
576        i += 8;
577    }
578
579    // Handle remaining elements
580    while i < len {
581        velocity[i] = momentum * velocity[i] + grad[i];
582        i += 1;
583    }
584}
585
586#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
587#[target_feature(enable = "avx2", enable = "fma")]
588unsafe fn simd_momentum_update_avx2_fma(
589    momentum: f32,
590    grad: &ArrayView1<f32>,
591    velocity: &mut ArrayViewMut1<f32>,
592) {
593    use core::arch::x86_64::*;
594
595    let momentum_vec = _mm256_set1_ps(momentum);
596    let len = grad.len();
597    let mut i = 0;
598
599    while i + 8 <= len {
600        let grad_vec = _mm256_loadu_ps(&grad[i]);
601        let vel_vec = _mm256_loadu_ps(&velocity[i]);
602        let result = _mm256_fmadd_ps(momentum_vec, vel_vec, grad_vec);
603        _mm256_storeu_ps(&mut velocity[i], result);
604        i += 8;
605    }
606
607    // Handle remaining elements
608    while i < len {
609        velocity[i] = momentum * velocity[i] + grad[i];
610        i += 1;
611    }
612}
613
614#[allow(non_snake_case)]
615#[cfg(all(test, not(feature = "no-std")))]
616mod tests {
617    use super::*;
618    use approx::assert_relative_eq;
619
620    #[cfg(feature = "no-std")]
621    use alloc::{vec, vec::Vec};
622
623    #[test]
624    fn test_gradient_descent() {
625        let optimizer = GradientDescent::new(0.1).with_momentum(0.9);
626
627        let mut params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
628        let gradient = Array1::from_vec(vec![0.1, 0.2, 0.3]);
629        let mut velocity = Array1::zeros(3);
630
631        let params_before = params.clone();
632        optimizer.step(
633            &mut params.view_mut(),
634            &gradient.view(),
635            &mut velocity.view_mut(),
636        );
637
638        // Parameters should have moved in the opposite direction of the gradient
639        for i in 0..params.len() {
640            assert!(params[i] < params_before[i]);
641        }
642    }
643
644    #[test]
645    fn test_coordinate_descent() {
646        let optimizer = CoordinateDescent::new(0.1);
647
648        // Simple 2D problem
649        let x = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
650            .expect("shape and data length should match");
651        let y = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
652        let mut coeff = Array1::zeros(2);
653
654        let result = optimizer.optimize_lasso(&x, &y, &mut coeff);
655        assert!(result.is_ok());
656    }
657
658    #[test]
659    fn test_simd_axpy() {
660        let x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
661        let mut y = Array1::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0]);
662        let alpha = 2.0;
663
664        let expected = &y + &(&x * alpha);
665        simd_axpy(alpha, &x.view(), &mut y.view_mut());
666
667        for i in 0..x.len() {
668            assert_relative_eq!(y[i], expected[i], epsilon = 1e-6);
669        }
670    }
671
672    #[test]
673    fn test_simd_scale() {
674        let mut x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
675        let alpha = 2.5;
676
677        let expected = &x * alpha;
678        simd_scale(alpha, &mut x.view_mut());
679
680        for i in 0..x.len() {
681            assert_relative_eq!(x[i], expected[i], epsilon = 1e-6);
682        }
683    }
684
685    #[test]
686    fn test_momentum_update() {
687        let grad = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
688        let mut velocity = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
689        let momentum = 0.9;
690
691        let expected = &velocity * momentum + &grad;
692        simd_momentum_update(momentum, &grad.view(), &mut velocity.view_mut());
693
694        for i in 0..grad.len() {
695            assert_relative_eq!(velocity[i], expected[i], epsilon = 1e-6);
696        }
697    }
698
699    #[test]
700    fn test_soft_threshold() {
701        assert_eq!(soft_threshold(2.0, 1.0), 1.0);
702        assert_eq!(soft_threshold(-2.0, 1.0), -1.0);
703        assert_eq!(soft_threshold(0.5, 1.0), 0.0);
704        assert_eq!(soft_threshold(-0.5, 1.0), 0.0);
705    }
706}