Skip to main content

uni_common/
muvera.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright 2024-2026 Dragonscale Team
3
4//! MUVERA — Fixed Dimensional Encoding (FDE) for multi-vector (ColBERT) retrieval.
5//!
6//! MUVERA (arXiv:2405.19504) maps a *multi-vector* (a set of per-token vectors,
7//! the ColBERT/late-interaction representation) into ONE fixed-dimensional dense
8//! vector — the FDE — such that the **inner product** of two FDEs approximates the
9//! exact MaxSim score:
10//!
11//! ```text
12//! ⟨encode_query(q), encode_doc(d)⟩ ≈ MaxSim(q, d)  (= Σ_i max_j ⟨q_i, d_j⟩)
13//! ```
14//!
15//! This lets the fast, mature single-*vector* ANN index do first-stage retrieval
16//! over a derived FDE column, with the exact MaxSim kernel
17//! (`uni_query_functions::similar_to::maxsim`) re-ranking the candidates. Because the
18//! approximation is an inner product, the FDE
19//! ANN index must always use the **Dot** metric, independent of the metric the exact
20//! re-rank uses. This holds because ColBERT tokens are L2-normalised (per-token cosine
21//! equals dot).
22//!
23//! ## Algorithm (one *repetition*, `B = 2^k_sim` buckets)
24//! - **SimHash buckets:** `k_sim` random Gaussian hyperplanes; a token's bucket id in
25//!   `[0, B)` is the sign-bit pattern of its dot products with the hyperplanes.
26//! - **Inner projection (optional):** project each token from `input_dim` down to
27//!   `d_proj` via a random ±1/√d_proj matrix (skipped when `d_proj == 0`).
28//! - **Document FDE:** each bucket holds the **centroid** (mean) of the (projected)
29//!   doc tokens that fall in it; empty buckets are filled from the non-empty bucket at
30//!   smallest Hamming distance on the `k_sim` bits (ties → lowest index — deterministic).
31//! - **Query FDE:** each bucket holds the **sum** of the (projected) query tokens in it;
32//!   no centroid, no empty-bucket filling.
33//! - Repeat `reps` times with independent matrices and concatenate →
34//!   `fde_dim = reps * 2^k_sim * (d_proj or input_dim)`.
35//!
36//! ## Determinism
37//! All random matrices are derived from a persisted `seed` using a self-contained
38//! SplitMix64 PRNG + Box–Muller Gaussian transform (no external RNG crate). This
39//! guarantees bit-for-bit identical matrices across platforms **and binary upgrades**
40//! — essential because the document FDEs are materialised at write time and the query
41//! FDE is computed later (possibly after a restart/upgrade); both must use the *same*
42//! transform or the inner-product approximation breaks.
43//!
44//! ## Parameter tuning (important)
45//! FDE recall is corpus-dependent. The shipped defaults (`k_sim=4, reps=20, d_proj=16`,
46//! see `uni_common::vector_index_opts`) are reasonable starting points but are **not**
47//! validated for recall on any particular corpus. Higher `reps`/`k_sim` raise recall at the
48//! cost of a larger `fde_dim`. Synthetic self-retrieval (an exact-match doc ranking first)
49//! is robust at any setting and is NOT evidence of real recall; measure recall@k on a real
50//! ColBERT corpus with `crates/uni-store/examples/multivec_recall_real.rs` and tune from
51//! there. The exact MaxSim re-rank means a poor FDE only costs recall, never precision.
52
53use serde::{Deserialize, Serialize};
54
55/// Errors produced while building or applying an FDE transform.
56#[derive(Debug, thiserror::Error, PartialEq, Eq)]
57pub enum FdeError {
58    /// A token vector's length does not match the configured `input_dim`.
59    #[error("muvera: token dimension {got} != configured input_dim {expected}")]
60    DimensionMismatch { got: usize, expected: usize },
61
62    /// The parameters are out of the supported range.
63    #[error("muvera: invalid params: {0}")]
64    InvalidParams(String),
65}
66
67/// Default master seed used when a MUVERA index is created without an explicit one.
68/// Fixed so behaviour is reproducible across runs (golden-ratio constant, matching the
69/// repo's other seeded RNG defaults).
70pub const DEFAULT_FDE_SEED: u64 = 0x9E37_79B9_7F4A_7C15;
71
72/// Upper bound on `k_sim` (so `2^k_sim` buckets stays sane) and on the resulting
73/// `fde_dim`, to fail fast on absurd configurations rather than allocating gigabytes.
74const MAX_K_SIM: u32 = 16;
75const MAX_FDE_DIM: usize = 200_000;
76
77/// Parameters of an FDE transform. Persisted (via the raw fields on
78/// `VectorIndexType::Muvera`) so query-time encoding reproduces document-time encoding.
79#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
80pub struct FdeParams {
81    /// Number of SimHash hyperplanes per repetition; produces `2^k_sim` buckets.
82    pub k_sim: u32,
83    /// Number of independent repetitions concatenated into the final FDE.
84    pub reps: u32,
85    /// Inner-projection target dimension. `0` means "no projection" (use `input_dim`).
86    pub d_proj: u32,
87    /// Dimension of each input token vector (resolved from the source column at build).
88    pub input_dim: u32,
89    /// Master seed; all hyperplanes/projections are derived from it.
90    pub seed: u64,
91}
92
93impl FdeParams {
94    /// Per-bucket vector dimension (the projected dim, or `input_dim` if no projection).
95    #[inline]
96    pub fn proj_dim(&self) -> usize {
97        if self.d_proj == 0 {
98            self.input_dim as usize
99        } else {
100            self.d_proj as usize
101        }
102    }
103
104    /// Number of buckets per repetition (`2^k_sim`).
105    #[inline]
106    pub fn buckets(&self) -> usize {
107        1usize << self.k_sim
108    }
109
110    /// Final FDE dimension: `reps * 2^k_sim * proj_dim`.
111    #[inline]
112    pub fn fde_dim(&self) -> usize {
113        self.reps as usize * self.buckets() * self.proj_dim()
114    }
115
116    /// Validate the parameters, returning a descriptive error if unsupported.
117    pub fn validate(&self) -> Result<(), FdeError> {
118        if self.k_sim == 0 || self.k_sim > MAX_K_SIM {
119            return Err(FdeError::InvalidParams(format!(
120                "k_sim must be in 1..={MAX_K_SIM}, got {}",
121                self.k_sim
122            )));
123        }
124        if self.reps == 0 {
125            return Err(FdeError::InvalidParams("reps must be >= 1".to_string()));
126        }
127        if self.input_dim == 0 {
128            return Err(FdeError::InvalidParams(
129                "input_dim must be >= 1".to_string(),
130            ));
131        }
132        let dim = self.fde_dim();
133        if dim == 0 || dim > MAX_FDE_DIM {
134            return Err(FdeError::InvalidParams(format!(
135                "fde_dim {dim} out of range (1..={MAX_FDE_DIM}); reduce k_sim/reps/d_proj"
136            )));
137        }
138        Ok(())
139    }
140}
141
142/// A minimal, fully-specified SplitMix64 PRNG. Deterministic and portable across
143/// platforms and binary versions — unlike `rand`'s `StdRng`, whose algorithm is not
144/// guaranteed stable. Only what the FDE encoder needs (uniform + Gaussian) is exposed.
145struct SplitMix64 {
146    state: u64,
147}
148
149impl SplitMix64 {
150    #[inline]
151    fn new(seed: u64) -> Self {
152        Self { state: seed }
153    }
154
155    #[inline]
156    fn next_u64(&mut self) -> u64 {
157        self.state = self.state.wrapping_add(0x9E37_79B9_7F4A_7C15);
158        let mut z = self.state;
159        z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
160        z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
161        z ^ (z >> 31)
162    }
163
164    /// Uniform f64 in `[0, 1)` using the top 53 bits.
165    #[inline]
166    fn next_f64(&mut self) -> f64 {
167        (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
168    }
169
170    /// One standard-normal sample via the Box–Muller transform (cos branch).
171    #[inline]
172    fn next_gaussian(&mut self) -> f32 {
173        // Clamp u1 away from 0 so ln() is finite.
174        let u1 = self.next_f64().max(1e-12);
175        let u2 = self.next_f64();
176        let r = (-2.0 * u1.ln()).sqrt();
177        (r * (2.0 * std::f64::consts::PI * u2).cos()) as f32
178    }
179}
180
181/// Mix a master seed with a repetition index into a distinct sub-seed, so each
182/// repetition's matrices are independent (SplitMix64-style finaliser).
183#[inline]
184fn rep_seed(base: u64, rep: u32) -> u64 {
185    let mut s = base.wrapping_add((rep as u64).wrapping_mul(0xD1B5_4A32_D192_ED03));
186    s = (s ^ (s >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
187    s = (s ^ (s >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
188    s ^ (s >> 31)
189}
190
191/// Precomputed random matrices for one repetition.
192struct RepMatrices {
193    /// `k_sim * input_dim`, row-major (one hyperplane per row). Gaussian entries.
194    hyperplanes: Vec<f32>,
195    /// `d_proj * input_dim`, row-major. `±1/√d_proj` entries. `None` = no projection.
196    projection: Option<Vec<f32>>,
197}
198
199impl RepMatrices {
200    fn build(params: &FdeParams, rep: u32) -> Self {
201        let mut rng = SplitMix64::new(rep_seed(params.seed, rep));
202        let d = params.input_dim as usize;
203        let hyperplanes = (0..params.k_sim as usize * d)
204            .map(|_| rng.next_gaussian())
205            .collect();
206        let projection = if params.d_proj == 0 {
207            None
208        } else {
209            let pd = params.d_proj as usize;
210            let scale = 1.0f32 / (pd as f32).sqrt();
211            // Draw the ±1 entries AFTER the hyperplanes so the draw order is fixed.
212            let proj = (0..pd * d)
213                .map(|_| {
214                    if rng.next_u64() & 1 == 0 {
215                        scale
216                    } else {
217                        -scale
218                    }
219                })
220                .collect();
221            Some(proj)
222        };
223        Self {
224            hyperplanes,
225            projection,
226        }
227    }
228
229    /// SimHash bucket id of a (raw, `input_dim`) token: sign-bit pattern over hyperplanes.
230    #[inline]
231    fn bucket_of(&self, token: &[f32], k_sim: u32, d: usize) -> usize {
232        let mut bucket = 0usize;
233        for h in 0..k_sim as usize {
234            let row = &self.hyperplanes[h * d..(h + 1) * d];
235            let mut dot = 0.0f32;
236            for i in 0..d {
237                dot += row[i] * token[i];
238            }
239            if dot > 0.0 {
240                bucket |= 1 << h;
241            }
242        }
243        bucket
244    }
245
246    /// Project a raw token to `proj_dim` (identity if no projection matrix).
247    #[inline]
248    fn project(&self, token: &[f32], proj_dim: usize, d: usize) -> Vec<f32> {
249        match &self.projection {
250            None => token.to_vec(),
251            Some(p) => {
252                let mut out = vec![0.0f32; proj_dim];
253                for (r, slot) in out.iter_mut().enumerate() {
254                    let row = &p[r * d..(r + 1) * d];
255                    let mut acc = 0.0f32;
256                    for i in 0..d {
257                        acc += row[i] * token[i];
258                    }
259                    *slot = acc;
260                }
261                out
262            }
263        }
264    }
265}
266
267/// A reusable FDE encoder holding all repetitions' random matrices. Build it ONCE per
268/// flush batch / per query (matrix generation is the expensive part) and reuse it
269/// across many `encode_doc`/`encode_query` calls.
270pub struct FdeEncoder {
271    params: FdeParams,
272    reps: Vec<RepMatrices>,
273}
274
275impl FdeEncoder {
276    /// Materialise all random matrices from the seed. Validates `params`.
277    pub fn new(params: &FdeParams) -> Result<Self, FdeError> {
278        params.validate()?;
279        let reps = (0..params.reps)
280            .map(|r| RepMatrices::build(params, r))
281            .collect();
282        Ok(Self {
283            params: params.clone(),
284            reps,
285        })
286    }
287
288    /// The parameters this encoder was built from.
289    #[inline]
290    pub fn params(&self) -> &FdeParams {
291        &self.params
292    }
293
294    /// Output FDE dimension (== `self.params().fde_dim()`).
295    #[inline]
296    pub fn fde_dim(&self) -> usize {
297        self.params.fde_dim()
298    }
299
300    fn check_tokens(&self, tokens: &[Vec<f32>]) -> Result<(), FdeError> {
301        let d = self.params.input_dim as usize;
302        for tok in tokens {
303            if tok.len() != d {
304                return Err(FdeError::DimensionMismatch {
305                    got: tok.len(),
306                    expected: d,
307                });
308            }
309        }
310        Ok(())
311    }
312
313    /// Encode a **document** multi-vector: per-bucket centroid + empty-bucket fill.
314    pub fn encode_doc(&self, tokens: &[Vec<f32>]) -> Result<Vec<f32>, FdeError> {
315        self.check_tokens(tokens)?;
316        let pd = self.params.proj_dim();
317        let b = self.params.buckets();
318        let d = self.params.input_dim as usize;
319        let mut out = vec![0.0f32; self.params.fde_dim()];
320
321        for (ri, rep) in self.reps.iter().enumerate() {
322            let base = ri * b * pd;
323            let mut sums = vec![0.0f32; b * pd];
324            let mut counts = vec![0u32; b];
325            for tok in tokens {
326                let bk = rep.bucket_of(tok, self.params.k_sim, d);
327                let proj = rep.project(tok, pd, d);
328                let slot = &mut sums[bk * pd..(bk + 1) * pd];
329                for (s, p) in slot.iter_mut().zip(proj.iter()) {
330                    *s += *p;
331                }
332                counts[bk] += 1;
333            }
334            // Centroid per non-empty bucket, written into the output region directly.
335            for bk in 0..b {
336                if counts[bk] > 0 {
337                    let inv = 1.0f32 / counts[bk] as f32;
338                    let dst = &mut out[base + bk * pd..base + (bk + 1) * pd];
339                    let src = &sums[bk * pd..(bk + 1) * pd];
340                    for (o, s) in dst.iter_mut().zip(src.iter()) {
341                        *o = *s * inv;
342                    }
343                }
344            }
345            // fill_empty: copy the centroid of the Hamming-nearest non-empty bucket.
346            for bk in 0..b {
347                if counts[bk] == 0
348                    && let Some(src) = nearest_nonempty(bk, &counts)
349                {
350                    let (lo, hi) = (bk.min(src), bk.max(src));
351                    // Split to satisfy the borrow checker, then copy src→bk.
352                    let (left, right) = out[base..base + b * pd].split_at_mut(hi * pd);
353                    let (src_slice, dst_slice) = if bk == lo {
354                        // dst (bk) is in `left`, src is in `right`
355                        (&right[0..pd], &mut left[bk * pd..bk * pd + pd])
356                    } else {
357                        // src is in `left`, dst (bk) is in `right`
358                        (&left[src * pd..src * pd + pd], &mut right[0..pd])
359                    };
360                    dst_slice.copy_from_slice(src_slice);
361                }
362            }
363        }
364        Ok(out)
365    }
366
367    /// Encode a **query** multi-vector: per-bucket sum, no centroid, no fill_empty.
368    pub fn encode_query(&self, tokens: &[Vec<f32>]) -> Result<Vec<f32>, FdeError> {
369        self.check_tokens(tokens)?;
370        let pd = self.params.proj_dim();
371        let b = self.params.buckets();
372        let d = self.params.input_dim as usize;
373        let mut out = vec![0.0f32; self.params.fde_dim()];
374
375        for (ri, rep) in self.reps.iter().enumerate() {
376            let base = ri * b * pd;
377            for tok in tokens {
378                let bk = rep.bucket_of(tok, self.params.k_sim, d);
379                let proj = rep.project(tok, pd, d);
380                let dst = &mut out[base + bk * pd..base + (bk + 1) * pd];
381                for (o, p) in dst.iter_mut().zip(proj.iter()) {
382                    *o += *p;
383                }
384            }
385        }
386        Ok(out)
387    }
388}
389
390/// Index of the non-empty bucket at smallest Hamming distance from `bucket` (ties →
391/// lowest index). `None` when every bucket is empty (an empty document).
392#[inline]
393fn nearest_nonempty(bucket: usize, counts: &[u32]) -> Option<usize> {
394    let mut best: Option<(u32, usize)> = None;
395    for (cand, &c) in counts.iter().enumerate() {
396        if c > 0 {
397            let h = (bucket ^ cand).count_ones();
398            match best {
399                Some((bh, _)) if h >= bh => {}
400                _ => best = Some((h, cand)),
401            }
402        }
403    }
404    best.map(|(_, idx)| idx)
405}
406
407/// Encode a single document multi-vector (builds a transient encoder). Prefer
408/// [`FdeEncoder`] when encoding many vectors with the same params.
409pub fn encode_doc(tokens: &[Vec<f32>], params: &FdeParams) -> Result<Vec<f32>, FdeError> {
410    FdeEncoder::new(params)?.encode_doc(tokens)
411}
412
413/// Encode a single query multi-vector (builds a transient encoder). Prefer
414/// [`FdeEncoder`] when encoding many vectors with the same params.
415pub fn encode_query(tokens: &[Vec<f32>], params: &FdeParams) -> Result<Vec<f32>, FdeError> {
416    FdeEncoder::new(params)?.encode_query(tokens)
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422
423    /// Exact MaxSim under the dot metric: Σ_i max_j ⟨q_i, d_j⟩ (empty doc → 0). Local to
424    /// the test so this foundational module stays dependency-free; the production kernel
425    /// lives in `uni_query_functions::similar_to::maxsim`.
426    fn maxsim_dot(query: &[Vec<f32>], doc: &[Vec<f32>]) -> f32 {
427        query
428            .iter()
429            .map(|q| {
430                if doc.is_empty() {
431                    0.0
432                } else {
433                    doc.iter()
434                        .map(|d| dot(q, d))
435                        .fold(f32::NEG_INFINITY, f32::max)
436                }
437            })
438            .sum()
439    }
440
441    /// Deterministic unit-norm random multi-vector generator (own PRNG, no rand crate).
442    struct Gen(SplitMix64);
443    impl Gen {
444        fn new(seed: u64) -> Self {
445            Self(SplitMix64::new(seed))
446        }
447        fn unit_token(&mut self, dim: usize) -> Vec<f32> {
448            let mut v: Vec<f32> = (0..dim).map(|_| self.0.next_gaussian()).collect();
449            let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
450            for x in &mut v {
451                *x /= norm;
452            }
453            v
454        }
455        fn multivec(&mut self, n: usize, dim: usize) -> Vec<Vec<f32>> {
456            (0..n).map(|_| self.unit_token(dim)).collect()
457        }
458        fn count(&mut self, lo: usize, hi: usize) -> usize {
459            lo + (self.0.next_u64() as usize) % (hi - lo + 1)
460        }
461    }
462
463    fn params(k_sim: u32, reps: u32, d_proj: u32, input_dim: u32) -> FdeParams {
464        FdeParams {
465            k_sim,
466            reps,
467            d_proj,
468            input_dim,
469            seed: DEFAULT_FDE_SEED,
470        }
471    }
472
473    fn dot(a: &[f32], b: &[f32]) -> f32 {
474        a.iter().zip(b).map(|(x, y)| x * y).sum()
475    }
476
477    fn pearson(xs: &[f32], ys: &[f32]) -> f32 {
478        let n = xs.len() as f32;
479        let mx = xs.iter().sum::<f32>() / n;
480        let my = ys.iter().sum::<f32>() / n;
481        let mut cov = 0.0;
482        let mut vx = 0.0;
483        let mut vy = 0.0;
484        for (x, y) in xs.iter().zip(ys) {
485            let dx = x - mx;
486            let dy = y - my;
487            cov += dx * dy;
488            vx += dx * dx;
489            vy += dy * dy;
490        }
491        cov / (vx.sqrt() * vy.sqrt()).max(1e-12)
492    }
493
494    #[test]
495    fn fde_dim_arithmetic() {
496        assert_eq!(params(4, 20, 16, 96).fde_dim(), 20 * 16 * 16);
497        // d_proj == 0 → use input_dim.
498        assert_eq!(params(3, 2, 0, 8).fde_dim(), 2 * 8 * 8);
499        assert_eq!(params(4, 20, 16, 96).buckets(), 16);
500    }
501
502    #[test]
503    fn validate_rejects_bad_params() {
504        assert!(params(0, 1, 0, 8).validate().is_err()); // k_sim 0
505        assert!(params(MAX_K_SIM + 1, 1, 0, 8).validate().is_err());
506        assert!(params(4, 0, 0, 8).validate().is_err()); // reps 0
507        assert!(params(4, 1, 0, 0).validate().is_err()); // input_dim 0
508        // absurd fde_dim
509        assert!(params(16, 1000, 64, 96).validate().is_err());
510        assert!(params(4, 20, 16, 96).validate().is_ok());
511    }
512
513    #[test]
514    fn fde_self_retrieval_ranks_first() {
515        // LOAD-BEARING correctness guard. A document queried by its OWN tokens must be
516        // the FDE-dot top-1 against a corpus of other (random) documents. This is the
517        // strong-signal property a faithful MUVERA estimator must satisfy and it holds
518        // even on cluster-free synthetic data (where *random-pair* recall is meaningless
519        // — see the project's documented "don't trust synthetic ANN recall" lesson; the
520        // real recall/latency gate is the multivec_recall_real bench on ColBERT data).
521        let dim = 32usize;
522        let p = params(4, 20, 16, dim as u32); // minimal/default params on purpose
523        let enc = FdeEncoder::new(&p).unwrap();
524        let mut g = Gen::new(7);
525        let corpus: Vec<Vec<Vec<f32>>> = (0..50)
526            .map(|_| {
527                let n = g.count(4, 16);
528                g.multivec(n, dim)
529            })
530            .collect();
531        let dfde: Vec<Vec<f32>> = corpus.iter().map(|d| enc.encode_doc(d).unwrap()).collect();
532        for (j, d) in corpus.iter().enumerate() {
533            let fq = enc.encode_query(d).unwrap();
534            let top = (0..corpus.len())
535                .max_by(|&a, &b| dot(&fq, &dfde[a]).total_cmp(&dot(&fq, &dfde[b])))
536                .unwrap();
537            assert_eq!(top, j, "doc {j} did not self-retrieve as FDE top-1");
538        }
539    }
540
541    #[test]
542    fn fde_dot_positively_correlates_with_maxsim() {
543        // Regression guard: the FDE inner product must track exact MaxSim. The estimator
544        // is biased (centroid < max) so over cluster-free random pairs the correlation
545        // tops out well below 1.0; assert a conservative floor that a correct impl clears
546        // comfortably (observed ~0.68 at these minimal params). Quality on real data is
547        // the bench's job, not this unit test's.
548        let dim = 32usize;
549        let p = params(4, 24, 16, dim as u32);
550        let enc = FdeEncoder::new(&p).unwrap();
551        let mut g = Gen::new(42);
552
553        let n_pairs = 400;
554        let mut fde_scores = Vec::with_capacity(n_pairs);
555        let mut exact_scores = Vec::with_capacity(n_pairs);
556        for _ in 0..n_pairs {
557            let (qn, dn) = (g.count(2, 6), g.count(4, 16));
558            let q = g.multivec(qn, dim);
559            let d = g.multivec(dn, dim);
560            fde_scores.push(dot(
561                &enc.encode_query(&q).unwrap(),
562                &enc.encode_doc(&d).unwrap(),
563            ));
564            exact_scores.push(maxsim_dot(&q, &d));
565        }
566        let r = pearson(&fde_scores, &exact_scores);
567        assert!(r >= 0.55, "FDE/MaxSim correlation regressed: {r}");
568    }
569
570    #[test]
571    fn deterministic_across_rebuild() {
572        // Two encoders from identical params (simulating doc-time vs query-time after a
573        // restart) must produce byte-identical output.
574        let p = params(4, 8, 8, 16);
575        let e1 = FdeEncoder::new(&p).unwrap();
576        let e2 = FdeEncoder::new(&p).unwrap();
577        let mut g = Gen::new(7);
578        let d = g.multivec(10, 16);
579        assert_eq!(e1.encode_doc(&d).unwrap(), e2.encode_doc(&d).unwrap());
580        let q = g.multivec(3, 16);
581        assert_eq!(e1.encode_query(&q).unwrap(), e2.encode_query(&q).unwrap());
582    }
583
584    #[test]
585    fn different_seed_changes_output() {
586        let mut p = params(4, 8, 8, 16);
587        let e1 = FdeEncoder::new(&p).unwrap();
588        p.seed = DEFAULT_FDE_SEED ^ 0xDEAD_BEEF;
589        let e2 = FdeEncoder::new(&p).unwrap();
590        let mut g = Gen::new(11);
591        let d = g.multivec(10, 16);
592        assert_ne!(e1.encode_doc(&d).unwrap(), e2.encode_doc(&d).unwrap());
593    }
594
595    #[test]
596    fn empty_doc_is_all_zero() {
597        let p = params(4, 4, 8, 16);
598        let enc = FdeEncoder::new(&p).unwrap();
599        let fde = enc.encode_doc(&[]).unwrap();
600        assert_eq!(fde.len(), p.fde_dim());
601        assert!(fde.iter().all(|&x| x == 0.0));
602    }
603
604    #[test]
605    fn empty_query_scores_zero() {
606        let p = params(4, 4, 8, 16);
607        let enc = FdeEncoder::new(&p).unwrap();
608        let mut g = Gen::new(3);
609        let fq = enc.encode_query(&[]).unwrap();
610        let fd = enc.encode_doc(&g.multivec(8, 16)).unwrap();
611        assert_eq!(dot(&fq, &fd), 0.0);
612    }
613
614    #[test]
615    fn dim_mismatch_errors() {
616        let p = params(4, 4, 8, 16);
617        let enc = FdeEncoder::new(&p).unwrap();
618        let bad = vec![vec![1.0f32; 15]]; // 15 != 16
619        assert_eq!(
620            enc.encode_doc(&bad),
621            Err(FdeError::DimensionMismatch {
622                got: 15,
623                expected: 16
624            })
625        );
626        assert!(enc.encode_query(&bad).is_err());
627    }
628
629    #[test]
630    fn single_token_doc_fills_all_buckets() {
631        // One token → exactly one non-empty bucket → fill_empty copies it everywhere,
632        // so every per-bucket slot equals that token's projection.
633        let p = params(3, 1, 0, 8); // no projection, 1 rep, 8 buckets
634        let enc = FdeEncoder::new(&p).unwrap();
635        let mut g = Gen::new(99);
636        let tok = g.unit_token(8);
637        let fde = enc.encode_doc(&[tok]).unwrap();
638        let pd = p.proj_dim();
639        let first = &fde[0..pd];
640        for bk in 1..p.buckets() {
641            assert_eq!(&fde[bk * pd..(bk + 1) * pd], first, "bucket {bk} differs");
642        }
643        assert!(first.iter().any(|&x| x != 0.0));
644    }
645
646    #[test]
647    fn query_leaves_empty_buckets_zero() {
648        // A single query token → exactly one non-empty bucket; the rest stay zero
649        // (no fill_empty for queries).
650        let p = params(3, 1, 0, 8);
651        let enc = FdeEncoder::new(&p).unwrap();
652        let mut g = Gen::new(123);
653        let tok = g.unit_token(8);
654        let fde = enc.encode_query(&[tok]).unwrap();
655        let pd = p.proj_dim();
656        let nonzero_buckets = (0..p.buckets())
657            .filter(|&bk| fde[bk * pd..(bk + 1) * pd].iter().any(|&x| x != 0.0))
658            .count();
659        assert_eq!(nonzero_buckets, 1);
660    }
661
662    #[test]
663    fn free_fns_match_encoder() {
664        let p = params(4, 4, 8, 16);
665        let enc = FdeEncoder::new(&p).unwrap();
666        let mut g = Gen::new(55);
667        let d = g.multivec(6, 16);
668        assert_eq!(encode_doc(&d, &p).unwrap(), enc.encode_doc(&d).unwrap());
669        let q = g.multivec(2, 16);
670        assert_eq!(encode_query(&q, &p).unwrap(), enc.encode_query(&q).unwrap());
671    }
672}