ruvector_core/
simd_intrinsics.rs1#[cfg(target_arch = "x86_64")]
7use std::arch::x86_64::*;
8
9#[inline]
12pub fn euclidean_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
13 #[cfg(target_arch = "x86_64")]
14 {
15 if is_x86_feature_detected!("avx2") {
16 unsafe { euclidean_distance_avx2_impl(a, b) }
17 } else {
18 euclidean_distance_scalar(a, b)
19 }
20 }
21
22 #[cfg(not(target_arch = "x86_64"))]
23 {
24 euclidean_distance_scalar(a, b)
25 }
26}
27
28#[cfg(target_arch = "x86_64")]
29#[target_feature(enable = "avx2")]
30unsafe fn euclidean_distance_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
31 let len = a.len();
32 let mut sum = _mm256_setzero_ps();
33
34 let chunks = len / 8;
36 for i in 0..chunks {
37 let idx = i * 8;
38
39 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
41 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
42
43 let diff = _mm256_sub_ps(va, vb);
45
46 let sq = _mm256_mul_ps(diff, diff);
48
49 sum = _mm256_add_ps(sum, sq);
51 }
52
53 let sum_arr: [f32; 8] = std::mem::transmute(sum);
55 let mut total = sum_arr.iter().sum::<f32>();
56
57 for i in (chunks * 8)..len {
59 let diff = a[i] - b[i];
60 total += diff * diff;
61 }
62
63 total.sqrt()
64}
65
66#[inline]
68pub fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
69 #[cfg(target_arch = "x86_64")]
70 {
71 if is_x86_feature_detected!("avx2") {
72 unsafe { dot_product_avx2_impl(a, b) }
73 } else {
74 dot_product_scalar(a, b)
75 }
76 }
77
78 #[cfg(not(target_arch = "x86_64"))]
79 {
80 dot_product_scalar(a, b)
81 }
82}
83
84#[cfg(target_arch = "x86_64")]
85#[target_feature(enable = "avx2")]
86unsafe fn dot_product_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
87 let len = a.len();
88 let mut sum = _mm256_setzero_ps();
89
90 let chunks = len / 8;
91 for i in 0..chunks {
92 let idx = i * 8;
93 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
94 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
95 let prod = _mm256_mul_ps(va, vb);
96 sum = _mm256_add_ps(sum, prod);
97 }
98
99 let sum_arr: [f32; 8] = std::mem::transmute(sum);
100 let mut total = sum_arr.iter().sum::<f32>();
101
102 for i in (chunks * 8)..len {
103 total += a[i] * b[i];
104 }
105
106 total
107}
108
109#[inline]
111pub fn cosine_similarity_avx2(a: &[f32], b: &[f32]) -> f32 {
112 #[cfg(target_arch = "x86_64")]
113 {
114 if is_x86_feature_detected!("avx2") {
115 unsafe { cosine_similarity_avx2_impl(a, b) }
116 } else {
117 cosine_similarity_scalar(a, b)
118 }
119 }
120
121 #[cfg(not(target_arch = "x86_64"))]
122 {
123 cosine_similarity_scalar(a, b)
124 }
125}
126
127#[cfg(target_arch = "x86_64")]
128#[target_feature(enable = "avx2")]
129unsafe fn cosine_similarity_avx2_impl(a: &[f32], b: &[f32]) -> f32 {
130 let len = a.len();
131 let mut dot = _mm256_setzero_ps();
132 let mut norm_a = _mm256_setzero_ps();
133 let mut norm_b = _mm256_setzero_ps();
134
135 let chunks = len / 8;
136 for i in 0..chunks {
137 let idx = i * 8;
138 let va = _mm256_loadu_ps(a.as_ptr().add(idx));
139 let vb = _mm256_loadu_ps(b.as_ptr().add(idx));
140
141 dot = _mm256_add_ps(dot, _mm256_mul_ps(va, vb));
143
144 norm_a = _mm256_add_ps(norm_a, _mm256_mul_ps(va, va));
146 norm_b = _mm256_add_ps(norm_b, _mm256_mul_ps(vb, vb));
147 }
148
149 let dot_arr: [f32; 8] = std::mem::transmute(dot);
150 let norm_a_arr: [f32; 8] = std::mem::transmute(norm_a);
151 let norm_b_arr: [f32; 8] = std::mem::transmute(norm_b);
152
153 let mut dot_sum = dot_arr.iter().sum::<f32>();
154 let mut norm_a_sum = norm_a_arr.iter().sum::<f32>();
155 let mut norm_b_sum = norm_b_arr.iter().sum::<f32>();
156
157 for i in (chunks * 8)..len {
158 dot_sum += a[i] * b[i];
159 norm_a_sum += a[i] * a[i];
160 norm_b_sum += b[i] * b[i];
161 }
162
163 dot_sum / (norm_a_sum.sqrt() * norm_b_sum.sqrt())
164}
165
166fn euclidean_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
169 a.iter()
170 .zip(b.iter())
171 .map(|(x, y)| {
172 let diff = x - y;
173 diff * diff
174 })
175 .sum::<f32>()
176 .sqrt()
177}
178
179fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
180 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
181}
182
183fn cosine_similarity_scalar(a: &[f32], b: &[f32]) -> f32 {
184 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
185 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
186 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
187 dot / (norm_a * norm_b)
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193
194 #[test]
195 fn test_euclidean_distance_avx2() {
196 let a = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
197 let b = vec![2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0];
198
199 let result = euclidean_distance_avx2(&a, &b);
200 let expected = euclidean_distance_scalar(&a, &b);
201
202 assert!(
203 (result - expected).abs() < 0.001,
204 "AVX2 result {} differs from scalar result {}",
205 result,
206 expected
207 );
208 }
209
210 #[test]
211 fn test_dot_product_avx2() {
212 let a = vec![1.0; 16];
213 let b = vec![2.0; 16];
214
215 let result = dot_product_avx2(&a, &b);
216 assert!((result - 32.0).abs() < 0.001);
217 }
218
219 #[test]
220 fn test_cosine_similarity_avx2() {
221 let a = vec![1.0, 0.0, 0.0];
222 let b = vec![1.0, 0.0, 0.0];
223
224 let result = cosine_similarity_avx2(&a, &b);
225 assert!((result - 1.0).abs() < 0.001);
226 }
227}