Skip to main content

zeph_memory/graph/
retrieval.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::{HashMap, HashSet};
5
6use crate::error::MemoryError;
7
8use super::store::GraphStore;
9use super::types::GraphFact;
10
11/// Retrieve graph facts relevant to `query` via BFS traversal from matched seed entities.
12///
13/// Algorithm:
14/// 1. Split query into words and search for entity matches via fuzzy LIKE for each word.
15/// 2. For each matched seed entity, run BFS up to `max_hops` hops.
16/// 3. Build `GraphFact` structs from edges, using depth map for `hop_distance`.
17/// 4. Deduplicate by `(entity_name, relation, target_name)` keeping highest `composite_score`.
18/// 5. Sort by `composite_score` desc, truncate to `limit`.
19///
20/// # Errors
21///
22/// Returns an error if any database query fails.
23pub async fn graph_recall(
24    store: &GraphStore,
25    _embeddings: Option<&crate::embedding_store::EmbeddingStore>,
26    _provider: &zeph_llm::any::AnyProvider,
27    query: &str,
28    limit: usize,
29    max_hops: u32,
30) -> Result<Vec<GraphFact>, MemoryError> {
31    // Cap at MAX_WORDS to bound the number of sequential full-table-scan LIKE queries.
32    const MAX_WORDS: usize = 5;
33
34    if limit == 0 {
35        return Ok(Vec::new());
36    }
37
38    // Step 1: fuzzy search per query word (avoids full-sentence LIKE misses).
39    // Fall back to the full query string when all words are too short (len < 3).
40    let filtered: Vec<&str> = query
41        .split_whitespace()
42        .filter(|w| w.len() >= 3)
43        .take(MAX_WORDS)
44        .collect();
45    let words: Vec<&str> = if filtered.is_empty() && !query.is_empty() {
46        vec![query]
47    } else {
48        filtered
49    };
50
51    let mut entity_scores: HashMap<i64, f32> = HashMap::new();
52
53    for word in &words {
54        let matches = store.find_entities_fuzzy(word, limit * 2).await?;
55        for entity in matches {
56            entity_scores
57                .entry(entity.id)
58                .and_modify(|s| *s = s.max(1.0))
59                .or_insert(1.0);
60        }
61    }
62
63    if entity_scores.is_empty() {
64        return Ok(Vec::new());
65    }
66
67    // Step 3: BFS from each seed entity, collect facts
68    let mut all_facts: Vec<GraphFact> = Vec::new();
69
70    for (seed_id, seed_score) in &entity_scores {
71        let (entities, edges, depth_map) = store.bfs_with_depth(*seed_id, max_hops).await?;
72
73        // Use canonical_name for stable dedup keys (S5 fix): entities reached via different
74        // aliases have different display names but share canonical_name, preventing duplicates.
75        let name_map: HashMap<i64, &str> = entities
76            .iter()
77            .map(|e| (e.id, e.canonical_name.as_str()))
78            .collect();
79
80        for edge in &edges {
81            let Some(&hop_distance) = depth_map
82                .get(&edge.source_entity_id)
83                .or_else(|| depth_map.get(&edge.target_entity_id))
84            else {
85                continue;
86            };
87
88            let entity_name = name_map
89                .get(&edge.source_entity_id)
90                .copied()
91                .unwrap_or_default();
92            let target_name = name_map
93                .get(&edge.target_entity_id)
94                .copied()
95                .unwrap_or_default();
96
97            if entity_name.is_empty() || target_name.is_empty() {
98                continue;
99            }
100
101            all_facts.push(GraphFact {
102                entity_name: entity_name.to_owned(),
103                relation: edge.relation.clone(),
104                target_name: target_name.to_owned(),
105                fact: edge.fact.clone(),
106                entity_match_score: *seed_score,
107                hop_distance,
108                confidence: edge.confidence,
109            });
110        }
111    }
112
113    // Step 4 & 5: sort by composite_score desc (total_cmp for deterministic NaN ordering),
114    // then dedup keeping highest-scored fact per (entity, relation, target) key.
115    all_facts.sort_by(|a, b| b.composite_score().total_cmp(&a.composite_score()));
116
117    let mut seen: HashSet<(String, String, String)> = HashSet::new();
118    all_facts.retain(|f| {
119        seen.insert((
120            f.entity_name.clone(),
121            f.relation.clone(),
122            f.target_name.clone(),
123        ))
124    });
125
126    // Step 6: truncate to limit
127    all_facts.truncate(limit);
128
129    Ok(all_facts)
130}
131
132#[cfg(test)]
133mod tests {
134    use super::*;
135    use crate::graph::store::GraphStore;
136    use crate::graph::types::EntityType;
137    use crate::sqlite::SqliteStore;
138    use zeph_llm::any::AnyProvider;
139    use zeph_llm::mock::MockProvider;
140
141    async fn setup_store() -> GraphStore {
142        let store = SqliteStore::new(":memory:").await.unwrap();
143        GraphStore::new(store.pool().clone())
144    }
145
146    fn mock_provider() -> AnyProvider {
147        AnyProvider::Mock(MockProvider::default())
148    }
149
150    #[tokio::test]
151    async fn graph_recall_empty_graph_returns_empty() {
152        let store = setup_store().await;
153        let provider = mock_provider();
154        let result = graph_recall(&store, None, &provider, "anything", 10, 2)
155            .await
156            .unwrap();
157        assert!(result.is_empty());
158    }
159
160    #[tokio::test]
161    async fn graph_recall_zero_limit_returns_empty() {
162        let store = setup_store().await;
163        let provider = mock_provider();
164        let result = graph_recall(&store, None, &provider, "user", 0, 2)
165            .await
166            .unwrap();
167        assert!(result.is_empty());
168    }
169
170    #[tokio::test]
171    async fn graph_recall_fuzzy_match_returns_facts() {
172        let store = setup_store().await;
173        let user_id = store
174            .upsert_entity("Alice", "Alice", EntityType::Person, None)
175            .await
176            .unwrap();
177        let tool_id = store
178            .upsert_entity("neovim", "neovim", EntityType::Tool, None)
179            .await
180            .unwrap();
181        store
182            .insert_edge(user_id, tool_id, "uses", "Alice uses neovim", 0.9, None)
183            .await
184            .unwrap();
185
186        let provider = mock_provider();
187        // "Ali" matches "Alice" via LIKE
188        let result = graph_recall(&store, None, &provider, "Ali neovim", 10, 2)
189            .await
190            .unwrap();
191        assert!(!result.is_empty());
192        assert_eq!(result[0].relation, "uses");
193    }
194
195    #[tokio::test]
196    async fn graph_recall_respects_max_hops() {
197        let store = setup_store().await;
198        let a = store
199            .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
200            .await
201            .unwrap();
202        let b = store
203            .upsert_entity("Beta", "Beta", EntityType::Person, None)
204            .await
205            .unwrap();
206        let c = store
207            .upsert_entity("Gamma", "Gamma", EntityType::Person, None)
208            .await
209            .unwrap();
210        store
211            .insert_edge(a, b, "knows", "Alpha knows Beta", 0.8, None)
212            .await
213            .unwrap();
214        store
215            .insert_edge(b, c, "knows", "Beta knows Gamma", 0.8, None)
216            .await
217            .unwrap();
218
219        let provider = mock_provider();
220        // max_hops=1: only the A→B edge should be reachable from A
221        let result = graph_recall(&store, None, &provider, "Alp", 10, 1)
222            .await
223            .unwrap();
224        // Should find A→B edge, but not B→C (which is hop 2 from A)
225        assert!(result.iter().all(|f| f.hop_distance <= 1));
226    }
227
228    #[tokio::test]
229    async fn graph_recall_deduplicates_facts() {
230        let store = setup_store().await;
231        let alice = store
232            .upsert_entity("Alice", "Alice", EntityType::Person, None)
233            .await
234            .unwrap();
235        let bob = store
236            .upsert_entity("Bob", "Bob", EntityType::Person, None)
237            .await
238            .unwrap();
239        store
240            .insert_edge(alice, bob, "knows", "Alice knows Bob", 0.9, None)
241            .await
242            .unwrap();
243
244        let provider = mock_provider();
245        // Both "Ali" and "Bob" match and BFS from both seeds yields the same edge
246        let result = graph_recall(&store, None, &provider, "Ali Bob", 10, 2)
247            .await
248            .unwrap();
249
250        // Should not have duplicate (Alice, knows, Bob) entries
251        let mut seen = std::collections::HashSet::new();
252        for f in &result {
253            let key = (&f.entity_name, &f.relation, &f.target_name);
254            assert!(seen.insert(key), "duplicate fact found: {f:?}");
255        }
256    }
257
258    #[tokio::test]
259    async fn graph_recall_sorts_by_composite_score() {
260        let store = setup_store().await;
261        let a = store
262            .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
263            .await
264            .unwrap();
265        let b = store
266            .upsert_entity("Beta", "Beta", EntityType::Tool, None)
267            .await
268            .unwrap();
269        let c = store
270            .upsert_entity("AlphaGadget", "AlphaGadget", EntityType::Tool, None)
271            .await
272            .unwrap();
273        // high-confidence direct edge
274        store
275            .insert_edge(a, b, "uses", "Alpha uses Beta", 1.0, None)
276            .await
277            .unwrap();
278        // low-confidence direct edge
279        store
280            .insert_edge(a, c, "mentions", "Alpha mentions AlphaGadget", 0.1, None)
281            .await
282            .unwrap();
283
284        let provider = mock_provider();
285        let result = graph_recall(&store, None, &provider, "Alp", 10, 2)
286            .await
287            .unwrap();
288
289        // First result should have higher composite score than second
290        assert!(result.len() >= 2);
291        let s0 = result[0].composite_score();
292        let s1 = result[1].composite_score();
293        assert!(s0 >= s1, "expected sorted desc: {s0} >= {s1}");
294    }
295
296    #[tokio::test]
297    async fn graph_recall_limit_truncates() {
298        let store = setup_store().await;
299        let root = store
300            .upsert_entity("Root", "Root", EntityType::Person, None)
301            .await
302            .unwrap();
303        for i in 0..10 {
304            let target = store
305                .upsert_entity(
306                    &format!("Target{i}"),
307                    &format!("Target{i}"),
308                    EntityType::Tool,
309                    None,
310                )
311                .await
312                .unwrap();
313            store
314                .insert_edge(
315                    root,
316                    target,
317                    "has",
318                    &format!("Root has Target{i}"),
319                    0.8,
320                    None,
321                )
322                .await
323                .unwrap();
324        }
325
326        let provider = mock_provider();
327        let result = graph_recall(&store, None, &provider, "Roo", 3, 2)
328            .await
329            .unwrap();
330        assert!(result.len() <= 3);
331    }
332}