Skip to main content

sediment/
graph.rs

1//! Graph store using SQLite for relationship tracking between memories.
2//!
3//! Provides a graph layer alongside LanceDB for tracking
4//! relationships like RELATED, SUPERSEDES, and CO_ACCESSED between items.
5
6use std::path::Path;
7
8use rusqlite::{Connection, params};
9use tracing::debug;
10
11use crate::error::{Result, SedimentError};
12
13/// A relationship between two memory items in the graph.
14#[derive(Debug, Clone)]
15pub struct Edge {
16    pub target_id: String,
17    pub rel_type: String,
18    pub strength: f64,
19}
20
21/// A co-access relationship between two memory items.
22#[derive(Debug, Clone)]
23pub struct CoAccessEdge {
24    pub target_id: String,
25    pub count: i64,
26}
27
28/// Full connection info for a memory item.
29#[derive(Debug, Clone)]
30pub struct ConnectionInfo {
31    pub target_id: String,
32    pub rel_type: String,
33    pub strength: f64,
34    pub count: Option<i64>,
35}
36
37/// SQLite-backed graph store for memory relationships.
38pub struct GraphStore {
39    conn: Connection,
40}
41
42impl GraphStore {
43    /// Open or create the graph store using the given SQLite database path.
44    /// Shares the same file as access.db.
45    pub fn open(path: &Path) -> Result<Self> {
46        let conn = Connection::open(path).map_err(|e| {
47            SedimentError::Database(format!("Failed to open graph database: {}", e))
48        })?;
49
50        conn.execute_batch("PRAGMA journal_mode=WAL;").ok();
51
52        conn.execute_batch(
53            "CREATE TABLE IF NOT EXISTS graph_nodes (
54                id TEXT PRIMARY KEY,
55                project_id TEXT NOT NULL DEFAULT '',
56                created_at INTEGER NOT NULL
57            );
58
59            CREATE TABLE IF NOT EXISTS graph_edges (
60                from_id TEXT NOT NULL,
61                to_id TEXT NOT NULL,
62                edge_type TEXT NOT NULL,
63                strength REAL NOT NULL DEFAULT 0.0,
64                rel_type TEXT NOT NULL DEFAULT '',
65                count INTEGER NOT NULL DEFAULT 0,
66                last_at INTEGER NOT NULL DEFAULT 0,
67                created_at INTEGER NOT NULL,
68                UNIQUE(from_id, to_id, edge_type)
69            );
70
71            CREATE INDEX IF NOT EXISTS idx_edges_from ON graph_edges(from_id);
72            CREATE INDEX IF NOT EXISTS idx_edges_to ON graph_edges(to_id);",
73        )
74        .map_err(|e| SedimentError::Database(format!("Failed to create graph tables: {}", e)))?;
75
76        Ok(Self { conn })
77    }
78
79    /// Add a Memory node to the graph.
80    pub fn add_node(&self, id: &str, project_id: Option<&str>, created_at: i64) -> Result<()> {
81        let pid = project_id.unwrap_or("");
82
83        self.conn
84            .execute(
85                "INSERT OR IGNORE INTO graph_nodes (id, project_id, created_at) VALUES (?1, ?2, ?3)",
86                params![id, pid, created_at],
87            )
88            .map_err(|e| SedimentError::Database(format!("Failed to add node: {}", e)))?;
89
90        debug!("Added graph node: {}", id);
91        Ok(())
92    }
93
94    /// Ensure a node exists in the graph. Creates it if missing (for backfill).
95    pub fn ensure_node_exists(
96        &self,
97        id: &str,
98        project_id: Option<&str>,
99        created_at: i64,
100    ) -> Result<()> {
101        self.add_node(id, project_id, created_at)
102    }
103
104    /// Remove a Memory node and all its edges from the graph.
105    pub fn remove_node(&self, id: &str) -> Result<()> {
106        self.conn
107            .execute(
108                "DELETE FROM graph_edges WHERE from_id = ?1 OR to_id = ?1",
109                params![id],
110            )
111            .map_err(|e| SedimentError::Database(format!("Failed to remove edges: {}", e)))?;
112
113        self.conn
114            .execute("DELETE FROM graph_nodes WHERE id = ?1", params![id])
115            .map_err(|e| SedimentError::Database(format!("Failed to remove node: {}", e)))?;
116
117        debug!("Removed graph node: {}", id);
118        Ok(())
119    }
120
121    /// Add a RELATED edge between two Memory nodes.
122    pub fn add_related_edge(
123        &self,
124        from_id: &str,
125        to_id: &str,
126        strength: f64,
127        rel_type: &str,
128    ) -> Result<()> {
129        let now = chrono::Utc::now().timestamp();
130
131        self.conn
132            .execute(
133                "INSERT OR IGNORE INTO graph_edges (from_id, to_id, edge_type, strength, rel_type, created_at)
134                 VALUES (?1, ?2, 'related', ?3, ?4, ?5)",
135                params![from_id, to_id, strength, rel_type, now],
136            )
137            .map_err(|e| SedimentError::Database(format!("Failed to add related edge: {}", e)))?;
138
139        debug!(
140            "Added RELATED edge: {} -> {} ({})",
141            from_id, to_id, rel_type
142        );
143        Ok(())
144    }
145
146    /// Add a SUPERSEDES edge from new_id to old_id.
147    pub fn add_supersedes_edge(&self, new_id: &str, old_id: &str) -> Result<()> {
148        let now = chrono::Utc::now().timestamp();
149
150        self.conn
151            .execute(
152                "INSERT OR IGNORE INTO graph_edges (from_id, to_id, edge_type, strength, created_at)
153                 VALUES (?1, ?2, 'supersedes', 1.0, ?3)",
154                params![new_id, old_id, now],
155            )
156            .map_err(|e| SedimentError::Database(format!("Failed to add supersedes edge: {}", e)))?;
157
158        debug!("Added SUPERSEDES edge: {} -> {}", new_id, old_id);
159        Ok(())
160    }
161
162    /// Get 1-hop neighbors of the given item IDs via RELATED or SUPERSEDES edges.
163    /// Returns (neighbor_id, rel_type, strength) tuples.
164    pub fn get_neighbors(
165        &self,
166        ids: &[&str],
167        min_strength: f64,
168    ) -> Result<Vec<(String, String, f64)>> {
169        if ids.is_empty() {
170            return Ok(Vec::new());
171        }
172
173        let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("?{}", i)).collect();
174        let ph = placeholders.join(",");
175        let strength_idx = ids.len() + 1;
176
177        let sql = format!(
178            "SELECT
179                CASE WHEN from_id IN ({ph}) THEN to_id ELSE from_id END AS neighbor,
180                CASE WHEN edge_type = 'related' THEN rel_type ELSE 'supersedes' END AS rtype,
181                strength
182             FROM graph_edges
183             WHERE (from_id IN ({ph}) OR to_id IN ({ph}))
184               AND edge_type IN ('related', 'supersedes')
185               AND strength >= ?{strength_idx}"
186        );
187
188        let mut stmt = self.conn.prepare(&sql).map_err(|e| {
189            SedimentError::Database(format!("Failed to prepare neighbors query: {}", e))
190        })?;
191
192        let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
193        for id in ids {
194            param_values.push(Box::new(id.to_string()));
195        }
196        param_values.push(Box::new(min_strength));
197
198        let params_ref: Vec<&dyn rusqlite::types::ToSql> =
199            param_values.iter().map(|b| b.as_ref()).collect();
200
201        let rows = stmt
202            .query_map(params_ref.as_slice(), |row| {
203                Ok((
204                    row.get::<_, String>(0)?,
205                    row.get::<_, String>(1)?,
206                    row.get::<_, f64>(2)?,
207                ))
208            })
209            .map_err(|e| SedimentError::Database(format!("Failed to query neighbors: {}", e)))?;
210
211        let mut results = Vec::new();
212        for row in rows {
213            let r = row
214                .map_err(|e| SedimentError::Database(format!("Failed to read neighbor: {}", e)))?;
215            results.push(r);
216        }
217
218        Ok(results)
219    }
220
221    /// Record co-access between pairs of item IDs.
222    /// Creates or increments CO_ACCESSED edges.
223    pub fn record_co_access(&self, item_ids: &[String]) -> Result<()> {
224        if item_ids.len() < 2 {
225            return Ok(());
226        }
227
228        let now = chrono::Utc::now().timestamp();
229
230        for i in 0..item_ids.len() {
231            for j in (i + 1)..item_ids.len() {
232                let a = &item_ids[i];
233                let b = &item_ids[j];
234
235                self.conn
236                    .execute(
237                        "INSERT INTO graph_edges (from_id, to_id, edge_type, count, last_at, created_at)
238                         VALUES (?1, ?2, 'co_accessed', 1, ?3, ?3)
239                         ON CONFLICT(from_id, to_id, edge_type)
240                         DO UPDATE SET count = count + 1, last_at = ?3",
241                        params![a, b, now],
242                    )
243                    .map_err(|e| {
244                        SedimentError::Database(format!("Failed to record co-access: {}", e))
245                    })?;
246            }
247        }
248
249        Ok(())
250    }
251
252    /// Get items that are frequently co-accessed with the given IDs.
253    /// Returns (neighbor_id, co_access_count) tuples.
254    pub fn get_co_accessed(&self, ids: &[&str], min_count: i64) -> Result<Vec<(String, i64)>> {
255        if ids.is_empty() {
256            return Ok(Vec::new());
257        }
258
259        let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("?{}", i)).collect();
260        let ph = placeholders.join(",");
261        let min_idx = ids.len() + 1;
262
263        let sql = format!(
264            "SELECT
265                CASE WHEN from_id IN ({ph}) THEN to_id ELSE from_id END AS neighbor,
266                count
267             FROM graph_edges
268             WHERE (from_id IN ({ph}) OR to_id IN ({ph}))
269               AND edge_type = 'co_accessed'
270               AND count >= ?{min_idx}
271             ORDER BY count DESC"
272        );
273
274        let mut stmt = self.conn.prepare(&sql).map_err(|e| {
275            SedimentError::Database(format!("Failed to prepare co-access query: {}", e))
276        })?;
277
278        let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
279        for id in ids {
280            param_values.push(Box::new(id.to_string()));
281        }
282        param_values.push(Box::new(min_count));
283
284        let params_ref: Vec<&dyn rusqlite::types::ToSql> =
285            param_values.iter().map(|b| b.as_ref()).collect();
286
287        let rows = stmt
288            .query_map(params_ref.as_slice(), |row| {
289                Ok((row.get::<_, String>(0)?, row.get::<_, i64>(1)?))
290            })
291            .map_err(|e| SedimentError::Database(format!("Failed to query co-access: {}", e)))?;
292
293        let mut results = Vec::new();
294        for row in rows {
295            let r = row
296                .map_err(|e| SedimentError::Database(format!("Failed to read co-access: {}", e)))?;
297            results.push(r);
298        }
299
300        // Deduplicate by target_id, keeping highest count
301        results.sort_by(|a, b| b.1.cmp(&a.1));
302        let mut seen = std::collections::HashSet::new();
303        results.retain(|(id, _)| seen.insert(id.clone()));
304
305        Ok(results)
306    }
307
308    /// Transfer all edges from one node to another (used during consolidation merge).
309    pub fn transfer_edges(&self, from_id: &str, to_id: &str) -> Result<()> {
310        // Get all RELATED edges connected to from_id (excluding edges to to_id)
311        let mut stmt = self
312            .conn
313            .prepare(
314                "SELECT from_id, to_id, strength, rel_type, created_at
315             FROM graph_edges
316             WHERE (from_id = ?1 OR to_id = ?1)
317               AND edge_type = 'related'
318               AND from_id != ?2 AND to_id != ?2",
319            )
320            .map_err(|e| {
321                SedimentError::Database(format!("Failed to prepare transfer query: {}", e))
322            })?;
323
324        let edges: Vec<(String, f64, String, i64)> = stmt
325            .query_map(params![from_id, to_id], |row| {
326                let fid: String = row.get(0)?;
327                let tid: String = row.get(1)?;
328                let neighbor = if fid == from_id { tid } else { fid };
329                Ok((neighbor, row.get(2)?, row.get(3)?, row.get(4)?))
330            })
331            .map_err(|e| {
332                SedimentError::Database(format!("Failed to query edges for transfer: {}", e))
333            })?
334            .filter_map(|r| r.ok())
335            .collect();
336
337        // Create edges on the new node
338        for (neighbor, strength, rel_type, _) in &edges {
339            let _ = self.add_related_edge(to_id, neighbor, *strength, rel_type);
340        }
341
342        Ok(())
343    }
344
345    /// Detect triangles of RELATED items (for clustering).
346    /// Returns sets of 3 item IDs that form triangles.
347    pub fn detect_clusters(&self) -> Result<Vec<(String, String, String)>> {
348        let mut stmt = self.conn.prepare(
349            "SELECT e1.from_id, e1.to_id, e2.to_id
350             FROM graph_edges e1
351             JOIN graph_edges e2 ON e1.to_id = e2.from_id AND e1.edge_type = 'related' AND e2.edge_type = 'related'
352             JOIN graph_edges e3 ON e2.to_id = e3.to_id AND e3.from_id = e1.from_id AND e3.edge_type = 'related'
353             WHERE e1.from_id < e1.to_id AND e1.to_id < e2.to_id
354             LIMIT 50"
355        ).map_err(|e| SedimentError::Database(format!("Failed to detect clusters: {}", e)))?;
356
357        let rows = stmt
358            .query_map([], |row| {
359                Ok((
360                    row.get::<_, String>(0)?,
361                    row.get::<_, String>(1)?,
362                    row.get::<_, String>(2)?,
363                ))
364            })
365            .map_err(|e| SedimentError::Database(format!("Failed to read clusters: {}", e)))?;
366
367        let mut clusters = Vec::new();
368        for r in rows.flatten() {
369            clusters.push(r);
370        }
371
372        Ok(clusters)
373    }
374
375    /// Get full connection info for an item (all edge types).
376    pub fn get_full_connections(&self, item_id: &str) -> Result<Vec<ConnectionInfo>> {
377        let mut stmt = self
378            .conn
379            .prepare(
380                "SELECT
381                CASE WHEN from_id = ?1 THEN to_id ELSE from_id END AS neighbor,
382                edge_type,
383                strength,
384                rel_type,
385                count
386             FROM graph_edges
387             WHERE from_id = ?1 OR to_id = ?1",
388            )
389            .map_err(|e| {
390                SedimentError::Database(format!("Failed to prepare connections query: {}", e))
391            })?;
392
393        let rows = stmt
394            .query_map(params![item_id], |row| {
395                let neighbor: String = row.get(0)?;
396                let edge_type: String = row.get(1)?;
397                let strength: f64 = row.get(2)?;
398                let rel_type_val: String = row.get(3)?;
399                let count: i64 = row.get(4)?;
400
401                let display_type = match edge_type.as_str() {
402                    "related" => rel_type_val.clone(),
403                    "supersedes" => "supersedes".to_string(),
404                    "co_accessed" => "co_accessed".to_string(),
405                    _ => edge_type.clone(),
406                };
407
408                Ok(ConnectionInfo {
409                    target_id: neighbor,
410                    rel_type: display_type,
411                    strength,
412                    count: if edge_type == "co_accessed" {
413                        Some(count)
414                    } else {
415                        None
416                    },
417                })
418            })
419            .map_err(|e| SedimentError::Database(format!("Failed to query connections: {}", e)))?;
420
421        let mut connections = Vec::new();
422        for row in rows {
423            let r = row.map_err(|e| {
424                SedimentError::Database(format!("Failed to read connection: {}", e))
425            })?;
426            connections.push(r);
427        }
428
429        Ok(connections)
430    }
431
432    /// Get the edge count for an item (total number of edges of all types).
433    pub fn get_edge_count(&self, item_id: &str) -> Result<u32> {
434        let count: i64 = self
435            .conn
436            .query_row(
437                "SELECT COUNT(*) FROM graph_edges WHERE from_id = ?1 OR to_id = ?1",
438                params![item_id],
439                |row| row.get(0),
440            )
441            .map_err(|e| SedimentError::Database(format!("Failed to count edges: {}", e)))?;
442
443        Ok(count as u32)
444    }
445}