1use super::dispatch::cpu_features;
39
40#[inline]
52pub fn dot_i8(a: &[i8], b: &[i8]) -> i32 {
53 assert_eq!(a.len(), b.len(), "vectors must have equal length");
54
55 let features = cpu_features();
56
57 #[cfg(target_arch = "x86_64")]
58 {
59 if features.has_avx2 {
60 return unsafe { dot_i8_avx2(a, b) };
62 }
63 }
64
65 #[cfg(target_arch = "aarch64")]
66 {
67 if features.has_neon {
68 return unsafe { dot_i8_neon(a, b) };
70 }
71 }
72
73 dot_i8_scalar(a, b)
74}
75
76#[inline]
90pub fn dot_i8_batch(query: &[i8], vectors: &[i8], scales: &[f32], dim: usize, results: &mut [f32]) {
91 let n_vec = scales.len();
92 assert!(query.len() >= dim, "query too short");
93 assert!(vectors.len() >= n_vec * dim, "vectors buffer too small");
94 assert!(results.len() >= n_vec, "results buffer too small");
95
96 let features = cpu_features();
97
98 #[cfg(target_arch = "x86_64")]
99 {
100 if features.has_avx2 {
101 unsafe { dot_i8_batch_avx2(query, vectors, scales, dim, results) };
102 return;
103 }
104 }
105
106 #[cfg(target_arch = "aarch64")]
107 {
108 if features.has_neon {
109 unsafe { dot_i8_batch_neon(query, vectors, scales, dim, results) };
110 return;
111 }
112 }
113
114 dot_i8_batch_scalar(query, vectors, scales, dim, results);
115}
116
117#[inline]
126pub fn dot_i8_indexed(
127 query: &[i8],
128 vectors: &[i8],
129 cand_ids: &[u32],
130 dim: usize,
131 out_scores: &mut [i32],
132) {
133 assert!(query.len() >= dim);
134 assert!(out_scores.len() >= cand_ids.len());
135
136 for (i, &cand_id) in cand_ids.iter().enumerate() {
137 let offset = cand_id as usize * dim;
138 let vec = &vectors[offset..offset + dim];
139 out_scores[i] = dot_i8(&query[..dim], vec);
140 }
141}
142
143#[cfg(target_arch = "x86_64")]
148#[target_feature(enable = "avx2")]
149unsafe fn dot_i8_avx2(a: &[i8], b: &[i8]) -> i32 {
150 use std::arch::x86_64::*;
151
152 unsafe {
153 let len = a.len();
154 let dim_aligned = (len / 32) * 32;
155
156 let mut acc = _mm256_setzero_si256();
157
158 for d in (0..dim_aligned).step_by(32) {
160 let q = _mm256_loadu_si256(a.as_ptr().add(d) as *const __m256i);
162 let v = _mm256_loadu_si256(b.as_ptr().add(d) as *const __m256i);
163
164 let q_lo = _mm256_castsi256_si128(q);
167 let q_hi = _mm256_extracti128_si256(q, 1);
168 let v_lo = _mm256_castsi256_si128(v);
169 let v_hi = _mm256_extracti128_si256(v, 1);
170
171 let q_lo_16 = _mm256_cvtepi8_epi16(q_lo);
173 let q_hi_16 = _mm256_cvtepi8_epi16(q_hi);
174 let v_lo_16 = _mm256_cvtepi8_epi16(v_lo);
175 let v_hi_16 = _mm256_cvtepi8_epi16(v_hi);
176
177 let prod_lo = _mm256_madd_epi16(q_lo_16, v_lo_16);
180 let prod_hi = _mm256_madd_epi16(q_hi_16, v_hi_16);
181
182 acc = _mm256_add_epi32(acc, prod_lo);
184 acc = _mm256_add_epi32(acc, prod_hi);
185 }
186
187 let acc_lo = _mm256_castsi256_si128(acc);
189 let acc_hi = _mm256_extracti128_si256(acc, 1);
190 let sum128 = _mm_add_epi32(acc_lo, acc_hi);
191
192 let sum128 = _mm_hadd_epi32(sum128, sum128);
194 let sum128 = _mm_hadd_epi32(sum128, sum128);
195
196 let mut result = _mm_cvtsi128_si32(sum128);
197
198 for d in dim_aligned..len {
200 result += (a[d] as i32) * (b[d] as i32);
201 }
202
203 result
204 }
205}
206
207#[cfg(target_arch = "x86_64")]
208#[target_feature(enable = "avx2")]
209unsafe fn dot_i8_batch_avx2(
210 query: &[i8],
211 vectors: &[i8],
212 scales: &[f32],
213 dim: usize,
214 results: &mut [f32],
215) {
216 unsafe {
217 let n_vec = scales.len();
218
219 for v in 0..n_vec {
220 let offset = v * dim;
221 let vec = &vectors[offset..offset + dim];
222 let int_dot = dot_i8_avx2(&query[..dim], vec);
223 results[v] = int_dot as f32 * scales[v];
224 }
225 }
226}
227
228#[cfg(target_arch = "aarch64")]
233#[target_feature(enable = "neon")]
234unsafe fn dot_i8_neon(a: &[i8], b: &[i8]) -> i32 {
235 use std::arch::aarch64::*;
236
237 unsafe {
238 let len = a.len();
239 let mut acc = vdupq_n_s32(0);
240
241 let mut i = 0;
242
243 while i + 16 <= len {
245 let va = vld1q_s8(a.as_ptr().add(i));
247 let vb = vld1q_s8(b.as_ptr().add(i));
248
249 let lo = vmull_s8(vget_low_s8(va), vget_low_s8(vb));
251 let hi = vmull_s8(vget_high_s8(va), vget_high_s8(vb));
252
253 acc = vpadalq_s16(acc, lo);
255 acc = vpadalq_s16(acc, hi);
256
257 i += 16;
258 }
259
260 let mut result = vaddvq_s32(acc);
262
263 while i < len {
265 result += (a[i] as i32) * (b[i] as i32);
266 i += 1;
267 }
268
269 result
270 }
271}
272
273#[cfg(target_arch = "aarch64")]
274#[target_feature(enable = "neon")]
275unsafe fn dot_i8_batch_neon(
276 query: &[i8],
277 vectors: &[i8],
278 scales: &[f32],
279 dim: usize,
280 results: &mut [f32],
281) {
282 unsafe {
283 let n_vec = scales.len();
284
285 for v in 0..n_vec {
286 let offset = v * dim;
287 let vec = &vectors[offset..offset + dim];
288 let int_dot = dot_i8_neon(&query[..dim], vec);
289 results[v] = int_dot as f32 * scales[v];
290 }
291 }
292}
293
294#[inline]
300fn dot_i8_scalar(a: &[i8], b: &[i8]) -> i32 {
301 a.iter()
302 .zip(b.iter())
303 .map(|(&x, &y)| (x as i32) * (y as i32))
304 .sum()
305}
306
307#[inline]
309fn dot_i8_batch_scalar(
310 query: &[i8],
311 vectors: &[i8],
312 scales: &[f32],
313 dim: usize,
314 results: &mut [f32],
315) {
316 for (i, &scale) in scales.iter().enumerate() {
317 let offset = i * dim;
318 let vec = &vectors[offset..offset + dim];
319 let int_dot = dot_i8_scalar(&query[..dim], vec);
320 results[i] = int_dot as f32 * scale;
321 }
322}
323
324#[inline]
332pub fn l2_distance_i8(a: &[i8], b: &[i8]) -> i32 {
333 assert_eq!(a.len(), b.len());
334
335 #[cfg(target_arch = "aarch64")]
336 {
337 let features = cpu_features();
338 if features.has_neon {
339 return unsafe { l2_distance_i8_neon(a, b) };
340 }
341 }
342
343 a.iter()
345 .zip(b.iter())
346 .map(|(&x, &y)| {
347 let diff = (x as i32) - (y as i32);
348 diff * diff
349 })
350 .sum()
351}
352
353#[cfg(target_arch = "aarch64")]
354#[target_feature(enable = "neon")]
355unsafe fn l2_distance_i8_neon(a: &[i8], b: &[i8]) -> i32 {
356 use std::arch::aarch64::*;
357
358 unsafe {
359 let len = a.len();
360 let mut acc = vdupq_n_s32(0);
361 let mut i = 0;
362
363 while i + 16 <= len {
364 let va = vld1q_s8(a.as_ptr().add(i));
365 let vb = vld1q_s8(b.as_ptr().add(i));
366
367 let diff_lo = vsubl_s8(vget_low_s8(va), vget_low_s8(vb));
369 let diff_hi = vsubl_s8(vget_high_s8(va), vget_high_s8(vb));
370
371 acc = vmlal_s16(acc, vget_low_s16(diff_lo), vget_low_s16(diff_lo));
373 acc = vmlal_s16(acc, vget_high_s16(diff_lo), vget_high_s16(diff_lo));
374 acc = vmlal_s16(acc, vget_low_s16(diff_hi), vget_low_s16(diff_hi));
375 acc = vmlal_s16(acc, vget_high_s16(diff_hi), vget_high_s16(diff_hi));
376
377 i += 16;
378 }
379
380 let mut result = vaddvq_s32(acc);
381
382 while i < len {
383 let diff = (a[i] as i32) - (b[i] as i32);
384 result += diff * diff;
385 i += 1;
386 }
387
388 result
389 }
390}
391
392#[cfg(test)]
393mod tests {
394 use super::*;
395
396 #[test]
397 fn test_dot_i8_basic() {
398 let a: Vec<i8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
399 let b: Vec<i8> = vec![8, 7, 6, 5, 4, 3, 2, 1];
400
401 let result = dot_i8(&a, &b);
402 let expected: i32 = a
403 .iter()
404 .zip(b.iter())
405 .map(|(&x, &y)| (x as i32) * (y as i32))
406 .sum();
407
408 assert_eq!(result, expected);
409 }
410
411 #[test]
412 fn test_dot_i8_large() {
413 let dim = 768;
415 let a: Vec<i8> = (0..dim)
416 .map(|i| ((i % 256) as i8).wrapping_add(-128))
417 .collect();
418 let b: Vec<i8> = (0..dim)
419 .map(|i| ((i * 7 % 256) as i8).wrapping_add(-128))
420 .collect();
421
422 let result = dot_i8(&a, &b);
423 let expected = dot_i8_scalar(&a, &b);
424
425 assert_eq!(result, expected);
426 }
427
428 #[test]
429 fn test_dot_i8_batch() {
430 let dim = 128;
431 let n_vec = 10;
432 let query: Vec<i8> = (0..dim).map(|i| (i % 127) as i8).collect();
433 let vectors: Vec<i8> = (0..n_vec * dim).map(|i| ((i * 3) % 127) as i8).collect();
434 let scales: Vec<f32> = (0..n_vec).map(|i| 0.01 * (i + 1) as f32).collect();
435 let mut results = vec![0.0f32; n_vec];
436
437 dot_i8_batch(&query, &vectors, &scales, dim, &mut results);
438
439 let mut expected = vec![0.0f32; n_vec];
441 dot_i8_batch_scalar(&query, &vectors, &scales, dim, &mut expected);
442
443 for (r, e) in results.iter().zip(expected.iter()) {
444 assert!((r - e).abs() < 1e-6, "result={}, expected={}", r, e);
445 }
446 }
447
448 #[test]
449 fn test_l2_distance() {
450 let a: Vec<i8> = vec![10, 20, 30, 40];
451 let b: Vec<i8> = vec![11, 22, 33, 44];
452
453 let result = l2_distance_i8(&a, &b);
454 assert_eq!(result, 30);
456 }
457}