1use std::collections::HashMap;
5
6use futures::Stream;
7use sqlx::SqlitePool;
8
9use crate::error::MemoryError;
10use crate::sqlite::messages::sanitize_fts5_query;
11use crate::types::MessageId;
12
13use super::types::{Community, Edge, Entity, EntityAlias, EntityType};
14
15pub struct GraphStore {
16 pool: SqlitePool,
17}
18
19impl GraphStore {
20 #[must_use]
21 pub fn new(pool: SqlitePool) -> Self {
22 Self { pool }
23 }
24
25 #[must_use]
26 pub fn pool(&self) -> &SqlitePool {
27 &self.pool
28 }
29
30 pub async fn upsert_entity(
43 &self,
44 surface_name: &str,
45 canonical_name: &str,
46 entity_type: EntityType,
47 summary: Option<&str>,
48 ) -> Result<i64, MemoryError> {
49 let type_str = entity_type.as_str();
50 let id: i64 = sqlx::query_scalar(
51 "INSERT INTO graph_entities (name, canonical_name, entity_type, summary)
52 VALUES (?1, ?2, ?3, ?4)
53 ON CONFLICT(canonical_name, entity_type) DO UPDATE SET
54 name = excluded.name,
55 summary = COALESCE(excluded.summary, summary),
56 last_seen_at = datetime('now')
57 RETURNING id",
58 )
59 .bind(surface_name)
60 .bind(canonical_name)
61 .bind(type_str)
62 .bind(summary)
63 .fetch_one(&self.pool)
64 .await?;
65 Ok(id)
66 }
67
68 pub async fn find_entity(
74 &self,
75 canonical_name: &str,
76 entity_type: EntityType,
77 ) -> Result<Option<Entity>, MemoryError> {
78 let type_str = entity_type.as_str();
79 let row: Option<EntityRow> = sqlx::query_as(
80 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
81 FROM graph_entities
82 WHERE canonical_name = ?1 AND entity_type = ?2",
83 )
84 .bind(canonical_name)
85 .bind(type_str)
86 .fetch_optional(&self.pool)
87 .await?;
88 row.map(entity_from_row).transpose()
89 }
90
91 pub async fn find_entity_by_id(&self, entity_id: i64) -> Result<Option<Entity>, MemoryError> {
97 let row: Option<EntityRow> = sqlx::query_as(
98 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
99 FROM graph_entities
100 WHERE id = ?1",
101 )
102 .bind(entity_id)
103 .fetch_optional(&self.pool)
104 .await?;
105 row.map(entity_from_row).transpose()
106 }
107
108 pub async fn set_entity_qdrant_point_id(
114 &self,
115 entity_id: i64,
116 point_id: &str,
117 ) -> Result<(), MemoryError> {
118 sqlx::query("UPDATE graph_entities SET qdrant_point_id = ?1 WHERE id = ?2")
119 .bind(point_id)
120 .bind(entity_id)
121 .execute(&self.pool)
122 .await?;
123 Ok(())
124 }
125
126 pub async fn find_entities_fuzzy(
147 &self,
148 query: &str,
149 limit: usize,
150 ) -> Result<Vec<Entity>, MemoryError> {
151 const FTS5_OPERATORS: &[&str] = &["AND", "OR", "NOT", "NEAR"];
155 let query = &query[..query.floor_char_boundary(512)];
156 let sanitized = sanitize_fts5_query(query);
159 if sanitized.is_empty() {
160 return Ok(vec![]);
161 }
162 let fts_query: String = sanitized
163 .split_whitespace()
164 .filter(|t| !FTS5_OPERATORS.contains(t))
165 .map(|t| format!("{t}*"))
166 .collect::<Vec<_>>()
167 .join(" ");
168 if fts_query.is_empty() {
169 return Ok(vec![]);
170 }
171
172 let limit = i64::try_from(limit)?;
173 let rows: Vec<EntityRow> = sqlx::query_as(
176 "SELECT DISTINCT e.id, e.name, e.canonical_name, e.entity_type, e.summary,
177 e.first_seen_at, e.last_seen_at, e.qdrant_point_id
178 FROM graph_entities_fts fts
179 JOIN graph_entities e ON e.id = fts.rowid
180 WHERE graph_entities_fts MATCH ?1
181 UNION
182 SELECT e.id, e.name, e.canonical_name, e.entity_type, e.summary,
183 e.first_seen_at, e.last_seen_at, e.qdrant_point_id
184 FROM graph_entity_aliases a
185 JOIN graph_entities e ON e.id = a.entity_id
186 WHERE a.alias_name LIKE ?2 ESCAPE '\\' COLLATE NOCASE
187 LIMIT ?3",
188 )
189 .bind(&fts_query)
190 .bind(format!(
191 "%{}%",
192 query
193 .trim()
194 .replace('\\', "\\\\")
195 .replace('%', "\\%")
196 .replace('_', "\\_")
197 ))
198 .bind(limit)
199 .fetch_all(&self.pool)
200 .await?;
201 rows.into_iter()
202 .map(entity_from_row)
203 .collect::<Result<Vec<_>, _>>()
204 }
205
206 pub fn all_entities_stream(&self) -> impl Stream<Item = Result<Entity, MemoryError>> + '_ {
208 use futures::StreamExt as _;
209 sqlx::query_as::<_, EntityRow>(
210 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
211 FROM graph_entities ORDER BY id ASC",
212 )
213 .fetch(&self.pool)
214 .map(|r: Result<EntityRow, sqlx::Error>| {
215 r.map_err(MemoryError::from).and_then(entity_from_row)
216 })
217 }
218
219 pub async fn add_alias(&self, entity_id: i64, alias_name: &str) -> Result<(), MemoryError> {
227 sqlx::query(
228 "INSERT OR IGNORE INTO graph_entity_aliases (entity_id, alias_name) VALUES (?1, ?2)",
229 )
230 .bind(entity_id)
231 .bind(alias_name)
232 .execute(&self.pool)
233 .await?;
234 Ok(())
235 }
236
237 pub async fn find_entity_by_alias(
245 &self,
246 alias_name: &str,
247 entity_type: EntityType,
248 ) -> Result<Option<Entity>, MemoryError> {
249 let type_str = entity_type.as_str();
250 let row: Option<EntityRow> = sqlx::query_as(
251 "SELECT e.id, e.name, e.canonical_name, e.entity_type, e.summary,
252 e.first_seen_at, e.last_seen_at, e.qdrant_point_id
253 FROM graph_entity_aliases a
254 JOIN graph_entities e ON e.id = a.entity_id
255 WHERE a.alias_name = ?1 COLLATE NOCASE
256 AND e.entity_type = ?2
257 ORDER BY e.id ASC
258 LIMIT 1",
259 )
260 .bind(alias_name)
261 .bind(type_str)
262 .fetch_optional(&self.pool)
263 .await?;
264 row.map(entity_from_row).transpose()
265 }
266
267 pub async fn aliases_for_entity(
273 &self,
274 entity_id: i64,
275 ) -> Result<Vec<EntityAlias>, MemoryError> {
276 let rows: Vec<AliasRow> = sqlx::query_as(
277 "SELECT id, entity_id, alias_name, created_at
278 FROM graph_entity_aliases
279 WHERE entity_id = ?1
280 ORDER BY id ASC",
281 )
282 .bind(entity_id)
283 .fetch_all(&self.pool)
284 .await?;
285 Ok(rows.into_iter().map(alias_from_row).collect())
286 }
287
288 pub async fn all_entities(&self) -> Result<Vec<Entity>, MemoryError> {
294 use futures::TryStreamExt as _;
295 self.all_entities_stream().try_collect().await
296 }
297
298 pub async fn entity_count(&self) -> Result<i64, MemoryError> {
304 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM graph_entities")
305 .fetch_one(&self.pool)
306 .await?;
307 Ok(count)
308 }
309
310 pub async fn insert_edge(
323 &self,
324 source_entity_id: i64,
325 target_entity_id: i64,
326 relation: &str,
327 fact: &str,
328 confidence: f32,
329 episode_id: Option<MessageId>,
330 ) -> Result<i64, MemoryError> {
331 let confidence = confidence.clamp(0.0, 1.0);
332
333 let existing: Option<(i64, f64)> = sqlx::query_as(
334 "SELECT id, confidence FROM graph_edges
335 WHERE source_entity_id = ?1
336 AND target_entity_id = ?2
337 AND relation = ?3
338 AND valid_to IS NULL
339 LIMIT 1",
340 )
341 .bind(source_entity_id)
342 .bind(target_entity_id)
343 .bind(relation)
344 .fetch_optional(&self.pool)
345 .await?;
346
347 if let Some((existing_id, stored_conf)) = existing {
348 let updated_conf = f64::from(confidence).max(stored_conf);
349 sqlx::query("UPDATE graph_edges SET confidence = ?1 WHERE id = ?2")
350 .bind(updated_conf)
351 .bind(existing_id)
352 .execute(&self.pool)
353 .await?;
354 return Ok(existing_id);
355 }
356
357 let episode_raw: Option<i64> = episode_id.map(|m| m.0);
358 let id: i64 = sqlx::query_scalar(
359 "INSERT INTO graph_edges (source_entity_id, target_entity_id, relation, fact, confidence, episode_id)
360 VALUES (?1, ?2, ?3, ?4, ?5, ?6)
361 RETURNING id",
362 )
363 .bind(source_entity_id)
364 .bind(target_entity_id)
365 .bind(relation)
366 .bind(fact)
367 .bind(f64::from(confidence))
368 .bind(episode_raw)
369 .fetch_one(&self.pool)
370 .await?;
371 Ok(id)
372 }
373
374 pub async fn invalidate_edge(&self, edge_id: i64) -> Result<(), MemoryError> {
380 sqlx::query(
381 "UPDATE graph_edges SET valid_to = datetime('now'), expired_at = datetime('now')
382 WHERE id = ?1",
383 )
384 .bind(edge_id)
385 .execute(&self.pool)
386 .await?;
387 Ok(())
388 }
389
390 pub async fn edges_for_entity(&self, entity_id: i64) -> Result<Vec<Edge>, MemoryError> {
396 let rows: Vec<EdgeRow> = sqlx::query_as(
397 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
398 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
399 FROM graph_edges
400 WHERE valid_to IS NULL
401 AND (source_entity_id = ?1 OR target_entity_id = ?1)",
402 )
403 .bind(entity_id)
404 .fetch_all(&self.pool)
405 .await?;
406 Ok(rows.into_iter().map(edge_from_row).collect())
407 }
408
409 pub async fn edge_history_for_entity(
416 &self,
417 entity_id: i64,
418 limit: usize,
419 ) -> Result<Vec<Edge>, MemoryError> {
420 let limit = i64::try_from(limit)?;
421 let rows: Vec<EdgeRow> = sqlx::query_as(
422 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
423 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
424 FROM graph_edges
425 WHERE source_entity_id = ?1 OR target_entity_id = ?1
426 ORDER BY valid_from DESC
427 LIMIT ?2",
428 )
429 .bind(entity_id)
430 .bind(limit)
431 .fetch_all(&self.pool)
432 .await?;
433 Ok(rows.into_iter().map(edge_from_row).collect())
434 }
435
436 pub async fn edges_between(
442 &self,
443 entity_a: i64,
444 entity_b: i64,
445 ) -> Result<Vec<Edge>, MemoryError> {
446 let rows: Vec<EdgeRow> = sqlx::query_as(
447 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
448 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
449 FROM graph_edges
450 WHERE valid_to IS NULL
451 AND ((source_entity_id = ?1 AND target_entity_id = ?2)
452 OR (source_entity_id = ?2 AND target_entity_id = ?1))",
453 )
454 .bind(entity_a)
455 .bind(entity_b)
456 .fetch_all(&self.pool)
457 .await?;
458 Ok(rows.into_iter().map(edge_from_row).collect())
459 }
460
461 pub async fn edges_exact(
467 &self,
468 source_entity_id: i64,
469 target_entity_id: i64,
470 ) -> Result<Vec<Edge>, MemoryError> {
471 let rows: Vec<EdgeRow> = sqlx::query_as(
472 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
473 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
474 FROM graph_edges
475 WHERE valid_to IS NULL
476 AND source_entity_id = ?1
477 AND target_entity_id = ?2",
478 )
479 .bind(source_entity_id)
480 .bind(target_entity_id)
481 .fetch_all(&self.pool)
482 .await?;
483 Ok(rows.into_iter().map(edge_from_row).collect())
484 }
485
486 pub async fn active_edge_count(&self) -> Result<i64, MemoryError> {
492 let count: i64 =
493 sqlx::query_scalar("SELECT COUNT(*) FROM graph_edges WHERE valid_to IS NULL")
494 .fetch_one(&self.pool)
495 .await?;
496 Ok(count)
497 }
498
499 pub async fn upsert_community(
511 &self,
512 name: &str,
513 summary: &str,
514 entity_ids: &[i64],
515 fingerprint: Option<&str>,
516 ) -> Result<i64, MemoryError> {
517 let entity_ids_json = serde_json::to_string(entity_ids)?;
518 let id: i64 = sqlx::query_scalar(
519 "INSERT INTO graph_communities (name, summary, entity_ids, fingerprint)
520 VALUES (?1, ?2, ?3, ?4)
521 ON CONFLICT(name) DO UPDATE SET
522 summary = excluded.summary,
523 entity_ids = excluded.entity_ids,
524 fingerprint = COALESCE(excluded.fingerprint, fingerprint),
525 updated_at = datetime('now')
526 RETURNING id",
527 )
528 .bind(name)
529 .bind(summary)
530 .bind(entity_ids_json)
531 .bind(fingerprint)
532 .fetch_one(&self.pool)
533 .await?;
534 Ok(id)
535 }
536
537 pub async fn community_fingerprints(&self) -> Result<HashMap<String, i64>, MemoryError> {
544 let rows: Vec<(String, i64)> = sqlx::query_as(
545 "SELECT fingerprint, id FROM graph_communities WHERE fingerprint IS NOT NULL",
546 )
547 .fetch_all(&self.pool)
548 .await?;
549 Ok(rows.into_iter().collect())
550 }
551
552 pub async fn delete_community_by_id(&self, id: i64) -> Result<(), MemoryError> {
558 sqlx::query("DELETE FROM graph_communities WHERE id = ?1")
559 .bind(id)
560 .execute(&self.pool)
561 .await?;
562 Ok(())
563 }
564
565 pub async fn clear_community_fingerprint(&self, id: i64) -> Result<(), MemoryError> {
574 sqlx::query("UPDATE graph_communities SET fingerprint = NULL WHERE id = ?1")
575 .bind(id)
576 .execute(&self.pool)
577 .await?;
578 Ok(())
579 }
580
581 pub async fn community_for_entity(
590 &self,
591 entity_id: i64,
592 ) -> Result<Option<Community>, MemoryError> {
593 let row: Option<CommunityRow> = sqlx::query_as(
594 "SELECT c.id, c.name, c.summary, c.entity_ids, c.fingerprint, c.created_at, c.updated_at
595 FROM graph_communities c, json_each(c.entity_ids) j
596 WHERE CAST(j.value AS INTEGER) = ?1
597 LIMIT 1",
598 )
599 .bind(entity_id)
600 .fetch_optional(&self.pool)
601 .await?;
602 match row {
603 Some(row) => {
604 let entity_ids: Vec<i64> = serde_json::from_str(&row.entity_ids)?;
605 Ok(Some(Community {
606 id: row.id,
607 name: row.name,
608 summary: row.summary,
609 entity_ids,
610 fingerprint: row.fingerprint,
611 created_at: row.created_at,
612 updated_at: row.updated_at,
613 }))
614 }
615 None => Ok(None),
616 }
617 }
618
619 pub async fn all_communities(&self) -> Result<Vec<Community>, MemoryError> {
625 let rows: Vec<CommunityRow> = sqlx::query_as(
626 "SELECT id, name, summary, entity_ids, fingerprint, created_at, updated_at
627 FROM graph_communities
628 ORDER BY id ASC",
629 )
630 .fetch_all(&self.pool)
631 .await?;
632
633 rows.into_iter()
634 .map(|row| {
635 let entity_ids: Vec<i64> = serde_json::from_str(&row.entity_ids)?;
636 Ok(Community {
637 id: row.id,
638 name: row.name,
639 summary: row.summary,
640 entity_ids,
641 fingerprint: row.fingerprint,
642 created_at: row.created_at,
643 updated_at: row.updated_at,
644 })
645 })
646 .collect()
647 }
648
649 pub async fn community_count(&self) -> Result<i64, MemoryError> {
655 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM graph_communities")
656 .fetch_one(&self.pool)
657 .await?;
658 Ok(count)
659 }
660
661 pub async fn get_metadata(&self, key: &str) -> Result<Option<String>, MemoryError> {
669 let val: Option<String> =
670 sqlx::query_scalar("SELECT value FROM graph_metadata WHERE key = ?1")
671 .bind(key)
672 .fetch_optional(&self.pool)
673 .await?;
674 Ok(val)
675 }
676
677 pub async fn set_metadata(&self, key: &str, value: &str) -> Result<(), MemoryError> {
683 sqlx::query(
684 "INSERT INTO graph_metadata (key, value) VALUES (?1, ?2)
685 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
686 )
687 .bind(key)
688 .bind(value)
689 .execute(&self.pool)
690 .await?;
691 Ok(())
692 }
693
694 pub async fn extraction_count(&self) -> Result<i64, MemoryError> {
702 let val = self.get_metadata("extraction_count").await?;
703 Ok(val.and_then(|v| v.parse::<i64>().ok()).unwrap_or(0))
704 }
705
706 pub fn all_active_edges_stream(&self) -> impl Stream<Item = Result<Edge, MemoryError>> + '_ {
708 use futures::StreamExt as _;
709 sqlx::query_as::<_, EdgeRow>(
710 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
711 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
712 FROM graph_edges
713 WHERE valid_to IS NULL
714 ORDER BY id ASC",
715 )
716 .fetch(&self.pool)
717 .map(|r| r.map_err(MemoryError::from).map(edge_from_row))
718 }
719
720 pub async fn edges_after_id(
737 &self,
738 after_id: i64,
739 limit: i64,
740 ) -> Result<Vec<Edge>, MemoryError> {
741 let rows: Vec<EdgeRow> = sqlx::query_as(
742 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
743 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
744 FROM graph_edges
745 WHERE valid_to IS NULL AND id > ?1
746 ORDER BY id ASC
747 LIMIT ?2",
748 )
749 .bind(after_id)
750 .bind(limit)
751 .fetch_all(&self.pool)
752 .await?;
753 Ok(rows.into_iter().map(edge_from_row).collect())
754 }
755
756 pub async fn find_community_by_id(&self, id: i64) -> Result<Option<Community>, MemoryError> {
762 let row: Option<CommunityRow> = sqlx::query_as(
763 "SELECT id, name, summary, entity_ids, fingerprint, created_at, updated_at
764 FROM graph_communities
765 WHERE id = ?1",
766 )
767 .bind(id)
768 .fetch_optional(&self.pool)
769 .await?;
770 match row {
771 Some(row) => {
772 let entity_ids: Vec<i64> = serde_json::from_str(&row.entity_ids)?;
773 Ok(Some(Community {
774 id: row.id,
775 name: row.name,
776 summary: row.summary,
777 entity_ids,
778 fingerprint: row.fingerprint,
779 created_at: row.created_at,
780 updated_at: row.updated_at,
781 }))
782 }
783 None => Ok(None),
784 }
785 }
786
787 pub async fn delete_all_communities(&self) -> Result<(), MemoryError> {
793 sqlx::query("DELETE FROM graph_communities")
794 .execute(&self.pool)
795 .await?;
796 Ok(())
797 }
798
799 pub async fn delete_expired_edges(&self, retention_days: u32) -> Result<usize, MemoryError> {
805 let days = i64::from(retention_days);
806 let result = sqlx::query(
807 "DELETE FROM graph_edges
808 WHERE expired_at IS NOT NULL
809 AND expired_at < datetime('now', '-' || ?1 || ' days')",
810 )
811 .bind(days)
812 .execute(&self.pool)
813 .await?;
814 Ok(usize::try_from(result.rows_affected())?)
815 }
816
817 pub async fn delete_orphan_entities(&self, retention_days: u32) -> Result<usize, MemoryError> {
823 let days = i64::from(retention_days);
824 let result = sqlx::query(
825 "DELETE FROM graph_entities
826 WHERE id NOT IN (
827 SELECT DISTINCT source_entity_id FROM graph_edges WHERE valid_to IS NULL
828 UNION
829 SELECT DISTINCT target_entity_id FROM graph_edges WHERE valid_to IS NULL
830 )
831 AND last_seen_at < datetime('now', '-' || ?1 || ' days')",
832 )
833 .bind(days)
834 .execute(&self.pool)
835 .await?;
836 Ok(usize::try_from(result.rows_affected())?)
837 }
838
839 pub async fn cap_entities(&self, max_entities: usize) -> Result<usize, MemoryError> {
848 let current = self.entity_count().await?;
849 let max = i64::try_from(max_entities)?;
850 if current <= max {
851 return Ok(0);
852 }
853 let excess = current - max;
854 let result = sqlx::query(
855 "DELETE FROM graph_entities
856 WHERE id IN (
857 SELECT e.id
858 FROM graph_entities e
859 LEFT JOIN (
860 SELECT source_entity_id AS eid, COUNT(*) AS cnt
861 FROM graph_edges WHERE valid_to IS NULL GROUP BY source_entity_id
862 UNION ALL
863 SELECT target_entity_id AS eid, COUNT(*) AS cnt
864 FROM graph_edges WHERE valid_to IS NULL GROUP BY target_entity_id
865 ) edge_counts ON e.id = edge_counts.eid
866 ORDER BY COALESCE(edge_counts.cnt, 0) ASC, e.last_seen_at ASC
867 LIMIT ?1
868 )",
869 )
870 .bind(excess)
871 .execute(&self.pool)
872 .await?;
873 Ok(usize::try_from(result.rows_affected())?)
874 }
875
876 pub async fn edges_at_timestamp(
890 &self,
891 entity_id: i64,
892 timestamp: &str,
893 ) -> Result<Vec<Edge>, MemoryError> {
894 let rows: Vec<EdgeRow> = sqlx::query_as(
898 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
899 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
900 FROM graph_edges
901 WHERE valid_to IS NULL
902 AND valid_from <= ?2
903 AND (source_entity_id = ?1 OR target_entity_id = ?1)
904 UNION ALL
905 SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
906 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
907 FROM graph_edges
908 WHERE valid_to IS NOT NULL
909 AND valid_from <= ?2
910 AND valid_to > ?2
911 AND (source_entity_id = ?1 OR target_entity_id = ?1)",
912 )
913 .bind(entity_id)
914 .bind(timestamp)
915 .fetch_all(&self.pool)
916 .await?;
917 Ok(rows.into_iter().map(edge_from_row).collect())
918 }
919
920 pub async fn edge_history(
929 &self,
930 source_entity_id: i64,
931 predicate: &str,
932 relation: Option<&str>,
933 limit: usize,
934 ) -> Result<Vec<Edge>, MemoryError> {
935 let escaped = predicate
937 .replace('\\', "\\\\")
938 .replace('%', "\\%")
939 .replace('_', "\\_");
940 let like_pattern = format!("%{escaped}%");
941 let limit = i64::try_from(limit)?;
942 let rows: Vec<EdgeRow> = if let Some(rel) = relation {
943 sqlx::query_as(
944 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
945 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
946 FROM graph_edges
947 WHERE source_entity_id = ?1
948 AND fact LIKE ?2 ESCAPE '\\'
949 AND relation = ?3
950 ORDER BY valid_from DESC
951 LIMIT ?4",
952 )
953 .bind(source_entity_id)
954 .bind(&like_pattern)
955 .bind(rel)
956 .bind(limit)
957 .fetch_all(&self.pool)
958 .await?
959 } else {
960 sqlx::query_as(
961 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
962 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
963 FROM graph_edges
964 WHERE source_entity_id = ?1
965 AND fact LIKE ?2 ESCAPE '\\'
966 ORDER BY valid_from DESC
967 LIMIT ?3",
968 )
969 .bind(source_entity_id)
970 .bind(&like_pattern)
971 .bind(limit)
972 .fetch_all(&self.pool)
973 .await?
974 };
975 Ok(rows.into_iter().map(edge_from_row).collect())
976 }
977
978 pub async fn bfs(
995 &self,
996 start_entity_id: i64,
997 max_hops: u32,
998 ) -> Result<(Vec<Entity>, Vec<Edge>), MemoryError> {
999 self.bfs_with_depth(start_entity_id, max_hops)
1000 .await
1001 .map(|(e, ed, _)| (e, ed))
1002 }
1003
1004 pub async fn bfs_with_depth(
1015 &self,
1016 start_entity_id: i64,
1017 max_hops: u32,
1018 ) -> Result<(Vec<Entity>, Vec<Edge>, std::collections::HashMap<i64, u32>), MemoryError> {
1019 self.bfs_core(start_entity_id, max_hops, None).await
1020 }
1021
1022 pub async fn bfs_at_timestamp(
1033 &self,
1034 start_entity_id: i64,
1035 max_hops: u32,
1036 timestamp: &str,
1037 ) -> Result<(Vec<Entity>, Vec<Edge>, std::collections::HashMap<i64, u32>), MemoryError> {
1038 self.bfs_core(start_entity_id, max_hops, Some(timestamp))
1039 .await
1040 }
1041
1042 async fn bfs_core(
1050 &self,
1051 start_entity_id: i64,
1052 max_hops: u32,
1053 at_timestamp: Option<&str>,
1054 ) -> Result<(Vec<Entity>, Vec<Edge>, std::collections::HashMap<i64, u32>), MemoryError> {
1055 use std::collections::HashMap;
1056
1057 const MAX_FRONTIER: usize = 300;
1060
1061 let mut depth_map: HashMap<i64, u32> = HashMap::new();
1062 let mut frontier: Vec<i64> = vec![start_entity_id];
1063 depth_map.insert(start_entity_id, 0);
1064
1065 for hop in 0..max_hops {
1066 if frontier.is_empty() {
1067 break;
1068 }
1069 frontier.truncate(MAX_FRONTIER);
1070 let placeholders = frontier
1072 .iter()
1073 .enumerate()
1074 .map(|(i, _)| format!("?{}", i + 1))
1075 .collect::<Vec<_>>()
1076 .join(", ");
1077 let edge_filter = if at_timestamp.is_some() {
1078 let ts_pos = frontier.len() * 3 + 1;
1079 format!("valid_from <= ?{ts_pos} AND (valid_to IS NULL OR valid_to > ?{ts_pos})")
1080 } else {
1081 "valid_to IS NULL".to_owned()
1082 };
1083 let neighbour_sql = format!(
1084 "SELECT DISTINCT CASE
1085 WHEN source_entity_id IN ({placeholders}) THEN target_entity_id
1086 ELSE source_entity_id
1087 END as neighbour_id
1088 FROM graph_edges
1089 WHERE {edge_filter}
1090 AND (source_entity_id IN ({placeholders}) OR target_entity_id IN ({placeholders}))"
1091 );
1092 let mut q = sqlx::query_scalar::<_, i64>(&neighbour_sql);
1093 for id in &frontier {
1094 q = q.bind(*id);
1095 }
1096 for id in &frontier {
1097 q = q.bind(*id);
1098 }
1099 for id in &frontier {
1100 q = q.bind(*id);
1101 }
1102 if let Some(ts) = at_timestamp {
1103 q = q.bind(ts);
1104 }
1105 let neighbours: Vec<i64> = q.fetch_all(&self.pool).await?;
1106 let mut next_frontier: Vec<i64> = Vec::new();
1107 for nbr in neighbours {
1108 if let std::collections::hash_map::Entry::Vacant(e) = depth_map.entry(nbr) {
1109 e.insert(hop + 1);
1110 next_frontier.push(nbr);
1111 }
1112 }
1113 frontier = next_frontier;
1114 }
1115
1116 self.bfs_fetch_results(depth_map, at_timestamp).await
1117 }
1118
1119 async fn bfs_fetch_results(
1121 &self,
1122 depth_map: std::collections::HashMap<i64, u32>,
1123 at_timestamp: Option<&str>,
1124 ) -> Result<(Vec<Entity>, Vec<Edge>, std::collections::HashMap<i64, u32>), MemoryError> {
1125 let mut visited_ids: Vec<i64> = depth_map.keys().copied().collect();
1126 if visited_ids.is_empty() {
1127 return Ok((Vec::new(), Vec::new(), depth_map));
1128 }
1129 if visited_ids.len() > 499 {
1131 tracing::warn!(
1132 total = visited_ids.len(),
1133 retained = 499,
1134 "bfs_fetch_results: visited entity set truncated to 499 to stay within SQLite bind limit; \
1135 some reachable entities will be dropped from results"
1136 );
1137 visited_ids.truncate(499);
1138 }
1139
1140 let placeholders = visited_ids
1141 .iter()
1142 .enumerate()
1143 .map(|(i, _)| format!("?{}", i + 1))
1144 .collect::<Vec<_>>()
1145 .join(", ");
1146 let edge_filter = if at_timestamp.is_some() {
1147 let ts_pos = visited_ids.len() * 2 + 1;
1148 format!("valid_from <= ?{ts_pos} AND (valid_to IS NULL OR valid_to > ?{ts_pos})")
1149 } else {
1150 "valid_to IS NULL".to_owned()
1151 };
1152 let edge_sql = format!(
1153 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
1154 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id
1155 FROM graph_edges
1156 WHERE {edge_filter}
1157 AND source_entity_id IN ({placeholders})
1158 AND target_entity_id IN ({placeholders})"
1159 );
1160 let mut edge_query = sqlx::query_as::<_, EdgeRow>(&edge_sql);
1161 for id in &visited_ids {
1162 edge_query = edge_query.bind(*id);
1163 }
1164 for id in &visited_ids {
1165 edge_query = edge_query.bind(*id);
1166 }
1167 if let Some(ts) = at_timestamp {
1168 edge_query = edge_query.bind(ts);
1169 }
1170 let edge_rows: Vec<EdgeRow> = edge_query.fetch_all(&self.pool).await?;
1171
1172 let entity_sql = format!(
1173 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
1174 FROM graph_entities WHERE id IN ({placeholders})"
1175 );
1176 let mut entity_query = sqlx::query_as::<_, EntityRow>(&entity_sql);
1177 for id in &visited_ids {
1178 entity_query = entity_query.bind(*id);
1179 }
1180 let entity_rows: Vec<EntityRow> = entity_query.fetch_all(&self.pool).await?;
1181
1182 let entities: Vec<Entity> = entity_rows
1183 .into_iter()
1184 .map(entity_from_row)
1185 .collect::<Result<Vec<_>, _>>()?;
1186 let edges: Vec<Edge> = edge_rows.into_iter().map(edge_from_row).collect();
1187
1188 Ok((entities, edges, depth_map))
1189 }
1190
1191 pub async fn find_entity_by_name(&self, name: &str) -> Result<Vec<Entity>, MemoryError> {
1207 let rows: Vec<EntityRow> = sqlx::query_as(
1208 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
1209 FROM graph_entities
1210 WHERE name = ?1 COLLATE NOCASE OR canonical_name = ?1 COLLATE NOCASE
1211 LIMIT 5",
1212 )
1213 .bind(name)
1214 .fetch_all(&self.pool)
1215 .await?;
1216
1217 if !rows.is_empty() {
1218 return rows.into_iter().map(entity_from_row).collect();
1219 }
1220
1221 self.find_entities_fuzzy(name, 5).await
1222 }
1223
1224 pub async fn unprocessed_messages_for_backfill(
1232 &self,
1233 limit: usize,
1234 ) -> Result<Vec<(crate::types::MessageId, String)>, MemoryError> {
1235 let limit = i64::try_from(limit)?;
1236 let rows: Vec<(i64, String)> = sqlx::query_as(
1237 "SELECT id, content FROM messages
1238 WHERE graph_processed = 0
1239 ORDER BY id ASC
1240 LIMIT ?1",
1241 )
1242 .bind(limit)
1243 .fetch_all(&self.pool)
1244 .await?;
1245 Ok(rows
1246 .into_iter()
1247 .map(|(id, content)| (crate::types::MessageId(id), content))
1248 .collect())
1249 }
1250
1251 pub async fn unprocessed_message_count(&self) -> Result<i64, MemoryError> {
1257 let count: i64 =
1258 sqlx::query_scalar("SELECT COUNT(*) FROM messages WHERE graph_processed = 0")
1259 .fetch_one(&self.pool)
1260 .await?;
1261 Ok(count)
1262 }
1263
1264 pub async fn mark_messages_graph_processed(
1270 &self,
1271 ids: &[crate::types::MessageId],
1272 ) -> Result<(), MemoryError> {
1273 if ids.is_empty() {
1274 return Ok(());
1275 }
1276 let placeholders = ids
1277 .iter()
1278 .enumerate()
1279 .map(|(i, _)| format!("?{}", i + 1))
1280 .collect::<Vec<_>>()
1281 .join(", ");
1282 let sql = format!("UPDATE messages SET graph_processed = 1 WHERE id IN ({placeholders})");
1283 let mut query = sqlx::query(&sql);
1284 for id in ids {
1285 query = query.bind(id.0);
1286 }
1287 query.execute(&self.pool).await?;
1288 Ok(())
1289 }
1290}
1291
1292#[derive(sqlx::FromRow)]
1295struct EntityRow {
1296 id: i64,
1297 name: String,
1298 canonical_name: String,
1299 entity_type: String,
1300 summary: Option<String>,
1301 first_seen_at: String,
1302 last_seen_at: String,
1303 qdrant_point_id: Option<String>,
1304}
1305
1306fn entity_from_row(row: EntityRow) -> Result<Entity, MemoryError> {
1307 let entity_type = row
1308 .entity_type
1309 .parse::<EntityType>()
1310 .map_err(MemoryError::GraphStore)?;
1311 Ok(Entity {
1312 id: row.id,
1313 name: row.name,
1314 canonical_name: row.canonical_name,
1315 entity_type,
1316 summary: row.summary,
1317 first_seen_at: row.first_seen_at,
1318 last_seen_at: row.last_seen_at,
1319 qdrant_point_id: row.qdrant_point_id,
1320 })
1321}
1322
1323#[derive(sqlx::FromRow)]
1324struct AliasRow {
1325 id: i64,
1326 entity_id: i64,
1327 alias_name: String,
1328 created_at: String,
1329}
1330
1331fn alias_from_row(row: AliasRow) -> EntityAlias {
1332 EntityAlias {
1333 id: row.id,
1334 entity_id: row.entity_id,
1335 alias_name: row.alias_name,
1336 created_at: row.created_at,
1337 }
1338}
1339
1340#[derive(sqlx::FromRow)]
1341struct EdgeRow {
1342 id: i64,
1343 source_entity_id: i64,
1344 target_entity_id: i64,
1345 relation: String,
1346 fact: String,
1347 confidence: f64,
1348 valid_from: String,
1349 valid_to: Option<String>,
1350 created_at: String,
1351 expired_at: Option<String>,
1352 episode_id: Option<i64>,
1353 qdrant_point_id: Option<String>,
1354}
1355
1356fn edge_from_row(row: EdgeRow) -> Edge {
1357 Edge {
1358 id: row.id,
1359 source_entity_id: row.source_entity_id,
1360 target_entity_id: row.target_entity_id,
1361 relation: row.relation,
1362 fact: row.fact,
1363 #[allow(clippy::cast_possible_truncation)]
1364 confidence: row.confidence as f32,
1365 valid_from: row.valid_from,
1366 valid_to: row.valid_to,
1367 created_at: row.created_at,
1368 expired_at: row.expired_at,
1369 episode_id: row.episode_id.map(MessageId),
1370 qdrant_point_id: row.qdrant_point_id,
1371 }
1372}
1373
1374#[derive(sqlx::FromRow)]
1375struct CommunityRow {
1376 id: i64,
1377 name: String,
1378 summary: String,
1379 entity_ids: String,
1380 fingerprint: Option<String>,
1381 created_at: String,
1382 updated_at: String,
1383}
1384
1385#[cfg(test)]
1388mod tests;