zeph_memory/five_signal/
causal_distance.rs1use std::collections::HashMap;
5use std::sync::Arc;
6
7use zeph_common::memory::EdgeType;
8
9use crate::graph::GraphStore;
10
11pub struct CausalDistanceComputer {
17 graph_store: Arc<GraphStore>,
18 max_depth: u32,
19 neutral_distance: u32,
20 cache: Option<(i64, HashMap<i64, u32>)>,
22}
23
24impl CausalDistanceComputer {
25 #[must_use]
32 pub fn new(graph_store: Arc<GraphStore>, max_depth: u32, neutral_distance: u32) -> Self {
33 Self {
34 graph_store,
35 max_depth,
36 neutral_distance,
37 cache: None,
38 }
39 }
40
41 #[tracing::instrument(
54 name = "memory.five_signal.causal_distance.compute",
55 skip(self, entity_ids),
56 fields(goal_entity_id, candidate_count = entity_ids.len())
57 )]
58 pub async fn compute(
59 &mut self,
60 goal_entity_id: Option<i64>,
61 entity_ids: &[i64],
62 ) -> Result<HashMap<i64, u32>, crate::error::MemoryError> {
63 tracing::debug!("five_signal: computing causal distances");
64
65 let Some(goal_id) = goal_entity_id else {
66 return Ok(HashMap::new());
67 };
68
69 let neutral = self.neutral_distance;
70 let depth_map = self.ensure_cache(goal_id).await?;
71
72 let result = entity_ids
73 .iter()
74 .map(|&eid| {
75 let dist = depth_map.get(&eid).copied().unwrap_or(neutral);
76 (eid, dist)
77 })
78 .collect();
79
80 Ok(result)
81 }
82
83 #[must_use]
88 #[inline]
89 pub fn distance_to_score(distance: u32) -> f64 {
90 if distance == 0 {
91 1.0
92 } else {
93 (1.0_f64 / f64::from(distance)).min(1.0)
94 }
95 }
96
97 pub fn invalidate_cache(&mut self) {
99 self.cache = None;
100 }
101
102 async fn ensure_cache(
103 &mut self,
104 goal_id: i64,
105 ) -> Result<&HashMap<i64, u32>, crate::error::MemoryError> {
106 if self.cache.as_ref().map(|(id, _)| *id) != Some(goal_id) {
107 let (_, _, depth_map) = self
108 .graph_store
109 .bfs_typed(goal_id, self.max_depth, &[EdgeType::Causal])
110 .await?;
111 self.cache = Some((goal_id, depth_map));
112 }
113 Ok(&self.cache.as_ref().expect("just set above").1)
114 }
115}
116
117#[cfg(test)]
118mod tests {
119 use super::*;
120
121 #[test]
122 fn distance_to_score_values() {
123 assert!((CausalDistanceComputer::distance_to_score(0) - 1.0).abs() < 1e-9);
124 assert!((CausalDistanceComputer::distance_to_score(1) - 1.0).abs() < 1e-9);
125 assert!((CausalDistanceComputer::distance_to_score(2) - 0.5).abs() < 1e-9);
126 assert!((CausalDistanceComputer::distance_to_score(5) - 0.2).abs() < 1e-9);
127 }
128
129 #[test]
130 fn distance_to_score_beyond_max_depth_clamped_to_min() {
131 let score_at_limit = CausalDistanceComputer::distance_to_score(10);
133 let score_beyond = CausalDistanceComputer::distance_to_score(20);
134 assert!(score_at_limit <= 1.0);
135 assert!(score_beyond <= score_at_limit, "deeper nodes score lower");
136 assert!((score_at_limit - 0.1).abs() < 1e-9);
137 assert!((score_beyond - 0.05).abs() < 1e-9);
138 }
139
140 #[test]
141 fn neutral_distance_determines_unreachable_score() {
142 let neutral = 5_u32;
144 let score = CausalDistanceComputer::distance_to_score(neutral);
145 assert!((score - 0.2).abs() < 1e-9);
146 }
147
148 #[tokio::test]
150 async fn compute_none_goal_returns_empty_map() {
151 use std::sync::Arc;
152
153 let pool = sqlx::SqlitePool::connect(":memory:").await.unwrap();
155 let graph_store = Arc::new(crate::graph::GraphStore::new(pool));
156 let mut computer = CausalDistanceComputer::new(graph_store, 10, 5);
157
158 let result = computer
159 .compute(None, &[1, 2, 3])
160 .await
161 .expect("None goal must not fail");
162 assert!(
163 result.is_empty(),
164 "goal_entity_id=None must return empty map, got: {result:?}"
165 );
166 }
167}