1use std::collections::{HashMap, HashSet};
11use std::time::{SystemTime, UNIX_EPOCH};
12
13use crate::embedding_store::EmbeddingStore;
14use crate::error::MemoryError;
15use crate::graph::retrieval::find_seed_entities;
16use crate::graph::store::GraphStore;
17use crate::graph::types::{EdgeType, GraphFact};
18
19const DEFAULT_STRUCTURAL_WEIGHT: f32 = 0.4;
20const DEFAULT_COMMUNITY_CAP: usize = 3;
21
22#[allow(clippy::too_many_arguments, clippy::too_many_lines)] pub async fn graph_recall_beam(
34 store: &GraphStore,
35 embeddings: Option<&EmbeddingStore>,
36 provider: &zeph_llm::any::AnyProvider,
37 query: &str,
38 limit: usize,
39 beam_width: usize,
40 max_hops: u32,
41 edge_types: &[EdgeType],
42 temporal_decay_rate: f64,
43 hebbian_enabled: bool,
44 hebbian_lr: f32,
45) -> Result<Vec<GraphFact>, MemoryError> {
46 let _span = tracing::info_span!("memory.graph.beam", query_len = query.len()).entered();
47
48 if limit == 0 {
49 return Ok(Vec::new());
50 }
51
52 let entity_scores = find_seed_entities(
53 store,
54 embeddings,
55 provider,
56 query,
57 limit,
58 DEFAULT_STRUCTURAL_WEIGHT,
59 DEFAULT_COMMUNITY_CAP,
60 )
61 .await?;
62
63 if entity_scores.is_empty() {
64 return Ok(Vec::new());
65 }
66
67 let now_secs: i64 = SystemTime::now()
68 .duration_since(UNIX_EPOCH)
69 .map_or(0, |d| d.as_secs().cast_signed());
70
71 let mut beam_scores: Vec<(i64, f32)> = entity_scores.into_iter().collect();
73 beam_scores.sort_by(|(_, sa), (_, sb)| sb.total_cmp(sa));
74 beam_scores.truncate(beam_width);
75
76 let mut beam_ids: Vec<i64> = beam_scores.iter().map(|(id, _)| *id).collect();
77 let mut beam_score_map: HashMap<i64, f32> = beam_scores.into_iter().collect();
78
79 let mut all_db_edges: Vec<crate::graph::types::Edge> = Vec::new();
80 let mut entity_name_map: HashMap<i64, String> = HashMap::new();
81
82 for _hop in 0..max_hops {
83 if beam_ids.is_empty() {
84 break;
85 }
86
87 let edges = store.edges_for_entities(&beam_ids, edge_types).await?;
88 if edges.is_empty() {
89 break;
90 }
91
92 let new_entity_ids: Vec<i64> = edges
94 .iter()
95 .flat_map(|e| [e.source_entity_id, e.target_entity_id])
96 .filter(|id| !entity_name_map.contains_key(id))
97 .collect::<HashSet<_>>()
98 .into_iter()
99 .collect();
100
101 for id in new_entity_ids {
102 if let Ok(Some(entity)) = store.find_entity_by_id(id).await {
103 entity_name_map.insert(id, entity.canonical_name.clone());
104 }
105 }
106
107 let mut neighbour_scores: HashMap<i64, f32> = HashMap::new();
109 for edge in &edges {
110 let edge_conf = edge.confidence;
111 neighbour_scores
112 .entry(edge.target_entity_id)
113 .and_modify(|s| *s = s.max(edge_conf))
114 .or_insert(edge_conf);
115 neighbour_scores
116 .entry(edge.source_entity_id)
117 .and_modify(|s| *s = s.max(edge_conf))
118 .or_insert(edge_conf);
119 }
120
121 let mut candidates: Vec<(i64, f32)> = neighbour_scores
123 .into_iter()
124 .filter(|(id, _)| !beam_score_map.contains_key(id))
125 .collect();
126 candidates.sort_by(|(_, sa), (_, sb)| sb.total_cmp(sa));
127 candidates.truncate(beam_width);
128
129 beam_ids = candidates.iter().map(|(id, _)| *id).collect();
130 for (id, cand_score) in candidates {
131 beam_score_map.insert(id, cand_score);
132 }
133
134 all_db_edges.extend(edges);
135 }
136
137 if all_db_edges.is_empty() {
138 return Ok(Vec::new());
139 }
140
141 let edge_ids: Vec<i64> = all_db_edges.iter().map(|e| e.id).collect();
143 if let Err(e) = store.record_edge_retrieval(&edge_ids).await {
144 tracing::warn!(error = %e, "graph_recall_beam: failed to record edge retrieval");
145 }
146 if hebbian_enabled
148 && !edge_ids.is_empty()
149 && let Err(e) = store.apply_hebbian_increment(&edge_ids, hebbian_lr).await
150 {
151 tracing::warn!(error = %e, "graph_recall_beam: hebbian increment failed");
152 }
153
154 let mut facts: Vec<GraphFact> = Vec::new();
156 let mut seen: HashSet<(String, String, String, EdgeType)> = HashSet::new();
157
158 for edge in &all_db_edges {
159 let entity_name = entity_name_map
160 .get(&edge.source_entity_id)
161 .cloned()
162 .unwrap_or_default();
163 let target_name = entity_name_map
164 .get(&edge.target_entity_id)
165 .cloned()
166 .unwrap_or_default();
167 if entity_name.is_empty() || target_name.is_empty() {
168 continue;
169 }
170 let key = (
171 entity_name.clone(),
172 edge.relation.clone(),
173 target_name.clone(),
174 edge.edge_type,
175 );
176 if seen.insert(key) {
177 let seed_score = beam_score_map
178 .get(&edge.source_entity_id)
179 .copied()
180 .unwrap_or(0.5);
181 facts.push(GraphFact {
182 entity_name,
183 relation: edge.relation.clone(),
184 target_name,
185 fact: edge.fact.clone(),
186 entity_match_score: seed_score,
187 hop_distance: 1,
188 confidence: edge.confidence,
189 valid_from: Some(edge.valid_from.clone()),
190 edge_type: edge.edge_type,
191 retrieval_count: edge.retrieval_count,
192 });
193 }
194 }
195
196 facts.sort_by(|a, b| {
197 let sa = a.score_with_decay(temporal_decay_rate, now_secs);
198 let sb = b.score_with_decay(temporal_decay_rate, now_secs);
199 sb.total_cmp(&sa)
200 });
201 facts.truncate(limit);
202
203 Ok(facts)
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209 use crate::graph::store::GraphStore;
210 use crate::graph::types::EntityType;
211 use crate::store::SqliteStore;
212 use zeph_llm::any::AnyProvider;
213 use zeph_llm::mock::MockProvider;
214
215 async fn setup_store() -> GraphStore {
216 let store = SqliteStore::new(":memory:").await.unwrap();
217 GraphStore::new(store.pool().clone())
218 }
219
220 fn mock_provider() -> AnyProvider {
221 AnyProvider::Mock(MockProvider::default())
222 }
223
224 #[tokio::test]
225 async fn beam_empty_graph_returns_empty() {
226 let store = setup_store().await;
227 let provider = mock_provider();
228 let result = graph_recall_beam(
229 &store,
230 None,
231 &provider,
232 "anything",
233 10,
234 5,
235 2,
236 &[],
237 0.0,
238 false,
239 0.0,
240 )
241 .await
242 .unwrap();
243 assert!(result.is_empty());
244 }
245
246 #[tokio::test]
247 async fn beam_zero_limit_returns_empty() {
248 let store = setup_store().await;
249 let provider = mock_provider();
250 let result = graph_recall_beam(
251 &store,
252 None,
253 &provider,
254 "anything",
255 0,
256 5,
257 2,
258 &[],
259 0.0,
260 false,
261 0.0,
262 )
263 .await
264 .unwrap();
265 assert!(result.is_empty());
266 }
267
268 #[tokio::test]
269 async fn beam_finds_direct_edge() {
270 let store = setup_store().await;
271 let a = store
272 .upsert_entity("Alice", "alice", EntityType::Person, None)
273 .await
274 .unwrap();
275 let b = store
276 .upsert_entity("Bob", "bob", EntityType::Person, None)
277 .await
278 .unwrap();
279 store
280 .insert_edge(a, b, "knows", "Alice knows Bob", 0.9, None)
281 .await
282 .unwrap();
283
284 let provider = mock_provider();
285 let result = graph_recall_beam(
286 &store,
287 None,
288 &provider,
289 "Alice",
290 10,
291 5,
292 2,
293 &[],
294 0.0,
295 false,
296 0.0,
297 )
298 .await
299 .unwrap();
300 assert!(!result.is_empty());
301 }
302}