1use std::collections::HashMap;
23
24use serde::{Deserialize, Serialize};
25
26use crate::claims::MemorySource;
27use crate::Result;
28
29const BM25_K1: f64 = 1.2;
31const BM25_B: f64 = 0.75;
32
33const RRF_K: f64 = 60.0;
35
36pub const SOURCE_WEIGHTS: &[(MemorySource, f64)] = &[
43 (MemorySource::User, 1.00),
44 (MemorySource::UserInferred, 0.90),
45 (MemorySource::Derived, 0.70),
46 (MemorySource::External, 0.70),
47 (MemorySource::Assistant, 0.55),
48];
49
50pub const LEGACY_CLAIM_FALLBACK_WEIGHT: f64 = 0.85;
56
57pub fn source_weight(source: MemorySource) -> f64 {
62 SOURCE_WEIGHTS
63 .iter()
64 .find(|(s, _)| *s == source)
65 .map(|(_, w)| *w)
66 .unwrap_or(LEGACY_CLAIM_FALLBACK_WEIGHT)
67}
68
69#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
75pub struct RerankerConfig {
76 pub apply_source_weights: bool,
79}
80
81impl Default for RerankerConfig {
82 fn default() -> Self {
85 RerankerConfig {
86 apply_source_weights: false,
87 }
88 }
89}
90
91#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct Candidate {
94 pub id: String,
96 pub text: String,
98 pub embedding: Vec<f32>,
100 pub timestamp: String,
102 #[serde(default, skip_serializing_if = "Option::is_none")]
108 pub source: Option<MemorySource>,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct RankedResult {
114 pub id: String,
116 pub text: String,
118 pub score: f64,
120 pub bm25_score: f64,
122 pub cosine_score: f64,
124 pub timestamp: String,
126 #[serde(default, skip_serializing_if = "is_one_f64")]
129 pub source_weight: f64,
130}
131
132fn is_one_f64(v: &f64) -> bool {
133 (*v - 1.0).abs() < f64::EPSILON
134}
135
136pub fn rerank(
150 query: &str,
151 query_embedding: &[f32],
152 candidates: &[Candidate],
153 top_k: usize,
154) -> Result<Vec<RankedResult>> {
155 rerank_with_config(
156 query,
157 query_embedding,
158 candidates,
159 top_k,
160 RerankerConfig::default(),
161 )
162}
163
164pub fn rerank_with_config(
176 query: &str,
177 query_embedding: &[f32],
178 candidates: &[Candidate],
179 top_k: usize,
180 config: RerankerConfig,
181) -> Result<Vec<RankedResult>> {
182 if candidates.is_empty() {
183 return Ok(Vec::new());
184 }
185
186 let query_tokens = tokenize(query);
188
189 let mut df: HashMap<String, usize> = HashMap::new();
191 let mut doc_tokens: Vec<Vec<String>> = Vec::with_capacity(candidates.len());
192 let mut total_doc_len: usize = 0;
193
194 for candidate in candidates {
195 let tokens = tokenize(&candidate.text);
196 total_doc_len += tokens.len();
197 for token in &tokens {
198 *df.entry(token.clone()).or_insert(0) += 1;
199 }
200 doc_tokens.push(tokens);
201 }
202
203 let avg_doc_len = total_doc_len as f64 / candidates.len() as f64;
204 let n_docs = candidates.len() as f64;
205
206 let mut bm25_scores: Vec<f64> = Vec::with_capacity(candidates.len());
208 for tokens in &doc_tokens {
209 let score = bm25_score(&query_tokens, tokens, &df, n_docs, avg_doc_len);
210 bm25_scores.push(score);
211 }
212
213 let mut cosine_scores: Vec<f64> = Vec::with_capacity(candidates.len());
215 for candidate in candidates {
216 let sim = cosine_similarity_f32(query_embedding, &candidate.embedding);
217 cosine_scores.push(sim);
218 }
219
220 let bm25_ranks = compute_ranks(&bm25_scores);
222 let cosine_ranks = compute_ranks(&cosine_scores);
223
224 let mut results: Vec<RankedResult> = Vec::with_capacity(candidates.len());
226 for (i, candidate) in candidates.iter().enumerate() {
227 let intent_score = cosine_scores[i].clamp(0.0, 1.0);
228 let bm25_weight = 0.3 + 0.3 * (1.0 - intent_score);
229 let cosine_weight = 0.3 + 0.3 * intent_score;
230
231 let rrf_bm25 = 1.0 / (RRF_K + bm25_ranks[i] as f64);
232 let rrf_cosine = 1.0 / (RRF_K + cosine_ranks[i] as f64);
233
234 let fused = bm25_weight * rrf_bm25 + cosine_weight * rrf_cosine;
235
236 let src_weight = if config.apply_source_weights {
238 match candidate.source {
239 Some(src) => source_weight(src),
240 None => LEGACY_CLAIM_FALLBACK_WEIGHT,
241 }
242 } else {
243 1.0
244 };
245
246 let final_score = fused * src_weight;
247
248 results.push(RankedResult {
249 id: candidate.id.clone(),
250 text: candidate.text.clone(),
251 score: final_score,
252 bm25_score: bm25_scores[i],
253 cosine_score: cosine_scores[i],
254 timestamp: candidate.timestamp.clone(),
255 source_weight: src_weight,
256 });
257 }
258
259 results.sort_by(|a, b| {
262 b.score
263 .partial_cmp(&a.score)
264 .unwrap_or(std::cmp::Ordering::Equal)
265 .then_with(|| a.id.cmp(&b.id))
266 });
267
268 results.truncate(top_k);
270
271 Ok(results)
272}
273
274fn bm25_score(
276 query_tokens: &[String],
277 doc_tokens: &[String],
278 df: &HashMap<String, usize>,
279 n_docs: f64,
280 avg_doc_len: f64,
281) -> f64 {
282 let doc_len = doc_tokens.len() as f64;
283
284 let mut tf: HashMap<&str, usize> = HashMap::new();
286 for token in doc_tokens {
287 *tf.entry(token.as_str()).or_insert(0) += 1;
288 }
289
290 let mut score = 0.0;
291 for qt in query_tokens {
292 let term_freq = *tf.get(qt.as_str()).unwrap_or(&0) as f64;
293 if term_freq == 0.0 {
294 continue;
295 }
296
297 let doc_freq = *df.get(qt.as_str()).unwrap_or(&0) as f64;
298 let idf = ((n_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1.0).ln();
300
301 let tf_component = (term_freq * (BM25_K1 + 1.0))
303 / (term_freq + BM25_K1 * (1.0 - BM25_B + BM25_B * doc_len / avg_doc_len));
304
305 score += idf * tf_component;
306 }
307
308 score
309}
310
311fn tokenize(text: &str) -> Vec<String> {
313 text.to_lowercase()
314 .split(|c: char| !c.is_alphanumeric())
315 .filter(|s| s.len() >= 2)
316 .map(|s| s.to_string())
317 .collect()
318}
319
320pub fn cosine_similarity_f32(a: &[f32], b: &[f32]) -> f64 {
322 if a.len() != b.len() || a.is_empty() {
323 return 0.0;
324 }
325
326 let mut dot: f64 = 0.0;
327 let mut norm_a: f64 = 0.0;
328 let mut norm_b: f64 = 0.0;
329
330 for (x, y) in a.iter().zip(b.iter()) {
331 let x = *x as f64;
332 let y = *y as f64;
333 dot += x * y;
334 norm_a += x * x;
335 norm_b += y * y;
336 }
337
338 let denom = norm_a.sqrt() * norm_b.sqrt();
339 if denom == 0.0 {
340 0.0
341 } else {
342 dot / denom
343 }
344}
345
346fn compute_ranks(scores: &[f64]) -> Vec<usize> {
355 let mut indexed: Vec<(usize, f64)> = scores.iter().copied().enumerate().collect();
356 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
357
358 let mut ranks = vec![0usize; scores.len()];
359 let mut current_rank = 1usize;
360 for (i, (idx, score)) in indexed.iter().enumerate() {
361 if i > 0 {
362 let prev_score = indexed[i - 1].1;
363 if (score - prev_score).abs() > 0.0 {
366 current_rank = i + 1;
367 }
368 }
369 ranks[*idx] = current_rank;
370 }
371 ranks
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377
378 #[test]
379 fn test_bm25_basic() {
380 let query_tokens = tokenize("dark mode preference");
381 let doc_tokens = tokenize("The user prefers dark mode in all applications");
382
383 let mut df: HashMap<String, usize> = HashMap::new();
384 for t in &doc_tokens {
385 *df.entry(t.clone()).or_insert(0) += 1;
386 }
387
388 let score = bm25_score(
389 &query_tokens,
390 &doc_tokens,
391 &df,
392 1.0,
393 doc_tokens.len() as f64,
394 );
395 assert!(
396 score > 0.0,
397 "BM25 score should be positive for matching terms"
398 );
399 }
400
401 #[test]
402 fn test_cosine_similarity() {
403 let a = vec![1.0f32, 0.0, 0.0];
404 let b = vec![1.0f32, 0.0, 0.0];
405 assert!((cosine_similarity_f32(&a, &b) - 1.0).abs() < 1e-10);
406
407 let c = vec![0.0f32, 1.0, 0.0];
408 assert!(cosine_similarity_f32(&a, &c).abs() < 1e-10);
409 }
410
411 #[test]
412 fn test_rerank_returns_top_k() {
413 let candidates: Vec<Candidate> = (0..10)
414 .map(|i| Candidate {
415 id: format!("fact_{}", i),
416 text: format!("fact number {} about dark mode preferences", i),
417 embedding: vec![i as f32 / 10.0; 4],
418 timestamp: String::new(),
419 source: None,
420 })
421 .collect();
422
423 let query_embedding = vec![0.5f32; 4];
424 let results = rerank("dark mode", &query_embedding, &candidates, 3).unwrap();
425
426 assert_eq!(results.len(), 3);
427 for i in 0..results.len() - 1 {
429 assert!(results[i].score >= results[i + 1].score);
430 }
431 }
432
433 #[test]
434 fn test_rerank_empty() {
435 let results = rerank("query", &[0.5f32; 4], &[], 3).unwrap();
436 assert!(results.is_empty());
437 }
438
439 #[test]
440 fn test_intent_weighting() {
441 let intent_score = 0.9;
443 let bm25_weight = 0.3 + 0.3 * (1.0 - intent_score);
444 let cosine_weight = 0.3 + 0.3 * intent_score;
445 assert!(cosine_weight > bm25_weight);
446 assert!(((bm25_weight + cosine_weight) - 0.9_f64).abs() < 1e-10);
448
449 let intent_score = 0.1;
451 let bm25_weight = 0.3 + 0.3 * (1.0 - intent_score);
452 let cosine_weight = 0.3 + 0.3 * intent_score;
453 assert!(bm25_weight > cosine_weight);
454 }
455
456 fn cand(id: &str, text: &str, embedding: Vec<f32>, source: Option<MemorySource>) -> Candidate {
459 Candidate {
460 id: id.to_string(),
461 text: text.to_string(),
462 embedding,
463 timestamp: String::new(),
464 source,
465 }
466 }
467
468 #[test]
469 fn test_source_weight_table_matches_spec() {
470 assert_eq!(source_weight(MemorySource::User), 1.00);
471 assert_eq!(source_weight(MemorySource::UserInferred), 0.90);
472 assert_eq!(source_weight(MemorySource::Derived), 0.70);
473 assert_eq!(source_weight(MemorySource::External), 0.70);
474 assert_eq!(source_weight(MemorySource::Assistant), 0.55);
475 }
476
477 #[test]
478 fn test_reranker_config_default_is_v0_compat() {
479 assert!(!RerankerConfig::default().apply_source_weights);
480 }
481
482 #[test]
483 fn test_rerank_source_weight_flag_off_matches_default() {
484 let candidates = vec![
486 cand(
487 "u",
488 "dark mode preference",
489 vec![0.9f32, 0.1, 0.0, 0.0],
490 Some(MemorySource::User),
491 ),
492 cand(
493 "a",
494 "dark mode preference",
495 vec![0.9f32, 0.1, 0.0, 0.0],
496 Some(MemorySource::Assistant),
497 ),
498 ];
499 let query_embedding = vec![0.9f32, 0.1, 0.0, 0.0];
500
501 let off = rerank_with_config(
502 "dark mode",
503 &query_embedding,
504 &candidates,
505 10,
506 RerankerConfig {
507 apply_source_weights: false,
508 },
509 )
510 .unwrap();
511 let default = rerank("dark mode", &query_embedding, &candidates, 10).unwrap();
512
513 assert_eq!(off.len(), default.len());
515 for (a, b) in off.iter().zip(default.iter()) {
516 assert!(
517 (a.score - b.score).abs() < 1e-12,
518 "flag off should equal v0 behaviour"
519 );
520 assert!((a.source_weight - 1.0).abs() < 1e-12, "no weight applied");
521 }
522 }
523
524 #[test]
525 fn test_rerank_source_weight_promotes_user_over_assistant_on_tie() {
526 let candidates = vec![
529 cand(
530 "a",
531 "dark mode preference",
532 vec![0.9f32, 0.1, 0.0, 0.0],
533 Some(MemorySource::Assistant),
534 ),
535 cand(
536 "u",
537 "dark mode preference",
538 vec![0.9f32, 0.1, 0.0, 0.0],
539 Some(MemorySource::User),
540 ),
541 ];
542 let query_embedding = vec![0.9f32, 0.1, 0.0, 0.0];
543
544 let ranked = rerank_with_config(
545 "dark mode",
546 &query_embedding,
547 &candidates,
548 10,
549 RerankerConfig {
550 apply_source_weights: true,
551 },
552 )
553 .unwrap();
554
555 assert_eq!(ranked.len(), 2);
556 assert_eq!(
557 ranked[0].id, "u",
558 "user source must outrank assistant on base-score tie"
559 );
560 assert_eq!(ranked[1].id, "a");
561 assert!((ranked[0].source_weight - 1.00).abs() < 1e-12);
563 assert!((ranked[1].source_weight - 0.55).abs() < 1e-12);
564 let ratio = ranked[1].score / ranked[0].score;
566 assert!(
567 (ratio - 0.55).abs() < 1e-6,
568 "assistant/user ratio should equal 0.55, got {}",
569 ratio
570 );
571 }
572
573 #[test]
574 fn test_rerank_source_weight_assistant_score_never_zero() {
575 let candidates = vec![cand(
578 "a",
579 "dark mode preference",
580 vec![0.9f32, 0.1, 0.0, 0.0],
581 Some(MemorySource::Assistant),
582 )];
583 let query_embedding = vec![0.9f32, 0.1, 0.0, 0.0];
584 let ranked = rerank_with_config(
585 "dark mode",
586 &query_embedding,
587 &candidates,
588 10,
589 RerankerConfig {
590 apply_source_weights: true,
591 },
592 )
593 .unwrap();
594 assert_eq!(ranked.len(), 1);
595 assert!(
596 ranked[0].score > 0.0,
597 "assistant score must not drop to zero"
598 );
599 assert!((ranked[0].source_weight - 0.55).abs() < 1e-12);
600 }
601
602 #[test]
603 fn test_rerank_source_weight_preserves_base_score_multiplier() {
604 let candidates = vec![
608 cand(
609 "asst",
610 "dark mode preference is set",
611 vec![0.9f32, 0.1, 0.0, 0.0],
612 Some(MemorySource::Assistant),
613 ),
614 cand(
615 "user",
616 "dark mode preference is set",
617 vec![0.9f32, 0.1, 0.0, 0.0],
618 Some(MemorySource::User),
619 ),
620 cand(
621 "derived",
622 "dark mode preference is set",
623 vec![0.9f32, 0.1, 0.0, 0.0],
624 Some(MemorySource::Derived),
625 ),
626 cand(
627 "ext",
628 "dark mode preference is set",
629 vec![0.9f32, 0.1, 0.0, 0.0],
630 Some(MemorySource::External),
631 ),
632 cand(
633 "inferred",
634 "dark mode preference is set",
635 vec![0.9f32, 0.1, 0.0, 0.0],
636 Some(MemorySource::UserInferred),
637 ),
638 ];
639 let query_embedding = vec![0.9f32, 0.1, 0.0, 0.0];
640
641 let off = rerank_with_config(
642 "dark mode preference",
643 &query_embedding,
644 &candidates,
645 10,
646 RerankerConfig {
647 apply_source_weights: false,
648 },
649 )
650 .unwrap();
651 let on = rerank_with_config(
652 "dark mode preference",
653 &query_embedding,
654 &candidates,
655 10,
656 RerankerConfig {
657 apply_source_weights: true,
658 },
659 )
660 .unwrap();
661
662 let off_map: std::collections::HashMap<_, _> =
664 off.iter().map(|r| (r.id.clone(), r.score)).collect();
665 for r in &on {
666 let expected = off_map[&r.id] * r.source_weight;
667 assert!(
668 (r.score - expected).abs() < 1e-12,
669 "id={}: expected score {} * {} = {}, got {}",
670 r.id,
671 off_map[&r.id],
672 r.source_weight,
673 expected,
674 r.score
675 );
676 }
677
678 let ids: Vec<_> = on.iter().map(|r| r.id.as_str()).collect();
682 assert_eq!(ids[0], "user");
683 assert_eq!(ids[1], "inferred");
684 assert_eq!(ids[2], "derived");
686 assert_eq!(ids[3], "ext");
687 assert_eq!(ids[4], "asst");
688 }
689
690 #[test]
691 fn test_rerank_legacy_claim_without_source_uses_fallback_weight() {
692 let candidates = vec![
693 cand(
694 "legacy",
695 "dark mode preference",
696 vec![0.9f32, 0.1, 0.0, 0.0],
697 None,
698 ),
699 cand(
700 "asst",
701 "dark mode preference",
702 vec![0.9f32, 0.1, 0.0, 0.0],
703 Some(MemorySource::Assistant),
704 ),
705 cand(
706 "user",
707 "dark mode preference",
708 vec![0.9f32, 0.1, 0.0, 0.0],
709 Some(MemorySource::User),
710 ),
711 ];
712 let query_embedding = vec![0.9f32, 0.1, 0.0, 0.0];
713
714 let ranked = rerank_with_config(
715 "dark mode",
716 &query_embedding,
717 &candidates,
718 10,
719 RerankerConfig {
720 apply_source_weights: true,
721 },
722 )
723 .unwrap();
724
725 assert_eq!(ranked[0].id, "user");
728 assert_eq!(ranked[1].id, "legacy");
729 assert_eq!(ranked[2].id, "asst");
730 assert!((ranked[1].source_weight - LEGACY_CLAIM_FALLBACK_WEIGHT).abs() < 1e-12);
731 }
732
733 #[test]
734 fn test_rerank_source_weight_stable_on_all_assistant_candidates() {
735 let candidates = vec![
738 cand(
739 "low",
740 "weak signal",
741 vec![0.0f32, 0.0, 1.0, 0.0],
742 Some(MemorySource::Assistant),
743 ),
744 cand(
745 "mid",
746 "medium signal dark mode",
747 vec![0.5f32, 0.5, 0.0, 0.0],
748 Some(MemorySource::Assistant),
749 ),
750 cand(
751 "hi",
752 "very strong dark mode signal",
753 vec![0.9f32, 0.1, 0.0, 0.0],
754 Some(MemorySource::Assistant),
755 ),
756 ];
757 let query_embedding = vec![0.9f32, 0.1, 0.0, 0.0];
758
759 let off = rerank_with_config(
760 "dark mode",
761 &query_embedding,
762 &candidates,
763 10,
764 RerankerConfig {
765 apply_source_weights: false,
766 },
767 )
768 .unwrap();
769 let on = rerank_with_config(
770 "dark mode",
771 &query_embedding,
772 &candidates,
773 10,
774 RerankerConfig {
775 apply_source_weights: true,
776 },
777 )
778 .unwrap();
779
780 let ids_off: Vec<_> = off.iter().map(|r| r.id.clone()).collect();
782 let ids_on: Vec<_> = on.iter().map(|r| r.id.clone()).collect();
783 assert_eq!(
784 ids_off, ids_on,
785 "uniform source must not change relative ordering"
786 );
787
788 for (w, u) in on.iter().zip(off.iter()) {
790 assert!((w.score - u.score * 0.55).abs() < 1e-12);
791 assert!((w.source_weight - 0.55).abs() < 1e-12);
792 }
793 }
794
795 #[test]
796 fn test_rerank_deterministic_id_tiebreak() {
797 let candidates = vec![
800 cand(
801 "zzz",
802 "dark mode preference",
803 vec![0.9f32, 0.1, 0.0, 0.0],
804 Some(MemorySource::User),
805 ),
806 cand(
807 "aaa",
808 "dark mode preference",
809 vec![0.9f32, 0.1, 0.0, 0.0],
810 Some(MemorySource::User),
811 ),
812 ];
813 let query_embedding = vec![0.9f32, 0.1, 0.0, 0.0];
814
815 let ranked = rerank_with_config(
816 "dark mode",
817 &query_embedding,
818 &candidates,
819 10,
820 RerankerConfig {
821 apply_source_weights: true,
822 },
823 )
824 .unwrap();
825
826 assert_eq!(ranked[0].id, "aaa");
828 assert_eq!(ranked[1].id, "zzz");
829 }
830
831 #[test]
832 fn test_candidate_source_field_serde_roundtrip() {
833 let candidates = vec![
834 Candidate {
835 id: "1".into(),
836 text: "hi".into(),
837 embedding: vec![0.1f32, 0.2],
838 timestamp: "2026-04-17T00:00:00Z".into(),
839 source: Some(MemorySource::User),
840 },
841 Candidate {
842 id: "2".into(),
843 text: "legacy".into(),
844 embedding: vec![0.1f32, 0.2],
845 timestamp: String::new(),
846 source: None,
847 },
848 ];
849 let json = serde_json::to_string(&candidates).unwrap();
850 assert!(json.contains("\"source\":\"user\""));
851 assert!(!json.contains("\"source\":null"));
853 let back: Vec<Candidate> = serde_json::from_str(&json).unwrap();
854 assert_eq!(back.len(), 2);
855 assert_eq!(back[0].source, Some(MemorySource::User));
856 assert_eq!(back[1].source, None);
857 }
858
859 #[test]
860 fn test_rerank_empty_with_flag_on_returns_empty() {
861 let results = rerank_with_config(
862 "query",
863 &[0.5f32; 4],
864 &[],
865 3,
866 RerankerConfig {
867 apply_source_weights: true,
868 },
869 )
870 .unwrap();
871 assert!(results.is_empty());
872 }
873
874 #[test]
875 fn test_ranked_result_preserves_source_weight_field() {
876 let candidates = vec![cand(
877 "u",
878 "hello world",
879 vec![0.5f32, 0.5],
880 Some(MemorySource::User),
881 )];
882 let ranked = rerank_with_config(
883 "hello",
884 &[0.5f32, 0.5],
885 &candidates,
886 10,
887 RerankerConfig {
888 apply_source_weights: true,
889 },
890 )
891 .unwrap();
892 assert_eq!(ranked.len(), 1);
893 assert!((ranked[0].source_weight - 1.0).abs() < 1e-12);
894 }
895}