1use serde::{Deserialize, Serialize};
54
55#[derive(Debug, thiserror::Error, PartialEq, Eq)]
57pub enum FdeError {
58 #[error("muvera: token dimension {got} != configured input_dim {expected}")]
60 DimensionMismatch { got: usize, expected: usize },
61
62 #[error("muvera: invalid params: {0}")]
64 InvalidParams(String),
65}
66
67pub const DEFAULT_FDE_SEED: u64 = 0x9E37_79B9_7F4A_7C15;
71
72const MAX_K_SIM: u32 = 16;
75const MAX_FDE_DIM: usize = 200_000;
76
77const MAX_REPS: u32 = 1024;
86const MAX_PROJ_DIM: u32 = 4096;
87
88#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
91pub struct FdeParams {
92 pub k_sim: u32,
94 pub reps: u32,
96 pub d_proj: u32,
98 pub input_dim: u32,
100 pub seed: u64,
102}
103
104impl FdeParams {
105 #[inline]
107 pub fn proj_dim(&self) -> usize {
108 if self.d_proj == 0 {
109 self.input_dim as usize
110 } else {
111 self.d_proj as usize
112 }
113 }
114
115 #[inline]
117 pub fn buckets(&self) -> usize {
118 1usize.checked_shl(self.k_sim).unwrap_or(0)
122 }
123
124 #[inline]
130 pub fn fde_dim(&self) -> usize {
131 self.buckets()
132 .checked_mul(self.proj_dim())
133 .and_then(|x| x.checked_mul(self.reps as usize))
134 .unwrap_or(usize::MAX)
135 }
136
137 pub fn validate(&self) -> Result<(), FdeError> {
139 if self.k_sim == 0 || self.k_sim > MAX_K_SIM {
140 return Err(FdeError::InvalidParams(format!(
141 "k_sim must be in 1..={MAX_K_SIM}, got {}",
142 self.k_sim
143 )));
144 }
145 if self.reps == 0 || self.reps > MAX_REPS {
146 return Err(FdeError::InvalidParams(format!(
147 "reps must be in 1..={MAX_REPS}, got {}",
148 self.reps
149 )));
150 }
151 if self.input_dim == 0 {
152 return Err(FdeError::InvalidParams(
153 "input_dim must be >= 1".to_string(),
154 ));
155 }
156 if self.d_proj > MAX_PROJ_DIM {
160 return Err(FdeError::InvalidParams(format!(
161 "d_proj must be <= {MAX_PROJ_DIM}, got {}",
162 self.d_proj
163 )));
164 }
165 let dim = self.fde_dim();
166 if dim == 0 || dim > MAX_FDE_DIM {
167 return Err(FdeError::InvalidParams(format!(
168 "fde_dim {dim} out of range (1..={MAX_FDE_DIM}); reduce k_sim/reps/d_proj"
169 )));
170 }
171 Ok(())
172 }
173}
174
175struct SplitMix64 {
179 state: u64,
180}
181
182impl SplitMix64 {
183 #[inline]
184 fn new(seed: u64) -> Self {
185 Self { state: seed }
186 }
187
188 #[inline]
189 fn next_u64(&mut self) -> u64 {
190 self.state = self.state.wrapping_add(0x9E37_79B9_7F4A_7C15);
191 let mut z = self.state;
192 z = (z ^ (z >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
193 z = (z ^ (z >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
194 z ^ (z >> 31)
195 }
196
197 #[inline]
199 fn next_f64(&mut self) -> f64 {
200 (self.next_u64() >> 11) as f64 / (1u64 << 53) as f64
201 }
202
203 #[inline]
205 fn next_gaussian(&mut self) -> f32 {
206 let u1 = self.next_f64().max(1e-12);
208 let u2 = self.next_f64();
209 let r = (-2.0 * u1.ln()).sqrt();
210 (r * (2.0 * std::f64::consts::PI * u2).cos()) as f32
211 }
212}
213
214#[inline]
217fn rep_seed(base: u64, rep: u32) -> u64 {
218 let mut s = base.wrapping_add((rep as u64).wrapping_mul(0xD1B5_4A32_D192_ED03));
219 s = (s ^ (s >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
220 s = (s ^ (s >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
221 s ^ (s >> 31)
222}
223
224struct RepMatrices {
226 hyperplanes: Vec<f32>,
228 projection: Option<Vec<f32>>,
230}
231
232impl RepMatrices {
233 fn build(params: &FdeParams, rep: u32) -> Self {
234 let mut rng = SplitMix64::new(rep_seed(params.seed, rep));
235 let d = params.input_dim as usize;
236 let hyperplanes = (0..params.k_sim as usize * d)
237 .map(|_| rng.next_gaussian())
238 .collect();
239 let projection = if params.d_proj == 0 {
240 None
241 } else {
242 let pd = params.d_proj as usize;
243 let scale = 1.0f32 / (pd as f32).sqrt();
244 let proj = (0..pd * d)
246 .map(|_| {
247 if rng.next_u64() & 1 == 0 {
248 scale
249 } else {
250 -scale
251 }
252 })
253 .collect();
254 Some(proj)
255 };
256 Self {
257 hyperplanes,
258 projection,
259 }
260 }
261
262 #[inline]
264 fn bucket_of(&self, token: &[f32], k_sim: u32, d: usize) -> usize {
265 let mut bucket = 0usize;
266 for h in 0..k_sim as usize {
267 let row = &self.hyperplanes[h * d..(h + 1) * d];
268 let mut dot = 0.0f32;
269 for i in 0..d {
270 dot += row[i] * token[i];
271 }
272 if dot > 0.0 {
273 bucket |= 1 << h;
274 }
275 }
276 bucket
277 }
278
279 #[inline]
281 fn project(&self, token: &[f32], proj_dim: usize, d: usize) -> Vec<f32> {
282 match &self.projection {
283 None => token.to_vec(),
284 Some(p) => {
285 let mut out = vec![0.0f32; proj_dim];
286 for (r, slot) in out.iter_mut().enumerate() {
287 let row = &p[r * d..(r + 1) * d];
288 let mut acc = 0.0f32;
289 for i in 0..d {
290 acc += row[i] * token[i];
291 }
292 *slot = acc;
293 }
294 out
295 }
296 }
297 }
298}
299
300pub struct FdeEncoder {
304 params: FdeParams,
305 reps: Vec<RepMatrices>,
306}
307
308impl FdeEncoder {
309 pub fn new(params: &FdeParams) -> Result<Self, FdeError> {
311 params.validate()?;
312 let reps = (0..params.reps)
313 .map(|r| RepMatrices::build(params, r))
314 .collect();
315 Ok(Self {
316 params: params.clone(),
317 reps,
318 })
319 }
320
321 #[inline]
323 pub fn params(&self) -> &FdeParams {
324 &self.params
325 }
326
327 #[inline]
329 pub fn fde_dim(&self) -> usize {
330 self.params.fde_dim()
331 }
332
333 fn check_tokens(&self, tokens: &[Vec<f32>]) -> Result<(), FdeError> {
334 let d = self.params.input_dim as usize;
335 for tok in tokens {
336 if tok.len() != d {
337 return Err(FdeError::DimensionMismatch {
338 got: tok.len(),
339 expected: d,
340 });
341 }
342 }
343 Ok(())
344 }
345
346 pub fn encode_doc(&self, tokens: &[Vec<f32>]) -> Result<Vec<f32>, FdeError> {
348 self.check_tokens(tokens)?;
349 let pd = self.params.proj_dim();
350 let b = self.params.buckets();
351 let d = self.params.input_dim as usize;
352 let mut out = vec![0.0f32; self.params.fde_dim()];
353
354 for (ri, rep) in self.reps.iter().enumerate() {
355 let base = ri * b * pd;
356 let mut sums = vec![0.0f32; b * pd];
357 let mut counts = vec![0u32; b];
358 for tok in tokens {
359 let bk = rep.bucket_of(tok, self.params.k_sim, d);
360 let proj = rep.project(tok, pd, d);
361 let slot = &mut sums[bk * pd..(bk + 1) * pd];
362 for (s, p) in slot.iter_mut().zip(proj.iter()) {
363 *s += *p;
364 }
365 counts[bk] += 1;
366 }
367 for bk in 0..b {
369 if counts[bk] > 0 {
370 let inv = 1.0f32 / counts[bk] as f32;
371 let dst = &mut out[base + bk * pd..base + (bk + 1) * pd];
372 let src = &sums[bk * pd..(bk + 1) * pd];
373 for (o, s) in dst.iter_mut().zip(src.iter()) {
374 *o = *s * inv;
375 }
376 }
377 }
378 for bk in 0..b {
380 if counts[bk] == 0
381 && let Some(src) = nearest_nonempty(bk, &counts)
382 {
383 let (lo, hi) = (bk.min(src), bk.max(src));
384 let (left, right) = out[base..base + b * pd].split_at_mut(hi * pd);
386 let (src_slice, dst_slice) = if bk == lo {
387 (&right[0..pd], &mut left[bk * pd..bk * pd + pd])
389 } else {
390 (&left[src * pd..src * pd + pd], &mut right[0..pd])
392 };
393 dst_slice.copy_from_slice(src_slice);
394 }
395 }
396 }
397 Ok(out)
398 }
399
400 pub fn encode_query(&self, tokens: &[Vec<f32>]) -> Result<Vec<f32>, FdeError> {
402 self.check_tokens(tokens)?;
403 let pd = self.params.proj_dim();
404 let b = self.params.buckets();
405 let d = self.params.input_dim as usize;
406 let mut out = vec![0.0f32; self.params.fde_dim()];
407
408 for (ri, rep) in self.reps.iter().enumerate() {
409 let base = ri * b * pd;
410 for tok in tokens {
411 let bk = rep.bucket_of(tok, self.params.k_sim, d);
412 let proj = rep.project(tok, pd, d);
413 let dst = &mut out[base + bk * pd..base + (bk + 1) * pd];
414 for (o, p) in dst.iter_mut().zip(proj.iter()) {
415 *o += *p;
416 }
417 }
418 }
419 Ok(out)
420 }
421}
422
423#[inline]
426fn nearest_nonempty(bucket: usize, counts: &[u32]) -> Option<usize> {
427 let mut best: Option<(u32, usize)> = None;
428 for (cand, &c) in counts.iter().enumerate() {
429 if c > 0 {
430 let h = (bucket ^ cand).count_ones();
431 match best {
432 Some((bh, _)) if h >= bh => {}
433 _ => best = Some((h, cand)),
434 }
435 }
436 }
437 best.map(|(_, idx)| idx)
438}
439
440pub fn encode_doc(tokens: &[Vec<f32>], params: &FdeParams) -> Result<Vec<f32>, FdeError> {
443 FdeEncoder::new(params)?.encode_doc(tokens)
444}
445
446pub fn encode_query(tokens: &[Vec<f32>], params: &FdeParams) -> Result<Vec<f32>, FdeError> {
449 FdeEncoder::new(params)?.encode_query(tokens)
450}
451
452#[cfg(test)]
453mod tests {
454 use super::*;
455
456 fn maxsim_dot(query: &[Vec<f32>], doc: &[Vec<f32>]) -> f32 {
460 query
461 .iter()
462 .map(|q| {
463 if doc.is_empty() {
464 0.0
465 } else {
466 doc.iter()
467 .map(|d| dot(q, d))
468 .fold(f32::NEG_INFINITY, f32::max)
469 }
470 })
471 .sum()
472 }
473
474 struct Gen(SplitMix64);
476 impl Gen {
477 fn new(seed: u64) -> Self {
478 Self(SplitMix64::new(seed))
479 }
480 fn unit_token(&mut self, dim: usize) -> Vec<f32> {
481 let mut v: Vec<f32> = (0..dim).map(|_| self.0.next_gaussian()).collect();
482 let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt().max(1e-12);
483 for x in &mut v {
484 *x /= norm;
485 }
486 v
487 }
488 fn multivec(&mut self, n: usize, dim: usize) -> Vec<Vec<f32>> {
489 (0..n).map(|_| self.unit_token(dim)).collect()
490 }
491 fn count(&mut self, lo: usize, hi: usize) -> usize {
492 lo + (self.0.next_u64() as usize) % (hi - lo + 1)
493 }
494 }
495
496 fn params(k_sim: u32, reps: u32, d_proj: u32, input_dim: u32) -> FdeParams {
497 FdeParams {
498 k_sim,
499 reps,
500 d_proj,
501 input_dim,
502 seed: DEFAULT_FDE_SEED,
503 }
504 }
505
506 fn dot(a: &[f32], b: &[f32]) -> f32 {
507 a.iter().zip(b).map(|(x, y)| x * y).sum()
508 }
509
510 fn pearson(xs: &[f32], ys: &[f32]) -> f32 {
511 let n = xs.len() as f32;
512 let mx = xs.iter().sum::<f32>() / n;
513 let my = ys.iter().sum::<f32>() / n;
514 let mut cov = 0.0;
515 let mut vx = 0.0;
516 let mut vy = 0.0;
517 for (x, y) in xs.iter().zip(ys) {
518 let dx = x - mx;
519 let dy = y - my;
520 cov += dx * dy;
521 vx += dx * dx;
522 vy += dy * dy;
523 }
524 cov / (vx.sqrt() * vy.sqrt()).max(1e-12)
525 }
526
527 #[test]
528 fn fde_dim_arithmetic() {
529 assert_eq!(params(4, 20, 16, 96).fde_dim(), 20 * 16 * 16);
530 assert_eq!(params(3, 2, 0, 8).fde_dim(), 2 * 8 * 8);
532 assert_eq!(params(4, 20, 16, 96).buckets(), 16);
533 }
534
535 #[test]
536 fn validate_rejects_bad_params() {
537 assert!(params(0, 1, 0, 8).validate().is_err()); assert!(params(MAX_K_SIM + 1, 1, 0, 8).validate().is_err());
539 assert!(params(4, 0, 0, 8).validate().is_err()); assert!(params(4, 1, 0, 0).validate().is_err()); assert!(params(16, 1000, 64, 96).validate().is_err());
543 assert!(params(4, 20, 16, 96).validate().is_ok());
544 }
545
546 #[test]
547 fn validate_rejects_overflowing_reps_and_d_proj_without_panicking() {
548 assert!(params(16, u32::MAX, u32::MAX, 96).validate().is_err());
554 assert!(params(16, MAX_REPS + 1, 16, 96).validate().is_err());
555 assert!(params(16, 20, MAX_PROJ_DIM + 1, 96).validate().is_err());
556 assert!(
558 params(1, 2_147_516_416, 4_294_901_761, 96)
559 .validate()
560 .is_err()
561 );
562 assert_eq!(params(16, u32::MAX, u32::MAX, 96).fde_dim(), usize::MAX);
564 assert_eq!(params(64, 1, 0, 8).buckets(), 0);
566 assert!(params(16, MAX_REPS, MAX_PROJ_DIM, 96).validate().is_err()); assert!(params(4, MAX_REPS, 16, 96).validate().is_err()); }
570
571 #[test]
572 fn fde_self_retrieval_ranks_first() {
573 let dim = 32usize;
580 let p = params(4, 20, 16, dim as u32); let enc = FdeEncoder::new(&p).unwrap();
582 let mut g = Gen::new(7);
583 let corpus: Vec<Vec<Vec<f32>>> = (0..50)
584 .map(|_| {
585 let n = g.count(4, 16);
586 g.multivec(n, dim)
587 })
588 .collect();
589 let dfde: Vec<Vec<f32>> = corpus.iter().map(|d| enc.encode_doc(d).unwrap()).collect();
590 for (j, d) in corpus.iter().enumerate() {
591 let fq = enc.encode_query(d).unwrap();
592 let top = (0..corpus.len())
593 .max_by(|&a, &b| dot(&fq, &dfde[a]).total_cmp(&dot(&fq, &dfde[b])))
594 .unwrap();
595 assert_eq!(top, j, "doc {j} did not self-retrieve as FDE top-1");
596 }
597 }
598
599 #[test]
600 fn fde_dot_positively_correlates_with_maxsim() {
601 let dim = 32usize;
607 let p = params(4, 24, 16, dim as u32);
608 let enc = FdeEncoder::new(&p).unwrap();
609 let mut g = Gen::new(42);
610
611 let n_pairs = 400;
612 let mut fde_scores = Vec::with_capacity(n_pairs);
613 let mut exact_scores = Vec::with_capacity(n_pairs);
614 for _ in 0..n_pairs {
615 let (qn, dn) = (g.count(2, 6), g.count(4, 16));
616 let q = g.multivec(qn, dim);
617 let d = g.multivec(dn, dim);
618 fde_scores.push(dot(
619 &enc.encode_query(&q).unwrap(),
620 &enc.encode_doc(&d).unwrap(),
621 ));
622 exact_scores.push(maxsim_dot(&q, &d));
623 }
624 let r = pearson(&fde_scores, &exact_scores);
625 assert!(r >= 0.55, "FDE/MaxSim correlation regressed: {r}");
626 }
627
628 #[test]
629 fn deterministic_across_rebuild() {
630 let p = params(4, 8, 8, 16);
633 let e1 = FdeEncoder::new(&p).unwrap();
634 let e2 = FdeEncoder::new(&p).unwrap();
635 let mut g = Gen::new(7);
636 let d = g.multivec(10, 16);
637 assert_eq!(e1.encode_doc(&d).unwrap(), e2.encode_doc(&d).unwrap());
638 let q = g.multivec(3, 16);
639 assert_eq!(e1.encode_query(&q).unwrap(), e2.encode_query(&q).unwrap());
640 }
641
642 #[test]
643 fn different_seed_changes_output() {
644 let mut p = params(4, 8, 8, 16);
645 let e1 = FdeEncoder::new(&p).unwrap();
646 p.seed = DEFAULT_FDE_SEED ^ 0xDEAD_BEEF;
647 let e2 = FdeEncoder::new(&p).unwrap();
648 let mut g = Gen::new(11);
649 let d = g.multivec(10, 16);
650 assert_ne!(e1.encode_doc(&d).unwrap(), e2.encode_doc(&d).unwrap());
651 }
652
653 #[test]
654 fn empty_doc_is_all_zero() {
655 let p = params(4, 4, 8, 16);
656 let enc = FdeEncoder::new(&p).unwrap();
657 let fde = enc.encode_doc(&[]).unwrap();
658 assert_eq!(fde.len(), p.fde_dim());
659 assert!(fde.iter().all(|&x| x == 0.0));
660 }
661
662 #[test]
663 fn empty_query_scores_zero() {
664 let p = params(4, 4, 8, 16);
665 let enc = FdeEncoder::new(&p).unwrap();
666 let mut g = Gen::new(3);
667 let fq = enc.encode_query(&[]).unwrap();
668 let fd = enc.encode_doc(&g.multivec(8, 16)).unwrap();
669 assert_eq!(dot(&fq, &fd), 0.0);
670 }
671
672 #[test]
673 fn dim_mismatch_errors() {
674 let p = params(4, 4, 8, 16);
675 let enc = FdeEncoder::new(&p).unwrap();
676 let bad = vec![vec![1.0f32; 15]]; assert_eq!(
678 enc.encode_doc(&bad),
679 Err(FdeError::DimensionMismatch {
680 got: 15,
681 expected: 16
682 })
683 );
684 assert!(enc.encode_query(&bad).is_err());
685 }
686
687 #[test]
688 fn single_token_doc_fills_all_buckets() {
689 let p = params(3, 1, 0, 8); let enc = FdeEncoder::new(&p).unwrap();
693 let mut g = Gen::new(99);
694 let tok = g.unit_token(8);
695 let fde = enc.encode_doc(&[tok]).unwrap();
696 let pd = p.proj_dim();
697 let first = &fde[0..pd];
698 for bk in 1..p.buckets() {
699 assert_eq!(&fde[bk * pd..(bk + 1) * pd], first, "bucket {bk} differs");
700 }
701 assert!(first.iter().any(|&x| x != 0.0));
702 }
703
704 #[test]
705 fn query_leaves_empty_buckets_zero() {
706 let p = params(3, 1, 0, 8);
709 let enc = FdeEncoder::new(&p).unwrap();
710 let mut g = Gen::new(123);
711 let tok = g.unit_token(8);
712 let fde = enc.encode_query(&[tok]).unwrap();
713 let pd = p.proj_dim();
714 let nonzero_buckets = (0..p.buckets())
715 .filter(|&bk| fde[bk * pd..(bk + 1) * pd].iter().any(|&x| x != 0.0))
716 .count();
717 assert_eq!(nonzero_buckets, 1);
718 }
719
720 #[test]
721 fn free_fns_match_encoder() {
722 let p = params(4, 4, 8, 16);
723 let enc = FdeEncoder::new(&p).unwrap();
724 let mut g = Gen::new(55);
725 let d = g.multivec(6, 16);
726 assert_eq!(encode_doc(&d, &p).unwrap(), enc.encode_doc(&d).unwrap());
727 let q = g.multivec(2, 16);
728 assert_eq!(encode_query(&q, &p).unwrap(), enc.encode_query(&q).unwrap());
729 }
730}