1use crate::config::ProjectConfig;
2use crate::domain::{
3 CandidateNote, LifecycleCandidate, MatchedModule, MatchedProject, MatchedScene, MemoryRecord,
4 MemoryScope, Note, RouteInput, ScoredNote,
5};
6use crate::engine::scorer;
7use std::collections::HashSet;
8
9const HOP_PENALTY: f64 = 0.7;
12
13const CROSS_PROJECT_PENALTY: f64 = 0.6;
17
18#[cfg(feature = "bm25")]
20const RRF_K: f64 = 60.0;
21
22#[cfg(all(feature = "embedding", not(feature = "bm25")))]
23const RRF_K: f64 = 60.0;
24
25pub fn select_scored_notes(
26 project_config: Option<&ProjectConfig>,
27 project: Option<&MatchedProject>,
28 modules: &[MatchedModule],
29 scenes: &[MatchedScene],
30 notes: &[Note],
31 input: &RouteInput,
32 limit: usize,
33) -> Vec<ScoredNote> {
34 let mut scored_notes: Vec<ScoredNote> = notes
35 .iter()
36 .filter_map(|note| {
37 let (score, reasons, score_breakdown, confidence) =
38 scorer::score_note(project_config, project, modules, scenes, note, input);
39 if score <= 0 {
40 return None;
41 }
42 Some(ScoredNote {
43 note: note.clone(),
44 score,
45 reasons,
46 score_breakdown,
47 confidence,
48 excerpt: note.excerpt_for_input(input, 220),
49 })
50 })
51 .collect();
52
53 scored_notes.sort_by(|left, right| {
54 right
55 .score
56 .cmp(&left.score)
57 .then_with(|| left.note.relative_path.cmp(&right.note.relative_path))
58 });
59
60 let initial: Vec<ScoredNote> = scored_notes.iter().take(limit).cloned().collect();
62
63 let selected_paths: HashSet<String> = initial
65 .iter()
66 .map(|s| s.note.relative_path.clone())
67 .collect();
68 let mut expand_targets: HashSet<String> = HashSet::new();
69 for scored in &initial {
70 for link in &scored.note.wikilinks {
72 expand_targets.insert(link.to_lowercase());
73 }
74 if let Some(related) = scored.note.frontmatter.get("related_memory")
76 && let Some(arr) = related.as_array()
77 {
78 for item in arr {
79 if let Some(s) = item.as_str() {
80 let cleaned = s.trim_start_matches("[[").trim_end_matches("]]");
81 expand_targets.insert(cleaned.to_lowercase());
82 }
83 }
84 }
85 }
86
87 let mut expanded = initial;
89 if !expand_targets.is_empty() {
90 for scored in &scored_notes {
91 if selected_paths.contains(&scored.note.relative_path) {
92 continue;
93 }
94 let title_lc = scored.note.title.to_lowercase();
95 let path_lc = scored.note.relative_path.to_lowercase();
96 let is_related = expand_targets.iter().any(|target| {
97 title_lc.contains(target) || path_lc.contains(target) || target.contains(&title_lc)
98 });
99 if is_related {
100 let penalized_score = ((scored.score as f64) * HOP_PENALTY) as i32;
101 if penalized_score > 0 {
102 let mut expanded_note = scored.clone();
103 expanded_note.score = penalized_score;
104 expanded_note.reasons.push(format!(
105 "relation-expanded (1-hop, {:.0}% penalty)",
106 (1.0 - HOP_PENALTY) * 100.0
107 ));
108 expanded.push(expanded_note);
109 }
110 }
111 }
112 }
113
114 expanded.sort_by(|left, right| {
116 right
117 .score
118 .cmp(&left.score)
119 .then_with(|| left.note.relative_path.cmp(&right.note.relative_path))
120 });
121 expanded.truncate(limit);
122 expanded
123}
124
125pub fn select_candidates(
126 project_config: Option<&ProjectConfig>,
127 project: Option<&MatchedProject>,
128 modules: &[MatchedModule],
129 scenes: &[MatchedScene],
130 notes: &[Note],
131 input: &RouteInput,
132 limit: usize,
133) -> Vec<CandidateNote> {
134 select_scored_notes(
135 project_config,
136 project,
137 modules,
138 scenes,
139 notes,
140 input,
141 limit,
142 )
143 .into_iter()
144 .map(CandidateNote::from)
145 .collect()
146}
147
148pub fn select_lifecycle_candidates(
160 project: Option<&MatchedProject>,
161 records: &[(String, MemoryRecord)],
162 input: &RouteInput,
163 limit: usize,
164 excluded_record_ids: &HashSet<String>,
165 reference_map: Option<&crate::reference_tracker::ReferenceMap>,
166) -> Vec<LifecycleCandidate> {
167 if limit == 0 {
168 return Vec::new();
169 }
170 let mut candidates: Vec<LifecycleCandidate> = records
171 .iter()
172 .filter(|(record_id, _)| !excluded_record_ids.contains(record_id))
173 .filter_map(|(record_id, record)| {
174 scorer::score_lifecycle_candidate(
175 project,
176 record_id,
177 record,
178 input,
179 reference_map,
180 Some(records),
181 )
182 })
183 .collect();
184
185 if project.is_some() {
188 for candidate in &mut candidates {
189 if matches!(
190 candidate.scope,
191 MemoryScope::User | MemoryScope::Agent | MemoryScope::Team
192 ) {
193 let penalized = ((candidate.score as f64) * CROSS_PROJECT_PENALTY) as i32;
194 if penalized != candidate.score {
195 candidate.score = penalized;
196 candidate.reasons.push(format!(
197 "cross-project penalty ({:.0}%)",
198 (1.0 - CROSS_PROJECT_PENALTY) * 100.0
199 ));
200 }
201 }
202 }
203 }
204
205 candidates.sort_by(|left, right| {
206 right
207 .score
208 .cmp(&left.score)
209 .then_with(|| left.record_id.cmp(&right.record_id))
210 });
211
212 let initial: Vec<LifecycleCandidate> = candidates.iter().take(limit).cloned().collect();
214
215 let selected_ids: HashSet<String> = initial.iter().map(|c| c.record_id.clone()).collect();
217 let candidate_ids: HashSet<String> = candidates.iter().map(|c| c.record_id.clone()).collect();
218 let mut expand_targets: HashSet<String> = HashSet::new();
219 for candidate in &initial {
220 if let Some((_, record)) = records.iter().find(|(id, _)| id == &candidate.record_id) {
221 for related_id in &record.related_records {
222 if !selected_ids.contains(related_id) && !excluded_record_ids.contains(related_id) {
223 expand_targets.insert(related_id.clone());
224 }
225 }
226 }
227 }
228
229 let mut expanded = initial;
231 if !expand_targets.is_empty() {
232 for target_id in &expand_targets {
233 if let Some(candidate) = candidates.iter().find(|c| &c.record_id == target_id) {
235 let penalized_score = ((candidate.score as f64) * HOP_PENALTY) as i32;
236 if penalized_score > 0 {
237 let mut expanded_candidate = candidate.clone();
238 expanded_candidate.score = penalized_score;
239 expanded_candidate.reasons.push(format!(
240 "relation-expanded (1-hop, {:.0}% penalty)",
241 (1.0 - HOP_PENALTY) * 100.0
242 ));
243 expanded.push(expanded_candidate);
244 }
245 } else if !candidate_ids.contains(target_id) {
246 if let Some((_, record)) = records.iter().find(|(id, _)| id == target_id) {
249 let referrer_score = expanded
251 .iter()
252 .filter(|c| {
253 records
254 .iter()
255 .find(|(id, _)| id == &c.record_id)
256 .map(|(_, r)| r.related_records.contains(target_id))
257 .unwrap_or(false)
258 })
259 .map(|c| c.score)
260 .max()
261 .unwrap_or(0);
262 let penalized_score = ((referrer_score as f64) * HOP_PENALTY) as i32;
263 if penalized_score > 0 {
264 let confidence = crate::domain::ConfidenceTier::Medium;
265 expanded.push(LifecycleCandidate {
266 record_id: target_id.clone(),
267 title: record.title.clone(),
268 summary: record.summary.clone(),
269 memory_type: record.memory_type.clone(),
270 scope: record.scope,
271 state: record.state,
272 score: penalized_score,
273 reasons: vec![format!(
274 "relation-expanded (1-hop, {:.0}% penalty, no direct score)",
275 (1.0 - HOP_PENALTY) * 100.0
276 )],
277 project_id: record.project_id.clone(),
278 confidence,
279 contradicts: Vec::new(),
280 });
281 }
282 }
283 }
284 }
285 }
286
287 expanded.sort_by(|left, right| {
289 right
290 .score
291 .cmp(&left.score)
292 .then_with(|| left.record_id.cmp(&right.record_id))
293 });
294 expanded.truncate(limit);
295 expanded
296}
297
298pub fn excluded_record_ids_from_scored(scored: &[ScoredNote]) -> HashSet<String> {
300 scored
301 .iter()
302 .filter_map(|s| {
303 s.note
304 .frontmatter
305 .get("record_id")
306 .and_then(|v| v.as_str())
307 .map(ToString::to_string)
308 })
309 .collect()
310}
311
312pub fn superseded_record_ids(records: &[(String, MemoryRecord)]) -> HashSet<String> {
319 use crate::domain::MemoryLifecycleState;
320
321 let mut superseded: HashSet<String> = HashSet::new();
322 for (_record_id, record) in records {
323 if !matches!(
324 record.state,
325 MemoryLifecycleState::Accepted | MemoryLifecycleState::Canonical
326 ) {
327 continue;
328 }
329 if record.memory_type == "knowledge" {
330 for source_id in &record.related_records {
331 superseded.insert(source_id.clone());
332 }
333 }
334 if let Some(ref replaces) = record.supersedes {
335 superseded.insert(replaces.clone());
336 }
337 }
338 superseded
339}
340
341#[cfg(feature = "bm25")]
348pub fn select_lifecycle_candidates_with_bm25(
349 project: Option<&MatchedProject>,
350 records: &[(String, MemoryRecord)],
351 input: &RouteInput,
352 limit: usize,
353 excluded_record_ids: &HashSet<String>,
354 reference_map: Option<&crate::reference_tracker::ReferenceMap>,
355 bm25_index_path: Option<&std::path::Path>,
356) -> Vec<LifecycleCandidate> {
357 let structured_candidates = select_lifecycle_candidates(
358 project,
359 records,
360 input,
361 limit * 2,
362 excluded_record_ids,
363 reference_map,
364 );
365
366 let Some(index_path) = bm25_index_path else {
367 let mut result = structured_candidates;
368 result.truncate(limit);
369 return result;
370 };
371
372 if !index_path.exists() {
373 let mut result = structured_candidates;
374 result.truncate(limit);
375 return result;
376 }
377
378 let bm25_results = match crate::engine::bm25::Bm25Index::open_or_create(index_path) {
379 Ok(idx) => idx.search(&input.task, limit * 2).unwrap_or_default(),
380 Err(_) => {
381 let mut result = structured_candidates;
382 result.truncate(limit);
383 return result;
384 }
385 };
386
387 if bm25_results.is_empty() {
388 let mut result = structured_candidates;
389 result.truncate(limit);
390 return result;
391 }
392
393 let mut rrf_scores: std::collections::HashMap<String, f64> = std::collections::HashMap::new();
395
396 for (rank, candidate) in structured_candidates.iter().enumerate() {
398 let rrf_score = 1.0 / (RRF_K + (rank as f64) + 1.0);
399 *rrf_scores.entry(candidate.record_id.clone()).or_default() += rrf_score;
400 }
401
402 for (rank, (record_id, _score)) in bm25_results.iter().enumerate() {
404 if excluded_record_ids.contains(record_id) {
405 continue;
406 }
407 let rrf_score = 1.0 / (RRF_K + (rank as f64) + 1.0);
408 *rrf_scores.entry(record_id.clone()).or_default() += rrf_score;
409 }
410
411 let mut fused: Vec<LifecycleCandidate> = structured_candidates
413 .into_iter()
414 .map(|mut c| {
415 let rrf = rrf_scores.get(&c.record_id).copied().unwrap_or(0.0);
416 c.score = (rrf * 1000.0) as i32;
418 c.reasons
419 .push(format!("RRF fused score (bm25+structured): {:.4}", rrf));
420 c
421 })
422 .collect();
423
424 let structured_ids: HashSet<String> = fused.iter().map(|c| c.record_id.clone()).collect();
426 for (record_id, _bm25_score) in &bm25_results {
427 if structured_ids.contains(record_id) || excluded_record_ids.contains(record_id) {
428 continue;
429 }
430 if let Some((_, record)) = records.iter().find(|(id, _)| id == record_id) {
431 let rrf = rrf_scores.get(record_id).copied().unwrap_or(0.0);
432 let score = (rrf * 1000.0) as i32;
433 if score > 0 {
434 fused.push(LifecycleCandidate {
435 record_id: record_id.clone(),
436 title: record.title.clone(),
437 summary: record.summary.clone(),
438 memory_type: record.memory_type.clone(),
439 scope: record.scope,
440 state: record.state,
441 score,
442 reasons: vec![format!("BM25-only hit, RRF score: {:.4}", rrf)],
443 project_id: record.project_id.clone(),
444 confidence: crate::domain::ConfidenceTier::Medium,
445 contradicts: Vec::new(),
446 });
447 }
448 }
449 }
450
451 fused.sort_by(|left, right| {
452 right
453 .score
454 .cmp(&left.score)
455 .then_with(|| left.record_id.cmp(&right.record_id))
456 });
457 fused.truncate(limit);
458 fused
459}
460
461#[cfg(feature = "embedding")]
464pub fn select_lifecycle_candidates_fused(
465 project: Option<&MatchedProject>,
466 records: &[(String, MemoryRecord)],
467 input: &RouteInput,
468 limit: usize,
469 excluded_record_ids: &HashSet<String>,
470 reference_map: Option<&crate::reference_tracker::ReferenceMap>,
471 #[cfg(feature = "bm25")] bm25_index_path: Option<&std::path::Path>,
472 embedding_results: &[(String, f32)],
473) -> Vec<LifecycleCandidate> {
474 let structured_candidates = select_lifecycle_candidates(
475 project,
476 records,
477 input,
478 limit * 2,
479 excluded_record_ids,
480 reference_map,
481 );
482
483 #[cfg(feature = "bm25")]
484 let bm25_results: Vec<(String, f32)> = bm25_index_path
485 .filter(|p| p.exists())
486 .and_then(|p| crate::engine::bm25::Bm25Index::open_or_create(p).ok())
487 .and_then(|idx| idx.search(&input.task, limit * 2).ok())
488 .unwrap_or_default();
489
490 #[cfg(not(feature = "bm25"))]
491 let bm25_results: Vec<(String, f32)> = Vec::new();
492
493 let has_bm25 = !bm25_results.is_empty();
494 let has_embedding = !embedding_results.is_empty();
495
496 if !has_bm25 && !has_embedding {
497 let mut result = structured_candidates;
498 result.truncate(limit);
499 return result;
500 }
501
502 let mut rrf_scores: std::collections::HashMap<String, f64> = std::collections::HashMap::new();
503
504 for (rank, candidate) in structured_candidates.iter().enumerate() {
505 let rrf_score = 1.0 / (RRF_K + (rank as f64) + 1.0);
506 *rrf_scores.entry(candidate.record_id.clone()).or_default() += rrf_score;
507 }
508
509 for (rank, (record_id, _)) in bm25_results.iter().enumerate() {
510 if excluded_record_ids.contains(record_id) {
511 continue;
512 }
513 let rrf_score = 1.0 / (RRF_K + (rank as f64) + 1.0);
514 *rrf_scores.entry(record_id.clone()).or_default() += rrf_score;
515 }
516
517 for (rank, (record_id, _)) in embedding_results.iter().enumerate() {
518 if excluded_record_ids.contains(record_id) {
519 continue;
520 }
521 let rrf_score = 1.0 / (RRF_K + (rank as f64) + 1.0);
522 *rrf_scores.entry(record_id.clone()).or_default() += rrf_score;
523 }
524
525 let mut fused: Vec<LifecycleCandidate> = structured_candidates
526 .into_iter()
527 .map(|mut c| {
528 let rrf = rrf_scores.get(&c.record_id).copied().unwrap_or(0.0);
529 c.score = (rrf * 1000.0) as i32;
530 let signals: Vec<&str> = [
531 Some("structured"),
532 if has_bm25 { Some("bm25") } else { None },
533 if has_embedding {
534 Some("embedding")
535 } else {
536 None
537 },
538 ]
539 .into_iter()
540 .flatten()
541 .collect();
542 c.reasons
543 .push(format!("RRF fused ({}): {:.4}", signals.join("+"), rrf));
544 c
545 })
546 .collect();
547
548 let structured_ids: HashSet<String> = fused.iter().map(|c| c.record_id.clone()).collect();
549
550 let extra_ids: HashSet<String> = bm25_results
551 .iter()
552 .chain(embedding_results.iter())
553 .map(|(id, _)| id.clone())
554 .filter(|id| !structured_ids.contains(id) && !excluded_record_ids.contains(id))
555 .collect();
556
557 for record_id in &extra_ids {
558 if let Some((_, record)) = records.iter().find(|(id, _)| id == record_id) {
559 let rrf = rrf_scores.get(record_id).copied().unwrap_or(0.0);
560 let score = (rrf * 1000.0) as i32;
561 if score > 0 {
562 fused.push(LifecycleCandidate {
563 record_id: record_id.clone(),
564 title: record.title.clone(),
565 summary: record.summary.clone(),
566 memory_type: record.memory_type.clone(),
567 scope: record.scope,
568 state: record.state,
569 score,
570 reasons: vec![format!("RRF extra hit: {:.4}", rrf)],
571 project_id: record.project_id.clone(),
572 confidence: crate::domain::ConfidenceTier::Medium,
573 contradicts: Vec::new(),
574 });
575 }
576 }
577 }
578
579 fused.sort_by(|left, right| {
580 right
581 .score
582 .cmp(&left.score)
583 .then_with(|| left.record_id.cmp(&right.record_id))
584 });
585 fused.truncate(limit);
586 fused
587}
588
589#[cfg(test)]
590mod tests;