1#![cfg_attr(test, allow(unused_variables, unused_mut))]
3use anyhow::{anyhow, Result};
8#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
9use std::arch::x86_64::*;
10
11#[derive(Debug, Clone)]
13pub struct SIMDConfig {
14 pub enable_avx2: bool,
16 pub enable_avx512: bool,
18 pub enable_neon: bool,
20 pub min_vector_size: usize,
22 pub enable_unrolling: bool,
24}
25
26impl Default for SIMDConfig {
27 fn default() -> Self {
28 Self {
29 enable_avx2: true,
30 enable_avx512: true,
31 enable_neon: true,
32 min_vector_size: 8,
33 enable_unrolling: true,
34 }
35 }
36}
37
38pub struct SIMDOptimizer {
40 config: SIMDConfig,
41}
42
43impl SIMDOptimizer {
44 pub fn new(config: SIMDConfig) -> Self {
46 Self { config }
47 }
48
49 pub fn detect_capabilities() -> SIMDConfig {
51 SIMDConfig {
52 enable_avx2: {
53 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
54 {
55 is_x86_feature_detected!("avx2")
56 }
57 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
58 {
59 false
60 }
61 },
62 enable_avx512: {
63 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
64 {
65 is_x86_feature_detected!("avx512f")
66 }
67 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
68 {
69 false
70 }
71 },
72 enable_neon: cfg!(target_arch = "aarch64"),
73 min_vector_size: 8,
74 enable_unrolling: true,
75 }
76 }
77
78 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
80 #[target_feature(enable = "avx2")]
81 pub unsafe fn adam_update_avx2(
82 &self,
83 params: &mut [f32],
84 gradients: &[f32],
85 momentum: &mut [f32],
86 velocity: &mut [f32],
87 lr: f32,
88 beta1: f32,
89 beta2: f32,
90 eps: f32,
91 step: i32,
92 ) -> Result<()> {
93 if params.len() != gradients.len()
94 || params.len() != momentum.len()
95 || params.len() != velocity.len()
96 {
97 return Err(anyhow!("All arrays must have the same length"));
98 }
99
100 let bias_correction1 = 1.0 - beta1.powi(step);
101 let bias_correction2 = 1.0 - beta2.powi(step);
102 let corrected_lr = lr * (bias_correction2.sqrt() / bias_correction1);
103
104 let beta1_vec = _mm256_set1_ps(beta1);
106 let beta2_vec = _mm256_set1_ps(beta2);
107 let one_minus_beta1 = _mm256_set1_ps(1.0 - beta1);
108 let one_minus_beta2 = _mm256_set1_ps(1.0 - beta2);
109 let eps_vec = _mm256_set1_ps(eps);
110 let lr_vec = _mm256_set1_ps(corrected_lr);
111
112 let len = params.len();
113 let chunks = len / 8;
114 let _remainder = len % 8;
115
116 for i in 0..chunks {
118 let idx = i * 8;
119
120 let p = _mm256_loadu_ps(params.as_ptr().add(idx));
122 let g = _mm256_loadu_ps(gradients.as_ptr().add(idx));
123 let m = _mm256_loadu_ps(momentum.as_ptr().add(idx));
124 let v = _mm256_loadu_ps(velocity.as_ptr().add(idx));
125
126 let m_new = _mm256_fmadd_ps(beta1_vec, m, _mm256_mul_ps(one_minus_beta1, g));
128
129 let g_sq = _mm256_mul_ps(g, g);
131 let v_new = _mm256_fmadd_ps(beta2_vec, v, _mm256_mul_ps(one_minus_beta2, g_sq));
132
133 let v_sqrt = _mm256_sqrt_ps(v_new);
135 let v_sqrt_eps = _mm256_add_ps(v_sqrt, eps_vec);
136 let update = _mm256_div_ps(m_new, v_sqrt_eps);
137 let p_new = _mm256_fnmadd_ps(lr_vec, update, p);
138
139 _mm256_storeu_ps(params.as_mut_ptr().add(idx), p_new);
141 _mm256_storeu_ps(momentum.as_mut_ptr().add(idx), m_new);
142 _mm256_storeu_ps(velocity.as_mut_ptr().add(idx), v_new);
143 }
144
145 for i in (chunks * 8)..len {
147 let g = gradients[i];
148 let m = momentum[i];
149 let v = velocity[i];
150
151 let m_new = beta1 * m + (1.0 - beta1) * g;
152 let v_new = beta2 * v + (1.0 - beta2) * g * g;
153
154 momentum[i] = m_new;
155 velocity[i] = v_new;
156 params[i] -= corrected_lr * m_new / (v_new.sqrt() + eps);
157 }
158
159 Ok(())
160 }
161
162 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
164 #[target_feature(enable = "avx2")]
165 pub unsafe fn adamw_update_avx2(
166 &self,
167 params: &mut [f32],
168 gradients: &[f32],
169 momentum: &mut [f32],
170 velocity: &mut [f32],
171 lr: f32,
172 beta1: f32,
173 beta2: f32,
174 eps: f32,
175 weight_decay: f32,
176 step: i32,
177 ) -> Result<()> {
178 if params.len() != gradients.len()
179 || params.len() != momentum.len()
180 || params.len() != velocity.len()
181 {
182 return Err(anyhow!("All arrays must have the same length"));
183 }
184
185 let bias_correction1 = 1.0 - beta1.powi(step);
186 let bias_correction2 = 1.0 - beta2.powi(step);
187 let corrected_lr = lr * (bias_correction2.sqrt() / bias_correction1);
188
189 let beta1_vec = _mm256_set1_ps(beta1);
191 let beta2_vec = _mm256_set1_ps(beta2);
192 let one_minus_beta1 = _mm256_set1_ps(1.0 - beta1);
193 let one_minus_beta2 = _mm256_set1_ps(1.0 - beta2);
194 let eps_vec = _mm256_set1_ps(eps);
195 let lr_vec = _mm256_set1_ps(corrected_lr);
196 let wd_vec = _mm256_set1_ps(1.0 - lr * weight_decay);
197
198 let len = params.len();
199 let chunks = len / 8;
200
201 for i in 0..chunks {
202 let idx = i * 8;
203
204 let p = _mm256_loadu_ps(params.as_ptr().add(idx));
205 let g = _mm256_loadu_ps(gradients.as_ptr().add(idx));
206 let m = _mm256_loadu_ps(momentum.as_ptr().add(idx));
207 let v = _mm256_loadu_ps(velocity.as_ptr().add(idx));
208
209 let p_decayed = _mm256_mul_ps(p, wd_vec);
211
212 let m_new = _mm256_fmadd_ps(beta1_vec, m, _mm256_mul_ps(one_minus_beta1, g));
214 let g_sq = _mm256_mul_ps(g, g);
215 let v_new = _mm256_fmadd_ps(beta2_vec, v, _mm256_mul_ps(one_minus_beta2, g_sq));
216
217 let v_sqrt = _mm256_sqrt_ps(v_new);
219 let v_sqrt_eps = _mm256_add_ps(v_sqrt, eps_vec);
220 let update = _mm256_div_ps(m_new, v_sqrt_eps);
221 let p_new = _mm256_fnmadd_ps(lr_vec, update, p_decayed);
222
223 _mm256_storeu_ps(params.as_mut_ptr().add(idx), p_new);
224 _mm256_storeu_ps(momentum.as_mut_ptr().add(idx), m_new);
225 _mm256_storeu_ps(velocity.as_mut_ptr().add(idx), v_new);
226 }
227
228 for i in (chunks * 8)..len {
230 let p = params[i];
231 let g = gradients[i];
232 let m = momentum[i];
233 let v = velocity[i];
234
235 let p_decayed = p * (1.0 - lr * weight_decay);
236 let m_new = beta1 * m + (1.0 - beta1) * g;
237 let v_new = beta2 * v + (1.0 - beta2) * g * g;
238
239 momentum[i] = m_new;
240 velocity[i] = v_new;
241 params[i] = p_decayed - corrected_lr * m_new / (v_new.sqrt() + eps);
242 }
243
244 Ok(())
245 }
246
247 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
249 #[target_feature(enable = "avx2")]
250 pub unsafe fn sgd_momentum_update_avx2(
251 &self,
252 params: &mut [f32],
253 gradients: &[f32],
254 momentum: &mut [f32],
255 lr: f32,
256 momentum_factor: f32,
257 weight_decay: f32,
258 nesterov: bool,
259 ) -> Result<()> {
260 if params.len() != gradients.len() || params.len() != momentum.len() {
261 return Err(anyhow!("All arrays must have the same length"));
262 }
263
264 let lr_vec = _mm256_set1_ps(lr);
265 let momentum_vec = _mm256_set1_ps(momentum_factor);
266 let wd_vec = _mm256_set1_ps(weight_decay);
267
268 let len = params.len();
269 let chunks = len / 8;
270
271 for i in 0..chunks {
272 let idx = i * 8;
273
274 let p = _mm256_loadu_ps(params.as_ptr().add(idx));
275 let g = _mm256_loadu_ps(gradients.as_ptr().add(idx));
276 let m = _mm256_loadu_ps(momentum.as_ptr().add(idx));
277
278 let g_wd = _mm256_fmadd_ps(wd_vec, p, g);
280
281 let m_new = _mm256_fmadd_ps(momentum_vec, m, g_wd);
283
284 let update = if nesterov {
286 _mm256_fmadd_ps(momentum_vec, m_new, g_wd)
288 } else {
289 m_new
291 };
292
293 let p_new = _mm256_fnmadd_ps(lr_vec, update, p);
294
295 _mm256_storeu_ps(params.as_mut_ptr().add(idx), p_new);
296 _mm256_storeu_ps(momentum.as_mut_ptr().add(idx), m_new);
297 }
298
299 for i in (chunks * 8)..len {
301 let p = params[i];
302 let g = gradients[i] + weight_decay * p;
303 let m = momentum[i];
304
305 let m_new = momentum_factor * m + g;
306 momentum[i] = m_new;
307
308 if nesterov {
309 params[i] = p - lr * (momentum_factor * m_new + g);
310 } else {
311 params[i] = p - lr * m_new;
312 }
313 }
314
315 Ok(())
316 }
317
318 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
320 #[target_feature(enable = "avx2")]
321 pub unsafe fn clip_gradients_avx2(&self, gradients: &mut [f32], max_norm: f32) -> Result<f32> {
322 let len = gradients.len();
323 let chunks = len / 8;
324
325 let mut norm_sq_vec = _mm256_setzero_ps();
327
328 for i in 0..chunks {
329 let idx = i * 8;
330 let g = _mm256_loadu_ps(gradients.as_ptr().add(idx));
331 let g_sq = _mm256_mul_ps(g, g);
332 norm_sq_vec = _mm256_add_ps(norm_sq_vec, g_sq);
333 }
334
335 let mut norm_sq = 0.0f32;
337 let norm_sq_array: [f32; 8] = std::mem::transmute(norm_sq_vec);
338 for &val in &norm_sq_array {
339 norm_sq += val;
340 }
341
342 for i in (chunks * 8)..len {
344 norm_sq += gradients[i] * gradients[i];
345 }
346
347 let global_norm = norm_sq.sqrt();
348
349 if global_norm > max_norm {
350 let scale = max_norm / global_norm;
351 let scale_vec = _mm256_set1_ps(scale);
352
353 for i in 0..chunks {
355 let idx = i * 8;
356 let g = _mm256_loadu_ps(gradients.as_ptr().add(idx));
357 let g_scaled = _mm256_mul_ps(g, scale_vec);
358 _mm256_storeu_ps(gradients.as_mut_ptr().add(idx), g_scaled);
359 }
360
361 for i in (chunks * 8)..len {
363 gradients[i] *= scale;
364 }
365 }
366
367 Ok(global_norm)
368 }
369
370 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
372 #[target_feature(enable = "avx2")]
373 pub unsafe fn vector_add_avx2(&self, a: &mut [f32], b: &[f32], scale: f32) -> Result<()> {
374 if a.len() != b.len() {
375 return Err(anyhow!("Vectors must have the same length"));
376 }
377
378 let scale_vec = _mm256_set1_ps(scale);
379 let len = a.len();
380 let chunks = len / 8;
381
382 for i in 0..chunks {
383 let idx = i * 8;
384 let a_vec = _mm256_loadu_ps(a.as_ptr().add(idx));
385 let b_vec = _mm256_loadu_ps(b.as_ptr().add(idx));
386 let result = _mm256_fmadd_ps(b_vec, scale_vec, a_vec);
387 _mm256_storeu_ps(a.as_mut_ptr().add(idx), result);
388 }
389
390 for i in (chunks * 8)..len {
392 a[i] += scale * b[i];
393 }
394
395 Ok(())
396 }
397
398 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
400 #[target_feature(enable = "avx2")]
401 pub unsafe fn dot_product_avx2(&self, a: &[f32], b: &[f32]) -> Result<f32> {
402 if a.len() != b.len() {
403 return Err(anyhow!("Vectors must have the same length"));
404 }
405
406 let len = a.len();
407 let chunks = len / 8;
408 let mut result_vec = _mm256_setzero_ps();
409
410 for i in 0..chunks {
411 let idx = i * 8;
412 let a_vec = _mm256_loadu_ps(a.as_ptr().add(idx));
413 let b_vec = _mm256_loadu_ps(b.as_ptr().add(idx));
414 let prod = _mm256_mul_ps(a_vec, b_vec);
415 result_vec = _mm256_add_ps(result_vec, prod);
416 }
417
418 let result_array: [f32; 8] = std::mem::transmute(result_vec);
420 let mut result = result_array.iter().sum::<f32>();
421
422 for i in (chunks * 8)..len {
424 result += a[i] * b[i];
425 }
426
427 Ok(result)
428 }
429
430 pub fn adam_update_fallback(
432 &self,
433 params: &mut [f32],
434 gradients: &[f32],
435 momentum: &mut [f32],
436 velocity: &mut [f32],
437 lr: f32,
438 beta1: f32,
439 beta2: f32,
440 eps: f32,
441 step: i32,
442 ) -> Result<()> {
443 if params.len() != gradients.len()
444 || params.len() != momentum.len()
445 || params.len() != velocity.len()
446 {
447 return Err(anyhow!("All arrays must have the same length"));
448 }
449
450 let bias_correction1 = 1.0 - beta1.powi(step);
451 let bias_correction2 = 1.0 - beta2.powi(step);
452 let corrected_lr = lr * (bias_correction2.sqrt() / bias_correction1);
453
454 for i in 0..params.len() {
455 let g = gradients[i];
456 let m = momentum[i];
457 let v = velocity[i];
458
459 let m_new = beta1 * m + (1.0 - beta1) * g;
460 let v_new = beta2 * v + (1.0 - beta2) * g * g;
461
462 momentum[i] = m_new;
463 velocity[i] = v_new;
464 params[i] -= corrected_lr * m_new / (v_new.sqrt() + eps);
465 }
466
467 Ok(())
468 }
469
470 pub fn adam_update(
472 &self,
473 params: &mut [f32],
474 gradients: &[f32],
475 momentum: &mut [f32],
476 velocity: &mut [f32],
477 lr: f32,
478 beta1: f32,
479 beta2: f32,
480 eps: f32,
481 step: i32,
482 ) -> Result<()> {
483 if params.len() < self.config.min_vector_size {
484 return self.adam_update_fallback(
485 params, gradients, momentum, velocity, lr, beta1, beta2, eps, step,
486 );
487 }
488
489 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
490 {
491 if self.config.enable_avx2 && is_x86_feature_detected!("avx2") {
492 return unsafe {
493 self.adam_update_avx2(
494 params, gradients, momentum, velocity, lr, beta1, beta2, eps, step,
495 )
496 };
497 }
498 }
499
500 self.adam_update_fallback(
501 params, gradients, momentum, velocity, lr, beta1, beta2, eps, step,
502 )
503 }
504
505 pub fn adamw_update(
507 &self,
508 params: &mut [f32],
509 gradients: &[f32],
510 momentum: &mut [f32],
511 velocity: &mut [f32],
512 lr: f32,
513 beta1: f32,
514 beta2: f32,
515 eps: f32,
516 weight_decay: f32,
517 step: i32,
518 ) -> Result<()> {
519 if params.len() < self.config.min_vector_size {
520 let bias_correction1 = 1.0 - beta1.powi(step);
522 let bias_correction2 = 1.0 - beta2.powi(step);
523 let corrected_lr = lr * (bias_correction2.sqrt() / bias_correction1);
524
525 for i in 0..params.len() {
526 let p = params[i];
527 let g = gradients[i];
528 let m = momentum[i];
529 let v = velocity[i];
530
531 let p_decayed = p * (1.0 - lr * weight_decay);
532 let m_new = beta1 * m + (1.0 - beta1) * g;
533 let v_new = beta2 * v + (1.0 - beta2) * g * g;
534
535 momentum[i] = m_new;
536 velocity[i] = v_new;
537 params[i] = p_decayed - corrected_lr * m_new / (v_new.sqrt() + eps);
538 }
539 return Ok(());
540 }
541
542 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
543 {
544 if self.config.enable_avx2 && is_x86_feature_detected!("avx2") {
545 return unsafe {
546 self.adamw_update_avx2(
547 params,
548 gradients,
549 momentum,
550 velocity,
551 lr,
552 beta1,
553 beta2,
554 eps,
555 weight_decay,
556 step,
557 )
558 };
559 }
560 }
561
562 self.adamw_update(
564 params,
565 gradients,
566 momentum,
567 velocity,
568 lr,
569 beta1,
570 beta2,
571 eps,
572 weight_decay,
573 step,
574 )
575 }
576
577 pub fn get_performance_info(&self) -> SIMDPerformanceInfo {
579 SIMDPerformanceInfo {
580 avx2_available: {
581 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
582 {
583 is_x86_feature_detected!("avx2")
584 }
585 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
586 {
587 false
588 }
589 },
590 avx512_available: {
591 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
592 {
593 is_x86_feature_detected!("avx512f")
594 }
595 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
596 {
597 false
598 }
599 },
600 neon_available: cfg!(target_arch = "aarch64"),
601 vector_width: {
602 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
603 {
604 if is_x86_feature_detected!("avx2") {
605 8
606 } else {
607 1
608 }
609 }
610 #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
611 {
612 1
613 }
614 },
615 recommended_min_size: self.config.min_vector_size,
616 }
617 }
618}
619
620impl Default for SIMDOptimizer {
621 fn default() -> Self {
622 Self::new(SIMDOptimizer::detect_capabilities())
623 }
624}
625
626#[derive(Debug, Clone)]
628pub struct SIMDPerformanceInfo {
629 pub avx2_available: bool,
630 pub avx512_available: bool,
631 pub neon_available: bool,
632 pub vector_width: usize,
633 pub recommended_min_size: usize,
634}
635
636#[cfg(test)]
637mod tests {
638 use super::*;
639
640 #[test]
641 fn test_simd_config_detection() {
642 let config = SIMDOptimizer::detect_capabilities();
643 assert!(config.min_vector_size > 0);
645 }
646
647 #[test]
648 fn test_adam_update_fallback() {
649 let optimizer = SIMDOptimizer::default();
650 let mut params = vec![1.0, 2.0, 3.0, 4.0];
651 let gradients = vec![0.1, 0.2, 0.3, 0.4];
652 let mut momentum = vec![0.0; 4];
653 let mut velocity = vec![0.0; 4];
654
655 optimizer
656 .adam_update_fallback(
657 &mut params,
658 &gradients,
659 &mut momentum,
660 &mut velocity,
661 0.001,
662 0.9,
663 0.999,
664 1e-8,
665 1,
666 )
667 .unwrap();
668
669 assert!(params[0] < 1.0);
671 assert!(momentum[0] > 0.0);
672 assert!(velocity[0] > 0.0);
673 }
674
675 #[test]
676 fn test_auto_dispatch_adam() {
677 let optimizer = SIMDOptimizer::default();
678 let mut params = vec![1.0; 16];
679 let gradients = vec![0.1; 16];
680 let mut momentum = vec![0.0; 16];
681 let mut velocity = vec![0.0; 16];
682
683 optimizer
684 .adam_update(
685 &mut params,
686 &gradients,
687 &mut momentum,
688 &mut velocity,
689 0.001,
690 0.9,
691 0.999,
692 1e-8,
693 1,
694 )
695 .unwrap();
696
697 assert!(params.iter().all(|&p| p < 1.0));
699 assert!(momentum.iter().all(|&m| m > 0.0));
700 }
701
702 #[test]
703 fn test_performance_info() {
704 let optimizer = SIMDOptimizer::default();
705 let info = optimizer.get_performance_info();
706
707 assert!(info.vector_width > 0);
708 assert!(info.recommended_min_size > 0);
709 }
710
711 #[test]
712 fn test_vector_operations() {
713 let optimizer = SIMDOptimizer::default();
714 let mut a = vec![1.0, 2.0, 3.0, 4.0];
715 let b = vec![0.5, 0.5, 0.5, 0.5];
716
717 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
718 {
719 if is_x86_feature_detected!("avx2") {
720 unsafe {
721 optimizer.vector_add_avx2(&mut a, &b, 2.0).unwrap();
722 }
723 assert_eq!(a, vec![2.0, 3.0, 4.0, 5.0]);
724 }
725 }
726 }
727
728 #[test]
729 fn test_dot_product() {
730 let optimizer = SIMDOptimizer::default();
731 let a = vec![1.0, 2.0, 3.0, 4.0];
732 let b = vec![1.0, 1.0, 1.0, 1.0];
733
734 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
735 {
736 if is_x86_feature_detected!("avx2") {
737 unsafe {
738 let result = optimizer.dot_product_avx2(&a, &b).unwrap();
739 assert_eq!(result, 10.0);
740 }
741 }
742 }
743 }
744}