Skip to main content

zeph_memory/five_signal/
causal_distance.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::HashMap;
5use std::sync::Arc;
6
7use zeph_common::memory::EdgeType;
8
9use crate::graph::GraphStore;
10
11/// Causal distance computer backed by MAGMA graph BFS.
12///
13/// Computes the shortest causal-edge hop count between the current goal entity and each
14/// candidate entity. BFS is bounded by `max_depth` to satisfy NFR-003. Results are cached
15/// per goal entity id to avoid re-traversal within the same turn.
16pub struct CausalDistanceComputer {
17    graph_store: Arc<GraphStore>,
18    max_depth: u32,
19    neutral_distance: u32,
20    /// Last BFS result: `(goal_entity_id, depth_map)`.
21    cache: Option<(i64, HashMap<i64, u32>)>,
22}
23
24impl CausalDistanceComputer {
25    /// Create a new computer.
26    ///
27    /// # Parameters
28    ///
29    /// - `max_depth`: BFS hop limit (default: 10).
30    /// - `neutral_distance`: distance assigned to unreachable entities (default: 5).
31    #[must_use]
32    pub fn new(graph_store: Arc<GraphStore>, max_depth: u32, neutral_distance: u32) -> Self {
33        Self {
34            graph_store,
35            max_depth,
36            neutral_distance,
37            cache: None,
38        }
39    }
40
41    /// Compute causal distances from `goal_entity_id` to each entity in `entity_ids`.
42    ///
43    /// Returns a map of `entity_id → causal distance` where unreachable or missing entities
44    /// receive `neutral_distance`. When `goal_entity_id` is `None`, returns an empty map
45    /// (callers treat absent entries as neutral, contributing zero to the signal per FR-006).
46    ///
47    /// BFS result is cached per `goal_entity_id`; the cache is invalidated only when
48    /// the goal entity changes.
49    ///
50    /// # Errors
51    ///
52    /// Returns an error if the graph BFS query fails.
53    #[tracing::instrument(
54        name = "memory.five_signal.causal_distance.compute",
55        skip(self, entity_ids),
56        fields(goal_entity_id, candidate_count = entity_ids.len())
57    )]
58    pub async fn compute(
59        &mut self,
60        goal_entity_id: Option<i64>,
61        entity_ids: &[i64],
62    ) -> Result<HashMap<i64, u32>, crate::error::MemoryError> {
63        tracing::debug!("five_signal: computing causal distances");
64
65        let Some(goal_id) = goal_entity_id else {
66            return Ok(HashMap::new());
67        };
68
69        let neutral = self.neutral_distance;
70        let depth_map = self.ensure_cache(goal_id).await?;
71
72        let result = entity_ids
73            .iter()
74            .map(|&eid| {
75                let dist = depth_map.get(&eid).copied().unwrap_or(neutral);
76                (eid, dist)
77            })
78            .collect();
79
80        Ok(result)
81    }
82
83    /// Convert a raw causal distance to a score in `[0.0, 1.0]`.
84    ///
85    /// Distance 1 → 1.0, distance 5 → 0.2, `neutral_distance` → neutral value.
86    /// Distance 0 (goal entity itself) → 1.0 (clamped).
87    #[must_use]
88    #[inline]
89    pub fn distance_to_score(distance: u32) -> f64 {
90        if distance == 0 {
91            1.0
92        } else {
93            (1.0_f64 / f64::from(distance)).min(1.0)
94        }
95    }
96
97    /// Invalidate the BFS cache. Call at turn boundaries when the goal entity may change.
98    pub fn invalidate_cache(&mut self) {
99        self.cache = None;
100    }
101
102    async fn ensure_cache(
103        &mut self,
104        goal_id: i64,
105    ) -> Result<&HashMap<i64, u32>, crate::error::MemoryError> {
106        if self.cache.as_ref().map(|(id, _)| *id) != Some(goal_id) {
107            let (_, _, depth_map) = self
108                .graph_store
109                .bfs_typed(goal_id, self.max_depth, &[EdgeType::Causal])
110                .await?;
111            self.cache = Some((goal_id, depth_map));
112        }
113        Ok(&self.cache.as_ref().expect("just set above").1)
114    }
115}
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn distance_to_score_values() {
123        assert!((CausalDistanceComputer::distance_to_score(0) - 1.0).abs() < 1e-9);
124        assert!((CausalDistanceComputer::distance_to_score(1) - 1.0).abs() < 1e-9);
125        assert!((CausalDistanceComputer::distance_to_score(2) - 0.5).abs() < 1e-9);
126        assert!((CausalDistanceComputer::distance_to_score(5) - 0.2).abs() < 1e-9);
127    }
128
129    #[test]
130    fn distance_to_score_beyond_max_depth_clamped_to_min() {
131        // Scores decrease as distance grows and never exceed 1.0.
132        let score_at_limit = CausalDistanceComputer::distance_to_score(10);
133        let score_beyond = CausalDistanceComputer::distance_to_score(20);
134        assert!(score_at_limit <= 1.0);
135        assert!(score_beyond <= score_at_limit, "deeper nodes score lower");
136        assert!((score_at_limit - 0.1).abs() < 1e-9);
137        assert!((score_beyond - 0.05).abs() < 1e-9);
138    }
139
140    #[test]
141    fn neutral_distance_determines_unreachable_score() {
142        // Unreachable entities receive neutral_distance (default 5) → score = 1/5 = 0.2.
143        let neutral = 5_u32;
144        let score = CausalDistanceComputer::distance_to_score(neutral);
145        assert!((score - 0.2).abs() < 1e-9);
146    }
147
148    // Regression test for #4405: goal_entity_id=None returns empty map without touching the DB.
149    #[tokio::test]
150    async fn compute_none_goal_returns_empty_map() {
151        use std::sync::Arc;
152
153        // Build a minimal in-memory graph store so the constructor is satisfied.
154        let pool = sqlx::SqlitePool::connect(":memory:").await.unwrap();
155        let graph_store = Arc::new(crate::graph::GraphStore::new(pool));
156        let mut computer = CausalDistanceComputer::new(graph_store, 10, 5);
157
158        let result = computer
159            .compute(None, &[1, 2, 3])
160            .await
161            .expect("None goal must not fail");
162        assert!(
163            result.is_empty(),
164            "goal_entity_id=None must return empty map, got: {result:?}"
165        );
166    }
167}