Skip to main content

sqlite_knowledge_graph/graph/
relation.rs

1//! Relation storage module for the knowledge graph.
2
3use rusqlite::params;
4use serde::{Deserialize, Serialize};
5use std::collections::{HashMap, VecDeque};
6
7use crate::error::{Error, Result};
8use crate::graph::entity::Entity;
9
10/// Represents a relation between entities in the knowledge graph.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Relation {
13    pub id: Option<i64>,
14    pub source_id: i64,
15    pub target_id: i64,
16    pub rel_type: String,
17    pub weight: f64,
18    pub properties: HashMap<String, serde_json::Value>,
19    pub created_at: Option<i64>,
20}
21
22impl Relation {
23    /// Create a new relation.
24    pub fn new(
25        source_id: i64,
26        target_id: i64,
27        rel_type: impl Into<String>,
28        weight: f64,
29    ) -> Result<Self> {
30        if !(0.0..=1.0).contains(&weight) {
31            return Err(Error::InvalidWeight(weight));
32        }
33
34        Ok(Self {
35            id: None,
36            source_id,
37            target_id,
38            rel_type: rel_type.into(),
39            weight,
40            properties: HashMap::new(),
41            created_at: None,
42        })
43    }
44
45    /// Set a property.
46    pub fn set_property(&mut self, key: impl Into<String>, value: serde_json::Value) {
47        self.properties.insert(key.into(), value);
48    }
49
50    /// Get a property.
51    pub fn get_property(&self, key: &str) -> Option<&serde_json::Value> {
52        self.properties.get(key)
53    }
54}
55
56/// Represents a neighbor in a graph traversal.
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct Neighbor {
59    pub entity: Entity,
60    pub relation: Relation,
61}
62
63/// Insert a new relation into the database.
64pub fn insert_relation(conn: &rusqlite::Connection, relation: &Relation) -> Result<i64> {
65    // Validate entities exist
66    crate::graph::entity::get_entity(conn, relation.source_id)?;
67    crate::graph::entity::get_entity(conn, relation.target_id)?;
68
69    let properties_json = serde_json::to_string(&relation.properties)?;
70
71    conn.execute(
72        r#"
73        INSERT INTO kg_relations (source_id, target_id, rel_type, weight, properties)
74        VALUES (?1, ?2, ?3, ?4, ?5)
75        "#,
76        params![
77            relation.source_id,
78            relation.target_id,
79            relation.rel_type,
80            relation.weight,
81            properties_json
82        ],
83    )?;
84
85    Ok(conn.last_insert_rowid())
86}
87
88/// Get neighbors of an entity using BFS traversal.
89pub fn get_neighbors(
90    conn: &rusqlite::Connection,
91    entity_id: i64,
92    depth: u32,
93) -> Result<Vec<Neighbor>> {
94    if depth == 0 {
95        return Ok(Vec::new());
96    }
97
98    if depth > 5 {
99        return Err(Error::InvalidDepth(depth));
100    }
101
102    // Validate entity exists
103    crate::graph::entity::get_entity(conn, entity_id)?;
104
105    let mut result = Vec::new();
106    let mut visited = std::collections::HashSet::new();
107    let mut queue = VecDeque::new();
108
109    // Start with direct neighbors
110    visited.insert(entity_id);
111    let direct_relations = get_direct_relations(conn, entity_id)?;
112
113    for (relation, neighbor_entity) in direct_relations {
114        let neighbor_id = neighbor_entity.id.ok_or(Error::EntityNotFound(0))?;
115
116        if !visited.contains(&neighbor_id) {
117            visited.insert(neighbor_id);
118            queue.push_back((neighbor_id, 1));
119            result.push(Neighbor {
120                entity: neighbor_entity,
121                relation,
122            });
123        }
124    }
125
126    // BFS traversal
127    while let Some((current_id, current_depth)) = queue.pop_front() {
128        if current_depth >= depth {
129            continue;
130        }
131
132        let relations = get_direct_relations(conn, current_id)?;
133
134        for (relation, neighbor_entity) in relations {
135            let neighbor_id = neighbor_entity.id.ok_or(Error::EntityNotFound(0))?;
136
137            if !visited.contains(&neighbor_id) {
138                visited.insert(neighbor_id);
139                queue.push_back((neighbor_id, current_depth + 1));
140                result.push(Neighbor {
141                    entity: neighbor_entity,
142                    relation,
143                });
144            }
145        }
146    }
147
148    Ok(result)
149}
150
151/// Get direct relations for an entity (both incoming and outgoing).
152fn get_direct_relations(
153    conn: &rusqlite::Connection,
154    entity_id: i64,
155) -> Result<Vec<(Relation, Entity)>> {
156    let mut result = Vec::new();
157
158    // Outgoing relations (entity_id is source)
159    let mut stmt = conn.prepare(
160        r#"
161        SELECT r.id, r.source_id, r.target_id, r.rel_type, r.weight, r.properties, r.created_at,
162               e.id, e.entity_type, e.name, e.properties, e.created_at, e.updated_at
163        FROM kg_relations r
164        JOIN kg_entities e ON r.target_id = e.id
165        WHERE r.source_id = ?1
166        "#,
167    )?;
168
169    let rows = stmt.query_map(params![entity_id], |row| {
170        let properties_json: String = row.get(5)?;
171        let properties: HashMap<String, serde_json::Value> =
172            serde_json::from_str(&properties_json).unwrap_or_default();
173
174        let entity_props_json: String = row.get(10)?;
175        let entity_props: HashMap<String, serde_json::Value> =
176            serde_json::from_str(&entity_props_json).unwrap_or_default();
177
178        Ok((
179            Relation {
180                id: Some(row.get(0)?),
181                source_id: row.get(1)?,
182                target_id: row.get(2)?,
183                rel_type: row.get(3)?,
184                weight: row.get(4)?,
185                properties,
186                created_at: row.get(6)?,
187            },
188            Entity {
189                id: Some(row.get(7)?),
190                entity_type: row.get(8)?,
191                name: row.get(9)?,
192                properties: entity_props,
193                created_at: row.get(11)?,
194                updated_at: row.get(12)?,
195            },
196        ))
197    })?;
198
199    for row in rows {
200        result.push(row?);
201    }
202
203    // Incoming relations (entity_id is target)
204    let mut stmt = conn.prepare(
205        r#"
206        SELECT r.id, r.source_id, r.target_id, r.rel_type, r.weight, r.properties, r.created_at,
207               e.id, e.entity_type, e.name, e.properties, e.created_at, e.updated_at
208        FROM kg_relations r
209        JOIN kg_entities e ON r.source_id = e.id
210        WHERE r.target_id = ?1
211        "#,
212    )?;
213
214    let rows = stmt.query_map(params![entity_id], |row| {
215        let properties_json: String = row.get(5)?;
216        let properties: HashMap<String, serde_json::Value> =
217            serde_json::from_str(&properties_json).unwrap_or_default();
218
219        let entity_props_json: String = row.get(10)?;
220        let entity_props: HashMap<String, serde_json::Value> =
221            serde_json::from_str(&entity_props_json).unwrap_or_default();
222
223        Ok((
224            Relation {
225                id: Some(row.get(0)?),
226                source_id: row.get(1)?,
227                target_id: row.get(2)?,
228                rel_type: row.get(3)?,
229                weight: row.get(4)?,
230                properties,
231                created_at: row.get(6)?,
232            },
233            Entity {
234                id: Some(row.get(7)?),
235                entity_type: row.get(8)?,
236                name: row.get(9)?,
237                properties: entity_props,
238                created_at: row.get(11)?,
239                updated_at: row.get(12)?,
240            },
241        ))
242    })?;
243
244    for row in rows {
245        result.push(row?);
246    }
247
248    Ok(result)
249}
250
251/// Get relations by source ID.
252pub fn get_relations_by_source(
253    conn: &rusqlite::Connection,
254    source_id: i64,
255) -> Result<Vec<Relation>> {
256    let mut stmt = conn.prepare(
257        r#"
258        SELECT id, source_id, target_id, rel_type, weight, properties, created_at
259        FROM kg_relations
260        WHERE source_id = ?1
261        "#,
262    )?;
263
264    let relations = stmt.query_map(params![source_id], |row| {
265        let properties_json: String = row.get(5)?;
266        let properties: HashMap<String, serde_json::Value> =
267            serde_json::from_str(&properties_json).unwrap_or_default();
268
269        Ok(Relation {
270            id: Some(row.get(0)?),
271            source_id: row.get(1)?,
272            target_id: row.get(2)?,
273            rel_type: row.get(3)?,
274            weight: row.get(4)?,
275            properties,
276            created_at: row.get(6)?,
277        })
278    })?;
279
280    let mut result = Vec::new();
281    for rel in relations {
282        result.push(rel?);
283    }
284
285    Ok(result)
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use crate::graph::entity::{insert_entity, Entity};
292    use rusqlite::Connection;
293
294    #[test]
295    fn test_insert_relation() {
296        let conn = Connection::open_in_memory().unwrap();
297        crate::schema::create_schema(&conn).unwrap();
298
299        let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
300        let entity2_id = insert_entity(&conn, &Entity::new("paper", "Paper 2")).unwrap();
301
302        let relation = Relation::new(entity1_id, entity2_id, "cites", 0.8).unwrap();
303        let id = insert_relation(&conn, &relation).unwrap();
304        assert!(id > 0);
305    }
306
307    #[test]
308    fn test_get_neighbors_depth_1() {
309        let conn = Connection::open_in_memory().unwrap();
310        crate::schema::create_schema(&conn).unwrap();
311
312        let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
313        let entity2_id = insert_entity(&conn, &Entity::new("paper", "Paper 2")).unwrap();
314        let entity3_id = insert_entity(&conn, &Entity::new("paper", "Paper 3")).unwrap();
315
316        let relation = Relation::new(entity1_id, entity2_id, "cites", 0.8).unwrap();
317        insert_relation(&conn, &relation).unwrap();
318
319        let relation = Relation::new(entity2_id, entity3_id, "cites", 0.9).unwrap();
320        insert_relation(&conn, &relation).unwrap();
321
322        let neighbors = get_neighbors(&conn, entity1_id, 1).unwrap();
323        assert_eq!(neighbors.len(), 1);
324        assert_eq!(neighbors[0].entity.name, "Paper 2");
325    }
326
327    #[test]
328    fn test_get_neighbors_depth_2() {
329        let conn = Connection::open_in_memory().unwrap();
330        crate::schema::create_schema(&conn).unwrap();
331
332        let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
333        let entity2_id = insert_entity(&conn, &Entity::new("paper", "Paper 2")).unwrap();
334        let entity3_id = insert_entity(&conn, &Entity::new("paper", "Paper 3")).unwrap();
335
336        let relation = Relation::new(entity1_id, entity2_id, "cites", 0.8).unwrap();
337        insert_relation(&conn, &relation).unwrap();
338
339        let relation = Relation::new(entity2_id, entity3_id, "cites", 0.9).unwrap();
340        insert_relation(&conn, &relation).unwrap();
341
342        let neighbors = get_neighbors(&conn, entity1_id, 2).unwrap();
343        assert_eq!(neighbors.len(), 2);
344        assert!(neighbors.iter().any(|n| n.entity.name == "Paper 2"));
345        assert!(neighbors.iter().any(|n| n.entity.name == "Paper 3"));
346    }
347
348    #[test]
349    fn test_invalid_weight() {
350        let relation = Relation::new(1, 2, "test", 1.5);
351        assert!(relation.is_err());
352    }
353
354    #[test]
355    fn test_invalid_depth() {
356        let conn = Connection::open_in_memory().unwrap();
357        crate::schema::create_schema(&conn).unwrap();
358
359        let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
360
361        let result = get_neighbors(&conn, entity1_id, 10);
362        assert!(result.is_err());
363    }
364}