Skip to main content

ruvector_rabitq/
kernel.rs

1//! `VectorKernel` trait — the pluggable execution backend for RaBitQ
2//! scan + rerank. Defined here (ADR-157 §"Where each piece lives")
3//! because kernels are RaBitQ primitives; the cache is a consumer.
4//!
5//! Ships with one implementation — `CpuKernel` — which delegates to
6//! the existing `RabitqPlusIndex::search_with_rerank`. GPU / SIMD /
7//! WASM kernels live in separate crates (`ruvector-rabitq-cuda` etc.)
8//! and register themselves with the caller (e.g. `ruvector-rulake`'s
9//! dispatcher) as optional accelerators.
10//!
11//! ## Determinism contract
12//!
13//! Scan-phase output (top-k by 1-bit Hamming distance) must be
14//! bit-reproducible across every kernel. Rerank-phase output (exact
15//! L2²) may differ in the last ulp on reduction-order-sensitive
16//! kernels (GPU with float reduction reorder); these set
17//! `caps().deterministic = false`, and the caller's dispatch policy
18//! filters them out of `Consistency::Fresh` / `Consistency::Frozen`
19//! paths.
20//!
21//! The witness chain is NOT recomputed per kernel; it stays anchored
22//! on `(data_ref, dim, rotation_seed, rerank_factor, generation)`.
23//! Kernel identity is surfaced in caps + stats, not in the witness.
24
25use crate::index::{AnnIndex, RabitqPlusIndex, SearchResult};
26use crate::RabitqError;
27
28/// Capability advertisement for a vector kernel. The caller's
29/// dispatch policy compares these against the request to pick the
30/// best kernel for a given batch + determinism requirement.
31#[derive(Debug, Clone)]
32pub struct KernelCaps {
33    /// Symbolic accelerator label: "cpu", "cpu-simd", "cuda",
34    /// "metal", "rocm", "wasm-simd", etc. Surfaced in stats.
35    pub accelerator: &'static str,
36    /// Minimum batch size at which this kernel is ever chosen. CPU
37    /// kernels report 1; GPU kernels typically ≥ 64.
38    pub min_batch: usize,
39    /// Maximum dimensionality the kernel supports without falling
40    /// back to a slower path. `usize::MAX` means "no constraint".
41    pub max_dim: usize,
42    /// Does the kernel produce byte-identical output (scan + rerank)
43    /// vs the reference CPU kernel? Only deterministic kernels can
44    /// feed witness-sealed outputs under Fresh/Frozen consistency.
45    pub deterministic: bool,
46}
47
48impl KernelCaps {
49    /// Default CPU caps: available always, deterministic, no dim cap.
50    pub const fn cpu_default() -> Self {
51        Self {
52            accelerator: "cpu",
53            min_batch: 1,
54            max_dim: usize::MAX,
55            deterministic: true,
56        }
57    }
58}
59
60/// A batch of query vectors against a single index. The index is
61/// borrowed by reference so GPU kernels don't need to own its
62/// lifetime — the cache holds the authoritative copy.
63pub struct ScanRequest<'a> {
64    pub index: &'a RabitqPlusIndex,
65    pub queries: &'a [Vec<f32>],
66    pub k: usize,
67    /// Optional per-call rerank factor. `None` uses the index's stored
68    /// default. Used by `ruvector-rulake` to divide rerank cost
69    /// across K shards (ADR-155 federation path).
70    pub rerank_factor: Option<usize>,
71}
72
73/// Batched top-k results, one `Vec<SearchResult>` per query. Order
74/// matches the input `queries`.
75pub type ScanResponse = Vec<Vec<SearchResult>>;
76
77/// A vector kernel executes scan + exact rerank for one or more
78/// queries against a compressed RaBitQ index.
79///
80/// Implementations are stateless w.r.t. the index — they receive it
81/// by reference on every call, so a single kernel instance can serve
82/// many caches / collections concurrently. Concrete GPU kernels may
83/// carry a driver handle or stream object; that's kernel state, not
84/// index state.
85pub trait VectorKernel: Send + Sync {
86    /// Stable identifier surfaced in stats + logs. Must be unique per
87    /// kernel type (e.g. `"cpu"`, `"cuda:0"`, `"metal"`).
88    fn id(&self) -> &str;
89
90    /// Capability advertisement — what this kernel can do. Return a
91    /// fresh struct (not a static reference) so kernels can narrow
92    /// caps at runtime (e.g. GPU-down → `min_batch = usize::MAX`).
93    fn caps(&self) -> KernelCaps;
94
95    /// Run the scan + rerank for every query in `req`. Returns one
96    /// `Vec<SearchResult>` per query, in the input order.
97    fn scan(&self, req: ScanRequest<'_>) -> Result<ScanResponse, RabitqError>;
98}
99
100/// Reference CPU kernel. Wraps `RabitqPlusIndex::search_with_rerank`.
101/// Deterministic by construction (integer popcount scan + stable
102/// exact L2² rerank with total-order tie break via position).
103///
104/// This is the default kernel every consumer gets for free; GPU /
105/// SIMD implementations plug in alongside it via registration.
106#[derive(Debug, Default, Clone, Copy)]
107pub struct CpuKernel;
108
109impl CpuKernel {
110    pub const fn new() -> Self {
111        Self
112    }
113}
114
115impl VectorKernel for CpuKernel {
116    fn id(&self) -> &str {
117        "cpu"
118    }
119
120    fn caps(&self) -> KernelCaps {
121        KernelCaps::cpu_default()
122    }
123
124    fn scan(&self, req: ScanRequest<'_>) -> Result<ScanResponse, RabitqError> {
125        let mut out = Vec::with_capacity(req.queries.len());
126        for q in req.queries {
127            let hits = match req.rerank_factor {
128                None => req.index.search(q, req.k)?,
129                Some(rf) => req.index.search_with_rerank(q, req.k, rf)?,
130            };
131            out.push(hits);
132        }
133        Ok(out)
134    }
135}
136
137#[cfg(test)]
138mod tests {
139    use super::*;
140
141    fn tiny_index() -> RabitqPlusIndex {
142        let d = 8;
143        let mut idx = RabitqPlusIndex::new(d, 42, 5);
144        for i in 0..16 {
145            let v: Vec<f32> = (0..d).map(|j| (i + j) as f32).collect();
146            idx.add(i, v).unwrap();
147        }
148        idx
149    }
150
151    #[test]
152    fn cpu_kernel_matches_direct_search() {
153        let idx = tiny_index();
154        let kernel = CpuKernel::new();
155        let q: Vec<f32> = vec![2.0; 8];
156        let direct = idx.search(&q, 4).unwrap();
157        let batched = kernel
158            .scan(ScanRequest {
159                index: &idx,
160                queries: std::slice::from_ref(&q),
161                k: 4,
162                rerank_factor: None,
163            })
164            .unwrap();
165        assert_eq!(batched.len(), 1);
166        let batch = &batched[0];
167        assert_eq!(batch.len(), direct.len());
168        for (a, b) in batch.iter().zip(direct.iter()) {
169            assert_eq!(a.id, b.id);
170            assert!((a.score - b.score).abs() < 1e-5);
171        }
172    }
173
174    #[test]
175    fn cpu_kernel_respects_rerank_override() {
176        let idx = tiny_index();
177        let kernel = CpuKernel::new();
178        let q: Vec<f32> = vec![2.0; 8];
179        // Override with a smaller rerank factor — results should still
180        // be sorted and a prefix of the default.
181        let out = kernel
182            .scan(ScanRequest {
183                index: &idx,
184                queries: &[q.clone(), q.clone()],
185                k: 3,
186                rerank_factor: Some(2),
187            })
188            .unwrap();
189        assert_eq!(out.len(), 2, "one result vec per input query");
190        for v in &out {
191            for w in v.windows(2) {
192                assert!(w[0].score <= w[1].score, "hits must be sorted");
193            }
194        }
195    }
196
197    #[test]
198    fn cpu_caps_are_deterministic_and_unbounded() {
199        let c = CpuKernel::new().caps();
200        assert_eq!(c.accelerator, "cpu");
201        assert_eq!(c.min_batch, 1);
202        assert_eq!(c.max_dim, usize::MAX);
203        assert!(c.deterministic);
204    }
205}