1use crate::errors::AppError;
9use rusqlite::{params, Connection};
10
11pub fn traverse_from_memories(
45 conn: &Connection,
46 seed_memory_ids: &[i64],
47 namespace: &str,
48 min_weight: f64,
49 max_hops: u32,
50) -> Result<Vec<i64>, AppError> {
51 if seed_memory_ids.is_empty() || max_hops == 0 {
52 return Ok(vec![]);
53 }
54
55 let mut seed_entities: Vec<i64> = Vec::with_capacity(seed_memory_ids.len());
57 for &mem_id in seed_memory_ids {
58 let mut stmt =
59 conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
60 let ids: Vec<i64> = stmt
61 .query_map(params![mem_id], |r| r.get(0))?
62 .filter_map(|r| r.ok())
63 .collect();
64 seed_entities.extend(ids);
65 }
66 seed_entities.sort_unstable();
67 seed_entities.dedup();
68
69 if seed_entities.is_empty() {
70 return Ok(vec![]);
71 }
72
73 use std::collections::HashSet;
75 let mut visited: HashSet<i64> = seed_entities.iter().copied().collect();
76 let mut frontier: Vec<i64> = seed_entities.to_vec();
77
78 for _ in 0..max_hops {
79 if frontier.is_empty() {
80 break;
81 }
82 let mut next_frontier = Vec::with_capacity(frontier.len() * 2);
83
84 for &entity_id in &frontier {
85 let mut stmt = conn.prepare_cached(
86 "SELECT target_id FROM relationships
87 WHERE source_id = ?1 AND weight >= ?2 AND namespace = ?3",
88 )?;
89 let neighbors: Vec<i64> = stmt
90 .query_map(params![entity_id, min_weight, namespace], |r| r.get(0))?
91 .filter_map(|r| r.ok())
92 .filter(|id| !visited.contains(id))
93 .collect();
94
95 for id in neighbors {
96 visited.insert(id);
97 next_frontier.push(id);
98 }
99 }
100 frontier = next_frontier;
101 }
102
103 let seed_set: HashSet<i64> = seed_memory_ids.iter().copied().collect();
105 let graph_only_entities: Vec<i64> = visited
106 .into_iter()
107 .filter(|id| !seed_entities.contains(id))
108 .collect();
109
110 let mut result_ids: Vec<i64> = Vec::with_capacity(graph_only_entities.len());
111 for &entity_id in &graph_only_entities {
112 let mut stmt = conn.prepare_cached(
113 "SELECT DISTINCT me.memory_id
114 FROM memory_entities me
115 JOIN memories m ON m.id = me.memory_id
116 WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
117 )?;
118 let mem_ids: Vec<i64> = stmt
119 .query_map(params![entity_id], |r| r.get(0))?
120 .filter_map(|r| r.ok())
121 .filter(|id| !seed_set.contains(id))
122 .collect();
123 result_ids.extend(mem_ids);
124 }
125
126 result_ids.sort_unstable();
127 result_ids.dedup();
128 Ok(result_ids)
129}
130
131pub fn traverse_from_memories_with_hops(
145 conn: &Connection,
146 seed_memory_ids: &[i64],
147 namespace: &str,
148 min_weight: f64,
149 max_hops: u32,
150) -> Result<Vec<(i64, u32)>, AppError> {
151 traverse_from_memories_with_hops_inner(
152 conn,
153 seed_memory_ids,
154 namespace,
155 min_weight,
156 max_hops,
157 None,
158 )
159}
160
161pub fn traverse_from_memories_with_hops_capped(
172 conn: &Connection,
173 seed_memory_ids: &[i64],
174 namespace: &str,
175 min_weight: f64,
176 max_hops: u32,
177 max_neighbors_per_hop: Option<usize>,
178) -> Result<Vec<(i64, u32)>, AppError> {
179 traverse_from_memories_with_hops_inner(
180 conn,
181 seed_memory_ids,
182 namespace,
183 min_weight,
184 max_hops,
185 max_neighbors_per_hop,
186 )
187}
188
189fn traverse_from_memories_with_hops_inner(
190 conn: &Connection,
191 seed_memory_ids: &[i64],
192 namespace: &str,
193 min_weight: f64,
194 max_hops: u32,
195 max_neighbors_per_hop: Option<usize>,
196) -> Result<Vec<(i64, u32)>, AppError> {
197 if seed_memory_ids.is_empty() || max_hops == 0 {
198 return Ok(vec![]);
199 }
200
201 let mut seed_entities: Vec<i64> = Vec::with_capacity(seed_memory_ids.len());
203 for &mem_id in seed_memory_ids {
204 let mut stmt =
205 conn.prepare_cached("SELECT entity_id FROM memory_entities WHERE memory_id = ?1")?;
206 let ids: Vec<i64> = stmt
207 .query_map(params![mem_id], |r| r.get(0))?
208 .filter_map(|r| r.ok())
209 .collect();
210 seed_entities.extend(ids);
211 }
212 seed_entities.sort_unstable();
213 seed_entities.dedup();
214
215 if seed_entities.is_empty() {
216 return Ok(vec![]);
217 }
218
219 use std::collections::HashMap;
221 let mut entity_depth: HashMap<i64, u32> = seed_entities.iter().map(|&id| (id, 0)).collect();
222 let mut frontier: Vec<i64> = seed_entities.to_vec();
223
224 for hop in 1..=max_hops {
225 if frontier.is_empty() {
226 break;
227 }
228 let mut next_frontier = Vec::with_capacity(frontier.len() * 2);
229
230 for &entity_id in &frontier {
231 let mut stmt = conn.prepare_cached(
233 "SELECT target_id, weight FROM relationships
234 WHERE source_id = ?1 AND weight >= ?2 AND namespace = ?3
235 ORDER BY weight DESC",
236 )?;
237 let mut neighbors: Vec<i64> = stmt
238 .query_map(params![entity_id, min_weight, namespace], |r| {
239 Ok((r.get::<_, i64>(0)?, r.get::<_, f64>(1)?))
240 })?
241 .filter_map(|r| r.ok())
242 .filter(|(id, _)| !entity_depth.contains_key(id))
243 .map(|(id, _)| id)
244 .collect();
245
246 if let Some(cap) = max_neighbors_per_hop {
248 neighbors.truncate(cap);
249 }
250
251 for id in neighbors {
252 entity_depth.insert(id, hop);
253 next_frontier.push(id);
254 }
255 }
256 frontier = next_frontier;
257 }
258
259 let seed_set: std::collections::HashSet<i64> = seed_memory_ids.iter().copied().collect();
261 let seed_entity_set: std::collections::HashSet<i64> = seed_entities.iter().copied().collect();
262
263 let mut result: Vec<(i64, u32)> = Vec::with_capacity(entity_depth.len());
264 let mut seen_memories: std::collections::HashSet<i64> =
265 std::collections::HashSet::with_capacity(entity_depth.len());
266
267 for (&entity_id, &hop) in &entity_depth {
268 if seed_entity_set.contains(&entity_id) {
269 continue;
270 }
271 let mut stmt = conn.prepare_cached(
272 "SELECT DISTINCT me.memory_id
273 FROM memory_entities me
274 JOIN memories m ON m.id = me.memory_id
275 WHERE me.entity_id = ?1 AND m.deleted_at IS NULL",
276 )?;
277 let mem_ids: Vec<i64> = stmt
278 .query_map(params![entity_id], |r| r.get(0))?
279 .filter_map(|r| r.ok())
280 .filter(|id| !seed_set.contains(id) && !seen_memories.contains(id))
281 .collect();
282
283 for mem_id in mem_ids {
284 seen_memories.insert(mem_id);
285 result.push((mem_id, hop));
286 }
287 }
288
289 result.sort_unstable_by_key(|&(id, _)| id);
290 Ok(result)
291}
292
293pub type EntityDepthMap = std::collections::HashMap<i64, u32>;
295
296pub type PredecessorMap = std::collections::HashMap<i64, (i64, String, f64)>;
300
301pub fn bfs_with_predecessors(
314 conn: &Connection,
315 seed_entity_ids: &[i64],
316 namespace: &str,
317 min_weight: f64,
318 max_hops: u32,
319 max_neighbors_per_hop: Option<usize>,
320) -> Result<(EntityDepthMap, PredecessorMap), AppError> {
321 use std::collections::HashMap;
322
323 let mut entity_depth: HashMap<i64, u32> = seed_entity_ids.iter().map(|&id| (id, 0)).collect();
324 let mut predecessor: HashMap<i64, (i64, String, f64)> =
325 HashMap::with_capacity(max_hops as usize * 10);
326 let mut frontier: Vec<i64> = seed_entity_ids.to_vec();
327
328 for hop in 1..=max_hops {
329 if frontier.is_empty() {
330 break;
331 }
332 let mut next_frontier = Vec::with_capacity(frontier.len() * 2);
333
334 for &entity_id in &frontier {
335 let mut stmt = conn.prepare_cached(
336 "SELECT target_id, relation, weight FROM relationships
337 WHERE source_id = ?1 AND weight >= ?2 AND namespace = ?3
338 ORDER BY weight DESC",
339 )?;
340 let mut neighbors: Vec<(i64, String, f64)> = stmt
341 .query_map(params![entity_id, min_weight, namespace], |r| {
342 Ok((
343 r.get::<_, i64>(0)?,
344 r.get::<_, String>(1)?,
345 r.get::<_, f64>(2)?,
346 ))
347 })?
348 .filter_map(|r| r.ok())
349 .filter(|(id, _, _)| !entity_depth.contains_key(id))
350 .collect();
351
352 if let Some(cap) = max_neighbors_per_hop {
353 neighbors.truncate(cap);
354 }
355
356 for (id, relation, weight) in neighbors {
357 entity_depth.insert(id, hop);
358 predecessor.insert(id, (entity_id, relation, weight));
359 next_frontier.push(id);
360 }
361 }
362 frontier = next_frontier;
363 }
364
365 Ok((entity_depth, predecessor))
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371 use rusqlite::Connection;
372
373 fn setup_db() -> Connection {
374 let conn = Connection::open_in_memory().unwrap();
375 conn.execute_batch(
376 "CREATE TABLE memories (
377 id INTEGER PRIMARY KEY,
378 namespace TEXT NOT NULL,
379 deleted_at TEXT
380 );
381 CREATE TABLE memory_entities (
382 memory_id INTEGER NOT NULL,
383 entity_id INTEGER NOT NULL
384 );
385 CREATE TABLE relationships (
386 source_id INTEGER NOT NULL,
387 target_id INTEGER NOT NULL,
388 weight REAL NOT NULL,
389 namespace TEXT NOT NULL
390 );",
391 )
392 .unwrap();
393 conn
394 }
395
396 fn insert_memory(conn: &Connection, id: i64, namespace: &str, deleted: bool) {
397 conn.execute(
398 "INSERT INTO memories (id, namespace, deleted_at) VALUES (?1, ?2, ?3)",
399 params![
400 id,
401 namespace,
402 if deleted { Some("2024-01-01") } else { None }
403 ],
404 )
405 .unwrap();
406 }
407
408 fn link_memory_entity(conn: &Connection, memory_id: i64, entity_id: i64) {
409 conn.execute(
410 "INSERT INTO memory_entities (memory_id, entity_id) VALUES (?1, ?2)",
411 params![memory_id, entity_id],
412 )
413 .unwrap();
414 }
415
416 fn insert_relationship(conn: &Connection, src: i64, tgt: i64, weight: f64, ns: &str) {
417 conn.execute(
418 "INSERT INTO relationships (source_id, target_id, weight, namespace) VALUES (?1, ?2, ?3, ?4)",
419 params![src, tgt, weight, ns],
420 )
421 .unwrap();
422 }
423
424 #[test]
427 fn returns_empty_when_seeds_empty() {
428 let conn = setup_db();
429 let result = traverse_from_memories(&conn, &[], "ns", 0.5, 3).unwrap();
430 assert!(result.is_empty());
431 }
432
433 #[test]
434 fn returns_empty_when_max_hops_zero() {
435 let conn = setup_db();
436 insert_memory(&conn, 1, "ns", false);
437 link_memory_entity(&conn, 1, 10);
438 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 0).unwrap();
439 assert!(result.is_empty());
440 }
441
442 #[test]
443 fn returns_empty_when_seed_has_no_entities() {
444 let conn = setup_db();
445 insert_memory(&conn, 1, "ns", false);
446 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
448 assert!(result.is_empty());
449 }
450
451 #[test]
452 fn returns_empty_when_no_relationships() {
453 let conn = setup_db();
454 insert_memory(&conn, 1, "ns", false);
455 link_memory_entity(&conn, 1, 10);
456 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
458 assert!(result.is_empty());
459 }
460
461 #[test]
464 fn traversal_basic_one_hop() {
465 let conn = setup_db();
466
467 insert_memory(&conn, 1, "ns", false);
469 link_memory_entity(&conn, 1, 10);
470
471 insert_memory(&conn, 2, "ns", false);
473 link_memory_entity(&conn, 2, 20);
474
475 insert_relationship(&conn, 10, 20, 1.0, "ns");
477
478 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
479 assert_eq!(result, vec![2]);
480 }
481
482 #[test]
483 fn traversal_two_hops() {
484 let conn = setup_db();
485
486 insert_memory(&conn, 1, "ns", false);
487 link_memory_entity(&conn, 1, 10);
488
489 insert_memory(&conn, 2, "ns", false);
490 link_memory_entity(&conn, 2, 20);
491
492 insert_memory(&conn, 3, "ns", false);
493 link_memory_entity(&conn, 3, 30);
494
495 insert_relationship(&conn, 10, 20, 1.0, "ns");
497 insert_relationship(&conn, 20, 30, 1.0, "ns");
498
499 let mut result = traverse_from_memories(&conn, &[1], "ns", 0.5, 2).unwrap();
500 result.sort_unstable();
501 assert_eq!(result, vec![2, 3]);
502 }
503
504 #[test]
505 fn max_hops_limits_depth() {
506 let conn = setup_db();
507
508 insert_memory(&conn, 1, "ns", false);
509 link_memory_entity(&conn, 1, 10);
510
511 insert_memory(&conn, 2, "ns", false);
512 link_memory_entity(&conn, 2, 20);
513
514 insert_memory(&conn, 3, "ns", false);
515 link_memory_entity(&conn, 3, 30);
516
517 insert_relationship(&conn, 10, 20, 1.0, "ns");
518 insert_relationship(&conn, 20, 30, 1.0, "ns");
519
520 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
522 assert_eq!(result, vec![2]);
523 assert!(!result.contains(&3));
524 }
525
526 #[test]
529 fn relationship_with_weight_below_min_ignored() {
530 let conn = setup_db();
531
532 insert_memory(&conn, 1, "ns", false);
533 link_memory_entity(&conn, 1, 10);
534
535 insert_memory(&conn, 2, "ns", false);
536 link_memory_entity(&conn, 2, 20);
537
538 insert_relationship(&conn, 10, 20, 0.3, "ns");
540
541 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
542 assert!(result.is_empty());
543 }
544
545 #[test]
546 fn relationship_with_weight_exactly_at_min_included() {
547 let conn = setup_db();
548
549 insert_memory(&conn, 1, "ns", false);
550 link_memory_entity(&conn, 1, 10);
551
552 insert_memory(&conn, 2, "ns", false);
553 link_memory_entity(&conn, 2, 20);
554
555 insert_relationship(&conn, 10, 20, 0.5, "ns");
556
557 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
558 assert_eq!(result, vec![2]);
559 }
560
561 #[test]
564 fn relationship_from_different_namespace_ignored() {
565 let conn = setup_db();
566
567 insert_memory(&conn, 1, "ns_a", false);
568 link_memory_entity(&conn, 1, 10);
569
570 insert_memory(&conn, 2, "ns_a", false);
571 link_memory_entity(&conn, 2, 20);
572
573 insert_relationship(&conn, 10, 20, 1.0, "ns_b");
575
576 let result = traverse_from_memories(&conn, &[1], "ns_a", 0.5, 3).unwrap();
577 assert!(result.is_empty());
578 }
579
580 #[test]
583 fn seeds_do_not_appear_in_result() {
584 let conn = setup_db();
585
586 insert_memory(&conn, 1, "ns", false);
587 link_memory_entity(&conn, 1, 10);
588
589 insert_memory(&conn, 2, "ns", false);
590 link_memory_entity(&conn, 2, 20);
591
592 insert_relationship(&conn, 10, 20, 1.0, "ns");
594 insert_relationship(&conn, 20, 10, 1.0, "ns");
595
596 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
597 assert!(!result.contains(&1));
599 assert_eq!(result, vec![2]);
600 }
601
602 #[test]
605 fn deleted_memories_not_included() {
606 let conn = setup_db();
607
608 insert_memory(&conn, 1, "ns", false);
609 link_memory_entity(&conn, 1, 10);
610
611 insert_memory(&conn, 2, "ns", true);
613 link_memory_entity(&conn, 2, 20);
614
615 insert_relationship(&conn, 10, 20, 1.0, "ns");
616
617 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 3).unwrap();
618 assert!(result.is_empty());
619 }
620
621 #[test]
624 fn multiple_seeds_merged_in_result() {
625 let conn = setup_db();
626
627 insert_memory(&conn, 1, "ns", false);
628 link_memory_entity(&conn, 1, 10);
629
630 insert_memory(&conn, 2, "ns", false);
631 link_memory_entity(&conn, 2, 20);
632
633 insert_memory(&conn, 3, "ns", false);
634 link_memory_entity(&conn, 3, 30);
635
636 insert_memory(&conn, 4, "ns", false);
637 link_memory_entity(&conn, 4, 40);
638
639 insert_relationship(&conn, 10, 30, 1.0, "ns");
640 insert_relationship(&conn, 20, 40, 1.0, "ns");
641
642 let mut result = traverse_from_memories(&conn, &[1, 2], "ns", 0.5, 1).unwrap();
643 result.sort_unstable();
644 assert_eq!(result, vec![3, 4]);
645 }
646
647 #[test]
650 fn result_without_duplicates() {
651 let conn = setup_db();
652
653 insert_memory(&conn, 1, "ns", false);
654 link_memory_entity(&conn, 1, 10);
655 link_memory_entity(&conn, 1, 11); insert_memory(&conn, 2, "ns", false);
658 link_memory_entity(&conn, 2, 20);
659
660 insert_relationship(&conn, 10, 20, 1.0, "ns");
662 insert_relationship(&conn, 11, 20, 1.0, "ns");
663
664 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 1).unwrap();
665 assert_eq!(result.len(), 1);
667 assert_eq!(result, vec![2]);
668 }
669
670 #[test]
673 fn single_node_without_neighbors_returns_empty() {
674 let conn = setup_db();
675
676 insert_memory(&conn, 1, "ns", false);
677 link_memory_entity(&conn, 1, 10);
678 let result = traverse_from_memories(&conn, &[1], "ns", 0.5, 5).unwrap();
681 assert!(result.is_empty());
682 }
683
684 #[test]
687 fn cycle_does_not_cause_infinite_loop() {
688 let conn = setup_db();
689
690 insert_memory(&conn, 1, "ns", false);
691 link_memory_entity(&conn, 1, 10);
692
693 insert_memory(&conn, 2, "ns", false);
694 link_memory_entity(&conn, 2, 20);
695
696 insert_memory(&conn, 3, "ns", false);
697 link_memory_entity(&conn, 3, 30);
698
699 insert_relationship(&conn, 10, 20, 1.0, "ns");
701 insert_relationship(&conn, 20, 30, 1.0, "ns");
702 insert_relationship(&conn, 30, 10, 1.0, "ns");
703
704 let mut result = traverse_from_memories(&conn, &[1], "ns", 0.5, 10).unwrap();
705 result.sort_unstable();
706 assert_eq!(result, vec![2, 3]);
708 }
709}