1#[cfg(target_arch = "x86_64")]
5use std::arch::x86_64::*;
6
7#[cfg(target_arch = "aarch64")]
8use std::arch::aarch64::*;
9
10#[cfg(target_arch = "x86_64")]
12const MIN_DIM_SIZE_AVX: usize = 32;
13
14#[cfg(any(target_arch = "x86", target_arch = "x86_64", target_arch = "aarch64"))]
15const MIN_DIM_SIZE_SIMD: usize = 16;
16
17#[inline]
21pub fn dot_product_simd(a: &[f32], b: &[f32]) -> f32 {
22 if a.len() != b.len() {
23 return 0.0;
24 }
25
26 #[cfg(target_arch = "x86_64")]
28 {
29 if is_x86_feature_detected!("avx2")
30 && is_x86_feature_detected!("fma")
31 && a.len() >= MIN_DIM_SIZE_AVX
32 {
33 return unsafe { dot_product_avx2(a, b) };
34 }
35 }
36
37 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
38 {
39 if is_x86_feature_detected!("sse") && a.len() >= MIN_DIM_SIZE_SIMD {
40 return unsafe { dot_product_sse(a, b) };
41 }
42 }
43
44 #[cfg(target_arch = "aarch64")]
45 {
46 if std::arch::is_aarch64_feature_detected!("neon") && a.len() >= MIN_DIM_SIZE_SIMD {
47 return unsafe { dot_product_neon(a, b) };
48 }
49 }
50
51 dot_product_scalar(a, b)
54}
55
56#[cfg(target_arch = "x86_64")]
59#[target_feature(enable = "avx2", enable = "fma")]
60#[inline]
61unsafe fn dot_product_avx2(a: &[f32], b: &[f32]) -> f32 {
62 let dim = a.len();
63 let mut i = 0;
64
65 let mut sum1 = _mm256_setzero_ps();
66 let mut sum2 = _mm256_setzero_ps();
67
68 while i + 15 < dim {
70 let vx1 = _mm256_loadu_ps(a.as_ptr().add(i));
71 let vy1 = _mm256_loadu_ps(b.as_ptr().add(i));
72 let vx2 = _mm256_loadu_ps(a.as_ptr().add(i + 8));
73 let vy2 = _mm256_loadu_ps(b.as_ptr().add(i + 8));
74
75 sum1 = _mm256_fmadd_ps(vx1, vy1, sum1);
76 sum2 = _mm256_fmadd_ps(vx2, vy2, sum2);
77
78 i += 16;
79 }
80
81 let combined = _mm256_add_ps(sum1, sum2);
83
84 let sum_high = _mm256_extractf128_ps(combined, 1);
86 let sum_low = _mm256_castps256_ps128(combined);
87 let mut sum_128 = _mm_add_ps(sum_high, sum_low);
88
89 sum_128 = _mm_hadd_ps(sum_128, sum_128);
90 sum_128 = _mm_hadd_ps(sum_128, sum_128);
91
92 let mut dot = _mm_cvtss_f32(sum_128);
93
94 while i < dim {
96 dot += a[i] * b[i];
97 i += 1;
98 }
99
100 dot
101}
102
103#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
105#[target_feature(enable = "sse")]
106#[inline]
107unsafe fn dot_product_sse(a: &[f32], b: &[f32]) -> f32 {
108 #[cfg(target_arch = "x86")]
109 use std::arch::x86::*;
110 #[cfg(target_arch = "x86_64")]
111 use std::arch::x86_64::*;
112
113 let dim = a.len();
114 let mut i = 0;
115 let mut sum = _mm_setzero_ps();
116
117 while i + 3 < dim {
119 let va = _mm_loadu_ps(a.as_ptr().add(i));
120 let vb = _mm_loadu_ps(b.as_ptr().add(i));
121 sum = _mm_add_ps(sum, _mm_mul_ps(va, vb));
122 i += 4;
123 }
124
125 let shuf = _mm_shuffle_ps(sum, sum, 0b10_11_00_01);
127 sum = _mm_add_ps(sum, shuf);
128 let shuf = _mm_movehl_ps(sum, sum);
129 sum = _mm_add_ss(sum, shuf);
130
131 let mut dot = _mm_cvtss_f32(sum);
132
133 while i < dim {
135 dot += a[i] * b[i];
136 i += 1;
137 }
138
139 dot
140}
141
142#[cfg(target_arch = "aarch64")]
145#[target_feature(enable = "neon")]
146#[inline]
147unsafe fn dot_product_neon(a: &[f32], b: &[f32]) -> f32 {
148 let dim = a.len();
149 let mut i = 0;
150
151 let mut sum1 = vdupq_n_f32(0.0);
153 let mut sum2 = vdupq_n_f32(0.0);
154
155 while i + 7 < dim {
157 let va1 = vld1q_f32(a.as_ptr().add(i));
158 let vb1 = vld1q_f32(b.as_ptr().add(i));
159 let va2 = vld1q_f32(a.as_ptr().add(i + 4));
160 let vb2 = vld1q_f32(b.as_ptr().add(i + 4));
161
162 sum1 = vfmaq_f32(sum1, va1, vb1);
163 sum2 = vfmaq_f32(sum2, va2, vb2);
164
165 i += 8;
166 }
167
168 while i + 3 < dim {
170 let va = vld1q_f32(a.as_ptr().add(i));
171 let vb = vld1q_f32(b.as_ptr().add(i));
172 sum1 = vfmaq_f32(sum1, va, vb);
173 i += 4;
174 }
175
176 let combined = vaddq_f32(sum1, sum2);
178 let mut dot = vaddvq_f32(combined);
179
180 while i < dim {
182 dot += a[i] * b[i];
183 i += 1;
184 }
185
186 dot
187}
188
189#[inline]
191fn dot_product_scalar(a: &[f32], b: &[f32]) -> f32 {
192 let mut dot0 = 0.0f32;
193 let mut dot1 = 0.0f32;
194
195 let chunks = a.chunks_exact(8);
197 let remainder = chunks.remainder();
198 let b_chunks = b.chunks_exact(8);
199
200 for (a_chunk, b_chunk) in chunks.zip(b_chunks) {
201 dot0 += a_chunk[0] * b_chunk[0] +
202 a_chunk[1] * b_chunk[1] +
203 a_chunk[2] * b_chunk[2] +
204 a_chunk[3] * b_chunk[3];
205
206 dot1 += a_chunk[4] * b_chunk[4] +
207 a_chunk[5] * b_chunk[5] +
208 a_chunk[6] * b_chunk[6] +
209 a_chunk[7] * b_chunk[7];
210 }
211
212 for i in (a.len() - remainder.len())..a.len() {
214 dot0 += a[i] * b[i];
215 }
216
217 dot0 + dot1
218}
219
220
221#[inline]
223pub fn l2_distance_simd(a: &[f32], b: &[f32]) -> f32 {
224 if a.len() != b.len() {
225 return f32::INFINITY;
226 }
227
228 #[cfg(target_arch = "x86_64")]
230 {
231 if is_x86_feature_detected!("avx2")
232 && is_x86_feature_detected!("fma")
233 && a.len() >= MIN_DIM_SIZE_AVX
234 {
235 return unsafe { l2_distance_avx2(a, b) };
236 }
237 }
238
239 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
240 {
241 if is_x86_feature_detected!("sse") && a.len() >= MIN_DIM_SIZE_SIMD {
242 return unsafe { l2_distance_sse(a, b) };
243 }
244 }
245
246 #[cfg(target_arch = "aarch64")]
247 {
248 if std::arch::is_aarch64_feature_detected!("neon") && a.len() >= MIN_DIM_SIZE_SIMD {
249 return unsafe { l2_distance_neon(a, b) };
250 }
251 }
252
253 l2_distance_scalar(a, b)
254}
255
256#[cfg(target_arch = "x86_64")]
258#[target_feature(enable = "avx2", enable = "fma")]
259#[inline]
260unsafe fn l2_distance_avx2(a: &[f32], b: &[f32]) -> f32 {
261 let dim = a.len();
262 let mut i = 0;
263
264 let mut sum1 = _mm256_setzero_ps();
265 let mut sum2 = _mm256_setzero_ps();
266
267 while i + 15 < dim {
269 let va1 = _mm256_loadu_ps(a.as_ptr().add(i));
270 let vb1 = _mm256_loadu_ps(b.as_ptr().add(i));
271 let va2 = _mm256_loadu_ps(a.as_ptr().add(i + 8));
272 let vb2 = _mm256_loadu_ps(b.as_ptr().add(i + 8));
273
274 let diff1 = _mm256_sub_ps(va1, vb1);
275 let diff2 = _mm256_sub_ps(va2, vb2);
276
277 sum1 = _mm256_fmadd_ps(diff1, diff1, sum1);
278 sum2 = _mm256_fmadd_ps(diff2, diff2, sum2);
279
280 i += 16;
281 }
282
283 let combined = _mm256_add_ps(sum1, sum2);
285
286 let sum_high = _mm256_extractf128_ps(combined, 1);
288 let sum_low = _mm256_castps256_ps128(combined);
289 let mut sum_128 = _mm_add_ps(sum_high, sum_low);
290
291 sum_128 = _mm_hadd_ps(sum_128, sum_128);
292 sum_128 = _mm_hadd_ps(sum_128, sum_128);
293
294 let mut sum_sq = _mm_cvtss_f32(sum_128);
295
296 while i < dim {
298 let diff = a[i] - b[i];
299 sum_sq += diff * diff;
300 i += 1;
301 }
302
303 sum_sq.sqrt()
304}
305
306#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
308#[target_feature(enable = "sse")]
309#[inline]
310unsafe fn l2_distance_sse(a: &[f32], b: &[f32]) -> f32 {
311 #[cfg(target_arch = "x86")]
312 use std::arch::x86::*;
313 #[cfg(target_arch = "x86_64")]
314 use std::arch::x86_64::*;
315
316 let dim = a.len();
317 let mut i = 0;
318 let mut sum = _mm_setzero_ps();
319
320 while i + 3 < dim {
322 let va = _mm_loadu_ps(a.as_ptr().add(i));
323 let vb = _mm_loadu_ps(b.as_ptr().add(i));
324 let diff = _mm_sub_ps(va, vb);
325 sum = _mm_add_ps(sum, _mm_mul_ps(diff, diff));
326 i += 4;
327 }
328
329 let shuf = _mm_shuffle_ps(sum, sum, 0b10_11_00_01);
331 sum = _mm_add_ps(sum, shuf);
332 let shuf = _mm_movehl_ps(sum, sum);
333 sum = _mm_add_ss(sum, shuf);
334
335 let mut sum_sq = _mm_cvtss_f32(sum);
336
337 while i < dim {
339 let diff = a[i] - b[i];
340 sum_sq += diff * diff;
341 i += 1;
342 }
343
344 sum_sq.sqrt()
345}
346
347#[cfg(target_arch = "aarch64")]
350#[target_feature(enable = "neon")]
351#[inline]
352unsafe fn l2_distance_neon(a: &[f32], b: &[f32]) -> f32 {
353 let dim = a.len();
354 let mut i = 0;
355
356 let mut sum1 = vdupq_n_f32(0.0);
358 let mut sum2 = vdupq_n_f32(0.0);
359
360 while i + 7 < dim {
362 let va1 = vld1q_f32(a.as_ptr().add(i));
363 let vb1 = vld1q_f32(b.as_ptr().add(i));
364 let va2 = vld1q_f32(a.as_ptr().add(i + 4));
365 let vb2 = vld1q_f32(b.as_ptr().add(i + 4));
366
367 let diff1 = vsubq_f32(va1, vb1);
368 let diff2 = vsubq_f32(va2, vb2);
369
370 sum1 = vfmaq_f32(sum1, diff1, diff1);
371 sum2 = vfmaq_f32(sum2, diff2, diff2);
372
373 i += 8;
374 }
375
376 while i + 3 < dim {
378 let va = vld1q_f32(a.as_ptr().add(i));
379 let vb = vld1q_f32(b.as_ptr().add(i));
380 let diff = vsubq_f32(va, vb);
381 sum1 = vfmaq_f32(sum1, diff, diff);
382 i += 4;
383 }
384
385 let combined = vaddq_f32(sum1, sum2);
387 let mut sum_sq = vaddvq_f32(combined);
388
389 while i < dim {
391 let diff = a[i] - b[i];
392 sum_sq += diff * diff;
393 i += 1;
394 }
395
396 sum_sq.sqrt()
397}
398
399#[inline]
401fn l2_distance_scalar(a: &[f32], b: &[f32]) -> f32 {
402 let mut sum0 = 0.0f32;
403 let mut sum1 = 0.0f32;
404
405 let chunks = a.chunks_exact(4);
407 let remainder = chunks.remainder();
408 let b_chunks = b.chunks_exact(4);
409
410 for (a_chunk, b_chunk) in chunks.zip(b_chunks) {
411 let d0 = a_chunk[0] - b_chunk[0];
412 let d1 = a_chunk[1] - b_chunk[1];
413 let d2 = a_chunk[2] - b_chunk[2];
414 let d3 = a_chunk[3] - b_chunk[3];
415
416 sum0 += d0 * d0 + d1 * d1;
417 sum1 += d2 * d2 + d3 * d3;
418 }
419
420 for i in (a.len() - remainder.len())..a.len() {
422 let diff = a[i] - b[i];
423 sum0 += diff * diff;
424 }
425
426 (sum0 + sum1).sqrt()
427}
428
429#[inline]
431pub fn norm_squared_simd(v: &[f32]) -> f32 {
432 dot_product_simd(v, v)
433}
434
435#[inline]
437pub fn norm_simd(v: &[f32]) -> f32 {
438 norm_squared_simd(v).sqrt()
439}
440