velesdb_core/
simd_avx512.rs1use wide::f32x8;
22
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum SimdLevel {
26 Avx512,
28 Avx2,
30 Scalar,
32}
33
34#[must_use]
47pub fn detect_simd_level() -> SimdLevel {
48 #[cfg(target_arch = "x86_64")]
49 {
50 if is_x86_feature_detected!("avx512f") {
51 return SimdLevel::Avx512;
52 }
53 if is_x86_feature_detected!("avx2") {
54 return SimdLevel::Avx2;
55 }
56 }
57 SimdLevel::Scalar
58}
59
60#[must_use]
62#[inline]
63pub fn has_avx512() -> bool {
64 #[cfg(target_arch = "x86_64")]
65 {
66 is_x86_feature_detected!("avx512f")
67 }
68 #[cfg(not(target_arch = "x86_64"))]
69 {
70 false
71 }
72}
73
74#[inline]
86#[must_use]
87pub fn dot_product_auto(a: &[f32], b: &[f32]) -> f32 {
88 assert_eq!(a.len(), b.len(), "Vector dimensions must match");
89
90 if a.len() >= 16 {
92 return dot_product_wide16(a, b);
93 }
94
95 crate::simd_explicit::dot_product_simd(a, b)
97}
98
99#[inline]
105#[must_use]
106pub fn squared_l2_auto(a: &[f32], b: &[f32]) -> f32 {
107 assert_eq!(a.len(), b.len(), "Vector dimensions must match");
108
109 if a.len() >= 16 {
110 return squared_l2_wide16(a, b);
111 }
112
113 crate::simd_explicit::squared_l2_distance_simd(a, b)
114}
115
116#[inline]
118#[must_use]
119pub fn euclidean_auto(a: &[f32], b: &[f32]) -> f32 {
120 squared_l2_auto(a, b).sqrt()
121}
122
123#[inline]
129#[must_use]
130pub fn cosine_similarity_auto(a: &[f32], b: &[f32]) -> f32 {
131 assert_eq!(a.len(), b.len(), "Vector dimensions must match");
132
133 if a.len() >= 16 {
134 return cosine_similarity_wide16(a, b);
135 }
136
137 crate::simd_explicit::cosine_similarity_simd(a, b)
138}
139
140#[inline]
150fn dot_product_wide16(a: &[f32], b: &[f32]) -> f32 {
151 let len = a.len();
152 let simd_len = len / 32;
153
154 let mut sum0 = f32x8::ZERO;
156 let mut sum1 = f32x8::ZERO;
157 let mut sum2 = f32x8::ZERO;
158 let mut sum3 = f32x8::ZERO;
159
160 for i in 0..simd_len {
162 let offset = i * 32;
163
164 let va0 = f32x8::from(&a[offset..offset + 8]);
165 let vb0 = f32x8::from(&b[offset..offset + 8]);
166 sum0 = va0.mul_add(vb0, sum0);
167
168 let va1 = f32x8::from(&a[offset + 8..offset + 16]);
169 let vb1 = f32x8::from(&b[offset + 8..offset + 16]);
170 sum1 = va1.mul_add(vb1, sum1);
171
172 let va2 = f32x8::from(&a[offset + 16..offset + 24]);
173 let vb2 = f32x8::from(&b[offset + 16..offset + 24]);
174 sum2 = va2.mul_add(vb2, sum2);
175
176 let va3 = f32x8::from(&a[offset + 24..offset + 32]);
177 let vb3 = f32x8::from(&b[offset + 24..offset + 32]);
178 sum3 = va3.mul_add(vb3, sum3);
179 }
180
181 let combined01 = sum0 + sum1;
183 let combined23 = sum2 + sum3;
184 let mut result = (combined01 + combined23).reduce_add();
185
186 let base = simd_len * 32;
188 let mut pos = base;
189
190 while pos + 8 <= len {
191 let va = f32x8::from(&a[pos..pos + 8]);
192 let vb = f32x8::from(&b[pos..pos + 8]);
193 result += va.mul_add(vb, f32x8::ZERO).reduce_add();
194 pos += 8;
195 }
196
197 while pos < len {
199 result += a[pos] * b[pos];
200 pos += 1;
201 }
202
203 result
204}
205
206#[inline]
208fn squared_l2_wide16(a: &[f32], b: &[f32]) -> f32 {
209 let len = a.len();
210 let simd_len = len / 32;
211
212 let mut sum0 = f32x8::ZERO;
213 let mut sum1 = f32x8::ZERO;
214 let mut sum2 = f32x8::ZERO;
215 let mut sum3 = f32x8::ZERO;
216
217 for i in 0..simd_len {
218 let offset = i * 32;
219
220 let va0 = f32x8::from(&a[offset..offset + 8]);
221 let vb0 = f32x8::from(&b[offset..offset + 8]);
222 let diff0 = va0 - vb0;
223 sum0 = diff0.mul_add(diff0, sum0);
224
225 let va1 = f32x8::from(&a[offset + 8..offset + 16]);
226 let vb1 = f32x8::from(&b[offset + 8..offset + 16]);
227 let diff1 = va1 - vb1;
228 sum1 = diff1.mul_add(diff1, sum1);
229
230 let va2 = f32x8::from(&a[offset + 16..offset + 24]);
231 let vb2 = f32x8::from(&b[offset + 16..offset + 24]);
232 let diff2 = va2 - vb2;
233 sum2 = diff2.mul_add(diff2, sum2);
234
235 let va3 = f32x8::from(&a[offset + 24..offset + 32]);
236 let vb3 = f32x8::from(&b[offset + 24..offset + 32]);
237 let diff3 = va3 - vb3;
238 sum3 = diff3.mul_add(diff3, sum3);
239 }
240
241 let combined01 = sum0 + sum1;
242 let combined23 = sum2 + sum3;
243 let mut result = (combined01 + combined23).reduce_add();
244
245 let base = simd_len * 32;
247 let mut pos = base;
248
249 while pos + 8 <= len {
250 let va = f32x8::from(&a[pos..pos + 8]);
251 let vb = f32x8::from(&b[pos..pos + 8]);
252 let diff = va - vb;
253 result += diff.mul_add(diff, f32x8::ZERO).reduce_add();
254 pos += 8;
255 }
256
257 while pos < len {
258 let diff = a[pos] - b[pos];
259 result += diff * diff;
260 pos += 1;
261 }
262
263 result
264}
265
266#[inline]
270#[allow(clippy::similar_names)]
271fn cosine_similarity_wide16(a: &[f32], b: &[f32]) -> f32 {
272 let len = a.len();
273 let simd_len = len / 32;
274
275 let mut dot0 = f32x8::ZERO;
277 let mut dot1 = f32x8::ZERO;
278 let mut dot2 = f32x8::ZERO;
279 let mut dot3 = f32x8::ZERO;
280 let mut na0 = f32x8::ZERO;
281 let mut na1 = f32x8::ZERO;
282 let mut na2 = f32x8::ZERO;
283 let mut na3 = f32x8::ZERO;
284 let mut nb0 = f32x8::ZERO;
285 let mut nb1 = f32x8::ZERO;
286 let mut nb2 = f32x8::ZERO;
287 let mut nb3 = f32x8::ZERO;
288
289 for i in 0..simd_len {
290 let offset = i * 32;
291
292 let va0 = f32x8::from(&a[offset..offset + 8]);
293 let vb0 = f32x8::from(&b[offset..offset + 8]);
294 dot0 = va0.mul_add(vb0, dot0);
295 na0 = va0.mul_add(va0, na0);
296 nb0 = vb0.mul_add(vb0, nb0);
297
298 let va1 = f32x8::from(&a[offset + 8..offset + 16]);
299 let vb1 = f32x8::from(&b[offset + 8..offset + 16]);
300 dot1 = va1.mul_add(vb1, dot1);
301 na1 = va1.mul_add(va1, na1);
302 nb1 = vb1.mul_add(vb1, nb1);
303
304 let va2 = f32x8::from(&a[offset + 16..offset + 24]);
305 let vb2 = f32x8::from(&b[offset + 16..offset + 24]);
306 dot2 = va2.mul_add(vb2, dot2);
307 na2 = va2.mul_add(va2, na2);
308 nb2 = vb2.mul_add(vb2, nb2);
309
310 let va3 = f32x8::from(&a[offset + 24..offset + 32]);
311 let vb3 = f32x8::from(&b[offset + 24..offset + 32]);
312 dot3 = va3.mul_add(vb3, dot3);
313 na3 = va3.mul_add(va3, na3);
314 nb3 = vb3.mul_add(vb3, nb3);
315 }
316
317 let mut dot = ((dot0 + dot1) + (dot2 + dot3)).reduce_add();
319 let mut norm_a_sq = ((na0 + na1) + (na2 + na3)).reduce_add();
320 let mut norm_b_sq = ((nb0 + nb1) + (nb2 + nb3)).reduce_add();
321
322 let base = simd_len * 32;
324 let mut pos = base;
325
326 while pos + 8 <= len {
327 let va = f32x8::from(&a[pos..pos + 8]);
328 let vb = f32x8::from(&b[pos..pos + 8]);
329 dot += va.mul_add(vb, f32x8::ZERO).reduce_add();
330 norm_a_sq += va.mul_add(va, f32x8::ZERO).reduce_add();
331 norm_b_sq += vb.mul_add(vb, f32x8::ZERO).reduce_add();
332 pos += 8;
333 }
334
335 while pos < len {
336 let ai = a[pos];
337 let bi = b[pos];
338 dot += ai * bi;
339 norm_a_sq += ai * ai;
340 norm_b_sq += bi * bi;
341 pos += 1;
342 }
343
344 let norm_a = norm_a_sq.sqrt();
345 let norm_b = norm_b_sq.sqrt();
346
347 if norm_a == 0.0 || norm_b == 0.0 {
348 return 0.0;
349 }
350
351 dot / (norm_a * norm_b)
352}
353
354#[inline]
389#[must_use]
390pub fn cosine_similarity_normalized(a: &[f32], b: &[f32]) -> f32 {
391 dot_product_auto(a, b)
393}
394
395#[must_use]
405pub fn batch_cosine_normalized(candidates: &[&[f32]], query: &[f32]) -> Vec<f32> {
406 let mut results = Vec::with_capacity(candidates.len());
407
408 for (i, candidate) in candidates.iter().enumerate() {
409 if i + 4 < candidates.len() {
411 #[cfg(target_arch = "x86_64")]
412 unsafe {
413 use std::arch::x86_64::{_mm_prefetch, _MM_HINT_T0};
414 _mm_prefetch(candidates[i + 4].as_ptr().cast::<i8>(), _MM_HINT_T0);
415 }
416 }
417
418 results.push(dot_product_auto(candidate, query));
419 }
420
421 results
422}
423
424