Skip to main content

zeph_memory/graph/
retrieval.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::{HashMap, HashSet};
5use std::time::{SystemTime, UNIX_EPOCH};
6
7use crate::error::MemoryError;
8
9use super::activation::{ActivatedFact, SpreadingActivation, SpreadingActivationParams};
10use super::store::GraphStore;
11use super::types::{EdgeType, GraphFact};
12
13/// Retrieve graph facts relevant to `query` via BFS traversal from matched seed entities.
14///
15/// Algorithm:
16/// 1. Split query into words and search for entity matches via fuzzy LIKE for each word.
17/// 2. For each matched seed entity, run BFS up to `max_hops` hops (temporal BFS when
18///    `at_timestamp` is `Some`, typed BFS when `edge_types` is non-empty).
19/// 3. Build `GraphFact` structs from edges, using depth map for `hop_distance`.
20/// 4. Deduplicate by `(entity_name, relation, target_name, edge_type)` keeping highest score.
21/// 5. Sort by score desc, truncate to `limit`.
22///
23/// # Parameters
24///
25/// - `at_timestamp`: `SQLite` datetime string (`"YYYY-MM-DD HH:MM:SS"`). When `Some`, only edges
26///   valid at that point in time are traversed. When `None`, only currently active edges are used.
27/// - `temporal_decay_rate`: non-negative decay rate (units: 1/day). `0.0` preserves the original
28///   `composite_score` ordering with no temporal adjustment.
29/// - `edge_types`: MAGMA subgraph filter. When non-empty, only traverses edges of the given types.
30///   When empty, traverses all active edges (backward-compatible).
31///
32/// # Errors
33///
34/// Returns an error if any database query fails.
35#[allow(clippy::too_many_arguments)]
36pub async fn graph_recall(
37    store: &GraphStore,
38    _embeddings: Option<&crate::embedding_store::EmbeddingStore>,
39    _provider: &zeph_llm::any::AnyProvider,
40    query: &str,
41    limit: usize,
42    max_hops: u32,
43    at_timestamp: Option<&str>,
44    temporal_decay_rate: f64,
45    edge_types: &[EdgeType],
46) -> Result<Vec<GraphFact>, MemoryError> {
47    if limit == 0 {
48        return Ok(Vec::new());
49    }
50
51    // Step 1: fuzzy search per query word — shared helper (DRY with graph_recall_activated).
52    let entity_scores = find_seed_entities(store, query, limit).await?;
53
54    if entity_scores.is_empty() {
55        return Ok(Vec::new());
56    }
57
58    // Capture current time once for consistent decay scoring across all facts.
59    let now_secs: i64 = SystemTime::now()
60        .duration_since(UNIX_EPOCH)
61        .map(|d| d.as_secs().cast_signed())
62        .unwrap_or(0);
63
64    // Step 2: BFS from each seed entity, collect facts
65    let mut all_facts: Vec<GraphFact> = Vec::new();
66
67    for (seed_id, seed_score) in &entity_scores {
68        let (entities, edges, depth_map) = if let Some(ts) = at_timestamp {
69            store.bfs_at_timestamp(*seed_id, max_hops, ts).await?
70        } else if !edge_types.is_empty() {
71            store.bfs_typed(*seed_id, max_hops, edge_types).await?
72        } else {
73            store.bfs_with_depth(*seed_id, max_hops).await?
74        };
75
76        // Use canonical_name for stable dedup keys (S5 fix): entities reached via different
77        // aliases have different display names but share canonical_name, preventing duplicates.
78        let name_map: HashMap<i64, &str> = entities
79            .iter()
80            .map(|e| (e.id, e.canonical_name.as_str()))
81            .collect();
82
83        for edge in &edges {
84            let Some(&hop_distance) = depth_map
85                .get(&edge.source_entity_id)
86                .or_else(|| depth_map.get(&edge.target_entity_id))
87            else {
88                continue;
89            };
90
91            let entity_name = name_map
92                .get(&edge.source_entity_id)
93                .copied()
94                .unwrap_or_default();
95            let target_name = name_map
96                .get(&edge.target_entity_id)
97                .copied()
98                .unwrap_or_default();
99
100            if entity_name.is_empty() || target_name.is_empty() {
101                continue;
102            }
103
104            all_facts.push(GraphFact {
105                entity_name: entity_name.to_owned(),
106                relation: edge.relation.clone(),
107                target_name: target_name.to_owned(),
108                fact: edge.fact.clone(),
109                entity_match_score: *seed_score,
110                hop_distance,
111                confidence: edge.confidence,
112                valid_from: Some(edge.valid_from.clone()),
113                edge_type: edge.edge_type,
114            });
115        }
116    }
117
118    // Step 3: sort by score desc (total_cmp for deterministic NaN ordering),
119    // then dedup keeping highest-scored fact per (entity, relation, target) key.
120    // Pre-compute scores to avoid recomputing composite_score() O(n log n) times.
121    let mut scored: Vec<(f32, GraphFact)> = all_facts
122        .into_iter()
123        .map(|f| {
124            let s = f.score_with_decay(temporal_decay_rate, now_secs);
125            (s, f)
126        })
127        .collect();
128    scored.sort_by(|(sa, _), (sb, _)| sb.total_cmp(sa));
129    let mut all_facts: Vec<GraphFact> = scored.into_iter().map(|(_, f)| f).collect();
130
131    // Dedup key includes edge_type (critic mitigation): the same (entity, relation, target)
132    // triple can legitimately exist with different edge types. Without edge_type in the key,
133    // typed BFS would return fewer facts than expected.
134    let mut seen: HashSet<(String, String, String, EdgeType)> = HashSet::new();
135    all_facts.retain(|f| {
136        seen.insert((
137            f.entity_name.clone(),
138            f.relation.clone(),
139            f.target_name.clone(),
140            f.edge_type,
141        ))
142    });
143
144    // Step 4: truncate to limit
145    all_facts.truncate(limit);
146
147    Ok(all_facts)
148}
149
150/// Find seed entities by fuzzy-matching query words against the entity store.
151///
152/// Returns a map of `entity_id -> match_score`. Match score is always `1.0` (fuzzy search
153/// returns binary match/no-match; seeds are query anchors per SYNAPSE semantics).
154///
155/// Shared by [`graph_recall`] and [`graph_recall_activated`] to avoid duplication (SUGGEST-04).
156///
157/// # Errors
158///
159/// Returns an error if any database query fails.
160async fn find_seed_entities(
161    store: &GraphStore,
162    query: &str,
163    limit: usize,
164) -> Result<HashMap<i64, f32>, MemoryError> {
165    const MAX_WORDS: usize = 5;
166
167    let filtered: Vec<&str> = query
168        .split_whitespace()
169        .filter(|w| w.len() >= 3)
170        .take(MAX_WORDS)
171        .collect();
172    let words: Vec<&str> = if filtered.is_empty() && !query.is_empty() {
173        vec![query]
174    } else {
175        filtered
176    };
177
178    let mut entity_scores: HashMap<i64, f32> = HashMap::new();
179    for word in &words {
180        let matches = store.find_entities_fuzzy(word, limit * 2).await?;
181        for entity in matches {
182            entity_scores
183                .entry(entity.id)
184                .and_modify(|s| *s = s.max(1.0))
185                .or_insert(1.0);
186        }
187    }
188
189    Ok(entity_scores)
190}
191
192/// Retrieve graph facts via SYNAPSE spreading activation from seed entities.
193///
194/// Algorithm:
195/// 1. Find seed entities via fuzzy word search (same as [`graph_recall`]).
196/// 2. Run spreading activation from seeds using `config`.
197/// 3. Return `ActivatedFact` records (edges collected during propagation) sorted by
198///    activation score descending, truncated to `limit`.
199///
200/// Edge type filtering via `edge_types` ensures MAGMA subgraph scoping is preserved
201/// (mirrors [`graph_recall`]'s `bfs_typed` path, MAJOR-05 fix).
202///
203/// # Errors
204///
205/// Returns an error if any database query fails.
206pub async fn graph_recall_activated(
207    store: &GraphStore,
208    query: &str,
209    limit: usize,
210    params: SpreadingActivationParams,
211    edge_types: &[EdgeType],
212) -> Result<Vec<ActivatedFact>, MemoryError> {
213    if limit == 0 {
214        return Ok(Vec::new());
215    }
216
217    let entity_scores = find_seed_entities(store, query, limit).await?;
218
219    if entity_scores.is_empty() {
220        return Ok(Vec::new());
221    }
222
223    tracing::debug!(
224        seeds = entity_scores.len(),
225        "spreading activation: starting recall"
226    );
227
228    let sa = SpreadingActivation::new(params);
229    let (_, mut facts) = sa.spread(store, entity_scores, edge_types).await?;
230
231    // Sort by activation score descending and truncate to limit.
232    facts.sort_by(|a, b| b.activation_score.total_cmp(&a.activation_score));
233
234    // Deduplicate by (source, relation, target, edge_type) keeping highest activation.
235    let mut seen: HashSet<(i64, String, i64, EdgeType)> = HashSet::new();
236    facts.retain(|f| {
237        seen.insert((
238            f.edge.source_entity_id,
239            f.edge.relation.clone(),
240            f.edge.target_entity_id,
241            f.edge.edge_type,
242        ))
243    });
244
245    facts.truncate(limit);
246
247    tracing::debug!(
248        result_count = facts.len(),
249        "spreading activation: recall complete"
250    );
251
252    Ok(facts)
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258    use crate::graph::store::GraphStore;
259    use crate::graph::types::EntityType;
260    use crate::sqlite::SqliteStore;
261    use zeph_llm::any::AnyProvider;
262    use zeph_llm::mock::MockProvider;
263
264    async fn setup_store() -> GraphStore {
265        let store = SqliteStore::new(":memory:").await.unwrap();
266        GraphStore::new(store.pool().clone())
267    }
268
269    fn mock_provider() -> AnyProvider {
270        AnyProvider::Mock(MockProvider::default())
271    }
272
273    #[tokio::test]
274    async fn graph_recall_empty_graph_returns_empty() {
275        let store = setup_store().await;
276        let provider = mock_provider();
277        let result = graph_recall(&store, None, &provider, "anything", 10, 2, None, 0.0, &[])
278            .await
279            .unwrap();
280        assert!(result.is_empty());
281    }
282
283    #[tokio::test]
284    async fn graph_recall_zero_limit_returns_empty() {
285        let store = setup_store().await;
286        let provider = mock_provider();
287        let result = graph_recall(&store, None, &provider, "user", 0, 2, None, 0.0, &[])
288            .await
289            .unwrap();
290        assert!(result.is_empty());
291    }
292
293    #[tokio::test]
294    async fn graph_recall_fuzzy_match_returns_facts() {
295        let store = setup_store().await;
296        let user_id = store
297            .upsert_entity("Alice", "Alice", EntityType::Person, None)
298            .await
299            .unwrap();
300        let tool_id = store
301            .upsert_entity("neovim", "neovim", EntityType::Tool, None)
302            .await
303            .unwrap();
304        store
305            .insert_edge(user_id, tool_id, "uses", "Alice uses neovim", 0.9, None)
306            .await
307            .unwrap();
308
309        let provider = mock_provider();
310        // "Ali" matches "Alice" via LIKE
311        let result = graph_recall(&store, None, &provider, "Ali neovim", 10, 2, None, 0.0, &[])
312            .await
313            .unwrap();
314        assert!(!result.is_empty());
315        assert_eq!(result[0].relation, "uses");
316    }
317
318    #[tokio::test]
319    async fn graph_recall_respects_max_hops() {
320        let store = setup_store().await;
321        let a = store
322            .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
323            .await
324            .unwrap();
325        let b = store
326            .upsert_entity("Beta", "Beta", EntityType::Person, None)
327            .await
328            .unwrap();
329        let c = store
330            .upsert_entity("Gamma", "Gamma", EntityType::Person, None)
331            .await
332            .unwrap();
333        store
334            .insert_edge(a, b, "knows", "Alpha knows Beta", 0.8, None)
335            .await
336            .unwrap();
337        store
338            .insert_edge(b, c, "knows", "Beta knows Gamma", 0.8, None)
339            .await
340            .unwrap();
341
342        let provider = mock_provider();
343        // max_hops=1: only the A→B edge should be reachable from A
344        let result = graph_recall(&store, None, &provider, "Alp", 10, 1, None, 0.0, &[])
345            .await
346            .unwrap();
347        // Should find A→B edge, but not B→C (which is hop 2 from A)
348        assert!(result.iter().all(|f| f.hop_distance <= 1));
349    }
350
351    #[tokio::test]
352    async fn graph_recall_deduplicates_facts() {
353        let store = setup_store().await;
354        let alice = store
355            .upsert_entity("Alice", "Alice", EntityType::Person, None)
356            .await
357            .unwrap();
358        let bob = store
359            .upsert_entity("Bob", "Bob", EntityType::Person, None)
360            .await
361            .unwrap();
362        store
363            .insert_edge(alice, bob, "knows", "Alice knows Bob", 0.9, None)
364            .await
365            .unwrap();
366
367        let provider = mock_provider();
368        // Both "Ali" and "Bob" match and BFS from both seeds yields the same edge
369        let result = graph_recall(&store, None, &provider, "Ali Bob", 10, 2, None, 0.0, &[])
370            .await
371            .unwrap();
372
373        // Should not have duplicate (Alice, knows, Bob) entries
374        let mut seen = std::collections::HashSet::new();
375        for f in &result {
376            let key = (&f.entity_name, &f.relation, &f.target_name);
377            assert!(seen.insert(key), "duplicate fact found: {f:?}");
378        }
379    }
380
381    #[tokio::test]
382    async fn graph_recall_sorts_by_composite_score() {
383        let store = setup_store().await;
384        let a = store
385            .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
386            .await
387            .unwrap();
388        let b = store
389            .upsert_entity("Beta", "Beta", EntityType::Tool, None)
390            .await
391            .unwrap();
392        let c = store
393            .upsert_entity("AlphaGadget", "AlphaGadget", EntityType::Tool, None)
394            .await
395            .unwrap();
396        // high-confidence direct edge
397        store
398            .insert_edge(a, b, "uses", "Alpha uses Beta", 1.0, None)
399            .await
400            .unwrap();
401        // low-confidence direct edge
402        store
403            .insert_edge(a, c, "mentions", "Alpha mentions AlphaGadget", 0.1, None)
404            .await
405            .unwrap();
406
407        let provider = mock_provider();
408        let result = graph_recall(&store, None, &provider, "Alp", 10, 2, None, 0.0, &[])
409            .await
410            .unwrap();
411
412        // First result should have higher composite score than second
413        assert!(result.len() >= 2);
414        let s0 = result[0].composite_score();
415        let s1 = result[1].composite_score();
416        assert!(s0 >= s1, "expected sorted desc: {s0} >= {s1}");
417    }
418
419    #[tokio::test]
420    async fn graph_recall_limit_truncates() {
421        let store = setup_store().await;
422        let root = store
423            .upsert_entity("Root", "Root", EntityType::Person, None)
424            .await
425            .unwrap();
426        for i in 0..10 {
427            let target = store
428                .upsert_entity(
429                    &format!("Target{i}"),
430                    &format!("Target{i}"),
431                    EntityType::Tool,
432                    None,
433                )
434                .await
435                .unwrap();
436            store
437                .insert_edge(
438                    root,
439                    target,
440                    "has",
441                    &format!("Root has Target{i}"),
442                    0.8,
443                    None,
444                )
445                .await
446                .unwrap();
447        }
448
449        let provider = mock_provider();
450        let result = graph_recall(&store, None, &provider, "Roo", 3, 2, None, 0.0, &[])
451            .await
452            .unwrap();
453        assert!(result.len() <= 3);
454    }
455
456    #[tokio::test]
457    async fn graph_recall_at_timestamp_excludes_future_edges() {
458        let store = setup_store().await;
459        let alice = store
460            .upsert_entity("Alice", "Alice", EntityType::Person, None)
461            .await
462            .unwrap();
463        let bob = store
464            .upsert_entity("Bob", "Bob", EntityType::Person, None)
465            .await
466            .unwrap();
467        // Insert an edge with valid_from = year 2100 (far future).
468        sqlx::query(
469            "INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, valid_from)
470             VALUES (?1, ?2, 'knows', 'Alice knows Bob', 0.9, '2100-01-01 00:00:00')",
471        )
472        .bind(alice)
473        .bind(bob)
474        .execute(store.pool())
475        .await
476        .unwrap();
477
478        let provider = mock_provider();
479        // Query at 2026 — should not see the 2100 edge.
480        let result = graph_recall(
481            &store,
482            None,
483            &provider,
484            "Ali",
485            10,
486            2,
487            Some("2026-01-01 00:00:00"),
488            0.0,
489            &[],
490        )
491        .await
492        .unwrap();
493        assert!(result.is_empty(), "future edge should be excluded");
494    }
495
496    #[tokio::test]
497    async fn graph_recall_at_timestamp_excludes_invalidated_edges() {
498        let store = setup_store().await;
499        let alice = store
500            .upsert_entity("Alice", "Alice", EntityType::Person, None)
501            .await
502            .unwrap();
503        let carol = store
504            .upsert_entity("Carol", "Carol", EntityType::Person, None)
505            .await
506            .unwrap();
507        // Insert an edge valid 2020-01-01 → 2021-01-01 (already expired by 2026).
508        sqlx::query(
509            "INSERT INTO graph_edges
510             (source_entity_id, target_entity_id, relation, fact, confidence, valid_from, valid_to, expired_at)
511             VALUES (?1, ?2, 'manages', 'Alice manages Carol', 0.8,
512                     '2020-01-01 00:00:00', '2021-01-01 00:00:00', '2021-01-01 00:00:00')",
513        )
514        .bind(alice)
515        .bind(carol)
516        .execute(store.pool())
517        .await
518        .unwrap();
519
520        let provider = mock_provider();
521
522        // Querying at 2026 (after valid_to) → no edge
523        let result_current = graph_recall(&store, None, &provider, "Ali", 10, 2, None, 0.0, &[])
524            .await
525            .unwrap();
526        assert!(
527            result_current.is_empty(),
528            "expired edge should be invisible at current time"
529        );
530
531        // Querying at 2020-06-01 (during validity window) → edge visible
532        let result_historical = graph_recall(
533            &store,
534            None,
535            &provider,
536            "Ali",
537            10,
538            2,
539            Some("2020-06-01 00:00:00"),
540            0.0,
541            &[],
542        )
543        .await
544        .unwrap();
545        assert!(
546            !result_historical.is_empty(),
547            "edge should be visible within its validity window"
548        );
549    }
550
551    #[tokio::test]
552    async fn graph_recall_temporal_decay_preserves_order_with_zero_rate() {
553        let store = setup_store().await;
554        let a = store
555            .upsert_entity("Alpha", "Alpha", EntityType::Person, None)
556            .await
557            .unwrap();
558        let b = store
559            .upsert_entity("Beta", "Beta", EntityType::Tool, None)
560            .await
561            .unwrap();
562        let c = store
563            .upsert_entity("AlphaGadget", "AlphaGadget", EntityType::Tool, None)
564            .await
565            .unwrap();
566        store
567            .insert_edge(a, b, "uses", "Alpha uses Beta", 1.0, None)
568            .await
569            .unwrap();
570        store
571            .insert_edge(a, c, "mentions", "Alpha mentions AlphaGadget", 0.1, None)
572            .await
573            .unwrap();
574
575        let provider = mock_provider();
576        // With decay_rate=0.0 order must be identical to composite_score ordering.
577        let result = graph_recall(&store, None, &provider, "Alp", 10, 2, None, 0.0, &[])
578            .await
579            .unwrap();
580        assert!(result.len() >= 2);
581        let s0 = result[0].composite_score();
582        let s1 = result[1].composite_score();
583        assert!(s0 >= s1, "expected sorted desc: {s0} >= {s1}");
584    }
585}