Skip to main content

sqlite_knowledge_graph/graph/
relation.rs

1//! Relation storage module for the knowledge graph.
2
3use rusqlite::params;
4use serde::{Deserialize, Serialize};
5use std::collections::{HashMap, VecDeque};
6
7use crate::error::{Error, Result};
8use crate::graph::entity::Entity;
9
10/// Represents a relation between entities in the knowledge graph.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Relation {
13    pub id: Option<i64>,
14    pub source_id: i64,
15    pub target_id: i64,
16    pub rel_type: String,
17    pub weight: f64,
18    pub properties: HashMap<String, serde_json::Value>,
19    pub created_at: Option<i64>,
20}
21
22impl Relation {
23    /// Create a new relation.
24    pub fn new(
25        source_id: i64,
26        target_id: i64,
27        rel_type: impl Into<String>,
28        weight: f64,
29    ) -> Result<Self> {
30        if !(0.0..=1.0).contains(&weight) {
31            return Err(Error::InvalidWeight(weight));
32        }
33
34        Ok(Self {
35            id: None,
36            source_id,
37            target_id,
38            rel_type: rel_type.into(),
39            weight,
40            properties: HashMap::new(),
41            created_at: None,
42        })
43    }
44
45    /// Set a property.
46    pub fn set_property(&mut self, key: impl Into<String>, value: serde_json::Value) {
47        self.properties.insert(key.into(), value);
48    }
49
50    /// Get a property.
51    pub fn get_property(&self, key: &str) -> Option<&serde_json::Value> {
52        self.properties.get(key)
53    }
54}
55
56/// Represents a neighbor in a graph traversal.
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct Neighbor {
59    pub entity: Entity,
60    pub relation: Relation,
61}
62
63/// Insert a new relation into the database.
64pub fn insert_relation(conn: &rusqlite::Connection, relation: &Relation) -> Result<i64> {
65    // Validate entities exist
66    crate::graph::entity::get_entity(conn, relation.source_id)?;
67    crate::graph::entity::get_entity(conn, relation.target_id)?;
68
69    let properties_json = serde_json::to_string(&relation.properties)?;
70
71    conn.execute(
72        r#"
73        INSERT INTO kg_relations (source_id, target_id, rel_type, weight, properties)
74        VALUES (?1, ?2, ?3, ?4, ?5)
75        "#,
76        params![
77            relation.source_id,
78            relation.target_id,
79            relation.rel_type,
80            relation.weight,
81            properties_json
82        ],
83    )?;
84
85    Ok(conn.last_insert_rowid())
86}
87
88/// Get neighbors of an entity using BFS traversal.
89pub fn get_neighbors(
90    conn: &rusqlite::Connection,
91    entity_id: i64,
92    depth: u32,
93) -> Result<Vec<Neighbor>> {
94    if depth == 0 {
95        return Ok(Vec::new());
96    }
97
98    if depth > 5 {
99        return Err(Error::InvalidDepth(depth));
100    }
101
102    // Validate entity exists
103    crate::graph::entity::get_entity(conn, entity_id)?;
104
105    let mut result = Vec::new();
106    let mut visited = std::collections::HashSet::new();
107    let mut queue = VecDeque::new();
108    let mut level_queue = VecDeque::new();
109
110    // Start with direct neighbors
111    visited.insert(entity_id);
112    let direct_relations = get_direct_relations(conn, entity_id)?;
113
114    for (relation, neighbor_entity) in direct_relations {
115        let neighbor_id = neighbor_entity.id.ok_or(Error::EntityNotFound(0))?;
116
117        if !visited.contains(&neighbor_id) {
118            visited.insert(neighbor_id);
119            queue.push_back((neighbor_id, 1));
120            level_queue.push_back((neighbor_entity.clone(), relation.clone()));
121            result.push(Neighbor {
122                entity: neighbor_entity,
123                relation,
124            });
125        }
126    }
127
128    // BFS traversal
129    while let Some((current_id, current_depth)) = queue.pop_front() {
130        if current_depth >= depth {
131            continue;
132        }
133
134        let relations = get_direct_relations(conn, current_id)?;
135
136        for (relation, neighbor_entity) in relations {
137            let neighbor_id = neighbor_entity.id.ok_or(Error::EntityNotFound(0))?;
138
139            if !visited.contains(&neighbor_id) {
140                visited.insert(neighbor_id);
141                queue.push_back((neighbor_id, current_depth + 1));
142                level_queue.push_back((neighbor_entity.clone(), relation.clone()));
143                result.push(Neighbor {
144                    entity: neighbor_entity,
145                    relation,
146                });
147            }
148        }
149    }
150
151    Ok(result)
152}
153
154/// Get direct relations for an entity (both incoming and outgoing).
155fn get_direct_relations(
156    conn: &rusqlite::Connection,
157    entity_id: i64,
158) -> Result<Vec<(Relation, Entity)>> {
159    let mut result = Vec::new();
160
161    // Outgoing relations (entity_id is source)
162    let mut stmt = conn.prepare(
163        r#"
164        SELECT r.id, r.source_id, r.target_id, r.rel_type, r.weight, r.properties, r.created_at,
165               e.id, e.entity_type, e.name, e.properties, e.created_at, e.updated_at
166        FROM kg_relations r
167        JOIN kg_entities e ON r.target_id = e.id
168        WHERE r.source_id = ?1
169        "#,
170    )?;
171
172    let rows = stmt.query_map(params![entity_id], |row| {
173        let properties_json: String = row.get(5)?;
174        let properties: HashMap<String, serde_json::Value> =
175            serde_json::from_str(&properties_json).unwrap_or_default();
176
177        let entity_props_json: String = row.get(10)?;
178        let entity_props: HashMap<String, serde_json::Value> =
179            serde_json::from_str(&entity_props_json).unwrap_or_default();
180
181        Ok((
182            Relation {
183                id: Some(row.get(0)?),
184                source_id: row.get(1)?,
185                target_id: row.get(2)?,
186                rel_type: row.get(3)?,
187                weight: row.get(4)?,
188                properties,
189                created_at: row.get(6)?,
190            },
191            Entity {
192                id: Some(row.get(7)?),
193                entity_type: row.get(8)?,
194                name: row.get(9)?,
195                properties: entity_props,
196                created_at: row.get(11)?,
197                updated_at: row.get(12)?,
198            },
199        ))
200    })?;
201
202    for row in rows {
203        result.push(row?);
204    }
205
206    // Incoming relations (entity_id is target)
207    let mut stmt = conn.prepare(
208        r#"
209        SELECT r.id, r.source_id, r.target_id, r.rel_type, r.weight, r.properties, r.created_at,
210               e.id, e.entity_type, e.name, e.properties, e.created_at, e.updated_at
211        FROM kg_relations r
212        JOIN kg_entities e ON r.source_id = e.id
213        WHERE r.target_id = ?1
214        "#,
215    )?;
216
217    let rows = stmt.query_map(params![entity_id], |row| {
218        let properties_json: String = row.get(5)?;
219        let properties: HashMap<String, serde_json::Value> =
220            serde_json::from_str(&properties_json).unwrap_or_default();
221
222        let entity_props_json: String = row.get(10)?;
223        let entity_props: HashMap<String, serde_json::Value> =
224            serde_json::from_str(&entity_props_json).unwrap_or_default();
225
226        Ok((
227            Relation {
228                id: Some(row.get(0)?),
229                source_id: row.get(1)?,
230                target_id: row.get(2)?,
231                rel_type: row.get(3)?,
232                weight: row.get(4)?,
233                properties,
234                created_at: row.get(6)?,
235            },
236            Entity {
237                id: Some(row.get(7)?),
238                entity_type: row.get(8)?,
239                name: row.get(9)?,
240                properties: entity_props,
241                created_at: row.get(11)?,
242                updated_at: row.get(12)?,
243            },
244        ))
245    })?;
246
247    for row in rows {
248        result.push(row?);
249    }
250
251    Ok(result)
252}
253
254/// Get relations by source ID.
255pub fn get_relations_by_source(
256    conn: &rusqlite::Connection,
257    source_id: i64,
258) -> Result<Vec<Relation>> {
259    let mut stmt = conn.prepare(
260        r#"
261        SELECT id, source_id, target_id, rel_type, weight, properties, created_at
262        FROM kg_relations
263        WHERE source_id = ?1
264        "#,
265    )?;
266
267    let relations = stmt.query_map(params![source_id], |row| {
268        let properties_json: String = row.get(5)?;
269        let properties: HashMap<String, serde_json::Value> =
270            serde_json::from_str(&properties_json).unwrap_or_default();
271
272        Ok(Relation {
273            id: Some(row.get(0)?),
274            source_id: row.get(1)?,
275            target_id: row.get(2)?,
276            rel_type: row.get(3)?,
277            weight: row.get(4)?,
278            properties,
279            created_at: row.get(6)?,
280        })
281    })?;
282
283    let mut result = Vec::new();
284    for rel in relations {
285        result.push(rel?);
286    }
287
288    Ok(result)
289}
290
291#[cfg(test)]
292mod tests {
293    use super::*;
294    use crate::graph::entity::{insert_entity, Entity};
295    use rusqlite::Connection;
296
297    #[test]
298    fn test_insert_relation() {
299        let conn = Connection::open_in_memory().unwrap();
300        crate::schema::create_schema(&conn).unwrap();
301
302        let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
303        let entity2_id = insert_entity(&conn, &Entity::new("paper", "Paper 2")).unwrap();
304
305        let relation = Relation::new(entity1_id, entity2_id, "cites", 0.8).unwrap();
306        let id = insert_relation(&conn, &relation).unwrap();
307        assert!(id > 0);
308    }
309
310    #[test]
311    fn test_get_neighbors_depth_1() {
312        let conn = Connection::open_in_memory().unwrap();
313        crate::schema::create_schema(&conn).unwrap();
314
315        let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
316        let entity2_id = insert_entity(&conn, &Entity::new("paper", "Paper 2")).unwrap();
317        let entity3_id = insert_entity(&conn, &Entity::new("paper", "Paper 3")).unwrap();
318
319        let relation = Relation::new(entity1_id, entity2_id, "cites", 0.8).unwrap();
320        insert_relation(&conn, &relation).unwrap();
321
322        let relation = Relation::new(entity2_id, entity3_id, "cites", 0.9).unwrap();
323        insert_relation(&conn, &relation).unwrap();
324
325        let neighbors = get_neighbors(&conn, entity1_id, 1).unwrap();
326        assert_eq!(neighbors.len(), 1);
327        assert_eq!(neighbors[0].entity.name, "Paper 2");
328    }
329
330    #[test]
331    fn test_get_neighbors_depth_2() {
332        let conn = Connection::open_in_memory().unwrap();
333        crate::schema::create_schema(&conn).unwrap();
334
335        let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
336        let entity2_id = insert_entity(&conn, &Entity::new("paper", "Paper 2")).unwrap();
337        let entity3_id = insert_entity(&conn, &Entity::new("paper", "Paper 3")).unwrap();
338
339        let relation = Relation::new(entity1_id, entity2_id, "cites", 0.8).unwrap();
340        insert_relation(&conn, &relation).unwrap();
341
342        let relation = Relation::new(entity2_id, entity3_id, "cites", 0.9).unwrap();
343        insert_relation(&conn, &relation).unwrap();
344
345        let neighbors = get_neighbors(&conn, entity1_id, 2).unwrap();
346        assert_eq!(neighbors.len(), 2);
347        assert!(neighbors.iter().any(|n| n.entity.name == "Paper 2"));
348        assert!(neighbors.iter().any(|n| n.entity.name == "Paper 3"));
349    }
350
351    #[test]
352    fn test_invalid_weight() {
353        let relation = Relation::new(1, 2, "test", 1.5);
354        assert!(relation.is_err());
355    }
356
357    #[test]
358    fn test_invalid_depth() {
359        let conn = Connection::open_in_memory().unwrap();
360        crate::schema::create_schema(&conn).unwrap();
361
362        let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
363
364        let result = get_neighbors(&conn, entity1_id, 10);
365        assert!(result.is_err());
366    }
367}