Skip to main content

spool/engine/selector/
mod.rs

1use crate::config::ProjectConfig;
2use crate::domain::{
3    CandidateNote, LifecycleCandidate, MatchedModule, MatchedProject, MatchedScene, MemoryRecord,
4    MemoryScope, Note, RouteInput, ScoredNote,
5};
6use crate::engine::scorer;
7use std::collections::HashSet;
8
9/// Score multiplier applied to relation-expanded candidates (1-hop).
10/// A 30% penalty means expanded items retain 70% of their original score.
11const HOP_PENALTY: f64 = 0.7;
12
13/// Score multiplier applied to cross-project (user/agent/team scope) lifecycle
14/// candidates when a project is matched. Ensures project-scoped memories are
15/// preferred over global ones when both are relevant.
16const CROSS_PROJECT_PENALTY: f64 = 0.6;
17
18/// RRF constant (k=60 is standard in literature).
19#[cfg(feature = "bm25")]
20const RRF_K: f64 = 60.0;
21
22#[cfg(all(feature = "embedding", not(feature = "bm25")))]
23const RRF_K: f64 = 60.0;
24
25pub fn select_scored_notes(
26    project_config: Option<&ProjectConfig>,
27    project: Option<&MatchedProject>,
28    modules: &[MatchedModule],
29    scenes: &[MatchedScene],
30    notes: &[Note],
31    input: &RouteInput,
32    limit: usize,
33) -> Vec<ScoredNote> {
34    let mut scored_notes: Vec<ScoredNote> = notes
35        .iter()
36        .filter_map(|note| {
37            let (score, reasons, score_breakdown, confidence) =
38                scorer::score_note(project_config, project, modules, scenes, note, input);
39            if score <= 0 {
40                return None;
41            }
42            Some(ScoredNote {
43                note: note.clone(),
44                score,
45                reasons,
46                score_breakdown,
47                confidence,
48                excerpt: note.excerpt_for_input(input, 220),
49            })
50        })
51        .collect();
52
53    scored_notes.sort_by(|left, right| {
54        right
55            .score
56            .cmp(&left.score)
57            .then_with(|| left.note.relative_path.cmp(&right.note.relative_path))
58    });
59
60    // Initial top-K selection
61    let initial: Vec<ScoredNote> = scored_notes.iter().take(limit).cloned().collect();
62
63    // Relation expansion: collect wikilinks and related_memory from selected notes
64    let selected_paths: HashSet<String> = initial
65        .iter()
66        .map(|s| s.note.relative_path.clone())
67        .collect();
68    let mut expand_targets: HashSet<String> = HashSet::new();
69    for scored in &initial {
70        // Wikilinks from note content
71        for link in &scored.note.wikilinks {
72            expand_targets.insert(link.to_lowercase());
73        }
74        // related_memory frontmatter (may contain wikilink-style references)
75        if let Some(related) = scored.note.frontmatter.get("related_memory")
76            && let Some(arr) = related.as_array()
77        {
78            for item in arr {
79                if let Some(s) = item.as_str() {
80                    let cleaned = s.trim_start_matches("[[").trim_end_matches("]]");
81                    expand_targets.insert(cleaned.to_lowercase());
82                }
83            }
84        }
85    }
86
87    // Find related notes from the full scored pool that aren't already selected
88    let mut expanded = initial;
89    if !expand_targets.is_empty() {
90        for scored in &scored_notes {
91            if selected_paths.contains(&scored.note.relative_path) {
92                continue;
93            }
94            let title_lc = scored.note.title.to_lowercase();
95            let path_lc = scored.note.relative_path.to_lowercase();
96            let is_related = expand_targets.iter().any(|target| {
97                title_lc.contains(target) || path_lc.contains(target) || target.contains(&title_lc)
98            });
99            if is_related {
100                let penalized_score = ((scored.score as f64) * HOP_PENALTY) as i32;
101                if penalized_score > 0 {
102                    let mut expanded_note = scored.clone();
103                    expanded_note.score = penalized_score;
104                    expanded_note.reasons.push(format!(
105                        "relation-expanded (1-hop, {:.0}% penalty)",
106                        (1.0 - HOP_PENALTY) * 100.0
107                    ));
108                    expanded.push(expanded_note);
109                }
110            }
111        }
112    }
113
114    // Re-sort and re-truncate
115    expanded.sort_by(|left, right| {
116        right
117            .score
118            .cmp(&left.score)
119            .then_with(|| left.note.relative_path.cmp(&right.note.relative_path))
120    });
121    expanded.truncate(limit);
122    expanded
123}
124
125pub fn select_candidates(
126    project_config: Option<&ProjectConfig>,
127    project: Option<&MatchedProject>,
128    modules: &[MatchedModule],
129    scenes: &[MatchedScene],
130    notes: &[Note],
131    input: &RouteInput,
132    limit: usize,
133) -> Vec<CandidateNote> {
134    select_scored_notes(
135        project_config,
136        project,
137        modules,
138        scenes,
139        notes,
140        input,
141        limit,
142    )
143    .into_iter()
144    .map(CandidateNote::from)
145    .collect()
146}
147
148/// 从 lifecycle ledger 的 `(record_id, record)` 列表里打分 + 过滤 + 截断,产出 top-N lifecycle 候选。
149///
150/// `excluded_record_ids` 用于去重:已经被 canonical vault note 覆盖的 record_id 不再作为
151/// lifecycle candidate 返回,避免同一条记忆在 context 里双计(note 一次 + lifecycle candidate 一次)。
152///
153/// `reference_map` 提供 staleness 信息,传 `None` 跳过 staleness 惩罚。
154///
155/// After initial top-K selection, performs 1-hop relation expansion via `related_records`
156/// fields. Expanded candidates receive a 30% score penalty and are merged into the result
157/// before final truncation. Records that didn't score on their own but are referenced by
158/// a top-K record receive a base score derived from the referencing record's score.
159pub fn select_lifecycle_candidates(
160    project: Option<&MatchedProject>,
161    records: &[(String, MemoryRecord)],
162    input: &RouteInput,
163    limit: usize,
164    excluded_record_ids: &HashSet<String>,
165    reference_map: Option<&crate::reference_tracker::ReferenceMap>,
166) -> Vec<LifecycleCandidate> {
167    if limit == 0 {
168        return Vec::new();
169    }
170    let mut candidates: Vec<LifecycleCandidate> = records
171        .iter()
172        .filter(|(record_id, _)| !excluded_record_ids.contains(record_id))
173        .filter_map(|(record_id, record)| {
174            scorer::score_lifecycle_candidate(
175                project,
176                record_id,
177                record,
178                input,
179                reference_map,
180                Some(records),
181            )
182        })
183        .collect();
184
185    // Project-first: when a project is matched, apply a penalty to cross-project
186    // (user/agent/team scope) candidates so project-scoped memories are preferred.
187    if project.is_some() {
188        for candidate in &mut candidates {
189            if matches!(
190                candidate.scope,
191                MemoryScope::User | MemoryScope::Agent | MemoryScope::Team
192            ) {
193                let penalized = ((candidate.score as f64) * CROSS_PROJECT_PENALTY) as i32;
194                if penalized != candidate.score {
195                    candidate.score = penalized;
196                    candidate.reasons.push(format!(
197                        "cross-project penalty ({:.0}%)",
198                        (1.0 - CROSS_PROJECT_PENALTY) * 100.0
199                    ));
200                }
201            }
202        }
203    }
204
205    candidates.sort_by(|left, right| {
206        right
207            .score
208            .cmp(&left.score)
209            .then_with(|| left.record_id.cmp(&right.record_id))
210    });
211
212    // Initial top-K selection
213    let initial: Vec<LifecycleCandidate> = candidates.iter().take(limit).cloned().collect();
214
215    // Relation expansion: collect related_records from selected candidates
216    let selected_ids: HashSet<String> = initial.iter().map(|c| c.record_id.clone()).collect();
217    let candidate_ids: HashSet<String> = candidates.iter().map(|c| c.record_id.clone()).collect();
218    let mut expand_targets: HashSet<String> = HashSet::new();
219    for candidate in &initial {
220        if let Some((_, record)) = records.iter().find(|(id, _)| id == &candidate.record_id) {
221            for related_id in &record.related_records {
222                if !selected_ids.contains(related_id) && !excluded_record_ids.contains(related_id) {
223                    expand_targets.insert(related_id.clone());
224                }
225            }
226        }
227    }
228
229    // Find related candidates: from scored pool OR from raw records (for items that scored 0)
230    let mut expanded = initial;
231    if !expand_targets.is_empty() {
232        for target_id in &expand_targets {
233            // First check if it's in the scored candidates pool
234            if let Some(candidate) = candidates.iter().find(|c| &c.record_id == target_id) {
235                let penalized_score = ((candidate.score as f64) * HOP_PENALTY) as i32;
236                if penalized_score > 0 {
237                    let mut expanded_candidate = candidate.clone();
238                    expanded_candidate.score = penalized_score;
239                    expanded_candidate.reasons.push(format!(
240                        "relation-expanded (1-hop, {:.0}% penalty)",
241                        (1.0 - HOP_PENALTY) * 100.0
242                    ));
243                    expanded.push(expanded_candidate);
244                }
245            } else if !candidate_ids.contains(target_id) {
246                // Record scored 0 on its own but is referenced by a top record.
247                // Create a minimal candidate with a relation-based score.
248                if let Some((_, record)) = records.iter().find(|(id, _)| id == target_id) {
249                    // Use the referencing record's score * HOP_PENALTY as base
250                    let referrer_score = expanded
251                        .iter()
252                        .filter(|c| {
253                            records
254                                .iter()
255                                .find(|(id, _)| id == &c.record_id)
256                                .map(|(_, r)| r.related_records.contains(target_id))
257                                .unwrap_or(false)
258                        })
259                        .map(|c| c.score)
260                        .max()
261                        .unwrap_or(0);
262                    let penalized_score = ((referrer_score as f64) * HOP_PENALTY) as i32;
263                    if penalized_score > 0 {
264                        let confidence = crate::domain::ConfidenceTier::Medium;
265                        expanded.push(LifecycleCandidate {
266                            record_id: target_id.clone(),
267                            title: record.title.clone(),
268                            summary: record.summary.clone(),
269                            memory_type: record.memory_type.clone(),
270                            scope: record.scope,
271                            state: record.state,
272                            score: penalized_score,
273                            reasons: vec![format!(
274                                "relation-expanded (1-hop, {:.0}% penalty, no direct score)",
275                                (1.0 - HOP_PENALTY) * 100.0
276                            )],
277                            project_id: record.project_id.clone(),
278                            confidence,
279                            contradicts: Vec::new(),
280                        });
281                    }
282                }
283            }
284        }
285    }
286
287    // Re-sort and re-truncate
288    expanded.sort_by(|left, right| {
289        right
290            .score
291            .cmp(&left.score)
292            .then_with(|| left.record_id.cmp(&right.record_id))
293    });
294    expanded.truncate(limit);
295    expanded
296}
297
298/// 从 scored notes 里提取 frontmatter `record_id` 作为 lifecycle candidate 的排除集。
299pub fn excluded_record_ids_from_scored(scored: &[ScoredNote]) -> HashSet<String> {
300    scored
301        .iter()
302        .filter_map(|s| {
303            s.note
304                .frontmatter
305                .get("record_id")
306                .and_then(|v| v.as_str())
307                .map(ToString::to_string)
308        })
309        .collect()
310}
311
312/// 当 knowledge (wiki 综合页) 记录处于 accepted / canonical 状态时,其 `related_records`
313/// 里列出的源碎片视为被综合页吸收,不应再作为独立 lifecycle candidate 返回。
314///
315/// 同时包含显式 `supersedes` 字段指向的被替代记录。
316///
317/// Karpathy LLM Wiki 的 "compiled 页优先于源碎片" 语义在此落地。
318pub fn superseded_record_ids(records: &[(String, MemoryRecord)]) -> HashSet<String> {
319    use crate::domain::MemoryLifecycleState;
320
321    let mut superseded: HashSet<String> = HashSet::new();
322    for (_record_id, record) in records {
323        if !matches!(
324            record.state,
325            MemoryLifecycleState::Accepted | MemoryLifecycleState::Canonical
326        ) {
327            continue;
328        }
329        if record.memory_type == "knowledge" {
330            for source_id in &record.related_records {
331                superseded.insert(source_id.clone());
332            }
333        }
334        if let Some(ref replaces) = record.supersedes {
335            superseded.insert(replaces.clone());
336        }
337    }
338    superseded
339}
340
341/// BM25-fused lifecycle candidate selection. When the `bm25` feature is enabled
342/// and an index path is provided, this function runs BM25 search and fuses the
343/// results with the structured scoring using Reciprocal Rank Fusion (RRF).
344///
345/// If BM25 is not available (feature disabled, index missing, or search fails),
346/// falls back to the standard `select_lifecycle_candidates` behavior.
347#[cfg(feature = "bm25")]
348pub fn select_lifecycle_candidates_with_bm25(
349    project: Option<&MatchedProject>,
350    records: &[(String, MemoryRecord)],
351    input: &RouteInput,
352    limit: usize,
353    excluded_record_ids: &HashSet<String>,
354    reference_map: Option<&crate::reference_tracker::ReferenceMap>,
355    bm25_index_path: Option<&std::path::Path>,
356) -> Vec<LifecycleCandidate> {
357    let structured_candidates = select_lifecycle_candidates(
358        project,
359        records,
360        input,
361        limit * 2,
362        excluded_record_ids,
363        reference_map,
364    );
365
366    let Some(index_path) = bm25_index_path else {
367        let mut result = structured_candidates;
368        result.truncate(limit);
369        return result;
370    };
371
372    if !index_path.exists() {
373        let mut result = structured_candidates;
374        result.truncate(limit);
375        return result;
376    }
377
378    let bm25_results = match crate::engine::bm25::Bm25Index::open_or_create(index_path) {
379        Ok(idx) => idx.search(&input.task, limit * 2).unwrap_or_default(),
380        Err(_) => {
381            let mut result = structured_candidates;
382            result.truncate(limit);
383            return result;
384        }
385    };
386
387    if bm25_results.is_empty() {
388        let mut result = structured_candidates;
389        result.truncate(limit);
390        return result;
391    }
392
393    // Build RRF scores
394    let mut rrf_scores: std::collections::HashMap<String, f64> = std::collections::HashMap::new();
395
396    // Structured ranking contribution
397    for (rank, candidate) in structured_candidates.iter().enumerate() {
398        let rrf_score = 1.0 / (RRF_K + (rank as f64) + 1.0);
399        *rrf_scores.entry(candidate.record_id.clone()).or_default() += rrf_score;
400    }
401
402    // BM25 ranking contribution
403    for (rank, (record_id, _score)) in bm25_results.iter().enumerate() {
404        if excluded_record_ids.contains(record_id) {
405            continue;
406        }
407        let rrf_score = 1.0 / (RRF_K + (rank as f64) + 1.0);
408        *rrf_scores.entry(record_id.clone()).or_default() += rrf_score;
409    }
410
411    // Re-rank candidates by RRF score
412    let mut fused: Vec<LifecycleCandidate> = structured_candidates
413        .into_iter()
414        .map(|mut c| {
415            let rrf = rrf_scores.get(&c.record_id).copied().unwrap_or(0.0);
416            // Convert RRF to an integer score (scale by 1000 for precision)
417            c.score = (rrf * 1000.0) as i32;
418            c.reasons
419                .push(format!("RRF fused score (bm25+structured): {:.4}", rrf));
420            c
421        })
422        .collect();
423
424    // Add BM25-only hits that weren't in the structured results
425    let structured_ids: HashSet<String> = fused.iter().map(|c| c.record_id.clone()).collect();
426    for (record_id, _bm25_score) in &bm25_results {
427        if structured_ids.contains(record_id) || excluded_record_ids.contains(record_id) {
428            continue;
429        }
430        if let Some((_, record)) = records.iter().find(|(id, _)| id == record_id) {
431            let rrf = rrf_scores.get(record_id).copied().unwrap_or(0.0);
432            let score = (rrf * 1000.0) as i32;
433            if score > 0 {
434                fused.push(LifecycleCandidate {
435                    record_id: record_id.clone(),
436                    title: record.title.clone(),
437                    summary: record.summary.clone(),
438                    memory_type: record.memory_type.clone(),
439                    scope: record.scope,
440                    state: record.state,
441                    score,
442                    reasons: vec![format!("BM25-only hit, RRF score: {:.4}", rrf)],
443                    project_id: record.project_id.clone(),
444                    confidence: crate::domain::ConfidenceTier::Medium,
445                    contradicts: Vec::new(),
446                });
447            }
448        }
449    }
450
451    fused.sort_by(|left, right| {
452        right
453            .score
454            .cmp(&left.score)
455            .then_with(|| left.record_id.cmp(&right.record_id))
456    });
457    fused.truncate(limit);
458    fused
459}
460
461/// Three-way RRF fusion: structured + BM25 + embedding.
462/// Falls back gracefully when any signal is unavailable.
463#[cfg(feature = "embedding")]
464pub fn select_lifecycle_candidates_fused(
465    project: Option<&MatchedProject>,
466    records: &[(String, MemoryRecord)],
467    input: &RouteInput,
468    limit: usize,
469    excluded_record_ids: &HashSet<String>,
470    reference_map: Option<&crate::reference_tracker::ReferenceMap>,
471    #[cfg(feature = "bm25")] bm25_index_path: Option<&std::path::Path>,
472    embedding_results: &[(String, f32)],
473) -> Vec<LifecycleCandidate> {
474    let structured_candidates = select_lifecycle_candidates(
475        project,
476        records,
477        input,
478        limit * 2,
479        excluded_record_ids,
480        reference_map,
481    );
482
483    #[cfg(feature = "bm25")]
484    let bm25_results: Vec<(String, f32)> = bm25_index_path
485        .filter(|p| p.exists())
486        .and_then(|p| crate::engine::bm25::Bm25Index::open_or_create(p).ok())
487        .and_then(|idx| idx.search(&input.task, limit * 2).ok())
488        .unwrap_or_default();
489
490    #[cfg(not(feature = "bm25"))]
491    let bm25_results: Vec<(String, f32)> = Vec::new();
492
493    let has_bm25 = !bm25_results.is_empty();
494    let has_embedding = !embedding_results.is_empty();
495
496    if !has_bm25 && !has_embedding {
497        let mut result = structured_candidates;
498        result.truncate(limit);
499        return result;
500    }
501
502    let mut rrf_scores: std::collections::HashMap<String, f64> = std::collections::HashMap::new();
503
504    for (rank, candidate) in structured_candidates.iter().enumerate() {
505        let rrf_score = 1.0 / (RRF_K + (rank as f64) + 1.0);
506        *rrf_scores.entry(candidate.record_id.clone()).or_default() += rrf_score;
507    }
508
509    for (rank, (record_id, _)) in bm25_results.iter().enumerate() {
510        if excluded_record_ids.contains(record_id) {
511            continue;
512        }
513        let rrf_score = 1.0 / (RRF_K + (rank as f64) + 1.0);
514        *rrf_scores.entry(record_id.clone()).or_default() += rrf_score;
515    }
516
517    for (rank, (record_id, _)) in embedding_results.iter().enumerate() {
518        if excluded_record_ids.contains(record_id) {
519            continue;
520        }
521        let rrf_score = 1.0 / (RRF_K + (rank as f64) + 1.0);
522        *rrf_scores.entry(record_id.clone()).or_default() += rrf_score;
523    }
524
525    let mut fused: Vec<LifecycleCandidate> = structured_candidates
526        .into_iter()
527        .map(|mut c| {
528            let rrf = rrf_scores.get(&c.record_id).copied().unwrap_or(0.0);
529            c.score = (rrf * 1000.0) as i32;
530            let signals: Vec<&str> = [
531                Some("structured"),
532                if has_bm25 { Some("bm25") } else { None },
533                if has_embedding {
534                    Some("embedding")
535                } else {
536                    None
537                },
538            ]
539            .into_iter()
540            .flatten()
541            .collect();
542            c.reasons
543                .push(format!("RRF fused ({}): {:.4}", signals.join("+"), rrf));
544            c
545        })
546        .collect();
547
548    let structured_ids: HashSet<String> = fused.iter().map(|c| c.record_id.clone()).collect();
549
550    let extra_ids: HashSet<String> = bm25_results
551        .iter()
552        .chain(embedding_results.iter())
553        .map(|(id, _)| id.clone())
554        .filter(|id| !structured_ids.contains(id) && !excluded_record_ids.contains(id))
555        .collect();
556
557    for record_id in &extra_ids {
558        if let Some((_, record)) = records.iter().find(|(id, _)| id == record_id) {
559            let rrf = rrf_scores.get(record_id).copied().unwrap_or(0.0);
560            let score = (rrf * 1000.0) as i32;
561            if score > 0 {
562                fused.push(LifecycleCandidate {
563                    record_id: record_id.clone(),
564                    title: record.title.clone(),
565                    summary: record.summary.clone(),
566                    memory_type: record.memory_type.clone(),
567                    scope: record.scope,
568                    state: record.state,
569                    score,
570                    reasons: vec![format!("RRF extra hit: {:.4}", rrf)],
571                    project_id: record.project_id.clone(),
572                    confidence: crate::domain::ConfidenceTier::Medium,
573                    contradicts: Vec::new(),
574                });
575            }
576        }
577    }
578
579    fused.sort_by(|left, right| {
580        right
581            .score
582            .cmp(&left.score)
583            .then_with(|| left.record_id.cmp(&right.record_id))
584    });
585    fused.truncate(limit);
586    fused
587}
588
589#[cfg(test)]
590mod tests;