ruvector_rabitq/kernel.rs
1//! `VectorKernel` trait — the pluggable execution backend for RaBitQ
2//! scan + rerank. Defined here (ADR-157 §"Where each piece lives")
3//! because kernels are RaBitQ primitives; the cache is a consumer.
4//!
5//! Ships with one implementation — `CpuKernel` — which delegates to
6//! the existing `RabitqPlusIndex::search_with_rerank`. GPU / SIMD /
7//! WASM kernels live in separate crates (`ruvector-rabitq-cuda` etc.)
8//! and register themselves with the caller (e.g. `ruvector-rulake`'s
9//! dispatcher) as optional accelerators.
10//!
11//! ## Determinism contract
12//!
13//! Scan-phase output (top-k by 1-bit Hamming distance) must be
14//! bit-reproducible across every kernel. Rerank-phase output (exact
15//! L2²) may differ in the last ulp on reduction-order-sensitive
16//! kernels (GPU with float reduction reorder); these set
17//! `caps().deterministic = false`, and the caller's dispatch policy
18//! filters them out of `Consistency::Fresh` / `Consistency::Frozen`
19//! paths.
20//!
21//! The witness chain is NOT recomputed per kernel; it stays anchored
22//! on `(data_ref, dim, rotation_seed, rerank_factor, generation)`.
23//! Kernel identity is surfaced in caps + stats, not in the witness.
24
25use crate::index::{AnnIndex, RabitqPlusIndex, SearchResult};
26use crate::RabitqError;
27
28/// Capability advertisement for a vector kernel. The caller's
29/// dispatch policy compares these against the request to pick the
30/// best kernel for a given batch + determinism requirement.
31#[derive(Debug, Clone)]
32pub struct KernelCaps {
33 /// Symbolic accelerator label: "cpu", "cpu-simd", "cuda",
34 /// "metal", "rocm", "wasm-simd", etc. Surfaced in stats.
35 pub accelerator: &'static str,
36 /// Minimum batch size at which this kernel is ever chosen. CPU
37 /// kernels report 1; GPU kernels typically ≥ 64.
38 pub min_batch: usize,
39 /// Maximum dimensionality the kernel supports without falling
40 /// back to a slower path. `usize::MAX` means "no constraint".
41 pub max_dim: usize,
42 /// Does the kernel produce byte-identical output (scan + rerank)
43 /// vs the reference CPU kernel? Only deterministic kernels can
44 /// feed witness-sealed outputs under Fresh/Frozen consistency.
45 pub deterministic: bool,
46}
47
48impl KernelCaps {
49 /// Default CPU caps: available always, deterministic, no dim cap.
50 pub const fn cpu_default() -> Self {
51 Self {
52 accelerator: "cpu",
53 min_batch: 1,
54 max_dim: usize::MAX,
55 deterministic: true,
56 }
57 }
58}
59
60/// A batch of query vectors against a single index. The index is
61/// borrowed by reference so GPU kernels don't need to own its
62/// lifetime — the cache holds the authoritative copy.
63pub struct ScanRequest<'a> {
64 pub index: &'a RabitqPlusIndex,
65 pub queries: &'a [Vec<f32>],
66 pub k: usize,
67 /// Optional per-call rerank factor. `None` uses the index's stored
68 /// default. Used by `ruvector-rulake` to divide rerank cost
69 /// across K shards (ADR-155 federation path).
70 pub rerank_factor: Option<usize>,
71}
72
73/// Batched top-k results, one `Vec<SearchResult>` per query. Order
74/// matches the input `queries`.
75pub type ScanResponse = Vec<Vec<SearchResult>>;
76
77/// A vector kernel executes scan + exact rerank for one or more
78/// queries against a compressed RaBitQ index.
79///
80/// Implementations are stateless w.r.t. the index — they receive it
81/// by reference on every call, so a single kernel instance can serve
82/// many caches / collections concurrently. Concrete GPU kernels may
83/// carry a driver handle or stream object; that's kernel state, not
84/// index state.
85pub trait VectorKernel: Send + Sync {
86 /// Stable identifier surfaced in stats + logs. Must be unique per
87 /// kernel type (e.g. `"cpu"`, `"cuda:0"`, `"metal"`).
88 fn id(&self) -> &str;
89
90 /// Capability advertisement — what this kernel can do. Return a
91 /// fresh struct (not a static reference) so kernels can narrow
92 /// caps at runtime (e.g. GPU-down → `min_batch = usize::MAX`).
93 fn caps(&self) -> KernelCaps;
94
95 /// Run the scan + rerank for every query in `req`. Returns one
96 /// `Vec<SearchResult>` per query, in the input order.
97 fn scan(&self, req: ScanRequest<'_>) -> Result<ScanResponse, RabitqError>;
98}
99
100/// Reference CPU kernel. Wraps `RabitqPlusIndex::search_with_rerank`.
101/// Deterministic by construction (integer popcount scan + stable
102/// exact L2² rerank with total-order tie break via position).
103///
104/// This is the default kernel every consumer gets for free; GPU /
105/// SIMD implementations plug in alongside it via registration.
106#[derive(Debug, Default, Clone, Copy)]
107pub struct CpuKernel;
108
109impl CpuKernel {
110 pub const fn new() -> Self {
111 Self
112 }
113}
114
115impl VectorKernel for CpuKernel {
116 fn id(&self) -> &str {
117 "cpu"
118 }
119
120 fn caps(&self) -> KernelCaps {
121 KernelCaps::cpu_default()
122 }
123
124 fn scan(&self, req: ScanRequest<'_>) -> Result<ScanResponse, RabitqError> {
125 let mut out = Vec::with_capacity(req.queries.len());
126 for q in req.queries {
127 let hits = match req.rerank_factor {
128 None => req.index.search(q, req.k)?,
129 Some(rf) => req.index.search_with_rerank(q, req.k, rf)?,
130 };
131 out.push(hits);
132 }
133 Ok(out)
134 }
135}
136
137#[cfg(test)]
138mod tests {
139 use super::*;
140
141 fn tiny_index() -> RabitqPlusIndex {
142 let d = 8;
143 let mut idx = RabitqPlusIndex::new(d, 42, 5);
144 for i in 0..16 {
145 let v: Vec<f32> = (0..d).map(|j| (i + j) as f32).collect();
146 idx.add(i, v).unwrap();
147 }
148 idx
149 }
150
151 #[test]
152 fn cpu_kernel_matches_direct_search() {
153 let idx = tiny_index();
154 let kernel = CpuKernel::new();
155 let q: Vec<f32> = vec![2.0; 8];
156 let direct = idx.search(&q, 4).unwrap();
157 let batched = kernel
158 .scan(ScanRequest {
159 index: &idx,
160 queries: std::slice::from_ref(&q),
161 k: 4,
162 rerank_factor: None,
163 })
164 .unwrap();
165 assert_eq!(batched.len(), 1);
166 let batch = &batched[0];
167 assert_eq!(batch.len(), direct.len());
168 for (a, b) in batch.iter().zip(direct.iter()) {
169 assert_eq!(a.id, b.id);
170 assert!((a.score - b.score).abs() < 1e-5);
171 }
172 }
173
174 #[test]
175 fn cpu_kernel_respects_rerank_override() {
176 let idx = tiny_index();
177 let kernel = CpuKernel::new();
178 let q: Vec<f32> = vec![2.0; 8];
179 // Override with a smaller rerank factor — results should still
180 // be sorted and a prefix of the default.
181 let out = kernel
182 .scan(ScanRequest {
183 index: &idx,
184 queries: &[q.clone(), q.clone()],
185 k: 3,
186 rerank_factor: Some(2),
187 })
188 .unwrap();
189 assert_eq!(out.len(), 2, "one result vec per input query");
190 for v in &out {
191 for w in v.windows(2) {
192 assert!(w[0].score <= w[1].score, "hits must be sorted");
193 }
194 }
195 }
196
197 #[test]
198 fn cpu_caps_are_deterministic_and_unbounded() {
199 let c = CpuKernel::new().caps();
200 assert_eq!(c.accelerator, "cpu");
201 assert_eq!(c.min_batch, 1);
202 assert_eq!(c.max_dim, usize::MAX);
203 assert!(c.deterministic);
204 }
205}