Skip to main content

steel_memory_lib/storage/
knowledge_graph.rs

1use std::path::{Path, PathBuf};
2use rusqlite::{Connection, params};
3use crate::types::{Entity, Triple};
4
5pub struct KnowledgeGraph {
6    db_path: PathBuf,
7}
8
9impl KnowledgeGraph {
10    pub fn new(db_path: &Path) -> anyhow::Result<Self> {
11        let conn = open_conn(db_path)?;
12        conn.execute_batch(
13            "PRAGMA journal_mode=WAL;
14             CREATE TABLE IF NOT EXISTS entities (
15                 id TEXT PRIMARY KEY,
16                 name TEXT NOT NULL,
17                 entity_type TEXT DEFAULT 'unknown',
18                 properties TEXT DEFAULT '{}',
19                 created_at TEXT DEFAULT CURRENT_TIMESTAMP
20             );
21             CREATE TABLE IF NOT EXISTS triples (
22                 id TEXT PRIMARY KEY,
23                 subject TEXT NOT NULL,
24                 predicate TEXT NOT NULL,
25                 object TEXT NOT NULL,
26                 valid_from TEXT,
27                 valid_to TEXT,
28                 confidence REAL DEFAULT 1.0,
29                 source_closet TEXT,
30                 source_file TEXT,
31                 extracted_at TEXT DEFAULT CURRENT_TIMESTAMP
32             );
33             CREATE INDEX IF NOT EXISTS idx_triples_subject ON triples(subject);
34             CREATE INDEX IF NOT EXISTS idx_triples_object ON triples(object);",
35        )?;
36        Ok(Self { db_path: db_path.to_path_buf() })
37    }
38
39    pub fn add_triple(
40        &self,
41        subject: &str,
42        predicate: &str,
43        object: &str,
44        confidence: f64,
45        source_closet: Option<&str>,
46        source_file: Option<&str>,
47    ) -> anyhow::Result<String> {
48        let conn = open_conn(&self.db_path)?;
49        let now = chrono::Utc::now().to_rfc3339();
50
51        // Ensure entities exist
52        let subj_id = normalize_id(subject);
53        let obj_id = normalize_id(object);
54
55        upsert_entity(&conn, &subj_id, subject)?;
56        upsert_entity(&conn, &obj_id, object)?;
57
58        let pred_norm = normalize_id(predicate);
59        let ts = chrono::Utc::now().timestamp_millis();
60        let triple_id = format!("triple_{subj_id}_{pred_norm}_{obj_id}_{ts}");
61
62        conn.execute(
63            "INSERT OR REPLACE INTO triples
64             (id, subject, predicate, object, valid_from, confidence, source_closet, source_file, extracted_at)
65             VALUES (?1,?2,?3,?4,?5,?6,?7,?8,?9)",
66            params![
67                triple_id, subj_id, predicate, obj_id,
68                now, confidence, source_closet, source_file, now
69            ],
70        )?;
71        Ok(triple_id)
72    }
73
74    pub fn invalidate_triple(
75        &self,
76        subject: &str,
77        predicate: &str,
78        object: &str,
79    ) -> anyhow::Result<usize> {
80        let conn = open_conn(&self.db_path)?;
81        let now = chrono::Utc::now().to_rfc3339();
82        let subj_id = normalize_id(subject);
83        let obj_id = normalize_id(object);
84        let n = conn.execute(
85            "UPDATE triples SET valid_to=?1 WHERE subject=?2 AND predicate=?3 AND object=?4 AND valid_to IS NULL",
86            params![now, subj_id, predicate, obj_id],
87        )?;
88        Ok(n)
89    }
90
91    pub fn query_entity(
92        &self,
93        entity: &str,
94        direction: &str,
95    ) -> anyhow::Result<Vec<Triple>> {
96        let conn = open_conn(&self.db_path)?;
97        let entity_id = normalize_id(entity);
98        let triples = match direction {
99            "outgoing" => {
100                let mut stmt = conn.prepare(
101                    "SELECT id,subject,predicate,object,valid_from,valid_to,confidence,source_closet,source_file,extracted_at FROM triples WHERE subject=?1 AND valid_to IS NULL ORDER BY extracted_at DESC"
102                )?;
103                stmt.query_map(params![entity_id], row_to_triple)?
104                    .filter_map(|r| r.ok())
105                    .collect()
106            }
107            "incoming" => {
108                let mut stmt = conn.prepare(
109                    "SELECT id,subject,predicate,object,valid_from,valid_to,confidence,source_closet,source_file,extracted_at FROM triples WHERE object=?1 AND valid_to IS NULL ORDER BY extracted_at DESC"
110                )?;
111                stmt.query_map(params![entity_id], row_to_triple)?
112                    .filter_map(|r| r.ok())
113                    .collect()
114            }
115            _ => {
116                let mut stmt = conn.prepare(
117                    "SELECT id,subject,predicate,object,valid_from,valid_to,confidence,source_closet,source_file,extracted_at FROM triples WHERE (subject=?1 OR object=?1) AND valid_to IS NULL ORDER BY extracted_at DESC"
118                )?;
119                stmt.query_map(params![entity_id], row_to_triple)?
120                    .filter_map(|r| r.ok())
121                    .collect()
122            }
123        };
124        Ok(triples)
125    }
126
127    pub fn timeline(&self, entity: &str, limit: usize) -> anyhow::Result<Vec<Triple>> {
128        let conn = open_conn(&self.db_path)?;
129        let entity_id = normalize_id(entity);
130        let mut stmt = conn.prepare(
131            "SELECT id,subject,predicate,object,valid_from,valid_to,confidence,source_closet,source_file,extracted_at FROM triples WHERE subject=?1 OR object=?1 ORDER BY extracted_at DESC LIMIT ?2"
132        )?;
133        let triples = stmt.query_map(params![entity_id, limit as i64], row_to_triple)?
134            .filter_map(|r| r.ok())
135            .collect();
136        Ok(triples)
137    }
138
139    pub fn all_triples(&self, limit: usize) -> anyhow::Result<Vec<Triple>> {
140        let conn = open_conn(&self.db_path)?;
141        let mut stmt = conn.prepare(
142            "SELECT id,subject,predicate,object,valid_from,valid_to,confidence,source_closet,source_file,extracted_at FROM triples ORDER BY extracted_at DESC LIMIT ?1"
143        )?;
144        let triples = stmt.query_map(params![limit as i64], row_to_triple)?
145            .filter_map(|r| r.ok())
146            .collect();
147        Ok(triples)
148    }
149
150    pub fn stats(&self) -> anyhow::Result<serde_json::Value> {
151        let conn = open_conn(&self.db_path)?;
152        let entity_count: i64 = conn.query_row("SELECT COUNT(*) FROM entities", [], |r| r.get(0))?;
153        let triple_count: i64 = conn.query_row("SELECT COUNT(*) FROM triples", [], |r| r.get(0))?;
154        let valid_count: i64 = conn.query_row("SELECT COUNT(*) FROM triples WHERE valid_to IS NULL", [], |r| r.get(0))?;
155        Ok(serde_json::json!({
156            "entities": entity_count,
157            "triples": triple_count,
158            "valid_triples": valid_count,
159        }))
160    }
161
162    #[allow(dead_code)]
163    pub fn list_entities(&self, limit: usize) -> anyhow::Result<Vec<Entity>> {
164        let conn = open_conn(&self.db_path)?;
165        let mut stmt = conn.prepare(
166            "SELECT id,name,entity_type,properties,created_at FROM entities ORDER BY created_at DESC LIMIT ?1"
167        )?;
168        let entities = stmt.query_map(params![limit as i64], |row| {
169            Ok(Entity {
170                id: row.get(0)?,
171                name: row.get(1)?,
172                entity_type: row.get(2)?,
173                properties: serde_json::from_str(&row.get::<_, String>(3)?).unwrap_or_default(),
174                created_at: row.get(4)?,
175            })
176        })?.filter_map(|r| r.ok()).collect();
177        Ok(entities)
178    }
179}
180
181fn upsert_entity(conn: &Connection, id: &str, name: &str) -> anyhow::Result<()> {
182    conn.execute(
183        "INSERT OR IGNORE INTO entities (id, name) VALUES (?1, ?2)",
184        params![id, name],
185    )?;
186    Ok(())
187}
188
189fn open_conn(path: &Path) -> anyhow::Result<Connection> {
190    let conn = Connection::open(path)?;
191    conn.execute_batch("PRAGMA journal_mode=WAL; PRAGMA foreign_keys=ON;")?;
192    Ok(conn)
193}
194
195fn row_to_triple(row: &rusqlite::Row<'_>) -> rusqlite::Result<Triple> {
196    Ok(Triple {
197        id: row.get(0)?,
198        subject: row.get(1)?,
199        predicate: row.get(2)?,
200        object: row.get(3)?,
201        valid_from: row.get(4)?,
202        valid_to: row.get(5)?,
203        confidence: row.get(6)?,
204        source_closet: row.get(7)?,
205        source_file: row.get(8)?,
206        extracted_at: row.get(9)?,
207    })
208}
209
210pub fn normalize_id(name: &str) -> String {
211    let lower = name.to_lowercase();
212    let mut result = String::new();
213    let mut prev_was_underscore = false;
214    for c in lower.chars() {
215        if c.is_ascii_alphanumeric() {
216            result.push(c);
217            prev_was_underscore = false;
218        } else if !prev_was_underscore {
219            result.push('_');
220            prev_was_underscore = true;
221        }
222    }
223    result.trim_matches('_').to_string()
224}