1use std::sync::OnceLock;
16
17pub use crate::simd::dispatch::{CpuFeatures, SimdLevel};
19
20static CPU_FEATURES: OnceLock<CpuFeatures> = OnceLock::new();
22
23pub fn cpu_features() -> &'static CpuFeatures {
25 CPU_FEATURES.get_or_init(CpuFeatures::detect)
26}
27
28pub fn simd_level() -> SimdLevel {
30 cpu_features().best_level()
31}
32
33pub struct BpsScanDispatcher;
39
40impl BpsScanDispatcher {
41 pub fn scan(
47 bps: &[u8],
48 n_vec: usize,
49 n_blocks: usize,
50 _proj: usize, query: &[u8],
52 out: &mut [u16],
53 ) {
54 crate::simd::bps_scan::bps_scan(bps, n_vec, n_blocks, query, out);
55 }
56
57 pub fn scan_u32(bps: &[u8], n_vec: usize, n_blocks: usize, query: &[u8], out: &mut [u32]) {
59 crate::simd::bps_scan::bps_scan_u32(bps, n_vec, n_blocks, query, out);
60 }
61
62 #[allow(dead_code)]
64 pub(crate) fn scan_fallback(
65 bps: &[u8],
66 n_vec: usize,
67 n_blocks: usize,
68 proj: usize,
69 query: &[u8],
70 out: &mut [u16],
71 ) {
72 let slots = n_blocks * proj;
73
74 for d in out.iter_mut().take(n_vec) {
76 *d = 0;
77 }
78
79 for slot in 0..slots {
80 let q = query[slot] as i16;
81 let base = slot * n_vec;
82
83 for vec_id in 0..n_vec {
84 let v = bps[base + vec_id] as i16;
85 let diff = (q - v).abs() as u16;
86 out[vec_id] = out[vec_id].saturating_add(diff);
87 }
88 }
89 }
90
91 #[allow(dead_code)]
93 pub(crate) fn scan_fallback_u32(
94 bps: &[u8],
95 n_vec: usize,
96 n_blocks: usize,
97 query: &[u8],
98 out: &mut [u32],
99 ) {
100 for d in out.iter_mut().take(n_vec) {
102 *d = 0;
103 }
104
105 for block in 0..n_blocks {
106 let q = query[block];
107 let base = block * n_vec;
108
109 for vec_id in 0..n_vec {
110 let v = bps[base + vec_id];
111 let diff = if q > v { q - v } else { v - q };
112 out[vec_id] += diff as u32;
113 }
114 }
115 }
116}
117
118pub struct DotI8Dispatcher;
124
125impl DotI8Dispatcher {
126 pub fn dot(a: &[i8], b: &[i8]) -> i32 {
128 crate::simd::dot_i8::dot_i8(a, b)
129 }
130
131 pub fn compute(
133 query: &[i8],
134 vectors: &[i8],
135 cand_ids: &[u32],
136 dim: usize,
137 out_scores: &mut [i32],
138 ) {
139 crate::simd::dot_i8::dot_i8_indexed(query, vectors, cand_ids, dim, out_scores);
140 }
141
142 pub fn compute_batch_contiguous(
144 query: &[i8],
145 vectors: &[i8],
146 scales: &[f32],
147 dim: usize,
148 out_scores: &mut [f32],
149 ) {
150 crate::simd::dot_i8::dot_i8_batch(query, vectors, scales, dim, out_scores);
151 }
152
153 pub fn compute_batch(
155 query: &[i8],
156 vectors: &[i8],
157 cand_ids: &[u32],
158 dim: usize,
159 query_scale: f32,
160 vec_scales: &[f32],
161 out_scores: &mut [f32],
162 ) {
163 let n_cand = cand_ids.len();
164 assert!(query.len() >= dim);
165 assert!(out_scores.len() >= n_cand);
166
167 let mut int_scores = vec![0i32; n_cand];
169 Self::compute(query, vectors, cand_ids, dim, &mut int_scores);
170
171 let denom = 127.0 * 127.0;
173 for (i, &cand_id) in cand_ids.iter().enumerate() {
174 let scale = query_scale * vec_scales[cand_id as usize] / denom;
175 out_scores[i] = int_scores[i] as f32 * scale;
176 }
177 }
178
179 #[allow(dead_code)]
181 pub(crate) fn dot_fallback(a: &[i8], b: &[i8]) -> i32 {
182 a.iter()
183 .zip(b.iter())
184 .map(|(&x, &y)| (x as i32) * (y as i32))
185 .sum()
186 }
187
188 #[allow(dead_code)]
190 pub(crate) fn compute_fallback(
191 query: &[i8],
192 vectors: &[i8],
193 cand_ids: &[u32],
194 dim: usize,
195 out_scores: &mut [i32],
196 ) {
197 for (i, &cand_id) in cand_ids.iter().enumerate() {
198 let offset = cand_id as usize * dim;
199 let vec = &vectors[offset..offset + dim];
200 out_scores[i] = Self::dot_fallback(&query[..dim], vec);
201 }
202 }
203
204 #[allow(dead_code)]
206 pub(crate) fn compute_batch_fallback(
207 query: &[i8],
208 vectors: &[i8],
209 scales: &[f32],
210 dim: usize,
211 out_scores: &mut [f32],
212 ) {
213 for (i, &scale) in scales.iter().enumerate() {
214 let offset = i * dim;
215 let vec = &vectors[offset..offset + dim];
216 let int_score = Self::dot_fallback(&query[..dim], vec);
217 out_scores[i] = int_score as f32 * scale;
218 }
219 }
220}
221
222pub struct VisibilityDispatcher;
232
233impl VisibilityDispatcher {
234 pub fn check_batch(commit_timestamps: &[u64], snapshot_ts: u64, visible_mask: &mut [u8]) {
244 crate::simd::visibility::visibility_check(commit_timestamps, snapshot_ts, visible_mask);
245 }
246
247 pub fn check_batch_with_txn(
260 commit_timestamps: &[u64],
261 txn_ids: &[u64],
262 snapshot_ts: u64,
263 current_txn_id: u64,
264 visible_mask: &mut [u8],
265 ) {
266 crate::simd::visibility::visibility_check_with_txn(
267 commit_timestamps,
268 txn_ids,
269 snapshot_ts,
270 current_txn_id,
271 visible_mask,
272 );
273 }
274
275 #[allow(dead_code)]
277 pub(crate) fn check_batch_fallback(
278 commit_timestamps: &[u64],
279 snapshot_ts: u64,
280 visible_mask: &mut [u8],
281 ) {
282 for (i, &commit_ts) in commit_timestamps.iter().enumerate() {
283 visible_mask[i] = if commit_ts != 0 && commit_ts < snapshot_ts {
284 1
285 } else {
286 0
287 };
288 }
289 }
290
291 #[allow(dead_code)]
293 pub(crate) fn check_batch_with_txn_fallback(
294 commit_timestamps: &[u64],
295 txn_ids: &[u64],
296 snapshot_ts: u64,
297 current_txn_id: u64,
298 visible_mask: &mut [u8],
299 ) {
300 for i in 0..commit_timestamps.len() {
301 let commit_ts = commit_timestamps[i];
302 let txn_id = txn_ids[i];
303 let visible = (commit_ts != 0 && commit_ts < snapshot_ts) || txn_id == current_txn_id;
304 visible_mask[i] = if visible { 1 } else { 0 };
305 }
306 }
307}
308
309pub fn simd_available() -> bool {
315 cpu_features().has_simd()
316}
317
318pub fn dispatch_info() -> String {
320 crate::simd::dispatch::dispatch_info()
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 #[test]
328 fn test_bps_scan_fallback() {
329 let n_vec = 100;
330 let n_blocks = 4;
331 let proj = 1;
332
333 let mut bps = vec![0u8; n_blocks * proj * n_vec];
335 for i in 0..n_vec {
336 for b in 0..n_blocks {
337 bps[b * n_vec + i] = (i % 256) as u8;
338 }
339 }
340
341 let query = vec![128u8; n_blocks * proj];
342 let mut out = vec![0u16; n_vec];
343
344 BpsScanDispatcher::scan_fallback(&bps, n_vec, n_blocks, proj, &query, &mut out);
345
346 assert!(out.iter().all(|&d| d > 0 || d == 0));
348 }
349
350 #[test]
351 fn test_bps_scan_fallback_u32() {
352 let n_vec = 100;
353 let n_blocks = 4;
354
355 let mut bps = vec![0u8; n_blocks * n_vec];
357 for i in 0..n_vec {
358 for b in 0..n_blocks {
359 bps[b * n_vec + i] = (i % 256) as u8;
360 }
361 }
362
363 let query = vec![128u8; n_blocks];
364 let mut out = vec![0u32; n_vec];
365
366 BpsScanDispatcher::scan_fallback_u32(&bps, n_vec, n_blocks, &query, &mut out);
367
368 for (i, &d) in out.iter().enumerate() {
370 let expected: u32 = (0..n_blocks)
371 .map(|_b| {
372 let v = (i % 256) as u8;
373 let q = 128u8;
374 (if q > v { q - v } else { v - q }) as u32
375 })
376 .sum();
377 assert_eq!(d, expected);
378 }
379 }
380
381 #[test]
382 fn test_dot_i8_fallback() {
383 let dim = 64;
384 let n_vec = 10;
385
386 let query: Vec<i8> = (0..dim).map(|i| (i % 128) as i8).collect();
387 let vectors: Vec<i8> = (0..n_vec * dim)
388 .map(|i| ((i / dim) as i8).wrapping_mul(2))
389 .collect();
390 let cand_ids: Vec<u32> = (0..n_vec as u32).collect();
391 let mut out = vec![0i32; n_vec];
392
393 DotI8Dispatcher::compute_fallback(&query, &vectors, &cand_ids, dim, &mut out);
394
395 assert!(out.iter().any(|&s| s != out[0]));
397 }
398
399 #[test]
400 fn test_dot_single() {
401 let a: Vec<i8> = vec![1, 2, 3, 4, 5];
402 let b: Vec<i8> = vec![1, 2, 3, 4, 5];
403 let result = DotI8Dispatcher::dot_fallback(&a, &b);
404 assert_eq!(result, 1 + 4 + 9 + 16 + 25);
405 }
406
407 #[test]
408 fn test_dispatch_info() {
409 let info = dispatch_info();
410 assert!(!info.is_empty());
411 println!("Dispatch: {}", info);
412 }
413
414 #[test]
416 fn test_simd_dispatch_cross_validation() {
417 let n_vec = 256;
419 let n_blocks = 8;
420
421 let bps: Vec<u8> = (0..(n_blocks * n_vec))
423 .map(|i| ((i * 17 + 13) % 256) as u8)
424 .collect();
425 let query: Vec<u8> = (0..n_blocks).map(|i| (i * 31 + 7) as u8).collect();
426
427 let mut ref_distances = vec![0u16; n_vec];
429 BpsScanDispatcher::scan_fallback(&bps, n_vec, n_blocks, 1, &query, &mut ref_distances);
430
431 let mut dispatch_distances = vec![0u16; n_vec];
433 BpsScanDispatcher::scan(&bps, n_vec, n_blocks, 1, &query, &mut dispatch_distances);
434
435 for i in 0..n_vec {
437 assert_eq!(
438 ref_distances[i], dispatch_distances[i],
439 "BPS scan mismatch at vector {}: fallback={}, dispatch={}",
440 i, ref_distances[i], dispatch_distances[i]
441 );
442 }
443
444 let dim = 128;
446 let a: Vec<i8> = (0..dim).map(|i| ((i * 3 - 64) % 128) as i8).collect();
447 let b: Vec<i8> = (0..dim).map(|i| ((i * 7 + 32) % 128) as i8).collect();
448
449 let ref_dot = DotI8Dispatcher::dot_fallback(&a, &b);
450 let dispatch_dot = DotI8Dispatcher::dot(&a, &b);
451
452 assert_eq!(
453 ref_dot, dispatch_dot,
454 "int8 dot product mismatch: fallback={}, dispatch={}",
455 ref_dot, dispatch_dot
456 );
457 }
458
459 #[test]
461 fn test_cpu_features_detection() {
462 let features = cpu_features();
463 let level = simd_level();
464
465 println!("CPU Features: {:?}", features);
466 println!("SIMD Level: {:?}", level);
467 println!("Dispatch Info: {}", dispatch_info());
468
469 #[cfg(target_arch = "x86_64")]
471 {
472 assert!(level >= SimdLevel::Scalar);
474 }
475
476 #[cfg(target_arch = "aarch64")]
478 {
479 assert!(features.has_neon);
480 assert!(level >= SimdLevel::Neon);
481 }
482 }
483
484 #[test]
486 fn test_visibility_check_basic() {
487 let commit_timestamps = vec![10, 0, 5, 15, 20, 8];
488 let snapshot_ts = 12;
489 let mut visible_mask = vec![0u8; 6];
490
491 VisibilityDispatcher::check_batch(&commit_timestamps, snapshot_ts, &mut visible_mask);
492
493 assert_eq!(visible_mask, vec![1, 0, 1, 0, 0, 1]);
495 }
496
497 #[test]
499 fn test_visibility_check_with_txn() {
500 let commit_timestamps = vec![10, 0, 5, 0, 20, 8];
501 let txn_ids = vec![1, 2, 3, 99, 5, 6];
502 let snapshot_ts = 12;
503 let current_txn_id = 99;
504 let mut visible_mask = vec![0u8; 6];
505
506 VisibilityDispatcher::check_batch_with_txn(
507 &commit_timestamps,
508 &txn_ids,
509 snapshot_ts,
510 current_txn_id,
511 &mut visible_mask,
512 );
513
514 assert_eq!(visible_mask, vec![1, 0, 1, 1, 0, 1]);
516 }
517
518 #[test]
520 fn test_visibility_simd_equivalence() {
521 let n_rows = 1024;
522
523 let commit_timestamps: Vec<u64> = (0..n_rows)
525 .map(|i| if i % 5 == 0 { 0 } else { (i * 7 % 100) as u64 })
526 .collect();
527 let txn_ids: Vec<u64> = (0..n_rows).map(|i| (i % 10) as u64).collect();
528 let snapshot_ts = 50;
529 let current_txn_id = 7;
530
531 let mut ref_mask = vec![0u8; n_rows];
533 let mut dispatch_mask = vec![0u8; n_rows];
534
535 VisibilityDispatcher::check_batch_fallback(&commit_timestamps, snapshot_ts, &mut ref_mask);
536 VisibilityDispatcher::check_batch(&commit_timestamps, snapshot_ts, &mut dispatch_mask);
537
538 for i in 0..n_rows {
539 assert_eq!(
540 ref_mask[i], dispatch_mask[i],
541 "Visibility mismatch at row {}: fallback={}, dispatch={}",
542 i, ref_mask[i], dispatch_mask[i]
543 );
544 }
545
546 let mut ref_mask_txn = vec![0u8; n_rows];
548 let mut dispatch_mask_txn = vec![0u8; n_rows];
549
550 VisibilityDispatcher::check_batch_with_txn_fallback(
551 &commit_timestamps,
552 &txn_ids,
553 snapshot_ts,
554 current_txn_id,
555 &mut ref_mask_txn,
556 );
557 VisibilityDispatcher::check_batch_with_txn(
558 &commit_timestamps,
559 &txn_ids,
560 snapshot_ts,
561 current_txn_id,
562 &mut dispatch_mask_txn,
563 );
564
565 for i in 0..n_rows {
566 assert_eq!(
567 ref_mask_txn[i], dispatch_mask_txn[i],
568 "Visibility+txn mismatch at row {}: fallback={}, dispatch={}",
569 i, ref_mask_txn[i], dispatch_mask_txn[i]
570 );
571 }
572 }
573}