1use serde::{Deserialize, Serialize};
54
55#[derive(Debug, thiserror::Error, PartialEq, Eq)]
57pub enum FdeError {
58 #[error("muvera: token dimension {got} != configured input_dim {expected}")]
60 DimensionMismatch { got: usize, expected: usize },
61
62 #[error("muvera: invalid params: {0}")]
64 InvalidParams(String),
65}
66
67pub const DEFAULT_FDE_SEED: u64 = 0x9E37_79B9_7F4A_7C15;
71
72const MAX_K_SIM: u32 = 16;
75const MAX_FDE_DIM: usize = 200_000;
76
77#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
80pub struct FdeParams {
81 pub k_sim: u32,
83 pub reps: u32,
85 pub d_proj: u32,
87 pub input_dim: u32,
89 pub seed: u64,
91}
92
93impl FdeParams {
94 #[inline]
96 pub fn proj_dim(&self) -> usize {
97 if self.d_proj == 0 {
98 self.input_dim as usize
99 } else {
100 self.d_proj as usize
101 }
102 }
103
104 #[inline]
106 pub fn buckets(&self) -> usize {
107 1usize << self.k_sim
108 }
109
110 #[inline]
112 pub fn fde_dim(&self) -> usize {
113 self.reps as usize * self.buckets() * self.proj_dim()
114 }
115
116 pub fn validate(&self) -> Result<(), FdeError> {
118 if self.k_sim == 0 || self.k_sim > MAX_K_SIM {
119 return Err(FdeError::InvalidParams(format!(
120 "k_sim must be in 1..={MAX_K_SIM}, got {}",
121 self.k_sim
122 )));
123 }
124 if self.reps == 0 {
125 return Err(FdeError::InvalidParams("reps must be >= 1".to_string()));
126 }
127 if self.input_dim == 0 {
128 return Err(FdeError::InvalidParams(
129 "input_dim must be >= 1".to_string(),
130 ));
131 }
132 let dim = self.fde_dim();
133 if dim == 0 || dim > MAX_FDE_DIM {
134 return Err(FdeError::InvalidParams(format!(
135 "fde_dim {dim} out of range (1..={MAX_FDE_DIM}); reduce k_sim/reps/d_proj"
136 )));
137 }
138 Ok(())
139 }
140}
141
142struct SplitMix64 {
146 state: u64,
147}
148
149impl SplitMix64 {
150 #[inline]
151 fn new(seed: u64) -> Self {
152 Self { state: seed }
153 }
154
155 #[inline]
156 fn next_u64(&mut self) -> u64 {
157 self.state = self.state.wrapping_add(0x9E37_79B9_7F4A_7C15);
158 let mut z = self.state;
159 z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
160 z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
161 z ^ (z >> 31)
162 }
163
164 #[inline]
166 fn next_f64(&mut self) -> f64 {
167 (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
168 }
169
170 #[inline]
172 fn next_gaussian(&mut self) -> f32 {
173 let u1 = self.next_f64().max(1e-12);
175 let u2 = self.next_f64();
176 let r = (-2.0 * u1.ln()).sqrt();
177 (r * (2.0 * std::f64::consts::PI * u2).cos()) as f32
178 }
179}
180
181#[inline]
184fn rep_seed(base: u64, rep: u32) -> u64 {
185 let mut s = base.wrapping_add((rep as u64).wrapping_mul(0xD1B5_4A32_D192_ED03));
186 s = (s ^ (s >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
187 s = (s ^ (s >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
188 s ^ (s >> 31)
189}
190
191struct RepMatrices {
193 hyperplanes: Vec<f32>,
195 projection: Option<Vec<f32>>,
197}
198
199impl RepMatrices {
200 fn build(params: &FdeParams, rep: u32) -> Self {
201 let mut rng = SplitMix64::new(rep_seed(params.seed, rep));
202 let d = params.input_dim as usize;
203 let hyperplanes = (0..params.k_sim as usize * d)
204 .map(|_| rng.next_gaussian())
205 .collect();
206 let projection = if params.d_proj == 0 {
207 None
208 } else {
209 let pd = params.d_proj as usize;
210 let scale = 1.0f32 / (pd as f32).sqrt();
211 let proj = (0..pd * d)
213 .map(|_| {
214 if rng.next_u64() & 1 == 0 {
215 scale
216 } else {
217 -scale
218 }
219 })
220 .collect();
221 Some(proj)
222 };
223 Self {
224 hyperplanes,
225 projection,
226 }
227 }
228
229 #[inline]
231 fn bucket_of(&self, token: &[f32], k_sim: u32, d: usize) -> usize {
232 let mut bucket = 0usize;
233 for h in 0..k_sim as usize {
234 let row = &self.hyperplanes[h * d..(h + 1) * d];
235 let mut dot = 0.0f32;
236 for i in 0..d {
237 dot += row[i] * token[i];
238 }
239 if dot > 0.0 {
240 bucket |= 1 << h;
241 }
242 }
243 bucket
244 }
245
246 #[inline]
248 fn project(&self, token: &[f32], proj_dim: usize, d: usize) -> Vec<f32> {
249 match &self.projection {
250 None => token.to_vec(),
251 Some(p) => {
252 let mut out = vec![0.0f32; proj_dim];
253 for (r, slot) in out.iter_mut().enumerate() {
254 let row = &p[r * d..(r + 1) * d];
255 let mut acc = 0.0f32;
256 for i in 0..d {
257 acc += row[i] * token[i];
258 }
259 *slot = acc;
260 }
261 out
262 }
263 }
264 }
265}
266
267pub struct FdeEncoder {
271 params: FdeParams,
272 reps: Vec<RepMatrices>,
273}
274
275impl FdeEncoder {
276 pub fn new(params: &FdeParams) -> Result<Self, FdeError> {
278 params.validate()?;
279 let reps = (0..params.reps)
280 .map(|r| RepMatrices::build(params, r))
281 .collect();
282 Ok(Self {
283 params: params.clone(),
284 reps,
285 })
286 }
287
288 #[inline]
290 pub fn params(&self) -> &FdeParams {
291 &self.params
292 }
293
294 #[inline]
296 pub fn fde_dim(&self) -> usize {
297 self.params.fde_dim()
298 }
299
300 fn check_tokens(&self, tokens: &[Vec<f32>]) -> Result<(), FdeError> {
301 let d = self.params.input_dim as usize;
302 for tok in tokens {
303 if tok.len() != d {
304 return Err(FdeError::DimensionMismatch {
305 got: tok.len(),
306 expected: d,
307 });
308 }
309 }
310 Ok(())
311 }
312
313 pub fn encode_doc(&self, tokens: &[Vec<f32>]) -> Result<Vec<f32>, FdeError> {
315 self.check_tokens(tokens)?;
316 let pd = self.params.proj_dim();
317 let b = self.params.buckets();
318 let d = self.params.input_dim as usize;
319 let mut out = vec![0.0f32; self.params.fde_dim()];
320
321 for (ri, rep) in self.reps.iter().enumerate() {
322 let base = ri * b * pd;
323 let mut sums = vec![0.0f32; b * pd];
324 let mut counts = vec![0u32; b];
325 for tok in tokens {
326 let bk = rep.bucket_of(tok, self.params.k_sim, d);
327 let proj = rep.project(tok, pd, d);
328 let slot = &mut sums[bk * pd..(bk + 1) * pd];
329 for (s, p) in slot.iter_mut().zip(proj.iter()) {
330 *s += *p;
331 }
332 counts[bk] += 1;
333 }
334 for bk in 0..b {
336 if counts[bk] > 0 {
337 let inv = 1.0f32 / counts[bk] as f32;
338 let dst = &mut out[base + bk * pd..base + (bk + 1) * pd];
339 let src = &sums[bk * pd..(bk + 1) * pd];
340 for (o, s) in dst.iter_mut().zip(src.iter()) {
341 *o = *s * inv;
342 }
343 }
344 }
345 for bk in 0..b {
347 if counts[bk] == 0
348 && let Some(src) = nearest_nonempty(bk, &counts)
349 {
350 let (lo, hi) = (bk.min(src), bk.max(src));
351 let (left, right) = out[base..base + b * pd].split_at_mut(hi * pd);
353 let (src_slice, dst_slice) = if bk == lo {
354 (&right[0..pd], &mut left[bk * pd..bk * pd + pd])
356 } else {
357 (&left[src * pd..src * pd + pd], &mut right[0..pd])
359 };
360 dst_slice.copy_from_slice(src_slice);
361 }
362 }
363 }
364 Ok(out)
365 }
366
367 pub fn encode_query(&self, tokens: &[Vec<f32>]) -> Result<Vec<f32>, FdeError> {
369 self.check_tokens(tokens)?;
370 let pd = self.params.proj_dim();
371 let b = self.params.buckets();
372 let d = self.params.input_dim as usize;
373 let mut out = vec![0.0f32; self.params.fde_dim()];
374
375 for (ri, rep) in self.reps.iter().enumerate() {
376 let base = ri * b * pd;
377 for tok in tokens {
378 let bk = rep.bucket_of(tok, self.params.k_sim, d);
379 let proj = rep.project(tok, pd, d);
380 let dst = &mut out[base + bk * pd..base + (bk + 1) * pd];
381 for (o, p) in dst.iter_mut().zip(proj.iter()) {
382 *o += *p;
383 }
384 }
385 }
386 Ok(out)
387 }
388}
389
390#[inline]
393fn nearest_nonempty(bucket: usize, counts: &[u32]) -> Option<usize> {
394 let mut best: Option<(u32, usize)> = None;
395 for (cand, &c) in counts.iter().enumerate() {
396 if c > 0 {
397 let h = (bucket ^ cand).count_ones();
398 match best {
399 Some((bh, _)) if h >= bh => {}
400 _ => best = Some((h, cand)),
401 }
402 }
403 }
404 best.map(|(_, idx)| idx)
405}
406
407pub fn encode_doc(tokens: &[Vec<f32>], params: &FdeParams) -> Result<Vec<f32>, FdeError> {
410 FdeEncoder::new(params)?.encode_doc(tokens)
411}
412
413pub fn encode_query(tokens: &[Vec<f32>], params: &FdeParams) -> Result<Vec<f32>, FdeError> {
416 FdeEncoder::new(params)?.encode_query(tokens)
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422
423 fn maxsim_dot(query: &[Vec<f32>], doc: &[Vec<f32>]) -> f32 {
427 query
428 .iter()
429 .map(|q| {
430 if doc.is_empty() {
431 0.0
432 } else {
433 doc.iter()
434 .map(|d| dot(q, d))
435 .fold(f32::NEG_INFINITY, f32::max)
436 }
437 })
438 .sum()
439 }
440
441 struct Gen(SplitMix64);
443 impl Gen {
444 fn new(seed: u64) -> Self {
445 Self(SplitMix64::new(seed))
446 }
447 fn unit_token(&mut self, dim: usize) -> Vec<f32> {
448 let mut v: Vec<f32> = (0..dim).map(|_| self.0.next_gaussian()).collect();
449 let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
450 for x in &mut v {
451 *x /= norm;
452 }
453 v
454 }
455 fn multivec(&mut self, n: usize, dim: usize) -> Vec<Vec<f32>> {
456 (0..n).map(|_| self.unit_token(dim)).collect()
457 }
458 fn count(&mut self, lo: usize, hi: usize) -> usize {
459 lo + (self.0.next_u64() as usize) % (hi - lo + 1)
460 }
461 }
462
463 fn params(k_sim: u32, reps: u32, d_proj: u32, input_dim: u32) -> FdeParams {
464 FdeParams {
465 k_sim,
466 reps,
467 d_proj,
468 input_dim,
469 seed: DEFAULT_FDE_SEED,
470 }
471 }
472
473 fn dot(a: &[f32], b: &[f32]) -> f32 {
474 a.iter().zip(b).map(|(x, y)| x * y).sum()
475 }
476
477 fn pearson(xs: &[f32], ys: &[f32]) -> f32 {
478 let n = xs.len() as f32;
479 let mx = xs.iter().sum::<f32>() / n;
480 let my = ys.iter().sum::<f32>() / n;
481 let mut cov = 0.0;
482 let mut vx = 0.0;
483 let mut vy = 0.0;
484 for (x, y) in xs.iter().zip(ys) {
485 let dx = x - mx;
486 let dy = y - my;
487 cov += dx * dy;
488 vx += dx * dx;
489 vy += dy * dy;
490 }
491 cov / (vx.sqrt() * vy.sqrt()).max(1e-12)
492 }
493
494 #[test]
495 fn fde_dim_arithmetic() {
496 assert_eq!(params(4, 20, 16, 96).fde_dim(), 20 * 16 * 16);
497 assert_eq!(params(3, 2, 0, 8).fde_dim(), 2 * 8 * 8);
499 assert_eq!(params(4, 20, 16, 96).buckets(), 16);
500 }
501
502 #[test]
503 fn validate_rejects_bad_params() {
504 assert!(params(0, 1, 0, 8).validate().is_err()); assert!(params(MAX_K_SIM + 1, 1, 0, 8).validate().is_err());
506 assert!(params(4, 0, 0, 8).validate().is_err()); assert!(params(4, 1, 0, 0).validate().is_err()); assert!(params(16, 1000, 64, 96).validate().is_err());
510 assert!(params(4, 20, 16, 96).validate().is_ok());
511 }
512
513 #[test]
514 fn fde_self_retrieval_ranks_first() {
515 let dim = 32usize;
522 let p = params(4, 20, 16, dim as u32); let enc = FdeEncoder::new(&p).unwrap();
524 let mut g = Gen::new(7);
525 let corpus: Vec<Vec<Vec<f32>>> = (0..50)
526 .map(|_| {
527 let n = g.count(4, 16);
528 g.multivec(n, dim)
529 })
530 .collect();
531 let dfde: Vec<Vec<f32>> = corpus.iter().map(|d| enc.encode_doc(d).unwrap()).collect();
532 for (j, d) in corpus.iter().enumerate() {
533 let fq = enc.encode_query(d).unwrap();
534 let top = (0..corpus.len())
535 .max_by(|&a, &b| dot(&fq, &dfde[a]).total_cmp(&dot(&fq, &dfde[b])))
536 .unwrap();
537 assert_eq!(top, j, "doc {j} did not self-retrieve as FDE top-1");
538 }
539 }
540
541 #[test]
542 fn fde_dot_positively_correlates_with_maxsim() {
543 let dim = 32usize;
549 let p = params(4, 24, 16, dim as u32);
550 let enc = FdeEncoder::new(&p).unwrap();
551 let mut g = Gen::new(42);
552
553 let n_pairs = 400;
554 let mut fde_scores = Vec::with_capacity(n_pairs);
555 let mut exact_scores = Vec::with_capacity(n_pairs);
556 for _ in 0..n_pairs {
557 let (qn, dn) = (g.count(2, 6), g.count(4, 16));
558 let q = g.multivec(qn, dim);
559 let d = g.multivec(dn, dim);
560 fde_scores.push(dot(
561 &enc.encode_query(&q).unwrap(),
562 &enc.encode_doc(&d).unwrap(),
563 ));
564 exact_scores.push(maxsim_dot(&q, &d));
565 }
566 let r = pearson(&fde_scores, &exact_scores);
567 assert!(r >= 0.55, "FDE/MaxSim correlation regressed: {r}");
568 }
569
570 #[test]
571 fn deterministic_across_rebuild() {
572 let p = params(4, 8, 8, 16);
575 let e1 = FdeEncoder::new(&p).unwrap();
576 let e2 = FdeEncoder::new(&p).unwrap();
577 let mut g = Gen::new(7);
578 let d = g.multivec(10, 16);
579 assert_eq!(e1.encode_doc(&d).unwrap(), e2.encode_doc(&d).unwrap());
580 let q = g.multivec(3, 16);
581 assert_eq!(e1.encode_query(&q).unwrap(), e2.encode_query(&q).unwrap());
582 }
583
584 #[test]
585 fn different_seed_changes_output() {
586 let mut p = params(4, 8, 8, 16);
587 let e1 = FdeEncoder::new(&p).unwrap();
588 p.seed = DEFAULT_FDE_SEED ^ 0xDEAD_BEEF;
589 let e2 = FdeEncoder::new(&p).unwrap();
590 let mut g = Gen::new(11);
591 let d = g.multivec(10, 16);
592 assert_ne!(e1.encode_doc(&d).unwrap(), e2.encode_doc(&d).unwrap());
593 }
594
595 #[test]
596 fn empty_doc_is_all_zero() {
597 let p = params(4, 4, 8, 16);
598 let enc = FdeEncoder::new(&p).unwrap();
599 let fde = enc.encode_doc(&[]).unwrap();
600 assert_eq!(fde.len(), p.fde_dim());
601 assert!(fde.iter().all(|&x| x == 0.0));
602 }
603
604 #[test]
605 fn empty_query_scores_zero() {
606 let p = params(4, 4, 8, 16);
607 let enc = FdeEncoder::new(&p).unwrap();
608 let mut g = Gen::new(3);
609 let fq = enc.encode_query(&[]).unwrap();
610 let fd = enc.encode_doc(&g.multivec(8, 16)).unwrap();
611 assert_eq!(dot(&fq, &fd), 0.0);
612 }
613
614 #[test]
615 fn dim_mismatch_errors() {
616 let p = params(4, 4, 8, 16);
617 let enc = FdeEncoder::new(&p).unwrap();
618 let bad = vec![vec![1.0f32; 15]]; assert_eq!(
620 enc.encode_doc(&bad),
621 Err(FdeError::DimensionMismatch {
622 got: 15,
623 expected: 16
624 })
625 );
626 assert!(enc.encode_query(&bad).is_err());
627 }
628
629 #[test]
630 fn single_token_doc_fills_all_buckets() {
631 let p = params(3, 1, 0, 8); let enc = FdeEncoder::new(&p).unwrap();
635 let mut g = Gen::new(99);
636 let tok = g.unit_token(8);
637 let fde = enc.encode_doc(&[tok]).unwrap();
638 let pd = p.proj_dim();
639 let first = &fde[0..pd];
640 for bk in 1..p.buckets() {
641 assert_eq!(&fde[bk * pd..(bk + 1) * pd], first, "bucket {bk} differs");
642 }
643 assert!(first.iter().any(|&x| x != 0.0));
644 }
645
646 #[test]
647 fn query_leaves_empty_buckets_zero() {
648 let p = params(3, 1, 0, 8);
651 let enc = FdeEncoder::new(&p).unwrap();
652 let mut g = Gen::new(123);
653 let tok = g.unit_token(8);
654 let fde = enc.encode_query(&[tok]).unwrap();
655 let pd = p.proj_dim();
656 let nonzero_buckets = (0..p.buckets())
657 .filter(|&bk| fde[bk * pd..(bk + 1) * pd].iter().any(|&x| x != 0.0))
658 .count();
659 assert_eq!(nonzero_buckets, 1);
660 }
661
662 #[test]
663 fn free_fns_match_encoder() {
664 let p = params(4, 4, 8, 16);
665 let enc = FdeEncoder::new(&p).unwrap();
666 let mut g = Gen::new(55);
667 let d = g.multivec(6, 16);
668 assert_eq!(encode_doc(&d, &p).unwrap(), enc.encode_doc(&d).unwrap());
669 let q = g.multivec(2, 16);
670 assert_eq!(encode_query(&q, &p).unwrap(), enc.encode_query(&q).unwrap());
671 }
672}