Skip to main content

sediment/
graph.rs

1//! Graph store using SQLite for relationship tracking between memories.
2//!
3//! Provides a graph layer alongside LanceDB for tracking
4//! relationships like RELATED, SUPERSEDES, and CO_ACCESSED between items.
5
6use std::path::Path;
7
8use rusqlite::{Connection, params};
9use tracing::debug;
10
11use crate::error::{Result, SedimentError};
12
13/// A relationship between two memory items in the graph.
14#[derive(Debug, Clone)]
15pub struct Edge {
16    pub target_id: String,
17    pub rel_type: String,
18    pub strength: f64,
19}
20
21/// A co-access relationship between two memory items.
22#[derive(Debug, Clone)]
23pub struct CoAccessEdge {
24    pub target_id: String,
25    pub count: i64,
26}
27
28/// Full connection info for a memory item.
29#[derive(Debug, Clone)]
30pub struct ConnectionInfo {
31    pub target_id: String,
32    pub rel_type: String,
33    pub strength: f64,
34    pub count: Option<i64>,
35}
36
37/// SQLite-backed graph store for memory relationships.
38pub struct GraphStore {
39    conn: Connection,
40}
41
42impl GraphStore {
43    /// Open or create the graph store using the given SQLite database path.
44    /// Shares the same file as access.db.
45    pub fn open(path: &Path) -> Result<Self> {
46        let conn = Connection::open(path).map_err(|e| {
47            SedimentError::Database(format!("Failed to open graph database: {}", e))
48        })?;
49
50        if let Err(e) = conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA busy_timeout=5000;") {
51            tracing::warn!("Failed to set SQLite PRAGMAs (graph): {}", e);
52        }
53
54        conn.execute_batch(
55            "CREATE TABLE IF NOT EXISTS graph_nodes (
56                id TEXT PRIMARY KEY,
57                project_id TEXT NOT NULL DEFAULT '',
58                created_at INTEGER NOT NULL
59            );
60
61            CREATE TABLE IF NOT EXISTS graph_edges (
62                from_id TEXT NOT NULL,
63                to_id TEXT NOT NULL,
64                edge_type TEXT NOT NULL,
65                strength REAL NOT NULL DEFAULT 0.0,
66                rel_type TEXT NOT NULL DEFAULT '',
67                count INTEGER NOT NULL DEFAULT 0,
68                last_at INTEGER NOT NULL DEFAULT 0,
69                created_at INTEGER NOT NULL,
70                UNIQUE(from_id, to_id, edge_type)
71            );
72
73            CREATE INDEX IF NOT EXISTS idx_edges_from ON graph_edges(from_id);
74            CREATE INDEX IF NOT EXISTS idx_edges_to ON graph_edges(to_id);",
75        )
76        .map_err(|e| SedimentError::Database(format!("Failed to create graph tables: {}", e)))?;
77
78        Ok(Self { conn })
79    }
80
81    /// Add a Memory node to the graph.
82    pub fn add_node(&self, id: &str, project_id: Option<&str>, created_at: i64) -> Result<()> {
83        let pid = project_id.unwrap_or("");
84
85        self.conn
86            .execute(
87                "INSERT OR IGNORE INTO graph_nodes (id, project_id, created_at) VALUES (?1, ?2, ?3)",
88                params![id, pid, created_at],
89            )
90            .map_err(|e| SedimentError::Database(format!("Failed to add node: {}", e)))?;
91
92        debug!("Added graph node: {}", id);
93        Ok(())
94    }
95
96    /// Ensure a node exists in the graph. Creates it if missing (for backfill).
97    pub fn ensure_node_exists(
98        &self,
99        id: &str,
100        project_id: Option<&str>,
101        created_at: i64,
102    ) -> Result<()> {
103        self.add_node(id, project_id, created_at)
104    }
105
106    /// Remove a Memory node and all its edges from the graph.
107    pub fn remove_node(&self, id: &str) -> Result<()> {
108        // Preserve incoming SUPERSEDES edges (where this node is the to_id)
109        // so that lineage/provenance chains remain intact after node removal.
110        self.conn
111            .execute(
112                "DELETE FROM graph_edges WHERE from_id = ?1 OR (to_id = ?1 AND edge_type != 'supersedes')",
113                params![id],
114            )
115            .map_err(|e| SedimentError::Database(format!("Failed to remove edges: {}", e)))?;
116
117        self.conn
118            .execute("DELETE FROM graph_nodes WHERE id = ?1", params![id])
119            .map_err(|e| SedimentError::Database(format!("Failed to remove node: {}", e)))?;
120
121        debug!("Removed graph node: {}", id);
122        Ok(())
123    }
124
125    /// Add a RELATED edge between two Memory nodes.
126    pub fn add_related_edge(
127        &self,
128        from_id: &str,
129        to_id: &str,
130        strength: f64,
131        rel_type: &str,
132    ) -> Result<()> {
133        let now = chrono::Utc::now().timestamp();
134
135        self.conn
136            .execute(
137                "INSERT OR IGNORE INTO graph_edges (from_id, to_id, edge_type, strength, rel_type, created_at)
138                 VALUES (?1, ?2, 'related', ?3, ?4, ?5)",
139                params![from_id, to_id, strength, rel_type, now],
140            )
141            .map_err(|e| SedimentError::Database(format!("Failed to add related edge: {}", e)))?;
142
143        debug!(
144            "Added RELATED edge: {} -> {} ({})",
145            from_id, to_id, rel_type
146        );
147        Ok(())
148    }
149
150    /// Add a SUPERSEDES edge from new_id to old_id.
151    pub fn add_supersedes_edge(&self, new_id: &str, old_id: &str) -> Result<()> {
152        let now = chrono::Utc::now().timestamp();
153
154        self.conn
155            .execute(
156                "INSERT OR IGNORE INTO graph_edges (from_id, to_id, edge_type, strength, created_at)
157                 VALUES (?1, ?2, 'supersedes', 1.0, ?3)",
158                params![new_id, old_id, now],
159            )
160            .map_err(|e| SedimentError::Database(format!("Failed to add supersedes edge: {}", e)))?;
161
162        debug!("Added SUPERSEDES edge: {} -> {}", new_id, old_id);
163        Ok(())
164    }
165
166    /// Get 1-hop neighbors of the given item IDs via RELATED or SUPERSEDES edges.
167    /// Returns (neighbor_id, rel_type, strength) tuples.
168    ///
169    /// Note on parameter binding: SQLite reuses the same positional parameters (?1..?N)
170    /// across all three IN clauses and the CASE expression. This is correct because
171    /// SQLite binds by position, so the same parameter set is applied to each reference.
172    pub fn get_neighbors(
173        &self,
174        ids: &[&str],
175        min_strength: f64,
176    ) -> Result<Vec<(String, String, f64)>> {
177        if ids.is_empty() {
178            return Ok(Vec::new());
179        }
180
181        let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("?{}", i)).collect();
182        let ph = placeholders.join(",");
183        let strength_idx = ids.len() + 1;
184
185        let sql = format!(
186            "SELECT
187                CASE WHEN from_id IN ({ph}) THEN to_id ELSE from_id END AS neighbor,
188                CASE WHEN edge_type = 'related' THEN rel_type ELSE 'supersedes' END AS rtype,
189                strength
190             FROM graph_edges
191             WHERE (from_id IN ({ph}) OR to_id IN ({ph}))
192               AND edge_type IN ('related', 'supersedes')
193               AND strength >= ?{strength_idx}
194             LIMIT 100"
195        );
196
197        let mut stmt = self.conn.prepare(&sql).map_err(|e| {
198            SedimentError::Database(format!("Failed to prepare neighbors query: {}", e))
199        })?;
200
201        let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
202        for id in ids {
203            param_values.push(Box::new(id.to_string()));
204        }
205        param_values.push(Box::new(min_strength));
206
207        let params_ref: Vec<&dyn rusqlite::types::ToSql> =
208            param_values.iter().map(|b| b.as_ref()).collect();
209
210        let rows = stmt
211            .query_map(params_ref.as_slice(), |row| {
212                Ok((
213                    row.get::<_, String>(0)?,
214                    row.get::<_, String>(1)?,
215                    row.get::<_, f64>(2)?,
216                ))
217            })
218            .map_err(|e| SedimentError::Database(format!("Failed to query neighbors: {}", e)))?;
219
220        // Filter out input IDs from results so we never return a query ID as its own neighbor
221        let input_set: std::collections::HashSet<&str> = ids.iter().copied().collect();
222        let mut results = Vec::new();
223        for row in rows {
224            let r = row
225                .map_err(|e| SedimentError::Database(format!("Failed to read neighbor: {}", e)))?;
226            if !input_set.contains(r.0.as_str()) {
227                results.push(r);
228            }
229        }
230
231        Ok(results)
232    }
233
234    /// Record co-access between pairs of item IDs.
235    /// Creates or increments CO_ACCESSED edges.
236    pub fn record_co_access(&self, item_ids: &[String]) -> Result<()> {
237        if item_ids.len() < 2 {
238            return Ok(());
239        }
240
241        // Performance optimization: limit co-access recording to the top 3 items
242        // by position to bound the O(n^2) pair generation. Items beyond position 3
243        // in recall results won't have co-access edges recorded.
244        let item_ids = if item_ids.len() > 3 {
245            &item_ids[..3]
246        } else {
247            item_ids
248        };
249
250        let now = chrono::Utc::now().timestamp();
251
252        for i in 0..item_ids.len() {
253            for j in (i + 1)..item_ids.len() {
254                // Normalize edge direction: smaller ID always goes first to prevent
255                // duplicate edges (A,B) and (B,A) from accumulating separately.
256                let (a, b) = if item_ids[i] <= item_ids[j] {
257                    (&item_ids[i], &item_ids[j])
258                } else {
259                    (&item_ids[j], &item_ids[i])
260                };
261
262                self.conn
263                    .execute(
264                        "INSERT INTO graph_edges (from_id, to_id, edge_type, count, last_at, created_at)
265                         VALUES (?1, ?2, 'co_accessed', 1, ?3, ?3)
266                         ON CONFLICT(from_id, to_id, edge_type)
267                         DO UPDATE SET count = count + 1, last_at = ?3",
268                        params![a, b, now],
269                    )
270                    .map_err(|e| {
271                        SedimentError::Database(format!("Failed to record co-access: {}", e))
272                    })?;
273            }
274        }
275
276        Ok(())
277    }
278
279    /// Get items that are frequently co-accessed with the given IDs.
280    /// Returns (neighbor_id, co_access_count) tuples.
281    pub fn get_co_accessed(&self, ids: &[&str], min_count: i64) -> Result<Vec<(String, i64)>> {
282        if ids.is_empty() {
283            return Ok(Vec::new());
284        }
285
286        let placeholders: Vec<String> = (1..=ids.len()).map(|i| format!("?{}", i)).collect();
287        let ph = placeholders.join(",");
288        let min_idx = ids.len() + 1;
289
290        let sql = format!(
291            "SELECT
292                CASE WHEN from_id IN ({ph}) THEN to_id ELSE from_id END AS neighbor,
293                count
294             FROM graph_edges
295             WHERE (from_id IN ({ph}) OR to_id IN ({ph}))
296               AND edge_type = 'co_accessed'
297               AND count >= ?{min_idx}
298             ORDER BY count DESC"
299        );
300
301        let mut stmt = self.conn.prepare(&sql).map_err(|e| {
302            SedimentError::Database(format!("Failed to prepare co-access query: {}", e))
303        })?;
304
305        let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
306        for id in ids {
307            param_values.push(Box::new(id.to_string()));
308        }
309        param_values.push(Box::new(min_count));
310
311        let params_ref: Vec<&dyn rusqlite::types::ToSql> =
312            param_values.iter().map(|b| b.as_ref()).collect();
313
314        let rows = stmt
315            .query_map(params_ref.as_slice(), |row| {
316                Ok((row.get::<_, String>(0)?, row.get::<_, i64>(1)?))
317            })
318            .map_err(|e| SedimentError::Database(format!("Failed to query co-access: {}", e)))?;
319
320        let mut results = Vec::new();
321        for row in rows {
322            let r = row
323                .map_err(|e| SedimentError::Database(format!("Failed to read co-access: {}", e)))?;
324            results.push(r);
325        }
326
327        // Deduplicate by target_id, keeping highest count
328        results.sort_by(|a, b| b.1.cmp(&a.1));
329        let mut seen = std::collections::HashSet::new();
330        results.retain(|(id, _)| seen.insert(id.clone()));
331
332        Ok(results)
333    }
334
335    /// Transfer all edges from one node to another (used during consolidation merge).
336    pub fn transfer_edges(&self, from_id: &str, to_id: &str) -> Result<()> {
337        // Get all RELATED edges connected to from_id (excluding edges to to_id)
338        let mut stmt = self
339            .conn
340            .prepare(
341                "SELECT from_id, to_id, strength, rel_type, created_at
342             FROM graph_edges
343             WHERE (from_id = ?1 OR to_id = ?1)
344               AND edge_type = 'related'
345               AND from_id != ?2 AND to_id != ?2",
346            )
347            .map_err(|e| {
348                SedimentError::Database(format!("Failed to prepare transfer query: {}", e))
349            })?;
350
351        let edges: Vec<(String, f64, String, i64)> = stmt
352            .query_map(params![from_id, to_id], |row| {
353                let fid: String = row.get(0)?;
354                let tid: String = row.get(1)?;
355                let neighbor = if fid == from_id { tid } else { fid };
356                Ok((neighbor, row.get(2)?, row.get(3)?, row.get(4)?))
357            })
358            .map_err(|e| {
359                SedimentError::Database(format!("Failed to query edges for transfer: {}", e))
360            })?
361            .filter_map(|r| match r {
362                Ok(v) => Some(v),
363                Err(e) => {
364                    tracing::warn!("transfer_edges: failed to read row: {}", e);
365                    None
366                }
367            })
368            .collect();
369
370        // Create edges on the new node
371        for (neighbor, strength, rel_type, _) in &edges {
372            if let Err(e) = self.add_related_edge(to_id, neighbor, *strength, rel_type) {
373                tracing::warn!("transfer edge to {} failed: {}", neighbor, e);
374            }
375        }
376
377        Ok(())
378    }
379
380    /// Detect triangles of RELATED items (for clustering).
381    /// Returns sets of 3 item IDs that form triangles.
382    pub fn detect_clusters(&self) -> Result<Vec<(String, String, String)>> {
383        let mut stmt = self
384            .conn
385            .prepare(
386                "WITH biedges AS (
387                SELECT from_id AS a, to_id AS b FROM graph_edges WHERE edge_type = 'related'
388                UNION ALL
389                SELECT to_id AS a, from_id AS b FROM graph_edges WHERE edge_type = 'related'
390            )
391            SELECT DISTINCT e1.a, e1.b, e2.b
392            FROM biedges e1
393            JOIN biedges e2 ON e1.b = e2.a
394            JOIN biedges e3 ON e2.b = e3.a AND e3.b = e1.a
395            WHERE e1.a < e1.b AND e1.b < e2.b
396            LIMIT 50",
397            )
398            .map_err(|e| SedimentError::Database(format!("Failed to detect clusters: {}", e)))?;
399
400        let rows = stmt
401            .query_map([], |row| {
402                Ok((
403                    row.get::<_, String>(0)?,
404                    row.get::<_, String>(1)?,
405                    row.get::<_, String>(2)?,
406                ))
407            })
408            .map_err(|e| SedimentError::Database(format!("Failed to read clusters: {}", e)))?;
409
410        let mut clusters = Vec::new();
411        for r in rows.flatten() {
412            clusters.push(r);
413        }
414
415        Ok(clusters)
416    }
417
418    /// Get full connection info for an item (all edge types).
419    pub fn get_full_connections(&self, item_id: &str) -> Result<Vec<ConnectionInfo>> {
420        let mut stmt = self
421            .conn
422            .prepare(
423                "SELECT
424                CASE WHEN from_id = ?1 THEN to_id ELSE from_id END AS neighbor,
425                edge_type,
426                strength,
427                rel_type,
428                count
429             FROM graph_edges
430             WHERE from_id = ?1 OR to_id = ?1",
431            )
432            .map_err(|e| {
433                SedimentError::Database(format!("Failed to prepare connections query: {}", e))
434            })?;
435
436        let rows = stmt
437            .query_map(params![item_id], |row| {
438                let neighbor: String = row.get(0)?;
439                let edge_type: String = row.get(1)?;
440                let strength: f64 = row.get(2)?;
441                let rel_type_val: String = row.get(3)?;
442                let count: i64 = row.get(4)?;
443
444                let display_type = match edge_type.as_str() {
445                    "related" => rel_type_val.clone(),
446                    "supersedes" => "supersedes".to_string(),
447                    "co_accessed" => "co_accessed".to_string(),
448                    _ => edge_type.clone(),
449                };
450
451                Ok(ConnectionInfo {
452                    target_id: neighbor,
453                    rel_type: display_type,
454                    strength,
455                    count: if edge_type == "co_accessed" {
456                        Some(count)
457                    } else {
458                        None
459                    },
460                })
461            })
462            .map_err(|e| SedimentError::Database(format!("Failed to query connections: {}", e)))?;
463
464        let mut connections = Vec::new();
465        for row in rows {
466            let r = row.map_err(|e| {
467                SedimentError::Database(format!("Failed to read connection: {}", e))
468            })?;
469            connections.push(r);
470        }
471
472        Ok(connections)
473    }
474
475    /// Get the edge count for an item (total number of edges of all types).
476    pub fn get_edge_count(&self, item_id: &str) -> Result<u32> {
477        let count: i64 = self
478            .conn
479            .query_row(
480                "SELECT COUNT(*) FROM graph_edges WHERE from_id = ?1 OR to_id = ?1",
481                params![item_id],
482                |row| row.get(0),
483            )
484            .map_err(|e| SedimentError::Database(format!("Failed to count edges: {}", e)))?;
485
486        Ok(count as u32)
487    }
488}
489
490#[cfg(test)]
491mod tests {
492    use super::*;
493    use tempfile::NamedTempFile;
494
495    fn open_test_graph() -> GraphStore {
496        let tmp = NamedTempFile::new().unwrap();
497        GraphStore::open(tmp.path()).unwrap()
498    }
499
500    #[test]
501    fn test_get_neighbors_excludes_input_ids() {
502        // Fix #7: get_neighbors should never return an input ID as a neighbor
503        let graph = open_test_graph();
504        let now = chrono::Utc::now().timestamp();
505        graph.add_node("A", Some("proj"), now).unwrap();
506        graph.add_node("B", Some("proj"), now).unwrap();
507        graph.add_node("C", Some("proj"), now).unwrap();
508
509        // Create edges A-B, B-C
510        graph.add_related_edge("A", "B", 0.9, "test").unwrap();
511        graph.add_related_edge("B", "C", 0.9, "test").unwrap();
512
513        // Query neighbors for [A, B] — should only return C, not A or B
514        let neighbors = graph.get_neighbors(&["A", "B"], 0.0).unwrap();
515        let neighbor_ids: Vec<&str> = neighbors.iter().map(|(id, _, _)| id.as_str()).collect();
516        assert!(neighbor_ids.contains(&"C"));
517        assert!(!neighbor_ids.contains(&"A"));
518        assert!(!neighbor_ids.contains(&"B"));
519    }
520
521    #[test]
522    fn test_co_access_normalized_direction() {
523        // Fix #8: co-access edges should be normalized so (A,B) and (B,A) don't create duplicates
524        let graph = open_test_graph();
525        let now = chrono::Utc::now().timestamp();
526        graph.add_node("Z", Some("proj"), now).unwrap();
527        graph.add_node("A", Some("proj"), now).unwrap();
528
529        // Record co-access with Z before A (Z > A lexicographically)
530        graph
531            .record_co_access(&["Z".to_string(), "A".to_string()])
532            .unwrap();
533        // Record again with reversed order
534        graph
535            .record_co_access(&["A".to_string(), "Z".to_string()])
536            .unwrap();
537
538        // Should only have 1 edge with count=2 (not 2 edges with count=1 each)
539        let count: i64 = graph
540            .conn
541            .query_row(
542                "SELECT COUNT(*) FROM graph_edges WHERE edge_type = 'co_accessed'",
543                [],
544                |row| row.get(0),
545            )
546            .unwrap();
547        assert_eq!(count, 1, "Should have exactly 1 co-access edge");
548
549        let edge_count: i64 = graph
550            .conn
551            .query_row(
552                "SELECT count FROM graph_edges WHERE edge_type = 'co_accessed'",
553                [],
554                |row| row.get(0),
555            )
556            .unwrap();
557        assert_eq!(edge_count, 2, "Edge count should be 2 (incremented twice)");
558    }
559
560    #[test]
561    fn test_transfer_edges_preserves_relationships() {
562        // Fix #6: transfer_edges should move edges from old node to new node
563        let graph = open_test_graph();
564        let now = chrono::Utc::now().timestamp();
565        graph.add_node("old", Some("proj"), now).unwrap();
566        graph.add_node("new", Some("proj"), now).unwrap();
567        graph.add_node("friend", Some("proj"), now).unwrap();
568
569        graph
570            .add_related_edge("old", "friend", 0.9, "test")
571            .unwrap();
572
573        // Transfer edges from old to new
574        graph.transfer_edges("old", "new").unwrap();
575
576        // New node should now have edge to friend
577        let neighbors = graph.get_neighbors(&["new"], 0.0).unwrap();
578        assert!(
579            !neighbors.is_empty(),
580            "New node should have inherited edges"
581        );
582        let neighbor_ids: Vec<&str> = neighbors.iter().map(|(id, _, _)| id.as_str()).collect();
583        assert!(neighbor_ids.contains(&"friend"));
584    }
585
586    #[test]
587    fn test_remove_node_preserves_incoming_supersedes() {
588        // Bug #4: remove_node should preserve incoming SUPERSEDES edges for lineage
589        let graph = open_test_graph();
590        let now = chrono::Utc::now().timestamp();
591        graph.add_node("new", Some("proj"), now).unwrap();
592        graph.add_node("old", Some("proj"), now).unwrap();
593
594        // Simulate replace workflow: create SUPERSEDES edge new -> old
595        graph.add_supersedes_edge("new", "old").unwrap();
596
597        // Now remove the old node (as execute_store does)
598        graph.remove_node("old").unwrap();
599
600        // The SUPERSEDES edge (new -> old) should survive because old is the to_id
601        let connections = graph.get_full_connections("new").unwrap();
602        assert_eq!(connections.len(), 1, "SUPERSEDES edge should be preserved");
603        assert_eq!(connections[0].target_id, "old");
604        assert_eq!(connections[0].rel_type, "supersedes");
605    }
606
607    #[test]
608    fn test_get_neighbors_bounded() {
609        // Bug #3: get_neighbors should return at most 100 results
610        let graph = open_test_graph();
611        let now = chrono::Utc::now().timestamp();
612        graph.add_node("center", Some("proj"), now).unwrap();
613
614        for i in 0..150 {
615            let id = format!("n{}", i);
616            graph.add_node(&id, Some("proj"), now).unwrap();
617            graph.add_related_edge("center", &id, 0.9, "test").unwrap();
618        }
619
620        let neighbors = graph.get_neighbors(&["center"], 0.0).unwrap();
621        assert!(
622            neighbors.len() <= 100,
623            "get_neighbors should return at most 100, got {}",
624            neighbors.len()
625        );
626    }
627}