1use anyhow::Result;
29use std::collections::HashMap;
30use std::time::Instant;
31use uuid::Uuid;
32
33use crate::constants::{
34 BIDIRECTIONAL_DENSITY_DENSE, BIDIRECTIONAL_DENSITY_SPARSE, BIDIRECTIONAL_HOPS_DENSE,
35 BIDIRECTIONAL_HOPS_MEDIUM, BIDIRECTIONAL_HOPS_SPARSE, BIDIRECTIONAL_INTERSECTION_BOOST,
36 BIDIRECTIONAL_INTERSECTION_MIN, BIDIRECTIONAL_MIN_ENTITIES, DENSITY_GRAPH_WEIGHT_MAX,
37 DENSITY_GRAPH_WEIGHT_MIN, DENSITY_LINGUISTIC_WEIGHT, DENSITY_THRESHOLD_MAX,
38 DENSITY_THRESHOLD_MIN, EDGE_TIER_TRUST_L1, EDGE_TIER_TRUST_L2, EDGE_TIER_TRUST_L3,
39 EDGE_TIER_TRUST_LTP, HYBRID_GRAPH_WEIGHT, HYBRID_LINGUISTIC_WEIGHT, HYBRID_SEMANTIC_WEIGHT,
40 IMPORTANCE_DECAY_MAX, IMPORTANCE_DECAY_MIN, MEMORY_TIER_GRAPH_MULT_ARCHIVE,
41 MEMORY_TIER_GRAPH_MULT_LONGTERM, MEMORY_TIER_GRAPH_MULT_SESSION,
42 MEMORY_TIER_GRAPH_MULT_WORKING, SALIENCE_BOOST_FACTOR, SPREADING_ACTIVATION_THRESHOLD,
43 SPREADING_DEGREE_NORMALIZATION, SPREADING_EARLY_TERMINATION_CANDIDATES,
44 SPREADING_EARLY_TERMINATION_RATIO, SPREADING_MAX_HOPS, SPREADING_MIN_CANDIDATES,
45 SPREADING_MIN_HOPS, SPREADING_NORMALIZATION_FACTOR, SPREADING_RELAXED_THRESHOLD,
46};
47use crate::embeddings::Embedder;
48use crate::graph_memory::{EdgeTier, EpisodicNode, GraphMemory};
49use crate::memory::types::MemoryTier;
50use crate::memory::query_parser::{analyze_query, QueryAnalysis};
52use crate::memory::types::{Memory, Query, RetrievalStats, SharedMemory};
53use crate::similarity::cosine_similarity;
54
55#[derive(Debug, Clone)]
57pub struct ActivatedMemory {
58 pub memory: SharedMemory,
59 #[allow(dead_code)] pub activation_score: f32,
61 #[allow(dead_code)] pub semantic_score: f32,
63 #[allow(dead_code)] pub linguistic_score: f32,
65 pub final_score: f32,
66}
67
68pub fn calculate_density_weights(graph_density: f32) -> (f32, f32, f32) {
80 let graph_weight = if graph_density <= DENSITY_THRESHOLD_MIN {
82 DENSITY_GRAPH_WEIGHT_MAX } else if graph_density >= DENSITY_THRESHOLD_MAX {
84 DENSITY_GRAPH_WEIGHT_MIN } else {
86 let ratio = (graph_density - DENSITY_THRESHOLD_MIN)
88 / (DENSITY_THRESHOLD_MAX - DENSITY_THRESHOLD_MIN);
89 DENSITY_GRAPH_WEIGHT_MAX - ratio * (DENSITY_GRAPH_WEIGHT_MAX - DENSITY_GRAPH_WEIGHT_MIN)
90 };
91
92 let linguistic_weight = DENSITY_LINGUISTIC_WEIGHT;
93 let semantic_weight = 1.0 - graph_weight - linguistic_weight;
94
95 (semantic_weight, graph_weight, linguistic_weight)
96}
97
98pub fn calculate_importance_weighted_decay(importance: f32) -> f32 {
109 let clamped_importance = importance.clamp(0.0, 1.0);
110 IMPORTANCE_DECAY_MIN
111 + (1.0 - clamped_importance) * (IMPORTANCE_DECAY_MAX - IMPORTANCE_DECAY_MIN)
112}
113
114pub fn calculate_adaptive_hops(graph_density: Option<f32>) -> usize {
139 match graph_density {
140 Some(density) if density > BIDIRECTIONAL_DENSITY_DENSE => {
141 BIDIRECTIONAL_HOPS_DENSE
143 }
144 Some(density) if density < BIDIRECTIONAL_DENSITY_SPARSE => {
145 BIDIRECTIONAL_HOPS_SPARSE
147 }
148 Some(_) => {
149 BIDIRECTIONAL_HOPS_MEDIUM
151 }
152 None => {
153 BIDIRECTIONAL_HOPS_MEDIUM
155 }
156 }
157}
158
159fn spread_single_direction(
164 seeds: &[(Uuid, f32)],
165 graph: &GraphMemory,
166 max_hops: usize,
167 threshold: f32,
168) -> Result<(HashMap<Uuid, f32>, Vec<Uuid>)> {
169 let mut activation_map: HashMap<Uuid, f32> = seeds.iter().cloned().collect();
170 let mut traversed_edges: Vec<Uuid> = Vec::new();
171
172 for hop in 1..=max_hops {
173 let current_activated: Vec<(Uuid, f32)> =
174 activation_map.iter().map(|(id, act)| (*id, *act)).collect();
175
176 for (entity_uuid, source_activation) in current_activated {
177 if source_activation < threshold {
178 continue;
179 }
180
181 const MAX_EDGES_PER_SPREAD: usize = 100;
182 let edges =
183 graph.get_entity_relationships_limited(&entity_uuid, Some(MAX_EDGES_PER_SPREAD))?;
184
185 let degree_norm = if SPREADING_DEGREE_NORMALIZATION {
189 1.0 / (1.0 + edges.len() as f32).sqrt()
190 } else {
191 1.0
192 };
193
194 for edge in edges {
195 let target_uuid = edge.to_entity;
196
197 let tier_trust = if edge.is_potentiated() {
199 EDGE_TIER_TRUST_LTP
200 } else {
201 match edge.tier {
202 EdgeTier::L3Semantic => EDGE_TIER_TRUST_L3,
203 EdgeTier::L2Episodic => EDGE_TIER_TRUST_L2,
204 EdgeTier::L1Working => EDGE_TIER_TRUST_L1,
205 }
206 };
207
208 let effective = edge.effective_strength();
210 let decay_rate = calculate_importance_weighted_decay(effective);
211 let decay = (-decay_rate * hop as f32).exp();
212
213 let spread_amount =
214 source_activation * decay * effective * tier_trust * degree_norm;
215
216 let new_activation = activation_map.entry(target_uuid).or_insert(0.0);
217 *new_activation += spread_amount;
218
219 if spread_amount > 0.01 {
220 traversed_edges.push(edge.uuid);
221 }
222 }
223 }
224
225 let max_activation = activation_map
227 .values()
228 .cloned()
229 .max_by(|a, b| a.total_cmp(b))
230 .unwrap_or(1.0);
231
232 if max_activation > SPREADING_NORMALIZATION_FACTOR {
233 let scale = SPREADING_NORMALIZATION_FACTOR / max_activation;
234 for activation in activation_map.values_mut() {
235 *activation *= scale;
236 }
237 }
238
239 activation_map.retain(|_, activation| *activation > threshold);
241 }
242
243 Ok((activation_map, traversed_edges))
244}
245
246fn bidirectional_spread(
259 entity_data: &[(Uuid, String, f32, f32)], graph: &GraphMemory,
261 total_salience: f32,
262 hops_per_direction: usize,
263) -> Result<(HashMap<Uuid, f32>, Vec<Uuid>, usize)> {
264 let mut forward_seeds: Vec<(Uuid, f32)> = Vec::new();
267 let mut backward_seeds: Vec<(Uuid, f32)> = Vec::new();
268
269 for (i, (uuid, _name, ic_weight, salience)) in entity_data.iter().enumerate() {
270 let normalized_salience = salience / total_salience;
271 let salience_boost = SALIENCE_BOOST_FACTOR * normalized_salience;
272 let initial_activation = ic_weight * (1.0 + salience_boost);
273
274 if i % 2 == 0 {
275 forward_seeds.push((*uuid, initial_activation));
276 } else {
277 backward_seeds.push((*uuid, initial_activation));
278 }
279 }
280
281 if backward_seeds.is_empty() && !forward_seeds.is_empty() {
283 backward_seeds.push(forward_seeds[forward_seeds.len() - 1]);
284 }
285
286 tracing::debug!(
287 "🔀 Bidirectional spread: {} forward seeds, {} backward seeds",
288 forward_seeds.len(),
289 backward_seeds.len()
290 );
291
292 let threshold = SPREADING_ACTIVATION_THRESHOLD;
294 let (forward_map, forward_edges) =
295 spread_single_direction(&forward_seeds, graph, hops_per_direction, threshold)?;
296
297 let (backward_map, backward_edges) =
298 spread_single_direction(&backward_seeds, graph, hops_per_direction, threshold)?;
299
300 let mut combined_map: HashMap<Uuid, f32> = HashMap::new();
302 let mut intersection_count = 0;
303
304 let all_entities: std::collections::HashSet<Uuid> = forward_map
306 .keys()
307 .chain(backward_map.keys())
308 .cloned()
309 .collect();
310
311 for entity_uuid in all_entities {
312 let forward_activation = forward_map.get(&entity_uuid).cloned().unwrap_or(0.0);
313 let backward_activation = backward_map.get(&entity_uuid).cloned().unwrap_or(0.0);
314
315 let is_intersection = forward_activation >= BIDIRECTIONAL_INTERSECTION_MIN
317 && backward_activation >= BIDIRECTIONAL_INTERSECTION_MIN;
318
319 let combined_activation = if is_intersection {
320 intersection_count += 1;
321 (forward_activation + backward_activation) * BIDIRECTIONAL_INTERSECTION_BOOST
323 } else {
324 forward_activation + backward_activation
326 };
327
328 combined_map.insert(entity_uuid, combined_activation);
329 }
330
331 let mut all_edges = forward_edges;
333 all_edges.extend(backward_edges);
334
335 tracing::debug!(
336 "🔀 Bidirectional result: {} entities ({} intersections), {} edges",
337 combined_map.len(),
338 intersection_count,
339 all_edges.len()
340 );
341
342 Ok((combined_map, all_edges, intersection_count))
343}
344
345pub fn spreading_activation_retrieve(
352 query_text: &str,
353 query: &Query,
354 graph: &GraphMemory,
355 embedder: &dyn Embedder,
356 episode_to_memory_fn: impl Fn(&EpisodicNode) -> Result<Option<SharedMemory>>,
357) -> Result<Vec<ActivatedMemory>> {
358 let (memories, _stats) = spreading_activation_retrieve_with_stats(
360 query_text,
361 query,
362 graph,
363 embedder,
364 None, episode_to_memory_fn,
366 )?;
367 Ok(memories)
368}
369
370pub fn spreading_activation_retrieve_with_stats(
388 query_text: &str,
389 query: &Query,
390 graph: &GraphMemory,
391 embedder: &dyn Embedder,
392 graph_density: Option<f32>,
393 episode_to_memory_fn: impl Fn(&EpisodicNode) -> Result<Option<SharedMemory>>,
394) -> Result<(Vec<ActivatedMemory>, RetrievalStats)> {
395 let start_time = Instant::now();
396 let mut stats = RetrievalStats::default();
397
398 let (semantic_weight, graph_weight, linguistic_weight) = if let Some(density) = graph_density {
400 stats.mode = "associative".to_string();
401 stats.graph_density = density;
402 calculate_density_weights(density)
403 } else {
404 stats.mode = "hybrid".to_string();
405 stats.graph_density = 0.0;
406 (
407 HYBRID_SEMANTIC_WEIGHT,
408 HYBRID_GRAPH_WEIGHT,
409 HYBRID_LINGUISTIC_WEIGHT,
410 )
411 };
412
413 stats.semantic_weight = semantic_weight;
414 stats.graph_weight = graph_weight;
415 stats.linguistic_weight = linguistic_weight;
416
417 let analysis = analyze_query(query_text);
419
420 tracing::info!("🔍 Query Analysis:");
421 tracing::info!(
422 " Focal Entities: {:?}",
423 analysis
424 .focal_entities
425 .iter()
426 .map(|e| &e.text)
427 .collect::<Vec<_>>()
428 );
429 tracing::info!(
430 " Modifiers: {:?}",
431 analysis
432 .discriminative_modifiers
433 .iter()
434 .map(|m| &m.text)
435 .collect::<Vec<_>>()
436 );
437 tracing::info!(
438 " Relations: {:?}",
439 analysis
440 .relational_context
441 .iter()
442 .map(|r| &r.text)
443 .collect::<Vec<_>>()
444 );
445 tracing::info!(
446 " Weights: semantic={:.2}, graph={:.2}, linguistic={:.2}",
447 semantic_weight,
448 graph_weight,
449 linguistic_weight
450 );
451
452 let mut activation_map: HashMap<Uuid, f32> = HashMap::new();
455
456 let mut entity_data: Vec<(Uuid, String, f32, f32)> = Vec::new(); for entity in &analysis.focal_entities {
460 if let Some(entity_node) = graph.find_entity_by_name(&entity.text)? {
461 entity_data.push((
462 entity_node.uuid,
463 entity.text.clone(),
464 entity.ic_weight,
465 entity_node.salience,
466 ));
467 } else {
468 tracing::debug!(" ✗ Entity '{}' not found in graph", entity.text);
469 }
470 }
471
472 let total_salience: f32 = entity_data.iter().map(|(_, _, _, s)| s).sum();
474 let total_salience = total_salience.max(0.1); let mut total_boost = 0.0_f32;
478 for (uuid, name, ic_weight, salience) in &entity_data {
479 let normalized_salience = salience / total_salience;
481
482 let salience_boost = SALIENCE_BOOST_FACTOR * normalized_salience;
484 let initial_activation = ic_weight * (1.0 + salience_boost);
485
486 activation_map.insert(*uuid, initial_activation);
487 stats.entities_activated += 1;
488 total_boost += salience_boost;
489
490 tracing::debug!(
491 " ✓ Activated '{}' (IC={:.2}, salience={:.2}, norm={:.2}, boost={:.2}, activation={:.2})",
492 name,
493 ic_weight,
494 salience,
495 normalized_salience,
496 salience_boost,
497 initial_activation
498 );
499 }
500
501 stats.avg_salience_boost = if !entity_data.is_empty() {
503 total_boost / entity_data.len() as f32
504 } else {
505 0.0
506 };
507
508 if activation_map.is_empty() {
509 tracing::warn!("No entities found in graph, falling back to semantic search");
510 stats.retrieval_time_us = start_time.elapsed().as_micros() as u64;
511 return Ok((Vec::new(), stats)); }
513
514 let graph_start = Instant::now();
517 let mut traversed_edges: Vec<Uuid>;
518
519 if entity_data.len() >= BIDIRECTIONAL_MIN_ENTITIES {
520 let adaptive_hops = calculate_adaptive_hops(graph_density);
524
525 tracing::info!(
526 "🔀 Using bidirectional spreading ({} focal entities, {} hops/direction, density={:.2})",
527 entity_data.len(),
528 adaptive_hops,
529 graph_density.unwrap_or(0.0)
530 );
531
532 let (bidirectional_map, edges, intersection_count) =
533 bidirectional_spread(&entity_data, graph, total_salience, adaptive_hops)?;
534
535 activation_map = bidirectional_map;
536 traversed_edges = edges;
537 stats.entities_activated = activation_map.len();
538 stats.graph_hops = adaptive_hops * 2; tracing::info!(
541 "🔀 Bidirectional complete: {} entities, {} intersections",
542 activation_map.len(),
543 intersection_count
544 );
545 } else {
546 tracing::info!(
549 "📊 Using unidirectional spreading ({} focal entity)",
550 entity_data.len()
551 );
552
553 let mut edges_collected: Vec<Uuid> = Vec::new();
554 let mut current_threshold = SPREADING_ACTIVATION_THRESHOLD;
555
556 for hop in 1..=SPREADING_MAX_HOPS {
557 stats.graph_hops = hop;
558 let count_before = activation_map.len();
559
560 tracing::debug!(
561 "📊 Spreading activation (hop {}/{}, threshold={:.4})",
562 hop,
563 SPREADING_MAX_HOPS,
564 current_threshold
565 );
566
567 let current_activated: Vec<(Uuid, f32)> =
569 activation_map.iter().map(|(id, act)| (*id, *act)).collect();
570
571 for (entity_uuid, source_activation) in current_activated {
572 if source_activation < current_threshold {
574 continue;
575 }
576
577 const MAX_EDGES_PER_SPREAD: usize = 100;
579 let edges = graph
580 .get_entity_relationships_limited(&entity_uuid, Some(MAX_EDGES_PER_SPREAD))?;
581
582 for edge in edges {
583 let target_uuid = edge.to_entity;
585
586 let tier_trust = if edge.is_potentiated() {
588 EDGE_TIER_TRUST_LTP
589 } else {
590 match edge.tier {
591 EdgeTier::L3Semantic => EDGE_TIER_TRUST_L3,
592 EdgeTier::L2Episodic => EDGE_TIER_TRUST_L2,
593 EdgeTier::L1Working => EDGE_TIER_TRUST_L1,
594 }
595 };
596
597 let effective = edge.effective_strength();
599 let decay_rate = calculate_importance_weighted_decay(effective);
600 let decay = (-decay_rate * hop as f32).exp();
601
602 let spread_amount = source_activation * decay * effective * tier_trust;
603
604 let new_activation = activation_map.entry(target_uuid).or_insert(0.0);
605 *new_activation += spread_amount;
606
607 if spread_amount > 0.01 {
608 edges_collected.push(edge.uuid);
609 }
610
611 if *new_activation >= current_threshold
612 && *new_activation - spread_amount < current_threshold
613 {
614 stats.entities_activated += 1;
615 }
616 }
617 }
618
619 let max_activation = activation_map
621 .values()
622 .cloned()
623 .max_by(|a, b| a.total_cmp(b))
624 .unwrap_or(1.0);
625
626 if max_activation > SPREADING_NORMALIZATION_FACTOR {
627 let scale = SPREADING_NORMALIZATION_FACTOR / max_activation;
628 for activation in activation_map.values_mut() {
629 *activation *= scale;
630 }
631 }
632
633 activation_map.retain(|_, activation| *activation > current_threshold);
635
636 let count_after = activation_map.len();
637 let new_activations = count_after.saturating_sub(count_before);
638
639 tracing::debug!(
640 " Activated entities: {} (+{} new)",
641 count_after,
642 new_activations
643 );
644
645 if count_after < SPREADING_MIN_CANDIDATES
647 && current_threshold > SPREADING_RELAXED_THRESHOLD
648 {
649 current_threshold = SPREADING_RELAXED_THRESHOLD;
650 tracing::debug!(
651 " Relaxing threshold to {:.4} (only {} candidates)",
652 current_threshold,
653 count_after
654 );
655 }
656
657 if hop >= SPREADING_MIN_HOPS {
659 let new_ratio = if count_after > 0 {
660 new_activations as f32 / count_after as f32
661 } else {
662 0.0
663 };
664
665 if new_ratio < SPREADING_EARLY_TERMINATION_RATIO && count_after > 0 {
666 tracing::debug!(
667 " Early termination: activation saturated ({:.1}% new)",
668 new_ratio * 100.0
669 );
670 break;
671 }
672
673 if count_after >= SPREADING_EARLY_TERMINATION_CANDIDATES {
674 tracing::debug!(
675 " Early termination: sufficient coverage ({} candidates)",
676 count_after
677 );
678 break;
679 }
680 }
681 }
682
683 traversed_edges = edges_collected;
684 }
685
686 stats.graph_time_us = graph_start.elapsed().as_micros() as u64;
687 tracing::info!("📊 Final activated entities: {}", activation_map.len());
688
689 let mut activated_memories: HashMap<Uuid, (f32, EpisodicNode)> = HashMap::new();
691
692 for (entity_uuid, entity_activation) in &activation_map {
693 let episodes = graph.get_episodes_by_entity(entity_uuid)?;
694
695 for episode in episodes {
696 let current = activated_memories
698 .entry(episode.uuid)
699 .or_insert((0.0, episode.clone()));
700
701 current.0 += entity_activation;
702 }
703 }
704
705 stats.graph_candidates = activated_memories.len();
706 tracing::info!(
707 "📊 Retrieved {} episodic memories via graph",
708 activated_memories.len()
709 );
710
711 let mut scored_memories = Vec::new();
713
714 let embedding_start = Instant::now();
716 let query_embedding = embedder.encode(query_text)?;
717 stats.embedding_time_us = embedding_start.elapsed().as_micros() as u64;
718
719 let now = chrono::Utc::now();
720
721 for (_episode_uuid, (graph_activation, episode)) in activated_memories {
722 if let Some(memory) = episode_to_memory_fn(&episode)? {
724 let semantic_score = if let Some(mem_emb) = &memory.experience.embeddings {
726 cosine_similarity(&query_embedding, mem_emb)
727 } else {
728 0.0
729 };
730
731 let linguistic_raw = calculate_linguistic_match(&memory, &analysis);
733 let linguistic_score = linguistic_raw; let tier_graph_mult = match memory.tier {
739 MemoryTier::Working => MEMORY_TIER_GRAPH_MULT_WORKING, MemoryTier::Session => MEMORY_TIER_GRAPH_MULT_SESSION, MemoryTier::LongTerm => MEMORY_TIER_GRAPH_MULT_LONGTERM, MemoryTier::Archive => MEMORY_TIER_GRAPH_MULT_ARCHIVE, };
744
745 let tier_adjusted_graph_weight = graph_weight * tier_graph_mult;
749 let weight_sum = semantic_weight + tier_adjusted_graph_weight + linguistic_weight;
751 let norm_semantic = semantic_weight / weight_sum;
752 let norm_graph = tier_adjusted_graph_weight / weight_sum;
753 let norm_linguistic = linguistic_weight / weight_sum;
754
755 let hybrid_score = semantic_score * norm_semantic
756 + graph_activation * norm_graph
757 + linguistic_score * norm_linguistic;
758
759 const RECENCY_DECAY_RATE: f32 = 0.01;
762 let hours_old = (now - memory.created_at).num_hours().max(0) as f32;
763 let recency_boost = (-RECENCY_DECAY_RATE * hours_old).exp() * 0.1;
764
765 let arousal_boost = memory
767 .experience
768 .context
769 .as_ref()
770 .map(|c| c.emotional.arousal * 0.05)
771 .unwrap_or(0.0);
772
773 let credibility_boost = memory
775 .experience
776 .context
777 .as_ref()
778 .map(|c| (c.source.credibility - 0.5).max(0.0) * 0.1)
779 .unwrap_or(0.0);
780
781 let final_score = hybrid_score + recency_boost + arousal_boost + credibility_boost;
782
783 scored_memories.push(ActivatedMemory {
784 memory,
785 activation_score: graph_activation,
786 semantic_score,
787 linguistic_score,
788 final_score,
789 });
790 }
791 }
792
793 scored_memories.sort_by(|a, b| b.final_score.total_cmp(&a.final_score));
795
796 scored_memories.truncate(query.max_results);
798
799 stats.retrieval_time_us = start_time.elapsed().as_micros() as u64;
800
801 traversed_edges.sort();
803 traversed_edges.dedup();
804 stats.traversed_edges = traversed_edges;
805
806 tracing::info!(
807 "🎯 Returning {} memories (top scores: {:?}), {} edges traversed",
808 scored_memories.len(),
809 scored_memories
810 .iter()
811 .take(3)
812 .map(|m| m.final_score)
813 .collect::<Vec<_>>(),
814 stats.traversed_edges.len()
815 );
816
817 Ok((scored_memories, stats))
818}
819
820fn calculate_linguistic_match(memory: &Memory, analysis: &QueryAnalysis) -> f32 {
827 let content_lower = memory.experience.content.to_lowercase();
828 let mut score = 0.0;
829
830 for entity in &analysis.focal_entities {
832 if content_lower.contains(&entity.text.to_lowercase()) {
833 score += 1.0;
834 }
835 }
836
837 for modifier in &analysis.discriminative_modifiers {
839 if content_lower.contains(&modifier.text.to_lowercase()) {
840 score += 0.5;
841 }
842 }
843
844 for relation in &analysis.relational_context {
846 if content_lower.contains(&relation.text.to_lowercase()) {
847 score += 0.2;
848 }
849 }
850
851 let max_possible = analysis.focal_entities.len() as f32 * 1.0
853 + analysis.discriminative_modifiers.len() as f32 * 0.5
854 + analysis.relational_context.len() as f32 * 0.2;
855
856 if max_possible > 0.0 {
857 score / max_possible
858 } else {
859 0.0
860 }
861}
862
863#[cfg(test)]
864mod tests {
865 use super::*;
866
867 #[test]
868 fn test_cosine_similarity() {
869 let a = vec![1.0, 0.0, 0.0];
870 let b = vec![1.0, 0.0, 0.0];
871 assert_eq!(cosine_similarity(&a, &b), 1.0);
872
873 let a = vec![1.0, 0.0];
874 let b = vec![0.0, 1.0];
875 assert_eq!(cosine_similarity(&a, &b), 0.0);
876
877 let a = vec![1.0, 1.0];
878 let b = vec![1.0, 1.0];
879 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
880 }
881
882 #[test]
883 fn test_density_weights_sparse() {
884 let (semantic, graph, linguistic) = calculate_density_weights(0.3);
886 assert!((graph - DENSITY_GRAPH_WEIGHT_MAX).abs() < 0.001);
887 assert!((linguistic - DENSITY_LINGUISTIC_WEIGHT).abs() < 0.001);
888 assert!((semantic + graph + linguistic - 1.0).abs() < 0.001);
889 }
890
891 #[test]
892 fn test_density_weights_dense() {
893 let (semantic, graph, linguistic) = calculate_density_weights(2.5);
895 assert!((graph - DENSITY_GRAPH_WEIGHT_MIN).abs() < 0.001);
896 assert!((linguistic - DENSITY_LINGUISTIC_WEIGHT).abs() < 0.001);
897 assert!((semantic + graph + linguistic - 1.0).abs() < 0.001);
898 }
899
900 #[test]
901 fn test_density_weights_interpolation() {
902 let (semantic, graph, linguistic) = calculate_density_weights(1.25);
904 assert!(graph > DENSITY_GRAPH_WEIGHT_MIN);
905 assert!(graph < DENSITY_GRAPH_WEIGHT_MAX);
906 assert!((linguistic - DENSITY_LINGUISTIC_WEIGHT).abs() < 0.001);
907 assert!((semantic + graph + linguistic - 1.0).abs() < 0.001);
908 }
909
910 #[test]
911 fn test_importance_weighted_decay_high() {
912 let decay = calculate_importance_weighted_decay(1.0);
914 assert!((decay - IMPORTANCE_DECAY_MIN).abs() < 0.001);
915 }
916
917 #[test]
918 fn test_importance_weighted_decay_low() {
919 let decay = calculate_importance_weighted_decay(0.0);
921 assert!((decay - IMPORTANCE_DECAY_MAX).abs() < 0.001);
922 }
923
924 #[test]
925 fn test_importance_weighted_decay_mid() {
926 let decay = calculate_importance_weighted_decay(0.5);
928 let expected = IMPORTANCE_DECAY_MIN + 0.5 * (IMPORTANCE_DECAY_MAX - IMPORTANCE_DECAY_MIN);
929 assert!((decay - expected).abs() < 0.001);
930 }
931
932 #[test]
933 fn test_activation_decay() {
934 let initial_activation = 1.0;
936
937 let high_importance_decay = calculate_importance_weighted_decay(0.9);
939 let high_importance_final = initial_activation * (-high_importance_decay).exp();
940
941 let low_importance_decay = calculate_importance_weighted_decay(0.1);
943 let low_importance_final = initial_activation * (-low_importance_decay).exp();
944
945 assert!(high_importance_final > low_importance_final);
947 }
948
949 #[test]
950 fn test_adaptive_constants_valid() {
951 use crate::constants::*;
952
953 assert!(SPREADING_RELAXED_THRESHOLD < SPREADING_ACTIVATION_THRESHOLD);
955
956 assert!(SPREADING_MIN_HOPS <= SPREADING_MAX_HOPS);
958
959 assert!(SPREADING_EARLY_TERMINATION_RATIO > 0.0);
961 assert!(SPREADING_EARLY_TERMINATION_RATIO < 1.0);
962
963 assert!(SPREADING_NORMALIZATION_FACTOR > 0.0);
965
966 assert!(SPREADING_MIN_CANDIDATES > 0);
968 assert!(SPREADING_MIN_CANDIDATES < SPREADING_EARLY_TERMINATION_CANDIDATES);
969 }
970
971 #[test]
972 fn test_normalization_prevents_explosion() {
973 use crate::constants::SPREADING_NORMALIZATION_FACTOR;
974
975 let mut activations: Vec<f32> = vec![1.0, 0.8, 0.5, 0.3];
977
978 for _ in 0..5 {
980 for activation in &mut activations {
981 *activation += *activation * 0.5; }
983
984 let max_activation = activations
986 .iter()
987 .cloned()
988 .max_by(|a, b| a.total_cmp(b))
989 .unwrap_or(1.0);
990
991 if max_activation > SPREADING_NORMALIZATION_FACTOR {
992 let scale = SPREADING_NORMALIZATION_FACTOR / max_activation;
993 for activation in &mut activations {
994 *activation *= scale;
995 }
996 }
997 }
998
999 let final_max = activations
1001 .iter()
1002 .cloned()
1003 .max_by(|a, b| a.total_cmp(b))
1004 .unwrap_or(0.0);
1005
1006 assert!(final_max <= SPREADING_NORMALIZATION_FACTOR + 0.001);
1007 }
1008
1009 #[test]
1010 fn test_early_termination_ratio() {
1011 use crate::constants::SPREADING_EARLY_TERMINATION_RATIO;
1012
1013 let total_before = 50;
1015 let total_after = 52; let new_activations = total_after - total_before;
1018 let new_ratio = new_activations as f32 / total_after as f32;
1019
1020 assert!(new_ratio < SPREADING_EARLY_TERMINATION_RATIO);
1022
1023 let growing_before = 10;
1025 let growing_after = 25; let growing_new = growing_after - growing_before;
1028 let growing_ratio = growing_new as f32 / growing_after as f32;
1029
1030 assert!(growing_ratio >= SPREADING_EARLY_TERMINATION_RATIO);
1032 }
1033
1034 #[test]
1039 fn test_bidirectional_constants_valid() {
1040 assert!(BIDIRECTIONAL_MIN_ENTITIES >= 2);
1042
1043 assert!(BIDIRECTIONAL_INTERSECTION_BOOST > 1.0);
1045
1046 assert!(BIDIRECTIONAL_INTERSECTION_MIN < SPREADING_ACTIVATION_THRESHOLD);
1048
1049 assert!(BIDIRECTIONAL_DENSITY_SPARSE < BIDIRECTIONAL_DENSITY_DENSE);
1051
1052 assert!(BIDIRECTIONAL_HOPS_DENSE < BIDIRECTIONAL_HOPS_MEDIUM);
1054 assert!(BIDIRECTIONAL_HOPS_MEDIUM < BIDIRECTIONAL_HOPS_SPARSE);
1055
1056 assert!(BIDIRECTIONAL_HOPS_MEDIUM * 2 >= SPREADING_MAX_HOPS);
1058 }
1059
1060 #[test]
1061 fn test_adaptive_hops_dense_graph() {
1062 let hops = calculate_adaptive_hops(Some(3.0)); assert_eq!(hops, BIDIRECTIONAL_HOPS_DENSE);
1065 assert_eq!(hops, 2);
1066 }
1067
1068 #[test]
1069 fn test_adaptive_hops_sparse_graph() {
1070 let hops = calculate_adaptive_hops(Some(0.3)); assert_eq!(hops, BIDIRECTIONAL_HOPS_SPARSE);
1073 assert_eq!(hops, 4);
1074 }
1075
1076 #[test]
1077 fn test_adaptive_hops_medium_graph() {
1078 let hops = calculate_adaptive_hops(Some(1.0)); assert_eq!(hops, BIDIRECTIONAL_HOPS_MEDIUM);
1081 assert_eq!(hops, 3);
1082 }
1083
1084 #[test]
1085 fn test_adaptive_hops_no_density() {
1086 let hops = calculate_adaptive_hops(None);
1088 assert_eq!(hops, BIDIRECTIONAL_HOPS_MEDIUM);
1089 }
1090
1091 #[test]
1092 fn test_adaptive_hops_lifecycle() {
1093 let fresh_hops = calculate_adaptive_hops(Some(2.5)); let mid_hops = calculate_adaptive_hops(Some(1.0)); let mature_hops = calculate_adaptive_hops(Some(0.3)); assert!(fresh_hops <= mid_hops);
1100 assert!(mid_hops <= mature_hops);
1101
1102 assert_eq!(fresh_hops, 2);
1104 assert_eq!(mid_hops, 3);
1105 assert_eq!(mature_hops, 4);
1106 }
1107
1108 #[test]
1109 fn test_intersection_boost_calculation() {
1110 let forward_activation = 0.5;
1112 let backward_activation = 0.3;
1113
1114 assert!(forward_activation >= BIDIRECTIONAL_INTERSECTION_MIN);
1116 assert!(backward_activation >= BIDIRECTIONAL_INTERSECTION_MIN);
1117
1118 let boosted = (forward_activation + backward_activation) * BIDIRECTIONAL_INTERSECTION_BOOST;
1120 let unboosted = forward_activation + backward_activation;
1121
1122 assert!(boosted > unboosted);
1124
1125 let expected_ratio = BIDIRECTIONAL_INTERSECTION_BOOST;
1127 assert!((boosted / unboosted - expected_ratio).abs() < 0.001);
1128 }
1129
1130 #[test]
1131 fn test_non_intersection_no_boost() {
1132 let forward_activation = 0.5;
1134 let backward_activation = 0.0; assert!(backward_activation < BIDIRECTIONAL_INTERSECTION_MIN);
1138
1139 let combined = forward_activation + backward_activation;
1141
1142 assert!((combined - forward_activation).abs() < 0.001);
1144 }
1145
1146 #[test]
1147 fn test_bidirectional_entity_split() {
1148 let entities = vec![
1150 (Uuid::new_v4(), "entity1".to_string(), 1.0, 0.5),
1151 (Uuid::new_v4(), "entity2".to_string(), 1.0, 0.5),
1152 (Uuid::new_v4(), "entity3".to_string(), 1.0, 0.5),
1153 (Uuid::new_v4(), "entity4".to_string(), 1.0, 0.5),
1154 ];
1155
1156 let mut forward_count = 0;
1158 let mut backward_count = 0;
1159
1160 for (i, _) in entities.iter().enumerate() {
1161 if i % 2 == 0 {
1162 forward_count += 1;
1163 } else {
1164 backward_count += 1;
1165 }
1166 }
1167
1168 assert_eq!(forward_count, 2);
1169 assert_eq!(backward_count, 2);
1170 }
1171
1172 #[test]
1173 fn test_bidirectional_odd_entities() {
1174 let entities = vec![
1176 (Uuid::new_v4(), "entity1".to_string(), 1.0, 0.5),
1177 (Uuid::new_v4(), "entity2".to_string(), 1.0, 0.5),
1178 (Uuid::new_v4(), "entity3".to_string(), 1.0, 0.5),
1179 ];
1180
1181 let mut forward_seeds = Vec::new();
1183 let mut backward_seeds = Vec::new();
1184
1185 for (i, entity) in entities.iter().enumerate() {
1186 if i % 2 == 0 {
1187 forward_seeds.push(entity.0);
1188 } else {
1189 backward_seeds.push(entity.0);
1190 }
1191 }
1192
1193 assert_eq!(forward_seeds.len(), 2);
1195 assert_eq!(backward_seeds.len(), 1);
1196
1197 assert!(!forward_seeds.is_empty());
1199 assert!(!backward_seeds.is_empty());
1200 }
1201
1202 #[test]
1203 fn test_bidirectional_threshold_triggers() {
1204 let single_entity = vec![(Uuid::new_v4(), "entity1".to_string(), 1.0, 0.5)];
1208 assert!(single_entity.len() < BIDIRECTIONAL_MIN_ENTITIES);
1209
1210 let two_entities = vec![
1212 (Uuid::new_v4(), "entity1".to_string(), 1.0, 0.5),
1213 (Uuid::new_v4(), "entity2".to_string(), 1.0, 0.5),
1214 ];
1215 assert!(two_entities.len() >= BIDIRECTIONAL_MIN_ENTITIES);
1216
1217 let many_entities = vec![
1219 (Uuid::new_v4(), "entity1".to_string(), 1.0, 0.5),
1220 (Uuid::new_v4(), "entity2".to_string(), 1.0, 0.5),
1221 (Uuid::new_v4(), "entity3".to_string(), 1.0, 0.5),
1222 (Uuid::new_v4(), "entity4".to_string(), 1.0, 0.5),
1223 (Uuid::new_v4(), "entity5".to_string(), 1.0, 0.5),
1224 ];
1225 assert!(many_entities.len() >= BIDIRECTIONAL_MIN_ENTITIES);
1226 }
1227
1228 #[test]
1229 fn test_complexity_improvement() {
1230 let b: f64 = 10.0;
1241 let d: f64 = 6.0;
1242
1243 let unidirectional = b.powf(d);
1244 let bidirectional = 2.0 * b.powf(d / 2.0);
1245
1246 assert!(bidirectional < unidirectional);
1248
1249 let improvement = unidirectional / bidirectional;
1251 assert!(improvement > 100.0);
1252 }
1253
1254 #[test]
1255 fn test_intersection_detection_threshold() {
1256 let min_threshold = BIDIRECTIONAL_INTERSECTION_MIN;
1258
1259 let expected = SPREADING_ACTIVATION_THRESHOLD / 2.0;
1261 assert!((min_threshold - expected).abs() < 0.001);
1262
1263 assert!(min_threshold > 0.0);
1265
1266 assert!(min_threshold < 1.0);
1269 }
1270}