Skip to main content

totalreclaw_core/
reranker.rs

1//! BM25 + Cosine + RRF fusion reranker.
2//!
3//! Matches the TypeScript reranker in the MCP server.
4//!
5//! Parameters:
6//! - BM25: k1=1.2, b=0.75
7//! - RRF: k=60
8//! - Intent-weighted fusion:
9//!   - `intent_score = cosine(query_embedding, fact_embedding)`
10//!   - `bm25_weight = 0.3 + 0.3 * (1 - intent_score)`
11//!   - `cosine_weight = 0.3 + 0.3 * intent_score`
12//!   - `final_score = bm25_weight * rrf_bm25 + cosine_weight * rrf_cosine`
13//!
14//! # Retrieval v2 Tier 1: source-weighted final score
15//!
16//! When [`RerankerConfig::apply_source_weights`] is true the final fused score
17//! is multiplied by a provenance weight derived from [`MemorySource`] (or by
18//! [`LEGACY_CLAIM_FALLBACK_WEIGHT`] when the candidate has no source field).
19//!
20//! See `docs/specs/totalreclaw/retrieval-v2.md` §Tier 1.
21
22use std::collections::HashMap;
23
24use serde::{Deserialize, Serialize};
25
26use crate::claims::MemorySource;
27use crate::Result;
28
29/// BM25 parameters.
30const BM25_K1: f64 = 1.2;
31const BM25_B: f64 = 0.75;
32
33/// RRF fusion parameter.
34const RRF_K: f64 = 60.0;
35
36/// Source-weight multipliers applied to the final fused score when
37/// [`RerankerConfig::apply_source_weights`] is enabled.
38///
39/// Values sourced from `docs/specs/totalreclaw/retrieval-v2.md` §Tier 1. The
40/// array is sorted highest-to-lowest trust; values MUST NOT be edited without
41/// updating the spec + recalibrating via the E13 retrieval benchmark.
42pub const SOURCE_WEIGHTS: &[(MemorySource, f64)] = &[
43    (MemorySource::User, 1.00),
44    (MemorySource::UserInferred, 0.90),
45    (MemorySource::Derived, 0.70),
46    (MemorySource::External, 0.70),
47    (MemorySource::Assistant, 0.55),
48];
49
50/// Fallback weight applied to candidates that have no `source` field.
51///
52/// Used for legacy v0 claims written before Memory Taxonomy v1 introduced
53/// the `source` axis. Value (0.85) sits between `user-inferred` (0.90) and
54/// `derived` / `external` (0.70) — mild penalty without erasing legacy data.
55pub const LEGACY_CLAIM_FALLBACK_WEIGHT: f64 = 0.85;
56
57/// Return the source-weight multiplier for a known [`MemorySource`].
58///
59/// Unknown sources (should not happen once `from_str_lossy` has run) fall back
60/// to [`LEGACY_CLAIM_FALLBACK_WEIGHT`].
61pub fn source_weight(source: MemorySource) -> f64 {
62    SOURCE_WEIGHTS
63        .iter()
64        .find(|(s, _)| *s == source)
65        .map(|(_, w)| *w)
66        .unwrap_or(LEGACY_CLAIM_FALLBACK_WEIGHT)
67}
68
69/// Reranker runtime configuration.
70///
71/// v0 callers can stay with the legacy [`rerank`] API (no source awareness).
72/// v1+ callers use [`rerank_with_config`] with `apply_source_weights = true`
73/// so the final RRF score respects provenance per Retrieval v2 Tier 1.
74#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
75pub struct RerankerConfig {
76    /// When true, multiply the final fused score by the source-weight for each
77    /// candidate. When false, behaviour is identical to the v0 [`rerank`] fn.
78    pub apply_source_weights: bool,
79}
80
81impl Default for RerankerConfig {
82    /// Defaults to v0-compatible behaviour (`apply_source_weights = false`) so
83    /// pre-v1 callers can bump the core version without ranking drift.
84    fn default() -> Self {
85        RerankerConfig {
86            apply_source_weights: false,
87        }
88    }
89}
90
91/// A candidate fact for reranking.
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct Candidate {
94    /// Unique identifier for the fact.
95    pub id: String,
96    /// Decrypted plaintext of the fact.
97    pub text: String,
98    /// Embedding vector of the fact.
99    pub embedding: Vec<f32>,
100    /// Timestamp (passed through to results).
101    pub timestamp: String,
102    /// Optional Memory Taxonomy v1 provenance source.
103    ///
104    /// If present AND [`RerankerConfig::apply_source_weights`] is true, the
105    /// candidate's final score is multiplied by [`source_weight`]. Absent
106    /// source yields [`LEGACY_CLAIM_FALLBACK_WEIGHT`].
107    #[serde(default, skip_serializing_if = "Option::is_none")]
108    pub source: Option<MemorySource>,
109}
110
111/// A reranked result with scores.
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct RankedResult {
114    /// Unique identifier.
115    pub id: String,
116    /// Decrypted plaintext.
117    pub text: String,
118    /// Final fused score (post source-weight multiplication if enabled).
119    pub score: f64,
120    /// BM25 component score.
121    pub bm25_score: f64,
122    /// Cosine similarity score.
123    pub cosine_score: f64,
124    /// Timestamp (passed through from candidate).
125    pub timestamp: String,
126    /// Source-weight multiplier applied to `score` (1.0 when disabled or
127    /// no source field). Useful for diagnostics / parity tests.
128    #[serde(default, skip_serializing_if = "is_one_f64")]
129    pub source_weight: f64,
130}
131
132fn is_one_f64(v: &f64) -> bool {
133    (*v - 1.0).abs() < f64::EPSILON
134}
135
136/// Rerank candidates using BM25 + Cosine + RRF fusion (v0-compatible).
137///
138/// This function does NOT apply source weights — call [`rerank_with_config`]
139/// with `apply_source_weights = true` to enable Retrieval v2 Tier 1.
140///
141/// # Arguments
142/// - `query` — The search query text
143/// - `query_embedding` — The query's embedding vector
144/// - `candidates` — Candidate facts to rerank
145/// - `top_k` — Number of top results to return
146///
147/// # Returns
148/// Top-K results sorted by descending fused score.
149pub fn rerank(
150    query: &str,
151    query_embedding: &[f32],
152    candidates: &[Candidate],
153    top_k: usize,
154) -> Result<Vec<RankedResult>> {
155    rerank_with_config(
156        query,
157        query_embedding,
158        candidates,
159        top_k,
160        RerankerConfig::default(),
161    )
162}
163
164/// Rerank candidates using BM25 + Cosine + RRF fusion, honouring the supplied
165/// [`RerankerConfig`].
166///
167/// When `config.apply_source_weights` is true the final RRF score is
168/// multiplied by the per-candidate [`source_weight`] AFTER fusion and BEFORE
169/// top-k truncation. Candidates with no `source` field receive
170/// [`LEGACY_CLAIM_FALLBACK_WEIGHT`] so v0 vaults still rank sensibly during
171/// the v0→v1 migration window.
172///
173/// All weights are deterministic — per `retrieval-v2.md` §cross-client, the
174/// same inputs MUST produce the same top-k across TS/Python/Rust bindings.
175pub fn rerank_with_config(
176    query: &str,
177    query_embedding: &[f32],
178    candidates: &[Candidate],
179    top_k: usize,
180    config: RerankerConfig,
181) -> Result<Vec<RankedResult>> {
182    if candidates.is_empty() {
183        return Ok(Vec::new());
184    }
185
186    // Tokenize query
187    let query_tokens = tokenize(query);
188
189    // Build document frequency map
190    let mut df: HashMap<String, usize> = HashMap::new();
191    let mut doc_tokens: Vec<Vec<String>> = Vec::with_capacity(candidates.len());
192    let mut total_doc_len: usize = 0;
193
194    for candidate in candidates {
195        let tokens = tokenize(&candidate.text);
196        total_doc_len += tokens.len();
197        for token in &tokens {
198            *df.entry(token.clone()).or_insert(0) += 1;
199        }
200        doc_tokens.push(tokens);
201    }
202
203    let avg_doc_len = total_doc_len as f64 / candidates.len() as f64;
204    let n_docs = candidates.len() as f64;
205
206    // Compute BM25 scores
207    let mut bm25_scores: Vec<f64> = Vec::with_capacity(candidates.len());
208    for tokens in &doc_tokens {
209        let score = bm25_score(&query_tokens, tokens, &df, n_docs, avg_doc_len);
210        bm25_scores.push(score);
211    }
212
213    // Compute cosine similarities
214    let mut cosine_scores: Vec<f64> = Vec::with_capacity(candidates.len());
215    for candidate in candidates {
216        let sim = cosine_similarity_f32(query_embedding, &candidate.embedding);
217        cosine_scores.push(sim);
218    }
219
220    // Compute RRF ranks
221    let bm25_ranks = compute_ranks(&bm25_scores);
222    let cosine_ranks = compute_ranks(&cosine_scores);
223
224    // Intent-weighted fusion
225    let mut results: Vec<RankedResult> = Vec::with_capacity(candidates.len());
226    for (i, candidate) in candidates.iter().enumerate() {
227        let intent_score = cosine_scores[i].clamp(0.0, 1.0);
228        let bm25_weight = 0.3 + 0.3 * (1.0 - intent_score);
229        let cosine_weight = 0.3 + 0.3 * intent_score;
230
231        let rrf_bm25 = 1.0 / (RRF_K + bm25_ranks[i] as f64);
232        let rrf_cosine = 1.0 / (RRF_K + cosine_ranks[i] as f64);
233
234        let fused = bm25_weight * rrf_bm25 + cosine_weight * rrf_cosine;
235
236        // Tier 1 source weighting (post-fusion, pre-truncation).
237        let src_weight = if config.apply_source_weights {
238            match candidate.source {
239                Some(src) => source_weight(src),
240                None => LEGACY_CLAIM_FALLBACK_WEIGHT,
241            }
242        } else {
243            1.0
244        };
245
246        let final_score = fused * src_weight;
247
248        results.push(RankedResult {
249            id: candidate.id.clone(),
250            text: candidate.text.clone(),
251            score: final_score,
252            bm25_score: bm25_scores[i],
253            cosine_score: cosine_scores[i],
254            timestamp: candidate.timestamp.clone(),
255            source_weight: src_weight,
256        });
257    }
258
259    // Sort by descending score, breaking ties deterministically on id so the
260    // cross-client parity guarantee holds even when two candidates collide.
261    results.sort_by(|a, b| {
262        b.score
263            .partial_cmp(&a.score)
264            .unwrap_or(std::cmp::Ordering::Equal)
265            .then_with(|| a.id.cmp(&b.id))
266    });
267
268    // Take top K
269    results.truncate(top_k);
270
271    Ok(results)
272}
273
274/// Compute BM25 score for a single document.
275fn bm25_score(
276    query_tokens: &[String],
277    doc_tokens: &[String],
278    df: &HashMap<String, usize>,
279    n_docs: f64,
280    avg_doc_len: f64,
281) -> f64 {
282    let doc_len = doc_tokens.len() as f64;
283
284    // Count term frequencies in document
285    let mut tf: HashMap<&str, usize> = HashMap::new();
286    for token in doc_tokens {
287        *tf.entry(token.as_str()).or_insert(0) += 1;
288    }
289
290    let mut score = 0.0;
291    for qt in query_tokens {
292        let term_freq = *tf.get(qt.as_str()).unwrap_or(&0) as f64;
293        if term_freq == 0.0 {
294            continue;
295        }
296
297        let doc_freq = *df.get(qt.as_str()).unwrap_or(&0) as f64;
298        // IDF: log((N - df + 0.5) / (df + 0.5) + 1)
299        let idf = ((n_docs - doc_freq + 0.5) / (doc_freq + 0.5) + 1.0).ln();
300
301        // BM25 TF component
302        let tf_component = (term_freq * (BM25_K1 + 1.0))
303            / (term_freq + BM25_K1 * (1.0 - BM25_B + BM25_B * doc_len / avg_doc_len));
304
305        score += idf * tf_component;
306    }
307
308    score
309}
310
311/// Simple tokenization for BM25 (lowercase, split on non-alphanumeric).
312fn tokenize(text: &str) -> Vec<String> {
313    text.to_lowercase()
314        .split(|c: char| !c.is_alphanumeric())
315        .filter(|s| s.len() >= 2)
316        .map(|s| s.to_string())
317        .collect()
318}
319
320/// Cosine similarity between two f32 vectors.
321pub fn cosine_similarity_f32(a: &[f32], b: &[f32]) -> f64 {
322    if a.len() != b.len() || a.is_empty() {
323        return 0.0;
324    }
325
326    let mut dot: f64 = 0.0;
327    let mut norm_a: f64 = 0.0;
328    let mut norm_b: f64 = 0.0;
329
330    for (x, y) in a.iter().zip(b.iter()) {
331        let x = *x as f64;
332        let y = *y as f64;
333        dot += x * y;
334        norm_a += x * x;
335        norm_b += y * y;
336    }
337
338    let denom = norm_a.sqrt() * norm_b.sqrt();
339    if denom == 0.0 {
340        0.0
341    } else {
342        dot / denom
343    }
344}
345
346/// Compute 1-based ranks from scores (highest score = rank 1).
347///
348/// Ties use **competition ranking** (aka "1224" ranking): candidates with
349/// equal scores share the lowest rank in their tied group. This is critical
350/// for the source-weighted reranker — without it, two identical candidates
351/// receive different RRF positions purely because of input-order, which
352/// breaks cross-client parity and the "uniform multiplier preserves order"
353/// invariant.
354fn compute_ranks(scores: &[f64]) -> Vec<usize> {
355    let mut indexed: Vec<(usize, f64)> = scores.iter().copied().enumerate().collect();
356    indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
357
358    let mut ranks = vec![0usize; scores.len()];
359    let mut current_rank = 1usize;
360    for (i, (idx, score)) in indexed.iter().enumerate() {
361        if i > 0 {
362            let prev_score = indexed[i - 1].1;
363            // Only advance the rank if the score strictly dropped from the
364            // previous position. Equal scores share a rank.
365            if (score - prev_score).abs() > 0.0 {
366                current_rank = i + 1;
367            }
368        }
369        ranks[*idx] = current_rank;
370    }
371    ranks
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn test_bm25_basic() {
380        let query_tokens = tokenize("dark mode preference");
381        let doc_tokens = tokenize("The user prefers dark mode in all applications");
382
383        let mut df: HashMap<String, usize> = HashMap::new();
384        for t in &doc_tokens {
385            *df.entry(t.clone()).or_insert(0) += 1;
386        }
387
388        let score = bm25_score(
389            &query_tokens,
390            &doc_tokens,
391            &df,
392            1.0,
393            doc_tokens.len() as f64,
394        );
395        assert!(
396            score > 0.0,
397            "BM25 score should be positive for matching terms"
398        );
399    }
400
401    #[test]
402    fn test_cosine_similarity() {
403        let a = vec![1.0f32, 0.0, 0.0];
404        let b = vec![1.0f32, 0.0, 0.0];
405        assert!((cosine_similarity_f32(&a, &b) - 1.0).abs() < 1e-10);
406
407        let c = vec![0.0f32, 1.0, 0.0];
408        assert!(cosine_similarity_f32(&a, &c).abs() < 1e-10);
409    }
410
411    #[test]
412    fn test_rerank_returns_top_k() {
413        let candidates: Vec<Candidate> = (0..10)
414            .map(|i| Candidate {
415                id: format!("fact_{}", i),
416                text: format!("fact number {} about dark mode preferences", i),
417                embedding: vec![i as f32 / 10.0; 4],
418                timestamp: String::new(),
419                source: None,
420            })
421            .collect();
422
423        let query_embedding = vec![0.5f32; 4];
424        let results = rerank("dark mode", &query_embedding, &candidates, 3).unwrap();
425
426        assert_eq!(results.len(), 3);
427        // Scores should be in descending order
428        for i in 0..results.len() - 1 {
429            assert!(results[i].score >= results[i + 1].score);
430        }
431    }
432
433    #[test]
434    fn test_rerank_empty() {
435        let results = rerank("query", &[0.5f32; 4], &[], 3).unwrap();
436        assert!(results.is_empty());
437    }
438
439    #[test]
440    fn test_intent_weighting() {
441        // High cosine similarity -> higher cosine weight
442        let intent_score = 0.9;
443        let bm25_weight = 0.3 + 0.3 * (1.0 - intent_score);
444        let cosine_weight = 0.3 + 0.3 * intent_score;
445        assert!(cosine_weight > bm25_weight);
446        // bm25_weight + cosine_weight = 0.3 + 0.3*(1-s) + 0.3 + 0.3*s = 0.9
447        assert!(((bm25_weight + cosine_weight) - 0.9_f64).abs() < 1e-10);
448
449        // Low cosine similarity -> higher bm25 weight
450        let intent_score = 0.1;
451        let bm25_weight = 0.3 + 0.3 * (1.0 - intent_score);
452        let cosine_weight = 0.3 + 0.3 * intent_score;
453        assert!(bm25_weight > cosine_weight);
454    }
455
456    // === Retrieval v2 Tier 1: source-weighted reranking ===
457
458    fn cand(id: &str, text: &str, embedding: Vec<f32>, source: Option<MemorySource>) -> Candidate {
459        Candidate {
460            id: id.to_string(),
461            text: text.to_string(),
462            embedding,
463            timestamp: String::new(),
464            source,
465        }
466    }
467
468    #[test]
469    fn test_source_weight_table_matches_spec() {
470        assert_eq!(source_weight(MemorySource::User), 1.00);
471        assert_eq!(source_weight(MemorySource::UserInferred), 0.90);
472        assert_eq!(source_weight(MemorySource::Derived), 0.70);
473        assert_eq!(source_weight(MemorySource::External), 0.70);
474        assert_eq!(source_weight(MemorySource::Assistant), 0.55);
475    }
476
477    #[test]
478    fn test_reranker_config_default_is_v0_compat() {
479        assert!(!RerankerConfig::default().apply_source_weights);
480    }
481
482    #[test]
483    fn test_rerank_source_weight_flag_off_matches_default() {
484        // Two candidates with different sources; flag OFF must ignore the source.
485        let candidates = vec![
486            cand(
487                "u",
488                "dark mode preference",
489                vec![0.9f32, 0.1, 0.0, 0.0],
490                Some(MemorySource::User),
491            ),
492            cand(
493                "a",
494                "dark mode preference",
495                vec![0.9f32, 0.1, 0.0, 0.0],
496                Some(MemorySource::Assistant),
497            ),
498        ];
499        let query_embedding = vec![0.9f32, 0.1, 0.0, 0.0];
500
501        let off = rerank_with_config(
502            "dark mode",
503            &query_embedding,
504            &candidates,
505            10,
506            RerankerConfig {
507                apply_source_weights: false,
508            },
509        )
510        .unwrap();
511        let default = rerank("dark mode", &query_embedding, &candidates, 10).unwrap();
512
513        // Scores must match one-for-one when flag is OFF.
514        assert_eq!(off.len(), default.len());
515        for (a, b) in off.iter().zip(default.iter()) {
516            assert!(
517                (a.score - b.score).abs() < 1e-12,
518                "flag off should equal v0 behaviour"
519            );
520            assert!((a.source_weight - 1.0).abs() < 1e-12, "no weight applied");
521        }
522    }
523
524    #[test]
525    fn test_rerank_source_weight_promotes_user_over_assistant_on_tie() {
526        // Two candidates with IDENTICAL base scores (same text, same embedding).
527        // With flag ON the user-authored fact must outrank the assistant-authored one.
528        let candidates = vec![
529            cand(
530                "a",
531                "dark mode preference",
532                vec![0.9f32, 0.1, 0.0, 0.0],
533                Some(MemorySource::Assistant),
534            ),
535            cand(
536                "u",
537                "dark mode preference",
538                vec![0.9f32, 0.1, 0.0, 0.0],
539                Some(MemorySource::User),
540            ),
541        ];
542        let query_embedding = vec![0.9f32, 0.1, 0.0, 0.0];
543
544        let ranked = rerank_with_config(
545            "dark mode",
546            &query_embedding,
547            &candidates,
548            10,
549            RerankerConfig {
550                apply_source_weights: true,
551            },
552        )
553        .unwrap();
554
555        assert_eq!(ranked.len(), 2);
556        assert_eq!(
557            ranked[0].id, "u",
558            "user source must outrank assistant on base-score tie"
559        );
560        assert_eq!(ranked[1].id, "a");
561        // Sanity-check the per-result source_weight field.
562        assert!((ranked[0].source_weight - 1.00).abs() < 1e-12);
563        assert!((ranked[1].source_weight - 0.55).abs() < 1e-12);
564        // Assistant score must be ~55% of user score on tie.
565        let ratio = ranked[1].score / ranked[0].score;
566        assert!(
567            (ratio - 0.55).abs() < 1e-6,
568            "assistant/user ratio should equal 0.55, got {}",
569            ratio
570        );
571    }
572
573    #[test]
574    fn test_rerank_source_weight_assistant_score_never_zero() {
575        // Spec §Tier 1: "Never drop to zero — all facts remain eligible for
576        // top-k." Assistant source must still retain a positive score.
577        let candidates = vec![cand(
578            "a",
579            "dark mode preference",
580            vec![0.9f32, 0.1, 0.0, 0.0],
581            Some(MemorySource::Assistant),
582        )];
583        let query_embedding = vec![0.9f32, 0.1, 0.0, 0.0];
584        let ranked = rerank_with_config(
585            "dark mode",
586            &query_embedding,
587            &candidates,
588            10,
589            RerankerConfig {
590                apply_source_weights: true,
591            },
592        )
593        .unwrap();
594        assert_eq!(ranked.len(), 1);
595        assert!(
596            ranked[0].score > 0.0,
597            "assistant score must not drop to zero"
598        );
599        assert!((ranked[0].source_weight - 0.55).abs() < 1e-12);
600    }
601
602    #[test]
603    fn test_rerank_source_weight_preserves_base_score_multiplier() {
604        // Invariant: with flag ON, score = fused_score * source_weight.
605        // Easy way to verify: run with flag OFF to get fused_score, then
606        // compare against ON * source_weight.
607        let candidates = vec![
608            cand(
609                "asst",
610                "dark mode preference is set",
611                vec![0.9f32, 0.1, 0.0, 0.0],
612                Some(MemorySource::Assistant),
613            ),
614            cand(
615                "user",
616                "dark mode preference is set",
617                vec![0.9f32, 0.1, 0.0, 0.0],
618                Some(MemorySource::User),
619            ),
620            cand(
621                "derived",
622                "dark mode preference is set",
623                vec![0.9f32, 0.1, 0.0, 0.0],
624                Some(MemorySource::Derived),
625            ),
626            cand(
627                "ext",
628                "dark mode preference is set",
629                vec![0.9f32, 0.1, 0.0, 0.0],
630                Some(MemorySource::External),
631            ),
632            cand(
633                "inferred",
634                "dark mode preference is set",
635                vec![0.9f32, 0.1, 0.0, 0.0],
636                Some(MemorySource::UserInferred),
637            ),
638        ];
639        let query_embedding = vec![0.9f32, 0.1, 0.0, 0.0];
640
641        let off = rerank_with_config(
642            "dark mode preference",
643            &query_embedding,
644            &candidates,
645            10,
646            RerankerConfig {
647                apply_source_weights: false,
648            },
649        )
650        .unwrap();
651        let on = rerank_with_config(
652            "dark mode preference",
653            &query_embedding,
654            &candidates,
655            10,
656            RerankerConfig {
657                apply_source_weights: true,
658            },
659        )
660        .unwrap();
661
662        // Each candidate's ON score == OFF score * source_weight.
663        let off_map: std::collections::HashMap<_, _> =
664            off.iter().map(|r| (r.id.clone(), r.score)).collect();
665        for r in &on {
666            let expected = off_map[&r.id] * r.source_weight;
667            assert!(
668                (r.score - expected).abs() < 1e-12,
669                "id={}: expected score {} * {} = {}, got {}",
670                r.id,
671                off_map[&r.id],
672                r.source_weight,
673                expected,
674                r.score
675            );
676        }
677
678        // And the canonical ordering: user (1.00) > inferred (0.90) > ext/derived
679        // (0.70) > assistant (0.55). All base scores are equal, so source
680        // weight is the only discriminator.
681        let ids: Vec<_> = on.iter().map(|r| r.id.as_str()).collect();
682        assert_eq!(ids[0], "user");
683        assert_eq!(ids[1], "inferred");
684        // derived and ext both at 0.70 — deterministic tie-break on id.
685        assert_eq!(ids[2], "derived");
686        assert_eq!(ids[3], "ext");
687        assert_eq!(ids[4], "asst");
688    }
689
690    #[test]
691    fn test_rerank_legacy_claim_without_source_uses_fallback_weight() {
692        let candidates = vec![
693            cand(
694                "legacy",
695                "dark mode preference",
696                vec![0.9f32, 0.1, 0.0, 0.0],
697                None,
698            ),
699            cand(
700                "asst",
701                "dark mode preference",
702                vec![0.9f32, 0.1, 0.0, 0.0],
703                Some(MemorySource::Assistant),
704            ),
705            cand(
706                "user",
707                "dark mode preference",
708                vec![0.9f32, 0.1, 0.0, 0.0],
709                Some(MemorySource::User),
710            ),
711        ];
712        let query_embedding = vec![0.9f32, 0.1, 0.0, 0.0];
713
714        let ranked = rerank_with_config(
715            "dark mode",
716            &query_embedding,
717            &candidates,
718            10,
719            RerankerConfig {
720                apply_source_weights: true,
721            },
722        )
723        .unwrap();
724
725        // On a three-way tie the legacy fallback (0.85) sits between assistant (0.55)
726        // and user (1.00) — so the ordering MUST be user > legacy > assistant.
727        assert_eq!(ranked[0].id, "user");
728        assert_eq!(ranked[1].id, "legacy");
729        assert_eq!(ranked[2].id, "asst");
730        assert!((ranked[1].source_weight - LEGACY_CLAIM_FALLBACK_WEIGHT).abs() < 1e-12);
731    }
732
733    #[test]
734    fn test_rerank_source_weight_stable_on_all_assistant_candidates() {
735        // If every candidate is assistant-source the ordering must still reflect
736        // the base-score differences (uniform multiplier, no instability).
737        let candidates = vec![
738            cand(
739                "low",
740                "weak signal",
741                vec![0.0f32, 0.0, 1.0, 0.0],
742                Some(MemorySource::Assistant),
743            ),
744            cand(
745                "mid",
746                "medium signal dark mode",
747                vec![0.5f32, 0.5, 0.0, 0.0],
748                Some(MemorySource::Assistant),
749            ),
750            cand(
751                "hi",
752                "very strong dark mode signal",
753                vec![0.9f32, 0.1, 0.0, 0.0],
754                Some(MemorySource::Assistant),
755            ),
756        ];
757        let query_embedding = vec![0.9f32, 0.1, 0.0, 0.0];
758
759        let off = rerank_with_config(
760            "dark mode",
761            &query_embedding,
762            &candidates,
763            10,
764            RerankerConfig {
765                apply_source_weights: false,
766            },
767        )
768        .unwrap();
769        let on = rerank_with_config(
770            "dark mode",
771            &query_embedding,
772            &candidates,
773            10,
774            RerankerConfig {
775                apply_source_weights: true,
776            },
777        )
778        .unwrap();
779
780        // The ORDERING must be identical between flag on & off — uniform multiplier.
781        let ids_off: Vec<_> = off.iter().map(|r| r.id.clone()).collect();
782        let ids_on: Vec<_> = on.iter().map(|r| r.id.clone()).collect();
783        assert_eq!(
784            ids_off, ids_on,
785            "uniform source must not change relative ordering"
786        );
787
788        // And every score in the weighted run must equal the unweighted score times 0.55.
789        for (w, u) in on.iter().zip(off.iter()) {
790            assert!((w.score - u.score * 0.55).abs() < 1e-12);
791            assert!((w.source_weight - 0.55).abs() < 1e-12);
792        }
793    }
794
795    #[test]
796    fn test_rerank_deterministic_id_tiebreak() {
797        // When two candidates produce identical final scores the tiebreak MUST be
798        // deterministic (ascending id) so cross-client parity holds.
799        let candidates = vec![
800            cand(
801                "zzz",
802                "dark mode preference",
803                vec![0.9f32, 0.1, 0.0, 0.0],
804                Some(MemorySource::User),
805            ),
806            cand(
807                "aaa",
808                "dark mode preference",
809                vec![0.9f32, 0.1, 0.0, 0.0],
810                Some(MemorySource::User),
811            ),
812        ];
813        let query_embedding = vec![0.9f32, 0.1, 0.0, 0.0];
814
815        let ranked = rerank_with_config(
816            "dark mode",
817            &query_embedding,
818            &candidates,
819            10,
820            RerankerConfig {
821                apply_source_weights: true,
822            },
823        )
824        .unwrap();
825
826        // Tied scores — ascending id wins.
827        assert_eq!(ranked[0].id, "aaa");
828        assert_eq!(ranked[1].id, "zzz");
829    }
830
831    #[test]
832    fn test_candidate_source_field_serde_roundtrip() {
833        let candidates = vec![
834            Candidate {
835                id: "1".into(),
836                text: "hi".into(),
837                embedding: vec![0.1f32, 0.2],
838                timestamp: "2026-04-17T00:00:00Z".into(),
839                source: Some(MemorySource::User),
840            },
841            Candidate {
842                id: "2".into(),
843                text: "legacy".into(),
844                embedding: vec![0.1f32, 0.2],
845                timestamp: String::new(),
846                source: None,
847            },
848        ];
849        let json = serde_json::to_string(&candidates).unwrap();
850        assert!(json.contains("\"source\":\"user\""));
851        // Legacy candidate should not serialize a null source field (skip_serializing_if).
852        assert!(!json.contains("\"source\":null"));
853        let back: Vec<Candidate> = serde_json::from_str(&json).unwrap();
854        assert_eq!(back.len(), 2);
855        assert_eq!(back[0].source, Some(MemorySource::User));
856        assert_eq!(back[1].source, None);
857    }
858
859    #[test]
860    fn test_rerank_empty_with_flag_on_returns_empty() {
861        let results = rerank_with_config(
862            "query",
863            &[0.5f32; 4],
864            &[],
865            3,
866            RerankerConfig {
867                apply_source_weights: true,
868            },
869        )
870        .unwrap();
871        assert!(results.is_empty());
872    }
873
874    #[test]
875    fn test_ranked_result_preserves_source_weight_field() {
876        let candidates = vec![cand(
877            "u",
878            "hello world",
879            vec![0.5f32, 0.5],
880            Some(MemorySource::User),
881        )];
882        let ranked = rerank_with_config(
883            "hello",
884            &[0.5f32, 0.5],
885            &candidates,
886            10,
887            RerankerConfig {
888                apply_source_weights: true,
889            },
890        )
891        .unwrap();
892        assert_eq!(ranked.len(), 1);
893        assert!((ranked[0].source_weight - 1.0).abs() < 1e-12);
894    }
895}