1use std::collections::HashMap;
5
6use futures::Stream;
7use sqlx::SqlitePool;
8
9use crate::error::MemoryError;
10use crate::sqlite::messages::sanitize_fts5_query;
11use crate::types::MessageId;
12
13use super::types::{Community, Edge, Entity, EntityAlias, EntityType};
14
15pub struct GraphStore {
16 pool: SqlitePool,
17}
18
19impl GraphStore {
20 #[must_use]
21 pub fn new(pool: SqlitePool) -> Self {
22 Self { pool }
23 }
24
25 #[must_use]
26 pub fn pool(&self) -> &SqlitePool {
27 &self.pool
28 }
29
30 pub async fn upsert_entity(
43 &self,
44 surface_name: &str,
45 canonical_name: &str,
46 entity_type: EntityType,
47 summary: Option<&str>,
48 ) -> Result<i64, MemoryError> {
49 let type_str = entity_type.as_str();
50 let id: i64 = sqlx::query_scalar(
51 "INSERT INTO graph_entities (name, canonical_name, entity_type, summary)
52 VALUES (?1, ?2, ?3, ?4)
53 ON CONFLICT(canonical_name, entity_type) DO UPDATE SET
54 name = excluded.name,
55 summary = COALESCE(excluded.summary, summary),
56 last_seen_at = datetime('now')
57 RETURNING id",
58 )
59 .bind(surface_name)
60 .bind(canonical_name)
61 .bind(type_str)
62 .bind(summary)
63 .fetch_one(&self.pool)
64 .await?;
65 Ok(id)
66 }
67
68 pub async fn find_entity(
74 &self,
75 canonical_name: &str,
76 entity_type: EntityType,
77 ) -> Result<Option<Entity>, MemoryError> {
78 let type_str = entity_type.as_str();
79 let row: Option<EntityRow> = sqlx::query_as(
80 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
81 FROM graph_entities
82 WHERE canonical_name = ?1 AND entity_type = ?2",
83 )
84 .bind(canonical_name)
85 .bind(type_str)
86 .fetch_optional(&self.pool)
87 .await?;
88 row.map(entity_from_row).transpose()
89 }
90
91 pub async fn find_entity_by_id(&self, entity_id: i64) -> Result<Option<Entity>, MemoryError> {
97 let row: Option<EntityRow> = sqlx::query_as(
98 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
99 FROM graph_entities
100 WHERE id = ?1",
101 )
102 .bind(entity_id)
103 .fetch_optional(&self.pool)
104 .await?;
105 row.map(entity_from_row).transpose()
106 }
107
108 pub async fn set_entity_qdrant_point_id(
114 &self,
115 entity_id: i64,
116 point_id: &str,
117 ) -> Result<(), MemoryError> {
118 sqlx::query("UPDATE graph_entities SET qdrant_point_id = ?1 WHERE id = ?2")
119 .bind(point_id)
120 .bind(entity_id)
121 .execute(&self.pool)
122 .await?;
123 Ok(())
124 }
125
126 pub async fn find_entities_fuzzy(
147 &self,
148 query: &str,
149 limit: usize,
150 ) -> Result<Vec<Entity>, MemoryError> {
151 const FTS5_OPERATORS: &[&str] = &["AND", "OR", "NOT", "NEAR"];
155 let query = &query[..query.floor_char_boundary(512)];
156 let sanitized = sanitize_fts5_query(query);
159 if sanitized.is_empty() {
160 return Ok(vec![]);
161 }
162 let fts_query: String = sanitized
163 .split_whitespace()
164 .filter(|t| !FTS5_OPERATORS.contains(t))
165 .map(|t| format!("{t}*"))
166 .collect::<Vec<_>>()
167 .join(" ");
168 if fts_query.is_empty() {
169 return Ok(vec![]);
170 }
171
172 let limit = i64::try_from(limit)?;
173 let rows: Vec<EntityRow> = sqlx::query_as(
176 "SELECT DISTINCT e.id, e.name, e.canonical_name, e.entity_type, e.summary,
177 e.first_seen_at, e.last_seen_at, e.qdrant_point_id
178 FROM graph_entities_fts fts
179 JOIN graph_entities e ON e.id = fts.rowid
180 WHERE graph_entities_fts MATCH ?1
181 UNION
182 SELECT e.id, e.name, e.canonical_name, e.entity_type, e.summary,
183 e.first_seen_at, e.last_seen_at, e.qdrant_point_id
184 FROM graph_entity_aliases a
185 JOIN graph_entities e ON e.id = a.entity_id
186 WHERE a.alias_name LIKE ?2 ESCAPE '\\' COLLATE NOCASE
187 LIMIT ?3",
188 )
189 .bind(&fts_query)
190 .bind(format!(
191 "%{}%",
192 query
193 .trim()
194 .replace('\\', "\\\\")
195 .replace('%', "\\%")
196 .replace('_', "\\_")
197 ))
198 .bind(limit)
199 .fetch_all(&self.pool)
200 .await?;
201 rows.into_iter()
202 .map(entity_from_row)
203 .collect::<Result<Vec<_>, _>>()
204 }
205
206 pub fn all_entities_stream(&self) -> impl Stream<Item = Result<Entity, MemoryError>> + '_ {
208 use futures::StreamExt as _;
209 sqlx::query_as::<_, EntityRow>(
210 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
211 FROM graph_entities ORDER BY id ASC",
212 )
213 .fetch(&self.pool)
214 .map(|r: Result<EntityRow, sqlx::Error>| {
215 r.map_err(MemoryError::from).and_then(entity_from_row)
216 })
217 }
218
219 pub async fn add_alias(&self, entity_id: i64, alias_name: &str) -> Result<(), MemoryError> {
227 sqlx::query(
228 "INSERT OR IGNORE INTO graph_entity_aliases (entity_id, alias_name) VALUES (?1, ?2)",
229 )
230 .bind(entity_id)
231 .bind(alias_name)
232 .execute(&self.pool)
233 .await?;
234 Ok(())
235 }
236
237 pub async fn find_entity_by_alias(
245 &self,
246 alias_name: &str,
247 entity_type: EntityType,
248 ) -> Result<Option<Entity>, MemoryError> {
249 let type_str = entity_type.as_str();
250 let row: Option<EntityRow> = sqlx::query_as(
251 "SELECT e.id, e.name, e.canonical_name, e.entity_type, e.summary,
252 e.first_seen_at, e.last_seen_at, e.qdrant_point_id
253 FROM graph_entity_aliases a
254 JOIN graph_entities e ON e.id = a.entity_id
255 WHERE a.alias_name = ?1 COLLATE NOCASE
256 AND e.entity_type = ?2
257 ORDER BY e.id ASC
258 LIMIT 1",
259 )
260 .bind(alias_name)
261 .bind(type_str)
262 .fetch_optional(&self.pool)
263 .await?;
264 row.map(entity_from_row).transpose()
265 }
266
267 pub async fn aliases_for_entity(
273 &self,
274 entity_id: i64,
275 ) -> Result<Vec<EntityAlias>, MemoryError> {
276 let rows: Vec<AliasRow> = sqlx::query_as(
277 "SELECT id, entity_id, alias_name, created_at
278 FROM graph_entity_aliases
279 WHERE entity_id = ?1
280 ORDER BY id ASC",
281 )
282 .bind(entity_id)
283 .fetch_all(&self.pool)
284 .await?;
285 Ok(rows.into_iter().map(alias_from_row).collect())
286 }
287
288 pub async fn all_entities(&self) -> Result<Vec<Entity>, MemoryError> {
294 use futures::TryStreamExt as _;
295 self.all_entities_stream().try_collect().await
296 }
297
298 pub async fn entity_count(&self) -> Result<i64, MemoryError> {
304 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM graph_entities")
305 .fetch_one(&self.pool)
306 .await?;
307 Ok(count)
308 }
309
310 pub async fn insert_edge(
323 &self,
324 source_entity_id: i64,
325 target_entity_id: i64,
326 relation: &str,
327 fact: &str,
328 confidence: f32,
329 episode_id: Option<MessageId>,
330 ) -> Result<i64, MemoryError> {
331 let confidence = confidence.clamp(0.0, 1.0);
332
333 let existing: Option<(i64, f64)> = sqlx::query_as(
334 "SELECT id, confidence FROM graph_edges
335 WHERE source_entity_id = ?1
336 AND target_entity_id = ?2
337 AND relation = ?3
338 AND valid_to IS NULL
339 LIMIT 1",
340 )
341 .bind(source_entity_id)
342 .bind(target_entity_id)
343 .bind(relation)
344 .fetch_optional(&self.pool)
345 .await?;
346
347 if let Some((existing_id, stored_conf)) = existing {
348 let updated_conf = f64::from(confidence).max(stored_conf);
349 sqlx::query("UPDATE graph_edges SET confidence = ?1 WHERE id = ?2")
350 .bind(updated_conf)
351 .bind(existing_id)
352 .execute(&self.pool)
353 .await?;
354 return Ok(existing_id);
355 }
356
357 let episode_raw: Option<i64> = episode_id.map(|m| m.0);
358 let id: i64 = sqlx::query_scalar(
359 "INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, episode_id)
360 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
361 RETURNING id",
362 )
363 .bind(source_entity_id)
364 .bind(target_entity_id)
365 .bind(relation)
366 .bind(fact)
367 .bind(f64::from(confidence))
368 .bind(episode_raw)
369 .fetch_one(&self.pool)
370 .await?;
371 Ok(id)
372 }
373
374 pub async fn invalidate_edge(&self, edge_id: i64) -> Result<(), MemoryError> {
380 sqlx::query(
381 "UPDATE graph_edges SET valid_to = datetime('now'), expired_at = datetime('now')
382 WHERE id = ?1",
383 )
384 .bind(edge_id)
385 .execute(&self.pool)
386 .await?;
387 Ok(())
388 }
389
390 pub async fn edges_for_entity(&self, entity_id: i64) -> Result<Vec<Edge>, MemoryError> {
396 let rows: Vec<EdgeRow> = sqlx::query_as(
397 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
398 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
399 FROM graph_edges
400 WHERE valid_to IS NULL
401 AND (source_entity_id = ?1 OR target_entity_id = ?1)",
402 )
403 .bind(entity_id)
404 .fetch_all(&self.pool)
405 .await?;
406 Ok(rows.into_iter().map(edge_from_row).collect())
407 }
408
409 pub async fn edges_between(
415 &self,
416 entity_a: i64,
417 entity_b: i64,
418 ) -> Result<Vec<Edge>, MemoryError> {
419 let rows: Vec<EdgeRow> = sqlx::query_as(
420 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
421 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
422 FROM graph_edges
423 WHERE valid_to IS NULL
424 AND ((source_entity_id = ?1 AND target_entity_id = ?2)
425 OR (source_entity_id = ?2 AND target_entity_id = ?1))",
426 )
427 .bind(entity_a)
428 .bind(entity_b)
429 .fetch_all(&self.pool)
430 .await?;
431 Ok(rows.into_iter().map(edge_from_row).collect())
432 }
433
434 pub async fn edges_exact(
440 &self,
441 source_entity_id: i64,
442 target_entity_id: i64,
443 ) -> Result<Vec<Edge>, MemoryError> {
444 let rows: Vec<EdgeRow> = sqlx::query_as(
445 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
446 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
447 FROM graph_edges
448 WHERE valid_to IS NULL
449 AND source_entity_id = ?1
450 AND target_entity_id = ?2",
451 )
452 .bind(source_entity_id)
453 .bind(target_entity_id)
454 .fetch_all(&self.pool)
455 .await?;
456 Ok(rows.into_iter().map(edge_from_row).collect())
457 }
458
459 pub async fn active_edge_count(&self) -> Result<i64, MemoryError> {
465 let count: i64 =
466 sqlx::query_scalar("SELECT COUNT(*) FROM graph_edges WHERE valid_to IS NULL")
467 .fetch_one(&self.pool)
468 .await?;
469 Ok(count)
470 }
471
472 pub async fn upsert_community(
484 &self,
485 name: &str,
486 summary: &str,
487 entity_ids: &[i64],
488 fingerprint: Option<&str>,
489 ) -> Result<i64, MemoryError> {
490 let entity_ids_json = serde_json::to_string(entity_ids)?;
491 let id: i64 = sqlx::query_scalar(
492 "INSERT INTO graph_communities (name, summary, entity_ids, fingerprint)
493 VALUES (?1, ?2, ?3, ?4)
494 ON CONFLICT(name) DO UPDATE SET
495 summary = excluded.summary,
496 entity_ids = excluded.entity_ids,
497 fingerprint = COALESCE(excluded.fingerprint, fingerprint),
498 updated_at = datetime('now')
499 RETURNING id",
500 )
501 .bind(name)
502 .bind(summary)
503 .bind(entity_ids_json)
504 .bind(fingerprint)
505 .fetch_one(&self.pool)
506 .await?;
507 Ok(id)
508 }
509
510 pub async fn community_fingerprints(&self) -> Result<HashMap<String, i64>, MemoryError> {
517 let rows: Vec<(String, i64)> = sqlx::query_as(
518 "SELECT fingerprint, id FROM graph_communities WHERE fingerprint IS NOT NULL",
519 )
520 .fetch_all(&self.pool)
521 .await?;
522 Ok(rows.into_iter().collect())
523 }
524
525 pub async fn delete_community_by_id(&self, id: i64) -> Result<(), MemoryError> {
531 sqlx::query("DELETE FROM graph_communities WHERE id = ?1")
532 .bind(id)
533 .execute(&self.pool)
534 .await?;
535 Ok(())
536 }
537
538 pub async fn clear_community_fingerprint(&self, id: i64) -> Result<(), MemoryError> {
547 sqlx::query("UPDATE graph_communities SET fingerprint = NULL WHERE id = ?1")
548 .bind(id)
549 .execute(&self.pool)
550 .await?;
551 Ok(())
552 }
553
554 pub async fn community_for_entity(
563 &self,
564 entity_id: i64,
565 ) -> Result<Option<Community>, MemoryError> {
566 let row: Option<CommunityRow> = sqlx::query_as(
567 "SELECT c.id, c.name, c.summary, c.entity_ids, c.fingerprint, c.created_at, c.updated_at
568 FROM graph_communities c, json_each(c.entity_ids) j
569 WHERE CAST(j.value AS INTEGER) = ?1
570 LIMIT 1",
571 )
572 .bind(entity_id)
573 .fetch_optional(&self.pool)
574 .await?;
575 match row {
576 Some(row) => {
577 let entity_ids: Vec<i64> = serde_json::from_str(&row.entity_ids)?;
578 Ok(Some(Community {
579 id: row.id,
580 name: row.name,
581 summary: row.summary,
582 entity_ids,
583 fingerprint: row.fingerprint,
584 created_at: row.created_at,
585 updated_at: row.updated_at,
586 }))
587 }
588 None => Ok(None),
589 }
590 }
591
592 pub async fn all_communities(&self) -> Result<Vec<Community>, MemoryError> {
598 let rows: Vec<CommunityRow> = sqlx::query_as(
599 "SELECT id, name, summary, entity_ids, fingerprint, created_at, updated_at
600 FROM graph_communities
601 ORDER BY id ASC",
602 )
603 .fetch_all(&self.pool)
604 .await?;
605
606 rows.into_iter()
607 .map(|row| {
608 let entity_ids: Vec<i64> = serde_json::from_str(&row.entity_ids)?;
609 Ok(Community {
610 id: row.id,
611 name: row.name,
612 summary: row.summary,
613 entity_ids,
614 fingerprint: row.fingerprint,
615 created_at: row.created_at,
616 updated_at: row.updated_at,
617 })
618 })
619 .collect()
620 }
621
622 pub async fn community_count(&self) -> Result<i64, MemoryError> {
628 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM graph_communities")
629 .fetch_one(&self.pool)
630 .await?;
631 Ok(count)
632 }
633
634 pub async fn get_metadata(&self, key: &str) -> Result<Option<String>, MemoryError> {
642 let val: Option<String> =
643 sqlx::query_scalar("SELECT value FROM graph_metadata WHERE key = ?1")
644 .bind(key)
645 .fetch_optional(&self.pool)
646 .await?;
647 Ok(val)
648 }
649
650 pub async fn set_metadata(&self, key: &str, value: &str) -> Result<(), MemoryError> {
656 sqlx::query(
657 "INSERT INTO graph_metadata (key, value) VALUES (?1, ?2)
658 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
659 )
660 .bind(key)
661 .bind(value)
662 .execute(&self.pool)
663 .await?;
664 Ok(())
665 }
666
667 pub async fn extraction_count(&self) -> Result<i64, MemoryError> {
675 let val = self.get_metadata("extraction_count").await?;
676 Ok(val.and_then(|v| v.parse::<i64>().ok()).unwrap_or(0))
677 }
678
679 pub fn all_active_edges_stream(&self) -> impl Stream<Item = Result<Edge, MemoryError>> + '_ {
681 use futures::StreamExt as _;
682 sqlx::query_as::<_, EdgeRow>(
683 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
684 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
685 FROM graph_edges
686 WHERE valid_to IS NULL
687 ORDER BY id ASC",
688 )
689 .fetch(&self.pool)
690 .map(|r| r.map_err(MemoryError::from).map(edge_from_row))
691 }
692
693 pub async fn edges_after_id(
710 &self,
711 after_id: i64,
712 limit: i64,
713 ) -> Result<Vec<Edge>, MemoryError> {
714 let rows: Vec<EdgeRow> = sqlx::query_as(
715 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
716 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
717 FROM graph_edges
718 WHERE valid_to IS NULL AND id > ?1
719 ORDER BY id ASC
720 LIMIT ?2",
721 )
722 .bind(after_id)
723 .bind(limit)
724 .fetch_all(&self.pool)
725 .await?;
726 Ok(rows.into_iter().map(edge_from_row).collect())
727 }
728
729 pub async fn find_community_by_id(&self, id: i64) -> Result<Option<Community>, MemoryError> {
735 let row: Option<CommunityRow> = sqlx::query_as(
736 "SELECT id, name, summary, entity_ids, fingerprint, created_at, updated_at
737 FROM graph_communities
738 WHERE id = ?1",
739 )
740 .bind(id)
741 .fetch_optional(&self.pool)
742 .await?;
743 match row {
744 Some(row) => {
745 let entity_ids: Vec<i64> = serde_json::from_str(&row.entity_ids)?;
746 Ok(Some(Community {
747 id: row.id,
748 name: row.name,
749 summary: row.summary,
750 entity_ids,
751 fingerprint: row.fingerprint,
752 created_at: row.created_at,
753 updated_at: row.updated_at,
754 }))
755 }
756 None => Ok(None),
757 }
758 }
759
760 pub async fn delete_all_communities(&self) -> Result<(), MemoryError> {
766 sqlx::query("DELETE FROM graph_communities")
767 .execute(&self.pool)
768 .await?;
769 Ok(())
770 }
771
772 pub async fn delete_expired_edges(&self, retention_days: u32) -> Result<usize, MemoryError> {
778 let days = i64::from(retention_days);
779 let result = sqlx::query(
780 "DELETE FROM graph_edges
781 WHERE expired_at IS NOT NULL
782 AND expired_at < datetime('now', '-' || ?1 || ' days')",
783 )
784 .bind(days)
785 .execute(&self.pool)
786 .await?;
787 Ok(usize::try_from(result.rows_affected())?)
788 }
789
790 pub async fn delete_orphan_entities(&self, retention_days: u32) -> Result<usize, MemoryError> {
796 let days = i64::from(retention_days);
797 let result = sqlx::query(
798 "DELETE FROM graph_entities
799 WHERE id NOT IN (
800 SELECT DISTINCT source_entity_id FROM graph_edges WHERE valid_to IS NULL
801 UNION
802 SELECT DISTINCT target_entity_id FROM graph_edges WHERE valid_to IS NULL
803 )
804 AND last_seen_at < datetime('now', '-' || ?1 || ' days')",
805 )
806 .bind(days)
807 .execute(&self.pool)
808 .await?;
809 Ok(usize::try_from(result.rows_affected())?)
810 }
811
812 pub async fn cap_entities(&self, max_entities: usize) -> Result<usize, MemoryError> {
821 let current = self.entity_count().await?;
822 let max = i64::try_from(max_entities)?;
823 if current <= max {
824 return Ok(0);
825 }
826 let excess = current - max;
827 let result = sqlx::query(
828 "DELETE FROM graph_entities
829 WHERE id IN (
830 SELECT e.id
831 FROM graph_entities e
832 LEFT JOIN (
833 SELECT source_entity_id AS eid, COUNT(*) AS cnt
834 FROM graph_edges WHERE valid_to IS NULL GROUP BY source_entity_id
835 UNION ALL
836 SELECT target_entity_id AS eid, COUNT(*) AS cnt
837 FROM graph_edges WHERE valid_to IS NULL GROUP BY target_entity_id
838 ) edge_counts ON e.id = edge_counts.eid
839 ORDER BY COALESCE(edge_counts.cnt, 0) ASC, e.last_seen_at ASC
840 LIMIT ?1
841 )",
842 )
843 .bind(excess)
844 .execute(&self.pool)
845 .await?;
846 Ok(usize::try_from(result.rows_affected())?)
847 }
848
849 pub async fn bfs(
866 &self,
867 start_entity_id: i64,
868 max_hops: u32,
869 ) -> Result<(Vec<Entity>, Vec<Edge>), MemoryError> {
870 self.bfs_with_depth(start_entity_id, max_hops)
871 .await
872 .map(|(e, ed, _)| (e, ed))
873 }
874
875 pub async fn bfs_with_depth(
886 &self,
887 start_entity_id: i64,
888 max_hops: u32,
889 ) -> Result<(Vec<Entity>, Vec<Edge>, std::collections::HashMap<i64, u32>), MemoryError> {
890 use std::collections::HashMap;
891
892 const MAX_FRONTIER: usize = 300;
895
896 let mut depth_map: HashMap<i64, u32> = HashMap::new();
897 let mut frontier: Vec<i64> = vec![start_entity_id];
898 depth_map.insert(start_entity_id, 0);
899
900 for hop in 0..max_hops {
901 if frontier.is_empty() {
902 break;
903 }
904 frontier.truncate(MAX_FRONTIER);
905 let placeholders = frontier
907 .iter()
908 .enumerate()
909 .map(|(i, _)| format!("?{}", i + 1))
910 .collect::<Vec<_>>()
911 .join(", ");
912 let neighbour_sql = format!(
913 "SELECT DISTINCT CASE
914 WHEN source_entity_id IN ({placeholders}) THEN target_entity_id
915 ELSE source_entity_id
916 END as neighbour_id
917 FROM graph_edges
918 WHERE valid_to IS NULL
919 AND (source_entity_id IN ({placeholders}) OR target_entity_id IN ({placeholders}))"
920 );
921 let mut q = sqlx::query_scalar::<_, i64>(&neighbour_sql);
922 for id in &frontier {
923 q = q.bind(*id);
924 }
925 for id in &frontier {
926 q = q.bind(*id);
927 }
928 for id in &frontier {
929 q = q.bind(*id);
930 }
931 let neighbours: Vec<i64> = q.fetch_all(&self.pool).await?;
932
933 let mut next_frontier: Vec<i64> = Vec::new();
934 for nbr in neighbours {
935 if let std::collections::hash_map::Entry::Vacant(e) = depth_map.entry(nbr) {
936 e.insert(hop + 1);
937 next_frontier.push(nbr);
938 }
939 }
940 frontier = next_frontier;
941 }
942
943 let mut visited_ids: Vec<i64> = depth_map.keys().copied().collect();
944 if visited_ids.is_empty() {
945 return Ok((Vec::new(), Vec::new(), depth_map));
946 }
947 visited_ids.truncate(499);
949
950 let placeholders = visited_ids
952 .iter()
953 .enumerate()
954 .map(|(i, _)| format!("?{}", i + 1))
955 .collect::<Vec<_>>()
956 .join(", ");
957
958 let edge_sql = format!(
959 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
960 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
961 FROM graph_edges
962 WHERE valid_to IS NULL
963 AND source_entity_id IN ({placeholders})
964 AND target_entity_id IN ({placeholders})"
965 );
966 let mut edge_query = sqlx::query_as::<_, EdgeRow>(&edge_sql);
967 for id in &visited_ids {
968 edge_query = edge_query.bind(*id);
969 }
970 for id in &visited_ids {
971 edge_query = edge_query.bind(*id);
972 }
973 let edge_rows: Vec<EdgeRow> = edge_query.fetch_all(&self.pool).await?;
974
975 let entity_sql = format!(
976 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
977 FROM graph_entities WHERE id IN ({placeholders})"
978 );
979 let mut entity_query = sqlx::query_as::<_, EntityRow>(&entity_sql);
980 for id in &visited_ids {
981 entity_query = entity_query.bind(*id);
982 }
983 let entity_rows: Vec<EntityRow> = entity_query.fetch_all(&self.pool).await?;
984
985 let entities: Vec<Entity> = entity_rows
986 .into_iter()
987 .map(entity_from_row)
988 .collect::<Result<Vec<_>, _>>()?;
989 let edges: Vec<Edge> = edge_rows.into_iter().map(edge_from_row).collect();
990
991 Ok((entities, edges, depth_map))
992 }
993
994 pub async fn find_entity_by_name(&self, name: &str) -> Result<Vec<Entity>, MemoryError> {
1010 let rows: Vec<EntityRow> = sqlx::query_as(
1011 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
1012 FROM graph_entities
1013 WHERE name = ?1 COLLATE NOCASE OR canonical_name = ?1 COLLATE NOCASE
1014 LIMIT 5",
1015 )
1016 .bind(name)
1017 .fetch_all(&self.pool)
1018 .await?;
1019
1020 if !rows.is_empty() {
1021 return rows.into_iter().map(entity_from_row).collect();
1022 }
1023
1024 self.find_entities_fuzzy(name, 5).await
1025 }
1026
1027 pub async fn unprocessed_messages_for_backfill(
1035 &self,
1036 limit: usize,
1037 ) -> Result<Vec<(crate::types::MessageId, String)>, MemoryError> {
1038 let limit = i64::try_from(limit)?;
1039 let rows: Vec<(i64, String)> = sqlx::query_as(
1040 "SELECT id, content FROM messages
1041 WHERE graph_processed = 0
1042 ORDER BY id ASC
1043 LIMIT ?1",
1044 )
1045 .bind(limit)
1046 .fetch_all(&self.pool)
1047 .await?;
1048 Ok(rows
1049 .into_iter()
1050 .map(|(id, content)| (crate::types::MessageId(id), content))
1051 .collect())
1052 }
1053
1054 pub async fn unprocessed_message_count(&self) -> Result<i64, MemoryError> {
1060 let count: i64 =
1061 sqlx::query_scalar("SELECT COUNT(*) FROM messages WHERE graph_processed = 0")
1062 .fetch_one(&self.pool)
1063 .await?;
1064 Ok(count)
1065 }
1066
1067 pub async fn mark_messages_graph_processed(
1073 &self,
1074 ids: &[crate::types::MessageId],
1075 ) -> Result<(), MemoryError> {
1076 if ids.is_empty() {
1077 return Ok(());
1078 }
1079 let placeholders = ids
1080 .iter()
1081 .enumerate()
1082 .map(|(i, _)| format!("?{}", i + 1))
1083 .collect::<Vec<_>>()
1084 .join(", ");
1085 let sql = format!("UPDATE messages SET graph_processed = 1 WHERE id IN ({placeholders})");
1086 let mut query = sqlx::query(&sql);
1087 for id in ids {
1088 query = query.bind(id.0);
1089 }
1090 query.execute(&self.pool).await?;
1091 Ok(())
1092 }
1093}
1094
1095#[derive(sqlx::FromRow)]
1098struct EntityRow {
1099 id: i64,
1100 name: String,
1101 canonical_name: String,
1102 entity_type: String,
1103 summary: Option<String>,
1104 first_seen_at: String,
1105 last_seen_at: String,
1106 qdrant_point_id: Option<String>,
1107}
1108
1109fn entity_from_row(row: EntityRow) -> Result<Entity, MemoryError> {
1110 let entity_type = row
1111 .entity_type
1112 .parse::<EntityType>()
1113 .map_err(MemoryError::GraphStore)?;
1114 Ok(Entity {
1115 id: row.id,
1116 name: row.name,
1117 canonical_name: row.canonical_name,
1118 entity_type,
1119 summary: row.summary,
1120 first_seen_at: row.first_seen_at,
1121 last_seen_at: row.last_seen_at,
1122 qdrant_point_id: row.qdrant_point_id,
1123 })
1124}
1125
1126#[derive(sqlx::FromRow)]
1127struct AliasRow {
1128 id: i64,
1129 entity_id: i64,
1130 alias_name: String,
1131 created_at: String,
1132}
1133
1134fn alias_from_row(row: AliasRow) -> EntityAlias {
1135 EntityAlias {
1136 id: row.id,
1137 entity_id: row.entity_id,
1138 alias_name: row.alias_name,
1139 created_at: row.created_at,
1140 }
1141}
1142
1143#[derive(sqlx::FromRow)]
1144struct EdgeRow {
1145 id: i64,
1146 source_entity_id: i64,
1147 target_entity_id: i64,
1148 relation: String,
1149 fact: String,
1150 confidence: f64,
1151 valid_from: String,
1152 valid_to: Option<String>,
1153 created_at: String,
1154 expired_at: Option<String>,
1155 episode_id: Option<i64>,
1156 qdrant_point_id: Option<String>,
1157}
1158
1159fn edge_from_row(row: EdgeRow) -> Edge {
1160 Edge {
1161 id: row.id,
1162 source_entity_id: row.source_entity_id,
1163 target_entity_id: row.target_entity_id,
1164 relation: row.relation,
1165 fact: row.fact,
1166 #[allow(clippy::cast_possible_truncation)]
1167 confidence: row.confidence as f32,
1168 valid_from: row.valid_from,
1169 valid_to: row.valid_to,
1170 created_at: row.created_at,
1171 expired_at: row.expired_at,
1172 episode_id: row.episode_id.map(MessageId),
1173 qdrant_point_id: row.qdrant_point_id,
1174 }
1175}
1176
1177#[derive(sqlx::FromRow)]
1178struct CommunityRow {
1179 id: i64,
1180 name: String,
1181 summary: String,
1182 entity_ids: String,
1183 fingerprint: Option<String>,
1184 created_at: String,
1185 updated_at: String,
1186}
1187
1188#[cfg(test)]
1191mod tests {
1192 use super::*;
1193 use crate::sqlite::SqliteStore;
1194
1195 async fn setup() -> GraphStore {
1196 let store = SqliteStore::new(":memory:").await.unwrap();
1197 GraphStore::new(store.pool().clone())
1198 }
1199
1200 #[tokio::test]
1201 async fn upsert_entity_insert_new() {
1202 let gs = setup().await;
1203 let id = gs
1204 .upsert_entity("Alice", "Alice", EntityType::Person, Some("a person"))
1205 .await
1206 .unwrap();
1207 assert!(id > 0);
1208 }
1209
1210 #[tokio::test]
1211 async fn upsert_entity_update_existing() {
1212 let gs = setup().await;
1213 let id1 = gs
1214 .upsert_entity("Alice", "Alice", EntityType::Person, None)
1215 .await
1216 .unwrap();
1217 let id2 = gs
1220 .upsert_entity("Alice", "Alice", EntityType::Person, Some("updated"))
1221 .await
1222 .unwrap();
1223 assert_eq!(id1, id2);
1224 let entity = gs
1225 .find_entity("Alice", EntityType::Person)
1226 .await
1227 .unwrap()
1228 .unwrap();
1229 assert_eq!(entity.summary.as_deref(), Some("updated"));
1230 }
1231
1232 #[tokio::test]
1233 async fn find_entity_found() {
1234 let gs = setup().await;
1235 gs.upsert_entity("Bob", "Bob", EntityType::Tool, Some("a tool"))
1236 .await
1237 .unwrap();
1238 let entity = gs
1239 .find_entity("Bob", EntityType::Tool)
1240 .await
1241 .unwrap()
1242 .unwrap();
1243 assert_eq!(entity.name, "Bob");
1244 assert_eq!(entity.entity_type, EntityType::Tool);
1245 }
1246
1247 #[tokio::test]
1248 async fn find_entity_not_found() {
1249 let gs = setup().await;
1250 let result = gs.find_entity("Nobody", EntityType::Person).await.unwrap();
1251 assert!(result.is_none());
1252 }
1253
1254 #[tokio::test]
1255 async fn find_entities_fuzzy_partial_match() {
1256 let gs = setup().await;
1257 gs.upsert_entity("GraphQL", "GraphQL", EntityType::Concept, None)
1258 .await
1259 .unwrap();
1260 gs.upsert_entity("Graph", "Graph", EntityType::Concept, None)
1261 .await
1262 .unwrap();
1263 gs.upsert_entity("Unrelated", "Unrelated", EntityType::Concept, None)
1264 .await
1265 .unwrap();
1266
1267 let results = gs.find_entities_fuzzy("graph", 10).await.unwrap();
1268 assert_eq!(results.len(), 2);
1269 assert!(results.iter().any(|e| e.name == "GraphQL"));
1270 assert!(results.iter().any(|e| e.name == "Graph"));
1271 }
1272
1273 #[tokio::test]
1274 async fn entity_count_empty() {
1275 let gs = setup().await;
1276 assert_eq!(gs.entity_count().await.unwrap(), 0);
1277 }
1278
1279 #[tokio::test]
1280 async fn entity_count_non_empty() {
1281 let gs = setup().await;
1282 gs.upsert_entity("A", "A", EntityType::Concept, None)
1283 .await
1284 .unwrap();
1285 gs.upsert_entity("B", "B", EntityType::Concept, None)
1286 .await
1287 .unwrap();
1288 assert_eq!(gs.entity_count().await.unwrap(), 2);
1289 }
1290
1291 #[tokio::test]
1292 async fn all_entities_and_stream() {
1293 use futures::StreamExt as _;
1294
1295 let gs = setup().await;
1296 gs.upsert_entity("X", "X", EntityType::Project, None)
1297 .await
1298 .unwrap();
1299 gs.upsert_entity("Y", "Y", EntityType::Language, None)
1300 .await
1301 .unwrap();
1302
1303 let all = gs.all_entities().await.unwrap();
1304 assert_eq!(all.len(), 2);
1305 let streamed: Vec<Result<Entity, _>> = gs.all_entities_stream().collect().await;
1306 assert_eq!(streamed.len(), 2);
1307 assert!(streamed.iter().all(Result::is_ok));
1308 }
1309
1310 #[tokio::test]
1311 async fn insert_edge_without_episode() {
1312 let gs = setup().await;
1313 let src = gs
1314 .upsert_entity("Src", "Src", EntityType::Concept, None)
1315 .await
1316 .unwrap();
1317 let tgt = gs
1318 .upsert_entity("Tgt", "Tgt", EntityType::Concept, None)
1319 .await
1320 .unwrap();
1321 let eid = gs
1322 .insert_edge(src, tgt, "relates_to", "Src relates to Tgt", 0.9, None)
1323 .await
1324 .unwrap();
1325 assert!(eid > 0);
1326 }
1327
1328 #[tokio::test]
1329 async fn insert_edge_deduplicates_active_edge() {
1330 let gs = setup().await;
1331 let src = gs
1332 .upsert_entity("Alice", "Alice", EntityType::Person, None)
1333 .await
1334 .unwrap();
1335 let tgt = gs
1336 .upsert_entity("Google", "Google", EntityType::Organization, None)
1337 .await
1338 .unwrap();
1339
1340 let id1 = gs
1341 .insert_edge(src, tgt, "works_at", "Alice works at Google", 0.7, None)
1342 .await
1343 .unwrap();
1344
1345 let id2 = gs
1347 .insert_edge(src, tgt, "works_at", "Alice works at Google", 0.9, None)
1348 .await
1349 .unwrap();
1350 assert_eq!(id1, id2, "duplicate active edge must not be created");
1351
1352 let count: i64 =
1354 sqlx::query_scalar("SELECT COUNT(*) FROM graph_edges WHERE valid_to IS NULL")
1355 .fetch_one(&gs.pool)
1356 .await
1357 .unwrap();
1358 assert_eq!(count, 1, "only one active edge must exist");
1359
1360 let conf: f64 = sqlx::query_scalar("SELECT confidence FROM graph_edges WHERE id = ?1")
1361 .bind(id1)
1362 .fetch_one(&gs.pool)
1363 .await
1364 .unwrap();
1365 assert!(
1367 (conf - f64::from(0.9_f32)).abs() < 1e-6,
1368 "confidence must be updated to max, got {conf}"
1369 );
1370 }
1371
1372 #[tokio::test]
1373 async fn insert_edge_different_relations_are_distinct() {
1374 let gs = setup().await;
1375 let src = gs
1376 .upsert_entity("Bob", "Bob", EntityType::Person, None)
1377 .await
1378 .unwrap();
1379 let tgt = gs
1380 .upsert_entity("Acme", "Acme", EntityType::Organization, None)
1381 .await
1382 .unwrap();
1383
1384 let id1 = gs
1385 .insert_edge(src, tgt, "founded", "Bob founded Acme", 0.8, None)
1386 .await
1387 .unwrap();
1388 let id2 = gs
1389 .insert_edge(src, tgt, "chairs", "Bob chairs Acme", 0.8, None)
1390 .await
1391 .unwrap();
1392 assert_ne!(id1, id2, "different relations must produce distinct edges");
1393
1394 let count: i64 =
1395 sqlx::query_scalar("SELECT COUNT(*) FROM graph_edges WHERE valid_to IS NULL")
1396 .fetch_one(&gs.pool)
1397 .await
1398 .unwrap();
1399 assert_eq!(count, 2);
1400 }
1401
1402 #[tokio::test]
1403 async fn insert_edge_with_episode() {
1404 let gs = setup().await;
1405 let src = gs
1406 .upsert_entity("Src2", "Src2", EntityType::Concept, None)
1407 .await
1408 .unwrap();
1409 let tgt = gs
1410 .upsert_entity("Tgt2", "Tgt2", EntityType::Concept, None)
1411 .await
1412 .unwrap();
1413 let episode = MessageId(999);
1419 let result = gs
1420 .insert_edge(src, tgt, "uses", "Src2 uses Tgt2", 1.0, Some(episode))
1421 .await;
1422 match &result {
1423 Ok(eid) => assert!(*eid > 0, "inserted edge should have positive id"),
1424 Err(MemoryError::Sqlite(_)) => {} Err(e) => panic!("unexpected error: {e}"),
1426 }
1427 }
1428
1429 #[tokio::test]
1430 async fn invalidate_edge_sets_timestamps() {
1431 let gs = setup().await;
1432 let src = gs
1433 .upsert_entity("E1", "E1", EntityType::Concept, None)
1434 .await
1435 .unwrap();
1436 let tgt = gs
1437 .upsert_entity("E2", "E2", EntityType::Concept, None)
1438 .await
1439 .unwrap();
1440 let eid = gs
1441 .insert_edge(src, tgt, "r", "fact", 1.0, None)
1442 .await
1443 .unwrap();
1444 gs.invalidate_edge(eid).await.unwrap();
1445
1446 let row: (Option<String>, Option<String>) =
1447 sqlx::query_as("SELECT valid_to, expired_at FROM graph_edges WHERE id = ?1")
1448 .bind(eid)
1449 .fetch_one(&gs.pool)
1450 .await
1451 .unwrap();
1452 assert!(row.0.is_some(), "valid_to should be set");
1453 assert!(row.1.is_some(), "expired_at should be set");
1454 }
1455
1456 #[tokio::test]
1457 async fn edges_for_entity_both_directions() {
1458 let gs = setup().await;
1459 let a = gs
1460 .upsert_entity("A", "A", EntityType::Concept, None)
1461 .await
1462 .unwrap();
1463 let b = gs
1464 .upsert_entity("B", "B", EntityType::Concept, None)
1465 .await
1466 .unwrap();
1467 let c = gs
1468 .upsert_entity("C", "C", EntityType::Concept, None)
1469 .await
1470 .unwrap();
1471 gs.insert_edge(a, b, "r", "f1", 1.0, None).await.unwrap();
1472 gs.insert_edge(c, a, "r", "f2", 1.0, None).await.unwrap();
1473
1474 let edges = gs.edges_for_entity(a).await.unwrap();
1475 assert_eq!(edges.len(), 2);
1476 }
1477
1478 #[tokio::test]
1479 async fn edges_between_both_directions() {
1480 let gs = setup().await;
1481 let a = gs
1482 .upsert_entity("PA", "PA", EntityType::Person, None)
1483 .await
1484 .unwrap();
1485 let b = gs
1486 .upsert_entity("PB", "PB", EntityType::Person, None)
1487 .await
1488 .unwrap();
1489 gs.insert_edge(a, b, "knows", "PA knows PB", 1.0, None)
1490 .await
1491 .unwrap();
1492
1493 let fwd = gs.edges_between(a, b).await.unwrap();
1494 assert_eq!(fwd.len(), 1);
1495 let rev = gs.edges_between(b, a).await.unwrap();
1496 assert_eq!(rev.len(), 1);
1497 }
1498
1499 #[tokio::test]
1500 async fn active_edge_count_excludes_invalidated() {
1501 let gs = setup().await;
1502 let a = gs
1503 .upsert_entity("N1", "N1", EntityType::Concept, None)
1504 .await
1505 .unwrap();
1506 let b = gs
1507 .upsert_entity("N2", "N2", EntityType::Concept, None)
1508 .await
1509 .unwrap();
1510 let e1 = gs.insert_edge(a, b, "r1", "f1", 1.0, None).await.unwrap();
1511 gs.insert_edge(a, b, "r2", "f2", 1.0, None).await.unwrap();
1512 gs.invalidate_edge(e1).await.unwrap();
1513
1514 assert_eq!(gs.active_edge_count().await.unwrap(), 1);
1515 }
1516
1517 #[tokio::test]
1518 async fn upsert_community_insert_and_update() {
1519 let gs = setup().await;
1520 let id1 = gs
1521 .upsert_community("clusterA", "summary A", &[1, 2, 3], None)
1522 .await
1523 .unwrap();
1524 assert!(id1 > 0);
1525 let id2 = gs
1526 .upsert_community("clusterA", "summary A updated", &[1, 2, 3, 4], None)
1527 .await
1528 .unwrap();
1529 assert_eq!(id1, id2);
1530 let communities = gs.all_communities().await.unwrap();
1531 assert_eq!(communities.len(), 1);
1532 assert_eq!(communities[0].summary, "summary A updated");
1533 assert_eq!(communities[0].entity_ids, vec![1, 2, 3, 4]);
1534 }
1535
1536 #[tokio::test]
1537 async fn community_for_entity_found() {
1538 let gs = setup().await;
1539 let a = gs
1540 .upsert_entity("CA", "CA", EntityType::Concept, None)
1541 .await
1542 .unwrap();
1543 let b = gs
1544 .upsert_entity("CB", "CB", EntityType::Concept, None)
1545 .await
1546 .unwrap();
1547 gs.upsert_community("cA", "summary", &[a, b], None)
1548 .await
1549 .unwrap();
1550 let result = gs.community_for_entity(a).await.unwrap();
1551 assert!(result.is_some());
1552 assert_eq!(result.unwrap().name, "cA");
1553 }
1554
1555 #[tokio::test]
1556 async fn community_for_entity_not_found() {
1557 let gs = setup().await;
1558 let result = gs.community_for_entity(999).await.unwrap();
1559 assert!(result.is_none());
1560 }
1561
1562 #[tokio::test]
1563 async fn community_count() {
1564 let gs = setup().await;
1565 assert_eq!(gs.community_count().await.unwrap(), 0);
1566 gs.upsert_community("c1", "s1", &[], None).await.unwrap();
1567 gs.upsert_community("c2", "s2", &[], None).await.unwrap();
1568 assert_eq!(gs.community_count().await.unwrap(), 2);
1569 }
1570
1571 #[tokio::test]
1572 async fn metadata_get_set_round_trip() {
1573 let gs = setup().await;
1574 assert_eq!(gs.get_metadata("counter").await.unwrap(), None);
1575 gs.set_metadata("counter", "42").await.unwrap();
1576 assert_eq!(gs.get_metadata("counter").await.unwrap(), Some("42".into()));
1577 gs.set_metadata("counter", "43").await.unwrap();
1578 assert_eq!(gs.get_metadata("counter").await.unwrap(), Some("43".into()));
1579 }
1580
1581 #[tokio::test]
1582 async fn bfs_max_hops_0_returns_only_start() {
1583 let gs = setup().await;
1584 let a = gs
1585 .upsert_entity("BfsA", "BfsA", EntityType::Concept, None)
1586 .await
1587 .unwrap();
1588 let b = gs
1589 .upsert_entity("BfsB", "BfsB", EntityType::Concept, None)
1590 .await
1591 .unwrap();
1592 gs.insert_edge(a, b, "r", "f", 1.0, None).await.unwrap();
1593
1594 let (entities, edges) = gs.bfs(a, 0).await.unwrap();
1595 assert_eq!(entities.len(), 1);
1596 assert_eq!(entities[0].id, a);
1597 assert!(edges.is_empty());
1598 }
1599
1600 #[tokio::test]
1601 async fn bfs_max_hops_2_chain() {
1602 let gs = setup().await;
1603 let a = gs
1604 .upsert_entity("ChainA", "ChainA", EntityType::Concept, None)
1605 .await
1606 .unwrap();
1607 let b = gs
1608 .upsert_entity("ChainB", "ChainB", EntityType::Concept, None)
1609 .await
1610 .unwrap();
1611 let c = gs
1612 .upsert_entity("ChainC", "ChainC", EntityType::Concept, None)
1613 .await
1614 .unwrap();
1615 gs.insert_edge(a, b, "r", "f1", 1.0, None).await.unwrap();
1616 gs.insert_edge(b, c, "r", "f2", 1.0, None).await.unwrap();
1617
1618 let (entities, edges) = gs.bfs(a, 2).await.unwrap();
1619 let ids: Vec<_> = entities.iter().map(|e| e.id).collect();
1620 assert!(ids.contains(&a));
1621 assert!(ids.contains(&b));
1622 assert!(ids.contains(&c));
1623 assert_eq!(edges.len(), 2);
1624 }
1625
1626 #[tokio::test]
1627 async fn bfs_cycle_no_infinite_loop() {
1628 let gs = setup().await;
1629 let a = gs
1630 .upsert_entity("CycA", "CycA", EntityType::Concept, None)
1631 .await
1632 .unwrap();
1633 let b = gs
1634 .upsert_entity("CycB", "CycB", EntityType::Concept, None)
1635 .await
1636 .unwrap();
1637 gs.insert_edge(a, b, "r", "f1", 1.0, None).await.unwrap();
1638 gs.insert_edge(b, a, "r", "f2", 1.0, None).await.unwrap();
1639
1640 let (entities, _edges) = gs.bfs(a, 3).await.unwrap();
1641 let ids: Vec<_> = entities.iter().map(|e| e.id).collect();
1642 assert!(ids.contains(&a));
1644 assert!(ids.contains(&b));
1645 assert_eq!(ids.len(), 2);
1646 }
1647
1648 #[tokio::test]
1649 async fn test_invalidated_edges_excluded_from_bfs() {
1650 let gs = setup().await;
1651 let a = gs
1652 .upsert_entity("InvA", "InvA", EntityType::Concept, None)
1653 .await
1654 .unwrap();
1655 let b = gs
1656 .upsert_entity("InvB", "InvB", EntityType::Concept, None)
1657 .await
1658 .unwrap();
1659 let c = gs
1660 .upsert_entity("InvC", "InvC", EntityType::Concept, None)
1661 .await
1662 .unwrap();
1663 let ab = gs.insert_edge(a, b, "r", "f1", 1.0, None).await.unwrap();
1664 gs.insert_edge(b, c, "r", "f2", 1.0, None).await.unwrap();
1665 gs.invalidate_edge(ab).await.unwrap();
1667
1668 let (entities, edges) = gs.bfs(a, 2).await.unwrap();
1669 let ids: Vec<_> = entities.iter().map(|e| e.id).collect();
1670 assert_eq!(ids, vec![a], "only start entity should be reachable");
1671 assert!(edges.is_empty(), "no active edges should be returned");
1672 }
1673
1674 #[tokio::test]
1675 async fn test_bfs_empty_graph() {
1676 let gs = setup().await;
1677 let a = gs
1678 .upsert_entity("IsoA", "IsoA", EntityType::Concept, None)
1679 .await
1680 .unwrap();
1681
1682 let (entities, edges) = gs.bfs(a, 2).await.unwrap();
1683 let ids: Vec<_> = entities.iter().map(|e| e.id).collect();
1684 assert_eq!(ids, vec![a], "isolated node: only start entity returned");
1685 assert!(edges.is_empty(), "no edges for isolated node");
1686 }
1687
1688 #[tokio::test]
1689 async fn test_bfs_diamond() {
1690 let gs = setup().await;
1691 let a = gs
1692 .upsert_entity("DiamA", "DiamA", EntityType::Concept, None)
1693 .await
1694 .unwrap();
1695 let b = gs
1696 .upsert_entity("DiamB", "DiamB", EntityType::Concept, None)
1697 .await
1698 .unwrap();
1699 let c = gs
1700 .upsert_entity("DiamC", "DiamC", EntityType::Concept, None)
1701 .await
1702 .unwrap();
1703 let d = gs
1704 .upsert_entity("DiamD", "DiamD", EntityType::Concept, None)
1705 .await
1706 .unwrap();
1707 gs.insert_edge(a, b, "r", "f1", 1.0, None).await.unwrap();
1708 gs.insert_edge(a, c, "r", "f2", 1.0, None).await.unwrap();
1709 gs.insert_edge(b, d, "r", "f3", 1.0, None).await.unwrap();
1710 gs.insert_edge(c, d, "r", "f4", 1.0, None).await.unwrap();
1711
1712 let (entities, edges) = gs.bfs(a, 2).await.unwrap();
1713 let mut ids: Vec<_> = entities.iter().map(|e| e.id).collect();
1714 ids.sort_unstable();
1715 let mut expected = vec![a, b, c, d];
1716 expected.sort_unstable();
1717 assert_eq!(ids, expected, "all 4 nodes reachable, no duplicates");
1718 assert_eq!(edges.len(), 4, "all 4 edges returned");
1719 }
1720
1721 #[tokio::test]
1722 async fn extraction_count_default_zero() {
1723 let gs = setup().await;
1724 assert_eq!(gs.extraction_count().await.unwrap(), 0);
1725 }
1726
1727 #[tokio::test]
1728 async fn extraction_count_after_set() {
1729 let gs = setup().await;
1730 gs.set_metadata("extraction_count", "7").await.unwrap();
1731 assert_eq!(gs.extraction_count().await.unwrap(), 7);
1732 }
1733
1734 #[tokio::test]
1735 async fn all_active_edges_stream_excludes_invalidated() {
1736 use futures::TryStreamExt as _;
1737 let gs = setup().await;
1738 let a = gs
1739 .upsert_entity("SA", "SA", EntityType::Concept, None)
1740 .await
1741 .unwrap();
1742 let b = gs
1743 .upsert_entity("SB", "SB", EntityType::Concept, None)
1744 .await
1745 .unwrap();
1746 let c = gs
1747 .upsert_entity("SC", "SC", EntityType::Concept, None)
1748 .await
1749 .unwrap();
1750 let e1 = gs.insert_edge(a, b, "r", "f1", 1.0, None).await.unwrap();
1751 gs.insert_edge(b, c, "r", "f2", 1.0, None).await.unwrap();
1752 gs.invalidate_edge(e1).await.unwrap();
1753
1754 let edges: Vec<_> = gs.all_active_edges_stream().try_collect().await.unwrap();
1755 assert_eq!(edges.len(), 1, "only the active edge should be returned");
1756 assert_eq!(edges[0].source_entity_id, b);
1757 assert_eq!(edges[0].target_entity_id, c);
1758 }
1759
1760 #[tokio::test]
1761 async fn find_community_by_id_found_and_not_found() {
1762 let gs = setup().await;
1763 let cid = gs
1764 .upsert_community("grp", "summary", &[1, 2], None)
1765 .await
1766 .unwrap();
1767 let found = gs.find_community_by_id(cid).await.unwrap();
1768 assert!(found.is_some());
1769 assert_eq!(found.unwrap().name, "grp");
1770
1771 let missing = gs.find_community_by_id(9999).await.unwrap();
1772 assert!(missing.is_none());
1773 }
1774
1775 #[tokio::test]
1776 async fn delete_all_communities_clears_table() {
1777 let gs = setup().await;
1778 gs.upsert_community("c1", "s1", &[1], None).await.unwrap();
1779 gs.upsert_community("c2", "s2", &[2], None).await.unwrap();
1780 assert_eq!(gs.community_count().await.unwrap(), 2);
1781 gs.delete_all_communities().await.unwrap();
1782 assert_eq!(gs.community_count().await.unwrap(), 0);
1783 }
1784
1785 #[tokio::test]
1786 async fn test_find_entities_fuzzy_no_results() {
1787 let gs = setup().await;
1788 gs.upsert_entity("Alpha", "Alpha", EntityType::Concept, None)
1789 .await
1790 .unwrap();
1791 let results = gs.find_entities_fuzzy("zzzznonexistent", 10).await.unwrap();
1792 assert!(
1793 results.is_empty(),
1794 "no entities should match an unknown term"
1795 );
1796 }
1797
1798 #[tokio::test]
1801 async fn upsert_entity_stores_canonical_name() {
1802 let gs = setup().await;
1803 gs.upsert_entity("rust", "rust", EntityType::Language, None)
1804 .await
1805 .unwrap();
1806 let entity = gs
1807 .find_entity("rust", EntityType::Language)
1808 .await
1809 .unwrap()
1810 .unwrap();
1811 assert_eq!(entity.canonical_name, "rust");
1812 assert_eq!(entity.name, "rust");
1813 }
1814
1815 #[tokio::test]
1816 async fn add_alias_idempotent() {
1817 let gs = setup().await;
1818 let id = gs
1819 .upsert_entity("rust", "rust", EntityType::Language, None)
1820 .await
1821 .unwrap();
1822 gs.add_alias(id, "rust-lang").await.unwrap();
1823 gs.add_alias(id, "rust-lang").await.unwrap();
1825 let aliases = gs.aliases_for_entity(id).await.unwrap();
1826 assert_eq!(
1827 aliases
1828 .iter()
1829 .filter(|a| a.alias_name == "rust-lang")
1830 .count(),
1831 1
1832 );
1833 }
1834
1835 #[tokio::test]
1838 async fn find_entity_by_id_found() {
1839 let gs = setup().await;
1840 let id = gs
1841 .upsert_entity("FindById", "finbyid", EntityType::Concept, Some("summary"))
1842 .await
1843 .unwrap();
1844 let entity = gs.find_entity_by_id(id).await.unwrap();
1845 assert!(entity.is_some());
1846 let entity = entity.unwrap();
1847 assert_eq!(entity.id, id);
1848 assert_eq!(entity.name, "FindById");
1849 }
1850
1851 #[tokio::test]
1852 async fn find_entity_by_id_not_found() {
1853 let gs = setup().await;
1854 let result = gs.find_entity_by_id(99999).await.unwrap();
1855 assert!(result.is_none());
1856 }
1857
1858 #[tokio::test]
1859 async fn set_entity_qdrant_point_id_updates() {
1860 let gs = setup().await;
1861 let id = gs
1862 .upsert_entity("QdrantPoint", "qdrantpoint", EntityType::Concept, None)
1863 .await
1864 .unwrap();
1865 let point_id = "550e8400-e29b-41d4-a716-446655440000";
1866 gs.set_entity_qdrant_point_id(id, point_id).await.unwrap();
1867
1868 let entity = gs.find_entity_by_id(id).await.unwrap().unwrap();
1869 assert_eq!(entity.qdrant_point_id.as_deref(), Some(point_id));
1870 }
1871
1872 #[tokio::test]
1873 async fn find_entities_fuzzy_matches_summary() {
1874 let gs = setup().await;
1875 gs.upsert_entity(
1876 "Rust",
1877 "Rust",
1878 EntityType::Language,
1879 Some("a systems programming language"),
1880 )
1881 .await
1882 .unwrap();
1883 gs.upsert_entity(
1884 "Go",
1885 "Go",
1886 EntityType::Language,
1887 Some("a compiled language by Google"),
1888 )
1889 .await
1890 .unwrap();
1891 let results = gs.find_entities_fuzzy("systems", 10).await.unwrap();
1893 assert_eq!(results.len(), 1);
1894 assert_eq!(results[0].name, "Rust");
1895 }
1896
1897 #[tokio::test]
1898 async fn find_entities_fuzzy_empty_query() {
1899 let gs = setup().await;
1900 gs.upsert_entity("Alpha", "Alpha", EntityType::Concept, None)
1901 .await
1902 .unwrap();
1903 let results = gs.find_entities_fuzzy("", 10).await.unwrap();
1905 assert!(results.is_empty(), "empty query should return no results");
1906 let results = gs.find_entities_fuzzy(" ", 10).await.unwrap();
1908 assert!(
1909 results.is_empty(),
1910 "whitespace query should return no results"
1911 );
1912 }
1913
1914 #[tokio::test]
1915 async fn find_entity_by_alias_case_insensitive() {
1916 let gs = setup().await;
1917 let id = gs
1918 .upsert_entity("rust", "rust", EntityType::Language, None)
1919 .await
1920 .unwrap();
1921 gs.add_alias(id, "rust").await.unwrap();
1922 gs.add_alias(id, "rust-lang").await.unwrap();
1923
1924 let found = gs
1925 .find_entity_by_alias("RUST-LANG", EntityType::Language)
1926 .await
1927 .unwrap();
1928 assert!(found.is_some());
1929 assert_eq!(found.unwrap().id, id);
1930 }
1931
1932 #[tokio::test]
1933 async fn find_entity_by_alias_returns_none_for_unknown() {
1934 let gs = setup().await;
1935 let id = gs
1936 .upsert_entity("rust", "rust", EntityType::Language, None)
1937 .await
1938 .unwrap();
1939 gs.add_alias(id, "rust").await.unwrap();
1940
1941 let found = gs
1942 .find_entity_by_alias("python", EntityType::Language)
1943 .await
1944 .unwrap();
1945 assert!(found.is_none());
1946 }
1947
1948 #[tokio::test]
1949 async fn find_entity_by_alias_filters_by_entity_type() {
1950 let gs = setup().await;
1952 let lang_id = gs
1953 .upsert_entity("python", "python", EntityType::Language, None)
1954 .await
1955 .unwrap();
1956 gs.add_alias(lang_id, "python").await.unwrap();
1957
1958 let found_tool = gs
1959 .find_entity_by_alias("python", EntityType::Tool)
1960 .await
1961 .unwrap();
1962 assert!(
1963 found_tool.is_none(),
1964 "cross-type alias collision must not occur"
1965 );
1966
1967 let found_lang = gs
1968 .find_entity_by_alias("python", EntityType::Language)
1969 .await
1970 .unwrap();
1971 assert!(found_lang.is_some());
1972 assert_eq!(found_lang.unwrap().id, lang_id);
1973 }
1974
1975 #[tokio::test]
1976 async fn aliases_for_entity_returns_all() {
1977 let gs = setup().await;
1978 let id = gs
1979 .upsert_entity("rust", "rust", EntityType::Language, None)
1980 .await
1981 .unwrap();
1982 gs.add_alias(id, "rust").await.unwrap();
1983 gs.add_alias(id, "rust-lang").await.unwrap();
1984 gs.add_alias(id, "rustlang").await.unwrap();
1985
1986 let aliases = gs.aliases_for_entity(id).await.unwrap();
1987 assert_eq!(aliases.len(), 3);
1988 let names: Vec<&str> = aliases.iter().map(|a| a.alias_name.as_str()).collect();
1989 assert!(names.contains(&"rust"));
1990 assert!(names.contains(&"rust-lang"));
1991 assert!(names.contains(&"rustlang"));
1992 }
1993
1994 #[tokio::test]
1995 async fn find_entities_fuzzy_includes_aliases() {
1996 let gs = setup().await;
1997 let id = gs
1998 .upsert_entity("rust", "rust", EntityType::Language, None)
1999 .await
2000 .unwrap();
2001 gs.add_alias(id, "rust-lang").await.unwrap();
2002 gs.upsert_entity("python", "python", EntityType::Language, None)
2003 .await
2004 .unwrap();
2005
2006 let results = gs.find_entities_fuzzy("rust-lang", 10).await.unwrap();
2008 assert!(!results.is_empty());
2009 assert!(results.iter().any(|e| e.id == id));
2010 }
2011
2012 #[tokio::test]
2013 async fn orphan_alias_cleanup_on_entity_delete() {
2014 let gs = setup().await;
2015 let id = gs
2016 .upsert_entity("rust", "rust", EntityType::Language, None)
2017 .await
2018 .unwrap();
2019 gs.add_alias(id, "rust").await.unwrap();
2020 gs.add_alias(id, "rust-lang").await.unwrap();
2021
2022 sqlx::query("DELETE FROM graph_entities WHERE id = ?1")
2024 .bind(id)
2025 .execute(&gs.pool)
2026 .await
2027 .unwrap();
2028
2029 let aliases = gs.aliases_for_entity(id).await.unwrap();
2031 assert!(
2032 aliases.is_empty(),
2033 "aliases should cascade-delete with entity"
2034 );
2035 }
2036
2037 #[tokio::test]
2047 #[allow(clippy::too_many_lines)]
2048 async fn migration_024_backfill_preserves_entities_and_edges() {
2049 use sqlx::Acquire as _;
2050 use sqlx::ConnectOptions as _;
2051 use sqlx::sqlite::SqliteConnectOptions;
2052
2053 let opts = SqliteConnectOptions::from_url(&"sqlite::memory:".parse().unwrap())
2056 .unwrap()
2057 .foreign_keys(true);
2058 let pool = sqlx::pool::PoolOptions::<sqlx::Sqlite>::new()
2059 .max_connections(1)
2060 .connect_with(opts)
2061 .await
2062 .unwrap();
2063
2064 sqlx::query(
2066 "CREATE TABLE graph_entities (
2067 id INTEGER PRIMARY KEY AUTOINCREMENT,
2068 name TEXT NOT NULL,
2069 entity_type TEXT NOT NULL,
2070 summary TEXT,
2071 first_seen_at TEXT NOT NULL DEFAULT (datetime('now')),
2072 last_seen_at TEXT NOT NULL DEFAULT (datetime('now')),
2073 qdrant_point_id TEXT,
2074 UNIQUE(name, entity_type)
2075 )",
2076 )
2077 .execute(&pool)
2078 .await
2079 .unwrap();
2080
2081 sqlx::query(
2082 "CREATE TABLE graph_edges (
2083 id INTEGER PRIMARY KEY AUTOINCREMENT,
2084 source_entity_id INTEGER NOT NULL REFERENCES graph_entities(id) ON DELETE CASCADE,
2085 target_entity_id INTEGER NOT NULL REFERENCES graph_entities(id) ON DELETE CASCADE,
2086 relation TEXT NOT NULL,
2087 fact TEXT NOT NULL,
2088 confidence REAL NOT NULL DEFAULT 1.0,
2089 valid_from TEXT NOT NULL DEFAULT (datetime('now')),
2090 valid_to TEXT,
2091 created_at TEXT NOT NULL DEFAULT (datetime('now')),
2092 expired_at TEXT,
2093 episode_id INTEGER,
2094 qdrant_point_id TEXT
2095 )",
2096 )
2097 .execute(&pool)
2098 .await
2099 .unwrap();
2100
2101 sqlx::query(
2103 "CREATE VIRTUAL TABLE IF NOT EXISTS graph_entities_fts USING fts5(
2104 name, summary, content='graph_entities', content_rowid='id',
2105 tokenize='unicode61 remove_diacritics 2'
2106 )",
2107 )
2108 .execute(&pool)
2109 .await
2110 .unwrap();
2111 sqlx::query(
2112 "CREATE TRIGGER IF NOT EXISTS graph_entities_fts_insert AFTER INSERT ON graph_entities
2113 BEGIN INSERT INTO graph_entities_fts(rowid, name, summary) VALUES (new.id, new.name, COALESCE(new.summary, '')); END",
2114 )
2115 .execute(&pool)
2116 .await
2117 .unwrap();
2118 sqlx::query(
2119 "CREATE TRIGGER IF NOT EXISTS graph_entities_fts_delete AFTER DELETE ON graph_entities
2120 BEGIN INSERT INTO graph_entities_fts(graph_entities_fts, rowid, name, summary) VALUES ('delete', old.id, old.name, COALESCE(old.summary, '')); END",
2121 )
2122 .execute(&pool)
2123 .await
2124 .unwrap();
2125 sqlx::query(
2126 "CREATE TRIGGER IF NOT EXISTS graph_entities_fts_update AFTER UPDATE ON graph_entities
2127 BEGIN
2128 INSERT INTO graph_entities_fts(graph_entities_fts, rowid, name, summary) VALUES ('delete', old.id, old.name, COALESCE(old.summary, ''));
2129 INSERT INTO graph_entities_fts(rowid, name, summary) VALUES (new.id, new.name, COALESCE(new.summary, ''));
2130 END",
2131 )
2132 .execute(&pool)
2133 .await
2134 .unwrap();
2135
2136 let alice_id: i64 = sqlx::query_scalar(
2138 "INSERT INTO graph_entities (name, entity_type) VALUES ('Alice', 'person') RETURNING id",
2139 )
2140 .fetch_one(&pool)
2141 .await
2142 .unwrap();
2143
2144 let rust_id: i64 = sqlx::query_scalar(
2145 "INSERT INTO graph_entities (name, entity_type) VALUES ('Rust', 'language') RETURNING id",
2146 )
2147 .fetch_one(&pool)
2148 .await
2149 .unwrap();
2150
2151 sqlx::query(
2152 "INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact)
2153 VALUES (?1, ?2, 'uses', 'Alice uses Rust')",
2154 )
2155 .bind(alice_id)
2156 .bind(rust_id)
2157 .execute(&pool)
2158 .await
2159 .unwrap();
2160
2161 let mut conn = pool.acquire().await.unwrap();
2165 let conn = conn.acquire().await.unwrap();
2166
2167 sqlx::query("PRAGMA foreign_keys = OFF")
2168 .execute(&mut *conn)
2169 .await
2170 .unwrap();
2171 sqlx::query("ALTER TABLE graph_entities ADD COLUMN canonical_name TEXT")
2172 .execute(&mut *conn)
2173 .await
2174 .unwrap();
2175 sqlx::query("UPDATE graph_entities SET canonical_name = name WHERE canonical_name IS NULL")
2176 .execute(&mut *conn)
2177 .await
2178 .unwrap();
2179 sqlx::query(
2180 "CREATE TABLE graph_entities_new (
2181 id INTEGER PRIMARY KEY AUTOINCREMENT,
2182 name TEXT NOT NULL,
2183 canonical_name TEXT NOT NULL,
2184 entity_type TEXT NOT NULL,
2185 summary TEXT,
2186 first_seen_at TEXT NOT NULL DEFAULT (datetime('now')),
2187 last_seen_at TEXT NOT NULL DEFAULT (datetime('now')),
2188 qdrant_point_id TEXT,
2189 UNIQUE(canonical_name, entity_type)
2190 )",
2191 )
2192 .execute(&mut *conn)
2193 .await
2194 .unwrap();
2195 sqlx::query(
2196 "INSERT INTO graph_entities_new
2197 (id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id)
2198 SELECT id, name, COALESCE(canonical_name, name), entity_type, summary,
2199 first_seen_at, last_seen_at, qdrant_point_id
2200 FROM graph_entities",
2201 )
2202 .execute(&mut *conn)
2203 .await
2204 .unwrap();
2205 sqlx::query("DROP TABLE graph_entities")
2206 .execute(&mut *conn)
2207 .await
2208 .unwrap();
2209 sqlx::query("ALTER TABLE graph_entities_new RENAME TO graph_entities")
2210 .execute(&mut *conn)
2211 .await
2212 .unwrap();
2213 sqlx::query(
2215 "CREATE TRIGGER IF NOT EXISTS graph_entities_fts_insert AFTER INSERT ON graph_entities
2216 BEGIN INSERT INTO graph_entities_fts(rowid, name, summary) VALUES (new.id, new.name, COALESCE(new.summary, '')); END",
2217 )
2218 .execute(&mut *conn)
2219 .await
2220 .unwrap();
2221 sqlx::query(
2222 "CREATE TRIGGER IF NOT EXISTS graph_entities_fts_delete AFTER DELETE ON graph_entities
2223 BEGIN INSERT INTO graph_entities_fts(graph_entities_fts, rowid, name, summary) VALUES ('delete', old.id, old.name, COALESCE(old.summary, '')); END",
2224 )
2225 .execute(&mut *conn)
2226 .await
2227 .unwrap();
2228 sqlx::query(
2229 "CREATE TRIGGER IF NOT EXISTS graph_entities_fts_update AFTER UPDATE ON graph_entities
2230 BEGIN
2231 INSERT INTO graph_entities_fts(graph_entities_fts, rowid, name, summary) VALUES ('delete', old.id, old.name, COALESCE(old.summary, ''));
2232 INSERT INTO graph_entities_fts(rowid, name, summary) VALUES (new.id, new.name, COALESCE(new.summary, ''));
2233 END",
2234 )
2235 .execute(&mut *conn)
2236 .await
2237 .unwrap();
2238 sqlx::query("INSERT INTO graph_entities_fts(graph_entities_fts) VALUES('rebuild')")
2239 .execute(&mut *conn)
2240 .await
2241 .unwrap();
2242 sqlx::query(
2243 "CREATE TABLE graph_entity_aliases (
2244 id INTEGER PRIMARY KEY AUTOINCREMENT,
2245 entity_id INTEGER NOT NULL REFERENCES graph_entities(id) ON DELETE CASCADE,
2246 alias_name TEXT NOT NULL,
2247 created_at TEXT NOT NULL DEFAULT (datetime('now')),
2248 UNIQUE(alias_name, entity_id)
2249 )",
2250 )
2251 .execute(&mut *conn)
2252 .await
2253 .unwrap();
2254 sqlx::query(
2255 "INSERT OR IGNORE INTO graph_entity_aliases (entity_id, alias_name)
2256 SELECT id, name FROM graph_entities",
2257 )
2258 .execute(&mut *conn)
2259 .await
2260 .unwrap();
2261 sqlx::query("PRAGMA foreign_keys = ON")
2262 .execute(&mut *conn)
2263 .await
2264 .unwrap();
2265
2266 let alice_canon: String =
2268 sqlx::query_scalar("SELECT canonical_name FROM graph_entities WHERE id = ?1")
2269 .bind(alice_id)
2270 .fetch_one(&mut *conn)
2271 .await
2272 .unwrap();
2273 assert_eq!(
2274 alice_canon, "Alice",
2275 "canonical_name should equal pre-migration name"
2276 );
2277
2278 let rust_canon: String =
2279 sqlx::query_scalar("SELECT canonical_name FROM graph_entities WHERE id = ?1")
2280 .bind(rust_id)
2281 .fetch_one(&mut *conn)
2282 .await
2283 .unwrap();
2284 assert_eq!(
2285 rust_canon, "Rust",
2286 "canonical_name should equal pre-migration name"
2287 );
2288
2289 let alice_aliases: Vec<String> =
2291 sqlx::query_scalar("SELECT alias_name FROM graph_entity_aliases WHERE entity_id = ?1")
2292 .bind(alice_id)
2293 .fetch_all(&mut *conn)
2294 .await
2295 .unwrap();
2296 assert!(
2297 alice_aliases.contains(&"Alice".to_owned()),
2298 "initial alias should be seeded from entity name"
2299 );
2300
2301 let edge_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM graph_edges")
2303 .fetch_one(&mut *conn)
2304 .await
2305 .unwrap();
2306 assert_eq!(
2307 edge_count, 1,
2308 "graph_edges must survive migration 024 table recreation"
2309 );
2310 }
2311
2312 #[tokio::test]
2313 async fn find_entity_by_alias_same_alias_two_entities_deterministic() {
2314 let gs = setup().await;
2316 let id1 = gs
2317 .upsert_entity("python-v2", "python-v2", EntityType::Language, None)
2318 .await
2319 .unwrap();
2320 let id2 = gs
2321 .upsert_entity("python-v3", "python-v3", EntityType::Language, None)
2322 .await
2323 .unwrap();
2324 gs.add_alias(id1, "python").await.unwrap();
2325 gs.add_alias(id2, "python").await.unwrap();
2326
2327 let found = gs
2329 .find_entity_by_alias("python", EntityType::Language)
2330 .await
2331 .unwrap();
2332 assert!(found.is_some(), "should find an entity by shared alias");
2333 assert_eq!(
2335 found.unwrap().id,
2336 id1,
2337 "first-registered entity should win on shared alias"
2338 );
2339 }
2340
2341 #[tokio::test]
2344 async fn find_entities_fuzzy_special_chars() {
2345 let gs = setup().await;
2346 gs.upsert_entity("Graph", "Graph", EntityType::Concept, None)
2347 .await
2348 .unwrap();
2349 let results = gs.find_entities_fuzzy("graph\"()*:^", 10).await.unwrap();
2351 assert!(results.iter().any(|e| e.name == "Graph"));
2353 }
2354
2355 #[tokio::test]
2356 async fn find_entities_fuzzy_prefix_match() {
2357 let gs = setup().await;
2358 gs.upsert_entity("Graph", "Graph", EntityType::Concept, None)
2359 .await
2360 .unwrap();
2361 gs.upsert_entity("GraphQL", "GraphQL", EntityType::Concept, None)
2362 .await
2363 .unwrap();
2364 gs.upsert_entity("Unrelated", "Unrelated", EntityType::Concept, None)
2365 .await
2366 .unwrap();
2367 let results = gs.find_entities_fuzzy("Gra", 10).await.unwrap();
2369 assert_eq!(results.len(), 2);
2370 assert!(results.iter().any(|e| e.name == "Graph"));
2371 assert!(results.iter().any(|e| e.name == "GraphQL"));
2372 }
2373
2374 #[tokio::test]
2375 async fn find_entities_fuzzy_fts5_operator_injection() {
2376 let gs = setup().await;
2377 gs.upsert_entity("Graph", "Graph", EntityType::Concept, None)
2378 .await
2379 .unwrap();
2380 gs.upsert_entity("Unrelated", "Unrelated", EntityType::Concept, None)
2381 .await
2382 .unwrap();
2383 let results = gs
2388 .find_entities_fuzzy("graph OR unrelated", 10)
2389 .await
2390 .unwrap();
2391 assert!(
2392 results.is_empty(),
2393 "implicit AND of 'graph*' and 'unrelated*' should match no entity"
2394 );
2395 }
2396
2397 #[tokio::test]
2398 async fn find_entities_fuzzy_after_entity_update() {
2399 let gs = setup().await;
2400 gs.upsert_entity(
2402 "Foo",
2403 "Foo",
2404 EntityType::Concept,
2405 Some("initial summary bar"),
2406 )
2407 .await
2408 .unwrap();
2409 gs.upsert_entity(
2411 "Foo",
2412 "Foo",
2413 EntityType::Concept,
2414 Some("updated summary baz"),
2415 )
2416 .await
2417 .unwrap();
2418 let old_results = gs.find_entities_fuzzy("bar", 10).await.unwrap();
2420 assert!(
2421 old_results.is_empty(),
2422 "old summary content should not match after update"
2423 );
2424 let new_results = gs.find_entities_fuzzy("baz", 10).await.unwrap();
2426 assert_eq!(new_results.len(), 1);
2427 assert_eq!(new_results[0].name, "Foo");
2428 }
2429
2430 #[tokio::test]
2431 async fn find_entities_fuzzy_only_special_chars() {
2432 let gs = setup().await;
2433 gs.upsert_entity("Alpha", "Alpha", EntityType::Concept, None)
2434 .await
2435 .unwrap();
2436 let results = gs.find_entities_fuzzy("***", 10).await.unwrap();
2440 assert!(
2441 results.is_empty(),
2442 "only special chars should return no results"
2443 );
2444 let results = gs.find_entities_fuzzy("(((", 10).await.unwrap();
2445 assert!(results.is_empty(), "only parens should return no results");
2446 let results = gs.find_entities_fuzzy("\"\"\"", 10).await.unwrap();
2447 assert!(results.is_empty(), "only quotes should return no results");
2448 }
2449
2450 #[tokio::test]
2453 async fn find_entity_by_name_exact_wins_over_summary_mention() {
2454 let gs = setup().await;
2457 gs.upsert_entity(
2458 "Alice",
2459 "Alice",
2460 EntityType::Person,
2461 Some("A person named Alice"),
2462 )
2463 .await
2464 .unwrap();
2465 gs.upsert_entity(
2467 "Google",
2468 "Google",
2469 EntityType::Organization,
2470 Some("Company where Charlie, Alice, and Bob have worked"),
2471 )
2472 .await
2473 .unwrap();
2474
2475 let results = gs.find_entity_by_name("Alice").await.unwrap();
2476 assert!(!results.is_empty(), "must find at least one entity");
2477 assert_eq!(
2478 results[0].name, "Alice",
2479 "exact name match must come first, not entity with 'Alice' in summary"
2480 );
2481 }
2482
2483 #[tokio::test]
2484 async fn find_entity_by_name_case_insensitive_exact() {
2485 let gs = setup().await;
2486 gs.upsert_entity("Bob", "Bob", EntityType::Person, None)
2487 .await
2488 .unwrap();
2489
2490 let results = gs.find_entity_by_name("bob").await.unwrap();
2491 assert!(!results.is_empty());
2492 assert_eq!(results[0].name, "Bob");
2493 }
2494
2495 #[tokio::test]
2496 async fn find_entity_by_name_falls_back_to_fuzzy_when_no_exact_match() {
2497 let gs = setup().await;
2498 gs.upsert_entity("Charlie", "Charlie", EntityType::Person, None)
2499 .await
2500 .unwrap();
2501
2502 let results = gs.find_entity_by_name("Char").await.unwrap();
2504 assert!(!results.is_empty(), "prefix search must find Charlie");
2505 }
2506
2507 #[tokio::test]
2508 async fn find_entity_by_name_returns_empty_for_unknown() {
2509 let gs = setup().await;
2510 let results = gs.find_entity_by_name("NonExistent").await.unwrap();
2511 assert!(results.is_empty());
2512 }
2513
2514 #[tokio::test]
2515 async fn find_entity_by_name_matches_canonical_name() {
2516 let gs = setup().await;
2518 gs.upsert_entity("Dave (Engineer)", "Dave", EntityType::Person, None)
2520 .await
2521 .unwrap();
2522
2523 let results = gs.find_entity_by_name("Dave").await.unwrap();
2526 assert!(
2527 !results.is_empty(),
2528 "canonical_name match must return entity"
2529 );
2530 assert_eq!(results[0].canonical_name, "Dave");
2531 }
2532
2533 async fn insert_test_message(gs: &GraphStore, content: &str) -> crate::types::MessageId {
2534 let conv_id: i64 =
2536 sqlx::query_scalar("INSERT INTO conversations DEFAULT VALUES RETURNING id")
2537 .fetch_one(&gs.pool)
2538 .await
2539 .unwrap();
2540 let id: i64 = sqlx::query_scalar(
2541 "INSERT INTO messages (conversation_id, role, content) VALUES (?1, 'user', ?2) RETURNING id",
2542 )
2543 .bind(conv_id)
2544 .bind(content)
2545 .fetch_one(&gs.pool)
2546 .await
2547 .unwrap();
2548 crate::types::MessageId(id)
2549 }
2550
2551 #[tokio::test]
2552 async fn unprocessed_messages_for_backfill_returns_unprocessed() {
2553 let gs = setup().await;
2554 let id1 = insert_test_message(&gs, "hello world").await;
2555 let id2 = insert_test_message(&gs, "second message").await;
2556
2557 let rows = gs.unprocessed_messages_for_backfill(10).await.unwrap();
2558 assert_eq!(rows.len(), 2);
2559 assert!(rows.iter().any(|(id, _)| *id == id1));
2560 assert!(rows.iter().any(|(id, _)| *id == id2));
2561 }
2562
2563 #[tokio::test]
2564 async fn unprocessed_messages_for_backfill_respects_limit() {
2565 let gs = setup().await;
2566 insert_test_message(&gs, "msg1").await;
2567 insert_test_message(&gs, "msg2").await;
2568 insert_test_message(&gs, "msg3").await;
2569
2570 let rows = gs.unprocessed_messages_for_backfill(2).await.unwrap();
2571 assert_eq!(rows.len(), 2);
2572 }
2573
2574 #[tokio::test]
2575 async fn mark_messages_graph_processed_updates_flag() {
2576 let gs = setup().await;
2577 let id1 = insert_test_message(&gs, "to process").await;
2578 let _id2 = insert_test_message(&gs, "also to process").await;
2579
2580 let count_before = gs.unprocessed_message_count().await.unwrap();
2582 assert_eq!(count_before, 2);
2583
2584 gs.mark_messages_graph_processed(&[id1]).await.unwrap();
2585
2586 let count_after = gs.unprocessed_message_count().await.unwrap();
2587 assert_eq!(count_after, 1);
2588
2589 let rows = gs.unprocessed_messages_for_backfill(10).await.unwrap();
2591 assert!(!rows.iter().any(|(id, _)| *id == id1));
2592 }
2593
2594 #[tokio::test]
2595 async fn mark_messages_graph_processed_empty_ids_is_noop() {
2596 let gs = setup().await;
2597 insert_test_message(&gs, "message").await;
2598
2599 gs.mark_messages_graph_processed(&[]).await.unwrap();
2600
2601 let count = gs.unprocessed_message_count().await.unwrap();
2602 assert_eq!(count, 1);
2603 }
2604
2605 #[tokio::test]
2606 async fn edges_after_id_first_page_returns_all_within_limit() {
2607 let gs = setup().await;
2608 let a = gs
2609 .upsert_entity("PA", "PA", EntityType::Concept, None)
2610 .await
2611 .unwrap();
2612 let b = gs
2613 .upsert_entity("PB", "PB", EntityType::Concept, None)
2614 .await
2615 .unwrap();
2616 let c = gs
2617 .upsert_entity("PC", "PC", EntityType::Concept, None)
2618 .await
2619 .unwrap();
2620 let e1 = gs.insert_edge(a, b, "r", "f1", 1.0, None).await.unwrap();
2621 let e2 = gs.insert_edge(b, c, "r", "f2", 1.0, None).await.unwrap();
2622 let e3 = gs.insert_edge(a, c, "r", "f3", 1.0, None).await.unwrap();
2623
2624 let page1 = gs.edges_after_id(0, 2).await.unwrap();
2626 assert_eq!(page1.len(), 2);
2627 assert_eq!(page1[0].id, e1);
2628 assert_eq!(page1[1].id, e2);
2629
2630 let page2 = gs
2632 .edges_after_id(page1.last().unwrap().id, 2)
2633 .await
2634 .unwrap();
2635 assert_eq!(page2.len(), 1);
2636 assert_eq!(page2[0].id, e3);
2637
2638 let page3 = gs
2640 .edges_after_id(page2.last().unwrap().id, 2)
2641 .await
2642 .unwrap();
2643 assert!(page3.is_empty(), "no more edges after last id");
2644 }
2645
2646 #[tokio::test]
2647 async fn edges_after_id_skips_invalidated_edges() {
2648 let gs = setup().await;
2649 let a = gs
2650 .upsert_entity("IA", "IA", EntityType::Concept, None)
2651 .await
2652 .unwrap();
2653 let b = gs
2654 .upsert_entity("IB", "IB", EntityType::Concept, None)
2655 .await
2656 .unwrap();
2657 let c = gs
2658 .upsert_entity("IC", "IC", EntityType::Concept, None)
2659 .await
2660 .unwrap();
2661 let e1 = gs.insert_edge(a, b, "r", "f1", 1.0, None).await.unwrap();
2662 let e2 = gs.insert_edge(b, c, "r", "f2", 1.0, None).await.unwrap();
2663
2664 gs.invalidate_edge(e1).await.unwrap();
2666
2667 let page = gs.edges_after_id(0, 10).await.unwrap();
2668 assert_eq!(page.len(), 1, "invalidated edge must be excluded");
2669 assert_eq!(page[0].id, e2);
2670 }
2671}