Skip to main content

sklears_simd/vector/
basic_operations.rs

1//! # Basic SIMD Vector Operations
2//!
3//! Fundamental vector operations optimized with SIMD instructions including
4//! dot products, norms, distance calculations, and basic linear algebra operations.
5//!
6//! ## Features
7//!
8//! - **Dot Product**: SIMD-optimized dot product with platform-specific implementations
9//! - **Vector Norms**: L1, L2, and infinity norms with high performance
10//! - **Distance Metrics**: Euclidean, Manhattan, and cosine distance/similarity
11//! - **Advanced Operations**: Cross product and outer product for linear algebra
12//! - **Multi-Platform**: SSE2, AVX2, AVX512, and NEON optimizations
13//! - **Scalar Fallbacks**: Automatic fallback to scalar operations when needed
14//!
15//! ## Implementation Details
16//!
17//! All functions automatically detect the best available SIMD instruction set
18//! and provide graceful fallback to scalar implementations. The functions are
19//! designed to handle arbitrary vector lengths efficiently by processing
20//! SIMD-sized chunks and handling remainders appropriately.
21
22#[cfg(feature = "no-std")]
23use alloc::vec;
24#[cfg(feature = "no-std")]
25use alloc::vec::Vec;
26#[cfg(not(feature = "no-std"))]
27use std::vec::Vec;
28
29// Import ARM64 feature detection macro
30#[cfg(all(target_arch = "aarch64", not(feature = "no-std")))]
31use std::arch::is_aarch64_feature_detected;
32
33/// SIMD-optimized dot product computation
34///
35/// Computes the dot product of two vectors using the best available SIMD
36/// instruction set. Automatically falls back to scalar implementation if
37/// no SIMD support is available.
38///
39/// # Arguments
40/// * `a` - First input vector
41/// * `b` - Second input vector (must have same length as `a`)
42///
43/// # Returns
44/// The dot product as a single f32 value
45///
46/// # Panics
47/// Panics if the vectors have different lengths
48///
49/// # Examples
50/// ```rust
51/// use sklears_simd::vector::basic_operations::dot_product;
52///
53/// let a = vec![1.0, 2.0, 3.0, 4.0];
54/// let b = vec![5.0, 6.0, 7.0, 8.0];
55/// let result = dot_product(&a, &b);
56/// assert_eq!(result, 70.0); // 1*5 + 2*6 + 3*7 + 4*8 = 70
57/// ```
58pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
59    assert_eq!(a.len(), b.len(), "Vectors must have the same length");
60
61    if a.is_empty() {
62        return 0.0;
63    }
64
65    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
66    {
67        if crate::simd_feature_detected!("avx512f") {
68            return unsafe { dot_product_avx512(a, b) };
69        } else if crate::simd_feature_detected!("avx2") {
70            return unsafe { dot_product_avx2(a, b) };
71        } else if crate::simd_feature_detected!("sse2") {
72            return unsafe { dot_product_sse2(a, b) };
73        }
74    }
75
76    #[cfg(all(target_arch = "aarch64", not(feature = "no-std")))]
77    {
78        if is_aarch64_feature_detected!("neon") {
79            return unsafe { dot_product_neon(a, b) };
80        }
81    }
82
83    dot_product_scalar(a, b)
84}
85
86/// SIMD-optimized L2 norm (Euclidean norm) computation
87///
88/// Computes the L2 norm (||x||₂) of a vector using SIMD-optimized dot product.
89///
90/// # Arguments
91/// * `x` - Input vector
92///
93/// # Returns
94/// The L2 norm as a single f32 value
95///
96/// # Examples
97/// ```rust
98/// use sklears_simd::vector::basic_operations::norm_l2;
99///
100/// let x = vec![3.0, 4.0];
101/// let result = norm_l2(&x);
102/// assert_eq!(result, 5.0); // sqrt(3² + 4²) = 5
103/// ```
104pub fn norm_l2(x: &[f32]) -> f32 {
105    dot_product(x, x).sqrt()
106}
107
108/// SIMD-optimized L1 norm (Manhattan norm) computation
109///
110/// Computes the L1 norm (||x||₁) of a vector using SIMD instructions
111/// for absolute value computation and summation.
112///
113/// # Arguments
114/// * `x` - Input vector
115///
116/// # Returns
117/// The L1 norm as a single f32 value
118///
119/// # Examples
120/// ```rust
121/// use sklears_simd::vector::basic_operations::norm_l1;
122///
123/// let x = vec![-3.0, 4.0, -5.0];
124/// let result = norm_l1(&x);
125/// assert_eq!(result, 12.0); // |−3| + |4| + |−5| = 12
126/// ```
127pub fn norm_l1(x: &[f32]) -> f32 {
128    if x.is_empty() {
129        return 0.0;
130    }
131
132    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
133    {
134        if crate::simd_feature_detected!("avx2") {
135            return unsafe { norm_l1_avx2(x) };
136        } else if crate::simd_feature_detected!("sse2") {
137            return unsafe { norm_l1_sse2(x) };
138        }
139    }
140
141    #[cfg(all(target_arch = "aarch64", not(feature = "no-std")))]
142    {
143        if is_aarch64_feature_detected!("neon") {
144            return unsafe { norm_l1_neon(x) };
145        }
146    }
147
148    norm_l1_scalar(x)
149}
150
151/// SIMD-optimized infinity norm computation
152///
153/// Computes the L∞ norm (||x||∞) of a vector, which is the maximum
154/// absolute value of all elements.
155///
156/// # Arguments
157/// * `x` - Input vector
158///
159/// # Returns
160/// The infinity norm as a single f32 value
161///
162/// # Examples
163/// ```rust
164/// use sklears_simd::vector::basic_operations::norm_inf;
165///
166/// let x = vec![-3.0, 4.0, -5.0, 2.0];
167/// let result = norm_inf(&x);
168/// assert_eq!(result, 5.0); // max(|−3|, |4|, |−5|, |2|) = 5
169/// ```
170pub fn norm_inf(x: &[f32]) -> f32 {
171    if x.is_empty() {
172        return 0.0;
173    }
174
175    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
176    {
177        if crate::simd_feature_detected!("avx2") {
178            return unsafe { norm_inf_avx2(x) };
179        } else if crate::simd_feature_detected!("sse2") {
180            return unsafe { norm_inf_sse2(x) };
181        }
182    }
183
184    norm_inf_scalar(x)
185}
186
187/// SIMD-optimized Euclidean distance computation
188///
189/// Computes the Euclidean distance between two vectors:
190/// ||a - b||₂ = sqrt(Σ(aᵢ - bᵢ)²)
191///
192/// # Arguments
193/// * `a` - First vector
194/// * `b` - Second vector (must have same length as `a`)
195///
196/// # Returns
197/// The Euclidean distance as a single f32 value
198///
199/// # Panics
200/// Panics if the vectors have different lengths
201///
202/// # Examples
203/// ```rust
204/// use sklears_simd::vector::basic_operations::euclidean_distance;
205///
206/// let a = vec![1.0, 2.0, 3.0];
207/// let b = vec![4.0, 5.0, 6.0];
208/// let result = euclidean_distance(&a, &b);
209/// // sqrt((1-4)² + (2-5)² + (3-6)²) = sqrt(9 + 9 + 9) = sqrt(27) ≈ 5.196
210/// assert!((result - 5.196).abs() < 0.01);
211/// ```
212pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
213    assert_eq!(a.len(), b.len(), "Vectors must have the same length");
214
215    if a.is_empty() {
216        return 0.0;
217    }
218
219    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
220    {
221        if crate::simd_feature_detected!("avx2") {
222            return unsafe { euclidean_distance_avx2(a, b) };
223        } else if crate::simd_feature_detected!("sse2") {
224            return unsafe { euclidean_distance_sse2(a, b) };
225        }
226    }
227
228    euclidean_distance_scalar(a, b)
229}
230
231/// SIMD-optimized cosine similarity computation
232///
233/// Computes the cosine similarity between two vectors:
234/// cos_sim(a, b) = (a · b) / (||a||₂ * ||b||₂)
235///
236/// # Arguments
237/// * `a` - First vector
238/// * `b` - Second vector (must have same length as `a`)
239///
240/// # Returns
241/// The cosine similarity as a single f32 value between -1.0 and 1.0
242///
243/// # Panics
244/// Panics if the vectors have different lengths
245///
246/// # Examples
247/// ```rust
248/// use sklears_simd::vector::basic_operations::cosine_similarity;
249///
250/// let a = vec![1.0, 0.0, 0.0];
251/// let b = vec![0.0, 1.0, 0.0];
252/// let result = cosine_similarity(&a, &b);
253/// assert_eq!(result, 0.0); // Orthogonal vectors have cosine similarity of 0
254/// ```
255pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
256    assert_eq!(a.len(), b.len(), "Vectors must have the same length");
257
258    if a.is_empty() {
259        return 1.0; // Convention: empty vectors are perfectly similar
260    }
261
262    let dot_ab = dot_product(a, b);
263    let norm_a = norm_l2(a);
264    let norm_b = norm_l2(b);
265
266    if norm_a == 0.0 || norm_b == 0.0 {
267        return 0.0; // Zero vectors are orthogonal to everything
268    }
269
270    dot_ab / (norm_a * norm_b)
271}
272
273/// SIMD-optimized cross product for 3D vectors
274///
275/// Computes the cross product a × b for two 3-dimensional vectors.
276///
277/// # Arguments
278/// * `a` - First 3D vector
279/// * `b` - Second 3D vector
280///
281/// # Returns
282/// The cross product as a 3-element vector, or an error if input vectors
283/// are not exactly 3 elements long
284///
285/// # Examples
286/// ```rust
287/// use sklears_simd::vector::basic_operations::cross_product;
288///
289/// let a = vec![1.0, 0.0, 0.0];  // Unit vector along x-axis
290/// let b = vec![0.0, 1.0, 0.0];  // Unit vector along y-axis
291/// let result = cross_product(&a, &b).unwrap();
292/// assert_eq!(result, vec![0.0, 0.0, 1.0]);  // Unit vector along z-axis
293/// ```
294pub fn cross_product(a: &[f32], b: &[f32]) -> Result<Vec<f32>, &'static str> {
295    if a.len() != 3 || b.len() != 3 {
296        return Err("Cross product requires exactly 3-dimensional vectors");
297    }
298
299    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
300    {
301        if crate::simd_feature_detected!("sse2") {
302            return Ok(unsafe { cross_product_sse2(a, b) });
303        }
304    }
305
306    Ok(cross_product_scalar(a, b))
307}
308
309/// SIMD-optimized outer product computation
310///
311/// Computes the outer product of two vectors, resulting in a matrix
312/// where element (i,j) = a\[i\] * b\[j\].
313///
314/// # Arguments
315/// * `a` - First vector (m elements)
316/// * `b` - Second vector (n elements)
317///
318/// # Returns
319/// An m×n matrix represented as `Vec<Vec<f32>>`
320///
321/// # Examples
322/// ```rust
323/// use sklears_simd::vector::basic_operations::outer_product;
324///
325/// let a = vec![1.0, 2.0];
326/// let b = vec![3.0, 4.0, 5.0];
327/// let result = outer_product(&a, &b);
328/// // Expected: [[3.0, 4.0, 5.0], [6.0, 8.0, 10.0]]
329/// assert_eq!(result[0], vec![3.0, 4.0, 5.0]);
330/// assert_eq!(result[1], vec![6.0, 8.0, 10.0]);
331/// ```
332pub fn outer_product(a: &[f32], b: &[f32]) -> Vec<Vec<f32>> {
333    let m = a.len();
334    let n = b.len();
335
336    if m == 0 || n == 0 {
337        return vec![];
338    }
339
340    #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
341    {
342        if crate::simd_feature_detected!("avx2") {
343            return unsafe { outer_product_avx2(a, b) };
344        } else if crate::simd_feature_detected!("sse2") {
345            return unsafe { outer_product_sse2(a, b) };
346        }
347    }
348
349    outer_product_scalar(a, b)
350}
351
352// ============================================================================
353// Scalar implementations (fallbacks)
354// ============================================================================
355
356fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
357    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
358}
359
360fn norm_l1_scalar(x: &[f32]) -> f32 {
361    x.iter().map(|&v| v.abs()).sum()
362}
363
364fn norm_inf_scalar(x: &[f32]) -> f32 {
365    x.iter().map(|&v| v.abs()).fold(0.0f32, |a, b| a.max(b))
366}
367
368fn euclidean_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
369    a.iter()
370        .zip(b.iter())
371        .map(|(x, y)| {
372            let diff = x - y;
373            diff * diff
374        })
375        .sum::<f32>()
376        .sqrt()
377}
378
379fn cross_product_scalar(a: &[f32], b: &[f32]) -> Vec<f32> {
380    vec![
381        a[1] * b[2] - a[2] * b[1], // i component
382        a[2] * b[0] - a[0] * b[2], // j component
383        a[0] * b[1] - a[1] * b[0], // k component
384    ]
385}
386
387fn outer_product_scalar(a: &[f32], b: &[f32]) -> Vec<Vec<f32>> {
388    let m = a.len();
389    let n = b.len();
390    let mut result = vec![vec![0.0; n]; m];
391
392    for i in 0..m {
393        for (j, &b_val) in b.iter().enumerate().take(n) {
394            result[i][j] = a[i] * b_val;
395        }
396    }
397
398    result
399}
400
401// ============================================================================
402// SSE2 implementations (x86/x86_64)
403// ============================================================================
404
405#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
406#[target_feature(enable = "sse2")]
407unsafe fn dot_product_sse2(a: &[f32], b: &[f32]) -> f32 {
408    #[cfg(feature = "no-std")]
409    use core::arch::x86_64::*;
410    #[cfg(not(feature = "no-std"))]
411    use core::arch::x86_64::*;
412
413    let mut sum = _mm_setzero_ps();
414    let mut i = 0;
415
416    // Process 4 elements at a time
417    while i + 4 <= a.len() {
418        let a_vec = _mm_loadu_ps(a.as_ptr().add(i));
419        let b_vec = _mm_loadu_ps(b.as_ptr().add(i));
420        let prod = _mm_mul_ps(a_vec, b_vec);
421        sum = _mm_add_ps(sum, prod);
422        i += 4;
423    }
424
425    // Horizontal sum of the 4 elements in sum
426    let mut result = [0.0f32; 4];
427    _mm_storeu_ps(result.as_mut_ptr(), sum);
428    let mut scalar_sum = result[0] + result[1] + result[2] + result[3];
429
430    // Handle remaining elements
431    while i < a.len() {
432        scalar_sum += a[i] * b[i];
433        i += 1;
434    }
435
436    scalar_sum
437}
438
439#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
440#[target_feature(enable = "sse2")]
441unsafe fn norm_l1_sse2(x: &[f32]) -> f32 {
442    #[cfg(feature = "no-std")]
443    use core::arch::x86_64::*;
444    #[cfg(not(feature = "no-std"))]
445    use core::arch::x86_64::*;
446
447    // Create mask for absolute value (clear sign bit)
448    let abs_mask = _mm_set1_ps(f32::from_bits(0x7FFFFFFF));
449    let mut sum = _mm_setzero_ps();
450    let mut i = 0;
451
452    // Process 4 elements at a time
453    while i + 4 <= x.len() {
454        let x_vec = _mm_loadu_ps(x.as_ptr().add(i));
455        let abs_vec = _mm_and_ps(x_vec, abs_mask);
456        sum = _mm_add_ps(sum, abs_vec);
457        i += 4;
458    }
459
460    // Horizontal sum
461    let mut result = [0.0f32; 4];
462    _mm_storeu_ps(result.as_mut_ptr(), sum);
463    let mut scalar_sum = result[0] + result[1] + result[2] + result[3];
464
465    // Handle remaining elements
466    while i < x.len() {
467        scalar_sum += x[i].abs();
468        i += 1;
469    }
470
471    scalar_sum
472}
473
474#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
475#[target_feature(enable = "sse2")]
476unsafe fn norm_inf_sse2(x: &[f32]) -> f32 {
477    #[cfg(feature = "no-std")]
478    use core::arch::x86_64::*;
479    #[cfg(not(feature = "no-std"))]
480    use core::arch::x86_64::*;
481
482    let abs_mask = _mm_set1_ps(f32::from_bits(0x7FFFFFFF));
483    let mut max_vec = _mm_setzero_ps();
484    let mut i = 0;
485
486    while i + 4 <= x.len() {
487        let x_vec = _mm_loadu_ps(x.as_ptr().add(i));
488        let abs_vec = _mm_and_ps(x_vec, abs_mask);
489        max_vec = _mm_max_ps(max_vec, abs_vec);
490        i += 4;
491    }
492
493    // Find maximum of the 4 elements
494    let mut result = [0.0f32; 4];
495    _mm_storeu_ps(result.as_mut_ptr(), max_vec);
496    let mut max_val = result[0].max(result[1]).max(result[2]).max(result[3]);
497
498    // Handle remaining elements
499    while i < x.len() {
500        max_val = max_val.max(x[i].abs());
501        i += 1;
502    }
503
504    max_val
505}
506
507#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
508#[target_feature(enable = "sse2")]
509unsafe fn euclidean_distance_sse2(a: &[f32], b: &[f32]) -> f32 {
510    #[cfg(feature = "no-std")]
511    use core::arch::x86_64::*;
512    #[cfg(not(feature = "no-std"))]
513    use core::arch::x86_64::*;
514
515    let mut sum = _mm_setzero_ps();
516    let mut i = 0;
517
518    while i + 4 <= a.len() {
519        let a_vec = _mm_loadu_ps(a.as_ptr().add(i));
520        let b_vec = _mm_loadu_ps(b.as_ptr().add(i));
521        let diff = _mm_sub_ps(a_vec, b_vec);
522        let squared = _mm_mul_ps(diff, diff);
523        sum = _mm_add_ps(sum, squared);
524        i += 4;
525    }
526
527    // Horizontal sum
528    let mut result = [0.0f32; 4];
529    _mm_storeu_ps(result.as_mut_ptr(), sum);
530    let mut scalar_sum = result[0] + result[1] + result[2] + result[3];
531
532    // Handle remaining elements
533    while i < a.len() {
534        let diff = a[i] - b[i];
535        scalar_sum += diff * diff;
536        i += 1;
537    }
538
539    scalar_sum.sqrt()
540}
541
542#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
543#[target_feature(enable = "sse2")]
544unsafe fn cross_product_sse2(a: &[f32], b: &[f32]) -> Vec<f32> {
545    #[cfg(feature = "no-std")]
546    use core::arch::x86_64::*;
547    #[cfg(not(feature = "no-std"))]
548    use core::arch::x86_64::*;
549
550    // Load vectors: a = [a0, a1, a2, 0], b = [b0, b1, b2, 0]
551    let a_vec = _mm_set_ps(0.0, a[2], a[1], a[0]);
552    let b_vec = _mm_set_ps(0.0, b[2], b[1], b[0]);
553
554    // Create shuffled versions for cross product computation
555    // a_yzx = [a1, a2, a0, *]: bits[1:0]=01 bits[3:2]=10 bits[5:4]=00 bits[7:6]=11 = 0xC9
556    let a_yzx = _mm_shuffle_ps(a_vec, a_vec, 0xC9);
557    // b_zxy = [b2, b0, b1, *]: bits[1:0]=10 bits[3:2]=00 bits[5:4]=01 bits[7:6]=11 = 0xD2
558    let b_zxy = _mm_shuffle_ps(b_vec, b_vec, 0xD2);
559
560    // a_zxy = [a2, a0, a1, *]: bits[1:0]=10 bits[3:2]=00 bits[5:4]=01 bits[7:6]=11 = 0xD2
561    let a_zxy = _mm_shuffle_ps(a_vec, a_vec, 0xD2);
562    // b_yzx = [b1, b2, b0, *]: bits[1:0]=01 bits[3:2]=10 bits[5:4]=00 bits[7:6]=11 = 0xC9
563    let b_yzx = _mm_shuffle_ps(b_vec, b_vec, 0xC9);
564
565    // Compute cross product: a_yzx * b_zxy - a_zxy * b_yzx
566    let prod1 = _mm_mul_ps(a_yzx, b_zxy);
567    let prod2 = _mm_mul_ps(a_zxy, b_yzx);
568    let result_vec = _mm_sub_ps(prod1, prod2);
569
570    // Extract result
571    let mut output = [0.0f32; 4];
572    _mm_storeu_ps(output.as_mut_ptr(), result_vec);
573
574    vec![output[0], output[1], output[2]]
575}
576
577#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
578#[target_feature(enable = "sse2")]
579unsafe fn outer_product_sse2(a: &[f32], b: &[f32]) -> Vec<Vec<f32>> {
580    #[cfg(feature = "no-std")]
581    use core::arch::x86_64::*;
582    #[cfg(not(feature = "no-std"))]
583    use core::arch::x86_64::*;
584
585    let m = a.len();
586    let n = b.len();
587    let mut result = vec![vec![0.0; n]; m];
588
589    for i in 0..m {
590        let a_broadcast = _mm_set1_ps(a[i]);
591        let mut j = 0;
592
593        while j + 4 <= n {
594            let b_vec = _mm_loadu_ps(b.as_ptr().add(j));
595            let prod = _mm_mul_ps(a_broadcast, b_vec);
596            _mm_storeu_ps(result[i].as_mut_ptr().add(j), prod);
597            j += 4;
598        }
599
600        // Handle remaining elements
601        while j < n {
602            result[i][j] = a[i] * b[j];
603            j += 1;
604        }
605    }
606
607    result
608}
609
610// ============================================================================
611// AVX2 implementations (x86/x86_64)
612// ============================================================================
613
614#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
615#[target_feature(enable = "avx2")]
616unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
617    #[cfg(feature = "no-std")]
618    use core::arch::x86_64::*;
619    #[cfg(not(feature = "no-std"))]
620    use core::arch::x86_64::*;
621
622    let mut sum = _mm256_setzero_ps();
623    let mut i = 0;
624
625    // Process 8 elements at a time
626    while i + 8 <= a.len() {
627        let a_vec = _mm256_loadu_ps(a.as_ptr().add(i));
628        let b_vec = _mm256_loadu_ps(b.as_ptr().add(i));
629        let prod = _mm256_mul_ps(a_vec, b_vec);
630        sum = _mm256_add_ps(sum, prod);
631        i += 8;
632    }
633
634    // Horizontal sum of the 8 elements in sum
635    let mut result = [0.0f32; 8];
636    _mm256_storeu_ps(result.as_mut_ptr(), sum);
637    let mut scalar_sum = result.iter().sum::<f32>();
638
639    // Handle remaining elements
640    while i < a.len() {
641        scalar_sum += a[i] * b[i];
642        i += 1;
643    }
644
645    scalar_sum
646}
647
648#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
649#[target_feature(enable = "avx2")]
650unsafe fn norm_l1_avx2(x: &[f32]) -> f32 {
651    #[cfg(feature = "no-std")]
652    use core::arch::x86_64::*;
653    #[cfg(not(feature = "no-std"))]
654    use core::arch::x86_64::*;
655
656    let abs_mask = _mm256_set1_ps(f32::from_bits(0x7FFFFFFF));
657    let mut sum = _mm256_setzero_ps();
658    let mut i = 0;
659
660    while i + 8 <= x.len() {
661        let x_vec = _mm256_loadu_ps(x.as_ptr().add(i));
662        let abs_vec = _mm256_and_ps(x_vec, abs_mask);
663        sum = _mm256_add_ps(sum, abs_vec);
664        i += 8;
665    }
666
667    // Horizontal sum
668    let mut result = [0.0f32; 8];
669    _mm256_storeu_ps(result.as_mut_ptr(), sum);
670    let mut scalar_sum = result.iter().sum::<f32>();
671
672    // Handle remaining elements
673    while i < x.len() {
674        scalar_sum += x[i].abs();
675        i += 1;
676    }
677
678    scalar_sum
679}
680
681#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
682#[target_feature(enable = "avx2")]
683unsafe fn norm_inf_avx2(x: &[f32]) -> f32 {
684    #[cfg(feature = "no-std")]
685    use core::arch::x86_64::*;
686    #[cfg(not(feature = "no-std"))]
687    use core::arch::x86_64::*;
688
689    let abs_mask = _mm256_set1_ps(f32::from_bits(0x7FFFFFFF));
690    let mut max_vec = _mm256_setzero_ps();
691    let mut i = 0;
692
693    while i + 8 <= x.len() {
694        let x_vec = _mm256_loadu_ps(x.as_ptr().add(i));
695        let abs_vec = _mm256_and_ps(x_vec, abs_mask);
696        max_vec = _mm256_max_ps(max_vec, abs_vec);
697        i += 8;
698    }
699
700    // Find maximum of the 8 elements
701    let mut result = [0.0f32; 8];
702    _mm256_storeu_ps(result.as_mut_ptr(), max_vec);
703    let mut max_val = result.iter().fold(0.0f32, |a, &b| a.max(b));
704
705    // Handle remaining elements
706    while i < x.len() {
707        max_val = max_val.max(x[i].abs());
708        i += 1;
709    }
710
711    max_val
712}
713
714#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
715#[target_feature(enable = "avx2")]
716unsafe fn euclidean_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
717    #[cfg(feature = "no-std")]
718    use core::arch::x86_64::*;
719    #[cfg(not(feature = "no-std"))]
720    use core::arch::x86_64::*;
721
722    let mut sum = _mm256_setzero_ps();
723    let mut i = 0;
724
725    while i + 8 <= a.len() {
726        let a_vec = _mm256_loadu_ps(a.as_ptr().add(i));
727        let b_vec = _mm256_loadu_ps(b.as_ptr().add(i));
728        let diff = _mm256_sub_ps(a_vec, b_vec);
729        let squared = _mm256_mul_ps(diff, diff);
730        sum = _mm256_add_ps(sum, squared);
731        i += 8;
732    }
733
734    // Horizontal sum
735    let mut result = [0.0f32; 8];
736    _mm256_storeu_ps(result.as_mut_ptr(), sum);
737    let mut scalar_sum = result.iter().sum::<f32>();
738
739    // Handle remaining elements
740    while i < a.len() {
741        let diff = a[i] - b[i];
742        scalar_sum += diff * diff;
743        i += 1;
744    }
745
746    scalar_sum.sqrt()
747}
748
749#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
750#[target_feature(enable = "avx2")]
751unsafe fn outer_product_avx2(a: &[f32], b: &[f32]) -> Vec<Vec<f32>> {
752    #[cfg(feature = "no-std")]
753    use core::arch::x86_64::*;
754    #[cfg(not(feature = "no-std"))]
755    use core::arch::x86_64::*;
756
757    let m = a.len();
758    let n = b.len();
759    let mut result = vec![vec![0.0; n]; m];
760
761    for i in 0..m {
762        let a_broadcast = _mm256_set1_ps(a[i]);
763        let mut j = 0;
764
765        while j + 8 <= n {
766            let b_vec = _mm256_loadu_ps(b.as_ptr().add(j));
767            let prod = _mm256_mul_ps(a_broadcast, b_vec);
768            _mm256_storeu_ps(result[i].as_mut_ptr().add(j), prod);
769            j += 8;
770        }
771
772        // Handle remaining elements
773        while j < n {
774            result[i][j] = a[i] * b[j];
775            j += 1;
776        }
777    }
778
779    result
780}
781
782// ============================================================================
783// AVX512 implementations (x86/x86_64)
784// ============================================================================
785
786#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
787#[target_feature(enable = "avx512f")]
788unsafe fn dot_product_avx512(a: &[f32], b: &[f32]) -> f32 {
789    #[cfg(feature = "no-std")]
790    use core::arch::x86_64::*;
791    #[cfg(not(feature = "no-std"))]
792    use core::arch::x86_64::*;
793
794    let mut sum = _mm512_setzero_ps();
795    let mut i = 0;
796
797    // Process 16 elements at a time
798    while i + 16 <= a.len() {
799        let a_vec = _mm512_loadu_ps(a.as_ptr().add(i));
800        let b_vec = _mm512_loadu_ps(b.as_ptr().add(i));
801        sum = _mm512_fmadd_ps(a_vec, b_vec, sum); // Fused multiply-add
802        i += 16;
803    }
804
805    // Horizontal sum of the 16 elements in sum
806    let scalar_sum = _mm512_reduce_add_ps(sum);
807
808    // Handle remaining elements
809    let mut remaining_sum = 0.0f32;
810    while i < a.len() {
811        remaining_sum += a[i] * b[i];
812        i += 1;
813    }
814
815    scalar_sum + remaining_sum
816}
817
818// ============================================================================
819// NEON implementations (ARM AArch64)
820// ============================================================================
821
822#[cfg(target_arch = "aarch64")]
823#[allow(dead_code)] // NEON dispatch; unused in --all-features (no-std disables runtime detection)
824#[target_feature(enable = "neon")]
825unsafe fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
826    use core::arch::aarch64::*;
827
828    let mut sum = vdupq_n_f32(0.0);
829    let mut i = 0;
830
831    // Process 4 elements at a time
832    while i + 4 <= a.len() {
833        let a_vec = vld1q_f32(a.as_ptr().add(i));
834        let b_vec = vld1q_f32(b.as_ptr().add(i));
835        sum = vfmaq_f32(sum, a_vec, b_vec); // Fused multiply-add
836        i += 4;
837    }
838
839    // Horizontal sum
840    let sum_pair = vpadd_f32(vget_low_f32(sum), vget_high_f32(sum));
841    let final_sum = vpadd_f32(sum_pair, sum_pair);
842    let mut scalar_sum = vget_lane_f32(final_sum, 0);
843
844    // Handle remaining elements
845    while i < a.len() {
846        scalar_sum += a[i] * b[i];
847        i += 1;
848    }
849
850    scalar_sum
851}
852
853#[cfg(target_arch = "aarch64")]
854#[allow(dead_code)] // NEON dispatch; unused in --all-features (no-std disables runtime detection)
855#[target_feature(enable = "neon")]
856unsafe fn norm_l1_neon(x: &[f32]) -> f32 {
857    use core::arch::aarch64::*;
858
859    let mut sum = vdupq_n_f32(0.0);
860    let mut i = 0;
861
862    while i + 4 <= x.len() {
863        let x_vec = vld1q_f32(x.as_ptr().add(i));
864        let abs_vec = vabsq_f32(x_vec);
865        sum = vaddq_f32(sum, abs_vec);
866        i += 4;
867    }
868
869    // Horizontal sum
870    let sum_pair = vpadd_f32(vget_low_f32(sum), vget_high_f32(sum));
871    let final_sum = vpadd_f32(sum_pair, sum_pair);
872    let mut scalar_sum = vget_lane_f32(final_sum, 0);
873
874    // Handle remaining elements
875    while i < x.len() {
876        scalar_sum += x[i].abs();
877        i += 1;
878    }
879
880    scalar_sum
881}
882
883#[allow(non_snake_case)]
884#[cfg(all(test, not(feature = "no-std")))]
885mod tests {
886    use super::*;
887
888    #[test]
889    fn test_dot_product() {
890        let a = vec![1.0, 2.0, 3.0, 4.0];
891        let b = vec![5.0, 6.0, 7.0, 8.0];
892        let result = dot_product(&a, &b);
893        assert_eq!(result, 70.0); // 1*5 + 2*6 + 3*7 + 4*8 = 5 + 12 + 21 + 32 = 70
894
895        // Test with empty vectors
896        let empty_a: Vec<f32> = vec![];
897        let empty_b: Vec<f32> = vec![];
898        assert_eq!(dot_product(&empty_a, &empty_b), 0.0);
899
900        // Test with single element
901        let single_a = vec![3.0];
902        let single_b = vec![4.0];
903        assert_eq!(dot_product(&single_a, &single_b), 12.0);
904    }
905
906    #[test]
907    fn test_norms() {
908        let x = vec![3.0, 4.0];
909
910        // L2 norm
911        let norm2 = norm_l2(&x);
912        assert_eq!(norm2, 5.0); // sqrt(3² + 4²) = sqrt(25) = 5
913
914        // L1 norm
915        let norm1 = norm_l1(&x);
916        assert_eq!(norm1, 7.0); // |3| + |4| = 7
917
918        // L∞ norm
919        let norm_inf_val = norm_inf(&x);
920        assert_eq!(norm_inf_val, 4.0); // max(|3|, |4|) = 4
921
922        // Test with negative values
923        let y = vec![-3.0, 4.0, -5.0];
924        assert_eq!(norm_l1(&y), 12.0); // |-3| + |4| + |-5| = 12
925        assert_eq!(norm_inf(&y), 5.0); // max(|-3|, |4|, |-5|) = 5
926
927        // Test with empty vector
928        let empty: Vec<f32> = vec![];
929        assert_eq!(norm_l2(&empty), 0.0);
930        assert_eq!(norm_l1(&empty), 0.0);
931        assert_eq!(norm_inf(&empty), 0.0);
932    }
933
934    #[test]
935    fn test_euclidean_distance() {
936        let a = vec![1.0, 2.0, 3.0];
937        let b = vec![4.0, 5.0, 6.0];
938        let result = euclidean_distance(&a, &b);
939        // sqrt((1-4)² + (2-5)² + (3-6)²) = sqrt(9 + 9 + 9) = sqrt(27) ≈ 5.196
940        assert!((result - 5.196).abs() < 0.01);
941
942        // Test with identical vectors
943        let identical = euclidean_distance(&a, &a);
944        assert_eq!(identical, 0.0);
945
946        // Test with empty vectors
947        let empty_a: Vec<f32> = vec![];
948        let empty_b: Vec<f32> = vec![];
949        assert_eq!(euclidean_distance(&empty_a, &empty_b), 0.0);
950    }
951
952    #[test]
953    fn test_cosine_similarity() {
954        // Test orthogonal vectors
955        let a = vec![1.0, 0.0, 0.0];
956        let b = vec![0.0, 1.0, 0.0];
957        let result = cosine_similarity(&a, &b);
958        assert!((result - 0.0).abs() < f32::EPSILON);
959
960        // Test identical vectors
961        let identical = cosine_similarity(&a, &a);
962        assert!((identical - 1.0).abs() < f32::EPSILON);
963
964        // Test opposite vectors
965        let opposite = vec![-1.0, 0.0, 0.0];
966        let opposite_sim = cosine_similarity(&a, &opposite);
967        assert!((opposite_sim - (-1.0)).abs() < f32::EPSILON);
968
969        // Test with zero vector
970        let zero = vec![0.0, 0.0, 0.0];
971        let zero_sim = cosine_similarity(&a, &zero);
972        assert_eq!(zero_sim, 0.0);
973
974        // Test with empty vectors
975        let empty_a: Vec<f32> = vec![];
976        let empty_b: Vec<f32> = vec![];
977        assert_eq!(cosine_similarity(&empty_a, &empty_b), 1.0);
978    }
979
980    #[test]
981    fn test_cross_product() {
982        // Test unit vectors
983        let i = vec![1.0, 0.0, 0.0];
984        let j = vec![0.0, 1.0, 0.0];
985        let result = cross_product(&i, &j).expect("operation should succeed");
986        assert_eq!(result, vec![0.0, 0.0, 1.0]); // i × j = k
987
988        // Test with general vectors
989        let a = vec![1.0, 2.0, 3.0];
990        let b = vec![4.0, 5.0, 6.0];
991        let cross = cross_product(&a, &b).expect("operation should succeed");
992        // Expected: (2*6 - 3*5, 3*4 - 1*6, 1*5 - 2*4) = (-3, 6, -3)
993        assert_eq!(cross, vec![-3.0, 6.0, -3.0]);
994
995        // Test error for wrong dimensions
996        let wrong_dim = vec![1.0, 2.0];
997        assert!(cross_product(&wrong_dim, &j).is_err());
998    }
999
1000    #[test]
1001    fn test_outer_product() {
1002        let a = vec![1.0, 2.0];
1003        let b = vec![3.0, 4.0, 5.0];
1004        let result = outer_product(&a, &b);
1005
1006        // Expected: [[1*3, 1*4, 1*5], [2*3, 2*4, 2*5]] = [[3, 4, 5], [6, 8, 10]]
1007        assert_eq!(result.len(), 2);
1008        assert_eq!(result[0].len(), 3);
1009        assert_eq!(result[0], vec![3.0, 4.0, 5.0]);
1010        assert_eq!(result[1], vec![6.0, 8.0, 10.0]);
1011
1012        // Test with empty vectors
1013        let empty_a: Vec<f32> = vec![];
1014        let empty_result = outer_product(&empty_a, &b);
1015        assert!(empty_result.is_empty());
1016
1017        let empty_b: Vec<f32> = vec![];
1018        let empty_result2 = outer_product(&a, &empty_b);
1019        assert!(empty_result2.is_empty());
1020    }
1021
1022    #[test]
1023    #[should_panic(expected = "Vectors must have the same length")]
1024    fn test_dot_product_dimension_mismatch() {
1025        let a = vec![1.0, 2.0, 3.0];
1026        let b = vec![4.0, 5.0];
1027        dot_product(&a, &b);
1028    }
1029
1030    #[test]
1031    #[should_panic(expected = "Vectors must have the same length")]
1032    fn test_euclidean_distance_dimension_mismatch() {
1033        let a = vec![1.0, 2.0, 3.0];
1034        let b = vec![4.0, 5.0];
1035        euclidean_distance(&a, &b);
1036    }
1037
1038    #[test]
1039    #[should_panic(expected = "Vectors must have the same length")]
1040    fn test_cosine_similarity_dimension_mismatch() {
1041        let a = vec![1.0, 2.0, 3.0];
1042        let b = vec![4.0, 5.0];
1043        cosine_similarity(&a, &b);
1044    }
1045}