Skip to main content

zeph_memory/graph/
retrieval.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::{HashMap, HashSet};
5use std::time::{SystemTime, UNIX_EPOCH};
6
7use crate::error::MemoryError;
8
9use super::store::GraphStore;
10use super::types::GraphFact;
11
12/// Retrieve graph facts relevant to `query` via BFS traversal from matched seed entities.
13///
14/// Algorithm:
15/// 1. Split query into words and search for entity matches via fuzzy LIKE for each word.
16/// 2. For each matched seed entity, run BFS up to `max_hops` hops (temporal BFS when
17///    `at_timestamp` is `Some`).
18/// 3. Build `GraphFact` structs from edges, using depth map for `hop_distance`.
19/// 4. Deduplicate by `(entity_name, relation, target_name)` keeping highest score.
20/// 5. Sort by score desc, truncate to `limit`.
21///
22/// # Parameters
23///
24/// - `at_timestamp`: `SQLite` datetime string (`"YYYY-MM-DD HH:MM:SS"`). When `Some`, only edges
25///   valid at that point in time are traversed. When `None`, only currently active edges are used.
26/// - `temporal_decay_rate`: non-negative decay rate (units: 1/day). `0.0` preserves the original
27///   `composite_score` ordering with no temporal adjustment.
28///
29/// # Errors
30///
31/// Returns an error if any database query fails.
32#[allow(clippy::too_many_arguments)]
33pub async fn graph_recall(
34    store: &GraphStore,
35    _embeddings: Option<&crate::embedding_store::EmbeddingStore>,
36    _provider: &zeph_llm::any::AnyProvider,
37    query: &str,
38    limit: usize,
39    max_hops: u32,
40    at_timestamp: Option<&str>,
41    temporal_decay_rate: f64,
42) -> Result<Vec<GraphFact>, MemoryError> {
43    // Cap at MAX_WORDS to bound the number of sequential full-table-scan LIKE queries.
44    const MAX_WORDS: usize = 5;
45
46    if limit == 0 {
47        return Ok(Vec::new());
48    }
49
50    // Step 1: fuzzy search per query word (avoids full-sentence LIKE misses).
51    // Fall back to the full query string when all words are too short (len < 3).
52    let filtered: Vec<&str> = query
53        .split_whitespace()
54        .filter(|w| w.len() >= 3)
55        .take(MAX_WORDS)
56        .collect();
57    let words: Vec<&str> = if filtered.is_empty() && !query.is_empty() {
58        vec![query]
59    } else {
60        filtered
61    };
62
63    let mut entity_scores: HashMap<i64, f32> = HashMap::new();
64
65    for word in &words {
66        let matches = store.find_entities_fuzzy(word, limit * 2).await?;
67        for entity in matches {
68            entity_scores
69                .entry(entity.id)
70                .and_modify(|s| *s = s.max(1.0))
71                .or_insert(1.0);
72        }
73    }
74
75    if entity_scores.is_empty() {
76        return Ok(Vec::new());
77    }
78
79    // Capture current time once for consistent decay scoring across all facts.
80    let now_secs: i64 = SystemTime::now()
81        .duration_since(UNIX_EPOCH)
82        .map(|d| d.as_secs().cast_signed())
83        .unwrap_or(0);
84
85    // Step 2: BFS from each seed entity, collect facts
86    let mut all_facts: Vec<GraphFact> = Vec::new();
87
88    for (seed_id, seed_score) in &entity_scores {
89        let (entities, edges, depth_map) = if let Some(ts) = at_timestamp {
90            store.bfs_at_timestamp(*seed_id, max_hops, ts).await?
91        } else {
92            store.bfs_with_depth(*seed_id, max_hops).await?
93        };
94
95        // Use canonical_name for stable dedup keys (S5 fix): entities reached via different
96        // aliases have different display names but share canonical_name, preventing duplicates.
97        let name_map: HashMap<i64, &str> = entities
98            .iter()
99            .map(|e| (e.id, e.canonical_name.as_str()))
100            .collect();
101
102        for edge in &edges {
103            let Some(&hop_distance) = depth_map
104                .get(&edge.source_entity_id)
105                .or_else(|| depth_map.get(&edge.target_entity_id))
106            else {
107                continue;
108            };
109
110            let entity_name = name_map
111                .get(&edge.source_entity_id)
112                .copied()
113                .unwrap_or_default();
114            let target_name = name_map
115                .get(&edge.target_entity_id)
116                .copied()
117                .unwrap_or_default();
118
119            if entity_name.is_empty() || target_name.is_empty() {
120                continue;
121            }
122
123            all_facts.push(GraphFact {
124                entity_name: entity_name.to_owned(),
125                relation: edge.relation.clone(),
126                target_name: target_name.to_owned(),
127                fact: edge.fact.clone(),
128                entity_match_score: *seed_score,
129                hop_distance,
130                confidence: edge.confidence,
131                valid_from: Some(edge.valid_from.clone()),
132            });
133        }
134    }
135
136    // Step 3: sort by score desc (total_cmp for deterministic NaN ordering),
137    // then dedup keeping highest-scored fact per (entity, relation, target) key.
138    // Pre-compute scores to avoid recomputing composite_score() O(n log n) times.
139    let mut scored: Vec<(f32, GraphFact)> = all_facts
140        .into_iter()
141        .map(|f| {
142            let s = f.score_with_decay(temporal_decay_rate, now_secs);
143            (s, f)
144        })
145        .collect();
146    scored.sort_by(|(sa, _), (sb, _)| sb.total_cmp(sa));
147    let mut all_facts: Vec<GraphFact> = scored.into_iter().map(|(_, f)| f).collect();
148
149    let mut seen: HashSet<(String, String, String)> = HashSet::new();
150    all_facts.retain(|f| {
151        seen.insert((
152            f.entity_name.clone(),
153            f.relation.clone(),
154            f.target_name.clone(),
155        ))
156    });
157
158    // Step 4: truncate to limit
159    all_facts.truncate(limit);
160
161    Ok(all_facts)
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167    use crate::graph::store::GraphStore;
168    use crate::graph::types::EntityType;
169    use crate::sqlite::SqliteStore;
170    use zeph_llm::any::AnyProvider;
171    use zeph_llm::mock::MockProvider;
172
173    async fn setup_store() -> GraphStore {
174        let store = SqliteStore::new(":memory:").await.unwrap();
175        GraphStore::new(store.pool().clone())
176    }
177
178    fn mock_provider() -> AnyProvider {
179        AnyProvider::Mock(MockProvider::default())
180    }
181
182    #[tokio::test]
183    async fn graph_recall_empty_graph_returns_empty() {
184        let store = setup_store().await;
185        let provider = mock_provider();
186        let result = graph_recall(&store, None, &provider, "anything", 10, 2, None, 0.0)
187            .await
188            .unwrap();
189        assert!(result.is_empty());
190    }
191
192    #[tokio::test]
193    async fn graph_recall_zero_limit_returns_empty() {
194        let store = setup_store().await;
195        let provider = mock_provider();
196        let result = graph_recall(&store, None, &provider, "user", 0, 2, None, 0.0)
197            .await
198            .unwrap();
199        assert!(result.is_empty());
200    }
201
202    #[tokio::test]
203    async fn graph_recall_fuzzy_match_returns_facts() {
204        let store = setup_store().await;
205        let user_id = store
206            .upsert_entity("Alice", "Alice", EntityType::Person, None)
207            .await
208            .unwrap();
209        let tool_id = store
210            .upsert_entity("neovim", "neovim", EntityType::Tool, None)
211            .await
212            .unwrap();
213        store
214            .insert_edge(user_id, tool_id, "uses", "Alice uses neovim", 0.9, None)
215            .await
216            .unwrap();
217
218        let provider = mock_provider();
219        // "Ali" matches "Alice" via LIKE
220        let result = graph_recall(&store, None, &provider, "Ali neovim", 10, 2, None, 0.0)
221            .await
222            .unwrap();
223        assert!(!result.is_empty());
224        assert_eq!(result[0].relation, "uses");
225    }
226
227    #[tokio::test]
228    async fn graph_recall_respects_max_hops() {
229        let store = setup_store().await;
230        let a = store
231            .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
232            .await
233            .unwrap();
234        let b = store
235            .upsert_entity("Beta", "Beta", EntityType::Person, None)
236            .await
237            .unwrap();
238        let c = store
239            .upsert_entity("Gamma", "Gamma", EntityType::Person, None)
240            .await
241            .unwrap();
242        store
243            .insert_edge(a, b, "knows", "Alpha knows Beta", 0.8, None)
244            .await
245            .unwrap();
246        store
247            .insert_edge(b, c, "knows", "Beta knows Gamma", 0.8, None)
248            .await
249            .unwrap();
250
251        let provider = mock_provider();
252        // max_hops=1: only the A→B edge should be reachable from A
253        let result = graph_recall(&store, None, &provider, "Alp", 10, 1, None, 0.0)
254            .await
255            .unwrap();
256        // Should find A→B edge, but not B→C (which is hop 2 from A)
257        assert!(result.iter().all(|f| f.hop_distance <= 1));
258    }
259
260    #[tokio::test]
261    async fn graph_recall_deduplicates_facts() {
262        let store = setup_store().await;
263        let alice = store
264            .upsert_entity("Alice", "Alice", EntityType::Person, None)
265            .await
266            .unwrap();
267        let bob = store
268            .upsert_entity("Bob", "Bob", EntityType::Person, None)
269            .await
270            .unwrap();
271        store
272            .insert_edge(alice, bob, "knows", "Alice knows Bob", 0.9, None)
273            .await
274            .unwrap();
275
276        let provider = mock_provider();
277        // Both "Ali" and "Bob" match and BFS from both seeds yields the same edge
278        let result = graph_recall(&store, None, &provider, "Ali Bob", 10, 2, None, 0.0)
279            .await
280            .unwrap();
281
282        // Should not have duplicate (Alice, knows, Bob) entries
283        let mut seen = std::collections::HashSet::new();
284        for f in &result {
285            let key = (&f.entity_name, &f.relation, &f.target_name);
286            assert!(seen.insert(key), "duplicate fact found: {f:?}");
287        }
288    }
289
290    #[tokio::test]
291    async fn graph_recall_sorts_by_composite_score() {
292        let store = setup_store().await;
293        let a = store
294            .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
295            .await
296            .unwrap();
297        let b = store
298            .upsert_entity("Beta", "Beta", EntityType::Tool, None)
299            .await
300            .unwrap();
301        let c = store
302            .upsert_entity("AlphaGadget", "AlphaGadget", EntityType::Tool, None)
303            .await
304            .unwrap();
305        // high-confidence direct edge
306        store
307            .insert_edge(a, b, "uses", "Alpha uses Beta", 1.0, None)
308            .await
309            .unwrap();
310        // low-confidence direct edge
311        store
312            .insert_edge(a, c, "mentions", "Alpha mentions AlphaGadget", 0.1, None)
313            .await
314            .unwrap();
315
316        let provider = mock_provider();
317        let result = graph_recall(&store, None, &provider, "Alp", 10, 2, None, 0.0)
318            .await
319            .unwrap();
320
321        // First result should have higher composite score than second
322        assert!(result.len() >= 2);
323        let s0 = result[0].composite_score();
324        let s1 = result[1].composite_score();
325        assert!(s0 >= s1, "expected sorted desc: {s0} >= {s1}");
326    }
327
328    #[tokio::test]
329    async fn graph_recall_limit_truncates() {
330        let store = setup_store().await;
331        let root = store
332            .upsert_entity("Root", "Root", EntityType::Person, None)
333            .await
334            .unwrap();
335        for i in 0..10 {
336            let target = store
337                .upsert_entity(
338                    &format!("Target{i}"),
339                    &format!("Target{i}"),
340                    EntityType::Tool,
341                    None,
342                )
343                .await
344                .unwrap();
345            store
346                .insert_edge(
347                    root,
348                    target,
349                    "has",
350                    &format!("Root has Target{i}"),
351                    0.8,
352                    None,
353                )
354                .await
355                .unwrap();
356        }
357
358        let provider = mock_provider();
359        let result = graph_recall(&store, None, &provider, "Roo", 3, 2, None, 0.0)
360            .await
361            .unwrap();
362        assert!(result.len() <= 3);
363    }
364
365    #[tokio::test]
366    async fn graph_recall_at_timestamp_excludes_future_edges() {
367        let store = setup_store().await;
368        let alice = store
369            .upsert_entity("Alice", "Alice", EntityType::Person, None)
370            .await
371            .unwrap();
372        let bob = store
373            .upsert_entity("Bob", "Bob", EntityType::Person, None)
374            .await
375            .unwrap();
376        // Insert an edge with valid_from = year 2100 (far future).
377        sqlx::query(
378            "INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from)
379             VALUES (?1, ?2, 'knows', 'Alice knows Bob', 0.9, '2100-01-01 00:00:00')",
380        )
381        .bind(alice)
382        .bind(bob)
383        .execute(store.pool())
384        .await
385        .unwrap();
386
387        let provider = mock_provider();
388        // Query at 2026 — should not see the 2100 edge.
389        let result = graph_recall(
390            &store,
391            None,
392            &provider,
393            "Ali",
394            10,
395            2,
396            Some("2026-01-01 00:00:00"),
397            0.0,
398        )
399        .await
400        .unwrap();
401        assert!(result.is_empty(), "future edge should be excluded");
402    }
403
404    #[tokio::test]
405    async fn graph_recall_at_timestamp_excludes_invalidated_edges() {
406        let store = setup_store().await;
407        let alice = store
408            .upsert_entity("Alice", "Alice", EntityType::Person, None)
409            .await
410            .unwrap();
411        let carol = store
412            .upsert_entity("Carol", "Carol", EntityType::Person, None)
413            .await
414            .unwrap();
415        // Insert an edge valid 2020-01-01 → 2021-01-01 (already expired by 2026).
416        sqlx::query(
417            "INSERT INTO graph_edges
418             (source_entity_id, target_entity_id, relation, fact, confidence, valid_from, valid_to, expired_at)
419             VALUES (?1, ?2, 'manages', 'Alice manages Carol', 0.8,
420                     '2020-01-01 00:00:00', '2021-01-01 00:00:00', '2021-01-01 00:00:00')",
421        )
422        .bind(alice)
423        .bind(carol)
424        .execute(store.pool())
425        .await
426        .unwrap();
427
428        let provider = mock_provider();
429
430        // Querying at 2026 (after valid_to) → no edge
431        let result_current = graph_recall(&store, None, &provider, "Ali", 10, 2, None, 0.0)
432            .await
433            .unwrap();
434        assert!(
435            result_current.is_empty(),
436            "expired edge should be invisible at current time"
437        );
438
439        // Querying at 2020-06-01 (during validity window) → edge visible
440        let result_historical = graph_recall(
441            &store,
442            None,
443            &provider,
444            "Ali",
445            10,
446            2,
447            Some("2020-06-01 00:00:00"),
448            0.0,
449        )
450        .await
451        .unwrap();
452        assert!(
453            !result_historical.is_empty(),
454            "edge should be visible within its validity window"
455        );
456    }
457
458    #[tokio::test]
459    async fn graph_recall_temporal_decay_preserves_order_with_zero_rate() {
460        let store = setup_store().await;
461        let a = store
462            .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
463            .await
464            .unwrap();
465        let b = store
466            .upsert_entity("Beta", "Beta", EntityType::Tool, None)
467            .await
468            .unwrap();
469        let c = store
470            .upsert_entity("AlphaGadget", "AlphaGadget", EntityType::Tool, None)
471            .await
472            .unwrap();
473        store
474            .insert_edge(a, b, "uses", "Alpha uses Beta", 1.0, None)
475            .await
476            .unwrap();
477        store
478            .insert_edge(a, c, "mentions", "Alpha mentions AlphaGadget", 0.1, None)
479            .await
480            .unwrap();
481
482        let provider = mock_provider();
483        // With decay_rate=0.0 order must be identical to composite_score ordering.
484        let result = graph_recall(&store, None, &provider, "Alp", 10, 2, None, 0.0)
485            .await
486            .unwrap();
487        assert!(result.len() >= 2);
488        let s0 = result[0].composite_score();
489        let s1 = result[1].composite_score();
490        assert!(s0 >= s1, "expected sorted desc: {s0} >= {s1}");
491    }
492}