1use std::fs;
2use std::path::{Path, PathBuf};
3
4use crate::error::{Error, Result};
5use crate::storage::Storage;
6use crate::types::*;
7
8pub struct Graph {
13 storage: Storage,
14 #[allow(dead_code)]
15 db_path: PathBuf,
16}
17
18impl Graph {
19 pub fn open(db_path: &Path) -> Result<Self> {
21 if !db_path.exists() {
22 return Err(Error::NotFound(format!(
23 "database not found at {}",
24 db_path.display()
25 )));
26 }
27 let storage = Storage::open(db_path)?;
28 Ok(Self {
29 storage,
30 db_path: db_path.to_path_buf(),
31 })
32 }
33
34 pub fn open_or_create(db_path: &Path) -> Result<Self> {
36 if let Some(parent) = db_path.parent()
37 && !parent.as_os_str().is_empty()
38 {
39 fs::create_dir_all(parent)?;
40 }
41 let storage = Storage::open(db_path)?;
42 Ok(Self {
43 storage,
44 db_path: db_path.to_path_buf(),
45 })
46 }
47
48 pub fn in_memory() -> Result<Self> {
50 let storage = Storage::open_in_memory()?;
51 Ok(Self {
52 storage,
53 db_path: PathBuf::from(":memory:"),
54 })
55 }
56
57 pub fn add_episode(&self, episode: Episode) -> Result<EpisodeResult> {
61 self.storage.insert_episode(&episode)?;
62 Ok(EpisodeResult {
63 episode_id: episode.id,
64 })
65 }
66
67 pub fn get_episode(&self, id: &str) -> Result<Option<Episode>> {
69 self.storage.get_episode(id)
70 }
71
72 pub fn list_episodes(&self, limit: usize, offset: usize) -> Result<Vec<Episode>> {
74 self.storage.list_episodes(limit, offset)
75 }
76
77 pub fn add_entity(&self, entity: Entity) -> Result<()> {
81 self.storage.insert_entity(&entity)
82 }
83
84 pub fn add_entity_deduped(&self, entity: Entity, threshold: f64) -> Result<(String, bool)> {
92 if let Some(canonical_id) = self.storage.find_by_alias(&entity.name)? {
94 return Ok((canonical_id, true));
95 }
96
97 let existing = self.storage.get_entity_names_by_type(&entity.entity_type)?;
99
100 let name_lower = entity.name.to_lowercase();
102 let mut best: Option<(String, f64)> = None;
103 for (existing_id, existing_name) in &existing {
104 let sim = strsim::jaro_winkler(&name_lower, &existing_name.to_lowercase());
105 if sim >= threshold && best.as_ref().is_none_or(|(_, best_sim)| sim > *best_sim) {
106 best = Some((existing_id.clone(), sim));
107 }
108 }
109
110 if let Some((canonical_id, sim)) = best {
112 self.storage.add_alias(&canonical_id, &entity.name, sim)?;
113 return Ok((canonical_id, true));
114 }
115
116 let id = entity.id.clone();
118 self.storage.insert_entity(&entity)?;
119 Ok((id, false))
120 }
121
122 pub fn get_entity(&self, id: &str) -> Result<Option<Entity>> {
124 self.storage.get_entity(id)
125 }
126
127 pub fn get_entity_by_name(&self, name: &str) -> Result<Option<Entity>> {
129 self.storage.get_entity_by_name(name)
130 }
131
132 pub fn list_entities(&self, entity_type: Option<&str>, limit: usize) -> Result<Vec<Entity>> {
134 self.storage.list_entities(entity_type, limit)
135 }
136
137 pub fn add_edge(&self, edge: Edge) -> Result<()> {
141 self.storage.insert_edge(&edge)
142 }
143
144 pub fn get_edges_for_entity(&self, entity_id: &str) -> Result<Vec<Edge>> {
146 self.storage.get_edges_for_entity(entity_id)
147 }
148
149 pub fn invalidate_edge(&self, edge_id: &str) -> Result<()> {
151 self.storage.invalidate_edge(edge_id, chrono::Utc::now())
152 }
153
154 pub fn link_episode_entity(
156 &self,
157 episode_id: &str,
158 entity_id: &str,
159 span_start: Option<usize>,
160 span_end: Option<usize>,
161 ) -> Result<()> {
162 self.storage
163 .link_episode_entity(episode_id, entity_id, span_start, span_end)
164 }
165
166 pub fn store_embedding(&self, episode_id: &str, embedding: &[f32]) -> Result<()> {
170 let bytes: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
171 self.storage.store_episode_embedding(episode_id, &bytes)
172 }
173
174 pub fn store_entity_embedding(&self, entity_id: &str, embedding: &[f32]) -> Result<()> {
176 let bytes: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
177 self.storage.store_entity_embedding(entity_id, &bytes)
178 }
179
180 pub fn get_embeddings(&self) -> Result<Vec<(String, Vec<f32>)>> {
182 let raw = self.storage.get_all_episode_embeddings()?;
183 let result = raw
184 .into_iter()
185 .map(|(id, bytes)| {
186 let floats: Vec<f32> = bytes
187 .chunks_exact(4)
188 .map(|c| f32::from_le_bytes(c.try_into().unwrap()))
189 .collect();
190 (id, floats)
191 })
192 .collect();
193 Ok(result)
194 }
195
196 pub fn search(&self, query: &str, limit: usize) -> Result<Vec<(Episode, f64)>> {
200 self.storage.search_episodes(query, limit)
201 }
202
203 pub fn search_entities(&self, query: &str, limit: usize) -> Result<Vec<(Entity, f64)>> {
205 self.storage.search_entities(query, limit)
206 }
207
208 pub fn search_fused(
213 &self,
214 query: &str,
215 query_embedding: &[f32],
216 limit: usize,
217 ) -> Result<Vec<FusedEpisodeResult>> {
218 const K: f64 = 60.0;
219
220 let mut scores: std::collections::HashMap<String, f64> = std::collections::HashMap::new();
221 let mut episodes_map: std::collections::HashMap<String, Episode> =
222 std::collections::HashMap::new();
223
224 let fts_pool = (limit * 10).max(200);
226 let fts_results = self.storage.search_episodes(query, fts_pool);
227 if let Ok(fts) = fts_results {
228 for (rank, (episode, _)) in fts.into_iter().enumerate() {
229 let rrf = 1.0 / (K + rank as f64 + 1.0);
230 *scores.entry(episode.id.clone()).or_insert(0.0) += rrf;
231 episodes_map.insert(episode.id.clone(), episode);
232 }
233 }
234
235 let all_embeddings = self.get_embeddings()?;
237 if !all_embeddings.is_empty() && !query_embedding.is_empty() {
238 let mut semantic: Vec<(String, f32)> = all_embeddings
239 .into_iter()
240 .map(|(id, vec)| {
241 let sim = cosine_similarity(query_embedding, &vec);
242 (id, sim)
243 })
244 .collect();
245 semantic.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
246
247 for (rank, (ep_id, _sim)) in semantic.into_iter().enumerate() {
248 let rrf = 1.0 / (K + rank as f64 + 1.0);
249 *scores.entry(ep_id.clone()).or_insert(0.0) += rrf;
250 if let std::collections::hash_map::Entry::Vacant(e) = episodes_map.entry(ep_id)
251 && let Ok(Some(ep)) = self.storage.get_episode(e.key())
252 {
253 e.insert(ep);
254 }
255 }
256 }
257
258 let mut fused: Vec<(String, f64)> = scores.into_iter().collect();
260 fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
261
262 let results = fused
263 .into_iter()
264 .take(limit)
265 .filter_map(|(id, score)| {
266 episodes_map
267 .remove(&id)
268 .map(|episode| FusedEpisodeResult { episode, score })
269 })
270 .collect();
271
272 Ok(results)
273 }
274
275 pub fn get_entity_context(&self, entity_id: &str) -> Result<EntityContext> {
279 let entity = self
280 .storage
281 .get_entity(entity_id)?
282 .ok_or_else(|| Error::NotFound(format!("entity {entity_id}")))?;
283
284 let edges = self.storage.get_current_edges_for_entity(entity_id)?;
285
286 let mut neighbors = Vec::new();
287 for edge in &edges {
288 let nid = if edge.source_id == entity_id {
289 &edge.target_id
290 } else {
291 &edge.source_id
292 };
293 if let Some(n) = self.storage.get_entity(nid)? {
294 neighbors.push(n);
295 }
296 }
297
298 Ok(EntityContext {
299 entity,
300 edges,
301 neighbors,
302 })
303 }
304
305 pub fn traverse(
309 &self,
310 start_entity_id: &str,
311 max_depth: usize,
312 ) -> Result<(Vec<Entity>, Vec<Edge>)> {
313 self.storage.traverse(start_entity_id, max_depth, true)
314 }
315
316 pub fn traverse_with_history(
318 &self,
319 start_entity_id: &str,
320 max_depth: usize,
321 ) -> Result<(Vec<Entity>, Vec<Edge>)> {
322 self.storage.traverse(start_entity_id, max_depth, false)
323 }
324
325 pub fn stats(&self) -> Result<GraphStats> {
329 self.storage.stats()
330 }
331}
332
333fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
335 if a.len() != b.len() || a.is_empty() {
336 return 0.0;
337 }
338 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
339 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
340 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
341 if mag_a == 0.0 || mag_b == 0.0 {
342 0.0
343 } else {
344 dot / (mag_a * mag_b)
345 }
346}