1use crate::errors::AppError;
9use rusqlite::{params, Connection};
10
11pub fn traverse_from_memories(
45 conn: &Connection,
46 seed_memory_ids: &[i64],
47 namespace: &str,
48 min_weight: f64,
49 max_hops: u32,
50) -> Result<Vec<i64>, AppError> {
51 if seed_memory_ids.is_empty() || max_hops == 0 {
52 return Ok(vec![]);
53 }
54
55 let mut seed_entities: Vec<i64> = Vec::with_capacity(seed_memory_ids.len());
57 for &mem_id in seed_memory_ids {
58 let mut stmt =
59 conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
60 let ids: Vec<i64> = stmt
61 .query_map(params![mem_id], |r| r.get(0))?
62 .filter_map(|r| r.ok())
63 .collect();
64 seed_entities.extend(ids);
65 }
66 seed_entities.sort_unstable();
67 seed_entities.dedup();
68
69 if seed_entities.is_empty() {
70 return Ok(vec![]);
71 }
72
73 use std::collections::HashSet;
75 let mut visited: HashSet<i64> = seed_entities.iter().copied().collect();
76 let mut frontier: Vec<i64> = seed_entities.to_vec();
77
78 for _ in 0..max_hops {
79 if frontier.is_empty() {
80 break;
81 }
82 let mut next_frontier = Vec::with_capacity(frontier.len() * 2);
83
84 for &entity_id in &frontier {
85 let mut stmt = conn.prepare_cached(
86 "SELECT target_id FROM relationships
87 WHERE source_id = ?1 AND weight >= ?2 AND namespace = ?3",
88 )?;
89 let neighbors: Vec<i64> = stmt
90 .query_map(params![entity_id, min_weight, namespace], |r| r.get(0))?
91 .filter_map(|r| r.ok())
92 .filter(|id| !visited.contains(id))
93 .collect();
94
95 for id in neighbors {
96 visited.insert(id);
97 next_frontier.push(id);
98 }
99 }
100 frontier = next_frontier;
101 }
102
103 let seed_set: HashSet<i64> = seed_memory_ids.iter().copied().collect();
105 let graph_only_entities: Vec<i64> = visited
106 .into_iter()
107 .filter(|id| !seed_entities.contains(id))
108 .collect();
109
110 let mut result_ids: Vec<i64> = Vec::with_capacity(graph_only_entities.len());
111 for &entity_id in &graph_only_entities {
112 let mut stmt = conn.prepare_cached(
113 "SELECT DISTINCT me.memory_id
114 FROM memory_entities me
115 JOIN memories m ON m.id = me.memory_id
116 WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
117 )?;
118 let mem_ids: Vec<i64> = stmt
119 .query_map(params![entity_id], |r| r.get(0))?
120 .filter_map(|r| r.ok())
121 .filter(|id| !seed_set.contains(id))
122 .collect();
123 result_ids.extend(mem_ids);
124 }
125
126 result_ids.sort_unstable();
127 result_ids.dedup();
128 Ok(result_ids)
129}
130
131pub fn traverse_from_memories_with_hops(
145 conn: &Connection,
146 seed_memory_ids: &[i64],
147 namespace: &str,
148 min_weight: f64,
149 max_hops: u32,
150) -> Result<Vec<(i64, u32)>, AppError> {
151 traverse_from_memories_with_hops_inner(
152 conn,
153 seed_memory_ids,
154 namespace,
155 min_weight,
156 max_hops,
157 None,
158 )
159}
160
161pub fn traverse_from_memories_with_hops_capped(
172 conn: &Connection,
173 seed_memory_ids: &[i64],
174 namespace: &str,
175 min_weight: f64,
176 max_hops: u32,
177 max_neighbors_per_hop: Option<usize>,
178) -> Result<Vec<(i64, u32)>, AppError> {
179 traverse_from_memories_with_hops_inner(
180 conn,
181 seed_memory_ids,
182 namespace,
183 min_weight,
184 max_hops,
185 max_neighbors_per_hop,
186 )
187}
188
189fn traverse_from_memories_with_hops_inner(
190 conn: &Connection,
191 seed_memory_ids: &[i64],
192 namespace: &str,
193 min_weight: f64,
194 max_hops: u32,
195 max_neighbors_per_hop: Option<usize>,
196) -> Result<Vec<(i64, u32)>, AppError> {
197 if seed_memory_ids.is_empty() || max_hops == 0 {
198 return Ok(vec![]);
199 }
200
201 let mut seed_entities: Vec<i64> = Vec::with_capacity(seed_memory_ids.len());
203 for &mem_id in seed_memory_ids {
204 let mut stmt =
205 conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
206 let ids: Vec<i64> = stmt
207 .query_map(params![mem_id], |r| r.get(0))?
208 .filter_map(|r| r.ok())
209 .collect();
210 seed_entities.extend(ids);
211 }
212 seed_entities.sort_unstable();
213 seed_entities.dedup();
214
215 if seed_entities.is_empty() {
216 return Ok(vec![]);
217 }
218
219 use std::collections::HashMap;
221 let mut entity_depth: HashMap<i64, u32> = seed_entities.iter().map(|&id| (id, 0)).collect();
222 let mut frontier: Vec<i64> = seed_entities.to_vec();
223
224 for hop in 1..=max_hops {
225 if frontier.is_empty() {
226 break;
227 }
228 let mut next_frontier = Vec::with_capacity(frontier.len() * 2);
229
230 for &entity_id in &frontier {
231 let mut stmt = conn.prepare_cached(
233 "SELECT target_id, weight FROM relationships
234 WHERE source_id = ?1 AND weight >= ?2 AND namespace = ?3
235 ORDER BY weight DESC",
236 )?;
237 let mut neighbors: Vec<i64> = stmt
238 .query_map(params![entity_id, min_weight, namespace], |r| {
239 Ok((r.get::<_, i64>(0)?, r.get::<_, f64>(1)?))
240 })?
241 .filter_map(|r| r.ok())
242 .filter(|(id, _)| !entity_depth.contains_key(id))
243 .map(|(id, _)| id)
244 .collect();
245
246 if let Some(cap) = max_neighbors_per_hop {
248 neighbors.truncate(cap);
249 }
250
251 for id in neighbors {
252 entity_depth.insert(id, hop);
253 next_frontier.push(id);
254 }
255 }
256 frontier = next_frontier;
257 }
258
259 let seed_set: std::collections::HashSet<i64> = seed_memory_ids.iter().copied().collect();
261 let seed_entity_set: std::collections::HashSet<i64> = seed_entities.iter().copied().collect();
262
263 let mut result: Vec<(i64, u32)> = Vec::with_capacity(entity_depth.len());
264 let mut seen_memories: std::collections::HashSet<i64> = std::collections::HashSet::new();
265
266 for (&entity_id, &hop) in &entity_depth {
267 if seed_entity_set.contains(&entity_id) {
268 continue;
269 }
270 let mut stmt = conn.prepare_cached(
271 "SELECT DISTINCT me.memory_id
272 FROM memory_entities me
273 JOIN memories m ON m.id = me.memory_id
274 WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
275 )?;
276 let mem_ids: Vec<i64> = stmt
277 .query_map(params![entity_id], |r| r.get(0))?
278 .filter_map(|r| r.ok())
279 .filter(|id| !seed_set.contains(id) && !seen_memories.contains(id))
280 .collect();
281
282 for mem_id in mem_ids {
283 seen_memories.insert(mem_id);
284 result.push((mem_id, hop));
285 }
286 }
287
288 result.sort_unstable_by_key(|&(id, _)| id);
289 Ok(result)
290}
291
292pub type EntityDepthMap = std::collections::HashMap<i64, u32>;
294
295pub type PredecessorMap = std::collections::HashMap<i64, (i64, String, f64)>;
299
300pub fn bfs_with_predecessors(
313 conn: &Connection,
314 seed_entity_ids: &[i64],
315 namespace: &str,
316 min_weight: f64,
317 max_hops: u32,
318 max_neighbors_per_hop: Option<usize>,
319) -> Result<(EntityDepthMap, PredecessorMap), AppError> {
320 use std::collections::HashMap;
321
322 let mut entity_depth: HashMap<i64, u32> = seed_entity_ids.iter().map(|&id| (id, 0)).collect();
323 let mut predecessor: HashMap<i64, (i64, String, f64)> = HashMap::new();
324 let mut frontier: Vec<i64> = seed_entity_ids.to_vec();
325
326 for hop in 1..=max_hops {
327 if frontier.is_empty() {
328 break;
329 }
330 let mut next_frontier = Vec::with_capacity(frontier.len() * 2);
331
332 for &entity_id in &frontier {
333 let mut stmt = conn.prepare_cached(
334 "SELECT target_id, relation, weight FROM relationships
335 WHERE source_id = ?1 AND weight >= ?2 AND namespace = ?3
336 ORDER BY weight DESC",
337 )?;
338 let mut neighbors: Vec<(i64, String, f64)> = stmt
339 .query_map(params![entity_id, min_weight, namespace], |r| {
340 Ok((
341 r.get::<_, i64>(0)?,
342 r.get::<_, String>(1)?,
343 r.get::<_, f64>(2)?,
344 ))
345 })?
346 .filter_map(|r| r.ok())
347 .filter(|(id, _, _)| !entity_depth.contains_key(id))
348 .collect();
349
350 if let Some(cap) = max_neighbors_per_hop {
351 neighbors.truncate(cap);
352 }
353
354 for (id, relation, weight) in neighbors {
355 entity_depth.insert(id, hop);
356 predecessor.insert(id, (entity_id, relation, weight));
357 next_frontier.push(id);
358 }
359 }
360 frontier = next_frontier;
361 }
362
363 Ok((entity_depth, predecessor))
364}
365
366#[cfg(test)]
367mod tests {
368 use super::*;
369 use rusqlite::Connection;
370
371 fn setup_db() -> Connection {
372 let conn = Connection::open_in_memory().unwrap();
373 conn.execute_batch(
374 "CREATE TABLE memories (
375 id INTEGER PRIMARY KEY,
376 namespace TEXT NOT NULL,
377 deleted_at TEXT
378 );
379 CREATE TABLE memory_entities (
380 memory_id INTEGER NOT NULL,
381 entity_id INTEGER NOT NULL
382 );
383 CREATE TABLE relationships (
384 source_id INTEGER NOT NULL,
385 target_id INTEGER NOT NULL,
386 weight REAL NOT NULL,
387 namespace TEXT NOT NULL
388 );",
389 )
390 .unwrap();
391 conn
392 }
393
394 fn insert_memory(conn: &Connection, id: i64, namespace: &str, deleted: bool) {
395 conn.execute(
396 "INSERT INTO memories (id, namespace, deleted_at) VALUES (?1, ?2, ?3)",
397 params![
398 id,
399 namespace,
400 if deleted { Some("2024-01-01") } else { None }
401 ],
402 )
403 .unwrap();
404 }
405
406 fn link_memory_entity(conn: &Connection, memory_id: i64, entity_id: i64) {
407 conn.execute(
408 "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
409 params![memory_id, entity_id],
410 )
411 .unwrap();
412 }
413
414 fn insert_relationship(conn: &Connection, src: i64, tgt: i64, weight: f64, ns: &str) {
415 conn.execute(
416 "INSERT INTO relationships (source_id, target_id, weight, namespace) VALUES (?1, ?2, ?3, ?4)",
417 params![src, tgt, weight, ns],
418 )
419 .unwrap();
420 }
421
422 #[test]
425 fn returns_empty_when_seeds_empty() {
426 let conn = setup_db();
427 let result = traverse_from_memories(&conn, &[], "ns", 0.5, 3).unwrap();
428 assert!(result.is_empty());
429 }
430
431 #[test]
432 fn returns_empty_when_max_hops_zero() {
433 let conn = setup_db();
434 insert_memory(&conn, 1, "ns", false);
435 link_memory_entity(&conn, 1, 10);
436 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 0).unwrap();
437 assert!(result.is_empty());
438 }
439
440 #[test]
441 fn returns_empty_when_seed_has_no_entities() {
442 let conn = setup_db();
443 insert_memory(&conn, 1, "ns", false);
444 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
446 assert!(result.is_empty());
447 }
448
449 #[test]
450 fn returns_empty_when_no_relationships() {
451 let conn = setup_db();
452 insert_memory(&conn, 1, "ns", false);
453 link_memory_entity(&conn, 1, 10);
454 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
456 assert!(result.is_empty());
457 }
458
459 #[test]
462 fn traversal_basic_one_hop() {
463 let conn = setup_db();
464
465 insert_memory(&conn, 1, "ns", false);
467 link_memory_entity(&conn, 1, 10);
468
469 insert_memory(&conn, 2, "ns", false);
471 link_memory_entity(&conn, 2, 20);
472
473 insert_relationship(&conn, 10, 20, 1.0, "ns");
475
476 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
477 assert_eq!(result, vec![2]);
478 }
479
480 #[test]
481 fn traversal_two_hops() {
482 let conn = setup_db();
483
484 insert_memory(&conn, 1, "ns", false);
485 link_memory_entity(&conn, 1, 10);
486
487 insert_memory(&conn, 2, "ns", false);
488 link_memory_entity(&conn, 2, 20);
489
490 insert_memory(&conn, 3, "ns", false);
491 link_memory_entity(&conn, 3, 30);
492
493 insert_relationship(&conn, 10, 20, 1.0, "ns");
495 insert_relationship(&conn, 20, 30, 1.0, "ns");
496
497 let mut result = traverse_from_memories(&conn, &[1], "ns", 0.5, 2).unwrap();
498 result.sort_unstable();
499 assert_eq!(result, vec![2, 3]);
500 }
501
502 #[test]
503 fn max_hops_limits_depth() {
504 let conn = setup_db();
505
506 insert_memory(&conn, 1, "ns", false);
507 link_memory_entity(&conn, 1, 10);
508
509 insert_memory(&conn, 2, "ns", false);
510 link_memory_entity(&conn, 2, 20);
511
512 insert_memory(&conn, 3, "ns", false);
513 link_memory_entity(&conn, 3, 30);
514
515 insert_relationship(&conn, 10, 20, 1.0, "ns");
516 insert_relationship(&conn, 20, 30, 1.0, "ns");
517
518 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
520 assert_eq!(result, vec![2]);
521 assert!(!result.contains(&3));
522 }
523
524 #[test]
527 fn relationship_with_weight_below_min_ignored() {
528 let conn = setup_db();
529
530 insert_memory(&conn, 1, "ns", false);
531 link_memory_entity(&conn, 1, 10);
532
533 insert_memory(&conn, 2, "ns", false);
534 link_memory_entity(&conn, 2, 20);
535
536 insert_relationship(&conn, 10, 20, 0.3, "ns");
538
539 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
540 assert!(result.is_empty());
541 }
542
543 #[test]
544 fn relationship_with_weight_exactly_at_min_included() {
545 let conn = setup_db();
546
547 insert_memory(&conn, 1, "ns", false);
548 link_memory_entity(&conn, 1, 10);
549
550 insert_memory(&conn, 2, "ns", false);
551 link_memory_entity(&conn, 2, 20);
552
553 insert_relationship(&conn, 10, 20, 0.5, "ns");
554
555 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
556 assert_eq!(result, vec![2]);
557 }
558
559 #[test]
562 fn relationship_from_different_namespace_ignored() {
563 let conn = setup_db();
564
565 insert_memory(&conn, 1, "ns_a", false);
566 link_memory_entity(&conn, 1, 10);
567
568 insert_memory(&conn, 2, "ns_a", false);
569 link_memory_entity(&conn, 2, 20);
570
571 insert_relationship(&conn, 10, 20, 1.0, "ns_b");
573
574 let result = traverse_from_memories(&conn, &[1], "ns_a", 0.5, 3).unwrap();
575 assert!(result.is_empty());
576 }
577
578 #[test]
581 fn seeds_do_not_appear_in_result() {
582 let conn = setup_db();
583
584 insert_memory(&conn, 1, "ns", false);
585 link_memory_entity(&conn, 1, 10);
586
587 insert_memory(&conn, 2, "ns", false);
588 link_memory_entity(&conn, 2, 20);
589
590 insert_relationship(&conn, 10, 20, 1.0, "ns");
592 insert_relationship(&conn, 20, 10, 1.0, "ns");
593
594 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
595 assert!(!result.contains(&1));
597 assert_eq!(result, vec![2]);
598 }
599
600 #[test]
603 fn deleted_memories_not_included() {
604 let conn = setup_db();
605
606 insert_memory(&conn, 1, "ns", false);
607 link_memory_entity(&conn, 1, 10);
608
609 insert_memory(&conn, 2, "ns", true);
611 link_memory_entity(&conn, 2, 20);
612
613 insert_relationship(&conn, 10, 20, 1.0, "ns");
614
615 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
616 assert!(result.is_empty());
617 }
618
619 #[test]
622 fn multiple_seeds_merged_in_result() {
623 let conn = setup_db();
624
625 insert_memory(&conn, 1, "ns", false);
626 link_memory_entity(&conn, 1, 10);
627
628 insert_memory(&conn, 2, "ns", false);
629 link_memory_entity(&conn, 2, 20);
630
631 insert_memory(&conn, 3, "ns", false);
632 link_memory_entity(&conn, 3, 30);
633
634 insert_memory(&conn, 4, "ns", false);
635 link_memory_entity(&conn, 4, 40);
636
637 insert_relationship(&conn, 10, 30, 1.0, "ns");
638 insert_relationship(&conn, 20, 40, 1.0, "ns");
639
640 let mut result = traverse_from_memories(&conn, &[1, 2], "ns", 0.5, 1).unwrap();
641 result.sort_unstable();
642 assert_eq!(result, vec![3, 4]);
643 }
644
645 #[test]
648 fn result_without_duplicates() {
649 let conn = setup_db();
650
651 insert_memory(&conn, 1, "ns", false);
652 link_memory_entity(&conn, 1, 10);
653 link_memory_entity(&conn, 1, 11); insert_memory(&conn, 2, "ns", false);
656 link_memory_entity(&conn, 2, 20);
657
658 insert_relationship(&conn, 10, 20, 1.0, "ns");
660 insert_relationship(&conn, 11, 20, 1.0, "ns");
661
662 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
663 assert_eq!(result.len(), 1);
665 assert_eq!(result, vec![2]);
666 }
667
668 #[test]
671 fn single_node_without_neighbors_returns_empty() {
672 let conn = setup_db();
673
674 insert_memory(&conn, 1, "ns", false);
675 link_memory_entity(&conn, 1, 10);
676 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 5).unwrap();
679 assert!(result.is_empty());
680 }
681
682 #[test]
685 fn cycle_does_not_cause_infinite_loop() {
686 let conn = setup_db();
687
688 insert_memory(&conn, 1, "ns", false);
689 link_memory_entity(&conn, 1, 10);
690
691 insert_memory(&conn, 2, "ns", false);
692 link_memory_entity(&conn, 2, 20);
693
694 insert_memory(&conn, 3, "ns", false);
695 link_memory_entity(&conn, 3, 30);
696
697 insert_relationship(&conn, 10, 20, 1.0, "ns");
699 insert_relationship(&conn, 20, 30, 1.0, "ns");
700 insert_relationship(&conn, 30, 10, 1.0, "ns");
701
702 let mut result = traverse_from_memories(&conn, &[1], "ns", 0.5, 10).unwrap();
703 result.sort_unstable();
704 assert_eq!(result, vec![2, 3]);
706 }
707}