Skip to main content

sochdb_vector/simd/
dot_i8.rs

1//! Int8 Dot Product Kernel
2//!
3//! This module provides SIMD-accelerated int8 dot product computation
4//! for reranking candidates after the initial BPS scan.
5//!
6//! # Algorithm
7//!
8//! ```text
9//! dot(Q, V) = Σ_{d=0}^{D-1} Q[d] × V[d]
10//! ```
11//!
12//! # Overflow Analysis
13//!
14//! For D=768 dimensions with i8 values in [-127, 127]:
15//! ```text
16//! max_product = 127 × 127 = 16,129
17//! max_sum = 768 × 16,129 = 12,387,072 < 2^31 - 1 (i32 max)
18//! ```
19//! Thus, i32 accumulation is sufficient.
20//!
21//! # Implementation Strategy
22//!
23//! ## x86_64 AVX2
24//! Uses sign-extension to i16 followed by `_mm256_madd_epi16`:
25//! 1. Load 32 i8 values
26//! 2. Sign-extend to 2×16 i16 values
27//! 3. Multiply-add pairs: (a0*b0 + a1*b1) -> i32
28//! 4. Accumulate i32 results
29//!
30//! ## ARM NEON
31//! Uses `vmull_s8` to multiply 8 i8 pairs to i16, then `vpadalq_s16` to
32//! widen and accumulate to i32.
33//!
34//! ## Future: VNNI/SDOT
35//! - AVX-512 VNNI: `_mm256_dpbssd_epi32` (single instruction, i8×i8→i32)
36//! - ARM v8.2 SDOT: `vdotq_s32` (single instruction, i8×i8→i32)
37
38use super::dispatch::cpu_features;
39
40/// Compute the dot product of two i8 vectors.
41///
42/// # Arguments
43/// * `a` - First vector (i8)
44/// * `b` - Second vector (i8, same length as `a`)
45///
46/// # Returns
47/// The dot product as i32
48///
49/// # Panics
50/// Panics if `a.len() != b.len()`
51#[inline]
52pub fn dot_i8(a: &[i8], b: &[i8]) -> i32 {
53    assert_eq!(a.len(), b.len(), "vectors must have equal length");
54
55    let features = cpu_features();
56
57    #[cfg(target_arch = "x86_64")]
58    {
59        if features.has_avx2 {
60            // Safety: AVX2 feature is verified
61            return unsafe { dot_i8_avx2(a, b) };
62        }
63    }
64
65    #[cfg(target_arch = "aarch64")]
66    {
67        if features.has_neon {
68            // Safety: NEON is mandatory on aarch64
69            return unsafe { dot_i8_neon(a, b) };
70        }
71    }
72
73    dot_i8_scalar(a, b)
74}
75
76/// Compute dot products for a batch of vectors with dequantization.
77///
78/// Computes: `result[i] = dot(query, vectors[i * dim..(i+1) * dim]) * scales[i]`
79///
80/// # Arguments
81/// * `query` - Query vector (i8)
82/// * `vectors` - Flattened database vectors (i8, n_vec × dim)
83/// * `scales` - Per-vector dequantization scales
84/// * `dim` - Dimension of each vector
85/// * `results` - Output dequantized dot products
86///
87/// # Panics
88/// Panics if buffer sizes are inconsistent
89#[inline]
90pub fn dot_i8_batch(query: &[i8], vectors: &[i8], scales: &[f32], dim: usize, results: &mut [f32]) {
91    let n_vec = scales.len();
92    assert!(query.len() >= dim, "query too short");
93    assert!(vectors.len() >= n_vec * dim, "vectors buffer too small");
94    assert!(results.len() >= n_vec, "results buffer too small");
95
96    let features = cpu_features();
97
98    #[cfg(target_arch = "x86_64")]
99    {
100        if features.has_avx2 {
101            unsafe { dot_i8_batch_avx2(query, vectors, scales, dim, results) };
102            return;
103        }
104    }
105
106    #[cfg(target_arch = "aarch64")]
107    {
108        if features.has_neon {
109            unsafe { dot_i8_batch_neon(query, vectors, scales, dim, results) };
110            return;
111        }
112    }
113
114    dot_i8_batch_scalar(query, vectors, scales, dim, results);
115}
116
117/// Compute dot products for indexed candidates.
118///
119/// # Arguments
120/// * `query` - Query vector (i8)
121/// * `vectors` - All vectors (i8, total_vecs × dim)
122/// * `cand_ids` - Candidate indices to compute
123/// * `dim` - Dimension of each vector
124/// * `out_scores` - Output i32 dot products
125#[inline]
126pub fn dot_i8_indexed(
127    query: &[i8],
128    vectors: &[i8],
129    cand_ids: &[u32],
130    dim: usize,
131    out_scores: &mut [i32],
132) {
133    assert!(query.len() >= dim);
134    assert!(out_scores.len() >= cand_ids.len());
135
136    for (i, &cand_id) in cand_ids.iter().enumerate() {
137        let offset = cand_id as usize * dim;
138        let vec = &vectors[offset..offset + dim];
139        out_scores[i] = dot_i8(&query[..dim], vec);
140    }
141}
142
143// ============================================================================
144// x86_64 AVX2 Implementation
145// ============================================================================
146
147#[cfg(target_arch = "x86_64")]
148#[target_feature(enable = "avx2")]
149unsafe fn dot_i8_avx2(a: &[i8], b: &[i8]) -> i32 {
150    use std::arch::x86_64::*;
151
152    unsafe {
153        let len = a.len();
154        let dim_aligned = (len / 32) * 32;
155
156        let mut acc = _mm256_setzero_si256();
157
158        // Main loop: process 32 dimensions per iteration
159        for d in (0..dim_aligned).step_by(32) {
160            // Load 32 bytes from each vector
161            let q = _mm256_loadu_si256(a.as_ptr().add(d) as *const __m256i);
162            let v = _mm256_loadu_si256(b.as_ptr().add(d) as *const __m256i);
163
164            // For signed × signed, we use sign extension to i16 then madd
165            // Extract low and high 128-bit lanes
166            let q_lo = _mm256_castsi256_si128(q);
167            let q_hi = _mm256_extracti128_si256(q, 1);
168            let v_lo = _mm256_castsi256_si128(v);
169            let v_hi = _mm256_extracti128_si256(v, 1);
170
171            // Sign-extend i8 to i16
172            let q_lo_16 = _mm256_cvtepi8_epi16(q_lo);
173            let q_hi_16 = _mm256_cvtepi8_epi16(q_hi);
174            let v_lo_16 = _mm256_cvtepi8_epi16(v_lo);
175            let v_hi_16 = _mm256_cvtepi8_epi16(v_hi);
176
177            // Multiply i16 × i16 → i32 with horizontal add (madd)
178            // madd: (a0*b0 + a1*b1, a2*b2 + a3*b3, ...) -> 8 i32
179            let prod_lo = _mm256_madd_epi16(q_lo_16, v_lo_16);
180            let prod_hi = _mm256_madd_epi16(q_hi_16, v_hi_16);
181
182            // Accumulate
183            acc = _mm256_add_epi32(acc, prod_lo);
184            acc = _mm256_add_epi32(acc, prod_hi);
185        }
186
187        // Horizontal sum of acc (8 × i32)
188        let acc_lo = _mm256_castsi256_si128(acc);
189        let acc_hi = _mm256_extracti128_si256(acc, 1);
190        let sum128 = _mm_add_epi32(acc_lo, acc_hi);
191
192        // Horizontal add within 128-bit register
193        let sum128 = _mm_hadd_epi32(sum128, sum128);
194        let sum128 = _mm_hadd_epi32(sum128, sum128);
195
196        let mut result = _mm_cvtsi128_si32(sum128);
197
198        // Handle remaining dimensions
199        for d in dim_aligned..len {
200            result += (a[d] as i32) * (b[d] as i32);
201        }
202
203        result
204    }
205}
206
207#[cfg(target_arch = "x86_64")]
208#[target_feature(enable = "avx2")]
209unsafe fn dot_i8_batch_avx2(
210    query: &[i8],
211    vectors: &[i8],
212    scales: &[f32],
213    dim: usize,
214    results: &mut [f32],
215) {
216    unsafe {
217        let n_vec = scales.len();
218
219        for v in 0..n_vec {
220            let offset = v * dim;
221            let vec = &vectors[offset..offset + dim];
222            let int_dot = dot_i8_avx2(&query[..dim], vec);
223            results[v] = int_dot as f32 * scales[v];
224        }
225    }
226}
227
228// ============================================================================
229// aarch64 NEON Implementation
230// ============================================================================
231
232#[cfg(target_arch = "aarch64")]
233#[target_feature(enable = "neon")]
234unsafe fn dot_i8_neon(a: &[i8], b: &[i8]) -> i32 {
235    use std::arch::aarch64::*;
236
237    unsafe {
238        let len = a.len();
239        let mut acc = vdupq_n_s32(0);
240
241        let mut i = 0;
242
243        // Process 16 elements at a time
244        while i + 16 <= len {
245            // Load 16 i8 values each
246            let va = vld1q_s8(a.as_ptr().add(i));
247            let vb = vld1q_s8(b.as_ptr().add(i));
248
249            // Widen to i16 and multiply
250            let lo = vmull_s8(vget_low_s8(va), vget_low_s8(vb));
251            let hi = vmull_s8(vget_high_s8(va), vget_high_s8(vb));
252
253            // Widen to i32 and accumulate
254            acc = vpadalq_s16(acc, lo);
255            acc = vpadalq_s16(acc, hi);
256
257            i += 16;
258        }
259
260        // Horizontal sum
261        let mut result = vaddvq_s32(acc);
262
263        // Handle remainder
264        while i < len {
265            result += (a[i] as i32) * (b[i] as i32);
266            i += 1;
267        }
268
269        result
270    }
271}
272
273#[cfg(target_arch = "aarch64")]
274#[target_feature(enable = "neon")]
275unsafe fn dot_i8_batch_neon(
276    query: &[i8],
277    vectors: &[i8],
278    scales: &[f32],
279    dim: usize,
280    results: &mut [f32],
281) {
282    unsafe {
283        let n_vec = scales.len();
284
285        for v in 0..n_vec {
286            let offset = v * dim;
287            let vec = &vectors[offset..offset + dim];
288            let int_dot = dot_i8_neon(&query[..dim], vec);
289            results[v] = int_dot as f32 * scales[v];
290        }
291    }
292}
293
294// ============================================================================
295// Scalar Fallback
296// ============================================================================
297
298/// Scalar dot product
299#[inline]
300fn dot_i8_scalar(a: &[i8], b: &[i8]) -> i32 {
301    a.iter()
302        .zip(b.iter())
303        .map(|(&x, &y)| (x as i32) * (y as i32))
304        .sum()
305}
306
307/// Scalar batch with dequantization
308#[inline]
309fn dot_i8_batch_scalar(
310    query: &[i8],
311    vectors: &[i8],
312    scales: &[f32],
313    dim: usize,
314    results: &mut [f32],
315) {
316    for (i, &scale) in scales.iter().enumerate() {
317        let offset = i * dim;
318        let vec = &vectors[offset..offset + dim];
319        let int_dot = dot_i8_scalar(&query[..dim], vec);
320        results[i] = int_dot as f32 * scale;
321    }
322}
323
324// ============================================================================
325// L2 Distance (bonus)
326// ============================================================================
327
328/// Compute squared L2 distance between two i8 vectors.
329///
330/// dist = sum((a[i] - b[i])^2)
331#[inline]
332pub fn l2_distance_i8(a: &[i8], b: &[i8]) -> i32 {
333    assert_eq!(a.len(), b.len());
334
335    #[cfg(target_arch = "aarch64")]
336    {
337        let features = cpu_features();
338        if features.has_neon {
339            return unsafe { l2_distance_i8_neon(a, b) };
340        }
341    }
342
343    // Scalar fallback
344    a.iter()
345        .zip(b.iter())
346        .map(|(&x, &y)| {
347            let diff = (x as i32) - (y as i32);
348            diff * diff
349        })
350        .sum()
351}
352
353#[cfg(target_arch = "aarch64")]
354#[target_feature(enable = "neon")]
355unsafe fn l2_distance_i8_neon(a: &[i8], b: &[i8]) -> i32 {
356    use std::arch::aarch64::*;
357
358    unsafe {
359        let len = a.len();
360        let mut acc = vdupq_n_s32(0);
361        let mut i = 0;
362
363        while i + 16 <= len {
364            let va = vld1q_s8(a.as_ptr().add(i));
365            let vb = vld1q_s8(b.as_ptr().add(i));
366
367            // Compute difference (widen to avoid overflow)
368            let diff_lo = vsubl_s8(vget_low_s8(va), vget_low_s8(vb));
369            let diff_hi = vsubl_s8(vget_high_s8(va), vget_high_s8(vb));
370
371            // Square and accumulate
372            acc = vmlal_s16(acc, vget_low_s16(diff_lo), vget_low_s16(diff_lo));
373            acc = vmlal_s16(acc, vget_high_s16(diff_lo), vget_high_s16(diff_lo));
374            acc = vmlal_s16(acc, vget_low_s16(diff_hi), vget_low_s16(diff_hi));
375            acc = vmlal_s16(acc, vget_high_s16(diff_hi), vget_high_s16(diff_hi));
376
377            i += 16;
378        }
379
380        let mut result = vaddvq_s32(acc);
381
382        while i < len {
383            let diff = (a[i] as i32) - (b[i] as i32);
384            result += diff * diff;
385            i += 1;
386        }
387
388        result
389    }
390}
391
392#[cfg(test)]
393mod tests {
394    use super::*;
395
396    #[test]
397    fn test_dot_i8_basic() {
398        let a: Vec<i8> = vec![1, 2, 3, 4, 5, 6, 7, 8];
399        let b: Vec<i8> = vec![8, 7, 6, 5, 4, 3, 2, 1];
400
401        let result = dot_i8(&a, &b);
402        let expected: i32 = a
403            .iter()
404            .zip(b.iter())
405            .map(|(&x, &y)| (x as i32) * (y as i32))
406            .sum();
407
408        assert_eq!(result, expected);
409    }
410
411    #[test]
412    fn test_dot_i8_large() {
413        // Test with typical embedding dimension
414        let dim = 768;
415        let a: Vec<i8> = (0..dim)
416            .map(|i| ((i % 256) as i8).wrapping_add(-128))
417            .collect();
418        let b: Vec<i8> = (0..dim)
419            .map(|i| ((i * 7 % 256) as i8).wrapping_add(-128))
420            .collect();
421
422        let result = dot_i8(&a, &b);
423        let expected = dot_i8_scalar(&a, &b);
424
425        assert_eq!(result, expected);
426    }
427
428    #[test]
429    fn test_dot_i8_batch() {
430        let dim = 128;
431        let n_vec = 10;
432        let query: Vec<i8> = (0..dim).map(|i| (i % 127) as i8).collect();
433        let vectors: Vec<i8> = (0..n_vec * dim).map(|i| ((i * 3) % 127) as i8).collect();
434        let scales: Vec<f32> = (0..n_vec).map(|i| 0.01 * (i + 1) as f32).collect();
435        let mut results = vec![0.0f32; n_vec];
436
437        dot_i8_batch(&query, &vectors, &scales, dim, &mut results);
438
439        // Verify against scalar
440        let mut expected = vec![0.0f32; n_vec];
441        dot_i8_batch_scalar(&query, &vectors, &scales, dim, &mut expected);
442
443        for (r, e) in results.iter().zip(expected.iter()) {
444            assert!((r - e).abs() < 1e-6, "result={}, expected={}", r, e);
445        }
446    }
447
448    #[test]
449    fn test_l2_distance() {
450        let a: Vec<i8> = vec![10, 20, 30, 40];
451        let b: Vec<i8> = vec![11, 22, 33, 44];
452
453        let result = l2_distance_i8(&a, &b);
454        // (10-11)^2 + (20-22)^2 + (30-33)^2 + (40-44)^2 = 1 + 4 + 9 + 16 = 30
455        assert_eq!(result, 30);
456    }
457}