Skip to main content

seshat_storage/repository/
edge_repository.rs

1//! SQLite implementation of [`EdgeRepository`].
2
3use std::sync::{Arc, Mutex};
4
5use rusqlite::{Connection, params};
6use seshat_core::{BranchId, Edge, EdgeId, EdgeType, NodeId};
7
8use super::EdgeRepository;
9use crate::StorageError;
10
11/// SQLite-backed edge repository.
12#[derive(Debug, Clone)]
13pub struct SqliteEdgeRepository {
14    conn: Arc<Mutex<Connection>>,
15}
16
17impl SqliteEdgeRepository {
18    /// Create a new repository backed by the given connection.
19    pub fn new(conn: Arc<Mutex<Connection>>) -> Self {
20        Self { conn }
21    }
22
23    fn conn(&self) -> Result<std::sync::MutexGuard<'_, Connection>, StorageError> {
24        self.conn.lock().map_err(|e| {
25            StorageError::QueryError(format!("Failed to acquire connection lock: {e}"))
26        })
27    }
28}
29
30impl EdgeRepository for SqliteEdgeRepository {
31    fn insert(&self, edge: &Edge) -> Result<Edge, StorageError> {
32        let conn = self.conn()?;
33
34        let metadata_str = edge
35            .metadata
36            .as_ref()
37            .map(serde_json::to_string)
38            .transpose()
39            .map_err(|e| StorageError::SerializationError(e.to_string()))?;
40
41        conn.execute(
42            "INSERT INTO edges (source_id, target_id, edge_type, branch_id, weight, metadata)
43             VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
44            params![
45                edge.source_id.0,
46                edge.target_id.0,
47                edge.edge_type.as_str(),
48                edge.branch_id.0,
49                edge.weight,
50                metadata_str,
51            ],
52        )?;
53
54        let id = conn.last_insert_rowid();
55
56        let mut inserted = edge.clone();
57        inserted.id = EdgeId(id);
58        Ok(inserted)
59    }
60
61    fn find_by_source(&self, source_id: NodeId) -> Result<Vec<Edge>, StorageError> {
62        self.query_edges(
63            "SELECT id, source_id, target_id, edge_type, branch_id, weight, metadata
64             FROM edges WHERE source_id = ?1",
65            &source_id.0,
66        )
67    }
68
69    fn find_by_target(&self, target_id: NodeId) -> Result<Vec<Edge>, StorageError> {
70        self.query_edges(
71            "SELECT id, source_id, target_id, edge_type, branch_id, weight, metadata
72             FROM edges WHERE target_id = ?1",
73            &target_id.0,
74        )
75    }
76
77    fn find_by_type(&self, edge_type: EdgeType) -> Result<Vec<Edge>, StorageError> {
78        self.query_edges(
79            "SELECT id, source_id, target_id, edge_type, branch_id, weight, metadata
80             FROM edges WHERE edge_type = ?1",
81            &edge_type.as_str(),
82        )
83    }
84
85    fn delete(&self, id: EdgeId) -> Result<(), StorageError> {
86        let conn = self.conn()?;
87
88        let affected = conn.execute("DELETE FROM edges WHERE id = ?1", params![id.0])?;
89
90        if affected == 0 {
91            return Err(StorageError::NotFound {
92                entity: "Edge",
93                id: id.0.to_string(),
94            });
95        }
96
97        Ok(())
98    }
99
100    fn delete_by_branch(&self, branch_id: &BranchId) -> Result<usize, StorageError> {
101        let conn = self.conn()?;
102
103        let affected = conn.execute(
104            "DELETE FROM edges WHERE branch_id = ?1",
105            params![branch_id.0],
106        )?;
107
108        Ok(affected)
109    }
110}
111
112// ---------------------------------------------------------------------------
113// Helpers
114// ---------------------------------------------------------------------------
115
116impl SqliteEdgeRepository {
117    /// Run a parameterised edge query and collect the results.
118    fn query_edges(
119        &self,
120        sql: &str,
121        param: &dyn rusqlite::types::ToSql,
122    ) -> Result<Vec<Edge>, StorageError> {
123        let conn = self.conn()?;
124        let mut stmt = conn.prepare(sql)?;
125        let rows = stmt.query_map([param], row_to_edge)?;
126        rows.collect::<Result<Vec<_>, _>>().map_err(Into::into)
127    }
128}
129
130/// Map a rusqlite Row to an `Edge`.
131fn row_to_edge(row: &rusqlite::Row<'_>) -> rusqlite::Result<Edge> {
132    let id: i64 = row.get(0)?;
133    let source_id: i64 = row.get(1)?;
134    let target_id: i64 = row.get(2)?;
135    let edge_type_str: String = row.get(3)?;
136    let branch_id: String = row.get(4)?;
137    let weight: f64 = row.get(5)?;
138    let metadata_str: Option<String> = row.get(6)?;
139
140    let edge_type: EdgeType = edge_type_str.parse().map_err(|e| {
141        rusqlite::Error::FromSqlConversionFailure(3, rusqlite::types::Type::Text, Box::new(e))
142    })?;
143
144    let metadata = metadata_str
145        .map(|s| serde_json::from_str(&s))
146        .transpose()
147        .map_err(|e| {
148            rusqlite::Error::FromSqlConversionFailure(6, rusqlite::types::Type::Text, Box::new(e))
149        })?;
150
151    Ok(Edge {
152        id: EdgeId(id),
153        source_id: NodeId(source_id),
154        target_id: NodeId(target_id),
155        edge_type,
156        branch_id: BranchId(branch_id),
157        weight,
158        metadata,
159    })
160}
161
162#[cfg(test)]
163mod tests {
164    use super::*;
165    use crate::Database;
166    use crate::repository::NodeRepository;
167    use crate::repository::node_repository::SqliteNodeRepository;
168    use seshat_core::KnowledgeNature;
169    use seshat_core::test_helpers::make_knowledge_node;
170
171    /// Helper: create an in-memory DB and return both repos (edges need nodes for FK).
172    fn test_repos() -> (SqliteNodeRepository, SqliteEdgeRepository) {
173        let db = Database::open(":memory:").expect("in-memory DB");
174        let conn = db.connection().clone();
175        (
176            SqliteNodeRepository::new(conn.clone()),
177            SqliteEdgeRepository::new(conn),
178        )
179    }
180
181    /// Helper: insert two nodes and return their IDs.
182    fn insert_two_nodes(node_repo: &SqliteNodeRepository) -> (NodeId, NodeId) {
183        let n1 = make_knowledge_node(KnowledgeNature::Fact, 0.8);
184        let n2 = make_knowledge_node(KnowledgeNature::Convention, 0.9);
185        let id1 = node_repo.insert(&n1).unwrap().id;
186        let id2 = node_repo.insert(&n2).unwrap().id;
187        (id1, id2)
188    }
189
190    fn make_edge(source_id: NodeId, target_id: NodeId, edge_type: EdgeType) -> Edge {
191        Edge {
192            id: EdgeId(0),
193            source_id,
194            target_id,
195            edge_type,
196            branch_id: BranchId::from("main"),
197            weight: 1.0,
198            metadata: None,
199        }
200    }
201
202    #[test]
203    fn insert_and_find_by_source() {
204        let (node_repo, edge_repo) = test_repos();
205        let (n1, n2) = insert_two_nodes(&node_repo);
206
207        let edge = make_edge(n1, n2, EdgeType::DependsOn);
208        let inserted = edge_repo.insert(&edge).expect("insert should succeed");
209        assert_ne!(inserted.id.0, 0, "should get assigned ID");
210
211        let edges = edge_repo.find_by_source(n1).expect("find_by_source");
212        assert_eq!(edges.len(), 1);
213        assert_eq!(edges[0].edge_type, EdgeType::DependsOn);
214        assert_eq!(edges[0].source_id, n1);
215        assert_eq!(edges[0].target_id, n2);
216    }
217
218    #[test]
219    fn find_by_target() {
220        let (node_repo, edge_repo) = test_repos();
221        let (n1, n2) = insert_two_nodes(&node_repo);
222
223        let edge = make_edge(n1, n2, EdgeType::RelatedTo);
224        edge_repo.insert(&edge).unwrap();
225
226        let edges = edge_repo.find_by_target(n2).unwrap();
227        assert_eq!(edges.len(), 1);
228        assert_eq!(edges[0].target_id, n2);
229    }
230
231    #[test]
232    fn find_by_type() {
233        let (node_repo, edge_repo) = test_repos();
234        let (n1, n2) = insert_two_nodes(&node_repo);
235
236        let e1 = make_edge(n1, n2, EdgeType::DependsOn);
237        let e2 = make_edge(n2, n1, EdgeType::RelatedTo);
238        let e3 = make_edge(n1, n2, EdgeType::DependsOn);
239        edge_repo.insert(&e1).unwrap();
240        edge_repo.insert(&e2).unwrap();
241        edge_repo.insert(&e3).unwrap();
242
243        let depends = edge_repo.find_by_type(EdgeType::DependsOn).unwrap();
244        assert_eq!(depends.len(), 2);
245
246        let related = edge_repo.find_by_type(EdgeType::RelatedTo).unwrap();
247        assert_eq!(related.len(), 1);
248    }
249
250    #[test]
251    fn delete_edge() {
252        let (node_repo, edge_repo) = test_repos();
253        let (n1, n2) = insert_two_nodes(&node_repo);
254
255        let edge = make_edge(n1, n2, EdgeType::PartOf);
256        let inserted = edge_repo.insert(&edge).unwrap();
257
258        edge_repo
259            .delete(inserted.id)
260            .expect("delete should succeed");
261
262        let edges = edge_repo.find_by_source(n1).unwrap();
263        assert!(edges.is_empty(), "edge should be deleted");
264    }
265
266    #[test]
267    fn delete_not_found() {
268        let (_node_repo, edge_repo) = test_repos();
269        let result = edge_repo.delete(EdgeId(999));
270        assert!(matches!(result, Err(StorageError::NotFound { .. })));
271    }
272
273    #[test]
274    fn insert_with_metadata() {
275        let (node_repo, edge_repo) = test_repos();
276        let (n1, n2) = insert_two_nodes(&node_repo);
277
278        let mut edge = make_edge(n1, n2, EdgeType::Implements);
279        edge.metadata = Some(serde_json::json!({"via": "trait impl"}));
280
281        let inserted = edge_repo.insert(&edge).unwrap();
282        let edges = edge_repo.find_by_source(n1).unwrap();
283        assert_eq!(edges.len(), 1);
284        assert_eq!(edges[0].id, inserted.id);
285        assert_eq!(edges[0].metadata.as_ref().unwrap()["via"], "trait impl");
286    }
287
288    #[test]
289    fn all_edge_type_variants_roundtrip() {
290        let (node_repo, edge_repo) = test_repos();
291        let (n1, n2) = insert_two_nodes(&node_repo);
292
293        let types = [
294            EdgeType::RelatedTo,
295            EdgeType::Updates,
296            EdgeType::Contradicts,
297            EdgeType::PartOf,
298            EdgeType::DependsOn,
299            EdgeType::Implements,
300        ];
301
302        for et in types {
303            let edge = make_edge(n1, n2, et);
304            edge_repo.insert(&edge).unwrap();
305        }
306
307        // All 6 should be retrievable via find_by_source
308        let all_edges = edge_repo.find_by_source(n1).unwrap();
309        assert_eq!(all_edges.len(), 6);
310
311        // Each type should match when queried individually
312        for et in types {
313            let found = edge_repo.find_by_type(et).unwrap();
314            assert!(!found.is_empty(), "should find edges of type {et}");
315        }
316    }
317
318    #[test]
319    fn delete_by_branch() {
320        let (node_repo, edge_repo) = test_repos();
321        let (n1, n2) = insert_two_nodes(&node_repo);
322
323        let mut e1 = make_edge(n1, n2, EdgeType::DependsOn);
324        e1.branch_id = BranchId::from("branch-a");
325
326        let mut e2 = make_edge(n2, n1, EdgeType::PartOf);
327        e2.branch_id = BranchId::from("branch-a");
328
329        let mut e3 = make_edge(n1, n2, EdgeType::RelatedTo);
330        e3.branch_id = BranchId::from("branch-b");
331
332        edge_repo.insert(&e1).unwrap();
333        edge_repo.insert(&e2).unwrap();
334        edge_repo.insert(&e3).unwrap();
335
336        let deleted = edge_repo
337            .delete_by_branch(&BranchId::from("branch-a"))
338            .unwrap();
339        assert_eq!(deleted, 2, "should delete 2 edges from branch-a");
340
341        // branch-a edges should be gone
342        let depends = edge_repo.find_by_type(EdgeType::DependsOn).unwrap();
343        assert!(depends.is_empty(), "DependsOn from branch-a should be gone");
344
345        // branch-b edge should still exist
346        let related = edge_repo.find_by_type(EdgeType::RelatedTo).unwrap();
347        assert_eq!(related.len(), 1, "branch-b edge should still exist");
348    }
349
350    #[test]
351    fn delete_by_branch_empty() {
352        let (_node_repo, edge_repo) = test_repos();
353        let deleted = edge_repo
354            .delete_by_branch(&BranchId::from("empty-branch"))
355            .unwrap();
356        assert_eq!(deleted, 0, "should delete 0 edges from empty branch");
357    }
358}