Skip to main content

sqlite_knowledge_graph/graph/
entity.rs

1//! Entity storage module for the knowledge graph.
2
3use rusqlite::params;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6
7use crate::error::{Error, Result};
8
9/// Represents an entity in the knowledge graph.
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct Entity {
12    pub id: Option<i64>,
13    pub entity_type: String,
14    pub name: String,
15    pub properties: HashMap<String, serde_json::Value>,
16    pub created_at: Option<i64>,
17    pub updated_at: Option<i64>,
18}
19
20impl Entity {
21    /// Create a new entity.
22    pub fn new(entity_type: impl Into<String>, name: impl Into<String>) -> Self {
23        Self {
24            id: None,
25            entity_type: entity_type.into(),
26            name: name.into(),
27            properties: HashMap::new(),
28            created_at: None,
29            updated_at: None,
30        }
31    }
32
33    /// Create a new entity with properties.
34    pub fn with_properties(
35        entity_type: impl Into<String>,
36        name: impl Into<String>,
37        properties: HashMap<String, serde_json::Value>,
38    ) -> Self {
39        Self {
40            id: None,
41            entity_type: entity_type.into(),
42            name: name.into(),
43            properties,
44            created_at: None,
45            updated_at: None,
46        }
47    }
48
49    /// Set a property.
50    pub fn set_property(&mut self, key: impl Into<String>, value: serde_json::Value) {
51        self.properties.insert(key.into(), value);
52    }
53
54    /// Get a property.
55    pub fn get_property(&self, key: &str) -> Option<&serde_json::Value> {
56        self.properties.get(key)
57    }
58}
59
60/// Insert a new entity into the database.
61pub fn insert_entity(conn: &rusqlite::Connection, entity: &Entity) -> Result<i64> {
62    let properties_json = serde_json::to_string(&entity.properties)?;
63
64    conn.execute(
65        r#"
66        INSERT INTO kg_entities (entity_type, name, properties)
67        VALUES (?1, ?2, ?3)
68        "#,
69        params![entity.entity_type, entity.name, properties_json],
70    )?;
71
72    Ok(conn.last_insert_rowid())
73}
74
75/// Get an entity by ID.
76pub fn get_entity(conn: &rusqlite::Connection, id: i64) -> Result<Entity> {
77    let mut stmt = conn.prepare(
78        r#"
79        SELECT id, entity_type, name, properties, created_at, updated_at
80        FROM kg_entities
81        WHERE id = ?1
82        "#,
83    )?;
84
85    let entity = stmt.query_row(params![id], row_to_entity)?;
86    Ok(entity)
87}
88
89/// Map a `kg_entities` row (id, entity_type, name, properties, created_at,
90/// updated_at) to an [`Entity`]. Shared by all entity read paths.
91fn row_to_entity(row: &rusqlite::Row) -> rusqlite::Result<Entity> {
92    let properties_json: Option<String> = row.get(3)?;
93    let properties: HashMap<String, serde_json::Value> = match properties_json {
94        Some(json) => serde_json::from_str(&json).unwrap_or_default(),
95        None => HashMap::new(),
96    };
97
98    Ok(Entity {
99        id: Some(row.get(0)?),
100        entity_type: row.get(1)?,
101        name: row.get(2)?,
102        properties,
103        created_at: row.get(4)?,
104        updated_at: row.get(5)?,
105    })
106}
107
108/// Load multiple entities by ID, batching into `IN (...)` queries to avoid the
109/// N+1 pattern of one query per id. Ids are queried in chunks to stay under
110/// SQLite's bound-parameter limit; missing ids are simply absent from the
111/// result, and order is not guaranteed.
112pub(crate) fn get_entities_by_ids(conn: &rusqlite::Connection, ids: &[i64]) -> Result<Vec<Entity>> {
113    // SQLite's default SQLITE_MAX_VARIABLE_NUMBER is 999 on older builds; stay
114    // comfortably below it.
115    const CHUNK: usize = 900;
116
117    let mut result = Vec::with_capacity(ids.len());
118    for chunk in ids.chunks(CHUNK) {
119        let placeholders = std::iter::repeat("?")
120            .take(chunk.len())
121            .collect::<Vec<_>>()
122            .join(",");
123        let sql = format!(
124            "SELECT id, entity_type, name, properties, created_at, updated_at \
125             FROM kg_entities WHERE id IN ({placeholders})"
126        );
127        let mut stmt = conn.prepare(&sql)?;
128        let rows = stmt.query_map(rusqlite::params_from_iter(chunk.iter()), row_to_entity)?;
129        for row in rows {
130            result.push(row?);
131        }
132    }
133    Ok(result)
134}
135
136/// List entities with optional filtering.
137pub fn list_entities(
138    conn: &rusqlite::Connection,
139    entity_type: Option<&str>,
140    limit: Option<i64>,
141) -> Result<Vec<Entity>> {
142    let mut query =
143        "SELECT id, entity_type, name, properties, created_at, updated_at FROM kg_entities"
144            .to_string();
145
146    let mut params_vec: Vec<Box<dyn rusqlite::ToSql>> = Vec::new();
147
148    if let Some(et) = entity_type {
149        query.push_str(" WHERE entity_type = ?1");
150        params_vec.push(Box::new(et.to_string()));
151    }
152
153    query.push_str(" ORDER BY created_at DESC");
154
155    if let Some(lim) = limit {
156        query.push_str(" LIMIT ?");
157        params_vec.push(Box::new(lim));
158    }
159
160    let mut stmt = conn.prepare(&query)?;
161
162    // Convert boxed params to references for query_map
163    let params_refs: Vec<&dyn rusqlite::ToSql> = params_vec.iter().map(|p| p.as_ref()).collect();
164
165    let entities = stmt.query_map(params_refs.as_slice(), row_to_entity)?;
166
167    let mut result = Vec::new();
168    for entity in entities {
169        result.push(entity?);
170    }
171
172    Ok(result)
173}
174
175/// Update an entity.
176pub fn update_entity(conn: &rusqlite::Connection, entity: &Entity) -> Result<()> {
177    let id = entity.id.ok_or(Error::EntityNotFound(0))?;
178    let properties_json = serde_json::to_string(&entity.properties)?;
179
180    let updated_at = std::time::SystemTime::now()
181        .duration_since(std::time::UNIX_EPOCH)
182        .map_err(|_| Error::InvalidInput("system clock before UNIX epoch".to_string()))?
183        .as_secs() as i64;
184
185    let affected = conn.execute(
186        r#"
187        UPDATE kg_entities
188        SET entity_type = ?1, name = ?2, properties = ?3, updated_at = ?4
189        WHERE id = ?5
190        "#,
191        params![
192            entity.entity_type,
193            entity.name,
194            properties_json,
195            updated_at,
196            id
197        ],
198    )?;
199
200    if affected == 0 {
201        return Err(Error::EntityNotFound(id));
202    }
203
204    Ok(())
205}
206
207/// Delete an entity by ID.
208pub fn delete_entity(conn: &rusqlite::Connection, id: i64) -> Result<()> {
209    let affected = conn.execute("DELETE FROM kg_entities WHERE id = ?1", params![id])?;
210
211    if affected == 0 {
212        return Err(Error::EntityNotFound(id));
213    }
214
215    Ok(())
216}
217
218#[cfg(test)]
219mod tests {
220    use super::*;
221    use rusqlite::Connection;
222
223    #[test]
224    fn test_insert_entity() {
225        let conn = Connection::open_in_memory().unwrap();
226        crate::schema::create_schema(&conn).unwrap();
227
228        let entity = Entity::new("paper", "Test Paper");
229        let id = insert_entity(&conn, &entity).unwrap();
230        assert!(id > 0);
231    }
232
233    #[test]
234    fn test_get_entity() {
235        let conn = Connection::open_in_memory().unwrap();
236        crate::schema::create_schema(&conn).unwrap();
237
238        let entity = Entity::new("paper", "Test Paper");
239        let id = insert_entity(&conn, &entity).unwrap();
240
241        let retrieved = get_entity(&conn, id).unwrap();
242        assert_eq!(retrieved.id, Some(id));
243        assert_eq!(retrieved.entity_type, "paper");
244        assert_eq!(retrieved.name, "Test Paper");
245    }
246
247    #[test]
248    fn test_list_entities() {
249        let conn = Connection::open_in_memory().unwrap();
250        crate::schema::create_schema(&conn).unwrap();
251
252        insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
253        insert_entity(&conn, &Entity::new("paper", "Paper 2")).unwrap();
254        insert_entity(&conn, &Entity::new("skill", "Skill 1")).unwrap();
255
256        let papers = list_entities(&conn, Some("paper"), None).unwrap();
257        assert_eq!(papers.len(), 2);
258
259        let all = list_entities(&conn, None, Some(2)).unwrap();
260        assert_eq!(all.len(), 2);
261    }
262
263    #[test]
264    fn test_get_entities_by_ids_batches_and_skips_missing() {
265        let conn = Connection::open_in_memory().unwrap();
266        crate::schema::create_schema(&conn).unwrap();
267
268        let id1 = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
269        let id2 = insert_entity(&conn, &Entity::new("paper", "Paper 2")).unwrap();
270
271        // Request both real ids plus a non-existent one; missing ids are skipped.
272        let loaded = get_entities_by_ids(&conn, &[id1, 99999, id2]).unwrap();
273        assert_eq!(loaded.len(), 2);
274        let names: std::collections::HashSet<&str> =
275            loaded.iter().map(|e| e.name.as_str()).collect();
276        assert!(names.contains("Paper 1"));
277        assert!(names.contains("Paper 2"));
278    }
279
280    #[test]
281    fn test_get_entities_by_ids_empty() {
282        let conn = Connection::open_in_memory().unwrap();
283        crate::schema::create_schema(&conn).unwrap();
284
285        let loaded = get_entities_by_ids(&conn, &[]).unwrap();
286        assert!(loaded.is_empty());
287    }
288
289    #[test]
290    fn test_entity_properties() {
291        let conn = Connection::open_in_memory().unwrap();
292        crate::schema::create_schema(&conn).unwrap();
293
294        let mut entity = Entity::new("paper", "Test Paper");
295        entity.set_property("author", serde_json::json!("John Doe"));
296        entity.set_property("year", serde_json::json!(2024));
297
298        let id = insert_entity(&conn, &entity).unwrap();
299
300        let retrieved = get_entity(&conn, id).unwrap();
301        assert_eq!(
302            retrieved.get_property("author"),
303            Some(&serde_json::json!("John Doe"))
304        );
305        assert_eq!(
306            retrieved.get_property("year"),
307            Some(&serde_json::json!(2024))
308        );
309    }
310}