1use rusqlite::params;
4use serde::{Deserialize, Serialize};
5use std::collections::{HashMap, VecDeque};
6
7use crate::error::{Error, Result};
8use crate::graph::entity::Entity;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Relation {
13 pub id: Option<i64>,
14 pub source_id: i64,
15 pub target_id: i64,
16 pub rel_type: String,
17 pub weight: f64,
18 pub properties: HashMap<String, serde_json::Value>,
19 pub created_at: Option<i64>,
20}
21
22impl Relation {
23 pub fn new(
25 source_id: i64,
26 target_id: i64,
27 rel_type: impl Into<String>,
28 weight: f64,
29 ) -> Result<Self> {
30 if !(0.0..=1.0).contains(&weight) {
31 return Err(Error::InvalidWeight(weight));
32 }
33
34 Ok(Self {
35 id: None,
36 source_id,
37 target_id,
38 rel_type: rel_type.into(),
39 weight,
40 properties: HashMap::new(),
41 created_at: None,
42 })
43 }
44
45 pub fn set_property(&mut self, key: impl Into<String>, value: serde_json::Value) {
47 self.properties.insert(key.into(), value);
48 }
49
50 pub fn get_property(&self, key: &str) -> Option<&serde_json::Value> {
52 self.properties.get(key)
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct Neighbor {
59 pub entity: Entity,
60 pub relation: Relation,
61}
62
63pub fn insert_relation(conn: &rusqlite::Connection, relation: &Relation) -> Result<i64> {
65 crate::graph::entity::get_entity(conn, relation.source_id)?;
67 crate::graph::entity::get_entity(conn, relation.target_id)?;
68
69 let properties_json = serde_json::to_string(&relation.properties)?;
70
71 conn.execute(
72 r#"
73 INSERT INTO kg_relations (source_id, target_id, rel_type, weight, properties)
74 VALUES (?1, ?2, ?3, ?4, ?5)
75 "#,
76 params![
77 relation.source_id,
78 relation.target_id,
79 relation.rel_type,
80 relation.weight,
81 properties_json
82 ],
83 )?;
84
85 Ok(conn.last_insert_rowid())
86}
87
88pub fn get_neighbors(
90 conn: &rusqlite::Connection,
91 entity_id: i64,
92 depth: u32,
93) -> Result<Vec<Neighbor>> {
94 if depth == 0 {
95 return Ok(Vec::new());
96 }
97
98 if depth > 5 {
99 return Err(Error::InvalidDepth(depth));
100 }
101
102 crate::graph::entity::get_entity(conn, entity_id)?;
104
105 let mut result = Vec::new();
106 let mut visited = std::collections::HashSet::new();
107 let mut queue = VecDeque::new();
108
109 visited.insert(entity_id);
111 let direct_relations = get_direct_relations(conn, entity_id)?;
112
113 for (relation, neighbor_entity) in direct_relations {
114 let neighbor_id = neighbor_entity.id.ok_or(Error::EntityNotFound(0))?;
115
116 if !visited.contains(&neighbor_id) {
117 visited.insert(neighbor_id);
118 queue.push_back((neighbor_id, 1));
119 result.push(Neighbor {
120 entity: neighbor_entity,
121 relation,
122 });
123 }
124 }
125
126 while let Some((current_id, current_depth)) = queue.pop_front() {
128 if current_depth >= depth {
129 continue;
130 }
131
132 let relations = get_direct_relations(conn, current_id)?;
133
134 for (relation, neighbor_entity) in relations {
135 let neighbor_id = neighbor_entity.id.ok_or(Error::EntityNotFound(0))?;
136
137 if !visited.contains(&neighbor_id) {
138 visited.insert(neighbor_id);
139 queue.push_back((neighbor_id, current_depth + 1));
140 result.push(Neighbor {
141 entity: neighbor_entity,
142 relation,
143 });
144 }
145 }
146 }
147
148 Ok(result)
149}
150
151fn get_direct_relations(
153 conn: &rusqlite::Connection,
154 entity_id: i64,
155) -> Result<Vec<(Relation, Entity)>> {
156 let mut result = Vec::new();
157
158 let mut stmt = conn.prepare(
160 r#"
161 SELECT r.id, r.source_id, r.target_id, r.rel_type, r.weight, r.properties, r.created_at,
162 e.id, e.entity_type, e.name, e.properties, e.created_at, e.updated_at
163 FROM kg_relations r
164 JOIN kg_entities e ON r.target_id = e.id
165 WHERE r.source_id = ?1
166 "#,
167 )?;
168
169 let rows = stmt.query_map(params![entity_id], |row| {
170 let properties_json: String = row.get(5)?;
171 let properties: HashMap<String, serde_json::Value> =
172 serde_json::from_str(&properties_json).unwrap_or_default();
173
174 let entity_props_json: String = row.get(10)?;
175 let entity_props: HashMap<String, serde_json::Value> =
176 serde_json::from_str(&entity_props_json).unwrap_or_default();
177
178 Ok((
179 Relation {
180 id: Some(row.get(0)?),
181 source_id: row.get(1)?,
182 target_id: row.get(2)?,
183 rel_type: row.get(3)?,
184 weight: row.get(4)?,
185 properties,
186 created_at: row.get(6)?,
187 },
188 Entity {
189 id: Some(row.get(7)?),
190 entity_type: row.get(8)?,
191 name: row.get(9)?,
192 properties: entity_props,
193 created_at: row.get(11)?,
194 updated_at: row.get(12)?,
195 },
196 ))
197 })?;
198
199 for row in rows {
200 result.push(row?);
201 }
202
203 let mut stmt = conn.prepare(
205 r#"
206 SELECT r.id, r.source_id, r.target_id, r.rel_type, r.weight, r.properties, r.created_at,
207 e.id, e.entity_type, e.name, e.properties, e.created_at, e.updated_at
208 FROM kg_relations r
209 JOIN kg_entities e ON r.source_id = e.id
210 WHERE r.target_id = ?1
211 "#,
212 )?;
213
214 let rows = stmt.query_map(params![entity_id], |row| {
215 let properties_json: String = row.get(5)?;
216 let properties: HashMap<String, serde_json::Value> =
217 serde_json::from_str(&properties_json).unwrap_or_default();
218
219 let entity_props_json: String = row.get(10)?;
220 let entity_props: HashMap<String, serde_json::Value> =
221 serde_json::from_str(&entity_props_json).unwrap_or_default();
222
223 Ok((
224 Relation {
225 id: Some(row.get(0)?),
226 source_id: row.get(1)?,
227 target_id: row.get(2)?,
228 rel_type: row.get(3)?,
229 weight: row.get(4)?,
230 properties,
231 created_at: row.get(6)?,
232 },
233 Entity {
234 id: Some(row.get(7)?),
235 entity_type: row.get(8)?,
236 name: row.get(9)?,
237 properties: entity_props,
238 created_at: row.get(11)?,
239 updated_at: row.get(12)?,
240 },
241 ))
242 })?;
243
244 for row in rows {
245 result.push(row?);
246 }
247
248 Ok(result)
249}
250
251pub fn get_relations_by_source(
253 conn: &rusqlite::Connection,
254 source_id: i64,
255) -> Result<Vec<Relation>> {
256 let mut stmt = conn.prepare(
257 r#"
258 SELECT id, source_id, target_id, rel_type, weight, properties, created_at
259 FROM kg_relations
260 WHERE source_id = ?1
261 "#,
262 )?;
263
264 let relations = stmt.query_map(params![source_id], |row| {
265 let properties_json: String = row.get(5)?;
266 let properties: HashMap<String, serde_json::Value> =
267 serde_json::from_str(&properties_json).unwrap_or_default();
268
269 Ok(Relation {
270 id: Some(row.get(0)?),
271 source_id: row.get(1)?,
272 target_id: row.get(2)?,
273 rel_type: row.get(3)?,
274 weight: row.get(4)?,
275 properties,
276 created_at: row.get(6)?,
277 })
278 })?;
279
280 let mut result = Vec::new();
281 for rel in relations {
282 result.push(rel?);
283 }
284
285 Ok(result)
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use crate::graph::entity::{insert_entity, Entity};
292 use rusqlite::Connection;
293
294 #[test]
295 fn test_insert_relation() {
296 let conn = Connection::open_in_memory().unwrap();
297 crate::schema::create_schema(&conn).unwrap();
298
299 let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
300 let entity2_id = insert_entity(&conn, &Entity::new("paper", "Paper 2")).unwrap();
301
302 let relation = Relation::new(entity1_id, entity2_id, "cites", 0.8).unwrap();
303 let id = insert_relation(&conn, &relation).unwrap();
304 assert!(id > 0);
305 }
306
307 #[test]
308 fn test_get_neighbors_depth_1() {
309 let conn = Connection::open_in_memory().unwrap();
310 crate::schema::create_schema(&conn).unwrap();
311
312 let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
313 let entity2_id = insert_entity(&conn, &Entity::new("paper", "Paper 2")).unwrap();
314 let entity3_id = insert_entity(&conn, &Entity::new("paper", "Paper 3")).unwrap();
315
316 let relation = Relation::new(entity1_id, entity2_id, "cites", 0.8).unwrap();
317 insert_relation(&conn, &relation).unwrap();
318
319 let relation = Relation::new(entity2_id, entity3_id, "cites", 0.9).unwrap();
320 insert_relation(&conn, &relation).unwrap();
321
322 let neighbors = get_neighbors(&conn, entity1_id, 1).unwrap();
323 assert_eq!(neighbors.len(), 1);
324 assert_eq!(neighbors[0].entity.name, "Paper 2");
325 }
326
327 #[test]
328 fn test_get_neighbors_depth_2() {
329 let conn = Connection::open_in_memory().unwrap();
330 crate::schema::create_schema(&conn).unwrap();
331
332 let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
333 let entity2_id = insert_entity(&conn, &Entity::new("paper", "Paper 2")).unwrap();
334 let entity3_id = insert_entity(&conn, &Entity::new("paper", "Paper 3")).unwrap();
335
336 let relation = Relation::new(entity1_id, entity2_id, "cites", 0.8).unwrap();
337 insert_relation(&conn, &relation).unwrap();
338
339 let relation = Relation::new(entity2_id, entity3_id, "cites", 0.9).unwrap();
340 insert_relation(&conn, &relation).unwrap();
341
342 let neighbors = get_neighbors(&conn, entity1_id, 2).unwrap();
343 assert_eq!(neighbors.len(), 2);
344 assert!(neighbors.iter().any(|n| n.entity.name == "Paper 2"));
345 assert!(neighbors.iter().any(|n| n.entity.name == "Paper 3"));
346 }
347
348 #[test]
349 fn test_invalid_weight() {
350 let relation = Relation::new(1, 2, "test", 1.5);
351 assert!(relation.is_err());
352 }
353
354 #[test]
355 fn test_invalid_depth() {
356 let conn = Connection::open_in_memory().unwrap();
357 crate::schema::create_schema(&conn).unwrap();
358
359 let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
360
361 let result = get_neighbors(&conn, entity1_id, 10);
362 assert!(result.is_err());
363 }
364}