1use rusqlite::params;
4use serde::{Deserialize, Serialize};
5use std::collections::{HashMap, VecDeque};
6
7use crate::error::{Error, Result};
8use crate::graph::entity::Entity;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Relation {
13 pub id: Option<i64>,
14 pub source_id: i64,
15 pub target_id: i64,
16 pub rel_type: String,
17 pub weight: f64,
18 pub properties: HashMap<String, serde_json::Value>,
19 pub created_at: Option<i64>,
20}
21
22impl Relation {
23 pub fn new(
25 source_id: i64,
26 target_id: i64,
27 rel_type: impl Into<String>,
28 weight: f64,
29 ) -> Result<Self> {
30 if !(0.0..=1.0).contains(&weight) {
31 return Err(Error::InvalidWeight(weight));
32 }
33
34 Ok(Self {
35 id: None,
36 source_id,
37 target_id,
38 rel_type: rel_type.into(),
39 weight,
40 properties: HashMap::new(),
41 created_at: None,
42 })
43 }
44
45 pub fn set_property(&mut self, key: impl Into<String>, value: serde_json::Value) {
47 self.properties.insert(key.into(), value);
48 }
49
50 pub fn get_property(&self, key: &str) -> Option<&serde_json::Value> {
52 self.properties.get(key)
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct Neighbor {
59 pub entity: Entity,
60 pub relation: Relation,
61}
62
63pub fn insert_relation(conn: &rusqlite::Connection, relation: &Relation) -> Result<i64> {
65 crate::graph::entity::get_entity(conn, relation.source_id)?;
67 crate::graph::entity::get_entity(conn, relation.target_id)?;
68
69 let properties_json = serde_json::to_string(&relation.properties)?;
70
71 conn.execute(
72 r#"
73 INSERT INTO kg_relations (source_id, target_id, rel_type, weight, properties)
74 VALUES (?1, ?2, ?3, ?4, ?5)
75 "#,
76 params![
77 relation.source_id,
78 relation.target_id,
79 relation.rel_type,
80 relation.weight,
81 properties_json
82 ],
83 )?;
84
85 Ok(conn.last_insert_rowid())
86}
87
88pub fn get_neighbors(
90 conn: &rusqlite::Connection,
91 entity_id: i64,
92 depth: u32,
93) -> Result<Vec<Neighbor>> {
94 if depth == 0 {
95 return Ok(Vec::new());
96 }
97
98 if depth > 5 {
99 return Err(Error::InvalidDepth(depth));
100 }
101
102 crate::graph::entity::get_entity(conn, entity_id)?;
104
105 let mut result = Vec::new();
106 let mut visited = std::collections::HashSet::new();
107 let mut queue = VecDeque::new();
108 let mut level_queue = VecDeque::new();
109
110 visited.insert(entity_id);
112 let direct_relations = get_direct_relations(conn, entity_id)?;
113
114 for (relation, neighbor_entity) in direct_relations {
115 let neighbor_id = neighbor_entity.id.ok_or(Error::EntityNotFound(0))?;
116
117 if !visited.contains(&neighbor_id) {
118 visited.insert(neighbor_id);
119 queue.push_back((neighbor_id, 1));
120 level_queue.push_back((neighbor_entity.clone(), relation.clone()));
121 result.push(Neighbor {
122 entity: neighbor_entity,
123 relation,
124 });
125 }
126 }
127
128 while let Some((current_id, current_depth)) = queue.pop_front() {
130 if current_depth >= depth {
131 continue;
132 }
133
134 let relations = get_direct_relations(conn, current_id)?;
135
136 for (relation, neighbor_entity) in relations {
137 let neighbor_id = neighbor_entity.id.ok_or(Error::EntityNotFound(0))?;
138
139 if !visited.contains(&neighbor_id) {
140 visited.insert(neighbor_id);
141 queue.push_back((neighbor_id, current_depth + 1));
142 level_queue.push_back((neighbor_entity.clone(), relation.clone()));
143 result.push(Neighbor {
144 entity: neighbor_entity,
145 relation,
146 });
147 }
148 }
149 }
150
151 Ok(result)
152}
153
154fn get_direct_relations(
156 conn: &rusqlite::Connection,
157 entity_id: i64,
158) -> Result<Vec<(Relation, Entity)>> {
159 let mut result = Vec::new();
160
161 let mut stmt = conn.prepare(
163 r#"
164 SELECT r.id, r.source_id, r.target_id, r.rel_type, r.weight, r.properties, r.created_at,
165 e.id, e.entity_type, e.name, e.properties, e.created_at, e.updated_at
166 FROM kg_relations r
167 JOIN kg_entities e ON r.target_id = e.id
168 WHERE r.source_id = ?1
169 "#,
170 )?;
171
172 let rows = stmt.query_map(params![entity_id], |row| {
173 let properties_json: String = row.get(5)?;
174 let properties: HashMap<String, serde_json::Value> =
175 serde_json::from_str(&properties_json).unwrap_or_default();
176
177 let entity_props_json: String = row.get(10)?;
178 let entity_props: HashMap<String, serde_json::Value> =
179 serde_json::from_str(&entity_props_json).unwrap_or_default();
180
181 Ok((
182 Relation {
183 id: Some(row.get(0)?),
184 source_id: row.get(1)?,
185 target_id: row.get(2)?,
186 rel_type: row.get(3)?,
187 weight: row.get(4)?,
188 properties,
189 created_at: row.get(6)?,
190 },
191 Entity {
192 id: Some(row.get(7)?),
193 entity_type: row.get(8)?,
194 name: row.get(9)?,
195 properties: entity_props,
196 created_at: row.get(11)?,
197 updated_at: row.get(12)?,
198 },
199 ))
200 })?;
201
202 for row in rows {
203 result.push(row?);
204 }
205
206 let mut stmt = conn.prepare(
208 r#"
209 SELECT r.id, r.source_id, r.target_id, r.rel_type, r.weight, r.properties, r.created_at,
210 e.id, e.entity_type, e.name, e.properties, e.created_at, e.updated_at
211 FROM kg_relations r
212 JOIN kg_entities e ON r.source_id = e.id
213 WHERE r.target_id = ?1
214 "#,
215 )?;
216
217 let rows = stmt.query_map(params![entity_id], |row| {
218 let properties_json: String = row.get(5)?;
219 let properties: HashMap<String, serde_json::Value> =
220 serde_json::from_str(&properties_json).unwrap_or_default();
221
222 let entity_props_json: String = row.get(10)?;
223 let entity_props: HashMap<String, serde_json::Value> =
224 serde_json::from_str(&entity_props_json).unwrap_or_default();
225
226 Ok((
227 Relation {
228 id: Some(row.get(0)?),
229 source_id: row.get(1)?,
230 target_id: row.get(2)?,
231 rel_type: row.get(3)?,
232 weight: row.get(4)?,
233 properties,
234 created_at: row.get(6)?,
235 },
236 Entity {
237 id: Some(row.get(7)?),
238 entity_type: row.get(8)?,
239 name: row.get(9)?,
240 properties: entity_props,
241 created_at: row.get(11)?,
242 updated_at: row.get(12)?,
243 },
244 ))
245 })?;
246
247 for row in rows {
248 result.push(row?);
249 }
250
251 Ok(result)
252}
253
254pub fn get_relations_by_source(
256 conn: &rusqlite::Connection,
257 source_id: i64,
258) -> Result<Vec<Relation>> {
259 let mut stmt = conn.prepare(
260 r#"
261 SELECT id, source_id, target_id, rel_type, weight, properties, created_at
262 FROM kg_relations
263 WHERE source_id = ?1
264 "#,
265 )?;
266
267 let relations = stmt.query_map(params![source_id], |row| {
268 let properties_json: String = row.get(5)?;
269 let properties: HashMap<String, serde_json::Value> =
270 serde_json::from_str(&properties_json).unwrap_or_default();
271
272 Ok(Relation {
273 id: Some(row.get(0)?),
274 source_id: row.get(1)?,
275 target_id: row.get(2)?,
276 rel_type: row.get(3)?,
277 weight: row.get(4)?,
278 properties,
279 created_at: row.get(6)?,
280 })
281 })?;
282
283 let mut result = Vec::new();
284 for rel in relations {
285 result.push(rel?);
286 }
287
288 Ok(result)
289}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294 use crate::graph::entity::{insert_entity, Entity};
295 use rusqlite::Connection;
296
297 #[test]
298 fn test_insert_relation() {
299 let conn = Connection::open_in_memory().unwrap();
300 crate::schema::create_schema(&conn).unwrap();
301
302 let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
303 let entity2_id = insert_entity(&conn, &Entity::new("paper", "Paper 2")).unwrap();
304
305 let relation = Relation::new(entity1_id, entity2_id, "cites", 0.8).unwrap();
306 let id = insert_relation(&conn, &relation).unwrap();
307 assert!(id > 0);
308 }
309
310 #[test]
311 fn test_get_neighbors_depth_1() {
312 let conn = Connection::open_in_memory().unwrap();
313 crate::schema::create_schema(&conn).unwrap();
314
315 let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
316 let entity2_id = insert_entity(&conn, &Entity::new("paper", "Paper 2")).unwrap();
317 let entity3_id = insert_entity(&conn, &Entity::new("paper", "Paper 3")).unwrap();
318
319 let relation = Relation::new(entity1_id, entity2_id, "cites", 0.8).unwrap();
320 insert_relation(&conn, &relation).unwrap();
321
322 let relation = Relation::new(entity2_id, entity3_id, "cites", 0.9).unwrap();
323 insert_relation(&conn, &relation).unwrap();
324
325 let neighbors = get_neighbors(&conn, entity1_id, 1).unwrap();
326 assert_eq!(neighbors.len(), 1);
327 assert_eq!(neighbors[0].entity.name, "Paper 2");
328 }
329
330 #[test]
331 fn test_get_neighbors_depth_2() {
332 let conn = Connection::open_in_memory().unwrap();
333 crate::schema::create_schema(&conn).unwrap();
334
335 let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
336 let entity2_id = insert_entity(&conn, &Entity::new("paper", "Paper 2")).unwrap();
337 let entity3_id = insert_entity(&conn, &Entity::new("paper", "Paper 3")).unwrap();
338
339 let relation = Relation::new(entity1_id, entity2_id, "cites", 0.8).unwrap();
340 insert_relation(&conn, &relation).unwrap();
341
342 let relation = Relation::new(entity2_id, entity3_id, "cites", 0.9).unwrap();
343 insert_relation(&conn, &relation).unwrap();
344
345 let neighbors = get_neighbors(&conn, entity1_id, 2).unwrap();
346 assert_eq!(neighbors.len(), 2);
347 assert!(neighbors.iter().any(|n| n.entity.name == "Paper 2"));
348 assert!(neighbors.iter().any(|n| n.entity.name == "Paper 3"));
349 }
350
351 #[test]
352 fn test_invalid_weight() {
353 let relation = Relation::new(1, 2, "test", 1.5);
354 assert!(relation.is_err());
355 }
356
357 #[test]
358 fn test_invalid_depth() {
359 let conn = Connection::open_in_memory().unwrap();
360 crate::schema::create_schema(&conn).unwrap();
361
362 let entity1_id = insert_entity(&conn, &Entity::new("paper", "Paper 1")).unwrap();
363
364 let result = get_neighbors(&conn, entity1_id, 10);
365 assert!(result.is_err());
366 }
367}