1use std::collections::HashMap;
5
6use futures::Stream;
7use sqlx::SqlitePool;
8
9use crate::error::MemoryError;
10use crate::sqlite::messages::sanitize_fts5_query;
11use crate::types::MessageId;
12
13use super::types::{Community, Edge, EdgeType, Entity, EntityAlias, EntityType};
14
15pub struct GraphStore {
16 pool: SqlitePool,
17}
18
19impl GraphStore {
20 #[must_use]
21 pub fn new(pool: SqlitePool) -> Self {
22 Self { pool }
23 }
24
25 #[must_use]
26 pub fn pool(&self) -> &SqlitePool {
27 &self.pool
28 }
29
30 pub async fn upsert_entity(
43 &self,
44 surface_name: &str,
45 canonical_name: &str,
46 entity_type: EntityType,
47 summary: Option<&str>,
48 ) -> Result<i64, MemoryError> {
49 let type_str = entity_type.as_str();
50 let id: i64 = sqlx::query_scalar(
51 "INSERT INTO graph_entities (name, canonical_name, entity_type, summary)
52 VALUES (?1, ?2, ?3, ?4)
53 ON CONFLICT(canonical_name, entity_type) DO UPDATE SET
54 name = excluded.name,
55 summary = COALESCE(excluded.summary, summary),
56 last_seen_at = datetime('now')
57 RETURNING id",
58 )
59 .bind(surface_name)
60 .bind(canonical_name)
61 .bind(type_str)
62 .bind(summary)
63 .fetch_one(&self.pool)
64 .await?;
65 Ok(id)
66 }
67
68 pub async fn find_entity(
74 &self,
75 canonical_name: &str,
76 entity_type: EntityType,
77 ) -> Result<Option<Entity>, MemoryError> {
78 let type_str = entity_type.as_str();
79 let row: Option<EntityRow> = sqlx::query_as(
80 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
81 FROM graph_entities
82 WHERE canonical_name = ?1 AND entity_type = ?2",
83 )
84 .bind(canonical_name)
85 .bind(type_str)
86 .fetch_optional(&self.pool)
87 .await?;
88 row.map(entity_from_row).transpose()
89 }
90
91 pub async fn find_entity_by_id(&self, entity_id: i64) -> Result<Option<Entity>, MemoryError> {
97 let row: Option<EntityRow> = sqlx::query_as(
98 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
99 FROM graph_entities
100 WHERE id = ?1",
101 )
102 .bind(entity_id)
103 .fetch_optional(&self.pool)
104 .await?;
105 row.map(entity_from_row).transpose()
106 }
107
108 pub async fn set_entity_qdrant_point_id(
114 &self,
115 entity_id: i64,
116 point_id: &str,
117 ) -> Result<(), MemoryError> {
118 sqlx::query("UPDATE graph_entities SET qdrant_point_id = ?1 WHERE id = ?2")
119 .bind(point_id)
120 .bind(entity_id)
121 .execute(&self.pool)
122 .await?;
123 Ok(())
124 }
125
126 pub async fn find_entities_fuzzy(
147 &self,
148 query: &str,
149 limit: usize,
150 ) -> Result<Vec<Entity>, MemoryError> {
151 const FTS5_OPERATORS: &[&str] = &["AND", "OR", "NOT", "NEAR"];
155 let query = &query[..query.floor_char_boundary(512)];
156 let sanitized = sanitize_fts5_query(query);
159 if sanitized.is_empty() {
160 return Ok(vec![]);
161 }
162 let fts_query: String = sanitized
163 .split_whitespace()
164 .filter(|t| !FTS5_OPERATORS.contains(t))
165 .map(|t| format!("{t}*"))
166 .collect::<Vec<_>>()
167 .join(" ");
168 if fts_query.is_empty() {
169 return Ok(vec![]);
170 }
171
172 let limit = i64::try_from(limit)?;
173 let rows: Vec<EntityRow> = sqlx::query_as(
176 "SELECT DISTINCT e.id, e.name, e.canonical_name, e.entity_type, e.summary,
177 e.first_seen_at, e.last_seen_at, e.qdrant_point_id
178 FROM graph_entities_fts fts
179 JOIN graph_entities e ON e.id = fts.rowid
180 WHERE graph_entities_fts MATCH ?1
181 UNION
182 SELECT e.id, e.name, e.canonical_name, e.entity_type, e.summary,
183 e.first_seen_at, e.last_seen_at, e.qdrant_point_id
184 FROM graph_entity_aliases a
185 JOIN graph_entities e ON e.id = a.entity_id
186 WHERE a.alias_name LIKE ?2 ESCAPE '\\' COLLATE NOCASE
187 LIMIT ?3",
188 )
189 .bind(&fts_query)
190 .bind(format!(
191 "%{}%",
192 query
193 .trim()
194 .replace('\\', "\\\\")
195 .replace('%', "\\%")
196 .replace('_', "\\_")
197 ))
198 .bind(limit)
199 .fetch_all(&self.pool)
200 .await?;
201 rows.into_iter()
202 .map(entity_from_row)
203 .collect::<Result<Vec<_>, _>>()
204 }
205
206 pub async fn checkpoint_wal(&self) -> Result<(), MemoryError> {
216 sqlx::query("PRAGMA wal_checkpoint(PASSIVE)")
217 .execute(&self.pool)
218 .await?;
219 Ok(())
220 }
221
222 pub fn all_entities_stream(&self) -> impl Stream<Item = Result<Entity, MemoryError>> + '_ {
224 use futures::StreamExt as _;
225 sqlx::query_as::<_, EntityRow>(
226 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
227 FROM graph_entities ORDER BY id ASC",
228 )
229 .fetch(&self.pool)
230 .map(|r: Result<EntityRow, sqlx::Error>| {
231 r.map_err(MemoryError::from).and_then(entity_from_row)
232 })
233 }
234
235 pub async fn add_alias(&self, entity_id: i64, alias_name: &str) -> Result<(), MemoryError> {
243 sqlx::query(
244 "INSERT OR IGNORE INTO graph_entity_aliases (entity_id, alias_name) VALUES (?1, ?2)",
245 )
246 .bind(entity_id)
247 .bind(alias_name)
248 .execute(&self.pool)
249 .await?;
250 Ok(())
251 }
252
253 pub async fn find_entity_by_alias(
261 &self,
262 alias_name: &str,
263 entity_type: EntityType,
264 ) -> Result<Option<Entity>, MemoryError> {
265 let type_str = entity_type.as_str();
266 let row: Option<EntityRow> = sqlx::query_as(
267 "SELECT e.id, e.name, e.canonical_name, e.entity_type, e.summary,
268 e.first_seen_at, e.last_seen_at, e.qdrant_point_id
269 FROM graph_entity_aliases a
270 JOIN graph_entities e ON e.id = a.entity_id
271 WHERE a.alias_name = ?1 COLLATE NOCASE
272 AND e.entity_type = ?2
273 ORDER BY e.id ASC
274 LIMIT 1",
275 )
276 .bind(alias_name)
277 .bind(type_str)
278 .fetch_optional(&self.pool)
279 .await?;
280 row.map(entity_from_row).transpose()
281 }
282
283 pub async fn aliases_for_entity(
289 &self,
290 entity_id: i64,
291 ) -> Result<Vec<EntityAlias>, MemoryError> {
292 let rows: Vec<AliasRow> = sqlx::query_as(
293 "SELECT id, entity_id, alias_name, created_at
294 FROM graph_entity_aliases
295 WHERE entity_id = ?1
296 ORDER BY id ASC",
297 )
298 .bind(entity_id)
299 .fetch_all(&self.pool)
300 .await?;
301 Ok(rows.into_iter().map(alias_from_row).collect())
302 }
303
304 pub async fn all_entities(&self) -> Result<Vec<Entity>, MemoryError> {
310 use futures::TryStreamExt as _;
311 self.all_entities_stream().try_collect().await
312 }
313
314 pub async fn entity_count(&self) -> Result<i64, MemoryError> {
320 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM graph_entities")
321 .fetch_one(&self.pool)
322 .await?;
323 Ok(count)
324 }
325
326 pub async fn insert_edge(
344 &self,
345 source_entity_id: i64,
346 target_entity_id: i64,
347 relation: &str,
348 fact: &str,
349 confidence: f32,
350 episode_id: Option<MessageId>,
351 ) -> Result<i64, MemoryError> {
352 self.insert_edge_typed(
353 source_entity_id,
354 target_entity_id,
355 relation,
356 fact,
357 confidence,
358 episode_id,
359 EdgeType::Semantic,
360 )
361 .await
362 }
363
364 #[allow(clippy::too_many_arguments)]
373 pub async fn insert_edge_typed(
374 &self,
375 source_entity_id: i64,
376 target_entity_id: i64,
377 relation: &str,
378 fact: &str,
379 confidence: f32,
380 episode_id: Option<MessageId>,
381 edge_type: EdgeType,
382 ) -> Result<i64, MemoryError> {
383 if source_entity_id == target_entity_id {
384 return Err(MemoryError::InvalidInput(format!(
385 "self-loop edge rejected: source and target are the same entity (id={source_entity_id})"
386 )));
387 }
388 let confidence = confidence.clamp(0.0, 1.0);
389 let edge_type_str = edge_type.as_str();
390
391 let existing: Option<(i64, f64)> = sqlx::query_as(
392 "SELECT id, confidence FROM graph_edges
393 WHERE source_entity_id = ?1
394 AND target_entity_id = ?2
395 AND relation = ?3
396 AND edge_type = ?4
397 AND valid_to IS NULL
398 LIMIT 1",
399 )
400 .bind(source_entity_id)
401 .bind(target_entity_id)
402 .bind(relation)
403 .bind(edge_type_str)
404 .fetch_optional(&self.pool)
405 .await?;
406
407 if let Some((existing_id, stored_conf)) = existing {
408 let updated_conf = f64::from(confidence).max(stored_conf);
409 sqlx::query("UPDATE graph_edges SET confidence = ?1 WHERE id = ?2")
410 .bind(updated_conf)
411 .bind(existing_id)
412 .execute(&self.pool)
413 .await?;
414 return Ok(existing_id);
415 }
416
417 let episode_raw: Option<i64> = episode_id.map(|m| m.0);
418 let id: i64 = sqlx::query_scalar(
419 "INSERT INTO graph_edges
420 (source_entity_id, target_entity_id, relation, fact, confidence, episode_id, edge_type)
421 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
422 RETURNING id",
423 )
424 .bind(source_entity_id)
425 .bind(target_entity_id)
426 .bind(relation)
427 .bind(fact)
428 .bind(f64::from(confidence))
429 .bind(episode_raw)
430 .bind(edge_type_str)
431 .fetch_one(&self.pool)
432 .await?;
433 Ok(id)
434 }
435
436 pub async fn invalidate_edge(&self, edge_id: i64) -> Result<(), MemoryError> {
442 sqlx::query(
443 "UPDATE graph_edges SET valid_to = datetime('now'), expired_at = datetime('now')
444 WHERE id = ?1",
445 )
446 .bind(edge_id)
447 .execute(&self.pool)
448 .await?;
449 Ok(())
450 }
451
452 pub async fn edges_for_entities(
469 &self,
470 entity_ids: &[i64],
471 edge_types: &[super::types::EdgeType],
472 ) -> Result<Vec<Edge>, MemoryError> {
473 const MAX_BATCH_ENTITIES: usize = 490;
477
478 let mut all_edges: Vec<Edge> = Vec::new();
479
480 for chunk in entity_ids.chunks(MAX_BATCH_ENTITIES) {
481 let edges = self.query_batch_edges(chunk, edge_types).await?;
482 all_edges.extend(edges);
483 }
484
485 Ok(all_edges)
486 }
487
488 async fn query_batch_edges(
492 &self,
493 entity_ids: &[i64],
494 edge_types: &[super::types::EdgeType],
495 ) -> Result<Vec<Edge>, MemoryError> {
496 if entity_ids.is_empty() {
497 return Ok(Vec::new());
498 }
499
500 let placeholders: String = (1..=entity_ids.len())
503 .map(|i| format!("?{i}"))
504 .collect::<Vec<_>>()
505 .join(", ");
506
507 let sql = if edge_types.is_empty() {
508 format!(
509 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
510 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id,
511 edge_type, retrieval_count, last_retrieved_at
512 FROM graph_edges
513 WHERE valid_to IS NULL
514 AND (source_entity_id IN ({placeholders}) OR target_entity_id IN ({placeholders}))"
515 )
516 } else {
517 let type_placeholders: String = (entity_ids.len() + 1
518 ..=entity_ids.len() + edge_types.len())
519 .map(|i| format!("?{i}"))
520 .collect::<Vec<_>>()
521 .join(", ");
522 format!(
523 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
524 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id,
525 edge_type, retrieval_count, last_retrieved_at
526 FROM graph_edges
527 WHERE valid_to IS NULL
528 AND (source_entity_id IN ({placeholders}) OR target_entity_id IN ({placeholders}))
529 AND edge_type IN ({type_placeholders})"
530 )
531 };
532
533 let mut query = sqlx::query_as::<_, EdgeRow>(&sql);
535 for id in entity_ids {
536 query = query.bind(*id);
537 }
538 for et in edge_types {
539 query = query.bind(et.as_str());
540 }
541
542 let rows: Vec<EdgeRow> = query.fetch_all(&self.pool).await?;
543 Ok(rows.into_iter().map(edge_from_row).collect())
544 }
545
546 pub async fn edges_for_entity(&self, entity_id: i64) -> Result<Vec<Edge>, MemoryError> {
552 let rows: Vec<EdgeRow> = sqlx::query_as(
553 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
554 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id,
555 edge_type, retrieval_count, last_retrieved_at
556 FROM graph_edges
557 WHERE valid_to IS NULL
558 AND (source_entity_id = ?1 OR target_entity_id = ?1)",
559 )
560 .bind(entity_id)
561 .fetch_all(&self.pool)
562 .await?;
563 Ok(rows.into_iter().map(edge_from_row).collect())
564 }
565
566 pub async fn edge_history_for_entity(
573 &self,
574 entity_id: i64,
575 limit: usize,
576 ) -> Result<Vec<Edge>, MemoryError> {
577 let limit = i64::try_from(limit)?;
578 let rows: Vec<EdgeRow> = sqlx::query_as(
579 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
580 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id,
581 edge_type, retrieval_count, last_retrieved_at
582 FROM graph_edges
583 WHERE source_entity_id = ?1 OR target_entity_id = ?1
584 ORDER BY valid_from DESC
585 LIMIT ?2",
586 )
587 .bind(entity_id)
588 .bind(limit)
589 .fetch_all(&self.pool)
590 .await?;
591 Ok(rows.into_iter().map(edge_from_row).collect())
592 }
593
594 pub async fn edges_between(
600 &self,
601 entity_a: i64,
602 entity_b: i64,
603 ) -> Result<Vec<Edge>, MemoryError> {
604 let rows: Vec<EdgeRow> = sqlx::query_as(
605 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
606 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id,
607 edge_type, retrieval_count, last_retrieved_at
608 FROM graph_edges
609 WHERE valid_to IS NULL
610 AND ((source_entity_id = ?1 AND target_entity_id = ?2)
611 OR (source_entity_id = ?2 AND target_entity_id = ?1))",
612 )
613 .bind(entity_a)
614 .bind(entity_b)
615 .fetch_all(&self.pool)
616 .await?;
617 Ok(rows.into_iter().map(edge_from_row).collect())
618 }
619
620 pub async fn edges_exact(
626 &self,
627 source_entity_id: i64,
628 target_entity_id: i64,
629 ) -> Result<Vec<Edge>, MemoryError> {
630 let rows: Vec<EdgeRow> = sqlx::query_as(
631 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
632 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id,
633 edge_type, retrieval_count, last_retrieved_at
634 FROM graph_edges
635 WHERE valid_to IS NULL
636 AND source_entity_id = ?1
637 AND target_entity_id = ?2",
638 )
639 .bind(source_entity_id)
640 .bind(target_entity_id)
641 .fetch_all(&self.pool)
642 .await?;
643 Ok(rows.into_iter().map(edge_from_row).collect())
644 }
645
646 pub async fn active_edge_count(&self) -> Result<i64, MemoryError> {
652 let count: i64 =
653 sqlx::query_scalar("SELECT COUNT(*) FROM graph_edges WHERE valid_to IS NULL")
654 .fetch_one(&self.pool)
655 .await?;
656 Ok(count)
657 }
658
659 pub async fn edge_type_distribution(&self) -> Result<Vec<(String, i64)>, MemoryError> {
665 let rows: Vec<(String, i64)> = sqlx::query_as(
666 "SELECT edge_type, COUNT(*) FROM graph_edges WHERE valid_to IS NULL GROUP BY edge_type ORDER BY edge_type",
667 )
668 .fetch_all(&self.pool)
669 .await?;
670 Ok(rows)
671 }
672
673 pub async fn upsert_community(
685 &self,
686 name: &str,
687 summary: &str,
688 entity_ids: &[i64],
689 fingerprint: Option<&str>,
690 ) -> Result<i64, MemoryError> {
691 let entity_ids_json = serde_json::to_string(entity_ids)?;
692 let id: i64 = sqlx::query_scalar(
693 "INSERT INTO graph_communities (name, summary, entity_ids, fingerprint)
694 VALUES (?1, ?2, ?3, ?4)
695 ON CONFLICT(name) DO UPDATE SET
696 summary = excluded.summary,
697 entity_ids = excluded.entity_ids,
698 fingerprint = COALESCE(excluded.fingerprint, fingerprint),
699 updated_at = datetime('now')
700 RETURNING id",
701 )
702 .bind(name)
703 .bind(summary)
704 .bind(entity_ids_json)
705 .bind(fingerprint)
706 .fetch_one(&self.pool)
707 .await?;
708 Ok(id)
709 }
710
711 pub async fn community_fingerprints(&self) -> Result<HashMap<String, i64>, MemoryError> {
718 let rows: Vec<(String, i64)> = sqlx::query_as(
719 "SELECT fingerprint, id FROM graph_communities WHERE fingerprint IS NOT NULL",
720 )
721 .fetch_all(&self.pool)
722 .await?;
723 Ok(rows.into_iter().collect())
724 }
725
726 pub async fn delete_community_by_id(&self, id: i64) -> Result<(), MemoryError> {
732 sqlx::query("DELETE FROM graph_communities WHERE id = ?1")
733 .bind(id)
734 .execute(&self.pool)
735 .await?;
736 Ok(())
737 }
738
739 pub async fn clear_community_fingerprint(&self, id: i64) -> Result<(), MemoryError> {
748 sqlx::query("UPDATE graph_communities SET fingerprint = NULL WHERE id = ?1")
749 .bind(id)
750 .execute(&self.pool)
751 .await?;
752 Ok(())
753 }
754
755 pub async fn community_for_entity(
764 &self,
765 entity_id: i64,
766 ) -> Result<Option<Community>, MemoryError> {
767 let row: Option<CommunityRow> = sqlx::query_as(
768 "SELECT c.id, c.name, c.summary, c.entity_ids, c.fingerprint, c.created_at, c.updated_at
769 FROM graph_communities c, json_each(c.entity_ids) j
770 WHERE CAST(j.value AS INTEGER) = ?1
771 LIMIT 1",
772 )
773 .bind(entity_id)
774 .fetch_optional(&self.pool)
775 .await?;
776 match row {
777 Some(row) => {
778 let entity_ids: Vec<i64> = serde_json::from_str(&row.entity_ids)?;
779 Ok(Some(Community {
780 id: row.id,
781 name: row.name,
782 summary: row.summary,
783 entity_ids,
784 fingerprint: row.fingerprint,
785 created_at: row.created_at,
786 updated_at: row.updated_at,
787 }))
788 }
789 None => Ok(None),
790 }
791 }
792
793 pub async fn all_communities(&self) -> Result<Vec<Community>, MemoryError> {
799 let rows: Vec<CommunityRow> = sqlx::query_as(
800 "SELECT id, name, summary, entity_ids, fingerprint, created_at, updated_at
801 FROM graph_communities
802 ORDER BY id ASC",
803 )
804 .fetch_all(&self.pool)
805 .await?;
806
807 rows.into_iter()
808 .map(|row| {
809 let entity_ids: Vec<i64> = serde_json::from_str(&row.entity_ids)?;
810 Ok(Community {
811 id: row.id,
812 name: row.name,
813 summary: row.summary,
814 entity_ids,
815 fingerprint: row.fingerprint,
816 created_at: row.created_at,
817 updated_at: row.updated_at,
818 })
819 })
820 .collect()
821 }
822
823 pub async fn community_count(&self) -> Result<i64, MemoryError> {
829 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM graph_communities")
830 .fetch_one(&self.pool)
831 .await?;
832 Ok(count)
833 }
834
835 pub async fn get_metadata(&self, key: &str) -> Result<Option<String>, MemoryError> {
843 let val: Option<String> =
844 sqlx::query_scalar("SELECT value FROM graph_metadata WHERE key = ?1")
845 .bind(key)
846 .fetch_optional(&self.pool)
847 .await?;
848 Ok(val)
849 }
850
851 pub async fn set_metadata(&self, key: &str, value: &str) -> Result<(), MemoryError> {
857 sqlx::query(
858 "INSERT INTO graph_metadata (key, value) VALUES (?1, ?2)
859 ON CONFLICT(key) DO UPDATE SET value = excluded.value",
860 )
861 .bind(key)
862 .bind(value)
863 .execute(&self.pool)
864 .await?;
865 Ok(())
866 }
867
868 pub async fn extraction_count(&self) -> Result<i64, MemoryError> {
876 let val = self.get_metadata("extraction_count").await?;
877 Ok(val.and_then(|v| v.parse::<i64>().ok()).unwrap_or(0))
878 }
879
880 pub fn all_active_edges_stream(&self) -> impl Stream<Item = Result<Edge, MemoryError>> + '_ {
882 use futures::StreamExt as _;
883 sqlx::query_as::<_, EdgeRow>(
884 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
885 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id,
886 edge_type, retrieval_count, last_retrieved_at
887 FROM graph_edges
888 WHERE valid_to IS NULL
889 ORDER BY id ASC",
890 )
891 .fetch(&self.pool)
892 .map(|r| r.map_err(MemoryError::from).map(edge_from_row))
893 }
894
895 pub async fn edges_after_id(
912 &self,
913 after_id: i64,
914 limit: i64,
915 ) -> Result<Vec<Edge>, MemoryError> {
916 let rows: Vec<EdgeRow> = sqlx::query_as(
917 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
918 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id,
919 edge_type, retrieval_count, last_retrieved_at
920 FROM graph_edges
921 WHERE valid_to IS NULL AND id > ?1
922 ORDER BY id ASC
923 LIMIT ?2",
924 )
925 .bind(after_id)
926 .bind(limit)
927 .fetch_all(&self.pool)
928 .await?;
929 Ok(rows.into_iter().map(edge_from_row).collect())
930 }
931
932 pub async fn find_community_by_id(&self, id: i64) -> Result<Option<Community>, MemoryError> {
938 let row: Option<CommunityRow> = sqlx::query_as(
939 "SELECT id, name, summary, entity_ids, fingerprint, created_at, updated_at
940 FROM graph_communities
941 WHERE id = ?1",
942 )
943 .bind(id)
944 .fetch_optional(&self.pool)
945 .await?;
946 match row {
947 Some(row) => {
948 let entity_ids: Vec<i64> = serde_json::from_str(&row.entity_ids)?;
949 Ok(Some(Community {
950 id: row.id,
951 name: row.name,
952 summary: row.summary,
953 entity_ids,
954 fingerprint: row.fingerprint,
955 created_at: row.created_at,
956 updated_at: row.updated_at,
957 }))
958 }
959 None => Ok(None),
960 }
961 }
962
963 pub async fn delete_all_communities(&self) -> Result<(), MemoryError> {
969 sqlx::query("DELETE FROM graph_communities")
970 .execute(&self.pool)
971 .await?;
972 Ok(())
973 }
974
975 pub async fn find_entities_ranked(
989 &self,
990 query: &str,
991 limit: usize,
992 ) -> Result<Vec<(Entity, f32)>, MemoryError> {
993 type EntityFtsRow = (
996 i64,
997 String,
998 String,
999 String,
1000 Option<String>,
1001 String,
1002 String,
1003 Option<String>,
1004 f64,
1005 );
1006
1007 const FTS5_OPERATORS: &[&str] = &["AND", "OR", "NOT", "NEAR"];
1008 let query = &query[..query.floor_char_boundary(512)];
1009 let sanitized = crate::sqlite::messages::sanitize_fts5_query(query);
1010 if sanitized.is_empty() {
1011 return Ok(vec![]);
1012 }
1013 let fts_query: String = sanitized
1014 .split_whitespace()
1015 .filter(|t| !FTS5_OPERATORS.contains(t))
1016 .map(|t| format!("{t}*"))
1017 .collect::<Vec<_>>()
1018 .join(" ");
1019 if fts_query.is_empty() {
1020 return Ok(vec![]);
1021 }
1022
1023 let limit_i64 = i64::try_from(limit)?;
1024
1025 let rows: Vec<EntityFtsRow> = sqlx::query_as(
1028 "SELECT * FROM (
1029 SELECT e.id, e.name, e.canonical_name, e.entity_type, e.summary,
1030 e.first_seen_at, e.last_seen_at, e.qdrant_point_id,
1031 -bm25(graph_entities_fts, 10.0, 1.0) AS fts_rank
1032 FROM graph_entities_fts fts
1033 JOIN graph_entities e ON e.id = fts.rowid
1034 WHERE graph_entities_fts MATCH ?1
1035 UNION ALL
1036 SELECT e.id, e.name, e.canonical_name, e.entity_type, e.summary,
1037 e.first_seen_at, e.last_seen_at, e.qdrant_point_id,
1038 0.5 AS fts_rank
1039 FROM graph_entity_aliases a
1040 JOIN graph_entities e ON e.id = a.entity_id
1041 WHERE a.alias_name LIKE ?2 ESCAPE '\\' COLLATE NOCASE
1042 )
1043 ORDER BY fts_rank DESC
1044 LIMIT ?3",
1045 )
1046 .bind(&fts_query)
1047 .bind(format!(
1048 "%{}%",
1049 query
1050 .trim()
1051 .replace('\\', "\\\\")
1052 .replace('%', "\\%")
1053 .replace('_', "\\_")
1054 ))
1055 .bind(limit_i64)
1056 .fetch_all(&self.pool)
1057 .await?;
1058
1059 if rows.is_empty() {
1060 return Ok(vec![]);
1061 }
1062
1063 let max_score: f64 = rows.iter().map(|r| r.8).fold(0.0_f64, f64::max);
1065 let max_score = if max_score <= 0.0 { 1.0 } else { max_score };
1066
1067 let mut seen_ids: std::collections::HashSet<i64> = std::collections::HashSet::new();
1069 let mut result: Vec<(Entity, f32)> = Vec::with_capacity(rows.len());
1070 for (
1071 id,
1072 name,
1073 canonical_name,
1074 entity_type_str,
1075 summary,
1076 first_seen_at,
1077 last_seen_at,
1078 qdrant_point_id,
1079 raw_score,
1080 ) in rows
1081 {
1082 if !seen_ids.insert(id) {
1083 continue;
1084 }
1085 let entity_type = entity_type_str
1086 .parse()
1087 .unwrap_or(super::types::EntityType::Concept);
1088 let entity = Entity {
1089 id,
1090 name,
1091 canonical_name,
1092 entity_type,
1093 summary,
1094 first_seen_at,
1095 last_seen_at,
1096 qdrant_point_id,
1097 };
1098 #[allow(clippy::cast_possible_truncation)]
1099 let normalized = (raw_score / max_score).clamp(0.0, 1.0) as f32;
1100 result.push((entity, normalized));
1101 }
1102
1103 Ok(result)
1104 }
1105
1106 pub async fn entity_structural_scores(
1116 &self,
1117 entity_ids: &[i64],
1118 ) -> Result<HashMap<i64, f32>, MemoryError> {
1119 const MAX_BATCH: usize = 163;
1122
1123 if entity_ids.is_empty() {
1124 return Ok(HashMap::new());
1125 }
1126
1127 let mut all_rows: Vec<(i64, i64, i64)> = Vec::new();
1128 for chunk in entity_ids.chunks(MAX_BATCH) {
1129 let placeholders = chunk
1130 .iter()
1131 .enumerate()
1132 .map(|(i, _)| format!("?{}", i + 1))
1133 .collect::<Vec<_>>()
1134 .join(", ");
1135
1136 let sql = format!(
1138 "SELECT entity_id,
1139 COUNT(*) AS degree,
1140 COUNT(DISTINCT edge_type) AS type_diversity
1141 FROM (
1142 SELECT source_entity_id AS entity_id, edge_type
1143 FROM graph_edges
1144 WHERE valid_to IS NULL AND source_entity_id IN ({placeholders})
1145 UNION ALL
1146 SELECT target_entity_id AS entity_id, edge_type
1147 FROM graph_edges
1148 WHERE valid_to IS NULL AND target_entity_id IN ({placeholders})
1149 )
1150 WHERE entity_id IN ({placeholders})
1151 GROUP BY entity_id"
1152 );
1153
1154 let mut query = sqlx::query_as::<_, (i64, i64, i64)>(&sql);
1155 for id in chunk {
1157 query = query.bind(*id);
1158 }
1159 for id in chunk {
1160 query = query.bind(*id);
1161 }
1162 for id in chunk {
1163 query = query.bind(*id);
1164 }
1165
1166 let chunk_rows: Vec<(i64, i64, i64)> = query.fetch_all(&self.pool).await?;
1167 all_rows.extend(chunk_rows);
1168 }
1169
1170 if all_rows.is_empty() {
1171 return Ok(entity_ids.iter().map(|&id| (id, 0.0_f32)).collect());
1172 }
1173
1174 let max_degree = all_rows
1175 .iter()
1176 .map(|(_, d, _)| *d)
1177 .max()
1178 .unwrap_or(1)
1179 .max(1);
1180
1181 let mut scores: HashMap<i64, f32> = entity_ids.iter().map(|&id| (id, 0.0_f32)).collect();
1182 for (entity_id, degree, type_diversity) in all_rows {
1183 #[allow(clippy::cast_precision_loss)]
1184 let norm_degree = degree as f32 / max_degree as f32;
1185 #[allow(clippy::cast_precision_loss)]
1186 let norm_diversity = (type_diversity as f32 / 4.0).min(1.0);
1187 let score = 0.6 * norm_degree + 0.4 * norm_diversity;
1188 scores.insert(entity_id, score);
1189 }
1190
1191 Ok(scores)
1192 }
1193
1194 pub async fn entity_community_ids(
1203 &self,
1204 entity_ids: &[i64],
1205 ) -> Result<HashMap<i64, i64>, MemoryError> {
1206 const MAX_BATCH: usize = 490;
1207
1208 if entity_ids.is_empty() {
1209 return Ok(HashMap::new());
1210 }
1211
1212 let mut result: HashMap<i64, i64> = HashMap::new();
1213 for chunk in entity_ids.chunks(MAX_BATCH) {
1214 let placeholders = chunk
1215 .iter()
1216 .enumerate()
1217 .map(|(i, _)| format!("?{}", i + 1))
1218 .collect::<Vec<_>>()
1219 .join(", ");
1220
1221 let sql = format!(
1222 "SELECT CAST(j.value AS INTEGER) AS entity_id, c.id AS community_id
1223 FROM graph_communities c, json_each(c.entity_ids) j
1224 WHERE CAST(j.value AS INTEGER) IN ({placeholders})"
1225 );
1226
1227 let mut query = sqlx::query_as::<_, (i64, i64)>(&sql);
1228 for id in chunk {
1229 query = query.bind(*id);
1230 }
1231
1232 let rows: Vec<(i64, i64)> = query.fetch_all(&self.pool).await?;
1233 result.extend(rows);
1234 }
1235
1236 Ok(result)
1237 }
1238
1239 pub async fn record_edge_retrieval(&self, edge_ids: &[i64]) -> Result<(), MemoryError> {
1248 const MAX_BATCH: usize = 490;
1249 for chunk in edge_ids.chunks(MAX_BATCH) {
1250 let placeholders = chunk
1251 .iter()
1252 .enumerate()
1253 .map(|(i, _)| format!("?{}", i + 1))
1254 .collect::<Vec<_>>()
1255 .join(", ");
1256 let sql = format!(
1257 "UPDATE graph_edges
1258 SET retrieval_count = retrieval_count + 1,
1259 last_retrieved_at = unixepoch('now')
1260 WHERE id IN ({placeholders})"
1261 );
1262 let mut q = sqlx::query(&sql);
1263 for id in chunk {
1264 q = q.bind(*id);
1265 }
1266 q.execute(&self.pool).await?;
1267 }
1268 Ok(())
1269 }
1270
1271 pub async fn decay_edge_retrieval_counts(
1280 &self,
1281 decay_lambda: f64,
1282 interval_secs: u64,
1283 ) -> Result<usize, MemoryError> {
1284 let result = sqlx::query(
1285 "UPDATE graph_edges
1286 SET retrieval_count = MAX(CAST(retrieval_count * ?1 AS INTEGER), 0)
1287 WHERE valid_to IS NULL
1288 AND retrieval_count > 0
1289 AND (last_retrieved_at IS NULL OR last_retrieved_at < unixepoch('now') - ?2)",
1290 )
1291 .bind(decay_lambda)
1292 .bind(i64::try_from(interval_secs).unwrap_or(i64::MAX))
1293 .execute(&self.pool)
1294 .await?;
1295 Ok(usize::try_from(result.rows_affected())?)
1296 }
1297
1298 pub async fn delete_expired_edges(&self, retention_days: u32) -> Result<usize, MemoryError> {
1304 let days = i64::from(retention_days);
1305 let result = sqlx::query(
1306 "DELETE FROM graph_edges
1307 WHERE expired_at IS NOT NULL
1308 AND expired_at < datetime('now', '-' || ?1 || ' days')",
1309 )
1310 .bind(days)
1311 .execute(&self.pool)
1312 .await?;
1313 Ok(usize::try_from(result.rows_affected())?)
1314 }
1315
1316 pub async fn delete_orphan_entities(&self, retention_days: u32) -> Result<usize, MemoryError> {
1322 let days = i64::from(retention_days);
1323 let result = sqlx::query(
1324 "DELETE FROM graph_entities
1325 WHERE id NOT IN (
1326 SELECT DISTINCT source_entity_id FROM graph_edges WHERE valid_to IS NULL
1327 UNION
1328 SELECT DISTINCT target_entity_id FROM graph_edges WHERE valid_to IS NULL
1329 )
1330 AND last_seen_at < datetime('now', '-' || ?1 || ' days')",
1331 )
1332 .bind(days)
1333 .execute(&self.pool)
1334 .await?;
1335 Ok(usize::try_from(result.rows_affected())?)
1336 }
1337
1338 pub async fn cap_entities(&self, max_entities: usize) -> Result<usize, MemoryError> {
1347 let current = self.entity_count().await?;
1348 let max = i64::try_from(max_entities)?;
1349 if current <= max {
1350 return Ok(0);
1351 }
1352 let excess = current - max;
1353 let result = sqlx::query(
1354 "DELETE FROM graph_entities
1355 WHERE id IN (
1356 SELECT e.id
1357 FROM graph_entities e
1358 LEFT JOIN (
1359 SELECT source_entity_id AS eid, COUNT(*) AS cnt
1360 FROM graph_edges WHERE valid_to IS NULL GROUP BY source_entity_id
1361 UNION ALL
1362 SELECT target_entity_id AS eid, COUNT(*) AS cnt
1363 FROM graph_edges WHERE valid_to IS NULL GROUP BY target_entity_id
1364 ) edge_counts ON e.id = edge_counts.eid
1365 ORDER BY COALESCE(edge_counts.cnt, 0) ASC, e.last_seen_at ASC
1366 LIMIT ?1
1367 )",
1368 )
1369 .bind(excess)
1370 .execute(&self.pool)
1371 .await?;
1372 Ok(usize::try_from(result.rows_affected())?)
1373 }
1374
1375 pub async fn edges_at_timestamp(
1389 &self,
1390 entity_id: i64,
1391 timestamp: &str,
1392 ) -> Result<Vec<Edge>, MemoryError> {
1393 let rows: Vec<EdgeRow> = sqlx::query_as(
1397 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
1398 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id,
1399 edge_type, retrieval_count, last_retrieved_at
1400 FROM graph_edges
1401 WHERE valid_to IS NULL
1402 AND valid_from <= ?2
1403 AND (source_entity_id = ?1 OR target_entity_id = ?1)
1404 UNION ALL
1405 SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
1406 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id,
1407 edge_type, retrieval_count, last_retrieved_at
1408 FROM graph_edges
1409 WHERE valid_to IS NOT NULL
1410 AND valid_from <= ?2
1411 AND valid_to > ?2
1412 AND (source_entity_id = ?1 OR target_entity_id = ?1)",
1413 )
1414 .bind(entity_id)
1415 .bind(timestamp)
1416 .fetch_all(&self.pool)
1417 .await?;
1418 Ok(rows.into_iter().map(edge_from_row).collect())
1419 }
1420
1421 pub async fn edge_history(
1430 &self,
1431 source_entity_id: i64,
1432 predicate: &str,
1433 relation: Option<&str>,
1434 limit: usize,
1435 ) -> Result<Vec<Edge>, MemoryError> {
1436 let escaped = predicate
1438 .replace('\\', "\\\\")
1439 .replace('%', "\\%")
1440 .replace('_', "\\_");
1441 let like_pattern = format!("%{escaped}%");
1442 let limit = i64::try_from(limit)?;
1443 let rows: Vec<EdgeRow> = if let Some(rel) = relation {
1444 sqlx::query_as(
1445 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
1446 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id,
1447 edge_type, retrieval_count, last_retrieved_at
1448 FROM graph_edges
1449 WHERE source_entity_id = ?1
1450 AND fact LIKE ?2 ESCAPE '\\'
1451 AND relation = ?3
1452 ORDER BY valid_from DESC
1453 LIMIT ?4",
1454 )
1455 .bind(source_entity_id)
1456 .bind(&like_pattern)
1457 .bind(rel)
1458 .bind(limit)
1459 .fetch_all(&self.pool)
1460 .await?
1461 } else {
1462 sqlx::query_as(
1463 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
1464 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id,
1465 edge_type, retrieval_count, last_retrieved_at
1466 FROM graph_edges
1467 WHERE source_entity_id = ?1
1468 AND fact LIKE ?2 ESCAPE '\\'
1469 ORDER BY valid_from DESC
1470 LIMIT ?3",
1471 )
1472 .bind(source_entity_id)
1473 .bind(&like_pattern)
1474 .bind(limit)
1475 .fetch_all(&self.pool)
1476 .await?
1477 };
1478 Ok(rows.into_iter().map(edge_from_row).collect())
1479 }
1480
1481 pub async fn bfs(
1498 &self,
1499 start_entity_id: i64,
1500 max_hops: u32,
1501 ) -> Result<(Vec<Entity>, Vec<Edge>), MemoryError> {
1502 self.bfs_with_depth(start_entity_id, max_hops)
1503 .await
1504 .map(|(e, ed, _)| (e, ed))
1505 }
1506
1507 pub async fn bfs_with_depth(
1518 &self,
1519 start_entity_id: i64,
1520 max_hops: u32,
1521 ) -> Result<(Vec<Entity>, Vec<Edge>, std::collections::HashMap<i64, u32>), MemoryError> {
1522 self.bfs_core(start_entity_id, max_hops, None).await
1523 }
1524
1525 pub async fn bfs_at_timestamp(
1536 &self,
1537 start_entity_id: i64,
1538 max_hops: u32,
1539 timestamp: &str,
1540 ) -> Result<(Vec<Entity>, Vec<Edge>, std::collections::HashMap<i64, u32>), MemoryError> {
1541 self.bfs_core(start_entity_id, max_hops, Some(timestamp))
1542 .await
1543 }
1544
1545 pub async fn bfs_typed(
1561 &self,
1562 start_entity_id: i64,
1563 max_hops: u32,
1564 edge_types: &[EdgeType],
1565 ) -> Result<(Vec<Entity>, Vec<Edge>, std::collections::HashMap<i64, u32>), MemoryError> {
1566 if edge_types.is_empty() {
1567 return self.bfs_with_depth(start_entity_id, max_hops).await;
1568 }
1569 self.bfs_core_typed(start_entity_id, max_hops, None, edge_types)
1570 .await
1571 }
1572
1573 async fn bfs_core(
1581 &self,
1582 start_entity_id: i64,
1583 max_hops: u32,
1584 at_timestamp: Option<&str>,
1585 ) -> Result<(Vec<Entity>, Vec<Edge>, std::collections::HashMap<i64, u32>), MemoryError> {
1586 use std::collections::HashMap;
1587
1588 const MAX_FRONTIER: usize = 300;
1591
1592 let mut depth_map: HashMap<i64, u32> = HashMap::new();
1593 let mut frontier: Vec<i64> = vec![start_entity_id];
1594 depth_map.insert(start_entity_id, 0);
1595
1596 for hop in 0..max_hops {
1597 if frontier.is_empty() {
1598 break;
1599 }
1600 frontier.truncate(MAX_FRONTIER);
1601 let placeholders = frontier
1603 .iter()
1604 .enumerate()
1605 .map(|(i, _)| format!("?{}", i + 1))
1606 .collect::<Vec<_>>()
1607 .join(", ");
1608 let edge_filter = if at_timestamp.is_some() {
1609 let ts_pos = frontier.len() * 3 + 1;
1610 format!("valid_from <= ?{ts_pos} AND (valid_to IS NULL OR valid_to > ?{ts_pos})")
1611 } else {
1612 "valid_to IS NULL".to_owned()
1613 };
1614 let neighbour_sql = format!(
1615 "SELECT DISTINCT CASE
1616 WHEN source_entity_id IN ({placeholders}) THEN target_entity_id
1617 ELSE source_entity_id
1618 END as neighbour_id
1619 FROM graph_edges
1620 WHERE {edge_filter}
1621 AND (source_entity_id IN ({placeholders}) OR target_entity_id IN ({placeholders}))"
1622 );
1623 let mut q = sqlx::query_scalar::<_, i64>(&neighbour_sql);
1624 for id in &frontier {
1625 q = q.bind(*id);
1626 }
1627 for id in &frontier {
1628 q = q.bind(*id);
1629 }
1630 for id in &frontier {
1631 q = q.bind(*id);
1632 }
1633 if let Some(ts) = at_timestamp {
1634 q = q.bind(ts);
1635 }
1636 let neighbours: Vec<i64> = q.fetch_all(&self.pool).await?;
1637 let mut next_frontier: Vec<i64> = Vec::new();
1638 for nbr in neighbours {
1639 if let std::collections::hash_map::Entry::Vacant(e) = depth_map.entry(nbr) {
1640 e.insert(hop + 1);
1641 next_frontier.push(nbr);
1642 }
1643 }
1644 frontier = next_frontier;
1645 }
1646
1647 self.bfs_fetch_results(depth_map, at_timestamp).await
1648 }
1649
1650 async fn bfs_core_typed(
1659 &self,
1660 start_entity_id: i64,
1661 max_hops: u32,
1662 at_timestamp: Option<&str>,
1663 edge_types: &[EdgeType],
1664 ) -> Result<(Vec<Entity>, Vec<Edge>, std::collections::HashMap<i64, u32>), MemoryError> {
1665 use std::collections::HashMap;
1666
1667 const MAX_FRONTIER: usize = 300;
1668
1669 let type_strs: Vec<&str> = edge_types.iter().map(|t| t.as_str()).collect();
1670
1671 let mut depth_map: HashMap<i64, u32> = HashMap::new();
1672 let mut frontier: Vec<i64> = vec![start_entity_id];
1673 depth_map.insert(start_entity_id, 0);
1674
1675 let n_types = type_strs.len();
1676 let type_in = (1..=n_types)
1678 .map(|i| format!("?{i}"))
1679 .collect::<Vec<_>>()
1680 .join(", ");
1681 let id_start = n_types + 1;
1682
1683 for hop in 0..max_hops {
1684 if frontier.is_empty() {
1685 break;
1686 }
1687 frontier.truncate(MAX_FRONTIER);
1688
1689 let n_frontier = frontier.len();
1690 let frontier_placeholders = frontier
1692 .iter()
1693 .enumerate()
1694 .map(|(i, _)| format!("?{}", id_start + i))
1695 .collect::<Vec<_>>()
1696 .join(", ");
1697
1698 let edge_filter = if at_timestamp.is_some() {
1699 let ts_pos = id_start + n_frontier * 3;
1700 format!(
1701 "edge_type IN ({type_in}) AND valid_from <= ?{ts_pos} AND (valid_to IS NULL OR valid_to > ?{ts_pos})"
1702 )
1703 } else {
1704 format!("edge_type IN ({type_in}) AND valid_to IS NULL")
1705 };
1706
1707 let neighbour_sql = format!(
1708 "SELECT DISTINCT CASE
1709 WHEN source_entity_id IN ({frontier_placeholders}) THEN target_entity_id
1710 ELSE source_entity_id
1711 END as neighbour_id
1712 FROM graph_edges
1713 WHERE {edge_filter}
1714 AND (source_entity_id IN ({frontier_placeholders}) OR target_entity_id IN ({frontier_placeholders}))"
1715 );
1716
1717 let mut q = sqlx::query_scalar::<_, i64>(&neighbour_sql);
1718 for t in &type_strs {
1720 q = q.bind(*t);
1721 }
1722 for id in &frontier {
1724 q = q.bind(*id);
1725 }
1726 for id in &frontier {
1727 q = q.bind(*id);
1728 }
1729 for id in &frontier {
1730 q = q.bind(*id);
1731 }
1732 if let Some(ts) = at_timestamp {
1733 q = q.bind(ts);
1734 }
1735
1736 let neighbours: Vec<i64> = q.fetch_all(&self.pool).await?;
1737 let mut next_frontier: Vec<i64> = Vec::new();
1738 for nbr in neighbours {
1739 if let std::collections::hash_map::Entry::Vacant(e) = depth_map.entry(nbr) {
1740 e.insert(hop + 1);
1741 next_frontier.push(nbr);
1742 }
1743 }
1744 frontier = next_frontier;
1745 }
1746
1747 self.bfs_fetch_results_typed(depth_map, at_timestamp, &type_strs)
1749 .await
1750 }
1751
1752 async fn bfs_fetch_results_typed(
1760 &self,
1761 depth_map: std::collections::HashMap<i64, u32>,
1762 at_timestamp: Option<&str>,
1763 type_strs: &[&str],
1764 ) -> Result<(Vec<Entity>, Vec<Edge>, std::collections::HashMap<i64, u32>), MemoryError> {
1765 let mut visited_ids: Vec<i64> = depth_map.keys().copied().collect();
1766 if visited_ids.is_empty() {
1767 return Ok((Vec::new(), Vec::new(), depth_map));
1768 }
1769 if visited_ids.len() > 499 {
1770 tracing::warn!(
1771 total = visited_ids.len(),
1772 retained = 499,
1773 "bfs_fetch_results_typed: visited entity set truncated to 499"
1774 );
1775 visited_ids.truncate(499);
1776 }
1777
1778 let n_types = type_strs.len();
1779 let n_visited = visited_ids.len();
1780
1781 let type_in = (1..=n_types)
1783 .map(|i| format!("?{i}"))
1784 .collect::<Vec<_>>()
1785 .join(", ");
1786 let id_start = n_types + 1;
1787 let placeholders = visited_ids
1788 .iter()
1789 .enumerate()
1790 .map(|(i, _)| format!("?{}", id_start + i))
1791 .collect::<Vec<_>>()
1792 .join(", ");
1793
1794 let edge_filter = if at_timestamp.is_some() {
1795 let ts_pos = id_start + n_visited * 2;
1796 format!(
1797 "edge_type IN ({type_in}) AND valid_from <= ?{ts_pos} AND (valid_to IS NULL OR valid_to > ?{ts_pos})"
1798 )
1799 } else {
1800 format!("edge_type IN ({type_in}) AND valid_to IS NULL")
1801 };
1802
1803 let edge_sql = format!(
1804 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
1805 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id,
1806 edge_type, retrieval_count, last_retrieved_at
1807 FROM graph_edges
1808 WHERE {edge_filter}
1809 AND source_entity_id IN ({placeholders})
1810 AND target_entity_id IN ({placeholders})"
1811 );
1812 let mut edge_query = sqlx::query_as::<_, EdgeRow>(&edge_sql);
1813 for t in type_strs {
1814 edge_query = edge_query.bind(*t);
1815 }
1816 for id in &visited_ids {
1817 edge_query = edge_query.bind(*id);
1818 }
1819 for id in &visited_ids {
1820 edge_query = edge_query.bind(*id);
1821 }
1822 if let Some(ts) = at_timestamp {
1823 edge_query = edge_query.bind(ts);
1824 }
1825 let edge_rows: Vec<EdgeRow> = edge_query.fetch_all(&self.pool).await?;
1826
1827 let entity_sql2 = {
1829 let ph = visited_ids
1830 .iter()
1831 .enumerate()
1832 .map(|(i, _)| format!("?{}", i + 1))
1833 .collect::<Vec<_>>()
1834 .join(", ");
1835 format!(
1836 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
1837 FROM graph_entities WHERE id IN ({ph})"
1838 )
1839 };
1840 let mut entity_query = sqlx::query_as::<_, EntityRow>(&entity_sql2);
1841 for id in &visited_ids {
1842 entity_query = entity_query.bind(*id);
1843 }
1844 let entity_rows: Vec<EntityRow> = entity_query.fetch_all(&self.pool).await?;
1845
1846 let entities: Vec<Entity> = entity_rows
1847 .into_iter()
1848 .map(entity_from_row)
1849 .collect::<Result<Vec<_>, _>>()?;
1850 let edges: Vec<Edge> = edge_rows.into_iter().map(edge_from_row).collect();
1851
1852 Ok((entities, edges, depth_map))
1853 }
1854
1855 async fn bfs_fetch_results(
1857 &self,
1858 depth_map: std::collections::HashMap<i64, u32>,
1859 at_timestamp: Option<&str>,
1860 ) -> Result<(Vec<Entity>, Vec<Edge>, std::collections::HashMap<i64, u32>), MemoryError> {
1861 let mut visited_ids: Vec<i64> = depth_map.keys().copied().collect();
1862 if visited_ids.is_empty() {
1863 return Ok((Vec::new(), Vec::new(), depth_map));
1864 }
1865 if visited_ids.len() > 499 {
1867 tracing::warn!(
1868 total = visited_ids.len(),
1869 retained = 499,
1870 "bfs_fetch_results: visited entity set truncated to 499 to stay within SQLite bind limit; \
1871 some reachable entities will be dropped from results"
1872 );
1873 visited_ids.truncate(499);
1874 }
1875
1876 let placeholders = visited_ids
1877 .iter()
1878 .enumerate()
1879 .map(|(i, _)| format!("?{}", i + 1))
1880 .collect::<Vec<_>>()
1881 .join(", ");
1882 let edge_filter = if at_timestamp.is_some() {
1883 let ts_pos = visited_ids.len() * 2 + 1;
1884 format!("valid_from <= ?{ts_pos} AND (valid_to IS NULL OR valid_to > ?{ts_pos})")
1885 } else {
1886 "valid_to IS NULL".to_owned()
1887 };
1888 let edge_sql = format!(
1889 "SELECT id, source_entity_id, target_entity_id, relation, fact, confidence,
1890 valid_from, valid_to, created_at, expired_at, episode_id, qdrant_point_id,
1891 edge_type, retrieval_count, last_retrieved_at
1892 FROM graph_edges
1893 WHERE {edge_filter}
1894 AND source_entity_id IN ({placeholders})
1895 AND target_entity_id IN ({placeholders})"
1896 );
1897 let mut edge_query = sqlx::query_as::<_, EdgeRow>(&edge_sql);
1898 for id in &visited_ids {
1899 edge_query = edge_query.bind(*id);
1900 }
1901 for id in &visited_ids {
1902 edge_query = edge_query.bind(*id);
1903 }
1904 if let Some(ts) = at_timestamp {
1905 edge_query = edge_query.bind(ts);
1906 }
1907 let edge_rows: Vec<EdgeRow> = edge_query.fetch_all(&self.pool).await?;
1908
1909 let entity_sql = format!(
1910 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
1911 FROM graph_entities WHERE id IN ({placeholders})"
1912 );
1913 let mut entity_query = sqlx::query_as::<_, EntityRow>(&entity_sql);
1914 for id in &visited_ids {
1915 entity_query = entity_query.bind(*id);
1916 }
1917 let entity_rows: Vec<EntityRow> = entity_query.fetch_all(&self.pool).await?;
1918
1919 let entities: Vec<Entity> = entity_rows
1920 .into_iter()
1921 .map(entity_from_row)
1922 .collect::<Result<Vec<_>, _>>()?;
1923 let edges: Vec<Edge> = edge_rows.into_iter().map(edge_from_row).collect();
1924
1925 Ok((entities, edges, depth_map))
1926 }
1927
1928 pub async fn find_entity_by_name(&self, name: &str) -> Result<Vec<Entity>, MemoryError> {
1944 let rows: Vec<EntityRow> = sqlx::query_as(
1945 "SELECT id, name, canonical_name, entity_type, summary, first_seen_at, last_seen_at, qdrant_point_id
1946 FROM graph_entities
1947 WHERE name = ?1 COLLATE NOCASE OR canonical_name = ?1 COLLATE NOCASE
1948 LIMIT 5",
1949 )
1950 .bind(name)
1951 .fetch_all(&self.pool)
1952 .await?;
1953
1954 if !rows.is_empty() {
1955 return rows.into_iter().map(entity_from_row).collect();
1956 }
1957
1958 self.find_entities_fuzzy(name, 5).await
1959 }
1960
1961 pub async fn unprocessed_messages_for_backfill(
1969 &self,
1970 limit: usize,
1971 ) -> Result<Vec<(crate::types::MessageId, String)>, MemoryError> {
1972 let limit = i64::try_from(limit)?;
1973 let rows: Vec<(i64, String)> = sqlx::query_as(
1974 "SELECT id, content FROM messages
1975 WHERE graph_processed = 0
1976 ORDER BY id ASC
1977 LIMIT ?1",
1978 )
1979 .bind(limit)
1980 .fetch_all(&self.pool)
1981 .await?;
1982 Ok(rows
1983 .into_iter()
1984 .map(|(id, content)| (crate::types::MessageId(id), content))
1985 .collect())
1986 }
1987
1988 pub async fn unprocessed_message_count(&self) -> Result<i64, MemoryError> {
1994 let count: i64 =
1995 sqlx::query_scalar("SELECT COUNT(*) FROM messages WHERE graph_processed = 0")
1996 .fetch_one(&self.pool)
1997 .await?;
1998 Ok(count)
1999 }
2000
2001 pub async fn mark_messages_graph_processed(
2007 &self,
2008 ids: &[crate::types::MessageId],
2009 ) -> Result<(), MemoryError> {
2010 if ids.is_empty() {
2011 return Ok(());
2012 }
2013 let placeholders = ids
2014 .iter()
2015 .enumerate()
2016 .map(|(i, _)| format!("?{}", i + 1))
2017 .collect::<Vec<_>>()
2018 .join(", ");
2019 let sql = format!("UPDATE messages SET graph_processed = 1 WHERE id IN ({placeholders})");
2020 let mut query = sqlx::query(&sql);
2021 for id in ids {
2022 query = query.bind(id.0);
2023 }
2024 query.execute(&self.pool).await?;
2025 Ok(())
2026 }
2027}
2028
2029#[derive(sqlx::FromRow)]
2032struct EntityRow {
2033 id: i64,
2034 name: String,
2035 canonical_name: String,
2036 entity_type: String,
2037 summary: Option<String>,
2038 first_seen_at: String,
2039 last_seen_at: String,
2040 qdrant_point_id: Option<String>,
2041}
2042
2043fn entity_from_row(row: EntityRow) -> Result<Entity, MemoryError> {
2044 let entity_type = row
2045 .entity_type
2046 .parse::<EntityType>()
2047 .map_err(MemoryError::GraphStore)?;
2048 Ok(Entity {
2049 id: row.id,
2050 name: row.name,
2051 canonical_name: row.canonical_name,
2052 entity_type,
2053 summary: row.summary,
2054 first_seen_at: row.first_seen_at,
2055 last_seen_at: row.last_seen_at,
2056 qdrant_point_id: row.qdrant_point_id,
2057 })
2058}
2059
2060#[derive(sqlx::FromRow)]
2061struct AliasRow {
2062 id: i64,
2063 entity_id: i64,
2064 alias_name: String,
2065 created_at: String,
2066}
2067
2068fn alias_from_row(row: AliasRow) -> EntityAlias {
2069 EntityAlias {
2070 id: row.id,
2071 entity_id: row.entity_id,
2072 alias_name: row.alias_name,
2073 created_at: row.created_at,
2074 }
2075}
2076
2077#[derive(sqlx::FromRow)]
2078struct EdgeRow {
2079 id: i64,
2080 source_entity_id: i64,
2081 target_entity_id: i64,
2082 relation: String,
2083 fact: String,
2084 confidence: f64,
2085 valid_from: String,
2086 valid_to: Option<String>,
2087 created_at: String,
2088 expired_at: Option<String>,
2089 episode_id: Option<i64>,
2090 qdrant_point_id: Option<String>,
2091 edge_type: String,
2092 retrieval_count: i32,
2093 last_retrieved_at: Option<i64>,
2094}
2095
2096fn edge_from_row(row: EdgeRow) -> Edge {
2097 let edge_type = row
2098 .edge_type
2099 .parse::<EdgeType>()
2100 .unwrap_or(EdgeType::Semantic);
2101 Edge {
2102 id: row.id,
2103 source_entity_id: row.source_entity_id,
2104 target_entity_id: row.target_entity_id,
2105 relation: row.relation,
2106 fact: row.fact,
2107 #[allow(clippy::cast_possible_truncation)]
2108 confidence: row.confidence as f32,
2109 valid_from: row.valid_from,
2110 valid_to: row.valid_to,
2111 created_at: row.created_at,
2112 expired_at: row.expired_at,
2113 episode_id: row.episode_id.map(MessageId),
2114 qdrant_point_id: row.qdrant_point_id,
2115 edge_type,
2116 retrieval_count: row.retrieval_count,
2117 last_retrieved_at: row.last_retrieved_at,
2118 }
2119}
2120
2121#[derive(sqlx::FromRow)]
2122struct CommunityRow {
2123 id: i64,
2124 name: String,
2125 summary: String,
2126 entity_ids: String,
2127 fingerprint: Option<String>,
2128 created_at: String,
2129 updated_at: String,
2130}
2131
2132#[cfg(test)]
2135mod tests;