1use crate::matrix::matrix_vector_multiply_f32;
8use crate::vector::{dot_product, norm_l2};
9use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayViewMut1};
10
11#[cfg(feature = "no-std")]
13use alloc::string::String;
14#[cfg(not(feature = "no-std"))]
15use std::string::String;
16
17pub struct GradientDescent {
19 learning_rate: f32,
20 momentum: f32,
21 #[allow(dead_code)]
22 dampening: f32,
24 weight_decay: f32,
25 nesterov: bool,
26}
27
28impl GradientDescent {
29 pub fn new(learning_rate: f32) -> Self {
31 Self {
32 learning_rate,
33 momentum: 0.0,
34 dampening: 0.0,
35 weight_decay: 0.0,
36 nesterov: false,
37 }
38 }
39
40 pub fn with_momentum(mut self, momentum: f32) -> Self {
42 self.momentum = momentum;
43 self
44 }
45
46 pub fn with_weight_decay(mut self, weight_decay: f32) -> Self {
48 self.weight_decay = weight_decay;
49 self
50 }
51
52 pub fn with_nesterov(mut self) -> Self {
54 self.nesterov = true;
55 self
56 }
57
58 pub fn step(
60 &self,
61 params: &mut ArrayViewMut1<f32>,
62 gradient: &ArrayView1<f32>,
63 velocity: &mut ArrayViewMut1<f32>,
64 ) {
65 let mut grad = gradient.to_owned();
67 if self.weight_decay != 0.0 {
68 simd_axpy(self.weight_decay, ¶ms.view(), &mut grad.view_mut());
69 }
70
71 if self.momentum != 0.0 {
72 simd_momentum_update(self.momentum, &grad.view(), velocity);
74
75 if self.nesterov {
76 let mut nesterov_grad = grad.clone();
78 simd_axpy(
79 self.momentum,
80 &velocity.view(),
81 &mut nesterov_grad.view_mut(),
82 );
83 simd_axpy(-self.learning_rate, &nesterov_grad.view(), params);
84 } else {
85 simd_axpy(-self.learning_rate, &velocity.view(), params);
87 }
88 } else {
89 simd_axpy(-self.learning_rate, &grad.view(), params);
91 }
92 }
93}
94
95pub struct CoordinateDescent {
97 alpha: f32,
98 tolerance: f32,
99 max_iterations: usize,
100}
101
102impl CoordinateDescent {
103 pub fn new(alpha: f32) -> Self {
105 Self {
106 alpha,
107 tolerance: 1e-4,
108 max_iterations: 1000,
109 }
110 }
111
112 pub fn with_tolerance(mut self, tolerance: f32) -> Self {
114 self.tolerance = tolerance;
115 self
116 }
117
118 pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
120 self.max_iterations = max_iterations;
121 self
122 }
123
124 pub fn optimize_lasso(
126 &self,
127 x: &Array2<f32>,
128 y: &Array1<f32>,
129 coeff: &mut Array1<f32>,
130 ) -> Result<(), String> {
131 let n_features = x.ncols();
132 let n_samples = x.nrows();
133
134 let mut xtx_diag = Array1::zeros(n_features);
136 for j in 0..n_features {
137 let col = x.column(j).to_owned();
138 xtx_diag[j] = dot_product(
139 col.as_slice().expect("slice operation should succeed"),
140 col.as_slice().expect("slice operation should succeed"),
141 );
142 }
143
144 let mut residuals = y.clone();
146 let pred = matrix_vector_multiply_f32(x, coeff);
147 simd_axpy(-1.0, &pred.view(), &mut residuals.view_mut());
148
149 for _ in 0..self.max_iterations {
150 let mut max_change: f32 = 0.0;
151
152 for j in 0..n_features {
153 let old_coeff = coeff[j];
154
155 let col = x.column(j);
157 simd_axpy(old_coeff, &col.to_owned().view(), &mut residuals.view_mut());
158
159 let col_slice = col.to_owned();
161 let rho = dot_product(
162 col_slice
163 .as_slice()
164 .expect("slice operation should succeed"),
165 residuals
166 .as_slice()
167 .expect("slice operation should succeed"),
168 );
169 let new_coeff = soft_threshold(rho / n_samples as f32, self.alpha)
170 / (xtx_diag[j] / n_samples as f32);
171
172 coeff[j] = new_coeff;
174 let change = new_coeff - old_coeff;
175 max_change = max_change.max(change.abs());
176
177 simd_axpy(
179 -new_coeff,
180 &col.to_owned().view(),
181 &mut residuals.view_mut(),
182 );
183 }
184
185 if max_change < self.tolerance {
186 return Ok(());
187 }
188 }
189
190 Ok(())
191 }
192}
193
194pub struct QuasiNewton {
196 memory_size: usize,
197 tolerance: f32,
198 max_iterations: usize,
199 line_search_max_iter: usize,
200}
201
202impl Default for QuasiNewton {
203 fn default() -> Self {
204 Self::new()
205 }
206}
207
208impl QuasiNewton {
209 pub fn new() -> Self {
211 Self {
212 memory_size: 10,
213 tolerance: 1e-6,
214 max_iterations: 1000,
215 line_search_max_iter: 20,
216 }
217 }
218
219 pub fn with_memory_size(mut self, memory_size: usize) -> Self {
221 self.memory_size = memory_size;
222 self
223 }
224
225 pub fn optimize<F, G>(
227 &self,
228 mut x: Array1<f32>,
229 objective: F,
230 gradient: G,
231 ) -> Result<Array1<f32>, String>
232 where
233 F: Fn(&Array1<f32>) -> f32,
234 G: Fn(&Array1<f32>) -> Array1<f32>,
235 {
236 let n = x.len();
237 let mut grad = gradient(&x);
238 let h_inv = Array2::eye(n); for _ in 0..self.max_iterations {
241 let grad_norm = norm_l2(grad.as_slice().expect("slice operation should succeed"));
242 if grad_norm < self.tolerance {
243 return Ok(x);
244 }
245
246 let direction = matrix_vector_multiply_f32(&h_inv, &grad);
248 let mut search_dir = direction;
249 simd_scale(-1.0, &mut search_dir.view_mut());
250
251 let step_size = self.line_search(&x, &search_dir, &objective, &gradient)?;
253
254 let mut step = search_dir.clone();
256 simd_scale(step_size, &mut step.view_mut());
257 let x_new = &x + &step;
258
259 let grad_new = gradient(&x_new);
260
261 let s = &x_new - &x;
263 let y = &grad_new - &grad;
264
265 let sy = dot_product(
266 s.as_slice().expect("slice operation should succeed"),
267 y.as_slice().expect("slice operation should succeed"),
268 );
269 if sy > 1e-10 {
270 }
273
274 x = x_new;
275 grad = grad_new;
276 }
277
278 Ok(x)
279 }
280
281 fn line_search<F, G>(
283 &self,
284 x: &Array1<f32>,
285 direction: &Array1<f32>,
286 objective: &F,
287 gradient: &G,
288 ) -> Result<f32, String>
289 where
290 F: Fn(&Array1<f32>) -> f32,
291 G: Fn(&Array1<f32>) -> Array1<f32>,
292 {
293 let c1 = 1e-4;
294 let mut alpha = 1.0;
295 let f_x = objective(x);
296 let grad_x = gradient(x);
297 let grad_dot_dir = dot_product(
298 grad_x.as_slice().expect("slice operation should succeed"),
299 direction
300 .as_slice()
301 .expect("slice operation should succeed"),
302 );
303
304 for _ in 0..self.line_search_max_iter {
305 let mut x_new = x.clone();
306 let mut step = direction.clone();
307 simd_scale(alpha, &mut step.view_mut());
308 simd_axpy(1.0, &step.view(), &mut x_new.view_mut());
309
310 let f_x_new = objective(&x_new);
311
312 if f_x_new <= f_x + c1 * alpha * grad_dot_dir {
314 return Ok(alpha);
315 }
316
317 alpha *= 0.5;
318 }
319
320 Ok(alpha)
321 }
322}
323
324pub fn simd_axpy(alpha: f32, x: &ArrayView1<f32>, y: &mut ArrayViewMut1<f32>) {
326 assert_eq!(x.len(), y.len(), "Arrays must have the same length");
327
328 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
329 {
330 if crate::simd_feature_detected!("avx2") && crate::simd_feature_detected!("fma") {
331 unsafe { simd_axpy_avx2_fma(alpha, x, y) };
332 return;
333 } else if crate::simd_feature_detected!("avx2") {
334 unsafe { simd_axpy_avx2(alpha, x, y) };
335 return;
336 } else if crate::simd_feature_detected!("sse2") {
337 unsafe { simd_axpy_sse2(alpha, x, y) };
338 return;
339 }
340 }
341
342 for i in 0..x.len() {
344 y[i] += alpha * x[i];
345 }
346}
347
348pub fn simd_scale(alpha: f32, x: &mut ArrayViewMut1<f32>) {
350 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
351 {
352 if crate::simd_feature_detected!("avx2") {
353 unsafe { simd_scale_avx2(alpha, x) };
354 return;
355 } else if crate::simd_feature_detected!("sse2") {
356 unsafe { simd_scale_sse2(alpha, x) };
357 return;
358 }
359 }
360
361 for val in x.iter_mut() {
363 *val *= alpha;
364 }
365}
366
367pub fn simd_momentum_update(
369 momentum: f32,
370 grad: &ArrayView1<f32>,
371 velocity: &mut ArrayViewMut1<f32>,
372) {
373 assert_eq!(
374 grad.len(),
375 velocity.len(),
376 "Arrays must have the same length"
377 );
378
379 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
380 {
381 if crate::simd_feature_detected!("avx2") && crate::simd_feature_detected!("fma") {
382 unsafe { simd_momentum_update_avx2_fma(momentum, grad, velocity) };
383 return;
384 } else if crate::simd_feature_detected!("avx2") {
385 unsafe { simd_momentum_update_avx2(momentum, grad, velocity) };
386 return;
387 } else if crate::simd_feature_detected!("sse2") {
388 unsafe { simd_momentum_update_sse2(momentum, grad, velocity) };
389 return;
390 }
391 }
392
393 for i in 0..grad.len() {
395 velocity[i] = momentum * velocity[i] + grad[i];
396 }
397}
398
399fn soft_threshold(x: f32, threshold: f32) -> f32 {
401 if x > threshold {
402 x - threshold
403 } else if x < -threshold {
404 x + threshold
405 } else {
406 0.0
407 }
408}
409
410#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
413#[target_feature(enable = "sse2")]
414unsafe fn simd_axpy_sse2(alpha: f32, x: &ArrayView1<f32>, y: &mut ArrayViewMut1<f32>) {
415 use core::arch::x86_64::*;
416
417 let alpha_vec = _mm_set1_ps(alpha);
418 let len = x.len();
419 let mut i = 0;
420
421 while i + 4 <= len {
422 let x_vec = _mm_loadu_ps(&x[i]);
423 let y_vec = _mm_loadu_ps(&y[i]);
424 let result = _mm_add_ps(_mm_mul_ps(alpha_vec, x_vec), y_vec);
425 _mm_storeu_ps(&mut y[i], result);
426 i += 4;
427 }
428
429 while i < len {
431 y[i] += alpha * x[i];
432 i += 1;
433 }
434}
435
436#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
437#[target_feature(enable = "avx2")]
438unsafe fn simd_axpy_avx2(alpha: f32, x: &ArrayView1<f32>, y: &mut ArrayViewMut1<f32>) {
439 use core::arch::x86_64::*;
440
441 let alpha_vec = _mm256_set1_ps(alpha);
442 let len = x.len();
443 let mut i = 0;
444
445 while i + 8 <= len {
446 let x_vec = _mm256_loadu_ps(&x[i]);
447 let y_vec = _mm256_loadu_ps(&y[i]);
448 let result = _mm256_add_ps(_mm256_mul_ps(alpha_vec, x_vec), y_vec);
449 _mm256_storeu_ps(&mut y[i], result);
450 i += 8;
451 }
452
453 while i < len {
455 y[i] += alpha * x[i];
456 i += 1;
457 }
458}
459
460#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
461#[target_feature(enable = "avx2", enable = "fma")]
462unsafe fn simd_axpy_avx2_fma(alpha: f32, x: &ArrayView1<f32>, y: &mut ArrayViewMut1<f32>) {
463 use core::arch::x86_64::*;
464
465 let alpha_vec = _mm256_set1_ps(alpha);
466 let len = x.len();
467 let mut i = 0;
468
469 while i + 8 <= len {
470 let x_vec = _mm256_loadu_ps(&x[i]);
471 let y_vec = _mm256_loadu_ps(&y[i]);
472 let result = _mm256_fmadd_ps(alpha_vec, x_vec, y_vec);
473 _mm256_storeu_ps(&mut y[i], result);
474 i += 8;
475 }
476
477 while i < len {
479 y[i] += alpha * x[i];
480 i += 1;
481 }
482}
483
484#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
485#[target_feature(enable = "sse2")]
486unsafe fn simd_scale_sse2(alpha: f32, x: &mut ArrayViewMut1<f32>) {
487 use core::arch::x86_64::*;
488
489 let alpha_vec = _mm_set1_ps(alpha);
490 let len = x.len();
491 let mut i = 0;
492
493 while i + 4 <= len {
494 let x_vec = _mm_loadu_ps(&x[i]);
495 let result = _mm_mul_ps(alpha_vec, x_vec);
496 _mm_storeu_ps(&mut x[i], result);
497 i += 4;
498 }
499
500 while i < len {
502 x[i] *= alpha;
503 i += 1;
504 }
505}
506
507#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
508#[target_feature(enable = "avx2")]
509unsafe fn simd_scale_avx2(alpha: f32, x: &mut ArrayViewMut1<f32>) {
510 use core::arch::x86_64::*;
511
512 let alpha_vec = _mm256_set1_ps(alpha);
513 let len = x.len();
514 let mut i = 0;
515
516 while i + 8 <= len {
517 let x_vec = _mm256_loadu_ps(&x[i]);
518 let result = _mm256_mul_ps(alpha_vec, x_vec);
519 _mm256_storeu_ps(&mut x[i], result);
520 i += 8;
521 }
522
523 while i < len {
525 x[i] *= alpha;
526 i += 1;
527 }
528}
529
530#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
531#[target_feature(enable = "sse2")]
532unsafe fn simd_momentum_update_sse2(
533 momentum: f32,
534 grad: &ArrayView1<f32>,
535 velocity: &mut ArrayViewMut1<f32>,
536) {
537 use core::arch::x86_64::*;
538
539 let momentum_vec = _mm_set1_ps(momentum);
540 let len = grad.len();
541 let mut i = 0;
542
543 while i + 4 <= len {
544 let grad_vec = _mm_loadu_ps(&grad[i]);
545 let vel_vec = _mm_loadu_ps(&velocity[i]);
546 let result = _mm_add_ps(_mm_mul_ps(momentum_vec, vel_vec), grad_vec);
547 _mm_storeu_ps(&mut velocity[i], result);
548 i += 4;
549 }
550
551 while i < len {
553 velocity[i] = momentum * velocity[i] + grad[i];
554 i += 1;
555 }
556}
557
558#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
559#[target_feature(enable = "avx2")]
560unsafe fn simd_momentum_update_avx2(
561 momentum: f32,
562 grad: &ArrayView1<f32>,
563 velocity: &mut ArrayViewMut1<f32>,
564) {
565 use core::arch::x86_64::*;
566
567 let momentum_vec = _mm256_set1_ps(momentum);
568 let len = grad.len();
569 let mut i = 0;
570
571 while i + 8 <= len {
572 let grad_vec = _mm256_loadu_ps(&grad[i]);
573 let vel_vec = _mm256_loadu_ps(&velocity[i]);
574 let result = _mm256_add_ps(_mm256_mul_ps(momentum_vec, vel_vec), grad_vec);
575 _mm256_storeu_ps(&mut velocity[i], result);
576 i += 8;
577 }
578
579 while i < len {
581 velocity[i] = momentum * velocity[i] + grad[i];
582 i += 1;
583 }
584}
585
586#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
587#[target_feature(enable = "avx2", enable = "fma")]
588unsafe fn simd_momentum_update_avx2_fma(
589 momentum: f32,
590 grad: &ArrayView1<f32>,
591 velocity: &mut ArrayViewMut1<f32>,
592) {
593 use core::arch::x86_64::*;
594
595 let momentum_vec = _mm256_set1_ps(momentum);
596 let len = grad.len();
597 let mut i = 0;
598
599 while i + 8 <= len {
600 let grad_vec = _mm256_loadu_ps(&grad[i]);
601 let vel_vec = _mm256_loadu_ps(&velocity[i]);
602 let result = _mm256_fmadd_ps(momentum_vec, vel_vec, grad_vec);
603 _mm256_storeu_ps(&mut velocity[i], result);
604 i += 8;
605 }
606
607 while i < len {
609 velocity[i] = momentum * velocity[i] + grad[i];
610 i += 1;
611 }
612}
613
614#[allow(non_snake_case)]
615#[cfg(all(test, not(feature = "no-std")))]
616mod tests {
617 use super::*;
618 use approx::assert_relative_eq;
619
620 #[cfg(feature = "no-std")]
621 use alloc::{vec, vec::Vec};
622
623 #[test]
624 fn test_gradient_descent() {
625 let optimizer = GradientDescent::new(0.1).with_momentum(0.9);
626
627 let mut params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
628 let gradient = Array1::from_vec(vec![0.1, 0.2, 0.3]);
629 let mut velocity = Array1::zeros(3);
630
631 let params_before = params.clone();
632 optimizer.step(
633 &mut params.view_mut(),
634 &gradient.view(),
635 &mut velocity.view_mut(),
636 );
637
638 for i in 0..params.len() {
640 assert!(params[i] < params_before[i]);
641 }
642 }
643
644 #[test]
645 fn test_coordinate_descent() {
646 let optimizer = CoordinateDescent::new(0.1);
647
648 let x = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
650 .expect("shape and data length should match");
651 let y = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
652 let mut coeff = Array1::zeros(2);
653
654 let result = optimizer.optimize_lasso(&x, &y, &mut coeff);
655 assert!(result.is_ok());
656 }
657
658 #[test]
659 fn test_simd_axpy() {
660 let x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
661 let mut y = Array1::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0]);
662 let alpha = 2.0;
663
664 let expected = &y + &(&x * alpha);
665 simd_axpy(alpha, &x.view(), &mut y.view_mut());
666
667 for i in 0..x.len() {
668 assert_relative_eq!(y[i], expected[i], epsilon = 1e-6);
669 }
670 }
671
672 #[test]
673 fn test_simd_scale() {
674 let mut x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
675 let alpha = 2.5;
676
677 let expected = &x * alpha;
678 simd_scale(alpha, &mut x.view_mut());
679
680 for i in 0..x.len() {
681 assert_relative_eq!(x[i], expected[i], epsilon = 1e-6);
682 }
683 }
684
685 #[test]
686 fn test_momentum_update() {
687 let grad = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
688 let mut velocity = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
689 let momentum = 0.9;
690
691 let expected = &velocity * momentum + &grad;
692 simd_momentum_update(momentum, &grad.view(), &mut velocity.view_mut());
693
694 for i in 0..grad.len() {
695 assert_relative_eq!(velocity[i], expected[i], epsilon = 1e-6);
696 }
697 }
698
699 #[test]
700 fn test_soft_threshold() {
701 assert_eq!(soft_threshold(2.0, 1.0), 1.0);
702 assert_eq!(soft_threshold(-2.0, 1.0), -1.0);
703 assert_eq!(soft_threshold(0.5, 1.0), 0.0);
704 assert_eq!(soft_threshold(-0.5, 1.0), 0.0);
705 }
706}