ruvector_memopt/accel/
simd.rs

1//! SIMD-optimized operations for pattern matching
2
3use super::cpu::CpuCapabilities;
4
5/// SIMD optimizer for vector operations
6pub struct SimdOptimizer {
7    caps: CpuCapabilities,
8}
9
10impl SimdOptimizer {
11    pub fn new() -> Self {
12        Self { caps: CpuCapabilities::detect() }
13    }
14
15    /// SIMD-optimized Euclidean distance calculation
16    #[cfg(target_arch = "x86_64")]
17    pub fn euclidean_distance(&self, a: &[f32], b: &[f32]) -> f32 {
18        if a.len() != b.len() { return f32::MAX; }
19
20        if self.caps.has_avx2 && a.len() >= 8 {
21            // SAFETY: We've checked has_avx2 is true
22            unsafe { self.euclidean_distance_avx2(a, b) }
23        } else {
24            self.euclidean_distance_scalar(a, b)
25        }
26    }
27
28    #[cfg(not(target_arch = "x86_64"))]
29    pub fn euclidean_distance(&self, a: &[f32], b: &[f32]) -> f32 {
30        self.euclidean_distance_scalar(a, b)
31    }
32
33    fn euclidean_distance_scalar(&self, a: &[f32], b: &[f32]) -> f32 {
34        a.iter().zip(b.iter())
35            .map(|(x, y)| (x - y).powi(2))
36            .sum::<f32>()
37            .sqrt()
38    }
39
40    #[cfg(target_arch = "x86_64")]
41    #[target_feature(enable = "avx2")]
42    unsafe fn euclidean_distance_avx2(&self, a: &[f32], b: &[f32]) -> f32 {
43        use std::arch::x86_64::*;
44
45        let len = a.len();
46        let chunks = len / 8;
47        let mut sum = _mm256_setzero_ps();
48
49        for i in 0..chunks {
50            let offset = i * 8;
51            let va = _mm256_loadu_ps(a.as_ptr().add(offset));
52            let vb = _mm256_loadu_ps(b.as_ptr().add(offset));
53            let diff = _mm256_sub_ps(va, vb);
54            let sq = _mm256_mul_ps(diff, diff);
55            sum = _mm256_add_ps(sum, sq);
56        }
57
58        // Horizontal sum
59        let high = _mm256_extractf128_ps(sum, 1);
60        let low = _mm256_castps256_ps128(sum);
61        let sum128 = _mm_add_ps(low, high);
62        let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
63        let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
64
65        let mut result = _mm_cvtss_f32(sum32);
66
67        // Handle remainder
68        for i in (chunks * 8)..len {
69            let diff = a[i] - b[i];
70            result += diff * diff;
71        }
72
73        result.sqrt()
74    }
75
76    /// SIMD-optimized dot product
77    pub fn dot_product(&self, a: &[f32], b: &[f32]) -> f32 {
78        if a.len() != b.len() { return 0.0; }
79
80        #[cfg(target_arch = "x86_64")]
81        if self.caps.has_avx2 && a.len() >= 8 {
82            return unsafe { self.dot_product_avx2(a, b) };
83        }
84
85        self.dot_product_scalar(a, b)
86    }
87
88    fn dot_product_scalar(&self, a: &[f32], b: &[f32]) -> f32 {
89        a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
90    }
91
92    #[cfg(target_arch = "x86_64")]
93    #[target_feature(enable = "avx2", enable = "fma")]
94    unsafe fn dot_product_avx2(&self, a: &[f32], b: &[f32]) -> f32 {
95        use std::arch::x86_64::*;
96
97        let len = a.len().min(b.len());
98        let chunks = len / 8;
99        let mut sum = _mm256_setzero_ps();
100
101        for i in 0..chunks {
102            let offset = i * 8;
103            let va = _mm256_loadu_ps(a.as_ptr().add(offset));
104            let vb = _mm256_loadu_ps(b.as_ptr().add(offset));
105            sum = _mm256_fmadd_ps(va, vb, sum);
106        }
107
108        // Horizontal sum
109        let high = _mm256_extractf128_ps(sum, 1);
110        let low = _mm256_castps256_ps128(sum);
111        let sum128 = _mm_add_ps(low, high);
112        let sum64 = _mm_add_ps(sum128, _mm_movehl_ps(sum128, sum128));
113        let sum32 = _mm_add_ss(sum64, _mm_shuffle_ps(sum64, sum64, 1));
114
115        let mut result = _mm_cvtss_f32(sum32);
116
117        for i in (chunks * 8)..len {
118            result += a[i] * b[i];
119        }
120
121        result
122    }
123
124    /// Batch distance calculation
125    pub fn batch_distances(&self, query: &[f32], vectors: &[Vec<f32>]) -> Vec<f32> {
126        vectors.iter()
127            .map(|v| self.euclidean_distance(query, v))
128            .collect()
129    }
130
131    /// Benchmark SIMD vs scalar
132    pub fn benchmark(&self, dim: usize, iterations: usize) -> (f64, f64, f64) {
133        use std::time::Instant;
134
135        let a: Vec<f32> = (0..dim).map(|i| i as f32 * 0.1).collect();
136        let b: Vec<f32> = (0..dim).map(|i| (dim - i) as f32 * 0.1).collect();
137
138        // Scalar benchmark
139        let start = Instant::now();
140        for _ in 0..iterations {
141            let _ = self.euclidean_distance_scalar(&a, &b);
142        }
143        let scalar_time = start.elapsed().as_secs_f64();
144
145        // SIMD benchmark
146        let start = Instant::now();
147        for _ in 0..iterations {
148            let _ = self.euclidean_distance(&a, &b);
149        }
150        let simd_time = start.elapsed().as_secs_f64();
151
152        let speedup = scalar_time / simd_time;
153        (scalar_time, simd_time, speedup)
154    }
155
156    pub fn capabilities(&self) -> &CpuCapabilities { &self.caps }
157}
158
159impl Default for SimdOptimizer {
160    fn default() -> Self { Self::new() }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166
167    #[test]
168    fn test_euclidean_distance() {
169        let opt = SimdOptimizer::new();
170        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
171        let b = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
172        assert!((opt.euclidean_distance(&a, &b) - 0.0).abs() < 0.001);
173    }
174
175    #[test]
176    fn test_dot_product() {
177        let opt = SimdOptimizer::new();
178        let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
179        let b = vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0];
180        assert!((opt.dot_product(&a, &b) - 36.0).abs() < 0.001);
181    }
182}