1use std::sync::Arc;
5#[allow(unused_imports)]
6use zeph_db::sql;
7
8use std::sync::atomic::Ordering;
9use zeph_db::DbPool;
10
11use zeph_llm::any::AnyProvider;
12use zeph_llm::provider::LlmProvider as _;
13
14use crate::embedding_store::EmbeddingStore;
15use crate::error::MemoryError;
16use crate::graph::extractor::ExtractionResult as ExtractorResult;
17use crate::vector_store::VectorFilter;
18
19use super::SemanticMemory;
20
21pub type PostExtractValidator = Option<Box<dyn Fn(&ExtractorResult) -> Result<(), String> + Send>>;
26
27#[derive(Debug, Clone)]
32pub struct GraphExtractionConfig {
33 pub max_entities: usize,
34 pub max_edges: usize,
35 pub extraction_timeout_secs: u64,
36 pub community_refresh_interval: usize,
37 pub expired_edge_retention_days: u32,
38 pub max_entities_cap: usize,
39 pub community_summary_max_prompt_bytes: usize,
40 pub community_summary_concurrency: usize,
41 pub lpa_edge_chunk_size: usize,
42 pub note_linking: NoteLinkingConfig,
44 pub link_weight_decay_lambda: f64,
46 pub link_weight_decay_interval_secs: u64,
48 pub belief_revision_enabled: bool,
50 pub belief_revision_similarity_threshold: f32,
52}
53
54impl Default for GraphExtractionConfig {
55 fn default() -> Self {
56 Self {
57 max_entities: 0,
58 max_edges: 0,
59 extraction_timeout_secs: 0,
60 community_refresh_interval: 0,
61 expired_edge_retention_days: 0,
62 max_entities_cap: 0,
63 community_summary_max_prompt_bytes: 0,
64 community_summary_concurrency: 0,
65 lpa_edge_chunk_size: 0,
66 note_linking: NoteLinkingConfig::default(),
67 link_weight_decay_lambda: 0.95,
68 link_weight_decay_interval_secs: 86400,
69 belief_revision_enabled: false,
70 belief_revision_similarity_threshold: 0.85,
71 }
72 }
73}
74
75#[derive(Debug, Clone)]
77pub struct NoteLinkingConfig {
78 pub enabled: bool,
79 pub similarity_threshold: f32,
80 pub top_k: usize,
81 pub timeout_secs: u64,
82}
83
84impl Default for NoteLinkingConfig {
85 fn default() -> Self {
86 Self {
87 enabled: false,
88 similarity_threshold: 0.85,
89 top_k: 10,
90 timeout_secs: 5,
91 }
92 }
93}
94
95#[derive(Debug, Default)]
97pub struct ExtractionStats {
98 pub entities_upserted: usize,
99 pub edges_inserted: usize,
100}
101
102#[derive(Debug, Default)]
104pub struct ExtractionResult {
105 pub stats: ExtractionStats,
106 pub entity_ids: Vec<i64>,
108}
109
110#[derive(Debug, Default)]
112pub struct LinkingStats {
113 pub entities_processed: usize,
114 pub edges_created: usize,
115}
116
117const ENTITY_COLLECTION: &str = "zeph_graph_entities";
119
120struct EntityWorkItem {
122 entity_id: i64,
123 canonical_name: String,
124 embed_text: String,
125 self_point_id: Option<String>,
126}
127
128#[allow(clippy::too_many_lines)]
144pub async fn link_memory_notes(
145 entity_ids: &[i64],
146 pool: DbPool,
147 embedding_store: Arc<EmbeddingStore>,
148 provider: AnyProvider,
149 cfg: &NoteLinkingConfig,
150) -> LinkingStats {
151 use futures::future;
152
153 use crate::graph::GraphStore;
154
155 let store = GraphStore::new(pool);
156 let mut stats = LinkingStats::default();
157
158 let mut work_items: Vec<EntityWorkItem> = Vec::with_capacity(entity_ids.len());
160 for &entity_id in entity_ids {
161 let entity = match store.find_entity_by_id(entity_id).await {
162 Ok(Some(e)) => e,
163 Ok(None) => {
164 tracing::debug!("note_linking: entity {entity_id} not found, skipping");
165 continue;
166 }
167 Err(e) => {
168 tracing::debug!("note_linking: DB error loading entity {entity_id}: {e:#}");
169 continue;
170 }
171 };
172 let embed_text = match &entity.summary {
173 Some(s) if !s.is_empty() => format!("{}: {s}", entity.canonical_name),
174 _ => entity.canonical_name.clone(),
175 };
176 work_items.push(EntityWorkItem {
177 entity_id,
178 canonical_name: entity.canonical_name,
179 embed_text,
180 self_point_id: entity.qdrant_point_id,
181 });
182 }
183
184 if work_items.is_empty() {
185 return stats;
186 }
187
188 let embed_results: Vec<_> =
190 future::join_all(work_items.iter().map(|w| provider.embed(&w.embed_text))).await;
191
192 let search_limit = cfg.top_k + 1; let valid: Vec<(usize, Vec<f32>)> = embed_results
195 .into_iter()
196 .enumerate()
197 .filter_map(|(i, r)| match r {
198 Ok(v) => Some((i, v)),
199 Err(e) => {
200 tracing::debug!(
201 "note_linking: embed failed for entity {:?}: {e:#}",
202 work_items[i].canonical_name
203 );
204 None
205 }
206 })
207 .collect();
208
209 let search_results: Vec<_> = future::join_all(valid.iter().map(|(_, vec)| {
210 embedding_store.search_collection(
211 ENTITY_COLLECTION,
212 vec,
213 search_limit,
214 None::<VectorFilter>,
215 )
216 }))
217 .await;
218
219 let mut seen_pairs = std::collections::HashSet::new();
224
225 for ((work_idx, _), search_result) in valid.iter().zip(search_results.iter()) {
226 let w = &work_items[*work_idx];
227
228 let results = match search_result {
229 Ok(r) => r,
230 Err(e) => {
231 tracing::debug!(
232 "note_linking: search failed for entity {:?}: {e:#}",
233 w.canonical_name
234 );
235 continue;
236 }
237 };
238
239 stats.entities_processed += 1;
240
241 let self_point_id = w.self_point_id.as_deref();
242 let candidates = results
243 .iter()
244 .filter(|p| Some(p.id.as_str()) != self_point_id && p.score >= cfg.similarity_threshold)
245 .take(cfg.top_k);
246
247 for point in candidates {
248 let Some(target_id) = point
249 .payload
250 .get("entity_id")
251 .and_then(serde_json::Value::as_i64)
252 else {
253 tracing::debug!(
254 "note_linking: missing entity_id in payload for point {}",
255 point.id
256 );
257 continue;
258 };
259
260 if target_id == w.entity_id {
261 continue; }
263
264 let (src, tgt) = if w.entity_id < target_id {
266 (w.entity_id, target_id)
267 } else {
268 (target_id, w.entity_id)
269 };
270
271 if !seen_pairs.insert((src, tgt)) {
273 continue;
274 }
275
276 let fact = format!("Semantically similar entities (score: {:.3})", point.score);
277
278 match store
279 .insert_edge(src, tgt, "similar_to", &fact, point.score, None)
280 .await
281 {
282 Ok(_) => stats.edges_created += 1,
283 Err(e) => {
284 tracing::debug!("note_linking: insert_edge failed: {e:#}");
285 }
286 }
287 }
288 }
289
290 stats
291}
292
293#[allow(clippy::too_many_lines)]
304pub async fn extract_and_store(
305 content: String,
306 context_messages: Vec<String>,
307 provider: AnyProvider,
308 pool: DbPool,
309 config: GraphExtractionConfig,
310 post_extract_validator: PostExtractValidator,
311 embedding_store: Option<Arc<EmbeddingStore>>,
312) -> Result<ExtractionResult, MemoryError> {
313 use crate::graph::{EntityResolver, GraphExtractor, GraphStore};
314
315 let extractor = GraphExtractor::new(provider.clone(), config.max_entities, config.max_edges);
316 let ctx_refs: Vec<&str> = context_messages.iter().map(String::as_str).collect();
317
318 let store = GraphStore::new(pool);
319
320 let pool = store.pool();
321 zeph_db::query(sql!(
322 "INSERT INTO graph_metadata (key, value) VALUES ('extraction_count', '0')
323 ON CONFLICT(key) DO NOTHING"
324 ))
325 .execute(pool)
326 .await?;
327 zeph_db::query(sql!(
328 "UPDATE graph_metadata
329 SET value = CAST(CAST(value AS INTEGER) + 1 AS TEXT)
330 WHERE key = 'extraction_count'"
331 ))
332 .execute(pool)
333 .await?;
334
335 let Some(result) = extractor.extract(&content, &ctx_refs).await? else {
336 return Ok(ExtractionResult::default());
337 };
338
339 if let Some(ref validator) = post_extract_validator
342 && let Err(reason) = validator(&result)
343 {
344 tracing::warn!(
345 reason,
346 "graph extraction validation failed, skipping upsert"
347 );
348 return Ok(ExtractionResult::default());
349 }
350
351 let resolver = if let Some(ref emb) = embedding_store {
352 EntityResolver::new(&store)
353 .with_embedding_store(emb)
354 .with_provider(&provider)
355 } else {
356 EntityResolver::new(&store)
357 };
358
359 let mut entities_upserted = 0usize;
360 let mut entity_name_to_id: std::collections::HashMap<String, i64> =
361 std::collections::HashMap::new();
362
363 for entity in &result.entities {
364 match resolver
365 .resolve(&entity.name, &entity.entity_type, entity.summary.as_deref())
366 .await
367 {
368 Ok((id, _outcome)) => {
369 entity_name_to_id.insert(entity.name.clone(), id);
370 entities_upserted += 1;
371 }
372 Err(e) => {
373 tracing::debug!("graph: skipping entity {:?}: {e:#}", entity.name);
374 }
375 }
376 }
377
378 let mut edges_inserted = 0usize;
379 for edge in &result.edges {
380 let (Some(&src_id), Some(&tgt_id)) = (
381 entity_name_to_id.get(&edge.source),
382 entity_name_to_id.get(&edge.target),
383 ) else {
384 tracing::debug!(
385 "graph: skipping edge {:?}->{:?}: entity not resolved",
386 edge.source,
387 edge.target
388 );
389 continue;
390 };
391 if src_id == tgt_id {
392 tracing::debug!(
393 "graph: skipping self-loop edge {:?}->{:?} (entity_id={src_id})",
394 edge.source,
395 edge.target
396 );
397 continue;
398 }
399 let edge_type = edge
402 .edge_type
403 .parse::<crate::graph::EdgeType>()
404 .unwrap_or_else(|_| {
405 tracing::warn!(
406 raw_type = %edge.edge_type,
407 "graph: unknown edge_type from LLM, defaulting to semantic"
408 );
409 crate::graph::EdgeType::Semantic
410 });
411 let belief_cfg =
412 config
413 .belief_revision_enabled
414 .then_some(crate::graph::BeliefRevisionConfig {
415 similarity_threshold: config.belief_revision_similarity_threshold,
416 });
417 match resolver
418 .resolve_edge_typed(
419 src_id,
420 tgt_id,
421 &edge.relation,
422 &edge.fact,
423 0.8,
424 None,
425 edge_type,
426 belief_cfg.as_ref(),
427 )
428 .await
429 {
430 Ok(Some(_)) => edges_inserted += 1,
431 Ok(None) => {} Err(e) => {
433 tracing::debug!("graph: skipping edge: {e:#}");
434 }
435 }
436 }
437
438 store.checkpoint_wal().await?;
439
440 let new_entity_ids: Vec<i64> = entity_name_to_id.into_values().collect();
441
442 Ok(ExtractionResult {
443 stats: ExtractionStats {
444 entities_upserted,
445 edges_inserted,
446 },
447 entity_ids: new_entity_ids,
448 })
449}
450
451impl SemanticMemory {
452 #[allow(clippy::too_many_lines)]
464 pub fn spawn_graph_extraction(
465 &self,
466 content: String,
467 context_messages: Vec<String>,
468 config: GraphExtractionConfig,
469 post_extract_validator: PostExtractValidator,
470 ) -> tokio::task::JoinHandle<()> {
471 let pool = self.sqlite.pool().clone();
472 let provider = self.provider.clone();
473 let failure_counter = self.community_detection_failures.clone();
474 let extraction_count = self.graph_extraction_count.clone();
475 let extraction_failures = self.graph_extraction_failures.clone();
476 let embedding_store = self.qdrant.clone();
478
479 tokio::spawn(async move {
480 let timeout_dur = std::time::Duration::from_secs(config.extraction_timeout_secs);
481 let extraction_result = tokio::time::timeout(
482 timeout_dur,
483 extract_and_store(
484 content,
485 context_messages,
486 provider.clone(),
487 pool.clone(),
488 config.clone(),
489 post_extract_validator,
490 embedding_store.clone(),
491 ),
492 )
493 .await;
494
495 let (extraction_ok, new_entity_ids) = match extraction_result {
496 Ok(Ok(result)) => {
497 tracing::debug!(
498 entities = result.stats.entities_upserted,
499 edges = result.stats.edges_inserted,
500 "graph extraction completed"
501 );
502 extraction_count.fetch_add(1, Ordering::Relaxed);
503 (true, result.entity_ids)
504 }
505 Ok(Err(e)) => {
506 tracing::warn!("graph extraction failed: {e:#}");
507 extraction_failures.fetch_add(1, Ordering::Relaxed);
508 (false, vec![])
509 }
510 Err(_elapsed) => {
511 tracing::warn!("graph extraction timed out");
512 extraction_failures.fetch_add(1, Ordering::Relaxed);
513 (false, vec![])
514 }
515 };
516
517 if extraction_ok
519 && config.note_linking.enabled
520 && !new_entity_ids.is_empty()
521 && let Some(store) = embedding_store
522 {
523 let linking_timeout =
524 std::time::Duration::from_secs(config.note_linking.timeout_secs);
525 match tokio::time::timeout(
526 linking_timeout,
527 link_memory_notes(
528 &new_entity_ids,
529 pool.clone(),
530 store,
531 provider.clone(),
532 &config.note_linking,
533 ),
534 )
535 .await
536 {
537 Ok(stats) => {
538 tracing::debug!(
539 entities_processed = stats.entities_processed,
540 edges_created = stats.edges_created,
541 "note linking completed"
542 );
543 }
544 Err(_elapsed) => {
545 tracing::debug!("note linking timed out (partial edges may exist)");
546 }
547 }
548 }
549
550 if extraction_ok && config.community_refresh_interval > 0 {
551 use crate::graph::GraphStore;
552
553 let store = GraphStore::new(pool.clone());
554 let extraction_count = store.extraction_count().await.unwrap_or(0);
555 if extraction_count > 0
556 && i64::try_from(config.community_refresh_interval)
557 .is_ok_and(|interval| extraction_count % interval == 0)
558 {
559 tracing::info!(extraction_count, "triggering community detection refresh");
560 let store2 = GraphStore::new(pool);
561 let provider2 = provider;
562 let retention_days = config.expired_edge_retention_days;
563 let max_cap = config.max_entities_cap;
564 let max_prompt_bytes = config.community_summary_max_prompt_bytes;
565 let concurrency = config.community_summary_concurrency;
566 let edge_chunk_size = config.lpa_edge_chunk_size;
567 let decay_lambda = config.link_weight_decay_lambda;
568 let decay_interval_secs = config.link_weight_decay_interval_secs;
569 tokio::spawn(async move {
570 match crate::graph::community::detect_communities(
571 &store2,
572 &provider2,
573 max_prompt_bytes,
574 concurrency,
575 edge_chunk_size,
576 )
577 .await
578 {
579 Ok(count) => {
580 tracing::info!(communities = count, "community detection complete");
581 }
582 Err(e) => {
583 tracing::warn!("community detection failed: {e:#}");
584 failure_counter.fetch_add(1, Ordering::Relaxed);
585 }
586 }
587 match crate::graph::community::run_graph_eviction(
588 &store2,
589 retention_days,
590 max_cap,
591 )
592 .await
593 {
594 Ok(stats) => {
595 tracing::info!(
596 expired_edges = stats.expired_edges_deleted,
597 orphan_entities = stats.orphan_entities_deleted,
598 capped_entities = stats.capped_entities_deleted,
599 "graph eviction complete"
600 );
601 }
602 Err(e) => {
603 tracing::warn!("graph eviction failed: {e:#}");
604 }
605 }
606
607 if decay_lambda > 0.0 && decay_interval_secs > 0 {
609 let now_secs = std::time::SystemTime::now()
610 .duration_since(std::time::UNIX_EPOCH)
611 .map(|d| d.as_secs())
612 .unwrap_or(0);
613 let last_decay = store2
614 .get_metadata("last_link_weight_decay_at")
615 .await
616 .ok()
617 .flatten()
618 .and_then(|s| s.parse::<u64>().ok())
619 .unwrap_or(0);
620 if now_secs.saturating_sub(last_decay) >= decay_interval_secs {
621 match store2
622 .decay_edge_retrieval_counts(decay_lambda, decay_interval_secs)
623 .await
624 {
625 Ok(affected) => {
626 tracing::info!(affected, "link weight decay applied");
627 let _ = store2
628 .set_metadata(
629 "last_link_weight_decay_at",
630 &now_secs.to_string(),
631 )
632 .await;
633 }
634 Err(e) => {
635 tracing::warn!("link weight decay failed: {e:#}");
636 }
637 }
638 }
639 }
640 });
641 }
642 }
643 })
644 }
645}
646
647#[cfg(test)]
648mod tests {
649 use std::sync::Arc;
650
651 use zeph_llm::any::AnyProvider;
652
653 use super::extract_and_store;
654 use crate::embedding_store::EmbeddingStore;
655 use crate::graph::GraphStore;
656 use crate::in_memory_store::InMemoryVectorStore;
657 use crate::store::SqliteStore;
658
659 use super::GraphExtractionConfig;
660
661 async fn setup() -> (GraphStore, Arc<EmbeddingStore>) {
662 let sqlite = SqliteStore::new(":memory:").await.unwrap();
663 let pool = sqlite.pool().clone();
664 let mem_store = Box::new(InMemoryVectorStore::new());
665 let emb = Arc::new(EmbeddingStore::with_store(mem_store, pool.clone()));
666 let gs = GraphStore::new(pool);
667 (gs, emb)
668 }
669
670 #[tokio::test]
673 async fn extract_and_store_sets_qdrant_point_id_when_embedding_store_provided() {
674 let (gs, emb) = setup().await;
675
676 let extraction_json = r#"{"entities":[{"name":"Rust","type":"language","summary":"systems language"}],"edges":[]}"#;
678 let mut mock =
679 zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
680 mock.supports_embeddings = true;
681 mock.embedding = vec![1.0_f32, 0.0, 0.0, 0.0];
682 let provider = AnyProvider::Mock(mock);
683
684 let config = GraphExtractionConfig {
685 max_entities: 10,
686 max_edges: 10,
687 extraction_timeout_secs: 10,
688 ..Default::default()
689 };
690
691 let result = extract_and_store(
692 "Rust is a systems programming language.".to_owned(),
693 vec![],
694 provider,
695 gs.pool().clone(),
696 config,
697 None,
698 Some(emb.clone()),
699 )
700 .await
701 .unwrap();
702
703 assert_eq!(
704 result.stats.entities_upserted, 1,
705 "one entity should be upserted"
706 );
707
708 let entity = gs
712 .find_entity("rust", crate::graph::EntityType::Language)
713 .await
714 .unwrap()
715 .expect("entity 'rust' must exist in SQLite");
716
717 assert!(
718 entity.qdrant_point_id.is_some(),
719 "qdrant_point_id must be set when embedding_store + provider are both provided (regression for #1829)"
720 );
721 }
722
723 #[tokio::test]
726 async fn extract_and_store_without_embedding_store_still_upserts_entities() {
727 let (gs, _emb) = setup().await;
728
729 let extraction_json = r#"{"entities":[{"name":"Python","type":"language","summary":"scripting"}],"edges":[]}"#;
730 let mock = zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
731 let provider = AnyProvider::Mock(mock);
732
733 let config = GraphExtractionConfig {
734 max_entities: 10,
735 max_edges: 10,
736 extraction_timeout_secs: 10,
737 ..Default::default()
738 };
739
740 let result = extract_and_store(
741 "Python is a scripting language.".to_owned(),
742 vec![],
743 provider,
744 gs.pool().clone(),
745 config,
746 None,
747 None, )
749 .await
750 .unwrap();
751
752 assert_eq!(result.stats.entities_upserted, 1);
753
754 let entity = gs
755 .find_entity("python", crate::graph::EntityType::Language)
756 .await
757 .unwrap()
758 .expect("entity 'python' must exist");
759
760 assert!(
761 entity.qdrant_point_id.is_none(),
762 "qdrant_point_id must remain None when no embedding_store is provided"
763 );
764 }
765
766 #[tokio::test]
770 async fn extract_and_store_fts5_cross_session_visibility() {
771 let file = tempfile::NamedTempFile::new().expect("tempfile");
772 let path = file.path().to_str().expect("valid path").to_string();
773
774 {
776 let sqlite = crate::store::SqliteStore::new(&path).await.unwrap();
777 let extraction_json = r#"{"entities":[{"name":"Ferris","type":"concept","summary":"Rust mascot"}],"edges":[]}"#;
778 let mock =
779 zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
780 let provider = AnyProvider::Mock(mock);
781 let config = GraphExtractionConfig {
782 max_entities: 10,
783 max_edges: 10,
784 extraction_timeout_secs: 10,
785 ..Default::default()
786 };
787 extract_and_store(
788 "Ferris is the Rust mascot.".to_owned(),
789 vec![],
790 provider,
791 sqlite.pool().clone(),
792 config,
793 None,
794 None,
795 )
796 .await
797 .unwrap();
798 }
799
800 let sqlite_b = crate::store::SqliteStore::new(&path).await.unwrap();
802 let gs_b = crate::graph::GraphStore::new(sqlite_b.pool().clone());
803 let results = gs_b.find_entities_fuzzy("Ferris", 10).await.unwrap();
804 assert!(
805 !results.is_empty(),
806 "FTS5 cross-session (#2166): entity extracted in session A must be visible in session B"
807 );
808 }
809
810 #[tokio::test]
813 async fn extract_and_store_skips_self_loop_edges() {
814 let (gs, _emb) = setup().await;
815
816 let extraction_json = r#"{
818 "entities":[{"name":"Rust","type":"language","summary":"systems language"}],
819 "edges":[{"source":"Rust","target":"Rust","relation":"is","fact":"Rust is Rust","edge_type":"semantic"}]
820 }"#;
821 let mock = zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
822 let provider = AnyProvider::Mock(mock);
823
824 let config = GraphExtractionConfig {
825 max_entities: 10,
826 max_edges: 10,
827 extraction_timeout_secs: 10,
828 ..Default::default()
829 };
830
831 let result = extract_and_store(
832 "Rust is a language.".to_owned(),
833 vec![],
834 provider,
835 gs.pool().clone(),
836 config,
837 None,
838 None,
839 )
840 .await
841 .unwrap();
842
843 assert_eq!(result.stats.entities_upserted, 1);
844 assert_eq!(
845 result.stats.edges_inserted, 0,
846 "self-loop edge must be rejected (#2215)"
847 );
848 }
849}