1use std::collections::HashMap;
5
6use futures::Stream;
7use sqlx::SqlitePool;
8
9use crate::error::MemoryError;
10use crate::sqlite::messages::sanitize_fts5_query;
11use crate::types::MessageId;
12
13use super::types::{Community, Edge, Entity, EntityAlias, EntityType};
14
15pub struct GraphStore {
16 pool: SqlitePool,
17}
18
19impl GraphStore {
20 #[must_use]
21 pub fn new(pool: SqlitePool) -> Self {
22 Self { pool }
23 }
24
25 #[must_use]
26 pub fn pool(&self) -> &SqlitePool {
27 &self.pool
28 }
29
30 pub async fn upsert_entity(
43 &self,
44 surface_name: &str,
45 canonical_name: &str,
46 entity_type: EntityType,
47 summary: Option<&str>,
48 ) -> Result<i64, MemoryError> {
49 let type_str = entity_type.as_str();
50 let id: i64 = sqlx::query_scalar(
51 "INSERT INTO graph_entities (name, canonical_name, entity_type, summary)
52 VALUES (?1, ?2, ?3, ?4)
53 ON CONFLICT(canonical_name, entity_type) DO UPDATE SET
54 name = excluded.name,
55 summary = COALESCE(excluded.summary, summary),
56 last_seen_at = datetime('now')
57 RETURNING id",
58 )
59 .bind(surface_name)
60 .bind(canonical_name)
61 .bind(type_str)
62 .bind(summary)
63 .fetch_one(&self.pool)
64 .await?;
65 Ok(id)
66 }
67
68 pub async fn find_entity(
74 &self,
75 canonical_name: &str,
76 entity_type: EntityType,
77 ) -> Result<Option<Entity>, MemoryError> {
78 let type_str = entity_type.as_str();
79 let row: Option<EntityRow> = sqlx::query_as(
80 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
81 FROM graph_entities
82 WHERE canonical_name = ?1 AND entity_type = ?2",
83 )
84 .bind(canonical_name)
85 .bind(type_str)
86 .fetch_optional(&self.pool)
87 .await?;
88 row.map(entity_from_row).transpose()
89 }
90
91 pub async fn find_entity_by_id(&self, entity_id: i64) -> Result<Option<Entity>, MemoryError> {
97 let row: Option<EntityRow> = sqlx::query_as(
98 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
99 FROM graph_entities
100 WHERE id = ?1",
101 )
102 .bind(entity_id)
103 .fetch_optional(&self.pool)
104 .await?;
105 row.map(entity_from_row).transpose()
106 }
107
108 pub async fn set_entity_qdrant_point_id(
114 &self,
115 entity_id: i64,
116 point_id: &str,
117 ) -> Result<(), MemoryError> {
118 sqlx::query("UPDATE graph_entities SET qdrant_point_id = ?1 WHERE id = ?2")
119 .bind(point_id)
120 .bind(entity_id)
121 .execute(&self.pool)
122 .await?;
123 Ok(())
124 }
125
126 pub async fn find_entities_fuzzy(
147 &self,
148 query: &str,
149 limit: usize,
150 ) -> Result<Vec<Entity>, MemoryError> {
151 const FTS5_OPERATORS: &[&str] = &["AND", "OR", "NOT", "NEAR"];
155 let query = &query[..query.floor_char_boundary(512)];
156 let sanitized = sanitize_fts5_query(query);
159 if sanitized.is_empty() {
160 return Ok(vec![]);
161 }
162 let fts_query: String = sanitized
163 .split_whitespace()
164 .filter(|t| !FTS5_OPERATORS.contains(t))
165 .map(|t| format!("{t}*"))
166 .collect::<Vec<_>>()
167 .join(" ");
168 if fts_query.is_empty() {
169 return Ok(vec![]);
170 }
171
172 let limit = i64::try_from(limit)?;
173 let rows: Vec<EntityRow> = sqlx::query_as(
176 "SELECT DISTINCT e.id, e.name, e.canonical_name, e.entity_type, e.summary,
177 e.first_seen_at, e.last_seen_at, e.qdrant_point_id
178 FROM graph_entities_fts fts
179 JOIN graph_entities e ON e.id = fts.rowid
180 WHERE graph_entities_fts MATCH ?1
181 UNION
182 SELECT e.id, e.name, e.canonical_name, e.entity_type, e.summary,
183 e.first_seen_at, e.last_seen_at, e.qdrant_point_id
184 FROM graph_entity_aliases a
185 JOIN graph_entities e ON e.id = a.entity_id
186 WHERE a.alias_name LIKE ?2 ESCAPE '\\' COLLATE NOCASE
187 LIMIT ?3",
188 )
189 .bind(&fts_query)
190 .bind(format!(
191 "%{}%",
192 query
193 .trim()
194 .replace('\\', "\\\\")
195 .replace('%', "\\%")
196 .replace('_', "\\_")
197 ))
198 .bind(limit)
199 .fetch_all(&self.pool)
200 .await?;
201 rows.into_iter()
202 .map(entity_from_row)
203 .collect::<Result<Vec<_>, _>>()
204 }
205
206 pub fn all_entities_stream(&self) -> impl Stream<Item = Result<Entity, MemoryError>> + '_ {
208 use futures::StreamExt as _;
209 sqlx::query_as::<_, EntityRow>(
210 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
211 FROM graph_entities ORDER BY id ASC",
212 )
213 .fetch(&self.pool)
214 .map(|r: Result<EntityRow, sqlx::Error>| {
215 r.map_err(MemoryError::from).and_then(entity_from_row)
216 })
217 }
218
219 pub async fn add_alias(&self, entity_id: i64, alias_name: &str) -> Result<(), MemoryError> {
227 sqlx::query(
228 "INSERT OR IGNORE INTO graph_entity_aliases (entity_id, alias_name) VALUES (?1, ?2)",
229 )
230 .bind(entity_id)
231 .bind(alias_name)
232 .execute(&self.pool)
233 .await?;
234 Ok(())
235 }
236
237 pub async fn find_entity_by_alias(
245 &self,
246 alias_name: &str,
247 entity_type: EntityType,
248 ) -> Result<Option<Entity>, MemoryError> {
249 let type_str = entity_type.as_str();
250 let row: Option<EntityRow> = sqlx::query_as(
251 "SELECT e.id, e.name, e.canonical_name, e.entity_type, e.summary,
252 e.first_seen_at, e.last_seen_at, e.qdrant_point_id
253 FROM graph_entity_aliases a
254 JOIN graph_entities e ON e.id = a.entity_id
255 WHERE a.alias_name = ?1 COLLATE NOCASE
256 AND e.entity_type = ?2
257 ORDER BY e.id ASC
258 LIMIT 1",
259 )
260 .bind(alias_name)
261 .bind(type_str)
262 .fetch_optional(&self.pool)
263 .await?;
264 row.map(entity_from_row).transpose()
265 }
266
267 pub async fn aliases_for_entity(
273 &self,
274 entity_id: i64,
275 ) -> Result<Vec<EntityAlias>, MemoryError> {
276 let rows: Vec<AliasRow> = sqlx::query_as(
277 "SELECT id, entity_id, alias_name, created_at
278 FROM graph_entity_aliases
279 WHERE entity_id = ?1
280 ORDER BY id ASC",
281 )
282 .bind(entity_id)
283 .fetch_all(&self.pool)
284 .await?;
285 Ok(rows.into_iter().map(alias_from_row).collect())
286 }
287
288 pub async fn all_entities(&self) -> Result<Vec<Entity>, MemoryError> {
294 use futures::TryStreamExt as _;
295 self.all_entities_stream().try_collect().await
296 }
297
298 pub async fn entity_count(&self) -> Result<i64, MemoryError> {
304 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM graph_entities")
305 .fetch_one(&self.pool)
306 .await?;
307 Ok(count)
308 }
309
310 pub async fn insert_edge(
323 &self,
324 source_entity_id: i64,
325 target_entity_id: i64,
326 relation: &str,
327 fact: &str,
328 confidence: f32,
329 episode_id: Option<MessageId>,
330 ) -> Result<i64, MemoryError> {
331 let confidence = confidence.clamp(0.0, 1.0);
332
333 let existing: Option<(i64, f64)> = sqlx::query_as(
334 "SELECT id, confidence FROM graph_edges
335 WHERE source_entity_id = ?1
336 AND target_entity_id = ?2
337 AND relation = ?3
338 AND valid_to IS NULL
339 LIMIT 1",
340 )
341 .bind(source_entity_id)
342 .bind(target_entity_id)
343 .bind(relation)
344 .fetch_optional(&self.pool)
345 .await?;
346
347 if let Some((existing_id, stored_conf)) = existing {
348 let updated_conf = f64::from(confidence).max(stored_conf);
349 sqlx::query("UPDATE graph_edges SET confidence = ?1 WHERE id = ?2")
350 .bind(updated_conf)
351 .bind(existing_id)
352 .execute(&self.pool)
353 .await?;
354 return Ok(existing_id);
355 }
356
357 let episode_raw: Option<i64> = episode_id.map(|m| m.0);
358 let id: i64 = sqlx::query_scalar(
359 "INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, episode_id)
360 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
361 RETURNING id",
362 )
363 .bind(source_entity_id)
364 .bind(target_entity_id)
365 .bind(relation)
366 .bind(fact)
367 .bind(f64::from(confidence))
368 .bind(episode_raw)
369 .fetch_one(&self.pool)
370 .await?;
371 Ok(id)
372 }
373
374 pub async fn invalidate_edge(&self, edge_id: i64) -> Result<(), MemoryError> {
380 sqlx::query(
381 "UPDATE graph_edges SET valid_to = datetime('now'), expired_at = datetime('now')
382 WHERE id = ?1",
383 )
384 .bind(edge_id)
385 .execute(&self.pool)
386 .await?;
387 Ok(())
388 }
389
390 pub async fn edges_for_entity(&self, entity_id: i64) -> Result<Vec<Edge>, MemoryError> {
396 let rows: Vec<EdgeRow> = sqlx::query_as(
397 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
398 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
399 FROM graph_edges
400 WHERE valid_to IS NULL
401 AND (source_entity_id = ?1 OR target_entity_id = ?1)",
402 )
403 .bind(entity_id)
404 .fetch_all(&self.pool)
405 .await?;
406 Ok(rows.into_iter().map(edge_from_row).collect())
407 }
408
409 pub async fn edges_between(
415 &self,
416 entity_a: i64,
417 entity_b: i64,
418 ) -> Result<Vec<Edge>, MemoryError> {
419 let rows: Vec<EdgeRow> = sqlx::query_as(
420 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
421 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
422 FROM graph_edges
423 WHERE valid_to IS NULL
424 AND ((source_entity_id = ?1 AND target_entity_id = ?2)
425 OR (source_entity_id = ?2 AND target_entity_id = ?1))",
426 )
427 .bind(entity_a)
428 .bind(entity_b)
429 .fetch_all(&self.pool)
430 .await?;
431 Ok(rows.into_iter().map(edge_from_row).collect())
432 }
433
434 pub async fn edges_exact(
440 &self,
441 source_entity_id: i64,
442 target_entity_id: i64,
443 ) -> Result<Vec<Edge>, MemoryError> {
444 let rows: Vec<EdgeRow> = sqlx::query_as(
445 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
446 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
447 FROM graph_edges
448 WHERE valid_to IS NULL
449 AND source_entity_id = ?1
450 AND target_entity_id = ?2",
451 )
452 .bind(source_entity_id)
453 .bind(target_entity_id)
454 .fetch_all(&self.pool)
455 .await?;
456 Ok(rows.into_iter().map(edge_from_row).collect())
457 }
458
459 pub async fn active_edge_count(&self) -> Result<i64, MemoryError> {
465 let count: i64 =
466 sqlx::query_scalar("SELECT COUNT(*) FROM graph_edges WHERE valid_to IS NULL")
467 .fetch_one(&self.pool)
468 .await?;
469 Ok(count)
470 }
471
472 pub async fn upsert_community(
484 &self,
485 name: &str,
486 summary: &str,
487 entity_ids: &[i64],
488 fingerprint: Option<&str>,
489 ) -> Result<i64, MemoryError> {
490 let entity_ids_json = serde_json::to_string(entity_ids)?;
491 let id: i64 = sqlx::query_scalar(
492 "INSERT INTO graph_communities (name, summary, entity_ids, fingerprint)
493 VALUES (?1, ?2, ?3, ?4)
494 ON CONFLICT(name) DO UPDATE SET
495 summary = excluded.summary,
496 entity_ids = excluded.entity_ids,
497 fingerprint = COALESCE(excluded.fingerprint, fingerprint),
498 updated_at = datetime('now')
499 RETURNING id",
500 )
501 .bind(name)
502 .bind(summary)
503 .bind(entity_ids_json)
504 .bind(fingerprint)
505 .fetch_one(&self.pool)
506 .await?;
507 Ok(id)
508 }
509
510 pub async fn community_fingerprints(&self) -> Result<HashMap<String, i64>, MemoryError> {
517 let rows: Vec<(String, i64)> = sqlx::query_as(
518 "SELECT fingerprint, id FROM graph_communities WHERE fingerprint IS NOT NULL",
519 )
520 .fetch_all(&self.pool)
521 .await?;
522 Ok(rows.into_iter().collect())
523 }
524
525 pub async fn delete_community_by_id(&self, id: i64) -> Result<(), MemoryError> {
531 sqlx::query("DELETE FROM graph_communities WHERE id = ?1")
532 .bind(id)
533 .execute(&self.pool)
534 .await?;
535 Ok(())
536 }
537
538 pub async fn clear_community_fingerprint(&self, id: i64) -> Result<(), MemoryError> {
547 sqlx::query("UPDATE graph_communities SET fingerprint = NULL WHERE id = ?1")
548 .bind(id)
549 .execute(&self.pool)
550 .await?;
551 Ok(())
552 }
553
554 pub async fn community_for_entity(
563 &self,
564 entity_id: i64,
565 ) -> Result<Option<Community>, MemoryError> {
566 let row: Option<CommunityRow> = sqlx::query_as(
567 "SELECT c.id, c.name, c.summary, c.entity_ids, c.fingerprint, c.created_at, c.updated_at
568 FROM graph_communities c, json_each(c.entity_ids) j
569 WHERE CAST(j.value AS INTEGER) = ?1
570 LIMIT 1",
571 )
572 .bind(entity_id)
573 .fetch_optional(&self.pool)
574 .await?;
575 match row {
576 Some(row) => {
577 let entity_ids: Vec<i64> = serde_json::from_str(&row.entity_ids)?;
578 Ok(Some(Community {
579 id: row.id,
580 name: row.name,
581 summary: row.summary,
582 entity_ids,
583 fingerprint: row.fingerprint,
584 created_at: row.created_at,
585 updated_at: row.updated_at,
586 }))
587 }
588 None => Ok(None),
589 }
590 }
591
592 pub async fn all_communities(&self) -> Result<Vec<Community>, MemoryError> {
598 let rows: Vec<CommunityRow> = sqlx::query_as(
599 "SELECT id, name, summary, entity_ids, fingerprint, created_at, updated_at
600 FROM graph_communities
601 ORDER BY id ASC",
602 )
603 .fetch_all(&self.pool)
604 .await?;
605
606 rows.into_iter()
607 .map(|row| {
608 let entity_ids: Vec<i64> = serde_json::from_str(&row.entity_ids)?;
609 Ok(Community {
610 id: row.id,
611 name: row.name,
612 summary: row.summary,
613 entity_ids,
614 fingerprint: row.fingerprint,
615 created_at: row.created_at,
616 updated_at: row.updated_at,
617 })
618 })
619 .collect()
620 }
621
622 pub async fn community_count(&self) -> Result<i64, MemoryError> {
628 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM graph_communities")
629 .fetch_one(&self.pool)
630 .await?;
631 Ok(count)
632 }
633
634 pub async fn get_metadata(&self, key: &str) -> Result<Option<String>, MemoryError> {
642 let val: Option<String> =
643 sqlx::query_scalar("SELECT value FROM graph_metadata WHERE key = ?1")
644 .bind(key)
645 .fetch_optional(&self.pool)
646 .await?;
647 Ok(val)
648 }
649
650 pub async fn set_metadata(&self, key: &str, value: &str) -> Result<(), MemoryError> {
656 sqlx::query(
657 "INSERT INTO graph_metadata (key, value) VALUES (?1, ?2)
658 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
659 )
660 .bind(key)
661 .bind(value)
662 .execute(&self.pool)
663 .await?;
664 Ok(())
665 }
666
667 pub async fn extraction_count(&self) -> Result<i64, MemoryError> {
675 let val = self.get_metadata("extraction_count").await?;
676 Ok(val.and_then(|v| v.parse::<i64>().ok()).unwrap_or(0))
677 }
678
679 pub fn all_active_edges_stream(&self) -> impl Stream<Item = Result<Edge, MemoryError>> + '_ {
681 use futures::StreamExt as _;
682 sqlx::query_as::<_, EdgeRow>(
683 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
684 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
685 FROM graph_edges
686 WHERE valid_to IS NULL
687 ORDER BY id ASC",
688 )
689 .fetch(&self.pool)
690 .map(|r| r.map_err(MemoryError::from).map(edge_from_row))
691 }
692
693 pub async fn find_community_by_id(&self, id: i64) -> Result<Option<Community>, MemoryError> {
699 let row: Option<CommunityRow> = sqlx::query_as(
700 "SELECT id, name, summary, entity_ids, fingerprint, created_at, updated_at
701 FROM graph_communities
702 WHERE id = ?1",
703 )
704 .bind(id)
705 .fetch_optional(&self.pool)
706 .await?;
707 match row {
708 Some(row) => {
709 let entity_ids: Vec<i64> = serde_json::from_str(&row.entity_ids)?;
710 Ok(Some(Community {
711 id: row.id,
712 name: row.name,
713 summary: row.summary,
714 entity_ids,
715 fingerprint: row.fingerprint,
716 created_at: row.created_at,
717 updated_at: row.updated_at,
718 }))
719 }
720 None => Ok(None),
721 }
722 }
723
724 pub async fn delete_all_communities(&self) -> Result<(), MemoryError> {
730 sqlx::query("DELETE FROM graph_communities")
731 .execute(&self.pool)
732 .await?;
733 Ok(())
734 }
735
736 pub async fn delete_expired_edges(&self, retention_days: u32) -> Result<usize, MemoryError> {
742 let days = i64::from(retention_days);
743 let result = sqlx::query(
744 "DELETE FROM graph_edges
745 WHERE expired_at IS NOT NULL
746 AND expired_at < datetime('now', '-' || ?1 || ' days')",
747 )
748 .bind(days)
749 .execute(&self.pool)
750 .await?;
751 Ok(usize::try_from(result.rows_affected())?)
752 }
753
754 pub async fn delete_orphan_entities(&self, retention_days: u32) -> Result<usize, MemoryError> {
760 let days = i64::from(retention_days);
761 let result = sqlx::query(
762 "DELETE FROM graph_entities
763 WHERE id NOT IN (
764 SELECT DISTINCT source_entity_id FROM graph_edges WHERE valid_to IS NULL
765 UNION
766 SELECT DISTINCT target_entity_id FROM graph_edges WHERE valid_to IS NULL
767 )
768 AND last_seen_at < datetime('now', '-' || ?1 || ' days')",
769 )
770 .bind(days)
771 .execute(&self.pool)
772 .await?;
773 Ok(usize::try_from(result.rows_affected())?)
774 }
775
776 pub async fn cap_entities(&self, max_entities: usize) -> Result<usize, MemoryError> {
785 let current = self.entity_count().await?;
786 let max = i64::try_from(max_entities)?;
787 if current <= max {
788 return Ok(0);
789 }
790 let excess = current - max;
791 let result = sqlx::query(
792 "DELETE FROM graph_entities
793 WHERE id IN (
794 SELECT e.id
795 FROM graph_entities e
796 LEFT JOIN (
797 SELECT source_entity_id AS eid, COUNT(*) AS cnt
798 FROM graph_edges WHERE valid_to IS NULL GROUP BY source_entity_id
799 UNION ALL
800 SELECT target_entity_id AS eid, COUNT(*) AS cnt
801 FROM graph_edges WHERE valid_to IS NULL GROUP BY target_entity_id
802 ) edge_counts ON e.id = edge_counts.eid
803 ORDER BY COALESCE(edge_counts.cnt, 0) ASC, e.last_seen_at ASC
804 LIMIT ?1
805 )",
806 )
807 .bind(excess)
808 .execute(&self.pool)
809 .await?;
810 Ok(usize::try_from(result.rows_affected())?)
811 }
812
813 pub async fn bfs(
830 &self,
831 start_entity_id: i64,
832 max_hops: u32,
833 ) -> Result<(Vec<Entity>, Vec<Edge>), MemoryError> {
834 self.bfs_with_depth(start_entity_id, max_hops)
835 .await
836 .map(|(e, ed, _)| (e, ed))
837 }
838
839 pub async fn bfs_with_depth(
850 &self,
851 start_entity_id: i64,
852 max_hops: u32,
853 ) -> Result<(Vec<Entity>, Vec<Edge>, std::collections::HashMap<i64, u32>), MemoryError> {
854 use std::collections::HashMap;
855
856 const MAX_FRONTIER: usize = 300;
859
860 let mut depth_map: HashMap<i64, u32> = HashMap::new();
861 let mut frontier: Vec<i64> = vec![start_entity_id];
862 depth_map.insert(start_entity_id, 0);
863
864 for hop in 0..max_hops {
865 if frontier.is_empty() {
866 break;
867 }
868 frontier.truncate(MAX_FRONTIER);
869 let placeholders = frontier
871 .iter()
872 .enumerate()
873 .map(|(i, _)| format!("?{}", i + 1))
874 .collect::<Vec<_>>()
875 .join(", ");
876 let neighbour_sql = format!(
877 "SELECT DISTINCT CASE
878 WHEN source_entity_id IN ({placeholders}) THEN target_entity_id
879 ELSE source_entity_id
880 END as neighbour_id
881 FROM graph_edges
882 WHERE valid_to IS NULL
883 AND (source_entity_id IN ({placeholders}) OR target_entity_id IN ({placeholders}))"
884 );
885 let mut q = sqlx::query_scalar::<_, i64>(&neighbour_sql);
886 for id in &frontier {
887 q = q.bind(*id);
888 }
889 for id in &frontier {
890 q = q.bind(*id);
891 }
892 for id in &frontier {
893 q = q.bind(*id);
894 }
895 let neighbours: Vec<i64> = q.fetch_all(&self.pool).await?;
896
897 let mut next_frontier: Vec<i64> = Vec::new();
898 for nbr in neighbours {
899 if let std::collections::hash_map::Entry::Vacant(e) = depth_map.entry(nbr) {
900 e.insert(hop + 1);
901 next_frontier.push(nbr);
902 }
903 }
904 frontier = next_frontier;
905 }
906
907 let mut visited_ids: Vec<i64> = depth_map.keys().copied().collect();
908 if visited_ids.is_empty() {
909 return Ok((Vec::new(), Vec::new(), depth_map));
910 }
911 visited_ids.truncate(499);
913
914 let placeholders = visited_ids
916 .iter()
917 .enumerate()
918 .map(|(i, _)| format!("?{}", i + 1))
919 .collect::<Vec<_>>()
920 .join(", ");
921
922 let edge_sql = format!(
923 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
924 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
925 FROM graph_edges
926 WHERE valid_to IS NULL
927 AND source_entity_id IN ({placeholders})
928 AND target_entity_id IN ({placeholders})"
929 );
930 let mut edge_query = sqlx::query_as::<_, EdgeRow>(&edge_sql);
931 for id in &visited_ids {
932 edge_query = edge_query.bind(*id);
933 }
934 for id in &visited_ids {
935 edge_query = edge_query.bind(*id);
936 }
937 let edge_rows: Vec<EdgeRow> = edge_query.fetch_all(&self.pool).await?;
938
939 let entity_sql = format!(
940 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
941 FROM graph_entities WHERE id IN ({placeholders})"
942 );
943 let mut entity_query = sqlx::query_as::<_, EntityRow>(&entity_sql);
944 for id in &visited_ids {
945 entity_query = entity_query.bind(*id);
946 }
947 let entity_rows: Vec<EntityRow> = entity_query.fetch_all(&self.pool).await?;
948
949 let entities: Vec<Entity> = entity_rows
950 .into_iter()
951 .map(entity_from_row)
952 .collect::<Result<Vec<_>, _>>()?;
953 let edges: Vec<Edge> = edge_rows.into_iter().map(edge_from_row).collect();
954
955 Ok((entities, edges, depth_map))
956 }
957
958 pub async fn find_entity_by_name(&self, name: &str) -> Result<Vec<Entity>, MemoryError> {
974 let rows: Vec<EntityRow> = sqlx::query_as(
975 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
976 FROM graph_entities
977 WHERE name = ?1 COLLATE NOCASE OR canonical_name = ?1 COLLATE NOCASE
978 LIMIT 5",
979 )
980 .bind(name)
981 .fetch_all(&self.pool)
982 .await?;
983
984 if !rows.is_empty() {
985 return rows.into_iter().map(entity_from_row).collect();
986 }
987
988 self.find_entities_fuzzy(name, 5).await
989 }
990
991 pub async fn unprocessed_messages_for_backfill(
999 &self,
1000 limit: usize,
1001 ) -> Result<Vec<(crate::types::MessageId, String)>, MemoryError> {
1002 let limit = i64::try_from(limit)?;
1003 let rows: Vec<(i64, String)> = sqlx::query_as(
1004 "SELECT id, content FROM messages
1005 WHERE graph_processed = 0
1006 ORDER BY id ASC
1007 LIMIT ?1",
1008 )
1009 .bind(limit)
1010 .fetch_all(&self.pool)
1011 .await?;
1012 Ok(rows
1013 .into_iter()
1014 .map(|(id, content)| (crate::types::MessageId(id), content))
1015 .collect())
1016 }
1017
1018 pub async fn unprocessed_message_count(&self) -> Result<i64, MemoryError> {
1024 let count: i64 =
1025 sqlx::query_scalar("SELECT COUNT(*) FROM messages WHERE graph_processed = 0")
1026 .fetch_one(&self.pool)
1027 .await?;
1028 Ok(count)
1029 }
1030
1031 pub async fn mark_messages_graph_processed(
1037 &self,
1038 ids: &[crate::types::MessageId],
1039 ) -> Result<(), MemoryError> {
1040 if ids.is_empty() {
1041 return Ok(());
1042 }
1043 let placeholders = ids
1044 .iter()
1045 .enumerate()
1046 .map(|(i, _)| format!("?{}", i + 1))
1047 .collect::<Vec<_>>()
1048 .join(", ");
1049 let sql = format!("UPDATE messages SET graph_processed = 1 WHERE id IN ({placeholders})");
1050 let mut query = sqlx::query(&sql);
1051 for id in ids {
1052 query = query.bind(id.0);
1053 }
1054 query.execute(&self.pool).await?;
1055 Ok(())
1056 }
1057}
1058
1059#[derive(sqlx::FromRow)]
1062struct EntityRow {
1063 id: i64,
1064 name: String,
1065 canonical_name: String,
1066 entity_type: String,
1067 summary: Option<String>,
1068 first_seen_at: String,
1069 last_seen_at: String,
1070 qdrant_point_id: Option<String>,
1071}
1072
1073fn entity_from_row(row: EntityRow) -> Result<Entity, MemoryError> {
1074 let entity_type = row
1075 .entity_type
1076 .parse::<EntityType>()
1077 .map_err(MemoryError::GraphStore)?;
1078 Ok(Entity {
1079 id: row.id,
1080 name: row.name,
1081 canonical_name: row.canonical_name,
1082 entity_type,
1083 summary: row.summary,
1084 first_seen_at: row.first_seen_at,
1085 last_seen_at: row.last_seen_at,
1086 qdrant_point_id: row.qdrant_point_id,
1087 })
1088}
1089
1090#[derive(sqlx::FromRow)]
1091struct AliasRow {
1092 id: i64,
1093 entity_id: i64,
1094 alias_name: String,
1095 created_at: String,
1096}
1097
1098fn alias_from_row(row: AliasRow) -> EntityAlias {
1099 EntityAlias {
1100 id: row.id,
1101 entity_id: row.entity_id,
1102 alias_name: row.alias_name,
1103 created_at: row.created_at,
1104 }
1105}
1106
1107#[derive(sqlx::FromRow)]
1108struct EdgeRow {
1109 id: i64,
1110 source_entity_id: i64,
1111 target_entity_id: i64,
1112 relation: String,
1113 fact: String,
1114 confidence: f64,
1115 valid_from: String,
1116 valid_to: Option<String>,
1117 created_at: String,
1118 expired_at: Option<String>,
1119 episode_id: Option<i64>,
1120 qdrant_point_id: Option<String>,
1121}
1122
1123fn edge_from_row(row: EdgeRow) -> Edge {
1124 Edge {
1125 id: row.id,
1126 source_entity_id: row.source_entity_id,
1127 target_entity_id: row.target_entity_id,
1128 relation: row.relation,
1129 fact: row.fact,
1130 #[allow(clippy::cast_possible_truncation)]
1131 confidence: row.confidence as f32,
1132 valid_from: row.valid_from,
1133 valid_to: row.valid_to,
1134 created_at: row.created_at,
1135 expired_at: row.expired_at,
1136 episode_id: row.episode_id.map(MessageId),
1137 qdrant_point_id: row.qdrant_point_id,
1138 }
1139}
1140
1141#[derive(sqlx::FromRow)]
1142struct CommunityRow {
1143 id: i64,
1144 name: String,
1145 summary: String,
1146 entity_ids: String,
1147 fingerprint: Option<String>,
1148 created_at: String,
1149 updated_at: String,
1150}
1151
1152#[cfg(test)]
1155mod tests {
1156 use super::*;
1157 use crate::sqlite::SqliteStore;
1158
1159 async fn setup() -> GraphStore {
1160 let store = SqliteStore::new(":memory:").await.unwrap();
1161 GraphStore::new(store.pool().clone())
1162 }
1163
1164 #[tokio::test]
1165 async fn upsert_entity_insert_new() {
1166 let gs = setup().await;
1167 let id = gs
1168 .upsert_entity("Alice", "Alice", EntityType::Person, Some("a person"))
1169 .await
1170 .unwrap();
1171 assert!(id > 0);
1172 }
1173
1174 #[tokio::test]
1175 async fn upsert_entity_update_existing() {
1176 let gs = setup().await;
1177 let id1 = gs
1178 .upsert_entity("Alice", "Alice", EntityType::Person, None)
1179 .await
1180 .unwrap();
1181 let id2 = gs
1184 .upsert_entity("Alice", "Alice", EntityType::Person, Some("updated"))
1185 .await
1186 .unwrap();
1187 assert_eq!(id1, id2);
1188 let entity = gs
1189 .find_entity("Alice", EntityType::Person)
1190 .await
1191 .unwrap()
1192 .unwrap();
1193 assert_eq!(entity.summary.as_deref(), Some("updated"));
1194 }
1195
1196 #[tokio::test]
1197 async fn find_entity_found() {
1198 let gs = setup().await;
1199 gs.upsert_entity("Bob", "Bob", EntityType::Tool, Some("a tool"))
1200 .await
1201 .unwrap();
1202 let entity = gs
1203 .find_entity("Bob", EntityType::Tool)
1204 .await
1205 .unwrap()
1206 .unwrap();
1207 assert_eq!(entity.name, "Bob");
1208 assert_eq!(entity.entity_type, EntityType::Tool);
1209 }
1210
1211 #[tokio::test]
1212 async fn find_entity_not_found() {
1213 let gs = setup().await;
1214 let result = gs.find_entity("Nobody", EntityType::Person).await.unwrap();
1215 assert!(result.is_none());
1216 }
1217
1218 #[tokio::test]
1219 async fn find_entities_fuzzy_partial_match() {
1220 let gs = setup().await;
1221 gs.upsert_entity("GraphQL", "GraphQL", EntityType::Concept, None)
1222 .await
1223 .unwrap();
1224 gs.upsert_entity("Graph", "Graph", EntityType::Concept, None)
1225 .await
1226 .unwrap();
1227 gs.upsert_entity("Unrelated", "Unrelated", EntityType::Concept, None)
1228 .await
1229 .unwrap();
1230
1231 let results = gs.find_entities_fuzzy("graph", 10).await.unwrap();
1232 assert_eq!(results.len(), 2);
1233 assert!(results.iter().any(|e| e.name == "GraphQL"));
1234 assert!(results.iter().any(|e| e.name == "Graph"));
1235 }
1236
1237 #[tokio::test]
1238 async fn entity_count_empty() {
1239 let gs = setup().await;
1240 assert_eq!(gs.entity_count().await.unwrap(), 0);
1241 }
1242
1243 #[tokio::test]
1244 async fn entity_count_non_empty() {
1245 let gs = setup().await;
1246 gs.upsert_entity("A", "A", EntityType::Concept, None)
1247 .await
1248 .unwrap();
1249 gs.upsert_entity("B", "B", EntityType::Concept, None)
1250 .await
1251 .unwrap();
1252 assert_eq!(gs.entity_count().await.unwrap(), 2);
1253 }
1254
1255 #[tokio::test]
1256 async fn all_entities_and_stream() {
1257 let gs = setup().await;
1258 gs.upsert_entity("X", "X", EntityType::Project, None)
1259 .await
1260 .unwrap();
1261 gs.upsert_entity("Y", "Y", EntityType::Language, None)
1262 .await
1263 .unwrap();
1264
1265 let all = gs.all_entities().await.unwrap();
1266 assert_eq!(all.len(), 2);
1267
1268 use futures::StreamExt as _;
1269 let streamed: Vec<Result<Entity, _>> = gs.all_entities_stream().collect().await;
1270 assert_eq!(streamed.len(), 2);
1271 assert!(streamed.iter().all(|r| r.is_ok()));
1272 }
1273
1274 #[tokio::test]
1275 async fn insert_edge_without_episode() {
1276 let gs = setup().await;
1277 let src = gs
1278 .upsert_entity("Src", "Src", EntityType::Concept, None)
1279 .await
1280 .unwrap();
1281 let tgt = gs
1282 .upsert_entity("Tgt", "Tgt", EntityType::Concept, None)
1283 .await
1284 .unwrap();
1285 let eid = gs
1286 .insert_edge(src, tgt, "relates_to", "Src relates to Tgt", 0.9, None)
1287 .await
1288 .unwrap();
1289 assert!(eid > 0);
1290 }
1291
1292 #[tokio::test]
1293 async fn insert_edge_deduplicates_active_edge() {
1294 let gs = setup().await;
1295 let src = gs
1296 .upsert_entity("Alice", "Alice", EntityType::Person, None)
1297 .await
1298 .unwrap();
1299 let tgt = gs
1300 .upsert_entity("Google", "Google", EntityType::Organization, None)
1301 .await
1302 .unwrap();
1303
1304 let id1 = gs
1305 .insert_edge(src, tgt, "works_at", "Alice works at Google", 0.7, None)
1306 .await
1307 .unwrap();
1308
1309 let id2 = gs
1311 .insert_edge(src, tgt, "works_at", "Alice works at Google", 0.9, None)
1312 .await
1313 .unwrap();
1314 assert_eq!(id1, id2, "duplicate active edge must not be created");
1315
1316 let count: i64 =
1318 sqlx::query_scalar("SELECT COUNT(*) FROM graph_edges WHERE valid_to IS NULL")
1319 .fetch_one(&gs.pool)
1320 .await
1321 .unwrap();
1322 assert_eq!(count, 1, "only one active edge must exist");
1323
1324 let conf: f64 = sqlx::query_scalar("SELECT confidence FROM graph_edges WHERE id = ?1")
1325 .bind(id1)
1326 .fetch_one(&gs.pool)
1327 .await
1328 .unwrap();
1329 assert!(
1331 (conf - f64::from(0.9_f32)).abs() < 1e-6,
1332 "confidence must be updated to max, got {conf}"
1333 );
1334 }
1335
1336 #[tokio::test]
1337 async fn insert_edge_different_relations_are_distinct() {
1338 let gs = setup().await;
1339 let src = gs
1340 .upsert_entity("Bob", "Bob", EntityType::Person, None)
1341 .await
1342 .unwrap();
1343 let tgt = gs
1344 .upsert_entity("Acme", "Acme", EntityType::Organization, None)
1345 .await
1346 .unwrap();
1347
1348 let id1 = gs
1349 .insert_edge(src, tgt, "founded", "Bob founded Acme", 0.8, None)
1350 .await
1351 .unwrap();
1352 let id2 = gs
1353 .insert_edge(src, tgt, "chairs", "Bob chairs Acme", 0.8, None)
1354 .await
1355 .unwrap();
1356 assert_ne!(id1, id2, "different relations must produce distinct edges");
1357
1358 let count: i64 =
1359 sqlx::query_scalar("SELECT COUNT(*) FROM graph_edges WHERE valid_to IS NULL")
1360 .fetch_one(&gs.pool)
1361 .await
1362 .unwrap();
1363 assert_eq!(count, 2);
1364 }
1365
1366 #[tokio::test]
1367 async fn insert_edge_with_episode() {
1368 let gs = setup().await;
1369 let src = gs
1370 .upsert_entity("Src2", "Src2", EntityType::Concept, None)
1371 .await
1372 .unwrap();
1373 let tgt = gs
1374 .upsert_entity("Tgt2", "Tgt2", EntityType::Concept, None)
1375 .await
1376 .unwrap();
1377 let episode = MessageId(999);
1383 let result = gs
1384 .insert_edge(src, tgt, "uses", "Src2 uses Tgt2", 1.0, Some(episode))
1385 .await;
1386 match &result {
1387 Ok(eid) => assert!(*eid > 0, "inserted edge should have positive id"),
1388 Err(MemoryError::Sqlite(_)) => {} Err(e) => panic!("unexpected error: {e}"),
1390 }
1391 }
1392
1393 #[tokio::test]
1394 async fn invalidate_edge_sets_timestamps() {
1395 let gs = setup().await;
1396 let src = gs
1397 .upsert_entity("E1", "E1", EntityType::Concept, None)
1398 .await
1399 .unwrap();
1400 let tgt = gs
1401 .upsert_entity("E2", "E2", EntityType::Concept, None)
1402 .await
1403 .unwrap();
1404 let eid = gs
1405 .insert_edge(src, tgt, "r", "fact", 1.0, None)
1406 .await
1407 .unwrap();
1408 gs.invalidate_edge(eid).await.unwrap();
1409
1410 let row: (Option<String>, Option<String>) =
1411 sqlx::query_as("SELECT valid_to, expired_at FROM graph_edges WHERE id = ?1")
1412 .bind(eid)
1413 .fetch_one(&gs.pool)
1414 .await
1415 .unwrap();
1416 assert!(row.0.is_some(), "valid_to should be set");
1417 assert!(row.1.is_some(), "expired_at should be set");
1418 }
1419
1420 #[tokio::test]
1421 async fn edges_for_entity_both_directions() {
1422 let gs = setup().await;
1423 let a = gs
1424 .upsert_entity("A", "A", EntityType::Concept, None)
1425 .await
1426 .unwrap();
1427 let b = gs
1428 .upsert_entity("B", "B", EntityType::Concept, None)
1429 .await
1430 .unwrap();
1431 let c = gs
1432 .upsert_entity("C", "C", EntityType::Concept, None)
1433 .await
1434 .unwrap();
1435 gs.insert_edge(a, b, "r", "f1", 1.0, None).await.unwrap();
1436 gs.insert_edge(c, a, "r", "f2", 1.0, None).await.unwrap();
1437
1438 let edges = gs.edges_for_entity(a).await.unwrap();
1439 assert_eq!(edges.len(), 2);
1440 }
1441
1442 #[tokio::test]
1443 async fn edges_between_both_directions() {
1444 let gs = setup().await;
1445 let a = gs
1446 .upsert_entity("PA", "PA", EntityType::Person, None)
1447 .await
1448 .unwrap();
1449 let b = gs
1450 .upsert_entity("PB", "PB", EntityType::Person, None)
1451 .await
1452 .unwrap();
1453 gs.insert_edge(a, b, "knows", "PA knows PB", 1.0, None)
1454 .await
1455 .unwrap();
1456
1457 let fwd = gs.edges_between(a, b).await.unwrap();
1458 assert_eq!(fwd.len(), 1);
1459 let rev = gs.edges_between(b, a).await.unwrap();
1460 assert_eq!(rev.len(), 1);
1461 }
1462
1463 #[tokio::test]
1464 async fn active_edge_count_excludes_invalidated() {
1465 let gs = setup().await;
1466 let a = gs
1467 .upsert_entity("N1", "N1", EntityType::Concept, None)
1468 .await
1469 .unwrap();
1470 let b = gs
1471 .upsert_entity("N2", "N2", EntityType::Concept, None)
1472 .await
1473 .unwrap();
1474 let e1 = gs.insert_edge(a, b, "r1", "f1", 1.0, None).await.unwrap();
1475 gs.insert_edge(a, b, "r2", "f2", 1.0, None).await.unwrap();
1476 gs.invalidate_edge(e1).await.unwrap();
1477
1478 assert_eq!(gs.active_edge_count().await.unwrap(), 1);
1479 }
1480
1481 #[tokio::test]
1482 async fn upsert_community_insert_and_update() {
1483 let gs = setup().await;
1484 let id1 = gs
1485 .upsert_community("clusterA", "summary A", &[1, 2, 3], None)
1486 .await
1487 .unwrap();
1488 assert!(id1 > 0);
1489 let id2 = gs
1490 .upsert_community("clusterA", "summary A updated", &[1, 2, 3, 4], None)
1491 .await
1492 .unwrap();
1493 assert_eq!(id1, id2);
1494 let communities = gs.all_communities().await.unwrap();
1495 assert_eq!(communities.len(), 1);
1496 assert_eq!(communities[0].summary, "summary A updated");
1497 assert_eq!(communities[0].entity_ids, vec![1, 2, 3, 4]);
1498 }
1499
1500 #[tokio::test]
1501 async fn community_for_entity_found() {
1502 let gs = setup().await;
1503 let a = gs
1504 .upsert_entity("CA", "CA", EntityType::Concept, None)
1505 .await
1506 .unwrap();
1507 let b = gs
1508 .upsert_entity("CB", "CB", EntityType::Concept, None)
1509 .await
1510 .unwrap();
1511 gs.upsert_community("cA", "summary", &[a, b], None)
1512 .await
1513 .unwrap();
1514 let result = gs.community_for_entity(a).await.unwrap();
1515 assert!(result.is_some());
1516 assert_eq!(result.unwrap().name, "cA");
1517 }
1518
1519 #[tokio::test]
1520 async fn community_for_entity_not_found() {
1521 let gs = setup().await;
1522 let result = gs.community_for_entity(999).await.unwrap();
1523 assert!(result.is_none());
1524 }
1525
1526 #[tokio::test]
1527 async fn community_count() {
1528 let gs = setup().await;
1529 assert_eq!(gs.community_count().await.unwrap(), 0);
1530 gs.upsert_community("c1", "s1", &[], None).await.unwrap();
1531 gs.upsert_community("c2", "s2", &[], None).await.unwrap();
1532 assert_eq!(gs.community_count().await.unwrap(), 2);
1533 }
1534
1535 #[tokio::test]
1536 async fn metadata_get_set_round_trip() {
1537 let gs = setup().await;
1538 assert_eq!(gs.get_metadata("counter").await.unwrap(), None);
1539 gs.set_metadata("counter", "42").await.unwrap();
1540 assert_eq!(gs.get_metadata("counter").await.unwrap(), Some("42".into()));
1541 gs.set_metadata("counter", "43").await.unwrap();
1542 assert_eq!(gs.get_metadata("counter").await.unwrap(), Some("43".into()));
1543 }
1544
1545 #[tokio::test]
1546 async fn bfs_max_hops_0_returns_only_start() {
1547 let gs = setup().await;
1548 let a = gs
1549 .upsert_entity("BfsA", "BfsA", EntityType::Concept, None)
1550 .await
1551 .unwrap();
1552 let b = gs
1553 .upsert_entity("BfsB", "BfsB", EntityType::Concept, None)
1554 .await
1555 .unwrap();
1556 gs.insert_edge(a, b, "r", "f", 1.0, None).await.unwrap();
1557
1558 let (entities, edges) = gs.bfs(a, 0).await.unwrap();
1559 assert_eq!(entities.len(), 1);
1560 assert_eq!(entities[0].id, a);
1561 assert!(edges.is_empty());
1562 }
1563
1564 #[tokio::test]
1565 async fn bfs_max_hops_2_chain() {
1566 let gs = setup().await;
1567 let a = gs
1568 .upsert_entity("ChainA", "ChainA", EntityType::Concept, None)
1569 .await
1570 .unwrap();
1571 let b = gs
1572 .upsert_entity("ChainB", "ChainB", EntityType::Concept, None)
1573 .await
1574 .unwrap();
1575 let c = gs
1576 .upsert_entity("ChainC", "ChainC", EntityType::Concept, None)
1577 .await
1578 .unwrap();
1579 gs.insert_edge(a, b, "r", "f1", 1.0, None).await.unwrap();
1580 gs.insert_edge(b, c, "r", "f2", 1.0, None).await.unwrap();
1581
1582 let (entities, edges) = gs.bfs(a, 2).await.unwrap();
1583 let ids: Vec<_> = entities.iter().map(|e| e.id).collect();
1584 assert!(ids.contains(&a));
1585 assert!(ids.contains(&b));
1586 assert!(ids.contains(&c));
1587 assert_eq!(edges.len(), 2);
1588 }
1589
1590 #[tokio::test]
1591 async fn bfs_cycle_no_infinite_loop() {
1592 let gs = setup().await;
1593 let a = gs
1594 .upsert_entity("CycA", "CycA", EntityType::Concept, None)
1595 .await
1596 .unwrap();
1597 let b = gs
1598 .upsert_entity("CycB", "CycB", EntityType::Concept, None)
1599 .await
1600 .unwrap();
1601 gs.insert_edge(a, b, "r", "f1", 1.0, None).await.unwrap();
1602 gs.insert_edge(b, a, "r", "f2", 1.0, None).await.unwrap();
1603
1604 let (entities, _edges) = gs.bfs(a, 3).await.unwrap();
1605 let ids: Vec<_> = entities.iter().map(|e| e.id).collect();
1606 assert!(ids.contains(&a));
1608 assert!(ids.contains(&b));
1609 assert_eq!(ids.len(), 2);
1610 }
1611
1612 #[tokio::test]
1613 async fn test_invalidated_edges_excluded_from_bfs() {
1614 let gs = setup().await;
1615 let a = gs
1616 .upsert_entity("InvA", "InvA", EntityType::Concept, None)
1617 .await
1618 .unwrap();
1619 let b = gs
1620 .upsert_entity("InvB", "InvB", EntityType::Concept, None)
1621 .await
1622 .unwrap();
1623 let c = gs
1624 .upsert_entity("InvC", "InvC", EntityType::Concept, None)
1625 .await
1626 .unwrap();
1627 let ab = gs.insert_edge(a, b, "r", "f1", 1.0, None).await.unwrap();
1628 gs.insert_edge(b, c, "r", "f2", 1.0, None).await.unwrap();
1629 gs.invalidate_edge(ab).await.unwrap();
1631
1632 let (entities, edges) = gs.bfs(a, 2).await.unwrap();
1633 let ids: Vec<_> = entities.iter().map(|e| e.id).collect();
1634 assert_eq!(ids, vec![a], "only start entity should be reachable");
1635 assert!(edges.is_empty(), "no active edges should be returned");
1636 }
1637
1638 #[tokio::test]
1639 async fn test_bfs_empty_graph() {
1640 let gs = setup().await;
1641 let a = gs
1642 .upsert_entity("IsoA", "IsoA", EntityType::Concept, None)
1643 .await
1644 .unwrap();
1645
1646 let (entities, edges) = gs.bfs(a, 2).await.unwrap();
1647 let ids: Vec<_> = entities.iter().map(|e| e.id).collect();
1648 assert_eq!(ids, vec![a], "isolated node: only start entity returned");
1649 assert!(edges.is_empty(), "no edges for isolated node");
1650 }
1651
1652 #[tokio::test]
1653 async fn test_bfs_diamond() {
1654 let gs = setup().await;
1655 let a = gs
1656 .upsert_entity("DiamA", "DiamA", EntityType::Concept, None)
1657 .await
1658 .unwrap();
1659 let b = gs
1660 .upsert_entity("DiamB", "DiamB", EntityType::Concept, None)
1661 .await
1662 .unwrap();
1663 let c = gs
1664 .upsert_entity("DiamC", "DiamC", EntityType::Concept, None)
1665 .await
1666 .unwrap();
1667 let d = gs
1668 .upsert_entity("DiamD", "DiamD", EntityType::Concept, None)
1669 .await
1670 .unwrap();
1671 gs.insert_edge(a, b, "r", "f1", 1.0, None).await.unwrap();
1672 gs.insert_edge(a, c, "r", "f2", 1.0, None).await.unwrap();
1673 gs.insert_edge(b, d, "r", "f3", 1.0, None).await.unwrap();
1674 gs.insert_edge(c, d, "r", "f4", 1.0, None).await.unwrap();
1675
1676 let (entities, edges) = gs.bfs(a, 2).await.unwrap();
1677 let mut ids: Vec<_> = entities.iter().map(|e| e.id).collect();
1678 ids.sort_unstable();
1679 let mut expected = vec![a, b, c, d];
1680 expected.sort_unstable();
1681 assert_eq!(ids, expected, "all 4 nodes reachable, no duplicates");
1682 assert_eq!(edges.len(), 4, "all 4 edges returned");
1683 }
1684
1685 #[tokio::test]
1686 async fn extraction_count_default_zero() {
1687 let gs = setup().await;
1688 assert_eq!(gs.extraction_count().await.unwrap(), 0);
1689 }
1690
1691 #[tokio::test]
1692 async fn extraction_count_after_set() {
1693 let gs = setup().await;
1694 gs.set_metadata("extraction_count", "7").await.unwrap();
1695 assert_eq!(gs.extraction_count().await.unwrap(), 7);
1696 }
1697
1698 #[tokio::test]
1699 async fn all_active_edges_stream_excludes_invalidated() {
1700 use futures::TryStreamExt as _;
1701 let gs = setup().await;
1702 let a = gs
1703 .upsert_entity("SA", "SA", EntityType::Concept, None)
1704 .await
1705 .unwrap();
1706 let b = gs
1707 .upsert_entity("SB", "SB", EntityType::Concept, None)
1708 .await
1709 .unwrap();
1710 let c = gs
1711 .upsert_entity("SC", "SC", EntityType::Concept, None)
1712 .await
1713 .unwrap();
1714 let e1 = gs.insert_edge(a, b, "r", "f1", 1.0, None).await.unwrap();
1715 gs.insert_edge(b, c, "r", "f2", 1.0, None).await.unwrap();
1716 gs.invalidate_edge(e1).await.unwrap();
1717
1718 let edges: Vec<_> = gs.all_active_edges_stream().try_collect().await.unwrap();
1719 assert_eq!(edges.len(), 1, "only the active edge should be returned");
1720 assert_eq!(edges[0].source_entity_id, b);
1721 assert_eq!(edges[0].target_entity_id, c);
1722 }
1723
1724 #[tokio::test]
1725 async fn find_community_by_id_found_and_not_found() {
1726 let gs = setup().await;
1727 let cid = gs
1728 .upsert_community("grp", "summary", &[1, 2], None)
1729 .await
1730 .unwrap();
1731 let found = gs.find_community_by_id(cid).await.unwrap();
1732 assert!(found.is_some());
1733 assert_eq!(found.unwrap().name, "grp");
1734
1735 let missing = gs.find_community_by_id(9999).await.unwrap();
1736 assert!(missing.is_none());
1737 }
1738
1739 #[tokio::test]
1740 async fn delete_all_communities_clears_table() {
1741 let gs = setup().await;
1742 gs.upsert_community("c1", "s1", &[1], None).await.unwrap();
1743 gs.upsert_community("c2", "s2", &[2], None).await.unwrap();
1744 assert_eq!(gs.community_count().await.unwrap(), 2);
1745 gs.delete_all_communities().await.unwrap();
1746 assert_eq!(gs.community_count().await.unwrap(), 0);
1747 }
1748
1749 #[tokio::test]
1750 async fn test_find_entities_fuzzy_no_results() {
1751 let gs = setup().await;
1752 gs.upsert_entity("Alpha", "Alpha", EntityType::Concept, None)
1753 .await
1754 .unwrap();
1755 let results = gs.find_entities_fuzzy("zzzznonexistent", 10).await.unwrap();
1756 assert!(
1757 results.is_empty(),
1758 "no entities should match an unknown term"
1759 );
1760 }
1761
1762 #[tokio::test]
1765 async fn upsert_entity_stores_canonical_name() {
1766 let gs = setup().await;
1767 gs.upsert_entity("rust", "rust", EntityType::Language, None)
1768 .await
1769 .unwrap();
1770 let entity = gs
1771 .find_entity("rust", EntityType::Language)
1772 .await
1773 .unwrap()
1774 .unwrap();
1775 assert_eq!(entity.canonical_name, "rust");
1776 assert_eq!(entity.name, "rust");
1777 }
1778
1779 #[tokio::test]
1780 async fn add_alias_idempotent() {
1781 let gs = setup().await;
1782 let id = gs
1783 .upsert_entity("rust", "rust", EntityType::Language, None)
1784 .await
1785 .unwrap();
1786 gs.add_alias(id, "rust-lang").await.unwrap();
1787 gs.add_alias(id, "rust-lang").await.unwrap();
1789 let aliases = gs.aliases_for_entity(id).await.unwrap();
1790 assert_eq!(
1791 aliases
1792 .iter()
1793 .filter(|a| a.alias_name == "rust-lang")
1794 .count(),
1795 1
1796 );
1797 }
1798
1799 #[tokio::test]
1802 async fn find_entity_by_id_found() {
1803 let gs = setup().await;
1804 let id = gs
1805 .upsert_entity("FindById", "finbyid", EntityType::Concept, Some("summary"))
1806 .await
1807 .unwrap();
1808 let entity = gs.find_entity_by_id(id).await.unwrap();
1809 assert!(entity.is_some());
1810 let entity = entity.unwrap();
1811 assert_eq!(entity.id, id);
1812 assert_eq!(entity.name, "FindById");
1813 }
1814
1815 #[tokio::test]
1816 async fn find_entity_by_id_not_found() {
1817 let gs = setup().await;
1818 let result = gs.find_entity_by_id(99999).await.unwrap();
1819 assert!(result.is_none());
1820 }
1821
1822 #[tokio::test]
1823 async fn set_entity_qdrant_point_id_updates() {
1824 let gs = setup().await;
1825 let id = gs
1826 .upsert_entity("QdrantPoint", "qdrantpoint", EntityType::Concept, None)
1827 .await
1828 .unwrap();
1829 let point_id = "550e8400-e29b-41d4-a716-446655440000";
1830 gs.set_entity_qdrant_point_id(id, point_id).await.unwrap();
1831
1832 let entity = gs.find_entity_by_id(id).await.unwrap().unwrap();
1833 assert_eq!(entity.qdrant_point_id.as_deref(), Some(point_id));
1834 }
1835
1836 #[tokio::test]
1837 async fn find_entities_fuzzy_matches_summary() {
1838 let gs = setup().await;
1839 gs.upsert_entity(
1840 "Rust",
1841 "Rust",
1842 EntityType::Language,
1843 Some("a systems programming language"),
1844 )
1845 .await
1846 .unwrap();
1847 gs.upsert_entity(
1848 "Go",
1849 "Go",
1850 EntityType::Language,
1851 Some("a compiled language by Google"),
1852 )
1853 .await
1854 .unwrap();
1855 let results = gs.find_entities_fuzzy("systems", 10).await.unwrap();
1857 assert_eq!(results.len(), 1);
1858 assert_eq!(results[0].name, "Rust");
1859 }
1860
1861 #[tokio::test]
1862 async fn find_entities_fuzzy_empty_query() {
1863 let gs = setup().await;
1864 gs.upsert_entity("Alpha", "Alpha", EntityType::Concept, None)
1865 .await
1866 .unwrap();
1867 let results = gs.find_entities_fuzzy("", 10).await.unwrap();
1869 assert!(results.is_empty(), "empty query should return no results");
1870 let results = gs.find_entities_fuzzy(" ", 10).await.unwrap();
1872 assert!(
1873 results.is_empty(),
1874 "whitespace query should return no results"
1875 );
1876 }
1877
1878 #[tokio::test]
1879 async fn find_entity_by_alias_case_insensitive() {
1880 let gs = setup().await;
1881 let id = gs
1882 .upsert_entity("rust", "rust", EntityType::Language, None)
1883 .await
1884 .unwrap();
1885 gs.add_alias(id, "rust").await.unwrap();
1886 gs.add_alias(id, "rust-lang").await.unwrap();
1887
1888 let found = gs
1889 .find_entity_by_alias("RUST-LANG", EntityType::Language)
1890 .await
1891 .unwrap();
1892 assert!(found.is_some());
1893 assert_eq!(found.unwrap().id, id);
1894 }
1895
1896 #[tokio::test]
1897 async fn find_entity_by_alias_returns_none_for_unknown() {
1898 let gs = setup().await;
1899 let id = gs
1900 .upsert_entity("rust", "rust", EntityType::Language, None)
1901 .await
1902 .unwrap();
1903 gs.add_alias(id, "rust").await.unwrap();
1904
1905 let found = gs
1906 .find_entity_by_alias("python", EntityType::Language)
1907 .await
1908 .unwrap();
1909 assert!(found.is_none());
1910 }
1911
1912 #[tokio::test]
1913 async fn find_entity_by_alias_filters_by_entity_type() {
1914 let gs = setup().await;
1916 let lang_id = gs
1917 .upsert_entity("python", "python", EntityType::Language, None)
1918 .await
1919 .unwrap();
1920 gs.add_alias(lang_id, "python").await.unwrap();
1921
1922 let found_tool = gs
1923 .find_entity_by_alias("python", EntityType::Tool)
1924 .await
1925 .unwrap();
1926 assert!(
1927 found_tool.is_none(),
1928 "cross-type alias collision must not occur"
1929 );
1930
1931 let found_lang = gs
1932 .find_entity_by_alias("python", EntityType::Language)
1933 .await
1934 .unwrap();
1935 assert!(found_lang.is_some());
1936 assert_eq!(found_lang.unwrap().id, lang_id);
1937 }
1938
1939 #[tokio::test]
1940 async fn aliases_for_entity_returns_all() {
1941 let gs = setup().await;
1942 let id = gs
1943 .upsert_entity("rust", "rust", EntityType::Language, None)
1944 .await
1945 .unwrap();
1946 gs.add_alias(id, "rust").await.unwrap();
1947 gs.add_alias(id, "rust-lang").await.unwrap();
1948 gs.add_alias(id, "rustlang").await.unwrap();
1949
1950 let aliases = gs.aliases_for_entity(id).await.unwrap();
1951 assert_eq!(aliases.len(), 3);
1952 let names: Vec<&str> = aliases.iter().map(|a| a.alias_name.as_str()).collect();
1953 assert!(names.contains(&"rust"));
1954 assert!(names.contains(&"rust-lang"));
1955 assert!(names.contains(&"rustlang"));
1956 }
1957
1958 #[tokio::test]
1959 async fn find_entities_fuzzy_includes_aliases() {
1960 let gs = setup().await;
1961 let id = gs
1962 .upsert_entity("rust", "rust", EntityType::Language, None)
1963 .await
1964 .unwrap();
1965 gs.add_alias(id, "rust-lang").await.unwrap();
1966 gs.upsert_entity("python", "python", EntityType::Language, None)
1967 .await
1968 .unwrap();
1969
1970 let results = gs.find_entities_fuzzy("rust-lang", 10).await.unwrap();
1972 assert!(!results.is_empty());
1973 assert!(results.iter().any(|e| e.id == id));
1974 }
1975
1976 #[tokio::test]
1977 async fn orphan_alias_cleanup_on_entity_delete() {
1978 let gs = setup().await;
1979 let id = gs
1980 .upsert_entity("rust", "rust", EntityType::Language, None)
1981 .await
1982 .unwrap();
1983 gs.add_alias(id, "rust").await.unwrap();
1984 gs.add_alias(id, "rust-lang").await.unwrap();
1985
1986 sqlx::query("DELETE FROM graph_entities WHERE id = ?1")
1988 .bind(id)
1989 .execute(&gs.pool)
1990 .await
1991 .unwrap();
1992
1993 let aliases = gs.aliases_for_entity(id).await.unwrap();
1995 assert!(
1996 aliases.is_empty(),
1997 "aliases should cascade-delete with entity"
1998 );
1999 }
2000
2001 #[tokio::test]
2011 async fn migration_024_backfill_preserves_entities_and_edges() {
2012 use sqlx::Acquire as _;
2013 use sqlx::ConnectOptions as _;
2014 use sqlx::sqlite::SqliteConnectOptions;
2015
2016 let opts = SqliteConnectOptions::from_url(&"sqlite::memory:".parse().unwrap())
2019 .unwrap()
2020 .foreign_keys(true);
2021 let pool = sqlx::pool::PoolOptions::<sqlx::Sqlite>::new()
2022 .max_connections(1)
2023 .connect_with(opts)
2024 .await
2025 .unwrap();
2026
2027 sqlx::query(
2029 "CREATE TABLE graph_entities (
2030 id INTEGER PRIMARY KEY AUTOINCREMENT,
2031 name TEXT NOT NULL,
2032 entity_type TEXT NOT NULL,
2033 summary TEXT,
2034 first_seen_at TEXT NOT NULL DEFAULT (datetime('now')),
2035 last_seen_at TEXT NOT NULL DEFAULT (datetime('now')),
2036 qdrant_point_id TEXT,
2037 UNIQUE(name, entity_type)
2038 )",
2039 )
2040 .execute(&pool)
2041 .await
2042 .unwrap();
2043
2044 sqlx::query(
2045 "CREATE TABLE graph_edges (
2046 id INTEGER PRIMARY KEY AUTOINCREMENT,
2047 source_entity_id INTEGER NOT NULL REFERENCES graph_entities(id) ON DELETE CASCADE,
2048 target_entity_id INTEGER NOT NULL REFERENCES graph_entities(id) ON DELETE CASCADE,
2049 relation TEXT NOT NULL,
2050 fact TEXT NOT NULL,
2051 confidence REAL NOT NULL DEFAULT 1.0,
2052 valid_from TEXT NOT NULL DEFAULT (datetime('now')),
2053 valid_to TEXT,
2054 created_at TEXT NOT NULL DEFAULT (datetime('now')),
2055 expired_at TEXT,
2056 episode_id INTEGER,
2057 qdrant_point_id TEXT
2058 )",
2059 )
2060 .execute(&pool)
2061 .await
2062 .unwrap();
2063
2064 sqlx::query(
2066 "CREATE VIRTUAL TABLE IF NOT EXISTS graph_entities_fts USING fts5(
2067 name, summary, content='graph_entities', content_rowid='id',
2068 tokenize='unicode61 remove_diacritics 2'
2069 )",
2070 )
2071 .execute(&pool)
2072 .await
2073 .unwrap();
2074 sqlx::query(
2075 "CREATE TRIGGER IF NOT EXISTS graph_entities_fts_insert AFTER INSERT ON graph_entities
2076 BEGIN INSERT INTO graph_entities_fts(rowid, name, summary) VALUES (new.id, new.name, COALESCE(new.summary, '')); END",
2077 )
2078 .execute(&pool)
2079 .await
2080 .unwrap();
2081 sqlx::query(
2082 "CREATE TRIGGER IF NOT EXISTS graph_entities_fts_delete AFTER DELETE ON graph_entities
2083 BEGIN INSERT INTO graph_entities_fts(graph_entities_fts, rowid, name, summary) VALUES ('delete', old.id, old.name, COALESCE(old.summary, '')); END",
2084 )
2085 .execute(&pool)
2086 .await
2087 .unwrap();
2088 sqlx::query(
2089 "CREATE TRIGGER IF NOT EXISTS graph_entities_fts_update AFTER UPDATE ON graph_entities
2090 BEGIN
2091 INSERT INTO graph_entities_fts(graph_entities_fts, rowid, name, summary) VALUES ('delete', old.id, old.name, COALESCE(old.summary, ''));
2092 INSERT INTO graph_entities_fts(rowid, name, summary) VALUES (new.id, new.name, COALESCE(new.summary, ''));
2093 END",
2094 )
2095 .execute(&pool)
2096 .await
2097 .unwrap();
2098
2099 let alice_id: i64 = sqlx::query_scalar(
2101 "INSERT INTO graph_entities (name, entity_type) VALUES ('Alice', 'person') RETURNING id",
2102 )
2103 .fetch_one(&pool)
2104 .await
2105 .unwrap();
2106
2107 let rust_id: i64 = sqlx::query_scalar(
2108 "INSERT INTO graph_entities (name, entity_type) VALUES ('Rust', 'language') RETURNING id",
2109 )
2110 .fetch_one(&pool)
2111 .await
2112 .unwrap();
2113
2114 sqlx::query(
2115 "INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact)
2116 VALUES (?1, ?2, 'uses', 'Alice uses Rust')",
2117 )
2118 .bind(alice_id)
2119 .bind(rust_id)
2120 .execute(&pool)
2121 .await
2122 .unwrap();
2123
2124 let mut conn = pool.acquire().await.unwrap();
2128 let conn = conn.acquire().await.unwrap();
2129
2130 sqlx::query("PRAGMA foreign_keys = OFF")
2131 .execute(&mut *conn)
2132 .await
2133 .unwrap();
2134 sqlx::query("ALTER TABLE graph_entities ADD COLUMN canonical_name TEXT")
2135 .execute(&mut *conn)
2136 .await
2137 .unwrap();
2138 sqlx::query("UPDATE graph_entities SET canonical_name = name WHERE canonical_name IS NULL")
2139 .execute(&mut *conn)
2140 .await
2141 .unwrap();
2142 sqlx::query(
2143 "CREATE TABLE graph_entities_new (
2144 id INTEGER PRIMARY KEY AUTOINCREMENT,
2145 name TEXT NOT NULL,
2146 canonical_name TEXT NOT NULL,
2147 entity_type TEXT NOT NULL,
2148 summary TEXT,
2149 first_seen_at TEXT NOT NULL DEFAULT (datetime('now')),
2150 last_seen_at TEXT NOT NULL DEFAULT (datetime('now')),
2151 qdrant_point_id TEXT,
2152 UNIQUE(canonical_name, entity_type)
2153 )",
2154 )
2155 .execute(&mut *conn)
2156 .await
2157 .unwrap();
2158 sqlx::query(
2159 "INSERT INTO graph_entities_new
2160 (id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id)
2161 SELECT id, name, COALESCE(canonical_name, name), entity_type, summary,
2162 first_seen_at, last_seen_at, qdrant_point_id
2163 FROM graph_entities",
2164 )
2165 .execute(&mut *conn)
2166 .await
2167 .unwrap();
2168 sqlx::query("DROP TABLE graph_entities")
2169 .execute(&mut *conn)
2170 .await
2171 .unwrap();
2172 sqlx::query("ALTER TABLE graph_entities_new RENAME TO graph_entities")
2173 .execute(&mut *conn)
2174 .await
2175 .unwrap();
2176 sqlx::query(
2178 "CREATE TRIGGER IF NOT EXISTS graph_entities_fts_insert AFTER INSERT ON graph_entities
2179 BEGIN INSERT INTO graph_entities_fts(rowid, name, summary) VALUES (new.id, new.name, COALESCE(new.summary, '')); END",
2180 )
2181 .execute(&mut *conn)
2182 .await
2183 .unwrap();
2184 sqlx::query(
2185 "CREATE TRIGGER IF NOT EXISTS graph_entities_fts_delete AFTER DELETE ON graph_entities
2186 BEGIN INSERT INTO graph_entities_fts(graph_entities_fts, rowid, name, summary) VALUES ('delete', old.id, old.name, COALESCE(old.summary, '')); END",
2187 )
2188 .execute(&mut *conn)
2189 .await
2190 .unwrap();
2191 sqlx::query(
2192 "CREATE TRIGGER IF NOT EXISTS graph_entities_fts_update AFTER UPDATE ON graph_entities
2193 BEGIN
2194 INSERT INTO graph_entities_fts(graph_entities_fts, rowid, name, summary) VALUES ('delete', old.id, old.name, COALESCE(old.summary, ''));
2195 INSERT INTO graph_entities_fts(rowid, name, summary) VALUES (new.id, new.name, COALESCE(new.summary, ''));
2196 END",
2197 )
2198 .execute(&mut *conn)
2199 .await
2200 .unwrap();
2201 sqlx::query("INSERT INTO graph_entities_fts(graph_entities_fts) VALUES('rebuild')")
2202 .execute(&mut *conn)
2203 .await
2204 .unwrap();
2205 sqlx::query(
2206 "CREATE TABLE graph_entity_aliases (
2207 id INTEGER PRIMARY KEY AUTOINCREMENT,
2208 entity_id INTEGER NOT NULL REFERENCES graph_entities(id) ON DELETE CASCADE,
2209 alias_name TEXT NOT NULL,
2210 created_at TEXT NOT NULL DEFAULT (datetime('now')),
2211 UNIQUE(alias_name, entity_id)
2212 )",
2213 )
2214 .execute(&mut *conn)
2215 .await
2216 .unwrap();
2217 sqlx::query(
2218 "INSERT OR IGNORE INTO graph_entity_aliases (entity_id, alias_name)
2219 SELECT id, name FROM graph_entities",
2220 )
2221 .execute(&mut *conn)
2222 .await
2223 .unwrap();
2224 sqlx::query("PRAGMA foreign_keys = ON")
2225 .execute(&mut *conn)
2226 .await
2227 .unwrap();
2228
2229 let alice_canon: String =
2231 sqlx::query_scalar("SELECT canonical_name FROM graph_entities WHERE id = ?1")
2232 .bind(alice_id)
2233 .fetch_one(&mut *conn)
2234 .await
2235 .unwrap();
2236 assert_eq!(
2237 alice_canon, "Alice",
2238 "canonical_name should equal pre-migration name"
2239 );
2240
2241 let rust_canon: String =
2242 sqlx::query_scalar("SELECT canonical_name FROM graph_entities WHERE id = ?1")
2243 .bind(rust_id)
2244 .fetch_one(&mut *conn)
2245 .await
2246 .unwrap();
2247 assert_eq!(
2248 rust_canon, "Rust",
2249 "canonical_name should equal pre-migration name"
2250 );
2251
2252 let alice_aliases: Vec<String> =
2254 sqlx::query_scalar("SELECT alias_name FROM graph_entity_aliases WHERE entity_id = ?1")
2255 .bind(alice_id)
2256 .fetch_all(&mut *conn)
2257 .await
2258 .unwrap();
2259 assert!(
2260 alice_aliases.contains(&"Alice".to_owned()),
2261 "initial alias should be seeded from entity name"
2262 );
2263
2264 let edge_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM graph_edges")
2266 .fetch_one(&mut *conn)
2267 .await
2268 .unwrap();
2269 assert_eq!(
2270 edge_count, 1,
2271 "graph_edges must survive migration 024 table recreation"
2272 );
2273 }
2274
2275 #[tokio::test]
2276 async fn find_entity_by_alias_same_alias_two_entities_deterministic() {
2277 let gs = setup().await;
2279 let id1 = gs
2280 .upsert_entity("python-v2", "python-v2", EntityType::Language, None)
2281 .await
2282 .unwrap();
2283 let id2 = gs
2284 .upsert_entity("python-v3", "python-v3", EntityType::Language, None)
2285 .await
2286 .unwrap();
2287 gs.add_alias(id1, "python").await.unwrap();
2288 gs.add_alias(id2, "python").await.unwrap();
2289
2290 let found = gs
2292 .find_entity_by_alias("python", EntityType::Language)
2293 .await
2294 .unwrap();
2295 assert!(found.is_some(), "should find an entity by shared alias");
2296 assert_eq!(
2298 found.unwrap().id,
2299 id1,
2300 "first-registered entity should win on shared alias"
2301 );
2302 }
2303
2304 #[tokio::test]
2307 async fn find_entities_fuzzy_special_chars() {
2308 let gs = setup().await;
2309 gs.upsert_entity("Graph", "Graph", EntityType::Concept, None)
2310 .await
2311 .unwrap();
2312 let results = gs.find_entities_fuzzy("graph\"()*:^", 10).await.unwrap();
2314 assert!(results.iter().any(|e| e.name == "Graph"));
2316 }
2317
2318 #[tokio::test]
2319 async fn find_entities_fuzzy_prefix_match() {
2320 let gs = setup().await;
2321 gs.upsert_entity("Graph", "Graph", EntityType::Concept, None)
2322 .await
2323 .unwrap();
2324 gs.upsert_entity("GraphQL", "GraphQL", EntityType::Concept, None)
2325 .await
2326 .unwrap();
2327 gs.upsert_entity("Unrelated", "Unrelated", EntityType::Concept, None)
2328 .await
2329 .unwrap();
2330 let results = gs.find_entities_fuzzy("Gra", 10).await.unwrap();
2332 assert_eq!(results.len(), 2);
2333 assert!(results.iter().any(|e| e.name == "Graph"));
2334 assert!(results.iter().any(|e| e.name == "GraphQL"));
2335 }
2336
2337 #[tokio::test]
2338 async fn find_entities_fuzzy_fts5_operator_injection() {
2339 let gs = setup().await;
2340 gs.upsert_entity("Graph", "Graph", EntityType::Concept, None)
2341 .await
2342 .unwrap();
2343 gs.upsert_entity("Unrelated", "Unrelated", EntityType::Concept, None)
2344 .await
2345 .unwrap();
2346 let results = gs
2351 .find_entities_fuzzy("graph OR unrelated", 10)
2352 .await
2353 .unwrap();
2354 assert!(
2355 results.is_empty(),
2356 "implicit AND of 'graph*' and 'unrelated*' should match no entity"
2357 );
2358 }
2359
2360 #[tokio::test]
2361 async fn find_entities_fuzzy_after_entity_update() {
2362 let gs = setup().await;
2363 gs.upsert_entity(
2365 "Foo",
2366 "Foo",
2367 EntityType::Concept,
2368 Some("initial summary bar"),
2369 )
2370 .await
2371 .unwrap();
2372 gs.upsert_entity(
2374 "Foo",
2375 "Foo",
2376 EntityType::Concept,
2377 Some("updated summary baz"),
2378 )
2379 .await
2380 .unwrap();
2381 let old_results = gs.find_entities_fuzzy("bar", 10).await.unwrap();
2383 assert!(
2384 old_results.is_empty(),
2385 "old summary content should not match after update"
2386 );
2387 let new_results = gs.find_entities_fuzzy("baz", 10).await.unwrap();
2389 assert_eq!(new_results.len(), 1);
2390 assert_eq!(new_results[0].name, "Foo");
2391 }
2392
2393 #[tokio::test]
2394 async fn find_entities_fuzzy_only_special_chars() {
2395 let gs = setup().await;
2396 gs.upsert_entity("Alpha", "Alpha", EntityType::Concept, None)
2397 .await
2398 .unwrap();
2399 let results = gs.find_entities_fuzzy("***", 10).await.unwrap();
2403 assert!(
2404 results.is_empty(),
2405 "only special chars should return no results"
2406 );
2407 let results = gs.find_entities_fuzzy("(((", 10).await.unwrap();
2408 assert!(results.is_empty(), "only parens should return no results");
2409 let results = gs.find_entities_fuzzy("\"\"\"", 10).await.unwrap();
2410 assert!(results.is_empty(), "only quotes should return no results");
2411 }
2412
2413 #[tokio::test]
2416 async fn find_entity_by_name_exact_wins_over_summary_mention() {
2417 let gs = setup().await;
2420 gs.upsert_entity(
2421 "Alice",
2422 "Alice",
2423 EntityType::Person,
2424 Some("A person named Alice"),
2425 )
2426 .await
2427 .unwrap();
2428 gs.upsert_entity(
2430 "Google",
2431 "Google",
2432 EntityType::Organization,
2433 Some("Company where Charlie, Alice, and Bob have worked"),
2434 )
2435 .await
2436 .unwrap();
2437
2438 let results = gs.find_entity_by_name("Alice").await.unwrap();
2439 assert!(!results.is_empty(), "must find at least one entity");
2440 assert_eq!(
2441 results[0].name, "Alice",
2442 "exact name match must come first, not entity with 'Alice' in summary"
2443 );
2444 }
2445
2446 #[tokio::test]
2447 async fn find_entity_by_name_case_insensitive_exact() {
2448 let gs = setup().await;
2449 gs.upsert_entity("Bob", "Bob", EntityType::Person, None)
2450 .await
2451 .unwrap();
2452
2453 let results = gs.find_entity_by_name("bob").await.unwrap();
2454 assert!(!results.is_empty());
2455 assert_eq!(results[0].name, "Bob");
2456 }
2457
2458 #[tokio::test]
2459 async fn find_entity_by_name_falls_back_to_fuzzy_when_no_exact_match() {
2460 let gs = setup().await;
2461 gs.upsert_entity("Charlie", "Charlie", EntityType::Person, None)
2462 .await
2463 .unwrap();
2464
2465 let results = gs.find_entity_by_name("Char").await.unwrap();
2467 assert!(!results.is_empty(), "prefix search must find Charlie");
2468 }
2469
2470 #[tokio::test]
2471 async fn find_entity_by_name_returns_empty_for_unknown() {
2472 let gs = setup().await;
2473 let results = gs.find_entity_by_name("NonExistent").await.unwrap();
2474 assert!(results.is_empty());
2475 }
2476
2477 #[tokio::test]
2478 async fn find_entity_by_name_matches_canonical_name() {
2479 let gs = setup().await;
2481 gs.upsert_entity("Dave (Engineer)", "Dave", EntityType::Person, None)
2483 .await
2484 .unwrap();
2485
2486 let results = gs.find_entity_by_name("Dave").await.unwrap();
2489 assert!(
2490 !results.is_empty(),
2491 "canonical_name match must return entity"
2492 );
2493 assert_eq!(results[0].canonical_name, "Dave");
2494 }
2495
2496 async fn insert_test_message(gs: &GraphStore, content: &str) -> crate::types::MessageId {
2497 let conv_id: i64 =
2499 sqlx::query_scalar("INSERT INTO conversations DEFAULT VALUES RETURNING id")
2500 .fetch_one(&gs.pool)
2501 .await
2502 .unwrap();
2503 let id: i64 = sqlx::query_scalar(
2504 "INSERT INTO messages (conversation_id, role, content) VALUES (?1, 'user', ?2) RETURNING id",
2505 )
2506 .bind(conv_id)
2507 .bind(content)
2508 .fetch_one(&gs.pool)
2509 .await
2510 .unwrap();
2511 crate::types::MessageId(id)
2512 }
2513
2514 #[tokio::test]
2515 async fn unprocessed_messages_for_backfill_returns_unprocessed() {
2516 let gs = setup().await;
2517 let id1 = insert_test_message(&gs, "hello world").await;
2518 let id2 = insert_test_message(&gs, "second message").await;
2519
2520 let rows = gs.unprocessed_messages_for_backfill(10).await.unwrap();
2521 assert_eq!(rows.len(), 2);
2522 assert!(rows.iter().any(|(id, _)| *id == id1));
2523 assert!(rows.iter().any(|(id, _)| *id == id2));
2524 }
2525
2526 #[tokio::test]
2527 async fn unprocessed_messages_for_backfill_respects_limit() {
2528 let gs = setup().await;
2529 insert_test_message(&gs, "msg1").await;
2530 insert_test_message(&gs, "msg2").await;
2531 insert_test_message(&gs, "msg3").await;
2532
2533 let rows = gs.unprocessed_messages_for_backfill(2).await.unwrap();
2534 assert_eq!(rows.len(), 2);
2535 }
2536
2537 #[tokio::test]
2538 async fn mark_messages_graph_processed_updates_flag() {
2539 let gs = setup().await;
2540 let id1 = insert_test_message(&gs, "to process").await;
2541 let _id2 = insert_test_message(&gs, "also to process").await;
2542
2543 let count_before = gs.unprocessed_message_count().await.unwrap();
2545 assert_eq!(count_before, 2);
2546
2547 gs.mark_messages_graph_processed(&[id1]).await.unwrap();
2548
2549 let count_after = gs.unprocessed_message_count().await.unwrap();
2550 assert_eq!(count_after, 1);
2551
2552 let rows = gs.unprocessed_messages_for_backfill(10).await.unwrap();
2554 assert!(!rows.iter().any(|(id, _)| *id == id1));
2555 }
2556
2557 #[tokio::test]
2558 async fn mark_messages_graph_processed_empty_ids_is_noop() {
2559 let gs = setup().await;
2560 insert_test_message(&gs, "message").await;
2561
2562 gs.mark_messages_graph_processed(&[]).await.unwrap();
2563
2564 let count = gs.unprocessed_message_count().await.unwrap();
2565 assert_eq!(count, 1);
2566 }
2567}