Skip to main content

sqlite_knowledge_graph/graph/
entity.rs

1//! Entity storage module for the knowledge graph.
2
3use rusqlite::params;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7use crate::error::{Error, Result};
8
9/// Represents an entity in the knowledge graph.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct Entity {
12    pub id: Option<i64>,
13    pub entity_type: String,
14    pub name: String,
15    pub properties: HashMap<String, serde_json::Value>,
16    pub created_at: Option<i64>,
17    pub updated_at: Option<i64>,
18}
19
20impl Entity {
21    /// Create a new entity.
22    pub fn new(entity_type: impl Into<String>, name: impl Into<String>) -> Self {
23        Self {
24            id: None,
25            entity_type: entity_type.into(),
26            name: name.into(),
27            properties: HashMap::new(),
28            created_at: None,
29            updated_at: None,
30        }
31    }
32
33    /// Create a new entity with properties.
34    pub fn with_properties(
35        entity_type: impl Into<String>,
36        name: impl Into<String>,
37        properties: HashMap<String, serde_json::Value>,
38    ) -> Self {
39        Self {
40            id: None,
41            entity_type: entity_type.into(),
42            name: name.into(),
43            properties,
44            created_at: None,
45            updated_at: None,
46        }
47    }
48
49    /// Set a property.
50    pub fn set_property(&mut self, key: impl Into<String>, value: serde_json::Value) {
51        self.properties.insert(key.into(), value);
52    }
53
54    /// Get a property.
55    pub fn get_property(&self, key: &str) -> Option<&serde_json::Value> {
56        self.properties.get(key)
57    }
58}
59
60/// Insert a new entity into the database.
61pub fn insert_entity(conn: &rusqlite::Connection, entity: &Entity) -> Result<i64> {
62    let properties_json = serde_json::to_string(&entity.properties)?;
63
64    conn.execute(
65        r#"
66        INSERT INTO kg_entities (entity_type, name, properties)
67        VALUES (?1, ?2, ?3)
68        "#,
69        params![entity.entity_type, entity.name, properties_json],
70    )?;
71
72    Ok(conn.last_insert_rowid())
73}
74
75/// Get an entity by ID.
76pub fn get_entity(conn: &rusqlite::Connection, id: i64) -> Result<Entity> {
77    let mut stmt = conn.prepare(
78        r#"
79        SELECT id, entity_type, name, properties, created_at, updated_at
80        FROM kg_entities
81        WHERE id = ?1
82        "#,
83    )?;
84
85    let entity = stmt.query_row(params![id], |row| {
86        let properties_json: String = row.get(3)?;
87        let properties: HashMap<String, serde_json::Value> =
88            serde_json::from_str(&properties_json).unwrap_or_default();
89
90        Ok(Entity {
91            id: Some(row.get(0)?),
92            entity_type: row.get(1)?,
93            name: row.get(2)?,
94            properties,
95            created_at: row.get(4)?,
96            updated_at: row.get(5)?,
97        })
98    })?;
99
100    Ok(entity)
101}
102
103/// List entities with optional filtering.
104pub fn list_entities(
105    conn: &rusqlite::Connection,
106    entity_type: Option<&str>,
107    limit: Option<i64>,
108) -> Result<Vec<Entity>> {
109    let mut query =
110        "SELECT id, entity_type, name, properties, created_at, updated_at FROM kg_entities"
111            .to_string();
112
113    let mut params_vec: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
114
115    if let Some(et) = entity_type {
116        query.push_str(" WHERE entity_type = ?1");
117        params_vec.push(Box::new(et.to_string()));
118    }
119
120    query.push_str(" ORDER BY created_at DESC");
121
122    if let Some(lim) = limit {
123        query.push_str(" LIMIT ?");
124        params_vec.push(Box::new(lim));
125    }
126
127    let mut stmt = conn.prepare(&query)?;
128
129    // Convert boxed params to references for query_map
130    let params_refs: Vec<&dyn rusqlite::ToSql> = params_vec.iter().map(|p| p.as_ref()).collect();
131
132    let entities = stmt.query_map(params_refs.as_slice(), |row| {
133        // Handle NULL properties column
134        let properties_json: Option<String> = row.get(3)?;
135        let properties: HashMap<String, serde_json::Value> = match properties_json {
136            Some(json) => serde_json::from_str(&json).unwrap_or_default(),
137            None => HashMap::new(),
138        };
139
140        Ok(Entity {
141            id: Some(row.get(0)?),
142            entity_type: row.get(1)?,
143            name: row.get(2)?,
144            properties,
145            created_at: row.get(4)?,
146            updated_at: row.get(5)?,
147        })
148    })?;
149
150    let mut result = Vec::new();
151    for entity in entities {
152        result.push(entity?);
153    }
154
155    Ok(result)
156}
157
158/// Update an entity.
159pub fn update_entity(conn: &rusqlite::Connection, entity: &Entity) -> Result<()> {
160    let id = entity.id.ok_or(Error::EntityNotFound(0))?;
161    let properties_json = serde_json::to_string(&entity.properties)?;
162
163    let updated_at = std::time::SystemTime::now()
164        .duration_since(std::time::UNIX_EPOCH)
165        .map_err(|_| Error::InvalidInput("system clock before UNIX epoch".to_string()))?
166        .as_secs() as i64;
167
168    let affected = conn.execute(
169        r#"
170        UPDATE kg_entities
171        SET entity_type = ?1, name = ?2, properties = ?3, updated_at = ?4
172        WHERE id = ?5
173        "#,
174        params![
175            entity.entity_type,
176            entity.name,
177            properties_json,
178            updated_at,
179            id
180        ],
181    )?;
182
183    if affected == 0 {
184        return Err(Error::EntityNotFound(id));
185    }
186
187    Ok(())
188}
189
190/// Delete an entity by ID.
191pub fn delete_entity(conn: &rusqlite::Connection, id: i64) -> Result<()> {
192    let affected = conn.execute("DELETE FROM kg_entities WHERE id = ?1", params![id])?;
193
194    if affected == 0 {
195        return Err(Error::EntityNotFound(id));
196    }
197
198    Ok(())
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use rusqlite::Connection;
205
206    #[test]
207    fn test_insert_entity() {
208        let conn = Connection::open_in_memory().unwrap();
209        crate::schema::create_schema(&conn).unwrap();
210
211        let entity = Entity::new("paper", "Test Paper");
212        let id = insert_entity(&conn, &entity).unwrap();
213        assert!(id > 0);
214    }
215
216    #[test]
217    fn test_get_entity() {
218        let conn = Connection::open_in_memory().unwrap();
219        crate::schema::create_schema(&conn).unwrap();
220
221        let entity = Entity::new("paper", "Test Paper");
222        let id = insert_entity(&conn, &entity).unwrap();
223
224        let retrieved = get_entity(&conn, id).unwrap();
225        assert_eq!(retrieved.id, Some(id));
226        assert_eq!(retrieved.entity_type, "paper");
227        assert_eq!(retrieved.name, "Test Paper");
228    }
229
230    #[test]
231    fn test_list_entities() {
232        let conn = Connection::open_in_memory().unwrap();
233        crate::schema::create_schema(&conn).unwrap();
234
235        insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
236        insert_entity(&conn, &Entity::new("paper", "Paper 2")).unwrap();
237        insert_entity(&conn, &Entity::new("skill", "Skill 1")).unwrap();
238
239        let papers = list_entities(&conn, Some("paper"), None).unwrap();
240        assert_eq!(papers.len(), 2);
241
242        let all = list_entities(&conn, None, Some(2)).unwrap();
243        assert_eq!(all.len(), 2);
244    }
245
246    #[test]
247    fn test_entity_properties() {
248        let conn = Connection::open_in_memory().unwrap();
249        crate::schema::create_schema(&conn).unwrap();
250
251        let mut entity = Entity::new("paper", "Test Paper");
252        entity.set_property("author", serde_json::json!("John Doe"));
253        entity.set_property("year", serde_json::json!(2024));
254
255        let id = insert_entity(&conn, &entity).unwrap();
256
257        let retrieved = get_entity(&conn, id).unwrap();
258        assert_eq!(
259            retrieved.get_property("author"),
260            Some(&serde_json::json!("John Doe"))
261        );
262        assert_eq!(
263            retrieved.get_property("year"),
264            Some(&serde_json::json!(2024))
265        );
266    }
267}