1use crate::matrix::matrix_vector_multiply_f32;
8use crate::vector::{dot_product, norm_l2};
9use scirs2_core::ndarray::{Array1, Array2, ArrayView1, ArrayViewMut1};
10
11#[cfg(feature = "no-std")]
13use alloc::string::String;
14#[cfg(not(feature = "no-std"))]
15use std::string::String;
16
17pub struct GradientDescent {
19 learning_rate: f32,
20 momentum: f32,
21 dampening: f32,
22 weight_decay: f32,
23 nesterov: bool,
24}
25
26impl GradientDescent {
27 pub fn new(learning_rate: f32) -> Self {
29 Self {
30 learning_rate,
31 momentum: 0.0,
32 dampening: 0.0,
33 weight_decay: 0.0,
34 nesterov: false,
35 }
36 }
37
38 pub fn with_momentum(mut self, momentum: f32) -> Self {
40 self.momentum = momentum;
41 self
42 }
43
44 pub fn with_weight_decay(mut self, weight_decay: f32) -> Self {
46 self.weight_decay = weight_decay;
47 self
48 }
49
50 pub fn with_nesterov(mut self) -> Self {
52 self.nesterov = true;
53 self
54 }
55
56 pub fn step(
58 &self,
59 params: &mut ArrayViewMut1<f32>,
60 gradient: &ArrayView1<f32>,
61 velocity: &mut ArrayViewMut1<f32>,
62 ) {
63 let mut grad = gradient.to_owned();
65 if self.weight_decay != 0.0 {
66 simd_axpy(self.weight_decay, ¶ms.view(), &mut grad.view_mut());
67 }
68
69 if self.momentum != 0.0 {
70 simd_momentum_update(self.momentum, &grad.view(), velocity);
72
73 if self.nesterov {
74 let mut nesterov_grad = grad.clone();
76 simd_axpy(
77 self.momentum,
78 &velocity.view(),
79 &mut nesterov_grad.view_mut(),
80 );
81 simd_axpy(-self.learning_rate, &nesterov_grad.view(), params);
82 } else {
83 simd_axpy(-self.learning_rate, &velocity.view(), params);
85 }
86 } else {
87 simd_axpy(-self.learning_rate, &grad.view(), params);
89 }
90 }
91}
92
93pub struct CoordinateDescent {
95 alpha: f32,
96 tolerance: f32,
97 max_iterations: usize,
98}
99
100impl CoordinateDescent {
101 pub fn new(alpha: f32) -> Self {
103 Self {
104 alpha,
105 tolerance: 1e-4,
106 max_iterations: 1000,
107 }
108 }
109
110 pub fn with_tolerance(mut self, tolerance: f32) -> Self {
112 self.tolerance = tolerance;
113 self
114 }
115
116 pub fn with_max_iterations(mut self, max_iterations: usize) -> Self {
118 self.max_iterations = max_iterations;
119 self
120 }
121
122 pub fn optimize_lasso(
124 &self,
125 x: &Array2<f32>,
126 y: &Array1<f32>,
127 coeff: &mut Array1<f32>,
128 ) -> Result<(), String> {
129 let n_features = x.ncols();
130 let n_samples = x.nrows();
131
132 let mut xtx_diag = Array1::zeros(n_features);
134 for j in 0..n_features {
135 let col = x.column(j).to_owned();
136 xtx_diag[j] = dot_product(
137 col.as_slice().expect("slice operation should succeed"),
138 col.as_slice().expect("slice operation should succeed"),
139 );
140 }
141
142 let mut residuals = y.clone();
144 let pred = matrix_vector_multiply_f32(x, coeff);
145 simd_axpy(-1.0, &pred.view(), &mut residuals.view_mut());
146
147 for _ in 0..self.max_iterations {
148 let mut max_change: f32 = 0.0;
149
150 for j in 0..n_features {
151 let old_coeff = coeff[j];
152
153 let col = x.column(j);
155 simd_axpy(old_coeff, &col.to_owned().view(), &mut residuals.view_mut());
156
157 let col_slice = col.to_owned();
159 let rho = dot_product(
160 col_slice
161 .as_slice()
162 .expect("slice operation should succeed"),
163 residuals
164 .as_slice()
165 .expect("slice operation should succeed"),
166 );
167 let new_coeff = soft_threshold(rho / n_samples as f32, self.alpha)
168 / (xtx_diag[j] / n_samples as f32);
169
170 coeff[j] = new_coeff;
172 let change = new_coeff - old_coeff;
173 max_change = max_change.max(change.abs());
174
175 simd_axpy(
177 -new_coeff,
178 &col.to_owned().view(),
179 &mut residuals.view_mut(),
180 );
181 }
182
183 if max_change < self.tolerance {
184 return Ok(());
185 }
186 }
187
188 Ok(())
189 }
190}
191
192pub struct QuasiNewton {
194 memory_size: usize,
195 tolerance: f32,
196 max_iterations: usize,
197 line_search_max_iter: usize,
198}
199
200impl Default for QuasiNewton {
201 fn default() -> Self {
202 Self::new()
203 }
204}
205
206impl QuasiNewton {
207 pub fn new() -> Self {
209 Self {
210 memory_size: 10,
211 tolerance: 1e-6,
212 max_iterations: 1000,
213 line_search_max_iter: 20,
214 }
215 }
216
217 pub fn with_memory_size(mut self, memory_size: usize) -> Self {
219 self.memory_size = memory_size;
220 self
221 }
222
223 pub fn optimize<F, G>(
225 &self,
226 mut x: Array1<f32>,
227 objective: F,
228 gradient: G,
229 ) -> Result<Array1<f32>, String>
230 where
231 F: Fn(&Array1<f32>) -> f32,
232 G: Fn(&Array1<f32>) -> Array1<f32>,
233 {
234 let n = x.len();
235 let mut grad = gradient(&x);
236 let h_inv = Array2::eye(n); for _ in 0..self.max_iterations {
239 let grad_norm = norm_l2(grad.as_slice().expect("slice operation should succeed"));
240 if grad_norm < self.tolerance {
241 return Ok(x);
242 }
243
244 let direction = matrix_vector_multiply_f32(&h_inv, &grad);
246 let mut search_dir = direction;
247 simd_scale(-1.0, &mut search_dir.view_mut());
248
249 let step_size = self.line_search(&x, &search_dir, &objective, &gradient)?;
251
252 let mut step = search_dir.clone();
254 simd_scale(step_size, &mut step.view_mut());
255 let x_new = &x + &step;
256
257 let grad_new = gradient(&x_new);
258
259 let s = &x_new - &x;
261 let y = &grad_new - &grad;
262
263 let sy = dot_product(
264 s.as_slice().expect("slice operation should succeed"),
265 y.as_slice().expect("slice operation should succeed"),
266 );
267 if sy > 1e-10 {
268 }
271
272 x = x_new;
273 grad = grad_new;
274 }
275
276 Ok(x)
277 }
278
279 fn line_search<F, G>(
281 &self,
282 x: &Array1<f32>,
283 direction: &Array1<f32>,
284 objective: &F,
285 gradient: &G,
286 ) -> Result<f32, String>
287 where
288 F: Fn(&Array1<f32>) -> f32,
289 G: Fn(&Array1<f32>) -> Array1<f32>,
290 {
291 let c1 = 1e-4;
292 let mut alpha = 1.0;
293 let f_x = objective(x);
294 let grad_x = gradient(x);
295 let grad_dot_dir = dot_product(
296 grad_x.as_slice().expect("slice operation should succeed"),
297 direction
298 .as_slice()
299 .expect("slice operation should succeed"),
300 );
301
302 for _ in 0..self.line_search_max_iter {
303 let mut x_new = x.clone();
304 let mut step = direction.clone();
305 simd_scale(alpha, &mut step.view_mut());
306 simd_axpy(1.0, &step.view(), &mut x_new.view_mut());
307
308 let f_x_new = objective(&x_new);
309
310 if f_x_new <= f_x + c1 * alpha * grad_dot_dir {
312 return Ok(alpha);
313 }
314
315 alpha *= 0.5;
316 }
317
318 Ok(alpha)
319 }
320}
321
322pub fn simd_axpy(alpha: f32, x: &ArrayView1<f32>, y: &mut ArrayViewMut1<f32>) {
324 assert_eq!(x.len(), y.len(), "Arrays must have the same length");
325
326 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
327 {
328 if crate::simd_feature_detected!("avx2") && crate::simd_feature_detected!("fma") {
329 unsafe { simd_axpy_avx2_fma(alpha, x, y) };
330 return;
331 } else if crate::simd_feature_detected!("avx2") {
332 unsafe { simd_axpy_avx2(alpha, x, y) };
333 return;
334 } else if crate::simd_feature_detected!("sse2") {
335 unsafe { simd_axpy_sse2(alpha, x, y) };
336 return;
337 }
338 }
339
340 for i in 0..x.len() {
342 y[i] += alpha * x[i];
343 }
344}
345
346pub fn simd_scale(alpha: f32, x: &mut ArrayViewMut1<f32>) {
348 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
349 {
350 if crate::simd_feature_detected!("avx2") {
351 unsafe { simd_scale_avx2(alpha, x) };
352 return;
353 } else if crate::simd_feature_detected!("sse2") {
354 unsafe { simd_scale_sse2(alpha, x) };
355 return;
356 }
357 }
358
359 for val in x.iter_mut() {
361 *val *= alpha;
362 }
363}
364
365pub fn simd_momentum_update(
367 momentum: f32,
368 grad: &ArrayView1<f32>,
369 velocity: &mut ArrayViewMut1<f32>,
370) {
371 assert_eq!(
372 grad.len(),
373 velocity.len(),
374 "Arrays must have the same length"
375 );
376
377 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
378 {
379 if crate::simd_feature_detected!("avx2") && crate::simd_feature_detected!("fma") {
380 unsafe { simd_momentum_update_avx2_fma(momentum, grad, velocity) };
381 return;
382 } else if crate::simd_feature_detected!("avx2") {
383 unsafe { simd_momentum_update_avx2(momentum, grad, velocity) };
384 return;
385 } else if crate::simd_feature_detected!("sse2") {
386 unsafe { simd_momentum_update_sse2(momentum, grad, velocity) };
387 return;
388 }
389 }
390
391 for i in 0..grad.len() {
393 velocity[i] = momentum * velocity[i] + grad[i];
394 }
395}
396
397fn soft_threshold(x: f32, threshold: f32) -> f32 {
399 if x > threshold {
400 x - threshold
401 } else if x < -threshold {
402 x + threshold
403 } else {
404 0.0
405 }
406}
407
408#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
411#[target_feature(enable = "sse2")]
412unsafe fn simd_axpy_sse2(alpha: f32, x: &ArrayView1<f32>, y: &mut ArrayViewMut1<f32>) {
413 use core::arch::x86_64::*;
414
415 let alpha_vec = _mm_set1_ps(alpha);
416 let len = x.len();
417 let mut i = 0;
418
419 while i + 4 <= len {
420 let x_vec = _mm_loadu_ps(&x[i]);
421 let y_vec = _mm_loadu_ps(&y[i]);
422 let result = _mm_add_ps(_mm_mul_ps(alpha_vec, x_vec), y_vec);
423 _mm_storeu_ps(&mut y[i], result);
424 i += 4;
425 }
426
427 while i < len {
429 y[i] += alpha * x[i];
430 i += 1;
431 }
432}
433
434#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
435#[target_feature(enable = "avx2")]
436unsafe fn simd_axpy_avx2(alpha: f32, x: &ArrayView1<f32>, y: &mut ArrayViewMut1<f32>) {
437 use core::arch::x86_64::*;
438
439 let alpha_vec = _mm256_set1_ps(alpha);
440 let len = x.len();
441 let mut i = 0;
442
443 while i + 8 <= len {
444 let x_vec = _mm256_loadu_ps(&x[i]);
445 let y_vec = _mm256_loadu_ps(&y[i]);
446 let result = _mm256_add_ps(_mm256_mul_ps(alpha_vec, x_vec), y_vec);
447 _mm256_storeu_ps(&mut y[i], result);
448 i += 8;
449 }
450
451 while i < len {
453 y[i] += alpha * x[i];
454 i += 1;
455 }
456}
457
458#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
459#[target_feature(enable = "avx2", enable = "fma")]
460unsafe fn simd_axpy_avx2_fma(alpha: f32, x: &ArrayView1<f32>, y: &mut ArrayViewMut1<f32>) {
461 use core::arch::x86_64::*;
462
463 let alpha_vec = _mm256_set1_ps(alpha);
464 let len = x.len();
465 let mut i = 0;
466
467 while i + 8 <= len {
468 let x_vec = _mm256_loadu_ps(&x[i]);
469 let y_vec = _mm256_loadu_ps(&y[i]);
470 let result = _mm256_fmadd_ps(alpha_vec, x_vec, y_vec);
471 _mm256_storeu_ps(&mut y[i], result);
472 i += 8;
473 }
474
475 while i < len {
477 y[i] += alpha * x[i];
478 i += 1;
479 }
480}
481
482#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
483#[target_feature(enable = "sse2")]
484unsafe fn simd_scale_sse2(alpha: f32, x: &mut ArrayViewMut1<f32>) {
485 use core::arch::x86_64::*;
486
487 let alpha_vec = _mm_set1_ps(alpha);
488 let len = x.len();
489 let mut i = 0;
490
491 while i + 4 <= len {
492 let x_vec = _mm_loadu_ps(&x[i]);
493 let result = _mm_mul_ps(alpha_vec, x_vec);
494 _mm_storeu_ps(&mut x[i], result);
495 i += 4;
496 }
497
498 while i < len {
500 x[i] *= alpha;
501 i += 1;
502 }
503}
504
505#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
506#[target_feature(enable = "avx2")]
507unsafe fn simd_scale_avx2(alpha: f32, x: &mut ArrayViewMut1<f32>) {
508 use core::arch::x86_64::*;
509
510 let alpha_vec = _mm256_set1_ps(alpha);
511 let len = x.len();
512 let mut i = 0;
513
514 while i + 8 <= len {
515 let x_vec = _mm256_loadu_ps(&x[i]);
516 let result = _mm256_mul_ps(alpha_vec, x_vec);
517 _mm256_storeu_ps(&mut x[i], result);
518 i += 8;
519 }
520
521 while i < len {
523 x[i] *= alpha;
524 i += 1;
525 }
526}
527
528#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
529#[target_feature(enable = "sse2")]
530unsafe fn simd_momentum_update_sse2(
531 momentum: f32,
532 grad: &ArrayView1<f32>,
533 velocity: &mut ArrayViewMut1<f32>,
534) {
535 use core::arch::x86_64::*;
536
537 let momentum_vec = _mm_set1_ps(momentum);
538 let len = grad.len();
539 let mut i = 0;
540
541 while i + 4 <= len {
542 let grad_vec = _mm_loadu_ps(&grad[i]);
543 let vel_vec = _mm_loadu_ps(&velocity[i]);
544 let result = _mm_add_ps(_mm_mul_ps(momentum_vec, vel_vec), grad_vec);
545 _mm_storeu_ps(&mut velocity[i], result);
546 i += 4;
547 }
548
549 while i < len {
551 velocity[i] = momentum * velocity[i] + grad[i];
552 i += 1;
553 }
554}
555
556#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
557#[target_feature(enable = "avx2")]
558unsafe fn simd_momentum_update_avx2(
559 momentum: f32,
560 grad: &ArrayView1<f32>,
561 velocity: &mut ArrayViewMut1<f32>,
562) {
563 use core::arch::x86_64::*;
564
565 let momentum_vec = _mm256_set1_ps(momentum);
566 let len = grad.len();
567 let mut i = 0;
568
569 while i + 8 <= len {
570 let grad_vec = _mm256_loadu_ps(&grad[i]);
571 let vel_vec = _mm256_loadu_ps(&velocity[i]);
572 let result = _mm256_add_ps(_mm256_mul_ps(momentum_vec, vel_vec), grad_vec);
573 _mm256_storeu_ps(&mut velocity[i], result);
574 i += 8;
575 }
576
577 while i < len {
579 velocity[i] = momentum * velocity[i] + grad[i];
580 i += 1;
581 }
582}
583
584#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
585#[target_feature(enable = "avx2", enable = "fma")]
586unsafe fn simd_momentum_update_avx2_fma(
587 momentum: f32,
588 grad: &ArrayView1<f32>,
589 velocity: &mut ArrayViewMut1<f32>,
590) {
591 use core::arch::x86_64::*;
592
593 let momentum_vec = _mm256_set1_ps(momentum);
594 let len = grad.len();
595 let mut i = 0;
596
597 while i + 8 <= len {
598 let grad_vec = _mm256_loadu_ps(&grad[i]);
599 let vel_vec = _mm256_loadu_ps(&velocity[i]);
600 let result = _mm256_fmadd_ps(momentum_vec, vel_vec, grad_vec);
601 _mm256_storeu_ps(&mut velocity[i], result);
602 i += 8;
603 }
604
605 while i < len {
607 velocity[i] = momentum * velocity[i] + grad[i];
608 i += 1;
609 }
610}
611
612#[allow(non_snake_case)]
613#[cfg(all(test, not(feature = "no-std")))]
614mod tests {
615 use super::*;
616 use approx::assert_relative_eq;
617
618 #[cfg(feature = "no-std")]
619 use alloc::{vec, vec::Vec};
620
621 #[test]
622 fn test_gradient_descent() {
623 let optimizer = GradientDescent::new(0.1).with_momentum(0.9);
624
625 let mut params = Array1::from_vec(vec![1.0, 2.0, 3.0]);
626 let gradient = Array1::from_vec(vec![0.1, 0.2, 0.3]);
627 let mut velocity = Array1::zeros(3);
628
629 let params_before = params.clone();
630 optimizer.step(
631 &mut params.view_mut(),
632 &gradient.view(),
633 &mut velocity.view_mut(),
634 );
635
636 for i in 0..params.len() {
638 assert!(params[i] < params_before[i]);
639 }
640 }
641
642 #[test]
643 fn test_coordinate_descent() {
644 let optimizer = CoordinateDescent::new(0.1);
645
646 let x = Array2::from_shape_vec((4, 2), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])
648 .expect("shape and data length should match");
649 let y = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
650 let mut coeff = Array1::zeros(2);
651
652 let result = optimizer.optimize_lasso(&x, &y, &mut coeff);
653 assert!(result.is_ok());
654 }
655
656 #[test]
657 fn test_simd_axpy() {
658 let x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
659 let mut y = Array1::from_vec(vec![10.0, 20.0, 30.0, 40.0, 50.0, 60.0, 70.0, 80.0]);
660 let alpha = 2.0;
661
662 let expected = &y + &(&x * alpha);
663 simd_axpy(alpha, &x.view(), &mut y.view_mut());
664
665 for i in 0..x.len() {
666 assert_relative_eq!(y[i], expected[i], epsilon = 1e-6);
667 }
668 }
669
670 #[test]
671 fn test_simd_scale() {
672 let mut x = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
673 let alpha = 2.5;
674
675 let expected = &x * alpha;
676 simd_scale(alpha, &mut x.view_mut());
677
678 for i in 0..x.len() {
679 assert_relative_eq!(x[i], expected[i], epsilon = 1e-6);
680 }
681 }
682
683 #[test]
684 fn test_momentum_update() {
685 let grad = Array1::from_vec(vec![0.1, 0.2, 0.3, 0.4]);
686 let mut velocity = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
687 let momentum = 0.9;
688
689 let expected = &velocity * momentum + &grad;
690 simd_momentum_update(momentum, &grad.view(), &mut velocity.view_mut());
691
692 for i in 0..grad.len() {
693 assert_relative_eq!(velocity[i], expected[i], epsilon = 1e-6);
694 }
695 }
696
697 #[test]
698 fn test_soft_threshold() {
699 assert_eq!(soft_threshold(2.0, 1.0), 1.0);
700 assert_eq!(soft_threshold(-2.0, 1.0), -1.0);
701 assert_eq!(soft_threshold(0.5, 1.0), 0.0);
702 assert_eq!(soft_threshold(-0.5, 1.0), 0.0);
703 }
704}