Skip to main content

uni_common/
muvera.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! MUVERA — Fixed Dimensional Encoding (FDE) for multi-vector (ColBERT) retrieval.
5//!
6//! MUVERA (arXiv:2405.19504) maps a *multi-vector* (a set of per-token vectors,
7//! the ColBERT/late-interaction representation) into ONE fixed-dimensional dense
8//! vector — the FDE — such that the **inner product** of two FDEs approximates the
9//! exact MaxSim score:
10//!
11//! ```text
12//! ⟨encode_query(q), encode_doc(d)⟩ ≈ MaxSim(q, d)  (= Σ_i max_j ⟨q_i, d_j⟩)
13//! ```
14//!
15//! This lets the fast, mature single-*vector* ANN index do first-stage retrieval
16//! over a derived FDE column, with the exact MaxSim kernel
17//! (`uni_query_functions::similar_to::maxsim`) re-ranking the candidates. Because the
18//! approximation is an inner product, the FDE
19//! ANN index must always use the **Dot** metric, independent of the metric the exact
20//! re-rank uses. This holds because ColBERT tokens are L2-normalised (per-token cosine
21//! equals dot).
22//!
23//! ## Algorithm (one *repetition*, `B = 2^k_sim` buckets)
24//! - **SimHash buckets:** `k_sim` random Gaussian hyperplanes; a token's bucket id in
25//!   `[0, B)` is the sign-bit pattern of its dot products with the hyperplanes.
26//! - **Inner projection (optional):** project each token from `input_dim` down to
27//!   `d_proj` via a random ±1/√d_proj matrix (skipped when `d_proj == 0`).
28//! - **Document FDE:** each bucket holds the **centroid** (mean) of the (projected)
29//!   doc tokens that fall in it; empty buckets are filled from the non-empty bucket at
30//!   smallest Hamming distance on the `k_sim` bits (ties → lowest index — deterministic).
31//! - **Query FDE:** each bucket holds the **sum** of the (projected) query tokens in it;
32//!   no centroid, no empty-bucket filling.
33//! - Repeat `reps` times with independent matrices and concatenate →
34//!   `fde_dim = reps * 2^k_sim * (d_proj or input_dim)`.
35//!
36//! ## Determinism
37//! All random matrices are derived from a persisted `seed` using a self-contained
38//! SplitMix64 PRNG + Box–Muller Gaussian transform (no external RNG crate). This
39//! guarantees bit-for-bit identical matrices across platforms **and binary upgrades**
40//! — essential because the document FDEs are materialised at write time and the query
41//! FDE is computed later (possibly after a restart/upgrade); both must use the *same*
42//! transform or the inner-product approximation breaks.
43//!
44//! ## Parameter tuning (important)
45//! FDE recall is corpus-dependent. The shipped defaults (`k_sim=4, reps=20, d_proj=16`,
46//! see `uni_common::vector_index_opts`) are reasonable starting points but are **not**
47//! validated for recall on any particular corpus. Higher `reps`/`k_sim` raise recall at the
48//! cost of a larger `fde_dim`. Synthetic self-retrieval (an exact-match doc ranking first)
49//! is robust at any setting and is NOT evidence of real recall; measure recall@k on a real
50//! ColBERT corpus with `crates/uni-store/examples/multivec_recall_real.rs` and tune from
51//! there. The exact MaxSim re-rank means a poor FDE only costs recall, never precision.
52
53use serde::{Deserialize, Serialize};
54
55/// Errors produced while building or applying an FDE transform.
56#[derive(Debug, thiserror::Error, PartialEq, Eq)]
57pub enum FdeError {
58    /// A token vector's length does not match the configured `input_dim`.
59    #[error("muvera: token dimension {got} != configured input_dim {expected}")]
60    DimensionMismatch { got: usize, expected: usize },
61
62    /// The parameters are out of the supported range.
63    #[error("muvera: invalid params: {0}")]
64    InvalidParams(String),
65}
66
67/// Default master seed used when a MUVERA index is created without an explicit one.
68/// Fixed so behaviour is reproducible across runs (golden-ratio constant, matching the
69/// repo's other seeded RNG defaults).
70pub const DEFAULT_FDE_SEED: u64 = 0x9E37_79B9_7F4A_7C15;
71
72/// Upper bound on `k_sim` (so `2^k_sim` buckets stays sane) and on the resulting
73/// `fde_dim`, to fail fast on absurd configurations rather than allocating gigabytes.
74const MAX_K_SIM: u32 = 16;
75const MAX_FDE_DIM: usize = 200_000;
76
77/// Per-axis caps on the user-supplied `reps` and `d_proj`, applied before the
78/// `fde_dim = reps · 2^k_sim · proj_dim` product is formed (M-DOCUMENTED-MAGIC).
79///
80/// With `k_sim ≤ 16`, `reps ≤ 1024`, and `proj_dim ≤ 4096`, the product is at most
81/// `2^16 · 1024 · 4096 ≈ 2.7e14`, far below `usize::MAX`, so `fde_dim` cannot overflow
82/// `usize` and wrap *under* the `MAX_FDE_DIM` guard (which previously let an absurd
83/// config bypass validation or panic an overflow-checked build — issue #96). The real
84/// ceiling is still enforced by `MAX_FDE_DIM`; these only bound the multiplication.
85const MAX_REPS: u32 = 1024;
86const MAX_PROJ_DIM: u32 = 4096;
87
88/// Parameters of an FDE transform. Persisted (via the raw fields on
89/// `VectorIndexType::Muvera`) so query-time encoding reproduces document-time encoding.
90#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
91pub struct FdeParams {
92    /// Number of SimHash hyperplanes per repetition; produces `2^k_sim` buckets.
93    pub k_sim: u32,
94    /// Number of independent repetitions concatenated into the final FDE.
95    pub reps: u32,
96    /// Inner-projection target dimension. `0` means "no projection" (use `input_dim`).
97    pub d_proj: u32,
98    /// Dimension of each input token vector (resolved from the source column at build).
99    pub input_dim: u32,
100    /// Master seed; all hyperplanes/projections are derived from it.
101    pub seed: u64,
102}
103
104impl FdeParams {
105    /// Per-bucket vector dimension (the projected dim, or `input_dim` if no projection).
106    #[inline]
107    pub fn proj_dim(&self) -> usize {
108        if self.d_proj == 0 {
109            self.input_dim as usize
110        } else {
111            self.d_proj as usize
112        }
113    }
114
115    /// Number of buckets per repetition (`2^k_sim`).
116    #[inline]
117    pub fn buckets(&self) -> usize {
118        // `checked_shl` guards an out-of-range `k_sim` (≥ `usize::BITS`) that a raw
119        // `1usize << k_sim` would panic on; `validate` rejects `k_sim > MAX_K_SIM`
120        // regardless, so this only hardens unvalidated callers.
121        1usize.checked_shl(self.k_sim).unwrap_or(0)
122    }
123
124    /// Final FDE dimension: `reps * 2^k_sim * proj_dim`.
125    ///
126    /// Saturates to `usize::MAX` on overflow rather than panicking (M-PANIC-IS-STOP):
127    /// an unvalidated caller with absurd `reps`/`d_proj` must not crash, and a saturated
128    /// value cleanly trips the `dim > MAX_FDE_DIM` check in [`FdeParams::validate`].
129    #[inline]
130    pub fn fde_dim(&self) -> usize {
131        self.buckets()
132            .checked_mul(self.proj_dim())
133            .and_then(|x| x.checked_mul(self.reps as usize))
134            .unwrap_or(usize::MAX)
135    }
136
137    /// Validate the parameters, returning a descriptive error if unsupported.
138    pub fn validate(&self) -> Result<(), FdeError> {
139        if self.k_sim == 0 || self.k_sim > MAX_K_SIM {
140            return Err(FdeError::InvalidParams(format!(
141                "k_sim must be in 1..={MAX_K_SIM}, got {}",
142                self.k_sim
143            )));
144        }
145        if self.reps == 0 || self.reps > MAX_REPS {
146            return Err(FdeError::InvalidParams(format!(
147                "reps must be in 1..={MAX_REPS}, got {}",
148                self.reps
149            )));
150        }
151        if self.input_dim == 0 {
152            return Err(FdeError::InvalidParams(
153                "input_dim must be >= 1".to_string(),
154            ));
155        }
156        // Bound `d_proj` before forming the `fde_dim` product so the multiplication
157        // cannot overflow `usize` (issue #96). `d_proj == 0` legitimately means
158        // "no projection" (use `input_dim`), so only the upper bound is checked here.
159        if self.d_proj > MAX_PROJ_DIM {
160            return Err(FdeError::InvalidParams(format!(
161                "d_proj must be <= {MAX_PROJ_DIM}, got {}",
162                self.d_proj
163            )));
164        }
165        let dim = self.fde_dim();
166        if dim == 0 || dim > MAX_FDE_DIM {
167            return Err(FdeError::InvalidParams(format!(
168                "fde_dim {dim} out of range (1..={MAX_FDE_DIM}); reduce k_sim/reps/d_proj"
169            )));
170        }
171        Ok(())
172    }
173}
174
175/// A minimal, fully-specified SplitMix64 PRNG. Deterministic and portable across
176/// platforms and binary versions — unlike `rand`'s `StdRng`, whose algorithm is not
177/// guaranteed stable. Only what the FDE encoder needs (uniform + Gaussian) is exposed.
178struct SplitMix64 {
179    state: u64,
180}
181
182impl SplitMix64 {
183    #[inline]
184    fn new(seed: u64) -> Self {
185        Self { state: seed }
186    }
187
188    #[inline]
189    fn next_u64(&mut self) -> u64 {
190        self.state = self.state.wrapping_add(0x9E37_79B9_7F4A_7C15);
191        let mut z = self.state;
192        z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
193        z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
194        z ^ (z >> 31)
195    }
196
197    /// Uniform f64 in `[0, 1)` using the top 53 bits.
198    #[inline]
199    fn next_f64(&mut self) -> f64 {
200        (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
201    }
202
203    /// One standard-normal sample via the Box–Muller transform (cos branch).
204    #[inline]
205    fn next_gaussian(&mut self) -> f32 {
206        // Clamp u1 away from 0 so ln() is finite.
207        let u1 = self.next_f64().max(1e-12);
208        let u2 = self.next_f64();
209        let r = (-2.0 * u1.ln()).sqrt();
210        (r * (2.0 * std::f64::consts::PI * u2).cos()) as f32
211    }
212}
213
214/// Mix a master seed with a repetition index into a distinct sub-seed, so each
215/// repetition's matrices are independent (SplitMix64-style finaliser).
216#[inline]
217fn rep_seed(base: u64, rep: u32) -> u64 {
218    let mut s = base.wrapping_add((rep as u64).wrapping_mul(0xD1B5_4A32_D192_ED03));
219    s = (s ^ (s >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
220    s = (s ^ (s >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
221    s ^ (s >> 31)
222}
223
224/// Precomputed random matrices for one repetition.
225struct RepMatrices {
226    /// `k_sim * input_dim`, row-major (one hyperplane per row). Gaussian entries.
227    hyperplanes: Vec<f32>,
228    /// `d_proj * input_dim`, row-major. `±1/√d_proj` entries. `None` = no projection.
229    projection: Option<Vec<f32>>,
230}
231
232impl RepMatrices {
233    fn build(params: &FdeParams, rep: u32) -> Self {
234        let mut rng = SplitMix64::new(rep_seed(params.seed, rep));
235        let d = params.input_dim as usize;
236        let hyperplanes = (0..params.k_sim as usize * d)
237            .map(|_| rng.next_gaussian())
238            .collect();
239        let projection = if params.d_proj == 0 {
240            None
241        } else {
242            let pd = params.d_proj as usize;
243            let scale = 1.0f32 / (pd as f32).sqrt();
244            // Draw the ±1 entries AFTER the hyperplanes so the draw order is fixed.
245            let proj = (0..pd * d)
246                .map(|_| {
247                    if rng.next_u64() & 1 == 0 {
248                        scale
249                    } else {
250                        -scale
251                    }
252                })
253                .collect();
254            Some(proj)
255        };
256        Self {
257            hyperplanes,
258            projection,
259        }
260    }
261
262    /// SimHash bucket id of a (raw, `input_dim`) token: sign-bit pattern over hyperplanes.
263    #[inline]
264    fn bucket_of(&self, token: &[f32], k_sim: u32, d: usize) -> usize {
265        let mut bucket = 0usize;
266        for h in 0..k_sim as usize {
267            let row = &self.hyperplanes[h * d..(h + 1) * d];
268            let mut dot = 0.0f32;
269            for i in 0..d {
270                dot += row[i] * token[i];
271            }
272            if dot > 0.0 {
273                bucket |= 1 << h;
274            }
275        }
276        bucket
277    }
278
279    /// Project a raw token to `proj_dim` (identity if no projection matrix).
280    #[inline]
281    fn project(&self, token: &[f32], proj_dim: usize, d: usize) -> Vec<f32> {
282        match &self.projection {
283            None => token.to_vec(),
284            Some(p) => {
285                let mut out = vec![0.0f32; proj_dim];
286                for (r, slot) in out.iter_mut().enumerate() {
287                    let row = &p[r * d..(r + 1) * d];
288                    let mut acc = 0.0f32;
289                    for i in 0..d {
290                        acc += row[i] * token[i];
291                    }
292                    *slot = acc;
293                }
294                out
295            }
296        }
297    }
298}
299
300/// A reusable FDE encoder holding all repetitions' random matrices. Build it ONCE per
301/// flush batch / per query (matrix generation is the expensive part) and reuse it
302/// across many `encode_doc`/`encode_query` calls.
303pub struct FdeEncoder {
304    params: FdeParams,
305    reps: Vec<RepMatrices>,
306}
307
308impl FdeEncoder {
309    /// Materialise all random matrices from the seed. Validates `params`.
310    pub fn new(params: &FdeParams) -> Result<Self, FdeError> {
311        params.validate()?;
312        let reps = (0..params.reps)
313            .map(|r| RepMatrices::build(params, r))
314            .collect();
315        Ok(Self {
316            params: params.clone(),
317            reps,
318        })
319    }
320
321    /// The parameters this encoder was built from.
322    #[inline]
323    pub fn params(&self) -> &FdeParams {
324        &self.params
325    }
326
327    /// Output FDE dimension (== `self.params().fde_dim()`).
328    #[inline]
329    pub fn fde_dim(&self) -> usize {
330        self.params.fde_dim()
331    }
332
333    fn check_tokens(&self, tokens: &[Vec<f32>]) -> Result<(), FdeError> {
334        let d = self.params.input_dim as usize;
335        for tok in tokens {
336            if tok.len() != d {
337                return Err(FdeError::DimensionMismatch {
338                    got: tok.len(),
339                    expected: d,
340                });
341            }
342        }
343        Ok(())
344    }
345
346    /// Encode a **document** multi-vector: per-bucket centroid + empty-bucket fill.
347    pub fn encode_doc(&self, tokens: &[Vec<f32>]) -> Result<Vec<f32>, FdeError> {
348        self.check_tokens(tokens)?;
349        let pd = self.params.proj_dim();
350        let b = self.params.buckets();
351        let d = self.params.input_dim as usize;
352        let mut out = vec![0.0f32; self.params.fde_dim()];
353
354        for (ri, rep) in self.reps.iter().enumerate() {
355            let base = ri * b * pd;
356            let mut sums = vec![0.0f32; b * pd];
357            let mut counts = vec![0u32; b];
358            for tok in tokens {
359                let bk = rep.bucket_of(tok, self.params.k_sim, d);
360                let proj = rep.project(tok, pd, d);
361                let slot = &mut sums[bk * pd..(bk + 1) * pd];
362                for (s, p) in slot.iter_mut().zip(proj.iter()) {
363                    *s += *p;
364                }
365                counts[bk] += 1;
366            }
367            // Centroid per non-empty bucket, written into the output region directly.
368            for bk in 0..b {
369                if counts[bk] > 0 {
370                    let inv = 1.0f32 / counts[bk] as f32;
371                    let dst = &mut out[base + bk * pd..base + (bk + 1) * pd];
372                    let src = &sums[bk * pd..(bk + 1) * pd];
373                    for (o, s) in dst.iter_mut().zip(src.iter()) {
374                        *o = *s * inv;
375                    }
376                }
377            }
378            // fill_empty: copy the centroid of the Hamming-nearest non-empty bucket.
379            for bk in 0..b {
380                if counts[bk] == 0
381                    && let Some(src) = nearest_nonempty(bk, &counts)
382                {
383                    let (lo, hi) = (bk.min(src), bk.max(src));
384                    // Split to satisfy the borrow checker, then copy src→bk.
385                    let (left, right) = out[base..base + b * pd].split_at_mut(hi * pd);
386                    let (src_slice, dst_slice) = if bk == lo {
387                        // dst (bk) is in `left`, src is in `right`
388                        (&right[0..pd], &mut left[bk * pd..bk * pd + pd])
389                    } else {
390                        // src is in `left`, dst (bk) is in `right`
391                        (&left[src * pd..src * pd + pd], &mut right[0..pd])
392                    };
393                    dst_slice.copy_from_slice(src_slice);
394                }
395            }
396        }
397        Ok(out)
398    }
399
400    /// Encode a **query** multi-vector: per-bucket sum, no centroid, no fill_empty.
401    pub fn encode_query(&self, tokens: &[Vec<f32>]) -> Result<Vec<f32>, FdeError> {
402        self.check_tokens(tokens)?;
403        let pd = self.params.proj_dim();
404        let b = self.params.buckets();
405        let d = self.params.input_dim as usize;
406        let mut out = vec![0.0f32; self.params.fde_dim()];
407
408        for (ri, rep) in self.reps.iter().enumerate() {
409            let base = ri * b * pd;
410            for tok in tokens {
411                let bk = rep.bucket_of(tok, self.params.k_sim, d);
412                let proj = rep.project(tok, pd, d);
413                let dst = &mut out[base + bk * pd..base + (bk + 1) * pd];
414                for (o, p) in dst.iter_mut().zip(proj.iter()) {
415                    *o += *p;
416                }
417            }
418        }
419        Ok(out)
420    }
421}
422
423/// Index of the non-empty bucket at smallest Hamming distance from `bucket` (ties →
424/// lowest index). `None` when every bucket is empty (an empty document).
425#[inline]
426fn nearest_nonempty(bucket: usize, counts: &[u32]) -> Option<usize> {
427    let mut best: Option<(u32, usize)> = None;
428    for (cand, &c) in counts.iter().enumerate() {
429        if c > 0 {
430            let h = (bucket ^ cand).count_ones();
431            match best {
432                Some((bh, _)) if h >= bh => {}
433                _ => best = Some((h, cand)),
434            }
435        }
436    }
437    best.map(|(_, idx)| idx)
438}
439
440/// Encode a single document multi-vector (builds a transient encoder). Prefer
441/// [`FdeEncoder`] when encoding many vectors with the same params.
442pub fn encode_doc(tokens: &[Vec<f32>], params: &FdeParams) -> Result<Vec<f32>, FdeError> {
443    FdeEncoder::new(params)?.encode_doc(tokens)
444}
445
446/// Encode a single query multi-vector (builds a transient encoder). Prefer
447/// [`FdeEncoder`] when encoding many vectors with the same params.
448pub fn encode_query(tokens: &[Vec<f32>], params: &FdeParams) -> Result<Vec<f32>, FdeError> {
449    FdeEncoder::new(params)?.encode_query(tokens)
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455
456    /// Exact MaxSim under the dot metric: Σ_i max_j ⟨q_i, d_j⟩ (empty doc → 0). Local to
457    /// the test so this foundational module stays dependency-free; the production kernel
458    /// lives in `uni_query_functions::similar_to::maxsim`.
459    fn maxsim_dot(query: &[Vec<f32>], doc: &[Vec<f32>]) -> f32 {
460        query
461            .iter()
462            .map(|q| {
463                if doc.is_empty() {
464                    0.0
465                } else {
466                    doc.iter()
467                        .map(|d| dot(q, d))
468                        .fold(f32::NEG_INFINITY, f32::max)
469                }
470            })
471            .sum()
472    }
473
474    /// Deterministic unit-norm random multi-vector generator (own PRNG, no rand crate).
475    struct Gen(SplitMix64);
476    impl Gen {
477        fn new(seed: u64) -> Self {
478            Self(SplitMix64::new(seed))
479        }
480        fn unit_token(&mut self, dim: usize) -> Vec<f32> {
481            let mut v: Vec<f32> = (0..dim).map(|_| self.0.next_gaussian()).collect();
482            let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
483            for x in &mut v {
484                *x /= norm;
485            }
486            v
487        }
488        fn multivec(&mut self, n: usize, dim: usize) -> Vec<Vec<f32>> {
489            (0..n).map(|_| self.unit_token(dim)).collect()
490        }
491        fn count(&mut self, lo: usize, hi: usize) -> usize {
492            lo + (self.0.next_u64() as usize) % (hi - lo + 1)
493        }
494    }
495
496    fn params(k_sim: u32, reps: u32, d_proj: u32, input_dim: u32) -> FdeParams {
497        FdeParams {
498            k_sim,
499            reps,
500            d_proj,
501            input_dim,
502            seed: DEFAULT_FDE_SEED,
503        }
504    }
505
506    fn dot(a: &[f32], b: &[f32]) -> f32 {
507        a.iter().zip(b).map(|(x, y)| x * y).sum()
508    }
509
510    fn pearson(xs: &[f32], ys: &[f32]) -> f32 {
511        let n = xs.len() as f32;
512        let mx = xs.iter().sum::<f32>() / n;
513        let my = ys.iter().sum::<f32>() / n;
514        let mut cov = 0.0;
515        let mut vx = 0.0;
516        let mut vy = 0.0;
517        for (x, y) in xs.iter().zip(ys) {
518            let dx = x - mx;
519            let dy = y - my;
520            cov += dx * dy;
521            vx += dx * dx;
522            vy += dy * dy;
523        }
524        cov / (vx.sqrt() * vy.sqrt()).max(1e-12)
525    }
526
527    #[test]
528    fn fde_dim_arithmetic() {
529        assert_eq!(params(4, 20, 16, 96).fde_dim(), 20 * 16 * 16);
530        // d_proj == 0 → use input_dim.
531        assert_eq!(params(3, 2, 0, 8).fde_dim(), 2 * 8 * 8);
532        assert_eq!(params(4, 20, 16, 96).buckets(), 16);
533    }
534
535    #[test]
536    fn validate_rejects_bad_params() {
537        assert!(params(0, 1, 0, 8).validate().is_err()); // k_sim 0
538        assert!(params(MAX_K_SIM + 1, 1, 0, 8).validate().is_err());
539        assert!(params(4, 0, 0, 8).validate().is_err()); // reps 0
540        assert!(params(4, 1, 0, 0).validate().is_err()); // input_dim 0
541        // absurd fde_dim
542        assert!(params(16, 1000, 64, 96).validate().is_err());
543        assert!(params(4, 20, 16, 96).validate().is_ok());
544    }
545
546    #[test]
547    fn validate_rejects_overflowing_reps_and_d_proj_without_panicking() {
548        // Regression for issue #96: an unbounded `reps`/`d_proj` made
549        // `fde_dim = reps · 2^k_sim · proj_dim` overflow `usize`, which panicked an
550        // overflow-checked build inside `validate` itself, or wrapped *under* the
551        // `MAX_FDE_DIM` guard in release. The per-axis bounds + `checked_mul` must
552        // reject these cleanly (an `Err`, never a panic and never an `Ok`).
553        assert!(params(16, u32::MAX, u32::MAX, 96).validate().is_err());
554        assert!(params(16, MAX_REPS + 1, 16, 96).validate().is_err());
555        assert!(params(16, 20, MAX_PROJ_DIM + 1, 96).validate().is_err());
556        // The historical wrap-bypass witness (k_sim=1) must also be rejected, not pass.
557        assert!(
558            params(1, 2_147_516_416, 4_294_901_761, 96)
559                .validate()
560                .is_err()
561        );
562        // `fde_dim` itself saturates instead of panicking for an unvalidated caller.
563        assert_eq!(params(16, u32::MAX, u32::MAX, 96).fde_dim(), usize::MAX);
564        // A `k_sim` at/above `usize::BITS` cannot panic the shift in `buckets`.
565        assert_eq!(params(64, 1, 0, 8).buckets(), 0);
566        // Parameters at the new ceilings still validate.
567        assert!(params(16, MAX_REPS, MAX_PROJ_DIM, 96).validate().is_err()); // exceeds MAX_FDE_DIM
568        assert!(params(4, MAX_REPS, 16, 96).validate().is_err()); // exceeds MAX_FDE_DIM but no overflow/panic
569    }
570
571    #[test]
572    fn fde_self_retrieval_ranks_first() {
573        // LOAD-BEARING correctness guard. A document queried by its OWN tokens must be
574        // the FDE-dot top-1 against a corpus of other (random) documents. This is the
575        // strong-signal property a faithful MUVERA estimator must satisfy and it holds
576        // even on cluster-free synthetic data (where *random-pair* recall is meaningless
577        // — see the project's documented "don't trust synthetic ANN recall" lesson; the
578        // real recall/latency gate is the multivec_recall_real bench on ColBERT data).
579        let dim = 32usize;
580        let p = params(4, 20, 16, dim as u32); // minimal/default params on purpose
581        let enc = FdeEncoder::new(&p).unwrap();
582        let mut g = Gen::new(7);
583        let corpus: Vec<Vec<Vec<f32>>> = (0..50)
584            .map(|_| {
585                let n = g.count(4, 16);
586                g.multivec(n, dim)
587            })
588            .collect();
589        let dfde: Vec<Vec<f32>> = corpus.iter().map(|d| enc.encode_doc(d).unwrap()).collect();
590        for (j, d) in corpus.iter().enumerate() {
591            let fq = enc.encode_query(d).unwrap();
592            let top = (0..corpus.len())
593                .max_by(|&a, &b| dot(&fq, &dfde[a]).total_cmp(&dot(&fq, &dfde[b])))
594                .unwrap();
595            assert_eq!(top, j, "doc {j} did not self-retrieve as FDE top-1");
596        }
597    }
598
599    #[test]
600    fn fde_dot_positively_correlates_with_maxsim() {
601        // Regression guard: the FDE inner product must track exact MaxSim. The estimator
602        // is biased (centroid < max) so over cluster-free random pairs the correlation
603        // tops out well below 1.0; assert a conservative floor that a correct impl clears
604        // comfortably (observed ~0.68 at these minimal params). Quality on real data is
605        // the bench's job, not this unit test's.
606        let dim = 32usize;
607        let p = params(4, 24, 16, dim as u32);
608        let enc = FdeEncoder::new(&p).unwrap();
609        let mut g = Gen::new(42);
610
611        let n_pairs = 400;
612        let mut fde_scores = Vec::with_capacity(n_pairs);
613        let mut exact_scores = Vec::with_capacity(n_pairs);
614        for _ in 0..n_pairs {
615            let (qn, dn) = (g.count(2, 6), g.count(4, 16));
616            let q = g.multivec(qn, dim);
617            let d = g.multivec(dn, dim);
618            fde_scores.push(dot(
619                &enc.encode_query(&q).unwrap(),
620                &enc.encode_doc(&d).unwrap(),
621            ));
622            exact_scores.push(maxsim_dot(&q, &d));
623        }
624        let r = pearson(&fde_scores, &exact_scores);
625        assert!(r >= 0.55, "FDE/MaxSim correlation regressed: {r}");
626    }
627
628    #[test]
629    fn deterministic_across_rebuild() {
630        // Two encoders from identical params (simulating doc-time vs query-time after a
631        // restart) must produce byte-identical output.
632        let p = params(4, 8, 8, 16);
633        let e1 = FdeEncoder::new(&p).unwrap();
634        let e2 = FdeEncoder::new(&p).unwrap();
635        let mut g = Gen::new(7);
636        let d = g.multivec(10, 16);
637        assert_eq!(e1.encode_doc(&d).unwrap(), e2.encode_doc(&d).unwrap());
638        let q = g.multivec(3, 16);
639        assert_eq!(e1.encode_query(&q).unwrap(), e2.encode_query(&q).unwrap());
640    }
641
642    #[test]
643    fn different_seed_changes_output() {
644        let mut p = params(4, 8, 8, 16);
645        let e1 = FdeEncoder::new(&p).unwrap();
646        p.seed = DEFAULT_FDE_SEED ^ 0xDEAD_BEEF;
647        let e2 = FdeEncoder::new(&p).unwrap();
648        let mut g = Gen::new(11);
649        let d = g.multivec(10, 16);
650        assert_ne!(e1.encode_doc(&d).unwrap(), e2.encode_doc(&d).unwrap());
651    }
652
653    #[test]
654    fn empty_doc_is_all_zero() {
655        let p = params(4, 4, 8, 16);
656        let enc = FdeEncoder::new(&p).unwrap();
657        let fde = enc.encode_doc(&[]).unwrap();
658        assert_eq!(fde.len(), p.fde_dim());
659        assert!(fde.iter().all(|&x| x == 0.0));
660    }
661
662    #[test]
663    fn empty_query_scores_zero() {
664        let p = params(4, 4, 8, 16);
665        let enc = FdeEncoder::new(&p).unwrap();
666        let mut g = Gen::new(3);
667        let fq = enc.encode_query(&[]).unwrap();
668        let fd = enc.encode_doc(&g.multivec(8, 16)).unwrap();
669        assert_eq!(dot(&fq, &fd), 0.0);
670    }
671
672    #[test]
673    fn dim_mismatch_errors() {
674        let p = params(4, 4, 8, 16);
675        let enc = FdeEncoder::new(&p).unwrap();
676        let bad = vec![vec![1.0f32; 15]]; // 15 != 16
677        assert_eq!(
678            enc.encode_doc(&bad),
679            Err(FdeError::DimensionMismatch {
680                got: 15,
681                expected: 16
682            })
683        );
684        assert!(enc.encode_query(&bad).is_err());
685    }
686
687    #[test]
688    fn single_token_doc_fills_all_buckets() {
689        // One token → exactly one non-empty bucket → fill_empty copies it everywhere,
690        // so every per-bucket slot equals that token's projection.
691        let p = params(3, 1, 0, 8); // no projection, 1 rep, 8 buckets
692        let enc = FdeEncoder::new(&p).unwrap();
693        let mut g = Gen::new(99);
694        let tok = g.unit_token(8);
695        let fde = enc.encode_doc(&[tok]).unwrap();
696        let pd = p.proj_dim();
697        let first = &fde[0..pd];
698        for bk in 1..p.buckets() {
699            assert_eq!(&fde[bk * pd..(bk + 1) * pd], first, "bucket {bk} differs");
700        }
701        assert!(first.iter().any(|&x| x != 0.0));
702    }
703
704    #[test]
705    fn query_leaves_empty_buckets_zero() {
706        // A single query token → exactly one non-empty bucket; the rest stay zero
707        // (no fill_empty for queries).
708        let p = params(3, 1, 0, 8);
709        let enc = FdeEncoder::new(&p).unwrap();
710        let mut g = Gen::new(123);
711        let tok = g.unit_token(8);
712        let fde = enc.encode_query(&[tok]).unwrap();
713        let pd = p.proj_dim();
714        let nonzero_buckets = (0..p.buckets())
715            .filter(|&bk| fde[bk * pd..(bk + 1) * pd].iter().any(|&x| x != 0.0))
716            .count();
717        assert_eq!(nonzero_buckets, 1);
718    }
719
720    #[test]
721    fn free_fns_match_encoder() {
722        let p = params(4, 4, 8, 16);
723        let enc = FdeEncoder::new(&p).unwrap();
724        let mut g = Gen::new(55);
725        let d = g.multivec(6, 16);
726        assert_eq!(encode_doc(&d, &p).unwrap(), enc.encode_doc(&d).unwrap());
727        let q = g.multivec(2, 16);
728        assert_eq!(encode_query(&q, &p).unwrap(), enc.encode_query(&q).unwrap());
729    }
730}