1use std::collections::{HashSet, VecDeque};
4
5use rusqlite::params;
6
7use super::store;
8use crate::error::{Error, Result};
9use crate::graph::entity::Entity;
10use crate::graph::relation::Relation;
11
12pub fn version_entities(
14 conn: &rusqlite::Connection,
15 version_id: i64,
16 entity_type: Option<&str>,
17 limit: Option<i64>,
18) -> Result<Vec<Entity>> {
19 let bit = store::version_bit_for(conn, version_id)?;
20 let mut query = String::from(
21 "SELECT id, entity_type, name, properties, created_at, updated_at \
22 FROM kg_entities WHERE (validity & ?1) != 0",
23 );
24
25 let mut param_idx = 2;
26 if entity_type.is_some() {
27 query.push_str(&format!(" AND entity_type = ?{param_idx}"));
28 param_idx += 1;
29 }
30
31 if limit.is_some() {
32 query.push_str(&format!(" LIMIT ?{param_idx}"));
33 }
34
35 let mut stmt = conn.prepare(&query)?;
36
37 let mut param_vec: Vec<Box<dyn rusqlite::ToSql>> = vec![Box::new(bit)];
38 if let Some(et) = entity_type {
39 param_vec.push(Box::new(et.to_string()));
40 }
41 if let Some(lim) = limit {
42 param_vec.push(Box::new(lim));
43 }
44
45 let params_refs: Vec<&dyn rusqlite::ToSql> = param_vec.iter().map(|p| p.as_ref()).collect();
46
47 let entities = stmt.query_map(params_refs.as_slice(), row_to_entity)?;
48
49 let mut result = Vec::new();
50 for e in entities {
51 result.push(e?);
52 }
53 Ok(result)
54}
55
56pub fn version_relations(
58 conn: &rusqlite::Connection,
59 version_id: i64,
60 rel_type: Option<&str>,
61 source_id: Option<i64>,
62 target_id: Option<i64>,
63 limit: Option<i64>,
64) -> Result<Vec<Relation>> {
65 let bit = store::version_bit_for(conn, version_id)?;
66 let mut query = String::from(
67 "SELECT id, source_id, target_id, rel_type, weight, properties, created_at \
68 FROM kg_relations WHERE (validity & ?1) != 0",
69 );
70
71 let mut param_idx = 2;
72 if rel_type.is_some() {
73 query.push_str(&format!(" AND rel_type = ?{param_idx}"));
74 param_idx += 1;
75 }
76 if source_id.is_some() {
77 query.push_str(&format!(" AND source_id = ?{param_idx}"));
78 param_idx += 1;
79 }
80 if target_id.is_some() {
81 query.push_str(&format!(" AND target_id = ?{param_idx}"));
82 param_idx += 1;
83 }
84 if limit.is_some() {
85 query.push_str(&format!(" LIMIT ?{param_idx}"));
86 }
87
88 let mut stmt = conn.prepare(&query)?;
89
90 let mut param_vec: Vec<Box<dyn rusqlite::ToSql>> = vec![Box::new(bit)];
91 if let Some(rt) = rel_type {
92 param_vec.push(Box::new(rt.to_string()));
93 }
94 if let Some(sid) = source_id {
95 param_vec.push(Box::new(sid));
96 }
97 if let Some(tid) = target_id {
98 param_vec.push(Box::new(tid));
99 }
100 if let Some(lim) = limit {
101 param_vec.push(Box::new(lim));
102 }
103
104 let params_refs: Vec<&dyn rusqlite::ToSql> = param_vec.iter().map(|p| p.as_ref()).collect();
105
106 let relations = stmt.query_map(params_refs.as_slice(), row_to_relation)?;
107
108 let mut result = Vec::new();
109 for r in relations {
110 result.push(r?);
111 }
112 Ok(result)
113}
114
115pub fn version_neighbors(
118 conn: &rusqlite::Connection,
119 entity_id: i64,
120 version_id: i64,
121 depth: u32,
122) -> Result<Vec<crate::graph::relation::Neighbor>> {
123 if depth == 0 {
124 return Ok(Vec::new());
125 }
126 if depth > 5 {
127 return Err(Error::InvalidDepth(depth));
128 }
129
130 let bit = store::version_bit_for(conn, version_id)?;
131 store::ensure_entity_exists(conn, entity_id)?;
132
133 if !entity_in_version(conn, entity_id, bit)? {
136 return Ok(Vec::new());
137 }
138
139 let mut result = Vec::new();
140 let mut visited = HashSet::new();
141 let mut queue = VecDeque::new();
142
143 visited.insert(entity_id);
144
145 let direct = get_direct_version_relations(conn, entity_id, bit)?;
146 for (relation, neighbor_entity) in direct {
147 let nid = neighbor_entity.id.ok_or(Error::EntityNotFound(0))?;
148 if !visited.contains(&nid) {
149 visited.insert(nid);
150 queue.push_back((nid, 1));
151 result.push(crate::graph::relation::Neighbor {
152 entity: neighbor_entity,
153 relation,
154 });
155 }
156 }
157
158 while let Some((current_id, current_depth)) = queue.pop_front() {
159 if current_depth >= depth {
160 continue;
161 }
162
163 let relations = get_direct_version_relations(conn, current_id, bit)?;
164 for (relation, neighbor_entity) in relations {
165 let nid = neighbor_entity.id.ok_or(Error::EntityNotFound(0))?;
166 if !visited.contains(&nid) {
167 visited.insert(nid);
168 queue.push_back((nid, current_depth + 1));
169 result.push(crate::graph::relation::Neighbor {
170 entity: neighbor_entity,
171 relation,
172 });
173 }
174 }
175 }
176
177 Ok(result)
178}
179
180fn entity_in_version(conn: &rusqlite::Connection, entity_id: i64, bit: i64) -> Result<bool> {
182 let present: bool = conn.query_row(
183 "SELECT COALESCE((validity & ?1) != 0, 0) FROM kg_entities WHERE id = ?2",
184 params![bit, entity_id],
185 |r| r.get(0),
186 )?;
187 Ok(present)
188}
189
190fn get_direct_version_relations(
192 conn: &rusqlite::Connection,
193 entity_id: i64,
194 bit: i64,
195) -> Result<Vec<(Relation, Entity)>> {
196 let mut result = Vec::new();
197
198 let mut stmt = conn.prepare(
200 "SELECT r.id, r.source_id, r.target_id, r.rel_type, r.weight, r.properties, r.created_at,
201 e.id, e.entity_type, e.name, e.properties, e.created_at, e.updated_at
202 FROM kg_relations r
203 JOIN kg_entities e ON r.target_id = e.id
204 WHERE r.source_id = ?1 AND (r.validity & ?2) != 0 AND (e.validity & ?2) != 0",
205 )?;
206 let rows = stmt.query_map(params![entity_id, bit], |row| {
207 Ok((row_to_relation(row)?, row_to_entity_offset(row, 7)?))
208 })?;
209 for row in rows {
210 result.push(row?);
211 }
212
213 let mut stmt = conn.prepare(
215 "SELECT r.id, r.source_id, r.target_id, r.rel_type, r.weight, r.properties, r.created_at,
216 e.id, e.entity_type, e.name, e.properties, e.created_at, e.updated_at
217 FROM kg_relations r
218 JOIN kg_entities e ON r.source_id = e.id
219 WHERE r.target_id = ?1 AND (r.validity & ?2) != 0 AND (e.validity & ?2) != 0",
220 )?;
221 let rows = stmt.query_map(params![entity_id, bit], |row| {
222 Ok((row_to_relation(row)?, row_to_entity_offset(row, 7)?))
223 })?;
224 for row in rows {
225 result.push(row?);
226 }
227
228 Ok(result)
229}
230
231fn row_to_entity(row: &rusqlite::Row) -> rusqlite::Result<Entity> {
232 let props_json: Option<String> = row.get(3)?;
233 let properties = props_json
234 .and_then(|j| serde_json::from_str(&j).ok())
235 .unwrap_or_default();
236 Ok(Entity {
237 id: Some(row.get(0)?),
238 entity_type: row.get(1)?,
239 name: row.get(2)?,
240 properties,
241 created_at: row.get(4)?,
242 updated_at: row.get(5)?,
243 })
244}
245
246fn row_to_entity_offset(row: &rusqlite::Row, offset: usize) -> rusqlite::Result<Entity> {
247 let props_json: Option<String> = row.get(offset + 3)?;
248 let properties = props_json
249 .and_then(|j| serde_json::from_str(&j).ok())
250 .unwrap_or_default();
251 Ok(Entity {
252 id: Some(row.get(offset)?),
253 entity_type: row.get(offset + 1)?,
254 name: row.get(offset + 2)?,
255 properties,
256 created_at: row.get(offset + 4)?,
257 updated_at: row.get(offset + 5)?,
258 })
259}
260
261fn row_to_relation(row: &rusqlite::Row) -> rusqlite::Result<Relation> {
262 let props_json: Option<String> = row.get(5)?;
263 let properties = props_json
264 .and_then(|j| serde_json::from_str(&j).ok())
265 .unwrap_or_default();
266 Ok(Relation {
267 id: Some(row.get(0)?),
268 source_id: row.get(1)?,
269 target_id: row.get(2)?,
270 rel_type: row.get(3)?,
271 weight: crate::row_get_weight(row, 4)?,
272 properties,
273 created_at: row.get(6)?,
274 })
275}
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280 use rusqlite::Connection;
281
282 fn setup() -> Connection {
283 let conn = Connection::open_in_memory().unwrap();
284 crate::schema::create_schema(&conn).unwrap();
285 conn
286 }
287
288 fn add_entity(conn: &Connection, name: &str) -> i64 {
289 conn.execute(
290 "INSERT INTO kg_entities (entity_type, name) VALUES ('test', ?1)",
291 [name],
292 )
293 .unwrap();
294 conn.last_insert_rowid()
295 }
296
297 fn add_relation(conn: &Connection, src: i64, tgt: i64, rt: &str) -> i64 {
298 conn.execute(
299 "INSERT INTO kg_relations (source_id, target_id, rel_type) VALUES (?1, ?2, ?3)",
300 rusqlite::params![src, tgt, rt],
301 )
302 .unwrap();
303 conn.last_insert_rowid()
304 }
305
306 fn make_version(conn: &Connection, name: &str) -> i64 {
307 super::super::store::create_version(conn, name, "main", None, None).unwrap()
308 }
309
310 fn set_validity(conn: &Connection, table: &str, id: i64, val: i64) {
311 conn.execute(
312 &format!("UPDATE {table} SET validity = ?1 WHERE id = ?2"),
313 rusqlite::params![val, id],
314 )
315 .unwrap();
316 }
317
318 #[test]
319 fn test_entities_in_version() {
320 let conn = setup();
321 let e1 = add_entity(&conn, "A");
322 let e2 = add_entity(&conn, "B");
323 let e3 = add_entity(&conn, "C");
324 let v1 = make_version(&conn, "v1");
325
326 set_validity(&conn, "kg_entities", e1, 0b01); set_validity(&conn, "kg_entities", e2, 0b01); set_validity(&conn, "kg_entities", e3, 0b10); let ents = version_entities(&conn, v1, None, None).unwrap();
331 assert_eq!(ents.len(), 2);
332 let names: Vec<&str> = ents.iter().map(|e| e.name.as_str()).collect();
333 assert!(names.contains(&"A"));
334 assert!(names.contains(&"B"));
335 }
336
337 #[test]
338 fn test_entities_with_type_filter() {
339 let conn = setup();
340 conn.execute(
341 "INSERT INTO kg_entities (entity_type, name) VALUES ('paper', 'P1')",
342 [],
343 )
344 .unwrap();
345 let e2 = add_entity(&conn, "S1");
346 let v1 = make_version(&conn, "v1");
347
348 set_validity(&conn, "kg_entities", e2, 0b01);
349 conn.execute("UPDATE kg_entities SET validity = 1 WHERE name = 'P1'", [])
350 .unwrap();
351
352 let papers = version_entities(&conn, v1, Some("paper"), None).unwrap();
353 assert_eq!(papers.len(), 1);
354 assert_eq!(papers[0].name, "P1");
355 }
356
357 #[test]
358 fn test_relations_in_version() {
359 let conn = setup();
360 let e1 = add_entity(&conn, "A");
361 let e2 = add_entity(&conn, "B");
362 let r1 = add_relation(&conn, e1, e2, "cites");
363 let v1 = make_version(&conn, "v1");
364
365 set_validity(&conn, "kg_relations", r1, 0b01);
366
367 let rels = version_relations(&conn, v1, None, None, None, None).unwrap();
368 assert_eq!(rels.len(), 1);
369 }
370
371 #[test]
372 fn test_relations_with_type_filter() {
373 let conn = setup();
374 let e1 = add_entity(&conn, "A");
375 let e2 = add_entity(&conn, "B");
376 let r1 = add_relation(&conn, e1, e2, "cites");
377 let r2 = add_relation(&conn, e2, e1, "related");
378 let v1 = make_version(&conn, "v1");
379
380 set_validity(&conn, "kg_relations", r1, 0b01);
381 set_validity(&conn, "kg_relations", r2, 0b01);
382
383 let cites = version_relations(&conn, v1, Some("cites"), None, None, None).unwrap();
384 assert_eq!(cites.len(), 1);
385 assert_eq!(cites[0].rel_type, "cites");
386 }
387
388 #[test]
389 fn test_version_neighbors() {
390 let conn = setup();
391 let e1 = add_entity(&conn, "A");
392 let e2 = add_entity(&conn, "B");
393 let e3 = add_entity(&conn, "C");
394 let r1 = add_relation(&conn, e1, e2, "knows");
395 let r2 = add_relation(&conn, e1, e3, "knows");
396 let v1 = make_version(&conn, "v1");
397
398 set_validity(&conn, "kg_entities", e1, 0b01);
399 set_validity(&conn, "kg_entities", e2, 0b01);
400 set_validity(&conn, "kg_entities", e3, 0b01);
401 set_validity(&conn, "kg_relations", r1, 0b01);
402 set_validity(&conn, "kg_relations", r2, 0b01);
403
404 let neighbors = version_neighbors(&conn, e1, v1, 1).unwrap();
405 assert_eq!(neighbors.len(), 2);
406 }
407
408 #[test]
409 fn test_version_neighbors_excludes_non_version_entity() {
410 let conn = setup();
411 let e1 = add_entity(&conn, "A");
412 let e2 = add_entity(&conn, "B");
413 let e3 = add_entity(&conn, "C");
414 let r1 = add_relation(&conn, e1, e2, "knows");
415 let r2 = add_relation(&conn, e1, e3, "knows");
416 let v1 = make_version(&conn, "v1");
417
418 set_validity(&conn, "kg_entities", e1, 0b01);
419 set_validity(&conn, "kg_entities", e2, 0b01);
420 set_validity(&conn, "kg_relations", r1, 0b01);
422 set_validity(&conn, "kg_relations", r2, 0b01);
423
424 let neighbors = version_neighbors(&conn, e1, v1, 1).unwrap();
425 assert_eq!(neighbors.len(), 1);
426 assert_eq!(neighbors[0].entity.name, "B");
427 }
428}