1use std::collections::HashMap;
5
6use futures::Stream;
7use sqlx::SqlitePool;
8
9use crate::error::MemoryError;
10use crate::sqlite::messages::sanitize_fts5_query;
11use crate::types::MessageId;
12
13use super::types::{Community, Edge, Entity, EntityAlias, EntityType};
14
15pub struct GraphStore {
16 pool: SqlitePool,
17}
18
19impl GraphStore {
20 #[must_use]
21 pub fn new(pool: SqlitePool) -> Self {
22 Self { pool }
23 }
24
25 #[must_use]
26 pub fn pool(&self) -> &SqlitePool {
27 &self.pool
28 }
29
30 pub async fn upsert_entity(
43 &self,
44 surface_name: &str,
45 canonical_name: &str,
46 entity_type: EntityType,
47 summary: Option<&str>,
48 ) -> Result<i64, MemoryError> {
49 let type_str = entity_type.as_str();
50 let id: i64 = sqlx::query_scalar(
51 "INSERT INTO graph_entities (name, canonical_name, entity_type, summary)
52 VALUES (?1, ?2, ?3, ?4)
53 ON CONFLICT(canonical_name, entity_type) DO UPDATE SET
54 name = excluded.name,
55 summary = COALESCE(excluded.summary, summary),
56 last_seen_at = datetime('now')
57 RETURNING id",
58 )
59 .bind(surface_name)
60 .bind(canonical_name)
61 .bind(type_str)
62 .bind(summary)
63 .fetch_one(&self.pool)
64 .await?;
65 Ok(id)
66 }
67
68 pub async fn find_entity(
74 &self,
75 canonical_name: &str,
76 entity_type: EntityType,
77 ) -> Result<Option<Entity>, MemoryError> {
78 let type_str = entity_type.as_str();
79 let row: Option<EntityRow> = sqlx::query_as(
80 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
81 FROM graph_entities
82 WHERE canonical_name = ?1 AND entity_type = ?2",
83 )
84 .bind(canonical_name)
85 .bind(type_str)
86 .fetch_optional(&self.pool)
87 .await?;
88 row.map(entity_from_row).transpose()
89 }
90
91 pub async fn find_entity_by_id(&self, entity_id: i64) -> Result<Option<Entity>, MemoryError> {
97 let row: Option<EntityRow> = sqlx::query_as(
98 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
99 FROM graph_entities
100 WHERE id = ?1",
101 )
102 .bind(entity_id)
103 .fetch_optional(&self.pool)
104 .await?;
105 row.map(entity_from_row).transpose()
106 }
107
108 pub async fn set_entity_qdrant_point_id(
114 &self,
115 entity_id: i64,
116 point_id: &str,
117 ) -> Result<(), MemoryError> {
118 sqlx::query("UPDATE graph_entities SET qdrant_point_id = ?1 WHERE id = ?2")
119 .bind(point_id)
120 .bind(entity_id)
121 .execute(&self.pool)
122 .await?;
123 Ok(())
124 }
125
126 pub async fn find_entities_fuzzy(
147 &self,
148 query: &str,
149 limit: usize,
150 ) -> Result<Vec<Entity>, MemoryError> {
151 const FTS5_OPERATORS: &[&str] = &["AND", "OR", "NOT", "NEAR"];
155 let query = &query[..query.floor_char_boundary(512)];
156 let sanitized = sanitize_fts5_query(query);
159 if sanitized.is_empty() {
160 return Ok(vec![]);
161 }
162 let fts_query: String = sanitized
163 .split_whitespace()
164 .filter(|t| !FTS5_OPERATORS.contains(t))
165 .map(|t| format!("{t}*"))
166 .collect::<Vec<_>>()
167 .join(" ");
168 if fts_query.is_empty() {
169 return Ok(vec![]);
170 }
171
172 let limit = i64::try_from(limit)?;
173 let rows: Vec<EntityRow> = sqlx::query_as(
176 "SELECT DISTINCT e.id, e.name, e.canonical_name, e.entity_type, e.summary,
177 e.first_seen_at, e.last_seen_at, e.qdrant_point_id
178 FROM graph_entities_fts fts
179 JOIN graph_entities e ON e.id = fts.rowid
180 WHERE graph_entities_fts MATCH ?1
181 UNION
182 SELECT e.id, e.name, e.canonical_name, e.entity_type, e.summary,
183 e.first_seen_at, e.last_seen_at, e.qdrant_point_id
184 FROM graph_entity_aliases a
185 JOIN graph_entities e ON e.id = a.entity_id
186 WHERE a.alias_name LIKE ?2 ESCAPE '\\' COLLATE NOCASE
187 LIMIT ?3",
188 )
189 .bind(&fts_query)
190 .bind(format!(
191 "%{}%",
192 query
193 .trim()
194 .replace('\\', "\\\\")
195 .replace('%', "\\%")
196 .replace('_', "\\_")
197 ))
198 .bind(limit)
199 .fetch_all(&self.pool)
200 .await?;
201 rows.into_iter()
202 .map(entity_from_row)
203 .collect::<Result<Vec<_>, _>>()
204 }
205
206 pub fn all_entities_stream(&self) -> impl Stream<Item = Result<Entity, MemoryError>> + '_ {
208 use futures::StreamExt as _;
209 sqlx::query_as::<_, EntityRow>(
210 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
211 FROM graph_entities ORDER BY id ASC",
212 )
213 .fetch(&self.pool)
214 .map(|r: Result<EntityRow, sqlx::Error>| {
215 r.map_err(MemoryError::from).and_then(entity_from_row)
216 })
217 }
218
219 pub async fn add_alias(&self, entity_id: i64, alias_name: &str) -> Result<(), MemoryError> {
227 sqlx::query(
228 "INSERT OR IGNORE INTO graph_entity_aliases (entity_id, alias_name) VALUES (?1, ?2)",
229 )
230 .bind(entity_id)
231 .bind(alias_name)
232 .execute(&self.pool)
233 .await?;
234 Ok(())
235 }
236
237 pub async fn find_entity_by_alias(
245 &self,
246 alias_name: &str,
247 entity_type: EntityType,
248 ) -> Result<Option<Entity>, MemoryError> {
249 let type_str = entity_type.as_str();
250 let row: Option<EntityRow> = sqlx::query_as(
251 "SELECT e.id, e.name, e.canonical_name, e.entity_type, e.summary,
252 e.first_seen_at, e.last_seen_at, e.qdrant_point_id
253 FROM graph_entity_aliases a
254 JOIN graph_entities e ON e.id = a.entity_id
255 WHERE a.alias_name = ?1 COLLATE NOCASE
256 AND e.entity_type = ?2
257 ORDER BY e.id ASC
258 LIMIT 1",
259 )
260 .bind(alias_name)
261 .bind(type_str)
262 .fetch_optional(&self.pool)
263 .await?;
264 row.map(entity_from_row).transpose()
265 }
266
267 pub async fn aliases_for_entity(
273 &self,
274 entity_id: i64,
275 ) -> Result<Vec<EntityAlias>, MemoryError> {
276 let rows: Vec<AliasRow> = sqlx::query_as(
277 "SELECT id, entity_id, alias_name, created_at
278 FROM graph_entity_aliases
279 WHERE entity_id = ?1
280 ORDER BY id ASC",
281 )
282 .bind(entity_id)
283 .fetch_all(&self.pool)
284 .await?;
285 Ok(rows.into_iter().map(alias_from_row).collect())
286 }
287
288 pub async fn all_entities(&self) -> Result<Vec<Entity>, MemoryError> {
294 use futures::TryStreamExt as _;
295 self.all_entities_stream().try_collect().await
296 }
297
298 pub async fn entity_count(&self) -> Result<i64, MemoryError> {
304 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM graph_entities")
305 .fetch_one(&self.pool)
306 .await?;
307 Ok(count)
308 }
309
310 pub async fn insert_edge(
318 &self,
319 source_entity_id: i64,
320 target_entity_id: i64,
321 relation: &str,
322 fact: &str,
323 confidence: f32,
324 episode_id: Option<MessageId>,
325 ) -> Result<i64, MemoryError> {
326 let confidence = confidence.clamp(0.0, 1.0);
327 let episode_raw: Option<i64> = episode_id.map(|m| m.0);
328 let id: i64 = sqlx::query_scalar(
329 "INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, episode_id)
330 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
331 RETURNING id",
332 )
333 .bind(source_entity_id)
334 .bind(target_entity_id)
335 .bind(relation)
336 .bind(fact)
337 .bind(f64::from(confidence))
338 .bind(episode_raw)
339 .fetch_one(&self.pool)
340 .await?;
341 Ok(id)
342 }
343
344 pub async fn invalidate_edge(&self, edge_id: i64) -> Result<(), MemoryError> {
350 sqlx::query(
351 "UPDATE graph_edges SET valid_to = datetime('now'), expired_at = datetime('now')
352 WHERE id = ?1",
353 )
354 .bind(edge_id)
355 .execute(&self.pool)
356 .await?;
357 Ok(())
358 }
359
360 pub async fn edges_for_entity(&self, entity_id: i64) -> Result<Vec<Edge>, MemoryError> {
366 let rows: Vec<EdgeRow> = sqlx::query_as(
367 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
368 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
369 FROM graph_edges
370 WHERE valid_to IS NULL
371 AND (source_entity_id = ?1 OR target_entity_id = ?1)",
372 )
373 .bind(entity_id)
374 .fetch_all(&self.pool)
375 .await?;
376 Ok(rows.into_iter().map(edge_from_row).collect())
377 }
378
379 pub async fn edges_between(
385 &self,
386 entity_a: i64,
387 entity_b: i64,
388 ) -> Result<Vec<Edge>, MemoryError> {
389 let rows: Vec<EdgeRow> = sqlx::query_as(
390 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
391 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
392 FROM graph_edges
393 WHERE valid_to IS NULL
394 AND ((source_entity_id = ?1 AND target_entity_id = ?2)
395 OR (source_entity_id = ?2 AND target_entity_id = ?1))",
396 )
397 .bind(entity_a)
398 .bind(entity_b)
399 .fetch_all(&self.pool)
400 .await?;
401 Ok(rows.into_iter().map(edge_from_row).collect())
402 }
403
404 pub async fn edges_exact(
410 &self,
411 source_entity_id: i64,
412 target_entity_id: i64,
413 ) -> Result<Vec<Edge>, MemoryError> {
414 let rows: Vec<EdgeRow> = sqlx::query_as(
415 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
416 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
417 FROM graph_edges
418 WHERE valid_to IS NULL
419 AND source_entity_id = ?1
420 AND target_entity_id = ?2",
421 )
422 .bind(source_entity_id)
423 .bind(target_entity_id)
424 .fetch_all(&self.pool)
425 .await?;
426 Ok(rows.into_iter().map(edge_from_row).collect())
427 }
428
429 pub async fn active_edge_count(&self) -> Result<i64, MemoryError> {
435 let count: i64 =
436 sqlx::query_scalar("SELECT COUNT(*) FROM graph_edges WHERE valid_to IS NULL")
437 .fetch_one(&self.pool)
438 .await?;
439 Ok(count)
440 }
441
442 pub async fn upsert_community(
454 &self,
455 name: &str,
456 summary: &str,
457 entity_ids: &[i64],
458 fingerprint: Option<&str>,
459 ) -> Result<i64, MemoryError> {
460 let entity_ids_json = serde_json::to_string(entity_ids)?;
461 let id: i64 = sqlx::query_scalar(
462 "INSERT INTO graph_communities (name, summary, entity_ids, fingerprint)
463 VALUES (?1, ?2, ?3, ?4)
464 ON CONFLICT(name) DO UPDATE SET
465 summary = excluded.summary,
466 entity_ids = excluded.entity_ids,
467 fingerprint = COALESCE(excluded.fingerprint, fingerprint),
468 updated_at = datetime('now')
469 RETURNING id",
470 )
471 .bind(name)
472 .bind(summary)
473 .bind(entity_ids_json)
474 .bind(fingerprint)
475 .fetch_one(&self.pool)
476 .await?;
477 Ok(id)
478 }
479
480 pub async fn community_fingerprints(&self) -> Result<HashMap<String, i64>, MemoryError> {
487 let rows: Vec<(String, i64)> = sqlx::query_as(
488 "SELECT fingerprint, id FROM graph_communities WHERE fingerprint IS NOT NULL",
489 )
490 .fetch_all(&self.pool)
491 .await?;
492 Ok(rows.into_iter().collect())
493 }
494
495 pub async fn delete_community_by_id(&self, id: i64) -> Result<(), MemoryError> {
501 sqlx::query("DELETE FROM graph_communities WHERE id = ?1")
502 .bind(id)
503 .execute(&self.pool)
504 .await?;
505 Ok(())
506 }
507
508 pub async fn clear_community_fingerprint(&self, id: i64) -> Result<(), MemoryError> {
517 sqlx::query("UPDATE graph_communities SET fingerprint = NULL WHERE id = ?1")
518 .bind(id)
519 .execute(&self.pool)
520 .await?;
521 Ok(())
522 }
523
524 pub async fn community_for_entity(
533 &self,
534 entity_id: i64,
535 ) -> Result<Option<Community>, MemoryError> {
536 let row: Option<CommunityRow> = sqlx::query_as(
537 "SELECT c.id, c.name, c.summary, c.entity_ids, c.fingerprint, c.created_at, c.updated_at
538 FROM graph_communities c, json_each(c.entity_ids) j
539 WHERE CAST(j.value AS INTEGER) = ?1
540 LIMIT 1",
541 )
542 .bind(entity_id)
543 .fetch_optional(&self.pool)
544 .await?;
545 match row {
546 Some(row) => {
547 let entity_ids: Vec<i64> = serde_json::from_str(&row.entity_ids)?;
548 Ok(Some(Community {
549 id: row.id,
550 name: row.name,
551 summary: row.summary,
552 entity_ids,
553 fingerprint: row.fingerprint,
554 created_at: row.created_at,
555 updated_at: row.updated_at,
556 }))
557 }
558 None => Ok(None),
559 }
560 }
561
562 pub async fn all_communities(&self) -> Result<Vec<Community>, MemoryError> {
568 let rows: Vec<CommunityRow> = sqlx::query_as(
569 "SELECT id, name, summary, entity_ids, fingerprint, created_at, updated_at
570 FROM graph_communities
571 ORDER BY id ASC",
572 )
573 .fetch_all(&self.pool)
574 .await?;
575
576 rows.into_iter()
577 .map(|row| {
578 let entity_ids: Vec<i64> = serde_json::from_str(&row.entity_ids)?;
579 Ok(Community {
580 id: row.id,
581 name: row.name,
582 summary: row.summary,
583 entity_ids,
584 fingerprint: row.fingerprint,
585 created_at: row.created_at,
586 updated_at: row.updated_at,
587 })
588 })
589 .collect()
590 }
591
592 pub async fn community_count(&self) -> Result<i64, MemoryError> {
598 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM graph_communities")
599 .fetch_one(&self.pool)
600 .await?;
601 Ok(count)
602 }
603
604 pub async fn get_metadata(&self, key: &str) -> Result<Option<String>, MemoryError> {
612 let val: Option<String> =
613 sqlx::query_scalar("SELECT value FROM graph_metadata WHERE key = ?1")
614 .bind(key)
615 .fetch_optional(&self.pool)
616 .await?;
617 Ok(val)
618 }
619
620 pub async fn set_metadata(&self, key: &str, value: &str) -> Result<(), MemoryError> {
626 sqlx::query(
627 "INSERT INTO graph_metadata (key, value) VALUES (?1, ?2)
628 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
629 )
630 .bind(key)
631 .bind(value)
632 .execute(&self.pool)
633 .await?;
634 Ok(())
635 }
636
637 pub async fn extraction_count(&self) -> Result<i64, MemoryError> {
645 let val = self.get_metadata("extraction_count").await?;
646 Ok(val.and_then(|v| v.parse::<i64>().ok()).unwrap_or(0))
647 }
648
649 pub fn all_active_edges_stream(&self) -> impl Stream<Item = Result<Edge, MemoryError>> + '_ {
651 use futures::StreamExt as _;
652 sqlx::query_as::<_, EdgeRow>(
653 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
654 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
655 FROM graph_edges
656 WHERE valid_to IS NULL
657 ORDER BY id ASC",
658 )
659 .fetch(&self.pool)
660 .map(|r| r.map_err(MemoryError::from).map(edge_from_row))
661 }
662
663 pub async fn find_community_by_id(&self, id: i64) -> Result<Option<Community>, MemoryError> {
669 let row: Option<CommunityRow> = sqlx::query_as(
670 "SELECT id, name, summary, entity_ids, fingerprint, created_at, updated_at
671 FROM graph_communities
672 WHERE id = ?1",
673 )
674 .bind(id)
675 .fetch_optional(&self.pool)
676 .await?;
677 match row {
678 Some(row) => {
679 let entity_ids: Vec<i64> = serde_json::from_str(&row.entity_ids)?;
680 Ok(Some(Community {
681 id: row.id,
682 name: row.name,
683 summary: row.summary,
684 entity_ids,
685 fingerprint: row.fingerprint,
686 created_at: row.created_at,
687 updated_at: row.updated_at,
688 }))
689 }
690 None => Ok(None),
691 }
692 }
693
694 pub async fn delete_all_communities(&self) -> Result<(), MemoryError> {
700 sqlx::query("DELETE FROM graph_communities")
701 .execute(&self.pool)
702 .await?;
703 Ok(())
704 }
705
706 pub async fn delete_expired_edges(&self, retention_days: u32) -> Result<usize, MemoryError> {
712 let days = i64::from(retention_days);
713 let result = sqlx::query(
714 "DELETE FROM graph_edges
715 WHERE expired_at IS NOT NULL
716 AND expired_at < datetime('now', '-' || ?1 || ' days')",
717 )
718 .bind(days)
719 .execute(&self.pool)
720 .await?;
721 Ok(usize::try_from(result.rows_affected())?)
722 }
723
724 pub async fn delete_orphan_entities(&self, retention_days: u32) -> Result<usize, MemoryError> {
730 let days = i64::from(retention_days);
731 let result = sqlx::query(
732 "DELETE FROM graph_entities
733 WHERE id NOT IN (
734 SELECT DISTINCT source_entity_id FROM graph_edges WHERE valid_to IS NULL
735 UNION
736 SELECT DISTINCT target_entity_id FROM graph_edges WHERE valid_to IS NULL
737 )
738 AND last_seen_at < datetime('now', '-' || ?1 || ' days')",
739 )
740 .bind(days)
741 .execute(&self.pool)
742 .await?;
743 Ok(usize::try_from(result.rows_affected())?)
744 }
745
746 pub async fn cap_entities(&self, max_entities: usize) -> Result<usize, MemoryError> {
755 let current = self.entity_count().await?;
756 let max = i64::try_from(max_entities)?;
757 if current <= max {
758 return Ok(0);
759 }
760 let excess = current - max;
761 let result = sqlx::query(
762 "DELETE FROM graph_entities
763 WHERE id IN (
764 SELECT e.id
765 FROM graph_entities e
766 LEFT JOIN (
767 SELECT source_entity_id AS eid, COUNT(*) AS cnt
768 FROM graph_edges WHERE valid_to IS NULL GROUP BY source_entity_id
769 UNION ALL
770 SELECT target_entity_id AS eid, COUNT(*) AS cnt
771 FROM graph_edges WHERE valid_to IS NULL GROUP BY target_entity_id
772 ) edge_counts ON e.id = edge_counts.eid
773 ORDER BY COALESCE(edge_counts.cnt, 0) ASC, e.last_seen_at ASC
774 LIMIT ?1
775 )",
776 )
777 .bind(excess)
778 .execute(&self.pool)
779 .await?;
780 Ok(usize::try_from(result.rows_affected())?)
781 }
782
783 pub async fn bfs(
800 &self,
801 start_entity_id: i64,
802 max_hops: u32,
803 ) -> Result<(Vec<Entity>, Vec<Edge>), MemoryError> {
804 self.bfs_with_depth(start_entity_id, max_hops)
805 .await
806 .map(|(e, ed, _)| (e, ed))
807 }
808
809 pub async fn bfs_with_depth(
820 &self,
821 start_entity_id: i64,
822 max_hops: u32,
823 ) -> Result<(Vec<Entity>, Vec<Edge>, std::collections::HashMap<i64, u32>), MemoryError> {
824 use std::collections::HashMap;
825
826 const MAX_FRONTIER: usize = 300;
829
830 let mut depth_map: HashMap<i64, u32> = HashMap::new();
831 let mut frontier: Vec<i64> = vec![start_entity_id];
832 depth_map.insert(start_entity_id, 0);
833
834 for hop in 0..max_hops {
835 if frontier.is_empty() {
836 break;
837 }
838 frontier.truncate(MAX_FRONTIER);
839 let placeholders = frontier
841 .iter()
842 .enumerate()
843 .map(|(i, _)| format!("?{}", i + 1))
844 .collect::<Vec<_>>()
845 .join(", ");
846 let neighbour_sql = format!(
847 "SELECT DISTINCT CASE
848 WHEN source_entity_id IN ({placeholders}) THEN target_entity_id
849 ELSE source_entity_id
850 END as neighbour_id
851 FROM graph_edges
852 WHERE valid_to IS NULL
853 AND (source_entity_id IN ({placeholders}) OR target_entity_id IN ({placeholders}))"
854 );
855 let mut q = sqlx::query_scalar::<_, i64>(&neighbour_sql);
856 for id in &frontier {
857 q = q.bind(*id);
858 }
859 for id in &frontier {
860 q = q.bind(*id);
861 }
862 for id in &frontier {
863 q = q.bind(*id);
864 }
865 let neighbours: Vec<i64> = q.fetch_all(&self.pool).await?;
866
867 let mut next_frontier: Vec<i64> = Vec::new();
868 for nbr in neighbours {
869 if let std::collections::hash_map::Entry::Vacant(e) = depth_map.entry(nbr) {
870 e.insert(hop + 1);
871 next_frontier.push(nbr);
872 }
873 }
874 frontier = next_frontier;
875 }
876
877 let mut visited_ids: Vec<i64> = depth_map.keys().copied().collect();
878 if visited_ids.is_empty() {
879 return Ok((Vec::new(), Vec::new(), depth_map));
880 }
881 visited_ids.truncate(499);
883
884 let placeholders = visited_ids
886 .iter()
887 .enumerate()
888 .map(|(i, _)| format!("?{}", i + 1))
889 .collect::<Vec<_>>()
890 .join(", ");
891
892 let edge_sql = format!(
893 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
894 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
895 FROM graph_edges
896 WHERE valid_to IS NULL
897 AND source_entity_id IN ({placeholders})
898 AND target_entity_id IN ({placeholders})"
899 );
900 let mut edge_query = sqlx::query_as::<_, EdgeRow>(&edge_sql);
901 for id in &visited_ids {
902 edge_query = edge_query.bind(*id);
903 }
904 for id in &visited_ids {
905 edge_query = edge_query.bind(*id);
906 }
907 let edge_rows: Vec<EdgeRow> = edge_query.fetch_all(&self.pool).await?;
908
909 let entity_sql = format!(
910 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
911 FROM graph_entities WHERE id IN ({placeholders})"
912 );
913 let mut entity_query = sqlx::query_as::<_, EntityRow>(&entity_sql);
914 for id in &visited_ids {
915 entity_query = entity_query.bind(*id);
916 }
917 let entity_rows: Vec<EntityRow> = entity_query.fetch_all(&self.pool).await?;
918
919 let entities: Vec<Entity> = entity_rows
920 .into_iter()
921 .map(entity_from_row)
922 .collect::<Result<Vec<_>, _>>()?;
923 let edges: Vec<Edge> = edge_rows.into_iter().map(edge_from_row).collect();
924
925 Ok((entities, edges, depth_map))
926 }
927
928 pub async fn find_entity_by_name(&self, name: &str) -> Result<Vec<Entity>, MemoryError> {
938 self.find_entities_fuzzy(name, 5).await
939 }
940
941 pub async fn unprocessed_messages_for_backfill(
949 &self,
950 limit: usize,
951 ) -> Result<Vec<(crate::types::MessageId, String)>, MemoryError> {
952 let limit = i64::try_from(limit)?;
953 let rows: Vec<(i64, String)> = sqlx::query_as(
954 "SELECT id, content FROM messages
955 WHERE graph_processed = 0
956 ORDER BY id ASC
957 LIMIT ?1",
958 )
959 .bind(limit)
960 .fetch_all(&self.pool)
961 .await?;
962 Ok(rows
963 .into_iter()
964 .map(|(id, content)| (crate::types::MessageId(id), content))
965 .collect())
966 }
967
968 pub async fn unprocessed_message_count(&self) -> Result<i64, MemoryError> {
974 let count: i64 =
975 sqlx::query_scalar("SELECT COUNT(*) FROM messages WHERE graph_processed = 0")
976 .fetch_one(&self.pool)
977 .await?;
978 Ok(count)
979 }
980
981 pub async fn mark_messages_graph_processed(
987 &self,
988 ids: &[crate::types::MessageId],
989 ) -> Result<(), MemoryError> {
990 if ids.is_empty() {
991 return Ok(());
992 }
993 let placeholders = ids
994 .iter()
995 .enumerate()
996 .map(|(i, _)| format!("?{}", i + 1))
997 .collect::<Vec<_>>()
998 .join(", ");
999 let sql = format!("UPDATE messages SET graph_processed = 1 WHERE id IN ({placeholders})");
1000 let mut query = sqlx::query(&sql);
1001 for id in ids {
1002 query = query.bind(id.0);
1003 }
1004 query.execute(&self.pool).await?;
1005 Ok(())
1006 }
1007}
1008
1009#[derive(sqlx::FromRow)]
1012struct EntityRow {
1013 id: i64,
1014 name: String,
1015 canonical_name: String,
1016 entity_type: String,
1017 summary: Option<String>,
1018 first_seen_at: String,
1019 last_seen_at: String,
1020 qdrant_point_id: Option<String>,
1021}
1022
1023fn entity_from_row(row: EntityRow) -> Result<Entity, MemoryError> {
1024 let entity_type = row
1025 .entity_type
1026 .parse::<EntityType>()
1027 .map_err(MemoryError::GraphStore)?;
1028 Ok(Entity {
1029 id: row.id,
1030 name: row.name,
1031 canonical_name: row.canonical_name,
1032 entity_type,
1033 summary: row.summary,
1034 first_seen_at: row.first_seen_at,
1035 last_seen_at: row.last_seen_at,
1036 qdrant_point_id: row.qdrant_point_id,
1037 })
1038}
1039
1040#[derive(sqlx::FromRow)]
1041struct AliasRow {
1042 id: i64,
1043 entity_id: i64,
1044 alias_name: String,
1045 created_at: String,
1046}
1047
1048fn alias_from_row(row: AliasRow) -> EntityAlias {
1049 EntityAlias {
1050 id: row.id,
1051 entity_id: row.entity_id,
1052 alias_name: row.alias_name,
1053 created_at: row.created_at,
1054 }
1055}
1056
1057#[derive(sqlx::FromRow)]
1058struct EdgeRow {
1059 id: i64,
1060 source_entity_id: i64,
1061 target_entity_id: i64,
1062 relation: String,
1063 fact: String,
1064 confidence: f64,
1065 valid_from: String,
1066 valid_to: Option<String>,
1067 created_at: String,
1068 expired_at: Option<String>,
1069 episode_id: Option<i64>,
1070 qdrant_point_id: Option<String>,
1071}
1072
1073fn edge_from_row(row: EdgeRow) -> Edge {
1074 Edge {
1075 id: row.id,
1076 source_entity_id: row.source_entity_id,
1077 target_entity_id: row.target_entity_id,
1078 relation: row.relation,
1079 fact: row.fact,
1080 #[allow(clippy::cast_possible_truncation)]
1081 confidence: row.confidence as f32,
1082 valid_from: row.valid_from,
1083 valid_to: row.valid_to,
1084 created_at: row.created_at,
1085 expired_at: row.expired_at,
1086 episode_id: row.episode_id.map(MessageId),
1087 qdrant_point_id: row.qdrant_point_id,
1088 }
1089}
1090
1091#[derive(sqlx::FromRow)]
1092struct CommunityRow {
1093 id: i64,
1094 name: String,
1095 summary: String,
1096 entity_ids: String,
1097 fingerprint: Option<String>,
1098 created_at: String,
1099 updated_at: String,
1100}
1101
1102#[cfg(test)]
1105mod tests {
1106 use super::*;
1107 use crate::sqlite::SqliteStore;
1108
1109 async fn setup() -> GraphStore {
1110 let store = SqliteStore::new(":memory:").await.unwrap();
1111 GraphStore::new(store.pool().clone())
1112 }
1113
1114 #[tokio::test]
1115 async fn upsert_entity_insert_new() {
1116 let gs = setup().await;
1117 let id = gs
1118 .upsert_entity("Alice", "Alice", EntityType::Person, Some("a person"))
1119 .await
1120 .unwrap();
1121 assert!(id > 0);
1122 }
1123
1124 #[tokio::test]
1125 async fn upsert_entity_update_existing() {
1126 let gs = setup().await;
1127 let id1 = gs
1128 .upsert_entity("Alice", "Alice", EntityType::Person, None)
1129 .await
1130 .unwrap();
1131 let id2 = gs
1134 .upsert_entity("Alice", "Alice", EntityType::Person, Some("updated"))
1135 .await
1136 .unwrap();
1137 assert_eq!(id1, id2);
1138 let entity = gs
1139 .find_entity("Alice", EntityType::Person)
1140 .await
1141 .unwrap()
1142 .unwrap();
1143 assert_eq!(entity.summary.as_deref(), Some("updated"));
1144 }
1145
1146 #[tokio::test]
1147 async fn find_entity_found() {
1148 let gs = setup().await;
1149 gs.upsert_entity("Bob", "Bob", EntityType::Tool, Some("a tool"))
1150 .await
1151 .unwrap();
1152 let entity = gs
1153 .find_entity("Bob", EntityType::Tool)
1154 .await
1155 .unwrap()
1156 .unwrap();
1157 assert_eq!(entity.name, "Bob");
1158 assert_eq!(entity.entity_type, EntityType::Tool);
1159 }
1160
1161 #[tokio::test]
1162 async fn find_entity_not_found() {
1163 let gs = setup().await;
1164 let result = gs.find_entity("Nobody", EntityType::Person).await.unwrap();
1165 assert!(result.is_none());
1166 }
1167
1168 #[tokio::test]
1169 async fn find_entities_fuzzy_partial_match() {
1170 let gs = setup().await;
1171 gs.upsert_entity("GraphQL", "GraphQL", EntityType::Concept, None)
1172 .await
1173 .unwrap();
1174 gs.upsert_entity("Graph", "Graph", EntityType::Concept, None)
1175 .await
1176 .unwrap();
1177 gs.upsert_entity("Unrelated", "Unrelated", EntityType::Concept, None)
1178 .await
1179 .unwrap();
1180
1181 let results = gs.find_entities_fuzzy("graph", 10).await.unwrap();
1182 assert_eq!(results.len(), 2);
1183 assert!(results.iter().any(|e| e.name == "GraphQL"));
1184 assert!(results.iter().any(|e| e.name == "Graph"));
1185 }
1186
1187 #[tokio::test]
1188 async fn entity_count_empty() {
1189 let gs = setup().await;
1190 assert_eq!(gs.entity_count().await.unwrap(), 0);
1191 }
1192
1193 #[tokio::test]
1194 async fn entity_count_non_empty() {
1195 let gs = setup().await;
1196 gs.upsert_entity("A", "A", EntityType::Concept, None)
1197 .await
1198 .unwrap();
1199 gs.upsert_entity("B", "B", EntityType::Concept, None)
1200 .await
1201 .unwrap();
1202 assert_eq!(gs.entity_count().await.unwrap(), 2);
1203 }
1204
1205 #[tokio::test]
1206 async fn all_entities_and_stream() {
1207 let gs = setup().await;
1208 gs.upsert_entity("X", "X", EntityType::Project, None)
1209 .await
1210 .unwrap();
1211 gs.upsert_entity("Y", "Y", EntityType::Language, None)
1212 .await
1213 .unwrap();
1214
1215 let all = gs.all_entities().await.unwrap();
1216 assert_eq!(all.len(), 2);
1217
1218 use futures::StreamExt as _;
1219 let streamed: Vec<Result<Entity, _>> = gs.all_entities_stream().collect().await;
1220 assert_eq!(streamed.len(), 2);
1221 assert!(streamed.iter().all(|r| r.is_ok()));
1222 }
1223
1224 #[tokio::test]
1225 async fn insert_edge_without_episode() {
1226 let gs = setup().await;
1227 let src = gs
1228 .upsert_entity("Src", "Src", EntityType::Concept, None)
1229 .await
1230 .unwrap();
1231 let tgt = gs
1232 .upsert_entity("Tgt", "Tgt", EntityType::Concept, None)
1233 .await
1234 .unwrap();
1235 let eid = gs
1236 .insert_edge(src, tgt, "relates_to", "Src relates to Tgt", 0.9, None)
1237 .await
1238 .unwrap();
1239 assert!(eid > 0);
1240 }
1241
1242 #[tokio::test]
1243 async fn insert_edge_with_episode() {
1244 let gs = setup().await;
1245 let src = gs
1246 .upsert_entity("Src2", "Src2", EntityType::Concept, None)
1247 .await
1248 .unwrap();
1249 let tgt = gs
1250 .upsert_entity("Tgt2", "Tgt2", EntityType::Concept, None)
1251 .await
1252 .unwrap();
1253 let episode = MessageId(999);
1259 let result = gs
1260 .insert_edge(src, tgt, "uses", "Src2 uses Tgt2", 1.0, Some(episode))
1261 .await;
1262 match &result {
1263 Ok(eid) => assert!(*eid > 0, "inserted edge should have positive id"),
1264 Err(MemoryError::Sqlite(_)) => {} Err(e) => panic!("unexpected error: {e}"),
1266 }
1267 }
1268
1269 #[tokio::test]
1270 async fn invalidate_edge_sets_timestamps() {
1271 let gs = setup().await;
1272 let src = gs
1273 .upsert_entity("E1", "E1", EntityType::Concept, None)
1274 .await
1275 .unwrap();
1276 let tgt = gs
1277 .upsert_entity("E2", "E2", EntityType::Concept, None)
1278 .await
1279 .unwrap();
1280 let eid = gs
1281 .insert_edge(src, tgt, "r", "fact", 1.0, None)
1282 .await
1283 .unwrap();
1284 gs.invalidate_edge(eid).await.unwrap();
1285
1286 let row: (Option<String>, Option<String>) =
1287 sqlx::query_as("SELECT valid_to, expired_at FROM graph_edges WHERE id = ?1")
1288 .bind(eid)
1289 .fetch_one(&gs.pool)
1290 .await
1291 .unwrap();
1292 assert!(row.0.is_some(), "valid_to should be set");
1293 assert!(row.1.is_some(), "expired_at should be set");
1294 }
1295
1296 #[tokio::test]
1297 async fn edges_for_entity_both_directions() {
1298 let gs = setup().await;
1299 let a = gs
1300 .upsert_entity("A", "A", EntityType::Concept, None)
1301 .await
1302 .unwrap();
1303 let b = gs
1304 .upsert_entity("B", "B", EntityType::Concept, None)
1305 .await
1306 .unwrap();
1307 let c = gs
1308 .upsert_entity("C", "C", EntityType::Concept, None)
1309 .await
1310 .unwrap();
1311 gs.insert_edge(a, b, "r", "f1", 1.0, None).await.unwrap();
1312 gs.insert_edge(c, a, "r", "f2", 1.0, None).await.unwrap();
1313
1314 let edges = gs.edges_for_entity(a).await.unwrap();
1315 assert_eq!(edges.len(), 2);
1316 }
1317
1318 #[tokio::test]
1319 async fn edges_between_both_directions() {
1320 let gs = setup().await;
1321 let a = gs
1322 .upsert_entity("PA", "PA", EntityType::Person, None)
1323 .await
1324 .unwrap();
1325 let b = gs
1326 .upsert_entity("PB", "PB", EntityType::Person, None)
1327 .await
1328 .unwrap();
1329 gs.insert_edge(a, b, "knows", "PA knows PB", 1.0, None)
1330 .await
1331 .unwrap();
1332
1333 let fwd = gs.edges_between(a, b).await.unwrap();
1334 assert_eq!(fwd.len(), 1);
1335 let rev = gs.edges_between(b, a).await.unwrap();
1336 assert_eq!(rev.len(), 1);
1337 }
1338
1339 #[tokio::test]
1340 async fn active_edge_count_excludes_invalidated() {
1341 let gs = setup().await;
1342 let a = gs
1343 .upsert_entity("N1", "N1", EntityType::Concept, None)
1344 .await
1345 .unwrap();
1346 let b = gs
1347 .upsert_entity("N2", "N2", EntityType::Concept, None)
1348 .await
1349 .unwrap();
1350 let e1 = gs.insert_edge(a, b, "r1", "f1", 1.0, None).await.unwrap();
1351 gs.insert_edge(a, b, "r2", "f2", 1.0, None).await.unwrap();
1352 gs.invalidate_edge(e1).await.unwrap();
1353
1354 assert_eq!(gs.active_edge_count().await.unwrap(), 1);
1355 }
1356
1357 #[tokio::test]
1358 async fn upsert_community_insert_and_update() {
1359 let gs = setup().await;
1360 let id1 = gs
1361 .upsert_community("clusterA", "summary A", &[1, 2, 3], None)
1362 .await
1363 .unwrap();
1364 assert!(id1 > 0);
1365 let id2 = gs
1366 .upsert_community("clusterA", "summary A updated", &[1, 2, 3, 4], None)
1367 .await
1368 .unwrap();
1369 assert_eq!(id1, id2);
1370 let communities = gs.all_communities().await.unwrap();
1371 assert_eq!(communities.len(), 1);
1372 assert_eq!(communities[0].summary, "summary A updated");
1373 assert_eq!(communities[0].entity_ids, vec![1, 2, 3, 4]);
1374 }
1375
1376 #[tokio::test]
1377 async fn community_for_entity_found() {
1378 let gs = setup().await;
1379 let a = gs
1380 .upsert_entity("CA", "CA", EntityType::Concept, None)
1381 .await
1382 .unwrap();
1383 let b = gs
1384 .upsert_entity("CB", "CB", EntityType::Concept, None)
1385 .await
1386 .unwrap();
1387 gs.upsert_community("cA", "summary", &[a, b], None)
1388 .await
1389 .unwrap();
1390 let result = gs.community_for_entity(a).await.unwrap();
1391 assert!(result.is_some());
1392 assert_eq!(result.unwrap().name, "cA");
1393 }
1394
1395 #[tokio::test]
1396 async fn community_for_entity_not_found() {
1397 let gs = setup().await;
1398 let result = gs.community_for_entity(999).await.unwrap();
1399 assert!(result.is_none());
1400 }
1401
1402 #[tokio::test]
1403 async fn community_count() {
1404 let gs = setup().await;
1405 assert_eq!(gs.community_count().await.unwrap(), 0);
1406 gs.upsert_community("c1", "s1", &[], None).await.unwrap();
1407 gs.upsert_community("c2", "s2", &[], None).await.unwrap();
1408 assert_eq!(gs.community_count().await.unwrap(), 2);
1409 }
1410
1411 #[tokio::test]
1412 async fn metadata_get_set_round_trip() {
1413 let gs = setup().await;
1414 assert_eq!(gs.get_metadata("counter").await.unwrap(), None);
1415 gs.set_metadata("counter", "42").await.unwrap();
1416 assert_eq!(gs.get_metadata("counter").await.unwrap(), Some("42".into()));
1417 gs.set_metadata("counter", "43").await.unwrap();
1418 assert_eq!(gs.get_metadata("counter").await.unwrap(), Some("43".into()));
1419 }
1420
1421 #[tokio::test]
1422 async fn bfs_max_hops_0_returns_only_start() {
1423 let gs = setup().await;
1424 let a = gs
1425 .upsert_entity("BfsA", "BfsA", EntityType::Concept, None)
1426 .await
1427 .unwrap();
1428 let b = gs
1429 .upsert_entity("BfsB", "BfsB", EntityType::Concept, None)
1430 .await
1431 .unwrap();
1432 gs.insert_edge(a, b, "r", "f", 1.0, None).await.unwrap();
1433
1434 let (entities, edges) = gs.bfs(a, 0).await.unwrap();
1435 assert_eq!(entities.len(), 1);
1436 assert_eq!(entities[0].id, a);
1437 assert!(edges.is_empty());
1438 }
1439
1440 #[tokio::test]
1441 async fn bfs_max_hops_2_chain() {
1442 let gs = setup().await;
1443 let a = gs
1444 .upsert_entity("ChainA", "ChainA", EntityType::Concept, None)
1445 .await
1446 .unwrap();
1447 let b = gs
1448 .upsert_entity("ChainB", "ChainB", EntityType::Concept, None)
1449 .await
1450 .unwrap();
1451 let c = gs
1452 .upsert_entity("ChainC", "ChainC", EntityType::Concept, None)
1453 .await
1454 .unwrap();
1455 gs.insert_edge(a, b, "r", "f1", 1.0, None).await.unwrap();
1456 gs.insert_edge(b, c, "r", "f2", 1.0, None).await.unwrap();
1457
1458 let (entities, edges) = gs.bfs(a, 2).await.unwrap();
1459 let ids: Vec<_> = entities.iter().map(|e| e.id).collect();
1460 assert!(ids.contains(&a));
1461 assert!(ids.contains(&b));
1462 assert!(ids.contains(&c));
1463 assert_eq!(edges.len(), 2);
1464 }
1465
1466 #[tokio::test]
1467 async fn bfs_cycle_no_infinite_loop() {
1468 let gs = setup().await;
1469 let a = gs
1470 .upsert_entity("CycA", "CycA", EntityType::Concept, None)
1471 .await
1472 .unwrap();
1473 let b = gs
1474 .upsert_entity("CycB", "CycB", EntityType::Concept, None)
1475 .await
1476 .unwrap();
1477 gs.insert_edge(a, b, "r", "f1", 1.0, None).await.unwrap();
1478 gs.insert_edge(b, a, "r", "f2", 1.0, None).await.unwrap();
1479
1480 let (entities, _edges) = gs.bfs(a, 3).await.unwrap();
1481 let ids: Vec<_> = entities.iter().map(|e| e.id).collect();
1482 assert!(ids.contains(&a));
1484 assert!(ids.contains(&b));
1485 assert_eq!(ids.len(), 2);
1486 }
1487
1488 #[tokio::test]
1489 async fn test_invalidated_edges_excluded_from_bfs() {
1490 let gs = setup().await;
1491 let a = gs
1492 .upsert_entity("InvA", "InvA", EntityType::Concept, None)
1493 .await
1494 .unwrap();
1495 let b = gs
1496 .upsert_entity("InvB", "InvB", EntityType::Concept, None)
1497 .await
1498 .unwrap();
1499 let c = gs
1500 .upsert_entity("InvC", "InvC", EntityType::Concept, None)
1501 .await
1502 .unwrap();
1503 let ab = gs.insert_edge(a, b, "r", "f1", 1.0, None).await.unwrap();
1504 gs.insert_edge(b, c, "r", "f2", 1.0, None).await.unwrap();
1505 gs.invalidate_edge(ab).await.unwrap();
1507
1508 let (entities, edges) = gs.bfs(a, 2).await.unwrap();
1509 let ids: Vec<_> = entities.iter().map(|e| e.id).collect();
1510 assert_eq!(ids, vec![a], "only start entity should be reachable");
1511 assert!(edges.is_empty(), "no active edges should be returned");
1512 }
1513
1514 #[tokio::test]
1515 async fn test_bfs_empty_graph() {
1516 let gs = setup().await;
1517 let a = gs
1518 .upsert_entity("IsoA", "IsoA", EntityType::Concept, None)
1519 .await
1520 .unwrap();
1521
1522 let (entities, edges) = gs.bfs(a, 2).await.unwrap();
1523 let ids: Vec<_> = entities.iter().map(|e| e.id).collect();
1524 assert_eq!(ids, vec![a], "isolated node: only start entity returned");
1525 assert!(edges.is_empty(), "no edges for isolated node");
1526 }
1527
1528 #[tokio::test]
1529 async fn test_bfs_diamond() {
1530 let gs = setup().await;
1531 let a = gs
1532 .upsert_entity("DiamA", "DiamA", EntityType::Concept, None)
1533 .await
1534 .unwrap();
1535 let b = gs
1536 .upsert_entity("DiamB", "DiamB", EntityType::Concept, None)
1537 .await
1538 .unwrap();
1539 let c = gs
1540 .upsert_entity("DiamC", "DiamC", EntityType::Concept, None)
1541 .await
1542 .unwrap();
1543 let d = gs
1544 .upsert_entity("DiamD", "DiamD", EntityType::Concept, None)
1545 .await
1546 .unwrap();
1547 gs.insert_edge(a, b, "r", "f1", 1.0, None).await.unwrap();
1548 gs.insert_edge(a, c, "r", "f2", 1.0, None).await.unwrap();
1549 gs.insert_edge(b, d, "r", "f3", 1.0, None).await.unwrap();
1550 gs.insert_edge(c, d, "r", "f4", 1.0, None).await.unwrap();
1551
1552 let (entities, edges) = gs.bfs(a, 2).await.unwrap();
1553 let mut ids: Vec<_> = entities.iter().map(|e| e.id).collect();
1554 ids.sort_unstable();
1555 let mut expected = vec![a, b, c, d];
1556 expected.sort_unstable();
1557 assert_eq!(ids, expected, "all 4 nodes reachable, no duplicates");
1558 assert_eq!(edges.len(), 4, "all 4 edges returned");
1559 }
1560
1561 #[tokio::test]
1562 async fn extraction_count_default_zero() {
1563 let gs = setup().await;
1564 assert_eq!(gs.extraction_count().await.unwrap(), 0);
1565 }
1566
1567 #[tokio::test]
1568 async fn extraction_count_after_set() {
1569 let gs = setup().await;
1570 gs.set_metadata("extraction_count", "7").await.unwrap();
1571 assert_eq!(gs.extraction_count().await.unwrap(), 7);
1572 }
1573
1574 #[tokio::test]
1575 async fn all_active_edges_stream_excludes_invalidated() {
1576 use futures::TryStreamExt as _;
1577 let gs = setup().await;
1578 let a = gs
1579 .upsert_entity("SA", "SA", EntityType::Concept, None)
1580 .await
1581 .unwrap();
1582 let b = gs
1583 .upsert_entity("SB", "SB", EntityType::Concept, None)
1584 .await
1585 .unwrap();
1586 let c = gs
1587 .upsert_entity("SC", "SC", EntityType::Concept, None)
1588 .await
1589 .unwrap();
1590 let e1 = gs.insert_edge(a, b, "r", "f1", 1.0, None).await.unwrap();
1591 gs.insert_edge(b, c, "r", "f2", 1.0, None).await.unwrap();
1592 gs.invalidate_edge(e1).await.unwrap();
1593
1594 let edges: Vec<_> = gs.all_active_edges_stream().try_collect().await.unwrap();
1595 assert_eq!(edges.len(), 1, "only the active edge should be returned");
1596 assert_eq!(edges[0].source_entity_id, b);
1597 assert_eq!(edges[0].target_entity_id, c);
1598 }
1599
1600 #[tokio::test]
1601 async fn find_community_by_id_found_and_not_found() {
1602 let gs = setup().await;
1603 let cid = gs
1604 .upsert_community("grp", "summary", &[1, 2], None)
1605 .await
1606 .unwrap();
1607 let found = gs.find_community_by_id(cid).await.unwrap();
1608 assert!(found.is_some());
1609 assert_eq!(found.unwrap().name, "grp");
1610
1611 let missing = gs.find_community_by_id(9999).await.unwrap();
1612 assert!(missing.is_none());
1613 }
1614
1615 #[tokio::test]
1616 async fn delete_all_communities_clears_table() {
1617 let gs = setup().await;
1618 gs.upsert_community("c1", "s1", &[1], None).await.unwrap();
1619 gs.upsert_community("c2", "s2", &[2], None).await.unwrap();
1620 assert_eq!(gs.community_count().await.unwrap(), 2);
1621 gs.delete_all_communities().await.unwrap();
1622 assert_eq!(gs.community_count().await.unwrap(), 0);
1623 }
1624
1625 #[tokio::test]
1626 async fn test_find_entities_fuzzy_no_results() {
1627 let gs = setup().await;
1628 gs.upsert_entity("Alpha", "Alpha", EntityType::Concept, None)
1629 .await
1630 .unwrap();
1631 let results = gs.find_entities_fuzzy("zzzznonexistent", 10).await.unwrap();
1632 assert!(
1633 results.is_empty(),
1634 "no entities should match an unknown term"
1635 );
1636 }
1637
1638 #[tokio::test]
1641 async fn upsert_entity_stores_canonical_name() {
1642 let gs = setup().await;
1643 gs.upsert_entity("rust", "rust", EntityType::Language, None)
1644 .await
1645 .unwrap();
1646 let entity = gs
1647 .find_entity("rust", EntityType::Language)
1648 .await
1649 .unwrap()
1650 .unwrap();
1651 assert_eq!(entity.canonical_name, "rust");
1652 assert_eq!(entity.name, "rust");
1653 }
1654
1655 #[tokio::test]
1656 async fn add_alias_idempotent() {
1657 let gs = setup().await;
1658 let id = gs
1659 .upsert_entity("rust", "rust", EntityType::Language, None)
1660 .await
1661 .unwrap();
1662 gs.add_alias(id, "rust-lang").await.unwrap();
1663 gs.add_alias(id, "rust-lang").await.unwrap();
1665 let aliases = gs.aliases_for_entity(id).await.unwrap();
1666 assert_eq!(
1667 aliases
1668 .iter()
1669 .filter(|a| a.alias_name == "rust-lang")
1670 .count(),
1671 1
1672 );
1673 }
1674
1675 #[tokio::test]
1678 async fn find_entity_by_id_found() {
1679 let gs = setup().await;
1680 let id = gs
1681 .upsert_entity("FindById", "finbyid", EntityType::Concept, Some("summary"))
1682 .await
1683 .unwrap();
1684 let entity = gs.find_entity_by_id(id).await.unwrap();
1685 assert!(entity.is_some());
1686 let entity = entity.unwrap();
1687 assert_eq!(entity.id, id);
1688 assert_eq!(entity.name, "FindById");
1689 }
1690
1691 #[tokio::test]
1692 async fn find_entity_by_id_not_found() {
1693 let gs = setup().await;
1694 let result = gs.find_entity_by_id(99999).await.unwrap();
1695 assert!(result.is_none());
1696 }
1697
1698 #[tokio::test]
1699 async fn set_entity_qdrant_point_id_updates() {
1700 let gs = setup().await;
1701 let id = gs
1702 .upsert_entity("QdrantPoint", "qdrantpoint", EntityType::Concept, None)
1703 .await
1704 .unwrap();
1705 let point_id = "550e8400-e29b-41d4-a716-446655440000";
1706 gs.set_entity_qdrant_point_id(id, point_id).await.unwrap();
1707
1708 let entity = gs.find_entity_by_id(id).await.unwrap().unwrap();
1709 assert_eq!(entity.qdrant_point_id.as_deref(), Some(point_id));
1710 }
1711
1712 #[tokio::test]
1713 async fn find_entities_fuzzy_matches_summary() {
1714 let gs = setup().await;
1715 gs.upsert_entity(
1716 "Rust",
1717 "Rust",
1718 EntityType::Language,
1719 Some("a systems programming language"),
1720 )
1721 .await
1722 .unwrap();
1723 gs.upsert_entity(
1724 "Go",
1725 "Go",
1726 EntityType::Language,
1727 Some("a compiled language by Google"),
1728 )
1729 .await
1730 .unwrap();
1731 let results = gs.find_entities_fuzzy("systems", 10).await.unwrap();
1733 assert_eq!(results.len(), 1);
1734 assert_eq!(results[0].name, "Rust");
1735 }
1736
1737 #[tokio::test]
1738 async fn find_entities_fuzzy_empty_query() {
1739 let gs = setup().await;
1740 gs.upsert_entity("Alpha", "Alpha", EntityType::Concept, None)
1741 .await
1742 .unwrap();
1743 let results = gs.find_entities_fuzzy("", 10).await.unwrap();
1745 assert!(results.is_empty(), "empty query should return no results");
1746 let results = gs.find_entities_fuzzy(" ", 10).await.unwrap();
1748 assert!(
1749 results.is_empty(),
1750 "whitespace query should return no results"
1751 );
1752 }
1753
1754 #[tokio::test]
1755 async fn find_entity_by_alias_case_insensitive() {
1756 let gs = setup().await;
1757 let id = gs
1758 .upsert_entity("rust", "rust", EntityType::Language, None)
1759 .await
1760 .unwrap();
1761 gs.add_alias(id, "rust").await.unwrap();
1762 gs.add_alias(id, "rust-lang").await.unwrap();
1763
1764 let found = gs
1765 .find_entity_by_alias("RUST-LANG", EntityType::Language)
1766 .await
1767 .unwrap();
1768 assert!(found.is_some());
1769 assert_eq!(found.unwrap().id, id);
1770 }
1771
1772 #[tokio::test]
1773 async fn find_entity_by_alias_returns_none_for_unknown() {
1774 let gs = setup().await;
1775 let id = gs
1776 .upsert_entity("rust", "rust", EntityType::Language, None)
1777 .await
1778 .unwrap();
1779 gs.add_alias(id, "rust").await.unwrap();
1780
1781 let found = gs
1782 .find_entity_by_alias("python", EntityType::Language)
1783 .await
1784 .unwrap();
1785 assert!(found.is_none());
1786 }
1787
1788 #[tokio::test]
1789 async fn find_entity_by_alias_filters_by_entity_type() {
1790 let gs = setup().await;
1792 let lang_id = gs
1793 .upsert_entity("python", "python", EntityType::Language, None)
1794 .await
1795 .unwrap();
1796 gs.add_alias(lang_id, "python").await.unwrap();
1797
1798 let found_tool = gs
1799 .find_entity_by_alias("python", EntityType::Tool)
1800 .await
1801 .unwrap();
1802 assert!(
1803 found_tool.is_none(),
1804 "cross-type alias collision must not occur"
1805 );
1806
1807 let found_lang = gs
1808 .find_entity_by_alias("python", EntityType::Language)
1809 .await
1810 .unwrap();
1811 assert!(found_lang.is_some());
1812 assert_eq!(found_lang.unwrap().id, lang_id);
1813 }
1814
1815 #[tokio::test]
1816 async fn aliases_for_entity_returns_all() {
1817 let gs = setup().await;
1818 let id = gs
1819 .upsert_entity("rust", "rust", EntityType::Language, None)
1820 .await
1821 .unwrap();
1822 gs.add_alias(id, "rust").await.unwrap();
1823 gs.add_alias(id, "rust-lang").await.unwrap();
1824 gs.add_alias(id, "rustlang").await.unwrap();
1825
1826 let aliases = gs.aliases_for_entity(id).await.unwrap();
1827 assert_eq!(aliases.len(), 3);
1828 let names: Vec<&str> = aliases.iter().map(|a| a.alias_name.as_str()).collect();
1829 assert!(names.contains(&"rust"));
1830 assert!(names.contains(&"rust-lang"));
1831 assert!(names.contains(&"rustlang"));
1832 }
1833
1834 #[tokio::test]
1835 async fn find_entities_fuzzy_includes_aliases() {
1836 let gs = setup().await;
1837 let id = gs
1838 .upsert_entity("rust", "rust", EntityType::Language, None)
1839 .await
1840 .unwrap();
1841 gs.add_alias(id, "rust-lang").await.unwrap();
1842 gs.upsert_entity("python", "python", EntityType::Language, None)
1843 .await
1844 .unwrap();
1845
1846 let results = gs.find_entities_fuzzy("rust-lang", 10).await.unwrap();
1848 assert!(!results.is_empty());
1849 assert!(results.iter().any(|e| e.id == id));
1850 }
1851
1852 #[tokio::test]
1853 async fn orphan_alias_cleanup_on_entity_delete() {
1854 let gs = setup().await;
1855 let id = gs
1856 .upsert_entity("rust", "rust", EntityType::Language, None)
1857 .await
1858 .unwrap();
1859 gs.add_alias(id, "rust").await.unwrap();
1860 gs.add_alias(id, "rust-lang").await.unwrap();
1861
1862 sqlx::query("DELETE FROM graph_entities WHERE id = ?1")
1864 .bind(id)
1865 .execute(&gs.pool)
1866 .await
1867 .unwrap();
1868
1869 let aliases = gs.aliases_for_entity(id).await.unwrap();
1871 assert!(
1872 aliases.is_empty(),
1873 "aliases should cascade-delete with entity"
1874 );
1875 }
1876
1877 #[tokio::test]
1887 async fn migration_024_backfill_preserves_entities_and_edges() {
1888 use sqlx::Acquire as _;
1889 use sqlx::ConnectOptions as _;
1890 use sqlx::sqlite::SqliteConnectOptions;
1891
1892 let opts = SqliteConnectOptions::from_url(&"sqlite::memory:".parse().unwrap())
1895 .unwrap()
1896 .foreign_keys(true);
1897 let pool = sqlx::pool::PoolOptions::<sqlx::Sqlite>::new()
1898 .max_connections(1)
1899 .connect_with(opts)
1900 .await
1901 .unwrap();
1902
1903 sqlx::query(
1905 "CREATE TABLE graph_entities (
1906 id INTEGER PRIMARY KEY AUTOINCREMENT,
1907 name TEXT NOT NULL,
1908 entity_type TEXT NOT NULL,
1909 summary TEXT,
1910 first_seen_at TEXT NOT NULL DEFAULT (datetime('now')),
1911 last_seen_at TEXT NOT NULL DEFAULT (datetime('now')),
1912 qdrant_point_id TEXT,
1913 UNIQUE(name, entity_type)
1914 )",
1915 )
1916 .execute(&pool)
1917 .await
1918 .unwrap();
1919
1920 sqlx::query(
1921 "CREATE TABLE graph_edges (
1922 id INTEGER PRIMARY KEY AUTOINCREMENT,
1923 source_entity_id INTEGER NOT NULL REFERENCES graph_entities(id) ON DELETE CASCADE,
1924 target_entity_id INTEGER NOT NULL REFERENCES graph_entities(id) ON DELETE CASCADE,
1925 relation TEXT NOT NULL,
1926 fact TEXT NOT NULL,
1927 confidence REAL NOT NULL DEFAULT 1.0,
1928 valid_from TEXT NOT NULL DEFAULT (datetime('now')),
1929 valid_to TEXT,
1930 created_at TEXT NOT NULL DEFAULT (datetime('now')),
1931 expired_at TEXT,
1932 episode_id INTEGER,
1933 qdrant_point_id TEXT
1934 )",
1935 )
1936 .execute(&pool)
1937 .await
1938 .unwrap();
1939
1940 sqlx::query(
1942 "CREATE VIRTUAL TABLE IF NOT EXISTS graph_entities_fts USING fts5(
1943 name, summary, content='graph_entities', content_rowid='id',
1944 tokenize='unicode61 remove_diacritics 2'
1945 )",
1946 )
1947 .execute(&pool)
1948 .await
1949 .unwrap();
1950 sqlx::query(
1951 "CREATE TRIGGER IF NOT EXISTS graph_entities_fts_insert AFTER INSERT ON graph_entities
1952 BEGIN INSERT INTO graph_entities_fts(rowid, name, summary) VALUES (new.id, new.name, COALESCE(new.summary, '')); END",
1953 )
1954 .execute(&pool)
1955 .await
1956 .unwrap();
1957 sqlx::query(
1958 "CREATE TRIGGER IF NOT EXISTS graph_entities_fts_delete AFTER DELETE ON graph_entities
1959 BEGIN INSERT INTO graph_entities_fts(graph_entities_fts, rowid, name, summary) VALUES ('delete', old.id, old.name, COALESCE(old.summary, '')); END",
1960 )
1961 .execute(&pool)
1962 .await
1963 .unwrap();
1964 sqlx::query(
1965 "CREATE TRIGGER IF NOT EXISTS graph_entities_fts_update AFTER UPDATE ON graph_entities
1966 BEGIN
1967 INSERT INTO graph_entities_fts(graph_entities_fts, rowid, name, summary) VALUES ('delete', old.id, old.name, COALESCE(old.summary, ''));
1968 INSERT INTO graph_entities_fts(rowid, name, summary) VALUES (new.id, new.name, COALESCE(new.summary, ''));
1969 END",
1970 )
1971 .execute(&pool)
1972 .await
1973 .unwrap();
1974
1975 let alice_id: i64 = sqlx::query_scalar(
1977 "INSERT INTO graph_entities (name, entity_type) VALUES ('Alice', 'person') RETURNING id",
1978 )
1979 .fetch_one(&pool)
1980 .await
1981 .unwrap();
1982
1983 let rust_id: i64 = sqlx::query_scalar(
1984 "INSERT INTO graph_entities (name, entity_type) VALUES ('Rust', 'language') RETURNING id",
1985 )
1986 .fetch_one(&pool)
1987 .await
1988 .unwrap();
1989
1990 sqlx::query(
1991 "INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact)
1992 VALUES (?1, ?2, 'uses', 'Alice uses Rust')",
1993 )
1994 .bind(alice_id)
1995 .bind(rust_id)
1996 .execute(&pool)
1997 .await
1998 .unwrap();
1999
2000 let mut conn = pool.acquire().await.unwrap();
2004 let conn = conn.acquire().await.unwrap();
2005
2006 sqlx::query("PRAGMA foreign_keys = OFF")
2007 .execute(&mut *conn)
2008 .await
2009 .unwrap();
2010 sqlx::query("ALTER TABLE graph_entities ADD COLUMN canonical_name TEXT")
2011 .execute(&mut *conn)
2012 .await
2013 .unwrap();
2014 sqlx::query("UPDATE graph_entities SET canonical_name = name WHERE canonical_name IS NULL")
2015 .execute(&mut *conn)
2016 .await
2017 .unwrap();
2018 sqlx::query(
2019 "CREATE TABLE graph_entities_new (
2020 id INTEGER PRIMARY KEY AUTOINCREMENT,
2021 name TEXT NOT NULL,
2022 canonical_name TEXT NOT NULL,
2023 entity_type TEXT NOT NULL,
2024 summary TEXT,
2025 first_seen_at TEXT NOT NULL DEFAULT (datetime('now')),
2026 last_seen_at TEXT NOT NULL DEFAULT (datetime('now')),
2027 qdrant_point_id TEXT,
2028 UNIQUE(canonical_name, entity_type)
2029 )",
2030 )
2031 .execute(&mut *conn)
2032 .await
2033 .unwrap();
2034 sqlx::query(
2035 "INSERT INTO graph_entities_new
2036 (id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id)
2037 SELECT id, name, COALESCE(canonical_name, name), entity_type, summary,
2038 first_seen_at, last_seen_at, qdrant_point_id
2039 FROM graph_entities",
2040 )
2041 .execute(&mut *conn)
2042 .await
2043 .unwrap();
2044 sqlx::query("DROP TABLE graph_entities")
2045 .execute(&mut *conn)
2046 .await
2047 .unwrap();
2048 sqlx::query("ALTER TABLE graph_entities_new RENAME TO graph_entities")
2049 .execute(&mut *conn)
2050 .await
2051 .unwrap();
2052 sqlx::query(
2054 "CREATE TRIGGER IF NOT EXISTS graph_entities_fts_insert AFTER INSERT ON graph_entities
2055 BEGIN INSERT INTO graph_entities_fts(rowid, name, summary) VALUES (new.id, new.name, COALESCE(new.summary, '')); END",
2056 )
2057 .execute(&mut *conn)
2058 .await
2059 .unwrap();
2060 sqlx::query(
2061 "CREATE TRIGGER IF NOT EXISTS graph_entities_fts_delete AFTER DELETE ON graph_entities
2062 BEGIN INSERT INTO graph_entities_fts(graph_entities_fts, rowid, name, summary) VALUES ('delete', old.id, old.name, COALESCE(old.summary, '')); END",
2063 )
2064 .execute(&mut *conn)
2065 .await
2066 .unwrap();
2067 sqlx::query(
2068 "CREATE TRIGGER IF NOT EXISTS graph_entities_fts_update AFTER UPDATE ON graph_entities
2069 BEGIN
2070 INSERT INTO graph_entities_fts(graph_entities_fts, rowid, name, summary) VALUES ('delete', old.id, old.name, COALESCE(old.summary, ''));
2071 INSERT INTO graph_entities_fts(rowid, name, summary) VALUES (new.id, new.name, COALESCE(new.summary, ''));
2072 END",
2073 )
2074 .execute(&mut *conn)
2075 .await
2076 .unwrap();
2077 sqlx::query("INSERT INTO graph_entities_fts(graph_entities_fts) VALUES('rebuild')")
2078 .execute(&mut *conn)
2079 .await
2080 .unwrap();
2081 sqlx::query(
2082 "CREATE TABLE graph_entity_aliases (
2083 id INTEGER PRIMARY KEY AUTOINCREMENT,
2084 entity_id INTEGER NOT NULL REFERENCES graph_entities(id) ON DELETE CASCADE,
2085 alias_name TEXT NOT NULL,
2086 created_at TEXT NOT NULL DEFAULT (datetime('now')),
2087 UNIQUE(alias_name, entity_id)
2088 )",
2089 )
2090 .execute(&mut *conn)
2091 .await
2092 .unwrap();
2093 sqlx::query(
2094 "INSERT OR IGNORE INTO graph_entity_aliases (entity_id, alias_name)
2095 SELECT id, name FROM graph_entities",
2096 )
2097 .execute(&mut *conn)
2098 .await
2099 .unwrap();
2100 sqlx::query("PRAGMA foreign_keys = ON")
2101 .execute(&mut *conn)
2102 .await
2103 .unwrap();
2104
2105 let alice_canon: String =
2107 sqlx::query_scalar("SELECT canonical_name FROM graph_entities WHERE id = ?1")
2108 .bind(alice_id)
2109 .fetch_one(&mut *conn)
2110 .await
2111 .unwrap();
2112 assert_eq!(
2113 alice_canon, "Alice",
2114 "canonical_name should equal pre-migration name"
2115 );
2116
2117 let rust_canon: String =
2118 sqlx::query_scalar("SELECT canonical_name FROM graph_entities WHERE id = ?1")
2119 .bind(rust_id)
2120 .fetch_one(&mut *conn)
2121 .await
2122 .unwrap();
2123 assert_eq!(
2124 rust_canon, "Rust",
2125 "canonical_name should equal pre-migration name"
2126 );
2127
2128 let alice_aliases: Vec<String> =
2130 sqlx::query_scalar("SELECT alias_name FROM graph_entity_aliases WHERE entity_id = ?1")
2131 .bind(alice_id)
2132 .fetch_all(&mut *conn)
2133 .await
2134 .unwrap();
2135 assert!(
2136 alice_aliases.contains(&"Alice".to_owned()),
2137 "initial alias should be seeded from entity name"
2138 );
2139
2140 let edge_count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM graph_edges")
2142 .fetch_one(&mut *conn)
2143 .await
2144 .unwrap();
2145 assert_eq!(
2146 edge_count, 1,
2147 "graph_edges must survive migration 024 table recreation"
2148 );
2149 }
2150
2151 #[tokio::test]
2152 async fn find_entity_by_alias_same_alias_two_entities_deterministic() {
2153 let gs = setup().await;
2155 let id1 = gs
2156 .upsert_entity("python-v2", "python-v2", EntityType::Language, None)
2157 .await
2158 .unwrap();
2159 let id2 = gs
2160 .upsert_entity("python-v3", "python-v3", EntityType::Language, None)
2161 .await
2162 .unwrap();
2163 gs.add_alias(id1, "python").await.unwrap();
2164 gs.add_alias(id2, "python").await.unwrap();
2165
2166 let found = gs
2168 .find_entity_by_alias("python", EntityType::Language)
2169 .await
2170 .unwrap();
2171 assert!(found.is_some(), "should find an entity by shared alias");
2172 assert_eq!(
2174 found.unwrap().id,
2175 id1,
2176 "first-registered entity should win on shared alias"
2177 );
2178 }
2179
2180 #[tokio::test]
2183 async fn find_entities_fuzzy_special_chars() {
2184 let gs = setup().await;
2185 gs.upsert_entity("Graph", "Graph", EntityType::Concept, None)
2186 .await
2187 .unwrap();
2188 let results = gs.find_entities_fuzzy("graph\"()*:^", 10).await.unwrap();
2190 assert!(results.iter().any(|e| e.name == "Graph"));
2192 }
2193
2194 #[tokio::test]
2195 async fn find_entities_fuzzy_prefix_match() {
2196 let gs = setup().await;
2197 gs.upsert_entity("Graph", "Graph", EntityType::Concept, None)
2198 .await
2199 .unwrap();
2200 gs.upsert_entity("GraphQL", "GraphQL", EntityType::Concept, None)
2201 .await
2202 .unwrap();
2203 gs.upsert_entity("Unrelated", "Unrelated", EntityType::Concept, None)
2204 .await
2205 .unwrap();
2206 let results = gs.find_entities_fuzzy("Gra", 10).await.unwrap();
2208 assert_eq!(results.len(), 2);
2209 assert!(results.iter().any(|e| e.name == "Graph"));
2210 assert!(results.iter().any(|e| e.name == "GraphQL"));
2211 }
2212
2213 #[tokio::test]
2214 async fn find_entities_fuzzy_fts5_operator_injection() {
2215 let gs = setup().await;
2216 gs.upsert_entity("Graph", "Graph", EntityType::Concept, None)
2217 .await
2218 .unwrap();
2219 gs.upsert_entity("Unrelated", "Unrelated", EntityType::Concept, None)
2220 .await
2221 .unwrap();
2222 let results = gs
2227 .find_entities_fuzzy("graph OR unrelated", 10)
2228 .await
2229 .unwrap();
2230 assert!(
2231 results.is_empty(),
2232 "implicit AND of 'graph*' and 'unrelated*' should match no entity"
2233 );
2234 }
2235
2236 #[tokio::test]
2237 async fn find_entities_fuzzy_after_entity_update() {
2238 let gs = setup().await;
2239 gs.upsert_entity(
2241 "Foo",
2242 "Foo",
2243 EntityType::Concept,
2244 Some("initial summary bar"),
2245 )
2246 .await
2247 .unwrap();
2248 gs.upsert_entity(
2250 "Foo",
2251 "Foo",
2252 EntityType::Concept,
2253 Some("updated summary baz"),
2254 )
2255 .await
2256 .unwrap();
2257 let old_results = gs.find_entities_fuzzy("bar", 10).await.unwrap();
2259 assert!(
2260 old_results.is_empty(),
2261 "old summary content should not match after update"
2262 );
2263 let new_results = gs.find_entities_fuzzy("baz", 10).await.unwrap();
2265 assert_eq!(new_results.len(), 1);
2266 assert_eq!(new_results[0].name, "Foo");
2267 }
2268
2269 #[tokio::test]
2270 async fn find_entities_fuzzy_only_special_chars() {
2271 let gs = setup().await;
2272 gs.upsert_entity("Alpha", "Alpha", EntityType::Concept, None)
2273 .await
2274 .unwrap();
2275 let results = gs.find_entities_fuzzy("***", 10).await.unwrap();
2279 assert!(
2280 results.is_empty(),
2281 "only special chars should return no results"
2282 );
2283 let results = gs.find_entities_fuzzy("(((", 10).await.unwrap();
2284 assert!(results.is_empty(), "only parens should return no results");
2285 let results = gs.find_entities_fuzzy("\"\"\"", 10).await.unwrap();
2286 assert!(results.is_empty(), "only quotes should return no results");
2287 }
2288
2289 async fn insert_test_message(gs: &GraphStore, content: &str) -> crate::types::MessageId {
2290 let conv_id: i64 =
2292 sqlx::query_scalar("INSERT INTO conversations DEFAULT VALUES RETURNING id")
2293 .fetch_one(&gs.pool)
2294 .await
2295 .unwrap();
2296 let id: i64 = sqlx::query_scalar(
2297 "INSERT INTO messages (conversation_id, role, content) VALUES (?1, 'user', ?2) RETURNING id",
2298 )
2299 .bind(conv_id)
2300 .bind(content)
2301 .fetch_one(&gs.pool)
2302 .await
2303 .unwrap();
2304 crate::types::MessageId(id)
2305 }
2306
2307 #[tokio::test]
2308 async fn unprocessed_messages_for_backfill_returns_unprocessed() {
2309 let gs = setup().await;
2310 let id1 = insert_test_message(&gs, "hello world").await;
2311 let id2 = insert_test_message(&gs, "second message").await;
2312
2313 let rows = gs.unprocessed_messages_for_backfill(10).await.unwrap();
2314 assert_eq!(rows.len(), 2);
2315 assert!(rows.iter().any(|(id, _)| *id == id1));
2316 assert!(rows.iter().any(|(id, _)| *id == id2));
2317 }
2318
2319 #[tokio::test]
2320 async fn unprocessed_messages_for_backfill_respects_limit() {
2321 let gs = setup().await;
2322 insert_test_message(&gs, "msg1").await;
2323 insert_test_message(&gs, "msg2").await;
2324 insert_test_message(&gs, "msg3").await;
2325
2326 let rows = gs.unprocessed_messages_for_backfill(2).await.unwrap();
2327 assert_eq!(rows.len(), 2);
2328 }
2329
2330 #[tokio::test]
2331 async fn mark_messages_graph_processed_updates_flag() {
2332 let gs = setup().await;
2333 let id1 = insert_test_message(&gs, "to process").await;
2334 let _id2 = insert_test_message(&gs, "also to process").await;
2335
2336 let count_before = gs.unprocessed_message_count().await.unwrap();
2338 assert_eq!(count_before, 2);
2339
2340 gs.mark_messages_graph_processed(&[id1]).await.unwrap();
2341
2342 let count_after = gs.unprocessed_message_count().await.unwrap();
2343 assert_eq!(count_after, 1);
2344
2345 let rows = gs.unprocessed_messages_for_backfill(10).await.unwrap();
2347 assert!(!rows.iter().any(|(id, _)| *id == id1));
2348 }
2349
2350 #[tokio::test]
2351 async fn mark_messages_graph_processed_empty_ids_is_noop() {
2352 let gs = setup().await;
2353 insert_test_message(&gs, "message").await;
2354
2355 gs.mark_messages_graph_processed(&[]).await.unwrap();
2356
2357 let count = gs.unprocessed_message_count().await.unwrap();
2358 assert_eq!(count, 1);
2359 }
2360}