Skip to main content

sqlite_graph/
graph.rs

1use std::fs;
2use std::path::{Path, PathBuf};
3
4use crate::error::{Error, Result};
5use crate::storage::Storage;
6use crate::types::*;
7
8/// An embeddable graph database built on SQLite.
9///
10/// Provides entity/edge storage, bi-temporal edges, FTS5 search,
11/// vector embeddings with RRF fusion, and recursive CTE traversal.
12pub struct Graph {
13    storage: Storage,
14    #[allow(dead_code)]
15    db_path: PathBuf,
16}
17
18impl Graph {
19    /// Open an existing database file.
20    pub fn open(db_path: &Path) -> Result<Self> {
21        if !db_path.exists() {
22            return Err(Error::NotFound(format!(
23                "database not found at {}",
24                db_path.display()
25            )));
26        }
27        let storage = Storage::open(db_path)?;
28        Ok(Self {
29            storage,
30            db_path: db_path.to_path_buf(),
31        })
32    }
33
34    /// Open an existing database or create a new one at the given path.
35    pub fn open_or_create(db_path: &Path) -> Result<Self> {
36        if let Some(parent) = db_path.parent()
37            && !parent.as_os_str().is_empty()
38        {
39            fs::create_dir_all(parent)?;
40        }
41        let storage = Storage::open(db_path)?;
42        Ok(Self {
43            storage,
44            db_path: db_path.to_path_buf(),
45        })
46    }
47
48    /// Open an in-memory database (for testing or ephemeral use).
49    pub fn in_memory() -> Result<Self> {
50        let storage = Storage::open_in_memory()?;
51        Ok(Self {
52            storage,
53            db_path: PathBuf::from(":memory:"),
54        })
55    }
56
57    // ── Episodes ──
58
59    /// Add an episode (event, decision, message) to the graph.
60    pub fn add_episode(&self, episode: Episode) -> Result<EpisodeResult> {
61        self.storage.insert_episode(&episode)?;
62        Ok(EpisodeResult {
63            episode_id: episode.id,
64        })
65    }
66
67    /// Get an episode by ID.
68    pub fn get_episode(&self, id: &str) -> Result<Option<Episode>> {
69        self.storage.get_episode(id)
70    }
71
72    /// List episodes with pagination.
73    pub fn list_episodes(&self, limit: usize, offset: usize) -> Result<Vec<Episode>> {
74        self.storage.list_episodes(limit, offset)
75    }
76
77    // ── Entities ──
78
79    /// Add an entity (node) to the graph.
80    pub fn add_entity(&self, entity: Entity) -> Result<()> {
81        self.storage.insert_entity(&entity)
82    }
83
84    /// Add an entity with fuzzy deduplication against existing entities of the same type.
85    ///
86    /// If an existing entity with Jaro-Winkler similarity >= threshold exists,
87    /// returns that entity's ID and stores the new name as an alias.
88    /// Otherwise creates a new entity.
89    ///
90    /// Returns (entity_id, was_merged: bool).
91    pub fn add_entity_deduped(&self, entity: Entity, threshold: f64) -> Result<(String, bool)> {
92        // 1. Check alias table first (exact alias match)
93        if let Some(canonical_id) = self.storage.find_by_alias(&entity.name)? {
94            return Ok((canonical_id, true));
95        }
96
97        // 2. Get all existing entities of same type
98        let existing = self.storage.get_entity_names_by_type(&entity.entity_type)?;
99
100        // 3. Compute Jaro-Winkler similarity to each
101        let name_lower = entity.name.to_lowercase();
102        let mut best: Option<(String, f64)> = None;
103        for (existing_id, existing_name) in &existing {
104            let sim = strsim::jaro_winkler(&name_lower, &existing_name.to_lowercase());
105            if sim >= threshold && best.as_ref().is_none_or(|(_, best_sim)| sim > *best_sim) {
106                best = Some((existing_id.clone(), sim));
107            }
108        }
109
110        // 4. If match found: add alias and return existing id
111        if let Some((canonical_id, sim)) = best {
112            self.storage.add_alias(&canonical_id, &entity.name, sim)?;
113            return Ok((canonical_id, true));
114        }
115
116        // 5. Otherwise: insert new entity
117        let id = entity.id.clone();
118        self.storage.insert_entity(&entity)?;
119        Ok((id, false))
120    }
121
122    /// Get an entity by ID.
123    pub fn get_entity(&self, id: &str) -> Result<Option<Entity>> {
124        self.storage.get_entity(id)
125    }
126
127    /// Get an entity by name.
128    pub fn get_entity_by_name(&self, name: &str) -> Result<Option<Entity>> {
129        self.storage.get_entity_by_name(name)
130    }
131
132    /// List entities, optionally filtered by type.
133    pub fn list_entities(&self, entity_type: Option<&str>, limit: usize) -> Result<Vec<Entity>> {
134        self.storage.list_entities(entity_type, limit)
135    }
136
137    // ── Edges ──
138
139    /// Add an edge (relationship) between two entities.
140    pub fn add_edge(&self, edge: Edge) -> Result<()> {
141        self.storage.insert_edge(&edge)
142    }
143
144    /// Get all edges for an entity (both as source and target).
145    pub fn get_edges_for_entity(&self, entity_id: &str) -> Result<Vec<Edge>> {
146        self.storage.get_edges_for_entity(entity_id)
147    }
148
149    /// Invalidate an edge (set valid_until to now).
150    pub fn invalidate_edge(&self, edge_id: &str) -> Result<()> {
151        self.storage.invalidate_edge(edge_id, chrono::Utc::now())
152    }
153
154    /// Link an episode to an entity (with optional character span).
155    pub fn link_episode_entity(
156        &self,
157        episode_id: &str,
158        entity_id: &str,
159        span_start: Option<usize>,
160        span_end: Option<usize>,
161    ) -> Result<()> {
162        self.storage
163            .link_episode_entity(episode_id, entity_id, span_start, span_end)
164    }
165
166    // ── Embeddings ──
167
168    /// Store an embedding for an episode (serialized as little-endian f32 bytes).
169    pub fn store_embedding(&self, episode_id: &str, embedding: &[f32]) -> Result<()> {
170        let bytes: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
171        self.storage.store_episode_embedding(episode_id, &bytes)
172    }
173
174    /// Store an embedding for an entity.
175    pub fn store_entity_embedding(&self, entity_id: &str, embedding: &[f32]) -> Result<()> {
176        let bytes: Vec<u8> = embedding.iter().flat_map(|f| f.to_le_bytes()).collect();
177        self.storage.store_entity_embedding(entity_id, &bytes)
178    }
179
180    /// Load all episode embeddings as (episode_id, Vec<f32>) pairs.
181    pub fn get_embeddings(&self) -> Result<Vec<(String, Vec<f32>)>> {
182        let raw = self.storage.get_all_episode_embeddings()?;
183        let result = raw
184            .into_iter()
185            .map(|(id, bytes)| {
186                let floats: Vec<f32> = bytes
187                    .chunks_exact(4)
188                    .map(|c| f32::from_le_bytes(c.try_into().unwrap()))
189                    .collect();
190                (id, floats)
191            })
192            .collect();
193        Ok(result)
194    }
195
196    // ── Search ──
197
198    /// Search episodes via FTS5 full-text search.
199    pub fn search(&self, query: &str, limit: usize) -> Result<Vec<(Episode, f64)>> {
200        self.storage.search_episodes(query, limit)
201    }
202
203    /// Search entities via FTS5.
204    pub fn search_entities(&self, query: &str, limit: usize) -> Result<Vec<(Entity, f64)>> {
205        self.storage.search_entities(query, limit)
206    }
207
208    /// Fused search using Reciprocal Rank Fusion (RRF) over FTS5 + semantic results.
209    ///
210    /// `query_embedding` should be the pre-computed embedding for `query`.
211    /// Returns episodes ranked by combined RRF score.
212    pub fn search_fused(
213        &self,
214        query: &str,
215        query_embedding: &[f32],
216        limit: usize,
217    ) -> Result<Vec<FusedEpisodeResult>> {
218        const K: f64 = 60.0;
219
220        let mut scores: std::collections::HashMap<String, f64> = std::collections::HashMap::new();
221        let mut episodes_map: std::collections::HashMap<String, Episode> =
222            std::collections::HashMap::new();
223
224        // --- FTS5 ranked list ---
225        let fts_pool = (limit * 10).max(200);
226        let fts_results = self.storage.search_episodes(query, fts_pool);
227        if let Ok(fts) = fts_results {
228            for (rank, (episode, _)) in fts.into_iter().enumerate() {
229                let rrf = 1.0 / (K + rank as f64 + 1.0);
230                *scores.entry(episode.id.clone()).or_insert(0.0) += rrf;
231                episodes_map.insert(episode.id.clone(), episode);
232            }
233        }
234
235        // --- Semantic (cosine similarity) ranked list ---
236        let all_embeddings = self.get_embeddings()?;
237        if !all_embeddings.is_empty() && !query_embedding.is_empty() {
238            let mut semantic: Vec<(String, f32)> = all_embeddings
239                .into_iter()
240                .map(|(id, vec)| {
241                    let sim = cosine_similarity(query_embedding, &vec);
242                    (id, sim)
243                })
244                .collect();
245            semantic.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
246
247            for (rank, (ep_id, _sim)) in semantic.into_iter().enumerate() {
248                let rrf = 1.0 / (K + rank as f64 + 1.0);
249                *scores.entry(ep_id.clone()).or_insert(0.0) += rrf;
250                if let std::collections::hash_map::Entry::Vacant(e) = episodes_map.entry(ep_id)
251                    && let Ok(Some(ep)) = self.storage.get_episode(e.key())
252                {
253                    e.insert(ep);
254                }
255            }
256        }
257
258        // Sort by total RRF score descending
259        let mut fused: Vec<(String, f64)> = scores.into_iter().collect();
260        fused.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
261
262        let results = fused
263            .into_iter()
264            .take(limit)
265            .filter_map(|(id, score)| {
266                episodes_map
267                    .remove(&id)
268                    .map(|episode| FusedEpisodeResult { episode, score })
269            })
270            .collect();
271
272        Ok(results)
273    }
274
275    // ── Traversal ──
276
277    /// Get context around an entity — its neighbors and connecting edges.
278    pub fn get_entity_context(&self, entity_id: &str) -> Result<EntityContext> {
279        let entity = self
280            .storage
281            .get_entity(entity_id)?
282            .ok_or_else(|| Error::NotFound(format!("entity {entity_id}")))?;
283
284        let edges = self.storage.get_current_edges_for_entity(entity_id)?;
285
286        let mut neighbors = Vec::new();
287        for edge in &edges {
288            let nid = if edge.source_id == entity_id {
289                &edge.target_id
290            } else {
291                &edge.source_id
292            };
293            if let Some(n) = self.storage.get_entity(nid)? {
294                neighbors.push(n);
295            }
296        }
297
298        Ok(EntityContext {
299            entity,
300            edges,
301            neighbors,
302        })
303    }
304
305    /// Multi-hop graph traversal from a starting entity.
306    ///
307    /// Returns all entities reachable within `max_depth` hops and the edges connecting them.
308    pub fn traverse(
309        &self,
310        start_entity_id: &str,
311        max_depth: usize,
312    ) -> Result<(Vec<Entity>, Vec<Edge>)> {
313        self.storage.traverse(start_entity_id, max_depth, true)
314    }
315
316    /// Multi-hop traversal including historical (invalidated) edges.
317    pub fn traverse_with_history(
318        &self,
319        start_entity_id: &str,
320        max_depth: usize,
321    ) -> Result<(Vec<Entity>, Vec<Edge>)> {
322        self.storage.traverse(start_entity_id, max_depth, false)
323    }
324
325    // ── Stats ──
326
327    /// Get graph-wide statistics.
328    pub fn stats(&self) -> Result<GraphStats> {
329        self.storage.stats()
330    }
331}
332
333/// Compute cosine similarity between two f32 vectors.
334fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
335    if a.len() != b.len() || a.is_empty() {
336        return 0.0;
337    }
338    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
339    let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
340    let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
341    if mag_a == 0.0 || mag_b == 0.0 {
342        0.0
343    } else {
344        dot / (mag_a * mag_b)
345    }
346}