Skip to main content

seshat_storage/repository/
node_repository.rs

1//! SQLite implementation of [`NodeRepository`].
2
3use std::sync::{Arc, Mutex};
4
5use rusqlite::{Connection, params};
6use seshat_core::{BranchId, KnowledgeNature, KnowledgeNode, KnowledgeWeight, NodeId};
7
8use super::NodeRepository;
9use crate::StorageError;
10
11/// SQLite-backed node repository.
12#[derive(Debug, Clone)]
13pub struct SqliteNodeRepository {
14    conn: Arc<Mutex<Connection>>,
15}
16
17impl SqliteNodeRepository {
18    /// Create a new repository backed by the given connection.
19    pub fn new(conn: Arc<Mutex<Connection>>) -> Self {
20        Self { conn }
21    }
22
23    fn conn(&self) -> Result<std::sync::MutexGuard<'_, Connection>, StorageError> {
24        self.conn.lock().map_err(|e| {
25            StorageError::QueryError(format!("Failed to acquire connection lock: {e}"))
26        })
27    }
28}
29
30/// Serialize `ext_data` to a JSON string for storage.
31fn serialize_ext_data(data: &Option<serde_json::Value>) -> Result<Option<String>, StorageError> {
32    data.as_ref()
33        .map(serde_json::to_string)
34        .transpose()
35        .map_err(|e| StorageError::SerializationError(e.to_string()))
36}
37
38impl NodeRepository for SqliteNodeRepository {
39    fn insert(&self, node: &KnowledgeNode) -> Result<KnowledgeNode, StorageError> {
40        let conn = self.conn()?;
41
42        let ext_data_str = serialize_ext_data(&node.ext_data)?;
43
44        conn.execute(
45            "INSERT INTO nodes (branch_id, nature, weight, confidence, adoption_count, total_count, description, ext_data)
46             VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)",
47            params![
48                node.branch_id.0,
49                node.nature.as_str(),
50                node.weight.as_str(),
51                node.confidence,
52                node.adoption_count,
53                node.total_count,
54                node.description,
55                ext_data_str,
56            ],
57        )?;
58
59        let id = conn.last_insert_rowid();
60
61        let mut inserted = node.clone();
62        inserted.id = NodeId(id);
63        Ok(inserted)
64    }
65
66    fn get_by_id(&self, id: NodeId) -> Result<KnowledgeNode, StorageError> {
67        let conn = self.conn()?;
68
69        conn.query_row(
70            "SELECT id, branch_id, nature, weight, confidence, adoption_count, total_count, description, ext_data
71             FROM nodes WHERE id = ?1",
72            params![id.0],
73            row_to_node,
74        )
75        .map_err(|e| match e {
76            rusqlite::Error::QueryReturnedNoRows => StorageError::NotFound {
77                entity: "Node",
78                id: id.0.to_string(),
79            },
80            other => StorageError::from(other),
81        })
82    }
83
84    fn find_by_nature(&self, nature: KnowledgeNature) -> Result<Vec<KnowledgeNode>, StorageError> {
85        self.query_nodes(
86            "SELECT id, branch_id, nature, weight, confidence, adoption_count, total_count, description, ext_data
87             FROM nodes WHERE nature = ?1",
88            &nature.as_str(),
89        )
90    }
91
92    fn find_by_branch(&self, branch_id: &BranchId) -> Result<Vec<KnowledgeNode>, StorageError> {
93        self.query_nodes(
94            "SELECT id, branch_id, nature, weight, confidence, adoption_count, total_count, description, ext_data
95             FROM nodes WHERE branch_id = ?1",
96            &branch_id.0,
97        )
98    }
99
100    fn update(&self, node: &KnowledgeNode) -> Result<(), StorageError> {
101        let conn = self.conn()?;
102
103        let ext_data_str = serialize_ext_data(&node.ext_data)?;
104
105        let affected = conn.execute(
106            "UPDATE nodes SET branch_id = ?1, nature = ?2, weight = ?3, confidence = ?4,
107             adoption_count = ?5, total_count = ?6, description = ?7, ext_data = ?8
108             WHERE id = ?9",
109            params![
110                node.branch_id.0,
111                node.nature.as_str(),
112                node.weight.as_str(),
113                node.confidence,
114                node.adoption_count,
115                node.total_count,
116                node.description,
117                ext_data_str,
118                node.id.0,
119            ],
120        )?;
121
122        if affected == 0 {
123            return Err(StorageError::NotFound {
124                entity: "Node",
125                id: node.id.0.to_string(),
126            });
127        }
128
129        Ok(())
130    }
131
132    fn delete(&self, id: NodeId) -> Result<(), StorageError> {
133        let conn = self.conn()?;
134
135        let affected = conn.execute("DELETE FROM nodes WHERE id = ?1", params![id.0])?;
136
137        if affected == 0 {
138            return Err(StorageError::NotFound {
139                entity: "Node",
140                id: id.0.to_string(),
141            });
142        }
143
144        Ok(())
145    }
146
147    fn delete_by_branch(&self, branch_id: &BranchId) -> Result<usize, StorageError> {
148        let conn = self.conn()?;
149
150        let affected = conn.execute(
151            "DELETE FROM nodes WHERE branch_id = ?1",
152            params![branch_id.0],
153        )?;
154
155        Ok(affected)
156    }
157
158    fn delete_facts_by_branch(&self, branch_id: &BranchId) -> Result<usize, StorageError> {
159        let conn = self.conn()?;
160
161        let affected = conn.execute(
162            "DELETE FROM nodes WHERE branch_id = ?1 AND nature = 'fact'",
163            params![branch_id.0],
164        )?;
165
166        Ok(affected)
167    }
168
169    fn delete_auto_detected_by_branch(&self, branch_id: &BranchId) -> Result<usize, StorageError> {
170        let conn = self.conn()?;
171
172        let affected = conn.execute(
173            "DELETE FROM nodes WHERE branch_id = ?1
174             AND json_extract(ext_data, '$.source') = 'auto_detected'",
175            params![branch_id.0],
176        )?;
177
178        Ok(affected)
179    }
180
181    fn find_conventions_by_branch(
182        &self,
183        branch_id: &BranchId,
184    ) -> Result<Vec<KnowledgeNode>, StorageError> {
185        let conn = self.conn()?;
186        let mut stmt = conn.prepare(
187            "SELECT id, branch_id, nature, weight, confidence, adoption_count, total_count, description, ext_data
188            FROM nodes
189            WHERE branch_id = ?1
190              AND json_extract(ext_data, '$.source') IN ('auto_detected', 'user')
191              AND COALESCE(json_extract(ext_data, '$.removed'), 0) NOT IN (1, 'true')",
192        )?;
193        let rows = stmt.query_map(params![branch_id.0], row_to_node)?;
194        rows.collect::<Result<Vec<_>, _>>().map_err(Into::into)
195    }
196}
197
198// ---------------------------------------------------------------------------
199// Helpers
200// ---------------------------------------------------------------------------
201
202impl SqliteNodeRepository {
203    /// Run a parameterised node query and collect the results.
204    fn query_nodes(
205        &self,
206        sql: &str,
207        param: &dyn rusqlite::types::ToSql,
208    ) -> Result<Vec<KnowledgeNode>, StorageError> {
209        let conn = self.conn()?;
210        let mut stmt = conn.prepare(sql)?;
211        let rows = stmt.query_map([param], row_to_node)?;
212        rows.collect::<Result<Vec<_>, _>>().map_err(Into::into)
213    }
214}
215
216/// Map a rusqlite Row to a `KnowledgeNode`.
217fn row_to_node(row: &rusqlite::Row<'_>) -> rusqlite::Result<KnowledgeNode> {
218    let id: i64 = row.get(0)?;
219    let branch_id: String = row.get(1)?;
220    let nature_str: String = row.get(2)?;
221    let weight_str: String = row.get(3)?;
222    let confidence: f64 = row.get(4)?;
223    let adoption_count: u32 = row.get(5)?;
224    let total_count: u32 = row.get(6)?;
225    let description: String = row.get(7)?;
226    let ext_data_str: Option<String> = row.get(8)?;
227
228    let nature: KnowledgeNature = nature_str.parse().map_err(|e| {
229        rusqlite::Error::FromSqlConversionFailure(2, rusqlite::types::Type::Text, Box::new(e))
230    })?;
231
232    let weight: KnowledgeWeight = weight_str.parse().map_err(|e| {
233        rusqlite::Error::FromSqlConversionFailure(3, rusqlite::types::Type::Text, Box::new(e))
234    })?;
235
236    let ext_data = ext_data_str
237        .map(|s| serde_json::from_str(&s))
238        .transpose()
239        .map_err(|e| {
240            rusqlite::Error::FromSqlConversionFailure(8, rusqlite::types::Type::Text, Box::new(e))
241        })?;
242
243    Ok(KnowledgeNode {
244        id: NodeId(id),
245        branch_id: BranchId(branch_id),
246        nature,
247        weight,
248        confidence,
249        adoption_count,
250        total_count,
251        description,
252        ext_data,
253    })
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259    use crate::Database;
260    use seshat_core::test_helpers::make_knowledge_node;
261
262    /// Helper: create an in-memory DB and return a `SqliteNodeRepository`.
263    fn test_repo() -> SqliteNodeRepository {
264        let db = Database::open(":memory:").expect("in-memory DB");
265        SqliteNodeRepository::new(db.connection().clone())
266    }
267
268    #[test]
269    fn insert_and_get_by_id() {
270        let repo = test_repo();
271        let node = make_knowledge_node(KnowledgeNature::Convention, 0.9);
272
273        let inserted = repo.insert(&node).expect("insert should succeed");
274        assert_ne!(inserted.id.0, 0, "should get assigned ID");
275
276        let fetched = repo
277            .get_by_id(inserted.id)
278            .expect("get_by_id should succeed");
279        assert_eq!(fetched.id, inserted.id);
280        assert_eq!(fetched.nature, KnowledgeNature::Convention);
281        assert_eq!(fetched.weight, KnowledgeWeight::Strong);
282        assert!((fetched.confidence - 0.9).abs() < f64::EPSILON);
283        assert_eq!(fetched.branch_id, BranchId::from("main"));
284    }
285
286    #[test]
287    fn insert_with_ext_data() {
288        let repo = test_repo();
289        let mut node = make_knowledge_node(KnowledgeNature::Decision, 1.0);
290        node.ext_data = Some(serde_json::json!({"reasoning": "perf requirement"}));
291        node.description = "Use SQLite".to_string();
292
293        let inserted = repo.insert(&node).expect("insert");
294        let fetched = repo.get_by_id(inserted.id).expect("get");
295
296        assert_eq!(
297            fetched.ext_data.as_ref().unwrap()["reasoning"],
298            "perf requirement"
299        );
300        assert_eq!(fetched.description, "Use SQLite");
301    }
302
303    #[test]
304    fn get_by_id_not_found() {
305        let repo = test_repo();
306        let result = repo.get_by_id(NodeId(999));
307
308        assert!(result.is_err());
309        match result.unwrap_err() {
310            StorageError::NotFound { entity, id } => {
311                assert_eq!(entity, "Node");
312                assert_eq!(id, "999");
313            }
314            other => panic!("expected NotFound, got: {other}"),
315        }
316    }
317
318    #[test]
319    fn find_by_nature() {
320        let repo = test_repo();
321
322        let n1 = make_knowledge_node(KnowledgeNature::Convention, 0.9);
323        let n2 = make_knowledge_node(KnowledgeNature::Fact, 0.5);
324        let n3 = make_knowledge_node(KnowledgeNature::Convention, 0.6);
325
326        repo.insert(&n1).unwrap();
327        repo.insert(&n2).unwrap();
328        repo.insert(&n3).unwrap();
329
330        let conventions = repo.find_by_nature(KnowledgeNature::Convention).unwrap();
331        assert_eq!(conventions.len(), 2);
332
333        let facts = repo.find_by_nature(KnowledgeNature::Fact).unwrap();
334        assert_eq!(facts.len(), 1);
335    }
336
337    #[test]
338    fn find_by_branch() {
339        let repo = test_repo();
340
341        let mut n1 = make_knowledge_node(KnowledgeNature::Fact, 0.8);
342        n1.branch_id = BranchId::from("feature-a");
343
344        let n2 = make_knowledge_node(KnowledgeNature::Fact, 0.8);
345        // n2 defaults to branch "main"
346
347        repo.insert(&n1).unwrap();
348        repo.insert(&n2).unwrap();
349
350        let feature_nodes = repo.find_by_branch(&BranchId::from("feature-a")).unwrap();
351        assert_eq!(feature_nodes.len(), 1);
352
353        let main_nodes = repo.find_by_branch(&BranchId::from("main")).unwrap();
354        assert_eq!(main_nodes.len(), 1);
355    }
356
357    #[test]
358    fn update_node() {
359        let repo = test_repo();
360        let node = make_knowledge_node(KnowledgeNature::Convention, 0.9);
361
362        let mut inserted = repo.insert(&node).unwrap();
363        inserted.description = "Updated description".to_string();
364        inserted.confidence = 0.95;
365        inserted.adoption_count = 19;
366        inserted.total_count = 20;
367
368        repo.update(&inserted).expect("update should succeed");
369
370        let fetched = repo.get_by_id(inserted.id).unwrap();
371        assert_eq!(fetched.description, "Updated description");
372        assert!((fetched.confidence - 0.95).abs() < f64::EPSILON);
373        assert_eq!(fetched.adoption_count, 19);
374        assert_eq!(fetched.total_count, 20);
375    }
376
377    #[test]
378    fn update_not_found() {
379        let repo = test_repo();
380        let mut node = make_knowledge_node(KnowledgeNature::Fact, 0.5);
381        node.id = NodeId(999);
382
383        let result = repo.update(&node);
384        assert!(matches!(result, Err(StorageError::NotFound { .. })));
385    }
386
387    #[test]
388    fn delete_node() {
389        let repo = test_repo();
390        let node = make_knowledge_node(KnowledgeNature::Convention, 0.9);
391        let inserted = repo.insert(&node).unwrap();
392
393        repo.delete(inserted.id).expect("delete should succeed");
394
395        let result = repo.get_by_id(inserted.id);
396        assert!(matches!(result, Err(StorageError::NotFound { .. })));
397    }
398
399    #[test]
400    fn delete_not_found() {
401        let repo = test_repo();
402        let result = repo.delete(NodeId(999));
403        assert!(matches!(result, Err(StorageError::NotFound { .. })));
404    }
405
406    #[test]
407    fn all_nature_variants_roundtrip() {
408        let repo = test_repo();
409
410        let natures = [
411            KnowledgeNature::Fact,
412            KnowledgeNature::Convention,
413            KnowledgeNature::Observation,
414            KnowledgeNature::Decision,
415            KnowledgeNature::Preference,
416        ];
417
418        for nature in natures {
419            let node = make_knowledge_node(nature, 0.5);
420            let inserted = repo.insert(&node).unwrap();
421            let fetched = repo.get_by_id(inserted.id).unwrap();
422            assert_eq!(
423                fetched.nature, nature,
424                "nature roundtrip failed for {nature}"
425            );
426        }
427    }
428
429    #[test]
430    fn all_weight_variants_roundtrip() {
431        let repo = test_repo();
432
433        let cases: [(KnowledgeWeight, f64); 5] = [
434            (KnowledgeWeight::Info, 0.1),
435            (KnowledgeWeight::Weak, 0.3),
436            (KnowledgeWeight::Moderate, 0.6),
437            (KnowledgeWeight::Strong, 0.9),
438            (KnowledgeWeight::Rule, 1.0),
439        ];
440
441        for (expected_weight, confidence) in cases {
442            let mut node = make_knowledge_node(KnowledgeNature::Fact, confidence);
443            // Override the weight to test independently from auto-assignment
444            node.weight = expected_weight;
445            let inserted = repo.insert(&node).unwrap();
446            let fetched = repo.get_by_id(inserted.id).unwrap();
447            assert_eq!(
448                fetched.weight, expected_weight,
449                "weight roundtrip failed for {expected_weight}"
450            );
451        }
452    }
453
454    #[test]
455    fn delete_by_branch() {
456        let repo = test_repo();
457        let branch_a = BranchId::from("branch-a");
458        let branch_b = BranchId::from("branch-b");
459
460        let mut n1 = make_knowledge_node(KnowledgeNature::Fact, 0.8);
461        n1.branch_id = branch_a.clone();
462        let mut n2 = make_knowledge_node(KnowledgeNature::Fact, 0.7);
463        n2.branch_id = branch_a.clone();
464        let mut n3 = make_knowledge_node(KnowledgeNature::Fact, 0.6);
465        n3.branch_id = branch_b.clone();
466
467        repo.insert(&n1).unwrap();
468        repo.insert(&n2).unwrap();
469        repo.insert(&n3).unwrap();
470
471        let deleted = repo.delete_by_branch(&branch_a).unwrap();
472        assert_eq!(deleted, 2, "should delete 2 nodes from branch-a");
473
474        let a_nodes = repo.find_by_branch(&branch_a).unwrap();
475        assert!(a_nodes.is_empty(), "branch-a should have no nodes");
476
477        let b_nodes = repo.find_by_branch(&branch_b).unwrap();
478        assert_eq!(b_nodes.len(), 1, "branch-b should still have 1 node");
479    }
480
481    #[test]
482    fn delete_by_branch_empty() {
483        let repo = test_repo();
484        let branch = BranchId::from("empty-branch");
485
486        let deleted = repo.delete_by_branch(&branch).unwrap();
487        assert_eq!(deleted, 0, "should delete 0 nodes from empty branch");
488    }
489
490    #[test]
491    fn delete_auto_detected_preserves_user_decisions() {
492        let repo = test_repo();
493        let branch = BranchId::from("main");
494
495        // Auto-detected convention
496        let mut auto_node = make_knowledge_node(KnowledgeNature::Convention, 0.9);
497        auto_node.branch_id = branch.clone();
498        auto_node.description = "Uses thiserror".to_string();
499        auto_node.ext_data = Some(serde_json::json!({
500            "source": "auto_detected",
501            "detector_name": "error_handling"
502        }));
503        repo.insert(&auto_node).unwrap();
504
505        // User-recorded decision
506        let mut user_node = make_knowledge_node(KnowledgeNature::Decision, 1.0);
507        user_node.branch_id = branch.clone();
508        user_node.description = "Always use Result".to_string();
509        user_node.ext_data = Some(serde_json::json!({
510            "source": "user",
511            "user_confirmed": true
512        }));
513        repo.insert(&user_node).unwrap();
514
515        // Module fact (no source field in ext_data)
516        let mut fact_node = make_knowledge_node(KnowledgeNature::Fact, 0.8);
517        fact_node.branch_id = branch.clone();
518        fact_node.description = "Module: seshat-core".to_string();
519        repo.insert(&fact_node).unwrap();
520
521        let deleted = repo.delete_auto_detected_by_branch(&branch).unwrap();
522        assert_eq!(deleted, 1, "should only delete auto_detected node");
523
524        let all_nodes = repo.find_by_branch(&branch).unwrap();
525        assert_eq!(all_nodes.len(), 2, "user decision + fact should remain");
526
527        // Verify the user node is still there
528        let user = all_nodes
529            .iter()
530            .find(|n| n.description == "Always use Result");
531        assert!(user.is_some(), "user decision should be preserved");
532
533        // Verify the fact node is still there
534        let fact = all_nodes
535            .iter()
536            .find(|n| n.description == "Module: seshat-core");
537        assert!(fact.is_some(), "fact node should be preserved");
538    }
539
540    #[test]
541    fn delete_auto_detected_empty_branch() {
542        let repo = test_repo();
543        let branch = BranchId::from("empty");
544
545        let deleted = repo.delete_auto_detected_by_branch(&branch).unwrap();
546        assert_eq!(deleted, 0);
547    }
548
549    #[test]
550    fn find_conventions_by_branch_returns_auto_and_user() {
551        let repo = test_repo();
552        let branch = BranchId::from("main");
553
554        // Auto-detected convention
555        let mut auto_node = make_knowledge_node(KnowledgeNature::Convention, 0.9);
556        auto_node.branch_id = branch.clone();
557        auto_node.description = "Uses thiserror".to_string();
558        auto_node.ext_data = Some(serde_json::json!({
559            "source": "auto_detected",
560            "detector_name": "error_handling"
561        }));
562        repo.insert(&auto_node).unwrap();
563
564        // User-recorded decision
565        let mut user_node = make_knowledge_node(KnowledgeNature::Decision, 1.0);
566        user_node.branch_id = branch.clone();
567        user_node.description = "Always use Result".to_string();
568        user_node.ext_data = Some(serde_json::json!({
569            "source": "user",
570            "user_confirmed": true
571        }));
572        repo.insert(&user_node).unwrap();
573
574        // Module fact (no source field — should NOT appear)
575        let mut fact_node = make_knowledge_node(KnowledgeNature::Fact, 0.8);
576        fact_node.branch_id = branch.clone();
577        fact_node.description = "Module: seshat-core".to_string();
578        repo.insert(&fact_node).unwrap();
579
580        let conventions = repo.find_conventions_by_branch(&branch).unwrap();
581        assert_eq!(
582            conventions.len(),
583            2,
584            "should return auto_detected + user nodes"
585        );
586
587        let descriptions: Vec<&str> = conventions.iter().map(|n| n.description.as_str()).collect();
588        assert!(descriptions.contains(&"Uses thiserror"));
589        assert!(descriptions.contains(&"Always use Result"));
590    }
591
592    #[test]
593    fn find_conventions_by_branch_excludes_other_branches() {
594        let repo = test_repo();
595
596        let mut n1 = make_knowledge_node(KnowledgeNature::Convention, 0.9);
597        n1.branch_id = BranchId::from("main");
598        n1.ext_data = Some(serde_json::json!({"source": "auto_detected"}));
599        repo.insert(&n1).unwrap();
600
601        let mut n2 = make_knowledge_node(KnowledgeNature::Convention, 0.9);
602        n2.branch_id = BranchId::from("feature");
603        n2.ext_data = Some(serde_json::json!({"source": "auto_detected"}));
604        repo.insert(&n2).unwrap();
605
606        let main_conventions = repo
607            .find_conventions_by_branch(&BranchId::from("main"))
608            .unwrap();
609        assert_eq!(main_conventions.len(), 1);
610
611        let feature_conventions = repo
612            .find_conventions_by_branch(&BranchId::from("feature"))
613            .unwrap();
614        assert_eq!(feature_conventions.len(), 1);
615    }
616}