Skip to main content

trustformers_optim/
simd_optimizations.rs

1//! SIMD Optimizations for Optimizers
2#![cfg_attr(test, allow(unused_variables, unused_mut))]
3//!
4//! This module provides SIMD-optimized implementations of optimizer operations
5//! for improved performance on x86_64, ARM, and other architectures.
6
7use anyhow::{anyhow, Result};
8#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
9use std::arch::x86_64::*;
10
11/// SIMD-optimized operations configuration
12#[derive(Debug, Clone)]
13pub struct SIMDConfig {
14    /// Enable AVX2 operations (x86_64)
15    pub enable_avx2: bool,
16    /// Enable AVX-512 operations (x86_64)
17    pub enable_avx512: bool,
18    /// Enable NEON operations (ARM)
19    pub enable_neon: bool,
20    /// Minimum vector size for SIMD operations
21    pub min_vector_size: usize,
22    /// Enable unrolled loops
23    pub enable_unrolling: bool,
24}
25
26impl Default for SIMDConfig {
27    fn default() -> Self {
28        Self {
29            enable_avx2: true,
30            enable_avx512: true,
31            enable_neon: true,
32            min_vector_size: 8,
33            enable_unrolling: true,
34        }
35    }
36}
37
38/// SIMD operations for optimizer kernels
39pub struct SIMDOptimizer {
40    config: SIMDConfig,
41}
42
43impl SIMDOptimizer {
44    /// Create a new SIMD optimizer with configuration
45    pub fn new(config: SIMDConfig) -> Self {
46        Self { config }
47    }
48
49    /// Detect available SIMD instruction sets
50    pub fn detect_capabilities() -> SIMDConfig {
51        SIMDConfig {
52            enable_avx2: {
53                #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
54                {
55                    is_x86_feature_detected!("avx2")
56                }
57                #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
58                {
59                    false
60                }
61            },
62            enable_avx512: {
63                #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
64                {
65                    is_x86_feature_detected!("avx512f")
66                }
67                #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
68                {
69                    false
70                }
71            },
72            enable_neon: cfg!(target_arch = "aarch64"),
73            min_vector_size: 8,
74            enable_unrolling: true,
75        }
76    }
77
78    /// SIMD-optimized Adam update with AVX2
79    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
80    #[target_feature(enable = "avx2")]
81    pub unsafe fn adam_update_avx2(
82        &self,
83        params: &mut [f32],
84        gradients: &[f32],
85        momentum: &mut [f32],
86        velocity: &mut [f32],
87        lr: f32,
88        beta1: f32,
89        beta2: f32,
90        eps: f32,
91        step: i32,
92    ) -> Result<()> {
93        if params.len() != gradients.len()
94            || params.len() != momentum.len()
95            || params.len() != velocity.len()
96        {
97            return Err(anyhow!("All arrays must have the same length"));
98        }
99
100        let bias_correction1 = 1.0 - beta1.powi(step);
101        let bias_correction2 = 1.0 - beta2.powi(step);
102        let corrected_lr = lr * (bias_correction2.sqrt() / bias_correction1);
103
104        // SIMD constants
105        let beta1_vec = _mm256_set1_ps(beta1);
106        let beta2_vec = _mm256_set1_ps(beta2);
107        let one_minus_beta1 = _mm256_set1_ps(1.0 - beta1);
108        let one_minus_beta2 = _mm256_set1_ps(1.0 - beta2);
109        let eps_vec = _mm256_set1_ps(eps);
110        let lr_vec = _mm256_set1_ps(corrected_lr);
111
112        let len = params.len();
113        let chunks = len / 8;
114        let _remainder = len % 8;
115
116        // Process 8 elements at a time with AVX2
117        for i in 0..chunks {
118            let idx = i * 8;
119
120            // Load values
121            let p = _mm256_loadu_ps(params.as_ptr().add(idx));
122            let g = _mm256_loadu_ps(gradients.as_ptr().add(idx));
123            let m = _mm256_loadu_ps(momentum.as_ptr().add(idx));
124            let v = _mm256_loadu_ps(velocity.as_ptr().add(idx));
125
126            // Update momentum: m = β₁ * m + (1 - β₁) * g
127            let m_new = _mm256_fmadd_ps(beta1_vec, m, _mm256_mul_ps(one_minus_beta1, g));
128
129            // Update velocity: v = β₂ * v + (1 - β₂) * g²
130            let g_sq = _mm256_mul_ps(g, g);
131            let v_new = _mm256_fmadd_ps(beta2_vec, v, _mm256_mul_ps(one_minus_beta2, g_sq));
132
133            // Update parameters: p = p - α * m / (√v + ε)
134            let v_sqrt = _mm256_sqrt_ps(v_new);
135            let v_sqrt_eps = _mm256_add_ps(v_sqrt, eps_vec);
136            let update = _mm256_div_ps(m_new, v_sqrt_eps);
137            let p_new = _mm256_fnmadd_ps(lr_vec, update, p);
138
139            // Store results
140            _mm256_storeu_ps(params.as_mut_ptr().add(idx), p_new);
141            _mm256_storeu_ps(momentum.as_mut_ptr().add(idx), m_new);
142            _mm256_storeu_ps(velocity.as_mut_ptr().add(idx), v_new);
143        }
144
145        // Handle remaining elements
146        for i in (chunks * 8)..len {
147            let g = gradients[i];
148            let m = momentum[i];
149            let v = velocity[i];
150
151            let m_new = beta1 * m + (1.0 - beta1) * g;
152            let v_new = beta2 * v + (1.0 - beta2) * g * g;
153
154            momentum[i] = m_new;
155            velocity[i] = v_new;
156            params[i] -= corrected_lr * m_new / (v_new.sqrt() + eps);
157        }
158
159        Ok(())
160    }
161
162    /// SIMD-optimized AdamW update with AVX2 (decoupled weight decay)
163    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
164    #[target_feature(enable = "avx2")]
165    pub unsafe fn adamw_update_avx2(
166        &self,
167        params: &mut [f32],
168        gradients: &[f32],
169        momentum: &mut [f32],
170        velocity: &mut [f32],
171        lr: f32,
172        beta1: f32,
173        beta2: f32,
174        eps: f32,
175        weight_decay: f32,
176        step: i32,
177    ) -> Result<()> {
178        if params.len() != gradients.len()
179            || params.len() != momentum.len()
180            || params.len() != velocity.len()
181        {
182            return Err(anyhow!("All arrays must have the same length"));
183        }
184
185        let bias_correction1 = 1.0 - beta1.powi(step);
186        let bias_correction2 = 1.0 - beta2.powi(step);
187        let corrected_lr = lr * (bias_correction2.sqrt() / bias_correction1);
188
189        // SIMD constants
190        let beta1_vec = _mm256_set1_ps(beta1);
191        let beta2_vec = _mm256_set1_ps(beta2);
192        let one_minus_beta1 = _mm256_set1_ps(1.0 - beta1);
193        let one_minus_beta2 = _mm256_set1_ps(1.0 - beta2);
194        let eps_vec = _mm256_set1_ps(eps);
195        let lr_vec = _mm256_set1_ps(corrected_lr);
196        let wd_vec = _mm256_set1_ps(1.0 - lr * weight_decay);
197
198        let len = params.len();
199        let chunks = len / 8;
200
201        for i in 0..chunks {
202            let idx = i * 8;
203
204            let p = _mm256_loadu_ps(params.as_ptr().add(idx));
205            let g = _mm256_loadu_ps(gradients.as_ptr().add(idx));
206            let m = _mm256_loadu_ps(momentum.as_ptr().add(idx));
207            let v = _mm256_loadu_ps(velocity.as_ptr().add(idx));
208
209            // Apply weight decay first: p = p * (1 - lr * wd)
210            let p_decayed = _mm256_mul_ps(p, wd_vec);
211
212            // Update momentum and velocity
213            let m_new = _mm256_fmadd_ps(beta1_vec, m, _mm256_mul_ps(one_minus_beta1, g));
214            let g_sq = _mm256_mul_ps(g, g);
215            let v_new = _mm256_fmadd_ps(beta2_vec, v, _mm256_mul_ps(one_minus_beta2, g_sq));
216
217            // Update parameters
218            let v_sqrt = _mm256_sqrt_ps(v_new);
219            let v_sqrt_eps = _mm256_add_ps(v_sqrt, eps_vec);
220            let update = _mm256_div_ps(m_new, v_sqrt_eps);
221            let p_new = _mm256_fnmadd_ps(lr_vec, update, p_decayed);
222
223            _mm256_storeu_ps(params.as_mut_ptr().add(idx), p_new);
224            _mm256_storeu_ps(momentum.as_mut_ptr().add(idx), m_new);
225            _mm256_storeu_ps(velocity.as_mut_ptr().add(idx), v_new);
226        }
227
228        // Handle remaining elements
229        for i in (chunks * 8)..len {
230            let p = params[i];
231            let g = gradients[i];
232            let m = momentum[i];
233            let v = velocity[i];
234
235            let p_decayed = p * (1.0 - lr * weight_decay);
236            let m_new = beta1 * m + (1.0 - beta1) * g;
237            let v_new = beta2 * v + (1.0 - beta2) * g * g;
238
239            momentum[i] = m_new;
240            velocity[i] = v_new;
241            params[i] = p_decayed - corrected_lr * m_new / (v_new.sqrt() + eps);
242        }
243
244        Ok(())
245    }
246
247    /// SIMD-optimized SGD with momentum update
248    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
249    #[target_feature(enable = "avx2")]
250    pub unsafe fn sgd_momentum_update_avx2(
251        &self,
252        params: &mut [f32],
253        gradients: &[f32],
254        momentum: &mut [f32],
255        lr: f32,
256        momentum_factor: f32,
257        weight_decay: f32,
258        nesterov: bool,
259    ) -> Result<()> {
260        if params.len() != gradients.len() || params.len() != momentum.len() {
261            return Err(anyhow!("All arrays must have the same length"));
262        }
263
264        let lr_vec = _mm256_set1_ps(lr);
265        let momentum_vec = _mm256_set1_ps(momentum_factor);
266        let wd_vec = _mm256_set1_ps(weight_decay);
267
268        let len = params.len();
269        let chunks = len / 8;
270
271        for i in 0..chunks {
272            let idx = i * 8;
273
274            let p = _mm256_loadu_ps(params.as_ptr().add(idx));
275            let g = _mm256_loadu_ps(gradients.as_ptr().add(idx));
276            let m = _mm256_loadu_ps(momentum.as_ptr().add(idx));
277
278            // Apply weight decay to gradient: g = g + wd * p
279            let g_wd = _mm256_fmadd_ps(wd_vec, p, g);
280
281            // Update momentum: m = momentum * m + g
282            let m_new = _mm256_fmadd_ps(momentum_vec, m, g_wd);
283
284            // Update parameters
285            let update = if nesterov {
286                // Nesterov: p = p - lr * (momentum * m + g)
287                _mm256_fmadd_ps(momentum_vec, m_new, g_wd)
288            } else {
289                // Standard: p = p - lr * m
290                m_new
291            };
292
293            let p_new = _mm256_fnmadd_ps(lr_vec, update, p);
294
295            _mm256_storeu_ps(params.as_mut_ptr().add(idx), p_new);
296            _mm256_storeu_ps(momentum.as_mut_ptr().add(idx), m_new);
297        }
298
299        // Handle remaining elements
300        for i in (chunks * 8)..len {
301            let p = params[i];
302            let g = gradients[i] + weight_decay * p;
303            let m = momentum[i];
304
305            let m_new = momentum_factor * m + g;
306            momentum[i] = m_new;
307
308            if nesterov {
309                params[i] = p - lr * (momentum_factor * m_new + g);
310            } else {
311                params[i] = p - lr * m_new;
312            }
313        }
314
315        Ok(())
316    }
317
318    /// SIMD-optimized gradient clipping
319    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
320    #[target_feature(enable = "avx2")]
321    pub unsafe fn clip_gradients_avx2(&self, gradients: &mut [f32], max_norm: f32) -> Result<f32> {
322        let len = gradients.len();
323        let chunks = len / 8;
324
325        // Compute global norm
326        let mut norm_sq_vec = _mm256_setzero_ps();
327
328        for i in 0..chunks {
329            let idx = i * 8;
330            let g = _mm256_loadu_ps(gradients.as_ptr().add(idx));
331            let g_sq = _mm256_mul_ps(g, g);
332            norm_sq_vec = _mm256_add_ps(norm_sq_vec, g_sq);
333        }
334
335        // Horizontal sum of norm_sq_vec
336        let mut norm_sq = 0.0f32;
337        let norm_sq_array: [f32; 8] = std::mem::transmute(norm_sq_vec);
338        for &val in &norm_sq_array {
339            norm_sq += val;
340        }
341
342        // Add remaining elements
343        for i in (chunks * 8)..len {
344            norm_sq += gradients[i] * gradients[i];
345        }
346
347        let global_norm = norm_sq.sqrt();
348
349        if global_norm > max_norm {
350            let scale = max_norm / global_norm;
351            let scale_vec = _mm256_set1_ps(scale);
352
353            // Scale gradients
354            for i in 0..chunks {
355                let idx = i * 8;
356                let g = _mm256_loadu_ps(gradients.as_ptr().add(idx));
357                let g_scaled = _mm256_mul_ps(g, scale_vec);
358                _mm256_storeu_ps(gradients.as_mut_ptr().add(idx), g_scaled);
359            }
360
361            // Scale remaining elements
362            for i in (chunks * 8)..len {
363                gradients[i] *= scale;
364            }
365        }
366
367        Ok(global_norm)
368    }
369
370    /// SIMD-optimized vector addition (for gradient accumulation)
371    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
372    #[target_feature(enable = "avx2")]
373    pub unsafe fn vector_add_avx2(&self, a: &mut [f32], b: &[f32], scale: f32) -> Result<()> {
374        if a.len() != b.len() {
375            return Err(anyhow!("Vectors must have the same length"));
376        }
377
378        let scale_vec = _mm256_set1_ps(scale);
379        let len = a.len();
380        let chunks = len / 8;
381
382        for i in 0..chunks {
383            let idx = i * 8;
384            let a_vec = _mm256_loadu_ps(a.as_ptr().add(idx));
385            let b_vec = _mm256_loadu_ps(b.as_ptr().add(idx));
386            let result = _mm256_fmadd_ps(b_vec, scale_vec, a_vec);
387            _mm256_storeu_ps(a.as_mut_ptr().add(idx), result);
388        }
389
390        // Handle remaining elements
391        for i in (chunks * 8)..len {
392            a[i] += scale * b[i];
393        }
394
395        Ok(())
396    }
397
398    /// SIMD-optimized dot product
399    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
400    #[target_feature(enable = "avx2")]
401    pub unsafe fn dot_product_avx2(&self, a: &[f32], b: &[f32]) -> Result<f32> {
402        if a.len() != b.len() {
403            return Err(anyhow!("Vectors must have the same length"));
404        }
405
406        let len = a.len();
407        let chunks = len / 8;
408        let mut result_vec = _mm256_setzero_ps();
409
410        for i in 0..chunks {
411            let idx = i * 8;
412            let a_vec = _mm256_loadu_ps(a.as_ptr().add(idx));
413            let b_vec = _mm256_loadu_ps(b.as_ptr().add(idx));
414            let prod = _mm256_mul_ps(a_vec, b_vec);
415            result_vec = _mm256_add_ps(result_vec, prod);
416        }
417
418        // Horizontal sum
419        let result_array: [f32; 8] = std::mem::transmute(result_vec);
420        let mut result = result_array.iter().sum::<f32>();
421
422        // Add remaining elements
423        for i in (chunks * 8)..len {
424            result += a[i] * b[i];
425        }
426
427        Ok(result)
428    }
429
430    /// Fallback implementations for non-x86 architectures
431    pub fn adam_update_fallback(
432        &self,
433        params: &mut [f32],
434        gradients: &[f32],
435        momentum: &mut [f32],
436        velocity: &mut [f32],
437        lr: f32,
438        beta1: f32,
439        beta2: f32,
440        eps: f32,
441        step: i32,
442    ) -> Result<()> {
443        if params.len() != gradients.len()
444            || params.len() != momentum.len()
445            || params.len() != velocity.len()
446        {
447            return Err(anyhow!("All arrays must have the same length"));
448        }
449
450        let bias_correction1 = 1.0 - beta1.powi(step);
451        let bias_correction2 = 1.0 - beta2.powi(step);
452        let corrected_lr = lr * (bias_correction2.sqrt() / bias_correction1);
453
454        for i in 0..params.len() {
455            let g = gradients[i];
456            let m = momentum[i];
457            let v = velocity[i];
458
459            let m_new = beta1 * m + (1.0 - beta1) * g;
460            let v_new = beta2 * v + (1.0 - beta2) * g * g;
461
462            momentum[i] = m_new;
463            velocity[i] = v_new;
464            params[i] -= corrected_lr * m_new / (v_new.sqrt() + eps);
465        }
466
467        Ok(())
468    }
469
470    /// Auto-dispatch to best available implementation
471    pub fn adam_update(
472        &self,
473        params: &mut [f32],
474        gradients: &[f32],
475        momentum: &mut [f32],
476        velocity: &mut [f32],
477        lr: f32,
478        beta1: f32,
479        beta2: f32,
480        eps: f32,
481        step: i32,
482    ) -> Result<()> {
483        if params.len() < self.config.min_vector_size {
484            return self.adam_update_fallback(
485                params, gradients, momentum, velocity, lr, beta1, beta2, eps, step,
486            );
487        }
488
489        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
490        {
491            if self.config.enable_avx2 && is_x86_feature_detected!("avx2") {
492                return unsafe {
493                    self.adam_update_avx2(
494                        params, gradients, momentum, velocity, lr, beta1, beta2, eps, step,
495                    )
496                };
497            }
498        }
499
500        self.adam_update_fallback(
501            params, gradients, momentum, velocity, lr, beta1, beta2, eps, step,
502        )
503    }
504
505    /// Auto-dispatch AdamW
506    pub fn adamw_update(
507        &self,
508        params: &mut [f32],
509        gradients: &[f32],
510        momentum: &mut [f32],
511        velocity: &mut [f32],
512        lr: f32,
513        beta1: f32,
514        beta2: f32,
515        eps: f32,
516        weight_decay: f32,
517        step: i32,
518    ) -> Result<()> {
519        if params.len() < self.config.min_vector_size {
520            // Fallback implementation
521            let bias_correction1 = 1.0 - beta1.powi(step);
522            let bias_correction2 = 1.0 - beta2.powi(step);
523            let corrected_lr = lr * (bias_correction2.sqrt() / bias_correction1);
524
525            for i in 0..params.len() {
526                let p = params[i];
527                let g = gradients[i];
528                let m = momentum[i];
529                let v = velocity[i];
530
531                let p_decayed = p * (1.0 - lr * weight_decay);
532                let m_new = beta1 * m + (1.0 - beta1) * g;
533                let v_new = beta2 * v + (1.0 - beta2) * g * g;
534
535                momentum[i] = m_new;
536                velocity[i] = v_new;
537                params[i] = p_decayed - corrected_lr * m_new / (v_new.sqrt() + eps);
538            }
539            return Ok(());
540        }
541
542        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
543        {
544            if self.config.enable_avx2 && is_x86_feature_detected!("avx2") {
545                return unsafe {
546                    self.adamw_update_avx2(
547                        params,
548                        gradients,
549                        momentum,
550                        velocity,
551                        lr,
552                        beta1,
553                        beta2,
554                        eps,
555                        weight_decay,
556                        step,
557                    )
558                };
559            }
560        }
561
562        // Fallback
563        self.adamw_update(
564            params,
565            gradients,
566            momentum,
567            velocity,
568            lr,
569            beta1,
570            beta2,
571            eps,
572            weight_decay,
573            step,
574        )
575    }
576
577    /// Get performance statistics
578    pub fn get_performance_info(&self) -> SIMDPerformanceInfo {
579        SIMDPerformanceInfo {
580            avx2_available: {
581                #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
582                {
583                    is_x86_feature_detected!("avx2")
584                }
585                #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
586                {
587                    false
588                }
589            },
590            avx512_available: {
591                #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
592                {
593                    is_x86_feature_detected!("avx512f")
594                }
595                #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
596                {
597                    false
598                }
599            },
600            neon_available: cfg!(target_arch = "aarch64"),
601            vector_width: {
602                #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
603                {
604                    if is_x86_feature_detected!("avx2") {
605                        8
606                    } else {
607                        1
608                    }
609                }
610                #[cfg(not(any(target_arch = "x86", target_arch = "x86_64")))]
611                {
612                    1
613                }
614            },
615            recommended_min_size: self.config.min_vector_size,
616        }
617    }
618}
619
620impl Default for SIMDOptimizer {
621    fn default() -> Self {
622        Self::new(SIMDOptimizer::detect_capabilities())
623    }
624}
625
626/// SIMD performance information
627#[derive(Debug, Clone)]
628pub struct SIMDPerformanceInfo {
629    pub avx2_available: bool,
630    pub avx512_available: bool,
631    pub neon_available: bool,
632    pub vector_width: usize,
633    pub recommended_min_size: usize,
634}
635
636#[cfg(test)]
637mod tests {
638    use super::*;
639
640    #[test]
641    fn test_simd_config_detection() {
642        let config = SIMDOptimizer::detect_capabilities();
643        // Test will pass regardless of actual hardware capabilities
644        assert!(config.min_vector_size > 0);
645    }
646
647    #[test]
648    fn test_adam_update_fallback() {
649        let optimizer = SIMDOptimizer::default();
650        let mut params = vec![1.0, 2.0, 3.0, 4.0];
651        let gradients = vec![0.1, 0.2, 0.3, 0.4];
652        let mut momentum = vec![0.0; 4];
653        let mut velocity = vec![0.0; 4];
654
655        optimizer
656            .adam_update_fallback(
657                &mut params,
658                &gradients,
659                &mut momentum,
660                &mut velocity,
661                0.001,
662                0.9,
663                0.999,
664                1e-8,
665                1,
666            )
667            .unwrap();
668
669        // Check that parameters were updated
670        assert!(params[0] < 1.0);
671        assert!(momentum[0] > 0.0);
672        assert!(velocity[0] > 0.0);
673    }
674
675    #[test]
676    fn test_auto_dispatch_adam() {
677        let optimizer = SIMDOptimizer::default();
678        let mut params = vec![1.0; 16];
679        let gradients = vec![0.1; 16];
680        let mut momentum = vec![0.0; 16];
681        let mut velocity = vec![0.0; 16];
682
683        optimizer
684            .adam_update(
685                &mut params,
686                &gradients,
687                &mut momentum,
688                &mut velocity,
689                0.001,
690                0.9,
691                0.999,
692                1e-8,
693                1,
694            )
695            .unwrap();
696
697        // Verify update occurred
698        assert!(params.iter().all(|&p| p < 1.0));
699        assert!(momentum.iter().all(|&m| m > 0.0));
700    }
701
702    #[test]
703    fn test_performance_info() {
704        let optimizer = SIMDOptimizer::default();
705        let info = optimizer.get_performance_info();
706
707        assert!(info.vector_width > 0);
708        assert!(info.recommended_min_size > 0);
709    }
710
711    #[test]
712    fn test_vector_operations() {
713        let optimizer = SIMDOptimizer::default();
714        let mut a = vec![1.0, 2.0, 3.0, 4.0];
715        let b = vec![0.5, 0.5, 0.5, 0.5];
716
717        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
718        {
719            if is_x86_feature_detected!("avx2") {
720                unsafe {
721                    optimizer.vector_add_avx2(&mut a, &b, 2.0).unwrap();
722                }
723                assert_eq!(a, vec![2.0, 3.0, 4.0, 5.0]);
724            }
725        }
726    }
727
728    #[test]
729    fn test_dot_product() {
730        let optimizer = SIMDOptimizer::default();
731        let a = vec![1.0, 2.0, 3.0, 4.0];
732        let b = vec![1.0, 1.0, 1.0, 1.0];
733
734        #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
735        {
736            if is_x86_feature_detected!("avx2") {
737                unsafe {
738                    let result = optimizer.dot_product_avx2(&a, &b).unwrap();
739                    assert_eq!(result, 10.0);
740                }
741            }
742        }
743    }
744}