1use std::collections::{HashMap, HashSet};
22use std::sync::Arc;
23use std::time::Instant;
24
25use anyhow::Result;
26use chrono::{DateTime, Utc};
27use parking_lot::RwLock;
28use regex::Regex;
29use serde::{Deserialize, Serialize};
30use uuid::Uuid;
31
32use crate::embeddings::NeuralNer;
33use crate::graph_memory::GraphMemory;
34use crate::memory::feedback::FeedbackStore;
35use crate::memory::{Memory, MemorySystem};
36
37fn contains_word(text: &str, word: &str) -> bool {
44 if word.is_empty() {
45 return false;
46 }
47 let escaped = regex::escape(word);
50 let pattern = format!(r"(?i)\b{}\b", escaped);
51 match Regex::new(&pattern) {
52 Ok(re) => re.is_match(text),
53 Err(_) => text.contains(word), }
55}
56
57const DEFAULT_SEMANTIC_WEIGHT: f32 = 0.18;
65
66const DEFAULT_ENTITY_WEIGHT: f32 = 0.17;
69
70const DEFAULT_TAG_WEIGHT: f32 = 0.05;
72
73const DEFAULT_IMPORTANCE_WEIGHT: f32 = 0.05;
75
76const DEFAULT_MOMENTUM_WEIGHT: f32 = 0.28;
81
82const DEFAULT_ACCESS_COUNT_WEIGHT: f32 = 0.14;
85
86const DEFAULT_GRAPH_STRENGTH_WEIGHT: f32 = 0.13;
89
90const WEIGHT_LEARNING_RATE: f32 = 0.05;
92
93const MIN_WEIGHT: f32 = 0.05;
95
96const SIGMOID_STEEPNESS: f32 = 10.0;
98
99const SIGMOID_MIDPOINT: f32 = 0.5;
101
102#[derive(Debug, Clone, Serialize, Deserialize)]
104pub struct RelevanceConfig {
105 #[serde(default = "default_semantic_threshold")]
107 pub semantic_threshold: f32,
108
109 #[serde(default = "default_entity_threshold")]
111 pub entity_threshold: f32,
112
113 #[serde(default = "default_max_results")]
115 pub max_results: usize,
116
117 #[serde(default)]
119 pub memory_types: Vec<String>,
120
121 #[serde(default = "default_true")]
123 pub enable_entity_matching: bool,
124
125 #[serde(default = "default_true")]
127 pub enable_semantic_matching: bool,
128
129 #[serde(default = "default_min_importance")]
131 pub min_importance: f32,
132
133 #[serde(default = "default_recency_hours")]
135 pub recency_boost_hours: u64,
136
137 #[serde(default = "default_recency_multiplier")]
139 pub recency_boost_multiplier: f32,
140
141 #[serde(default = "default_graph_boost_multiplier")]
144 pub graph_boost_multiplier: f32,
145}
146
147fn default_graph_boost_multiplier() -> f32 {
148 1.15 }
150
151fn default_semantic_threshold() -> f32 {
152 0.45 }
154
155fn default_entity_threshold() -> f32 {
156 0.5
157}
158
159fn default_max_results() -> usize {
160 5
161}
162
163fn default_true() -> bool {
164 true
165}
166
167fn default_min_importance() -> f32 {
168 0.3
169}
170
171fn default_recency_hours() -> u64 {
172 24
173}
174
175fn default_recency_multiplier() -> f32 {
176 1.2
177}
178
179impl Default for RelevanceConfig {
180 fn default() -> Self {
181 Self {
182 semantic_threshold: default_semantic_threshold(),
183 entity_threshold: default_entity_threshold(),
184 max_results: default_max_results(),
185 memory_types: Vec::new(),
186 enable_entity_matching: true,
187 enable_semantic_matching: true,
188 min_importance: default_min_importance(),
189 recency_boost_hours: default_recency_hours(),
190 recency_boost_multiplier: default_recency_multiplier(),
191 graph_boost_multiplier: default_graph_boost_multiplier(),
192 }
193 }
194}
195
196#[derive(Debug, Clone, Serialize, Deserialize)]
198pub struct SurfacedMemory {
199 pub id: String,
201
202 pub content: String,
204
205 pub memory_type: String,
207
208 pub importance: f32,
210
211 pub relevance_score: f32,
213
214 pub relevance_reason: RelevanceReason,
216
217 pub matched_entities: Vec<String>,
219
220 pub semantic_similarity: Option<f32>,
222
223 pub created_at: DateTime<Utc>,
225
226 pub tags: Vec<String>,
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
232pub enum RelevanceReason {
233 EntityMatch,
235 SemanticSimilarity,
237 Combined,
239 RecentImportant,
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct RelevanceRequest {
246 pub user_id: String,
248
249 pub context: String,
251
252 #[serde(default)]
254 pub entities: Vec<String>,
255
256 #[serde(default)]
258 pub config: RelevanceConfig,
259}
260
261#[derive(Debug, Clone, Serialize, Deserialize)]
263pub struct RelevanceResponse {
264 pub memories: Vec<SurfacedMemory>,
266
267 pub detected_entities: Vec<DetectedEntity>,
269
270 pub latency_ms: f64,
272
273 pub latency_target_met: bool,
275
276 #[serde(skip_serializing_if = "Option::is_none")]
278 pub debug: Option<RelevanceDebug>,
279}
280
281#[derive(Debug, Clone, Serialize, Deserialize)]
283pub struct DetectedEntity {
284 pub name: String,
286
287 pub entity_type: String,
289
290 pub confidence: f32,
292}
293
294#[derive(Debug, Clone, Serialize, Deserialize)]
296pub struct RelevanceDebug {
297 pub ner_ms: f64,
299
300 pub entity_match_ms: f64,
302
303 pub semantic_search_ms: f64,
305
306 pub ranking_ms: f64,
308
309 pub memories_scanned: usize,
311
312 pub entity_matches: usize,
314
315 pub semantic_matches: usize,
317}
318
319#[derive(Debug, Clone, Default)]
321struct EntityIndexEntry {
322 memory_ids: HashSet<Uuid>,
324 #[allow(dead_code)]
326 last_updated: Option<DateTime<Utc>>,
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize)]
343pub struct LearnedWeights {
344 pub semantic: f32,
346 pub entity: f32,
348 pub tag: f32,
350 pub importance: f32,
352 #[serde(default = "default_momentum_weight")]
356 pub momentum: f32,
357 #[serde(default = "default_access_count_weight")]
360 pub access_count: f32,
361 #[serde(default = "default_graph_strength_weight")]
364 pub graph_strength: f32,
365 pub update_count: u32,
367 pub last_updated: Option<DateTime<Utc>>,
369}
370
371fn default_momentum_weight() -> f32 {
372 DEFAULT_MOMENTUM_WEIGHT
373}
374
375fn default_access_count_weight() -> f32 {
376 DEFAULT_ACCESS_COUNT_WEIGHT
377}
378
379fn default_graph_strength_weight() -> f32 {
380 DEFAULT_GRAPH_STRENGTH_WEIGHT
381}
382
383impl Default for LearnedWeights {
384 fn default() -> Self {
385 Self {
386 semantic: DEFAULT_SEMANTIC_WEIGHT,
387 entity: DEFAULT_ENTITY_WEIGHT,
388 tag: DEFAULT_TAG_WEIGHT,
389 importance: DEFAULT_IMPORTANCE_WEIGHT,
390 momentum: DEFAULT_MOMENTUM_WEIGHT,
391 access_count: DEFAULT_ACCESS_COUNT_WEIGHT,
392 graph_strength: DEFAULT_GRAPH_STRENGTH_WEIGHT,
393 update_count: 0,
394 last_updated: None,
395 }
396 }
397}
398
399impl LearnedWeights {
400 pub fn normalize(&mut self) {
402 let sum = self.semantic
403 + self.entity
404 + self.tag
405 + self.importance
406 + self.momentum
407 + self.access_count
408 + self.graph_strength;
409 if sum > 0.0 {
410 self.semantic /= sum;
411 self.entity /= sum;
412 self.tag /= sum;
413 self.importance /= sum;
414 self.momentum /= sum;
415 self.access_count /= sum;
416 self.graph_strength /= sum;
417 }
418 }
419
420 pub fn apply_feedback(
428 &mut self,
429 semantic_contributed: bool,
430 entity_contributed: bool,
431 tag_contributed: bool,
432 helpful: bool,
433 ) {
434 let direction = if helpful { 1.0 } else { -1.0 };
435 let delta = WEIGHT_LEARNING_RATE * direction;
436
437 if semantic_contributed {
439 self.semantic = (self.semantic + delta).max(MIN_WEIGHT);
440 }
441 if entity_contributed {
442 self.entity = (self.entity + delta).max(MIN_WEIGHT);
443 }
444 if tag_contributed {
445 self.tag = (self.tag + delta).max(MIN_WEIGHT);
446 }
447
448 if helpful && !semantic_contributed && !entity_contributed && !tag_contributed {
450 self.importance = (self.importance + delta).max(MIN_WEIGHT);
452 }
453
454 let aux_delta = WEIGHT_LEARNING_RATE * direction * 0.5;
458 self.momentum = (self.momentum + aux_delta).max(MIN_WEIGHT);
459 self.access_count = (self.access_count + aux_delta).max(MIN_WEIGHT);
460 self.graph_strength = (self.graph_strength + aux_delta).max(MIN_WEIGHT);
461
462 self.normalize();
463 self.update_count += 1;
464 self.last_updated = Some(Utc::now());
465 }
466
467 pub fn fuse_scores(
469 &self,
470 semantic_score: f32,
471 entity_score: f32,
472 tag_score: f32,
473 importance_score: f32,
474 ) -> f32 {
475 self.fuse_scores_full(
478 semantic_score,
479 entity_score,
480 tag_score,
481 importance_score,
482 0.0,
483 0,
484 0.5,
485 )
486 }
487
488 pub fn fuse_scores_with_momentum(
499 &self,
500 semantic_score: f32,
501 entity_score: f32,
502 tag_score: f32,
503 importance_score: f32,
504 momentum_ema: f32,
505 ) -> f32 {
506 self.fuse_scores_full(
508 semantic_score,
509 entity_score,
510 tag_score,
511 importance_score,
512 momentum_ema,
513 0,
514 0.5,
515 )
516 }
517
518 pub fn fuse_scores_full(
529 &self,
530 semantic_score: f32,
531 entity_score: f32,
532 tag_score: f32,
533 importance_score: f32,
534 momentum_ema: f32,
535 access_count: u32,
536 graph_strength: f32,
537 ) -> f32 {
538 let calibrated_semantic = calibrate_score(semantic_score);
540 let calibrated_entity = calibrate_score(entity_score);
541 let calibrated_tag = calibrate_score(tag_score);
542 let calibrated_importance = calibrate_score(importance_score);
543
544 let normalized_momentum = (momentum_ema + 1.0) / 2.0;
549
550 let amplified_momentum = if normalized_momentum > 0.65 {
553 (normalized_momentum * 1.5).min(1.0)
555 } else if normalized_momentum < 0.40 {
556 (normalized_momentum * 0.3).max(0.0)
559 } else {
560 normalized_momentum
561 };
562 let calibrated_momentum = calibrate_score(amplified_momentum);
563
564 let access_score = if access_count == 0 {
567 0.0
568 } else {
569 let log_access = (access_count as f32 + 1.0).log2();
572 (log_access / 4.0).min(1.0)
573 };
574 let calibrated_access = calibrate_score(access_score);
575
576 let calibrated_graph = calibrate_score(graph_strength);
578
579 let result = self.semantic * calibrated_semantic
581 + self.entity * calibrated_entity
582 + self.tag * calibrated_tag
583 + self.importance * calibrated_importance
584 + self.momentum * calibrated_momentum
585 + self.access_count * calibrated_access
586 + self.graph_strength * calibrated_graph;
587
588 if result.is_finite() {
589 result
590 } else {
591 0.0
592 }
593 }
594}
595
596fn calibrate_score(score: f32) -> f32 {
601 if !score.is_finite() {
602 return 0.0;
603 }
604 1.0 / (1.0 + (-SIGMOID_STEEPNESS * (score - SIGMOID_MIDPOINT)).exp())
605}
606
607pub struct RelevanceEngine {
609 ner: Arc<NeuralNer>,
611
612 entity_index: Arc<RwLock<HashMap<String, EntityIndexEntry>>>,
615
616 entity_index_timestamp: Arc<RwLock<Option<DateTime<Utc>>>>,
618
619 learned_weights: Arc<RwLock<LearnedWeights>>,
621
622 active_ab_test: Arc<RwLock<Option<String>>>,
624}
625
626impl RelevanceEngine {
627 pub fn new(ner: Arc<NeuralNer>) -> Self {
629 Self {
630 ner,
631 entity_index: Arc::new(RwLock::new(HashMap::new())),
632 entity_index_timestamp: Arc::new(RwLock::new(None)),
633 learned_weights: Arc::new(RwLock::new(LearnedWeights::default())),
634 active_ab_test: Arc::new(RwLock::new(None)),
635 }
636 }
637
638 pub fn set_active_ab_test(&self, test_id: Option<String>) {
642 *self.active_ab_test.write() = test_id;
643 }
644
645 pub fn get_active_ab_test(&self) -> Option<String> {
647 self.active_ab_test.read().clone()
648 }
649
650 pub fn get_weights(&self) -> LearnedWeights {
652 self.learned_weights.read().clone()
653 }
654
655 pub fn set_weights(&self, weights: LearnedWeights) {
657 *self.learned_weights.write() = weights;
658 }
659
660 pub fn apply_feedback(
664 &self,
665 semantic_contributed: bool,
666 entity_contributed: bool,
667 tag_contributed: bool,
668 helpful: bool,
669 ) {
670 self.learned_weights.write().apply_feedback(
671 semantic_contributed,
672 entity_contributed,
673 tag_contributed,
674 helpful,
675 );
676 }
677
678 fn calculate_tag_score(&self, context: &str, tags: &[String]) -> f32 {
680 if tags.is_empty() {
681 return 0.0;
682 }
683
684 let context_lower = context.to_lowercase();
685 let mut matches = 0;
686
687 for tag in tags {
688 let tag_lower = tag.to_lowercase();
689 if context_lower.contains(&tag_lower) {
691 matches += 1;
692 } else {
693 for word in context_lower.split_whitespace() {
695 if word.starts_with(&tag_lower) || tag_lower.starts_with(word) {
696 matches += 1;
697 break;
698 }
699 }
700 }
701 }
702
703 matches as f32 / tags.len() as f32
704 }
705
706 pub fn surface_relevant(
711 &self,
712 context: &str,
713 memory_system: &MemorySystem,
714 graph_memory: Option<&GraphMemory>,
715 config: &RelevanceConfig,
716 feedback_store: Option<&RwLock<FeedbackStore>>,
717 ) -> Result<RelevanceResponse> {
718 self.surface_relevant_inner(
719 context,
720 memory_system,
721 graph_memory,
722 config,
723 feedback_store,
724 None,
725 )
726 }
727
728 fn surface_relevant_inner(
729 &self,
730 context: &str,
731 memory_system: &MemorySystem,
732 graph_memory: Option<&GraphMemory>,
733 config: &RelevanceConfig,
734 feedback_store: Option<&RwLock<FeedbackStore>>,
735 weights_override: Option<LearnedWeights>,
736 ) -> Result<RelevanceResponse> {
737 let start = Instant::now();
738 let mut debug = RelevanceDebug {
739 ner_ms: 0.0,
740 entity_match_ms: 0.0,
741 semantic_search_ms: 0.0,
742 ranking_ms: 0.0,
743 memories_scanned: 0,
744 entity_matches: 0,
745 semantic_matches: 0,
746 };
747
748 let ner_start = Instant::now();
750 let detected_entities = if config.enable_entity_matching {
751 self.extract_entities(context)
752 } else {
753 Vec::new()
754 };
755 debug.ner_ms = ner_start.elapsed().as_secs_f64() * 1000.0;
756
757 let mut candidate_memories: HashMap<Uuid, (Memory, f32, f32, Vec<String>)> = HashMap::new();
761
762 if config.enable_entity_matching && !detected_entities.is_empty() {
764 let entity_start = Instant::now();
765 let entity_matches =
766 self.match_by_entities(&detected_entities, memory_system, graph_memory, config)?;
767 debug.entity_match_ms = entity_start.elapsed().as_secs_f64() * 1000.0;
768 debug.entity_matches = entity_matches.len();
769
770 for (memory, score, matched) in entity_matches {
771 let id = memory.id.0;
772 candidate_memories.insert(id, (memory, 0.0, score, matched));
774 }
775 }
776
777 if config.enable_semantic_matching {
779 let semantic_start = Instant::now();
780 let semantic_matches = self.match_by_semantic(context, memory_system, config)?;
781 debug.semantic_search_ms = semantic_start.elapsed().as_secs_f64() * 1000.0;
782 debug.semantic_matches = semantic_matches.len();
783
784 for (memory, score) in semantic_matches {
785 let id = memory.id.0;
786 if let Some((_, semantic_score, _entity_score, _matched)) =
787 candidate_memories.get_mut(&id)
788 {
789 *semantic_score = score;
791 } else {
792 candidate_memories.insert(id, (memory, score, 0.0, Vec::new()));
794 }
795 }
796 }
797
798 debug.memories_scanned = candidate_memories.len();
799
800 let ranking_start = Instant::now();
802 let weights = weights_override.unwrap_or_else(|| self.learned_weights.read().clone());
803
804 let mut results: Vec<SurfacedMemory> = candidate_memories
805 .into_iter()
806 .filter_map(
807 |(_, (memory, semantic_score, entity_score, matched_entities))| {
808 let importance = memory.importance();
810 if importance < config.min_importance {
811 return None;
812 }
813
814 if !config.memory_types.is_empty() {
816 let mem_type = format!("{:?}", memory.experience.experience_type);
817 if !config
818 .memory_types
819 .iter()
820 .any(|t| t.eq_ignore_ascii_case(&mem_type))
821 {
822 return None;
823 }
824 }
825
826 let tag_score = self.calculate_tag_score(context, &memory.experience.tags);
828
829 let access_count = memory.access_count();
831
832 let graph_strength = graph_memory
834 .and_then(|g| g.get_memory_hebbian_strength(&memory.id))
835 .unwrap_or(0.5); let momentum_ema = feedback_store
839 .and_then(|fs| {
840 let store = fs.read();
841 store.get_momentum(&memory.id).map(|m| m.ema_with_decay())
842 })
843 .unwrap_or(0.0);
844
845 let fused_score = weights.fuse_scores_full(
847 semantic_score,
848 entity_score,
849 tag_score,
850 importance,
851 momentum_ema,
852 access_count,
853 graph_strength,
854 );
855
856 let reason = if semantic_score > 0.0 && entity_score > 0.0 {
858 RelevanceReason::Combined
859 } else if entity_score > 0.0 {
860 RelevanceReason::EntityMatch
861 } else if semantic_score > 0.0 {
862 RelevanceReason::SemanticSimilarity
863 } else {
864 RelevanceReason::RecentImportant
865 };
866
867 let recency_boosted = self.apply_recency_boost(
869 fused_score,
870 memory.created_at,
871 config.recency_boost_hours,
872 config.recency_boost_multiplier,
873 );
874
875 let final_score = if entity_score > 0.0 {
877 (recency_boosted * config.graph_boost_multiplier).min(1.0)
878 } else {
879 recency_boosted
880 };
881
882 Some(SurfacedMemory {
883 id: memory.id.0.to_string(),
884 content: memory.experience.content.clone(),
885 memory_type: format!("{:?}", memory.experience.experience_type),
886 importance,
887 relevance_score: final_score,
888 relevance_reason: reason.clone(),
889 matched_entities,
890 semantic_similarity: if semantic_score > 0.0 {
891 Some(semantic_score)
892 } else {
893 None
894 },
895 created_at: memory.created_at,
896 tags: memory.experience.tags.clone(),
897 })
898 },
899 )
900 .collect();
901
902 results.sort_by(|a, b| b.relevance_score.total_cmp(&a.relevance_score));
904
905 const MIN_RELEVANCE_SCORE: f32 = 0.25;
908 results.retain(|r| r.relevance_score >= MIN_RELEVANCE_SCORE);
909
910 results.truncate(config.max_results);
912
913 debug.ranking_ms = ranking_start.elapsed().as_secs_f64() * 1000.0;
914
915 let total_latency = start.elapsed().as_secs_f64() * 1000.0;
916 let latency_target_met = total_latency < 30.0;
917
918 Ok(RelevanceResponse {
919 memories: results,
920 detected_entities,
921 latency_ms: total_latency,
922 latency_target_met,
923 debug: if cfg!(debug_assertions) {
924 Some(debug)
925 } else {
926 None
927 },
928 })
929 }
930
931 pub fn surface_relevant_with_momentum(
944 &self,
945 context: &str,
946 memory_system: &MemorySystem,
947 graph_memory: Option<&GraphMemory>,
948 config: &RelevanceConfig,
949 momentum_lookup: &HashMap<Uuid, f32>,
950 ) -> Result<RelevanceResponse> {
951 let start = Instant::now();
952 let mut debug = RelevanceDebug {
953 ner_ms: 0.0,
954 entity_match_ms: 0.0,
955 semantic_search_ms: 0.0,
956 ranking_ms: 0.0,
957 memories_scanned: 0,
958 entity_matches: 0,
959 semantic_matches: 0,
960 };
961
962 let ner_start = Instant::now();
964 let detected_entities = if config.enable_entity_matching {
965 self.extract_entities(context)
966 } else {
967 Vec::new()
968 };
969 debug.ner_ms = ner_start.elapsed().as_secs_f64() * 1000.0;
970
971 let mut candidate_memories: HashMap<Uuid, (Memory, f32, f32, Vec<String>)> = HashMap::new();
973
974 if config.enable_entity_matching && !detected_entities.is_empty() {
976 let entity_start = Instant::now();
977 let entity_matches =
978 self.match_by_entities(&detected_entities, memory_system, graph_memory, config)?;
979 debug.entity_match_ms = entity_start.elapsed().as_secs_f64() * 1000.0;
980 debug.entity_matches = entity_matches.len();
981
982 for (memory, score, matched) in entity_matches {
983 let id = memory.id.0;
984 candidate_memories.insert(id, (memory, 0.0, score, matched));
985 }
986 }
987
988 if config.enable_semantic_matching {
990 let semantic_start = Instant::now();
991 let semantic_matches = self.match_by_semantic(context, memory_system, config)?;
992 debug.semantic_search_ms = semantic_start.elapsed().as_secs_f64() * 1000.0;
993 debug.semantic_matches = semantic_matches.len();
994
995 for (memory, score) in semantic_matches {
996 let id = memory.id.0;
997 if let Some((_, semantic_score, _entity_score, _matched)) =
998 candidate_memories.get_mut(&id)
999 {
1000 *semantic_score = score;
1001 } else {
1002 candidate_memories.insert(id, (memory, score, 0.0, Vec::new()));
1003 }
1004 }
1005 }
1006
1007 debug.memories_scanned = candidate_memories.len();
1008
1009 let ranking_start = Instant::now();
1011 let weights = self.learned_weights.read().clone();
1012
1013 let mut results: Vec<SurfacedMemory> = candidate_memories
1014 .into_iter()
1015 .filter_map(
1016 |(id, (memory, semantic_score, entity_score, matched_entities))| {
1017 let importance = memory.importance();
1018 if importance < config.min_importance {
1019 return None;
1020 }
1021
1022 if !config.memory_types.is_empty() {
1023 let mem_type = format!("{:?}", memory.experience.experience_type);
1024 if !config
1025 .memory_types
1026 .iter()
1027 .any(|t| t.eq_ignore_ascii_case(&mem_type))
1028 {
1029 return None;
1030 }
1031 }
1032
1033 let tag_score = self.calculate_tag_score(context, &memory.experience.tags);
1034
1035 let momentum_ema = momentum_lookup.get(&id).copied().unwrap_or(0.0);
1037
1038 let access_count = memory.access_count();
1040
1041 let graph_strength = graph_memory
1043 .and_then(|g| g.get_memory_hebbian_strength(&memory.id))
1044 .unwrap_or(0.5); let fused_score = weights.fuse_scores_full(
1048 semantic_score,
1049 entity_score,
1050 tag_score,
1051 importance,
1052 momentum_ema,
1053 access_count,
1054 graph_strength,
1055 );
1056
1057 let reason = if semantic_score > 0.0 && entity_score > 0.0 {
1058 RelevanceReason::Combined
1059 } else if entity_score > 0.0 {
1060 RelevanceReason::EntityMatch
1061 } else if semantic_score > 0.0 {
1062 RelevanceReason::SemanticSimilarity
1063 } else {
1064 RelevanceReason::RecentImportant
1065 };
1066
1067 let recency_boosted = self.apply_recency_boost(
1069 fused_score,
1070 memory.created_at,
1071 config.recency_boost_hours,
1072 config.recency_boost_multiplier,
1073 );
1074
1075 let final_score = if entity_score > 0.0 {
1078 (recency_boosted * config.graph_boost_multiplier).min(1.0)
1079 } else {
1080 recency_boosted
1081 };
1082
1083 Some(SurfacedMemory {
1084 id: memory.id.0.to_string(),
1085 content: memory.experience.content.clone(),
1086 memory_type: format!("{:?}", memory.experience.experience_type),
1087 importance,
1088 relevance_score: final_score,
1089 relevance_reason: reason.clone(),
1090 matched_entities,
1091 semantic_similarity: if semantic_score > 0.0 {
1092 Some(semantic_score)
1093 } else {
1094 None
1095 },
1096 created_at: memory.created_at,
1097 tags: memory.experience.tags.clone(),
1098 })
1099 },
1100 )
1101 .collect();
1102
1103 results.sort_by(|a, b| b.relevance_score.total_cmp(&a.relevance_score));
1104
1105 const MIN_RELEVANCE_SCORE: f32 = 0.25;
1107 results.retain(|r| r.relevance_score >= MIN_RELEVANCE_SCORE);
1108
1109 results.truncate(config.max_results);
1111
1112 debug.ranking_ms = ranking_start.elapsed().as_secs_f64() * 1000.0;
1113
1114 let total_latency = start.elapsed().as_secs_f64() * 1000.0;
1115 let latency_target_met = total_latency < 30.0;
1116
1117 Ok(RelevanceResponse {
1118 memories: results,
1119 detected_entities,
1120 latency_ms: total_latency,
1121 latency_target_met,
1122 debug: if cfg!(debug_assertions) {
1123 Some(debug)
1124 } else {
1125 None
1126 },
1127 })
1128 }
1129
1130 pub fn surface_relevant_with_ab_test(
1139 &self,
1140 context: &str,
1141 user_id: &str,
1142 memory_system: &MemorySystem,
1143 graph_memory: Option<&GraphMemory>,
1144 config: &RelevanceConfig,
1145 ab_manager: &crate::ab_testing::ABTestManager,
1146 ) -> Result<(RelevanceResponse, Option<crate::ab_testing::ABTestVariant>)> {
1147 let start = Instant::now();
1148
1149 let active_test = self.get_active_ab_test();
1151
1152 let (weights, variant) = if let Some(ref test_id) = active_test {
1153 match ab_manager.get_weights_for_user(test_id, user_id) {
1155 Ok(w) => {
1156 let v = ab_manager.get_variant(test_id, user_id).ok();
1157 (w, v)
1158 }
1159 Err(_) => {
1160 (self.get_weights(), None)
1162 }
1163 }
1164 } else {
1165 (self.get_weights(), None)
1167 };
1168
1169 let response = self.surface_relevant_inner(
1171 context,
1172 memory_system,
1173 graph_memory,
1174 config,
1175 None,
1176 Some(weights),
1177 )?;
1178
1179 if let (Some(ref test_id), Some(ref _v)) = (&active_test, &variant) {
1181 let latency_us = start.elapsed().as_micros() as u64;
1182 let avg_score = if response.memories.is_empty() {
1183 0.0
1184 } else {
1185 response
1186 .memories
1187 .iter()
1188 .map(|m| m.relevance_score as f64)
1189 .sum::<f64>()
1190 / response.memories.len() as f64
1191 };
1192
1193 let _ = ab_manager.record_impression(test_id, user_id, avg_score, latency_us);
1194 }
1195
1196 Ok((response, variant))
1197 }
1198
1199 pub fn record_ab_click(
1203 &self,
1204 user_id: &str,
1205 memory_id: Uuid,
1206 ab_manager: &crate::ab_testing::ABTestManager,
1207 ) -> Result<()> {
1208 if let Some(test_id) = self.get_active_ab_test() {
1209 ab_manager
1210 .record_click(&test_id, user_id, memory_id)
1211 .map_err(|e| anyhow::anyhow!("Failed to record A/B click: {}", e))?;
1212 }
1213 Ok(())
1214 }
1215
1216 pub fn record_ab_feedback(
1220 &self,
1221 user_id: &str,
1222 positive: bool,
1223 ab_manager: &crate::ab_testing::ABTestManager,
1224 ) -> Result<()> {
1225 if let Some(test_id) = self.get_active_ab_test() {
1226 ab_manager
1227 .record_feedback(&test_id, user_id, positive)
1228 .map_err(|e| anyhow::anyhow!("Failed to record A/B feedback: {}", e))?;
1229 }
1230 Ok(())
1231 }
1232
1233 fn extract_entities(&self, context: &str) -> Vec<DetectedEntity> {
1235 match self.ner.extract(context) {
1236 Ok(entities) => entities
1237 .into_iter()
1238 .map(|e| DetectedEntity {
1239 name: e.text,
1240 entity_type: format!("{:?}", e.entity_type),
1241 confidence: e.confidence,
1242 })
1243 .collect(),
1244 Err(_) => Vec::new(),
1245 }
1246 }
1247
1248 fn match_by_entities(
1256 &self,
1257 entities: &[DetectedEntity],
1258 memory_system: &MemorySystem,
1259 graph_memory: Option<&GraphMemory>,
1260 config: &RelevanceConfig,
1261 ) -> Result<Vec<(Memory, f32, Vec<String>)>> {
1262 let entity_lookup: Vec<(String, &DetectedEntity, f32)> = entities
1264 .iter()
1265 .map(|e| {
1266 let weight = self.entity_type_weight(&e.entity_type);
1267 (e.name.to_lowercase(), e, weight)
1268 })
1269 .collect();
1270
1271 let max_candidates = config.max_results * 3;
1272 let mut results: Vec<(Memory, f32, Vec<String>)> = Vec::with_capacity(max_candidates);
1273 let mut found_ids: HashSet<Uuid> = HashSet::new();
1274
1275 {
1279 let index = self.entity_index.read();
1280 if !index.is_empty() {
1281 let mut candidate_ids: HashMap<Uuid, (f32, Vec<String>)> = HashMap::new();
1283
1284 for (name_lower, entity, weight) in &entity_lookup {
1285 if let Some(entry) = index.get(name_lower) {
1286 for &memory_id in &entry.memory_ids {
1287 let score = entity.confidence * weight;
1288 candidate_ids
1289 .entry(memory_id)
1290 .and_modify(|(existing_score, matched)| {
1291 *existing_score += score;
1292 if !matched.contains(&entity.name) {
1293 matched.push(entity.name.clone());
1294 }
1295 })
1296 .or_insert((score, vec![entity.name.clone()]));
1297 }
1298 }
1299 }
1300
1301 if !candidate_ids.is_empty() {
1303 let mut sorted_candidates: Vec<_> = candidate_ids.into_iter().collect();
1305 sorted_candidates.sort_by(|a, b| b.1 .0.total_cmp(&a.1 .0));
1306
1307 for (memory_id, (score, matched)) in
1308 sorted_candidates.into_iter().take(max_candidates)
1309 {
1310 let normalized_score = (score / matched.len() as f32).min(1.0);
1311 if normalized_score >= config.entity_threshold {
1312 let mem_id = crate::memory::MemoryId(memory_id);
1314 if let Ok(memory) = memory_system.get_memory(&mem_id) {
1315 found_ids.insert(memory_id);
1316 results.push((memory, normalized_score, matched));
1317 }
1318 }
1319 }
1320 }
1321
1322 if !results.is_empty() {
1324 return Ok(results);
1325 }
1326 }
1327 }
1328
1329 let all_memories = memory_system.get_all_memories()?;
1333
1334 if all_memories.is_empty() {
1335 return Ok(results);
1336 }
1337
1338 for shared_memory in &all_memories {
1339 if results.len() >= max_candidates {
1341 break;
1342 }
1343
1344 if found_ids.contains(&shared_memory.id.0) {
1346 continue;
1347 }
1348
1349 let content_lower = shared_memory.experience.content.to_lowercase();
1350 let mut matched: Vec<String> = Vec::new();
1351 let mut match_score = 0.0f32;
1352
1353 for (name_lower, entity, weight) in &entity_lookup {
1355 if contains_word(&content_lower, name_lower) {
1356 matched.push(entity.name.clone());
1357 match_score += entity.confidence * weight;
1358 }
1359 }
1360
1361 if !matched.is_empty() {
1363 let normalized_score = (match_score / matched.len() as f32).min(1.0);
1364 if normalized_score >= config.entity_threshold {
1365 results.push(((**shared_memory).clone(), normalized_score, matched));
1366 }
1367 }
1368 }
1369
1370 if results.len() >= config.max_results * 2 || graph_memory.is_none() {
1373 return Ok(results);
1374 }
1375
1376 let mut graph_results = results;
1378 if let Some(graph) = graph_memory {
1379 let found_ids: HashSet<Uuid> = graph_results.iter().map(|(m, _, _)| m.id.0).collect();
1381
1382 let max_graph_lookups = 5;
1384 for (idx, entity) in entities.iter().enumerate() {
1385 if idx >= max_graph_lookups {
1386 break;
1387 }
1388
1389 if let Ok(Some(entity_node)) = graph.find_entity_by_name(&entity.name) {
1390 if let Ok(traversal) = graph.traverse_from_entity(&entity_node.uuid, 5) {
1393 for traversed in &traversal.entities {
1394 if let Ok(episodes) =
1396 graph.get_episodes_by_entity(&traversed.entity.uuid)
1397 {
1398 for episode in episodes.iter().take(10) {
1399 let memory_id = crate::memory::MemoryId(episode.uuid);
1401
1402 if found_ids.contains(&episode.uuid) {
1404 continue;
1405 }
1406
1407 let score = entity.confidence
1410 * traversed.entity.salience
1411 * traversed.decay_factor;
1412 if score >= config.entity_threshold {
1413 if let Ok(memory) = memory_system.get_memory(&memory_id) {
1415 graph_results.push((
1416 memory,
1417 score,
1418 vec![
1419 entity.name.clone(),
1420 traversed.entity.name.clone(),
1421 ],
1422 ));
1423 }
1424 }
1425 }
1426 }
1427 }
1428 }
1429 }
1430 }
1431 }
1432
1433 Ok(graph_results)
1434 }
1435
1436 fn entity_type_weight(&self, entity_type: &str) -> f32 {
1438 match entity_type.to_lowercase().as_str() {
1439 "person" => 1.0,
1440 "organization" => 0.9,
1441 "location" => 0.8,
1442 "technology" => 0.85,
1443 "product" => 0.9,
1444 "event" => 0.7,
1445 "date" => 0.5,
1446 _ => 0.6,
1447 }
1448 }
1449
1450 fn match_by_semantic(
1452 &self,
1453 context: &str,
1454 memory_system: &MemorySystem,
1455 config: &RelevanceConfig,
1456 ) -> Result<Vec<(Memory, f32)>> {
1457 let query = crate::memory::Query {
1459 query_text: Some(context.to_string()),
1460 max_results: config.max_results * 2, importance_threshold: Some(config.min_importance),
1462 ..Default::default()
1463 };
1464
1465 let mut results: Vec<(Memory, f32)> = Vec::new();
1466
1467 match memory_system.recall(&query) {
1469 Ok(shared_memories) => {
1470 for (rank, shared_memory) in shared_memories.into_iter().enumerate() {
1473 let memory = (*shared_memory).clone();
1474 let score = shared_memory
1475 .get_score()
1476 .unwrap_or(1.0 / (rank as f32 + 1.0));
1477 if score >= config.semantic_threshold {
1478 results.push((memory, score));
1479 }
1480 }
1481 }
1482 Err(_) => {
1483 let context_words: HashSet<&str> =
1485 context.split_whitespace().filter(|w| w.len() > 3).collect();
1486
1487 let all_memories = memory_system.get_all_memories()?;
1488 for shared_memory in all_memories {
1489 let content_words: HashSet<&str> = shared_memory
1490 .experience
1491 .content
1492 .split_whitespace()
1493 .filter(|w| w.len() > 3)
1494 .collect();
1495
1496 let overlap = context_words.intersection(&content_words).count();
1497 if overlap > 0 {
1498 let score = overlap as f32
1499 / (context_words.len() + content_words.len()) as f32
1500 * 2.0;
1501 if score >= config.semantic_threshold {
1502 let memory = (*shared_memory).clone();
1503 results.push((memory, score.min(1.0)));
1504 }
1505 }
1506 }
1507 }
1508 }
1509
1510 Ok(results)
1511 }
1512
1513 fn apply_recency_boost(
1515 &self,
1516 base_score: f32,
1517 created_at: DateTime<Utc>,
1518 boost_hours: u64,
1519 multiplier: f32,
1520 ) -> f32 {
1521 if boost_hours == 0 {
1522 return base_score;
1523 }
1524
1525 let now = Utc::now();
1526 let age = now.signed_duration_since(created_at);
1527 let age_hours = age.num_hours() as u64;
1528
1529 if age_hours <= boost_hours {
1530 let decay = 1.0 - (age_hours as f32 / boost_hours as f32);
1532 let boost = 1.0 + (multiplier - 1.0) * decay;
1533 (base_score * boost).min(1.0)
1534 } else {
1535 base_score
1536 }
1537 }
1538
1539 pub fn refresh_entity_index(&self, graph_memory: &GraphMemory) -> Result<()> {
1551 let mut index = self.entity_index.write();
1552 index.clear();
1553
1554 let now = Utc::now();
1555 let entities = graph_memory.get_all_entities()?;
1556
1557 for entity in entities {
1558 let episodes = graph_memory.get_episodes_by_entity(&entity.uuid)?;
1559 let episode_ids: HashSet<Uuid> = episodes.iter().map(|e| e.uuid).collect();
1560
1561 let name_lower = entity.name.to_lowercase();
1562 index.insert(
1563 name_lower,
1564 EntityIndexEntry {
1565 memory_ids: episode_ids,
1566 last_updated: Some(now),
1567 },
1568 );
1569 }
1570
1571 *self.entity_index_timestamp.write() = Some(now);
1573
1574 Ok(())
1575 }
1576
1577 pub fn get_memories_for_entity(
1582 &self,
1583 entity_name: &str,
1584 graph_memory: Option<&GraphMemory>,
1585 ) -> Option<HashSet<Uuid>> {
1586 let name_lower = entity_name.to_lowercase();
1587
1588 {
1590 let index = self.entity_index.read();
1591 if let Some(entry) = index.get(&name_lower) {
1592 return Some(entry.memory_ids.clone());
1593 }
1594 }
1595
1596 if let Some(graph) = graph_memory {
1598 if let Ok(Some(entity)) = graph.find_entity_by_name(&name_lower) {
1599 if let Ok(episodes) = graph.get_episodes_by_entity(&entity.uuid) {
1600 let memory_ids: HashSet<Uuid> = episodes.iter().map(|e| e.uuid).collect();
1601
1602 let mut index = self.entity_index.write();
1604 index.insert(
1605 name_lower,
1606 EntityIndexEntry {
1607 memory_ids: memory_ids.clone(),
1608 last_updated: Some(Utc::now()),
1609 },
1610 );
1611
1612 return Some(memory_ids);
1613 }
1614 }
1615 }
1616
1617 None
1618 }
1619
1620 pub fn entity_index_needs_refresh(&self, max_age_hours: i64) -> bool {
1622 let timestamp = self.entity_index_timestamp.read();
1623 match *timestamp {
1624 None => true,
1625 Some(ts) => {
1626 let age = Utc::now().signed_duration_since(ts);
1627 age.num_hours() > max_age_hours
1628 }
1629 }
1630 }
1631
1632 pub fn entity_index_stats(&self) -> (usize, Option<DateTime<Utc>>) {
1634 let index = self.entity_index.read();
1635 let timestamp = *self.entity_index_timestamp.read();
1636 (index.len(), timestamp)
1637 }
1638
1639 pub fn clear_entity_index(&self) {
1641 self.entity_index.write().clear();
1642 *self.entity_index_timestamp.write() = None;
1643 }
1644}
1645
1646#[derive(Debug, Clone, Serialize, Deserialize)]
1652pub struct ContextMonitorHandshake {
1653 pub user_id: String,
1655
1656 #[serde(default)]
1658 pub config: Option<RelevanceConfig>,
1659
1660 #[serde(default = "default_debounce_ms")]
1662 pub debounce_ms: u64,
1663}
1664
1665fn default_debounce_ms() -> u64 {
1666 100
1667}
1668
1669#[derive(Debug, Clone, Serialize, Deserialize)]
1671pub struct ContextUpdate {
1672 pub context: String,
1674
1675 #[serde(default)]
1677 pub entities: Vec<String>,
1678
1679 #[serde(default)]
1681 pub config: Option<RelevanceConfig>,
1682}
1683
1684#[derive(Debug, Clone, Serialize, Deserialize)]
1686#[serde(tag = "type")]
1687pub enum ContextMonitorResponse {
1688 #[serde(rename = "ack")]
1690 Ack { timestamp: DateTime<Utc> },
1691
1692 #[serde(rename = "relevant")]
1694 Relevant {
1695 memories: Vec<SurfacedMemory>,
1696 detected_entities: Vec<DetectedEntity>,
1697 latency_ms: f64,
1698 timestamp: DateTime<Utc>,
1699 },
1700
1701 #[serde(rename = "none")]
1703 None { timestamp: DateTime<Utc> },
1704
1705 #[serde(rename = "error")]
1707 Error {
1708 code: String,
1709 message: String,
1710 fatal: bool,
1711 timestamp: DateTime<Utc>,
1712 },
1713}
1714
1715pub struct ContextMonitor {
1717 engine: Arc<RelevanceEngine>,
1719
1720 default_config: RelevanceConfig,
1722
1723 debounce_ms: u64,
1725}
1726
1727impl ContextMonitor {
1728 pub fn new(engine: Arc<RelevanceEngine>, debounce_ms: u64) -> Self {
1730 Self {
1731 engine,
1732 default_config: RelevanceConfig::default(),
1733 debounce_ms,
1734 }
1735 }
1736
1737 pub fn debounce_ms(&self) -> u64 {
1739 self.debounce_ms
1740 }
1741
1742 pub fn set_config(&mut self, config: RelevanceConfig) {
1744 self.default_config = config;
1745 }
1746
1747 pub fn engine(&self) -> &Arc<RelevanceEngine> {
1749 &self.engine
1750 }
1751
1752 pub fn process_context(
1754 &self,
1755 context: &str,
1756 memory_system: &MemorySystem,
1757 graph_memory: Option<&GraphMemory>,
1758 config: Option<&RelevanceConfig>,
1759 ) -> Result<Option<RelevanceResponse>> {
1760 let cfg = config.unwrap_or(&self.default_config);
1761
1762 if context.len() < 10 {
1764 return Ok(None);
1765 }
1766
1767 let response =
1768 self.engine
1769 .surface_relevant(context, memory_system, graph_memory, cfg, None)?;
1770
1771 if response.memories.is_empty() {
1773 Ok(None)
1774 } else {
1775 Ok(Some(response))
1776 }
1777 }
1778}
1779
1780#[cfg(test)]
1781mod tests {
1782 use super::*;
1783
1784 #[test]
1785 fn test_relevance_config_defaults() {
1786 let config = RelevanceConfig::default();
1787 assert_eq!(config.semantic_threshold, 0.45);
1788 assert_eq!(config.entity_threshold, 0.5);
1789 assert_eq!(config.max_results, 5);
1790 assert!(config.enable_entity_matching);
1791 assert!(config.enable_semantic_matching);
1792 }
1793
1794 #[test]
1795 fn test_recency_boost() {
1796 let engine = RelevanceEngine::new(Arc::new(crate::embeddings::NeuralNer::new_fallback(
1797 crate::embeddings::NerConfig::default(),
1798 )));
1799
1800 let recent = Utc::now();
1802 let boosted = engine.apply_recency_boost(0.5, recent, 24, 1.2);
1803 assert!(boosted > 0.5);
1804
1805 let old = Utc::now() - chrono::Duration::hours(48);
1807 let not_boosted = engine.apply_recency_boost(0.5, old, 24, 1.2);
1808 assert!((not_boosted - 0.5).abs() < 0.001);
1809 }
1810
1811 #[test]
1812 fn test_entity_type_weight() {
1813 let engine = RelevanceEngine::new(Arc::new(crate::embeddings::NeuralNer::new_fallback(
1814 crate::embeddings::NerConfig::default(),
1815 )));
1816
1817 assert_eq!(engine.entity_type_weight("Person"), 1.0);
1818 assert_eq!(engine.entity_type_weight("organization"), 0.9);
1819 assert!(engine.entity_type_weight("unknown") < 1.0);
1820 }
1821
1822 #[test]
1823 fn test_detected_entity_serialization() {
1824 let entity = DetectedEntity {
1825 name: "Rust".to_string(),
1826 entity_type: "Technology".to_string(),
1827 confidence: 0.95,
1828 };
1829
1830 let json = serde_json::to_string(&entity).unwrap();
1831 assert!(json.contains("Rust"));
1832 assert!(json.contains("Technology"));
1833 }
1834
1835 #[test]
1836 fn test_learned_weights_default() {
1837 let weights = LearnedWeights::default();
1838
1839 let sum = weights.semantic
1841 + weights.entity
1842 + weights.tag
1843 + weights.importance
1844 + weights.momentum
1845 + weights.access_count
1846 + weights.graph_strength;
1847 assert!((sum - 1.0).abs() < 0.001);
1848
1849 assert_eq!(weights.semantic, DEFAULT_SEMANTIC_WEIGHT);
1851 assert_eq!(weights.entity, DEFAULT_ENTITY_WEIGHT);
1852 assert_eq!(weights.tag, DEFAULT_TAG_WEIGHT);
1853 assert_eq!(weights.importance, DEFAULT_IMPORTANCE_WEIGHT);
1854 assert_eq!(weights.momentum, DEFAULT_MOMENTUM_WEIGHT);
1855 assert_eq!(weights.access_count, DEFAULT_ACCESS_COUNT_WEIGHT);
1856 assert_eq!(weights.graph_strength, DEFAULT_GRAPH_STRENGTH_WEIGHT);
1857 }
1858
1859 #[test]
1860 fn test_learned_weights_normalize() {
1861 let mut weights = LearnedWeights {
1862 semantic: 0.5,
1863 entity: 0.5,
1864 tag: 0.5,
1865 importance: 0.5,
1866 momentum: 0.5,
1867 access_count: 0.5,
1868 graph_strength: 0.5, update_count: 0,
1870 last_updated: None,
1871 };
1872
1873 weights.normalize();
1874
1875 let sum = weights.semantic
1876 + weights.entity
1877 + weights.tag
1878 + weights.importance
1879 + weights.momentum
1880 + weights.access_count
1881 + weights.graph_strength;
1882 assert!((sum - 1.0).abs() < 0.001);
1883 assert!((weights.semantic - 1.0 / 7.0).abs() < 0.001);
1885 }
1886
1887 #[test]
1888 fn test_learned_weights_feedback_helpful() {
1889 let mut weights = LearnedWeights::default();
1890 let initial_semantic = weights.semantic;
1891 let initial_entity = weights.entity;
1892
1893 weights.apply_feedback(true, true, false, true);
1895
1896 assert_eq!(weights.update_count, 1);
1899 assert!(weights.last_updated.is_some());
1900
1901 let sum = weights.semantic
1903 + weights.entity
1904 + weights.tag
1905 + weights.importance
1906 + weights.momentum
1907 + weights.access_count
1908 + weights.graph_strength;
1909 assert!((sum - 1.0).abs() < 0.001);
1910
1911 let new_se = weights.semantic + weights.entity;
1913 let old_se = initial_semantic + initial_entity;
1914 assert!(new_se >= old_se - 0.1); }
1916
1917 #[test]
1918 fn test_learned_weights_feedback_not_helpful() {
1919 let mut weights = LearnedWeights::default();
1920
1921 weights.apply_feedback(true, false, false, false);
1923
1924 let sum = weights.semantic
1926 + weights.entity
1927 + weights.tag
1928 + weights.importance
1929 + weights.momentum
1930 + weights.access_count
1931 + weights.graph_strength;
1932 assert!((sum - 1.0).abs() < 0.001);
1933 }
1934
1935 #[test]
1936 fn test_calibrate_score() {
1937 let high = calibrate_score(0.9);
1939 assert!(high > 0.9);
1940
1941 let low = calibrate_score(0.1);
1943 assert!(low < 0.1);
1944
1945 let mid = calibrate_score(SIGMOID_MIDPOINT);
1947 assert!((mid - 0.5).abs() < 0.001);
1948 }
1949
1950 #[test]
1951 fn test_score_fusion() {
1952 let weights = LearnedWeights::default();
1953
1954 let high = weights.fuse_scores_full(0.9, 0.9, 0.9, 0.9, 0.9, 16, 0.9);
1956 assert!(high > 0.8, "high score was {}", high);
1957
1958 let low = weights.fuse_scores_full(0.1, 0.1, 0.1, 0.1, -0.9, 0, 0.1);
1960 assert!(low < 0.3, "low score was {}", low);
1961
1962 let mixed = weights.fuse_scores_full(0.9, 0.1, 0.5, 0.7, 0.0, 2, 0.5);
1964 assert!(mixed > 0.2 && mixed < 0.8, "mixed score was {}", mixed);
1965
1966 let legacy = weights.fuse_scores(0.9, 0.9, 0.9, 0.9);
1968 assert!(legacy > 0.5, "legacy score was {}", legacy); }
1970
1971 #[test]
1972 fn test_tag_score_calculation() {
1973 let engine = RelevanceEngine::new(Arc::new(crate::embeddings::NeuralNer::new_fallback(
1974 crate::embeddings::NerConfig::default(),
1975 )));
1976
1977 let score = engine.calculate_tag_score("I love Rust programming", &["rust".to_string()]);
1979 assert_eq!(score, 1.0);
1980
1981 let score = engine
1983 .calculate_tag_score("Learning Rust", &["rust".to_string(), "python".to_string()]);
1984 assert_eq!(score, 0.5);
1985
1986 let score = engine.calculate_tag_score("Hello world", &["rust".to_string()]);
1988 assert_eq!(score, 0.0);
1989
1990 let score = engine.calculate_tag_score("Test", &[]);
1992 assert_eq!(score, 0.0);
1993 }
1994
1995 #[test]
1996 fn test_min_weight_enforcement() {
1997 let mut weights = LearnedWeights {
1998 semantic: 0.1,
1999 entity: MIN_WEIGHT + 0.01, tag: 0.3, importance: 0.1,
2002 momentum: 0.1,
2003 access_count: 0.1,
2004 graph_strength: 0.1,
2005 update_count: 0,
2006 last_updated: None,
2007 };
2008
2009 weights.apply_feedback(false, true, false, false);
2011
2012 assert!(
2016 weights.entity >= MIN_WEIGHT,
2017 "entity {} < MIN_WEIGHT {}",
2018 weights.entity,
2019 MIN_WEIGHT
2020 );
2021
2022 let sum = weights.semantic
2024 + weights.entity
2025 + weights.tag
2026 + weights.importance
2027 + weights.momentum
2028 + weights.access_count
2029 + weights.graph_strength;
2030 assert!((sum - 1.0).abs() < 0.001);
2031 }
2032
2033 #[test]
2034 fn test_fuse_scores_nan_inf_guard() {
2035 let weights = LearnedWeights::default();
2036
2037 let result = weights.fuse_scores_full(f32::NAN, 0.5, 0.5, 0.5, 0.0, 1, 0.5);
2039 assert!(result.is_finite(), "NaN input should produce finite output");
2040
2041 let result = weights.fuse_scores_full(0.5, f32::INFINITY, 0.5, 0.5, 0.0, 1, 0.5);
2043 assert!(result.is_finite(), "Inf input should produce finite output");
2044
2045 let result = weights.fuse_scores_full(0.5, 0.5, f32::NEG_INFINITY, 0.5, 0.0, 1, 0.5);
2047 assert!(
2048 result.is_finite(),
2049 "-Inf input should produce finite output"
2050 );
2051
2052 let result = weights.fuse_scores_full(0.8, 0.5, 0.3, 0.6, 0.2, 3, 0.4);
2054 assert!(result.is_finite());
2055 assert!(result > 0.0);
2056 }
2057}