1use std::sync::{Arc, Mutex};
4
5use rusqlite::{Connection, params};
6use seshat_core::{BranchId, Edge, EdgeId, EdgeType, NodeId};
7
8use super::EdgeRepository;
9use crate::StorageError;
10
11#[derive(Debug, Clone)]
13pub struct SqliteEdgeRepository {
14 conn: Arc<Mutex<Connection>>,
15}
16
17impl SqliteEdgeRepository {
18 pub fn new(conn: Arc<Mutex<Connection>>) -> Self {
20 Self { conn }
21 }
22
23 fn conn(&self) -> Result<std::sync::MutexGuard<'_, Connection>, StorageError> {
24 self.conn.lock().map_err(|e| {
25 StorageError::QueryError(format!("Failed to acquire connection lock: {e}"))
26 })
27 }
28}
29
30impl EdgeRepository for SqliteEdgeRepository {
31 fn insert(&self, edge: &Edge) -> Result<Edge, StorageError> {
32 let conn = self.conn()?;
33
34 let metadata_str = edge
35 .metadata
36 .as_ref()
37 .map(serde_json::to_string)
38 .transpose()
39 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
40
41 conn.execute(
42 "INSERT INTO edges (source_id, target_id, edge_type, branch_id, weight, metadata)
43 VALUES (?1, ?2, ?3, ?4, ?5, ?6)",
44 params![
45 edge.source_id.0,
46 edge.target_id.0,
47 edge.edge_type.as_str(),
48 edge.branch_id.0,
49 edge.weight,
50 metadata_str,
51 ],
52 )?;
53
54 let id = conn.last_insert_rowid();
55
56 let mut inserted = edge.clone();
57 inserted.id = EdgeId(id);
58 Ok(inserted)
59 }
60
61 fn find_by_source(&self, source_id: NodeId) -> Result<Vec<Edge>, StorageError> {
62 self.query_edges(
63 "SELECT id, source_id, target_id, edge_type, branch_id, weight, metadata
64 FROM edges WHERE source_id = ?1",
65 &source_id.0,
66 )
67 }
68
69 fn find_by_target(&self, target_id: NodeId) -> Result<Vec<Edge>, StorageError> {
70 self.query_edges(
71 "SELECT id, source_id, target_id, edge_type, branch_id, weight, metadata
72 FROM edges WHERE target_id = ?1",
73 &target_id.0,
74 )
75 }
76
77 fn find_by_type(&self, edge_type: EdgeType) -> Result<Vec<Edge>, StorageError> {
78 self.query_edges(
79 "SELECT id, source_id, target_id, edge_type, branch_id, weight, metadata
80 FROM edges WHERE edge_type = ?1",
81 &edge_type.as_str(),
82 )
83 }
84
85 fn delete(&self, id: EdgeId) -> Result<(), StorageError> {
86 let conn = self.conn()?;
87
88 let affected = conn.execute("DELETE FROM edges WHERE id = ?1", params![id.0])?;
89
90 if affected == 0 {
91 return Err(StorageError::NotFound {
92 entity: "Edge",
93 id: id.0.to_string(),
94 });
95 }
96
97 Ok(())
98 }
99
100 fn delete_by_branch(&self, branch_id: &BranchId) -> Result<usize, StorageError> {
101 let conn = self.conn()?;
102
103 let affected = conn.execute(
104 "DELETE FROM edges WHERE branch_id = ?1",
105 params![branch_id.0],
106 )?;
107
108 Ok(affected)
109 }
110}
111
112impl SqliteEdgeRepository {
117 fn query_edges(
119 &self,
120 sql: &str,
121 param: &dyn rusqlite::types::ToSql,
122 ) -> Result<Vec<Edge>, StorageError> {
123 let conn = self.conn()?;
124 let mut stmt = conn.prepare(sql)?;
125 let rows = stmt.query_map([param], row_to_edge)?;
126 rows.collect::<Result<Vec<_>, _>>().map_err(Into::into)
127 }
128}
129
130fn row_to_edge(row: &rusqlite::Row<'_>) -> rusqlite::Result<Edge> {
132 let id: i64 = row.get(0)?;
133 let source_id: i64 = row.get(1)?;
134 let target_id: i64 = row.get(2)?;
135 let edge_type_str: String = row.get(3)?;
136 let branch_id: String = row.get(4)?;
137 let weight: f64 = row.get(5)?;
138 let metadata_str: Option<String> = row.get(6)?;
139
140 let edge_type: EdgeType = edge_type_str.parse().map_err(|e| {
141 rusqlite::Error::FromSqlConversionFailure(3, rusqlite::types::Type::Text, Box::new(e))
142 })?;
143
144 let metadata = metadata_str
145 .map(|s| serde_json::from_str(&s))
146 .transpose()
147 .map_err(|e| {
148 rusqlite::Error::FromSqlConversionFailure(6, rusqlite::types::Type::Text, Box::new(e))
149 })?;
150
151 Ok(Edge {
152 id: EdgeId(id),
153 source_id: NodeId(source_id),
154 target_id: NodeId(target_id),
155 edge_type,
156 branch_id: BranchId(branch_id),
157 weight,
158 metadata,
159 })
160}
161
162#[cfg(test)]
163mod tests {
164 use super::*;
165 use crate::Database;
166 use crate::repository::NodeRepository;
167 use crate::repository::node_repository::SqliteNodeRepository;
168 use seshat_core::KnowledgeNature;
169 use seshat_core::test_helpers::make_knowledge_node;
170
171 fn test_repos() -> (SqliteNodeRepository, SqliteEdgeRepository) {
173 let db = Database::open(":memory:").expect("in-memory DB");
174 let conn = db.connection().clone();
175 (
176 SqliteNodeRepository::new(conn.clone()),
177 SqliteEdgeRepository::new(conn),
178 )
179 }
180
181 fn insert_two_nodes(node_repo: &SqliteNodeRepository) -> (NodeId, NodeId) {
183 let n1 = make_knowledge_node(KnowledgeNature::Fact, 0.8);
184 let n2 = make_knowledge_node(KnowledgeNature::Convention, 0.9);
185 let id1 = node_repo.insert(&n1).unwrap().id;
186 let id2 = node_repo.insert(&n2).unwrap().id;
187 (id1, id2)
188 }
189
190 fn make_edge(source_id: NodeId, target_id: NodeId, edge_type: EdgeType) -> Edge {
191 Edge {
192 id: EdgeId(0),
193 source_id,
194 target_id,
195 edge_type,
196 branch_id: BranchId::from("main"),
197 weight: 1.0,
198 metadata: None,
199 }
200 }
201
202 #[test]
203 fn insert_and_find_by_source() {
204 let (node_repo, edge_repo) = test_repos();
205 let (n1, n2) = insert_two_nodes(&node_repo);
206
207 let edge = make_edge(n1, n2, EdgeType::DependsOn);
208 let inserted = edge_repo.insert(&edge).expect("insert should succeed");
209 assert_ne!(inserted.id.0, 0, "should get assigned ID");
210
211 let edges = edge_repo.find_by_source(n1).expect("find_by_source");
212 assert_eq!(edges.len(), 1);
213 assert_eq!(edges[0].edge_type, EdgeType::DependsOn);
214 assert_eq!(edges[0].source_id, n1);
215 assert_eq!(edges[0].target_id, n2);
216 }
217
218 #[test]
219 fn find_by_target() {
220 let (node_repo, edge_repo) = test_repos();
221 let (n1, n2) = insert_two_nodes(&node_repo);
222
223 let edge = make_edge(n1, n2, EdgeType::RelatedTo);
224 edge_repo.insert(&edge).unwrap();
225
226 let edges = edge_repo.find_by_target(n2).unwrap();
227 assert_eq!(edges.len(), 1);
228 assert_eq!(edges[0].target_id, n2);
229 }
230
231 #[test]
232 fn find_by_type() {
233 let (node_repo, edge_repo) = test_repos();
234 let (n1, n2) = insert_two_nodes(&node_repo);
235
236 let e1 = make_edge(n1, n2, EdgeType::DependsOn);
237 let e2 = make_edge(n2, n1, EdgeType::RelatedTo);
238 let e3 = make_edge(n1, n2, EdgeType::DependsOn);
239 edge_repo.insert(&e1).unwrap();
240 edge_repo.insert(&e2).unwrap();
241 edge_repo.insert(&e3).unwrap();
242
243 let depends = edge_repo.find_by_type(EdgeType::DependsOn).unwrap();
244 assert_eq!(depends.len(), 2);
245
246 let related = edge_repo.find_by_type(EdgeType::RelatedTo).unwrap();
247 assert_eq!(related.len(), 1);
248 }
249
250 #[test]
251 fn delete_edge() {
252 let (node_repo, edge_repo) = test_repos();
253 let (n1, n2) = insert_two_nodes(&node_repo);
254
255 let edge = make_edge(n1, n2, EdgeType::PartOf);
256 let inserted = edge_repo.insert(&edge).unwrap();
257
258 edge_repo
259 .delete(inserted.id)
260 .expect("delete should succeed");
261
262 let edges = edge_repo.find_by_source(n1).unwrap();
263 assert!(edges.is_empty(), "edge should be deleted");
264 }
265
266 #[test]
267 fn delete_not_found() {
268 let (_node_repo, edge_repo) = test_repos();
269 let result = edge_repo.delete(EdgeId(999));
270 assert!(matches!(result, Err(StorageError::NotFound { .. })));
271 }
272
273 #[test]
274 fn insert_with_metadata() {
275 let (node_repo, edge_repo) = test_repos();
276 let (n1, n2) = insert_two_nodes(&node_repo);
277
278 let mut edge = make_edge(n1, n2, EdgeType::Implements);
279 edge.metadata = Some(serde_json::json!({"via": "trait impl"}));
280
281 let inserted = edge_repo.insert(&edge).unwrap();
282 let edges = edge_repo.find_by_source(n1).unwrap();
283 assert_eq!(edges.len(), 1);
284 assert_eq!(edges[0].id, inserted.id);
285 assert_eq!(edges[0].metadata.as_ref().unwrap()["via"], "trait impl");
286 }
287
288 #[test]
289 fn all_edge_type_variants_roundtrip() {
290 let (node_repo, edge_repo) = test_repos();
291 let (n1, n2) = insert_two_nodes(&node_repo);
292
293 let types = [
294 EdgeType::RelatedTo,
295 EdgeType::Updates,
296 EdgeType::Contradicts,
297 EdgeType::PartOf,
298 EdgeType::DependsOn,
299 EdgeType::Implements,
300 ];
301
302 for et in types {
303 let edge = make_edge(n1, n2, et);
304 edge_repo.insert(&edge).unwrap();
305 }
306
307 let all_edges = edge_repo.find_by_source(n1).unwrap();
309 assert_eq!(all_edges.len(), 6);
310
311 for et in types {
313 let found = edge_repo.find_by_type(et).unwrap();
314 assert!(!found.is_empty(), "should find edges of type {et}");
315 }
316 }
317
318 #[test]
319 fn delete_by_branch() {
320 let (node_repo, edge_repo) = test_repos();
321 let (n1, n2) = insert_two_nodes(&node_repo);
322
323 let mut e1 = make_edge(n1, n2, EdgeType::DependsOn);
324 e1.branch_id = BranchId::from("branch-a");
325
326 let mut e2 = make_edge(n2, n1, EdgeType::PartOf);
327 e2.branch_id = BranchId::from("branch-a");
328
329 let mut e3 = make_edge(n1, n2, EdgeType::RelatedTo);
330 e3.branch_id = BranchId::from("branch-b");
331
332 edge_repo.insert(&e1).unwrap();
333 edge_repo.insert(&e2).unwrap();
334 edge_repo.insert(&e3).unwrap();
335
336 let deleted = edge_repo
337 .delete_by_branch(&BranchId::from("branch-a"))
338 .unwrap();
339 assert_eq!(deleted, 2, "should delete 2 edges from branch-a");
340
341 let depends = edge_repo.find_by_type(EdgeType::DependsOn).unwrap();
343 assert!(depends.is_empty(), "DependsOn from branch-a should be gone");
344
345 let related = edge_repo.find_by_type(EdgeType::RelatedTo).unwrap();
347 assert_eq!(related.len(), 1, "branch-b edge should still exist");
348 }
349
350 #[test]
351 fn delete_by_branch_empty() {
352 let (_node_repo, edge_repo) = test_repos();
353 let deleted = edge_repo
354 .delete_by_branch(&BranchId::from("empty-branch"))
355 .unwrap();
356 assert_eq!(deleted, 0, "should delete 0 edges from empty branch");
357 }
358}