Skip to main content

zeph_memory/graph/
retrieval_beam.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4//! Beam search graph recall.
5//!
6//! [`graph_recall_beam`] keeps only the top-K candidate entities at each hop,
7//! enabling multi-hop reasoning paths to be explored efficiently without
8//! unbounded BFS expansion.
9
10use std::collections::{HashMap, HashSet};
11use std::time::{SystemTime, UNIX_EPOCH};
12
13use crate::embedding_store::EmbeddingStore;
14use crate::error::MemoryError;
15use crate::graph::retrieval::find_seed_entities;
16use crate::graph::store::GraphStore;
17use crate::graph::types::{EdgeType, GraphFact};
18
19const DEFAULT_STRUCTURAL_WEIGHT: f32 = 0.4;
20const DEFAULT_COMMUNITY_CAP: usize = 3;
21
22/// Retrieve graph facts using beam search.
23///
24/// Algorithm:
25/// 1. Find seed entities via hybrid FTS5 + structural scoring; take top `beam_width` as initial beam.
26/// 2. Per hop: fetch edges for beam entities, score each neighbour, keep top `beam_width` entity IDs.
27/// 3. Collect all traversed edges; convert to [`GraphFact`]; dedup; sort; truncate.
28///
29/// # Errors
30///
31/// Returns an error if any database query fails.
32#[allow(clippy::too_many_arguments, clippy::too_many_lines)] // complex algorithm function; both suppressions justified until the function is decomposed in a future refactor
33pub async fn graph_recall_beam(
34    store: &GraphStore,
35    embeddings: Option<&EmbeddingStore>,
36    provider: &zeph_llm::any::AnyProvider,
37    query: &str,
38    limit: usize,
39    beam_width: usize,
40    max_hops: u32,
41    edge_types: &[EdgeType],
42    temporal_decay_rate: f64,
43    hebbian_enabled: bool,
44    hebbian_lr: f32,
45) -> Result<Vec<GraphFact>, MemoryError> {
46    let _span = tracing::info_span!("memory.graph.beam", query_len = query.len()).entered();
47
48    if limit == 0 {
49        return Ok(Vec::new());
50    }
51
52    let entity_scores = find_seed_entities(
53        store,
54        embeddings,
55        provider,
56        query,
57        limit,
58        DEFAULT_STRUCTURAL_WEIGHT,
59        DEFAULT_COMMUNITY_CAP,
60    )
61    .await?;
62
63    if entity_scores.is_empty() {
64        return Ok(Vec::new());
65    }
66
67    let now_secs: i64 = SystemTime::now()
68        .duration_since(UNIX_EPOCH)
69        .map_or(0, |d| d.as_secs().cast_signed());
70
71    // Initial beam: top-`beam_width` seeds by score.
72    let mut beam_scores: Vec<(i64, f32)> = entity_scores.into_iter().collect();
73    beam_scores.sort_by(|(_, sa), (_, sb)| sb.total_cmp(sa));
74    beam_scores.truncate(beam_width);
75
76    let mut beam_ids: Vec<i64> = beam_scores.iter().map(|(id, _)| *id).collect();
77    let mut beam_score_map: HashMap<i64, f32> = beam_scores.into_iter().collect();
78
79    let mut all_db_edges: Vec<crate::graph::types::Edge> = Vec::new();
80    let mut entity_name_map: HashMap<i64, String> = HashMap::new();
81
82    for _hop in 0..max_hops {
83        if beam_ids.is_empty() {
84            break;
85        }
86
87        let edges = store.edges_for_entities(&beam_ids, edge_types).await?;
88        if edges.is_empty() {
89            break;
90        }
91
92        // Collect entity IDs from edges to resolve names.
93        let new_entity_ids: Vec<i64> = edges
94            .iter()
95            .flat_map(|e| [e.source_entity_id, e.target_entity_id])
96            .filter(|id| !entity_name_map.contains_key(id))
97            .collect::<HashSet<_>>()
98            .into_iter()
99            .collect();
100
101        for id in new_entity_ids {
102            if let Ok(Some(entity)) = store.find_entity_by_id(id).await {
103                entity_name_map.insert(id, entity.canonical_name.clone());
104            }
105        }
106
107        // Score each neighbour by edge confidence (proxy for traversal quality).
108        let mut neighbour_scores: HashMap<i64, f32> = HashMap::new();
109        for edge in &edges {
110            let edge_conf = edge.confidence;
111            neighbour_scores
112                .entry(edge.target_entity_id)
113                .and_modify(|s| *s = s.max(edge_conf))
114                .or_insert(edge_conf);
115            neighbour_scores
116                .entry(edge.source_entity_id)
117                .and_modify(|s| *s = s.max(edge_conf))
118                .or_insert(edge_conf);
119        }
120
121        // Next beam: top-`beam_width` by score (excluding current beam members).
122        let mut candidates: Vec<(i64, f32)> = neighbour_scores
123            .into_iter()
124            .filter(|(id, _)| !beam_score_map.contains_key(id))
125            .collect();
126        candidates.sort_by(|(_, sa), (_, sb)| sb.total_cmp(sa));
127        candidates.truncate(beam_width);
128
129        beam_ids = candidates.iter().map(|(id, _)| *id).collect();
130        for (id, cand_score) in candidates {
131            beam_score_map.insert(id, cand_score);
132        }
133
134        all_db_edges.extend(edges);
135    }
136
137    if all_db_edges.is_empty() {
138        return Ok(Vec::new());
139    }
140
141    // Record retrievals fire-and-forget.
142    let edge_ids: Vec<i64> = all_db_edges.iter().map(|e| e.id).collect();
143    if let Err(e) = store.record_edge_retrieval(&edge_ids).await {
144        tracing::warn!(error = %e, "graph_recall_beam: failed to record edge retrieval");
145    }
146    // HL-F2: Hebbian weight reinforcement (fire-and-forget).
147    if hebbian_enabled
148        && !edge_ids.is_empty()
149        && let Err(e) = store.apply_hebbian_increment(&edge_ids, hebbian_lr).await
150    {
151        tracing::warn!(error = %e, "graph_recall_beam: hebbian increment failed");
152    }
153
154    // Convert to GraphFact, dedup, sort, truncate.
155    let mut facts: Vec<GraphFact> = Vec::new();
156    let mut seen: HashSet<(String, String, String, EdgeType)> = HashSet::new();
157
158    for edge in &all_db_edges {
159        let entity_name = entity_name_map
160            .get(&edge.source_entity_id)
161            .cloned()
162            .unwrap_or_default();
163        let target_name = entity_name_map
164            .get(&edge.target_entity_id)
165            .cloned()
166            .unwrap_or_default();
167        if entity_name.is_empty() || target_name.is_empty() {
168            continue;
169        }
170        let key = (
171            entity_name.clone(),
172            edge.relation.clone(),
173            target_name.clone(),
174            edge.edge_type,
175        );
176        if seen.insert(key) {
177            let seed_score = beam_score_map
178                .get(&edge.source_entity_id)
179                .copied()
180                .unwrap_or(0.5);
181            facts.push(GraphFact {
182                entity_name,
183                relation: edge.relation.clone(),
184                target_name,
185                fact: edge.fact.clone(),
186                entity_match_score: seed_score,
187                hop_distance: 1,
188                confidence: edge.confidence,
189                valid_from: Some(edge.valid_from.clone()),
190                edge_type: edge.edge_type,
191                retrieval_count: edge.retrieval_count,
192            });
193        }
194    }
195
196    facts.sort_by(|a, b| {
197        let sa = a.score_with_decay(temporal_decay_rate, now_secs);
198        let sb = b.score_with_decay(temporal_decay_rate, now_secs);
199        sb.total_cmp(&sa)
200    });
201    facts.truncate(limit);
202
203    Ok(facts)
204}
205
206#[cfg(test)]
207mod tests {
208    use super::*;
209    use crate::graph::store::GraphStore;
210    use crate::graph::types::EntityType;
211    use crate::store::SqliteStore;
212    use zeph_llm::any::AnyProvider;
213    use zeph_llm::mock::MockProvider;
214
215    async fn setup_store() -> GraphStore {
216        let store = SqliteStore::new(":memory:").await.unwrap();
217        GraphStore::new(store.pool().clone())
218    }
219
220    fn mock_provider() -> AnyProvider {
221        AnyProvider::Mock(MockProvider::default())
222    }
223
224    #[tokio::test]
225    async fn beam_empty_graph_returns_empty() {
226        let store = setup_store().await;
227        let provider = mock_provider();
228        let result = graph_recall_beam(
229            &store,
230            None,
231            &provider,
232            "anything",
233            10,
234            5,
235            2,
236            &[],
237            0.0,
238            false,
239            0.0,
240        )
241        .await
242        .unwrap();
243        assert!(result.is_empty());
244    }
245
246    #[tokio::test]
247    async fn beam_zero_limit_returns_empty() {
248        let store = setup_store().await;
249        let provider = mock_provider();
250        let result = graph_recall_beam(
251            &store,
252            None,
253            &provider,
254            "anything",
255            0,
256            5,
257            2,
258            &[],
259            0.0,
260            false,
261            0.0,
262        )
263        .await
264        .unwrap();
265        assert!(result.is_empty());
266    }
267
268    #[tokio::test]
269    async fn beam_finds_direct_edge() {
270        let store = setup_store().await;
271        let a = store
272            .upsert_entity("Alice", "alice", EntityType::Person, None)
273            .await
274            .unwrap();
275        let b = store
276            .upsert_entity("Bob", "bob", EntityType::Person, None)
277            .await
278            .unwrap();
279        store
280            .insert_edge(a, b, "knows", "Alice knows Bob", 0.9, None)
281            .await
282            .unwrap();
283
284        let provider = mock_provider();
285        let result = graph_recall_beam(
286            &store,
287            None,
288            &provider,
289            "Alice",
290            10,
291            5,
292            2,
293            &[],
294            0.0,
295            false,
296            0.0,
297        )
298        .await
299        .unwrap();
300        assert!(!result.is_empty());
301    }
302}