1use std::sync::Arc;
5#[allow(unused_imports)]
6use zeph_db::sql;
7
8use std::sync::atomic::Ordering;
9use zeph_db::DbPool;
10
11pub use zeph_common::config::memory::NoteLinkingConfig;
12use zeph_llm::any::AnyProvider;
13use zeph_llm::provider::LlmProvider as _;
14
15use crate::embedding_store::EmbeddingStore;
16use crate::error::MemoryError;
17use crate::graph::extractor::ExtractionResult as ExtractorResult;
18use crate::vector_store::VectorFilter;
19
20use super::SemanticMemory;
21
22pub type PostExtractValidator = Option<Box<dyn Fn(&ExtractorResult) -> Result<(), String> + Send>>;
27
28#[derive(Debug, Clone)]
33pub struct GraphExtractionConfig {
34 pub max_entities: usize,
35 pub max_edges: usize,
36 pub extraction_timeout_secs: u64,
37 pub community_refresh_interval: usize,
38 pub expired_edge_retention_days: u32,
39 pub max_entities_cap: usize,
40 pub community_summary_max_prompt_bytes: usize,
41 pub community_summary_concurrency: usize,
42 pub lpa_edge_chunk_size: usize,
43 pub note_linking: NoteLinkingConfig,
45 pub link_weight_decay_lambda: f64,
47 pub link_weight_decay_interval_secs: u64,
49 pub belief_revision_enabled: bool,
51 pub belief_revision_similarity_threshold: f32,
53 pub conversation_id: Option<i64>,
56}
57
58impl Default for GraphExtractionConfig {
59 fn default() -> Self {
60 Self {
61 max_entities: 0,
62 max_edges: 0,
63 extraction_timeout_secs: 0,
64 community_refresh_interval: 0,
65 expired_edge_retention_days: 0,
66 max_entities_cap: 0,
67 community_summary_max_prompt_bytes: 0,
68 community_summary_concurrency: 0,
69 lpa_edge_chunk_size: 0,
70 note_linking: NoteLinkingConfig::default(),
71 link_weight_decay_lambda: 0.95,
72 link_weight_decay_interval_secs: 86400,
73 belief_revision_enabled: false,
74 belief_revision_similarity_threshold: 0.85,
75 conversation_id: None,
76 }
77 }
78}
79
80#[derive(Debug, Default)]
82pub struct ExtractionStats {
83 pub entities_upserted: usize,
84 pub edges_inserted: usize,
85}
86
87#[derive(Debug, Default)]
89pub struct ExtractionResult {
90 pub stats: ExtractionStats,
91 pub entity_ids: Vec<i64>,
93}
94
95#[derive(Debug, Default)]
97pub struct LinkingStats {
98 pub entities_processed: usize,
99 pub edges_created: usize,
100}
101
102const ENTITY_COLLECTION: &str = "zeph_graph_entities";
104
105struct EntityWorkItem {
107 entity_id: i64,
108 canonical_name: String,
109 embed_text: String,
110 self_point_id: Option<String>,
111}
112
113#[allow(clippy::too_many_lines)]
129pub async fn link_memory_notes(
130 entity_ids: &[i64],
131 pool: DbPool,
132 embedding_store: Arc<EmbeddingStore>,
133 provider: AnyProvider,
134 cfg: &NoteLinkingConfig,
135) -> LinkingStats {
136 use futures::future;
137
138 use crate::graph::GraphStore;
139
140 let store = GraphStore::new(pool);
141 let mut stats = LinkingStats::default();
142
143 let mut work_items: Vec<EntityWorkItem> = Vec::with_capacity(entity_ids.len());
145 for &entity_id in entity_ids {
146 let entity = match store.find_entity_by_id(entity_id).await {
147 Ok(Some(e)) => e,
148 Ok(None) => {
149 tracing::debug!("note_linking: entity {entity_id} not found, skipping");
150 continue;
151 }
152 Err(e) => {
153 tracing::debug!("note_linking: DB error loading entity {entity_id}: {e:#}");
154 continue;
155 }
156 };
157 let embed_text = match &entity.summary {
158 Some(s) if !s.is_empty() => format!("{}: {s}", entity.canonical_name),
159 _ => entity.canonical_name.clone(),
160 };
161 work_items.push(EntityWorkItem {
162 entity_id,
163 canonical_name: entity.canonical_name,
164 embed_text,
165 self_point_id: entity.qdrant_point_id,
166 });
167 }
168
169 if work_items.is_empty() {
170 return stats;
171 }
172
173 let embed_results: Vec<_> =
175 future::join_all(work_items.iter().map(|w| provider.embed(&w.embed_text))).await;
176
177 let search_limit = cfg.top_k + 1; let valid: Vec<(usize, Vec<f32>)> = embed_results
180 .into_iter()
181 .enumerate()
182 .filter_map(|(i, r)| match r {
183 Ok(v) => Some((i, v)),
184 Err(e) => {
185 tracing::debug!(
186 "note_linking: embed failed for entity {:?}: {e:#}",
187 work_items[i].canonical_name
188 );
189 None
190 }
191 })
192 .collect();
193
194 let search_results: Vec<_> = future::join_all(valid.iter().map(|(_, vec)| {
195 embedding_store.search_collection(
196 ENTITY_COLLECTION,
197 vec,
198 search_limit,
199 None::<VectorFilter>,
200 )
201 }))
202 .await;
203
204 let mut seen_pairs = std::collections::HashSet::new();
209
210 for ((work_idx, _), search_result) in valid.iter().zip(search_results.iter()) {
211 let w = &work_items[*work_idx];
212
213 let results = match search_result {
214 Ok(r) => r,
215 Err(e) => {
216 tracing::debug!(
217 "note_linking: search failed for entity {:?}: {e:#}",
218 w.canonical_name
219 );
220 continue;
221 }
222 };
223
224 stats.entities_processed += 1;
225
226 let self_point_id = w.self_point_id.as_deref();
227 let candidates = results
228 .iter()
229 .filter(|p| Some(p.id.as_str()) != self_point_id && p.score >= cfg.similarity_threshold)
230 .take(cfg.top_k);
231
232 for point in candidates {
233 let Some(target_id) = point
234 .payload
235 .get("entity_id")
236 .and_then(serde_json::Value::as_i64)
237 else {
238 tracing::debug!(
239 "note_linking: missing entity_id in payload for point {}",
240 point.id
241 );
242 continue;
243 };
244
245 if target_id == w.entity_id {
246 continue; }
248
249 let (src, tgt) = if w.entity_id < target_id {
251 (w.entity_id, target_id)
252 } else {
253 (target_id, w.entity_id)
254 };
255
256 if !seen_pairs.insert((src, tgt)) {
258 continue;
259 }
260
261 let fact = format!("Semantically similar entities (score: {:.3})", point.score);
262
263 match store
264 .insert_edge(src, tgt, "similar_to", &fact, point.score, None)
265 .await
266 {
267 Ok(_) => stats.edges_created += 1,
268 Err(e) => {
269 tracing::debug!("note_linking: insert_edge failed: {e:#}");
270 }
271 }
272 }
273 }
274
275 stats
276}
277
278#[cfg_attr(
289 feature = "profiling",
290 tracing::instrument(name = "memory.graph_extract", skip_all, fields(entities = tracing::field::Empty, edges = tracing::field::Empty))
291)]
292#[allow(clippy::too_many_lines)]
293pub async fn extract_and_store(
294 content: String,
295 context_messages: Vec<String>,
296 provider: AnyProvider,
297 pool: DbPool,
298 config: GraphExtractionConfig,
299 post_extract_validator: PostExtractValidator,
300 embedding_store: Option<Arc<EmbeddingStore>>,
301) -> Result<ExtractionResult, MemoryError> {
302 use crate::graph::{EntityResolver, GraphExtractor, GraphStore};
303
304 let extractor = GraphExtractor::new(provider.clone(), config.max_entities, config.max_edges);
305 let ctx_refs: Vec<&str> = context_messages.iter().map(String::as_str).collect();
306
307 let store = GraphStore::new(pool);
308
309 let pool = store.pool();
310 zeph_db::query(sql!(
311 "INSERT INTO graph_metadata (key, value) VALUES ('extraction_count', '0')
312 ON CONFLICT(key) DO NOTHING"
313 ))
314 .execute(pool)
315 .await?;
316 zeph_db::query(sql!(
317 "UPDATE graph_metadata
318 SET value = CAST(CAST(value AS INTEGER) + 1 AS TEXT)
319 WHERE key = 'extraction_count'"
320 ))
321 .execute(pool)
322 .await?;
323
324 let Some(result) = extractor.extract(&content, &ctx_refs).await? else {
325 return Ok(ExtractionResult::default());
326 };
327
328 if let Some(ref validator) = post_extract_validator
331 && let Err(reason) = validator(&result)
332 {
333 tracing::warn!(
334 reason,
335 "graph extraction validation failed, skipping upsert"
336 );
337 return Ok(ExtractionResult::default());
338 }
339
340 let resolver = if let Some(ref emb) = embedding_store {
341 EntityResolver::new(&store)
342 .with_embedding_store(emb)
343 .with_provider(&provider)
344 } else {
345 EntityResolver::new(&store)
346 };
347
348 let mut entities_upserted = 0usize;
349 let mut entity_name_to_id: std::collections::HashMap<String, i64> =
350 std::collections::HashMap::new();
351
352 for entity in &result.entities {
353 match resolver
354 .resolve(&entity.name, &entity.entity_type, entity.summary.as_deref())
355 .await
356 {
357 Ok((id, _outcome)) => {
358 entity_name_to_id.insert(entity.name.clone(), id);
359 entities_upserted += 1;
360 }
361 Err(e) => {
362 tracing::debug!("graph: skipping entity {:?}: {e:#}", entity.name);
363 }
364 }
365 }
366
367 let mut edges_inserted = 0usize;
368 for edge in &result.edges {
369 let (Some(&src_id), Some(&tgt_id)) = (
370 entity_name_to_id.get(&edge.source),
371 entity_name_to_id.get(&edge.target),
372 ) else {
373 tracing::debug!(
374 "graph: skipping edge {:?}->{:?}: entity not resolved",
375 edge.source,
376 edge.target
377 );
378 continue;
379 };
380 if src_id == tgt_id {
381 tracing::debug!(
382 "graph: skipping self-loop edge {:?}->{:?} (entity_id={src_id})",
383 edge.source,
384 edge.target
385 );
386 continue;
387 }
388 let edge_type = edge
391 .edge_type
392 .parse::<crate::graph::EdgeType>()
393 .unwrap_or_else(|_| {
394 tracing::warn!(
395 raw_type = %edge.edge_type,
396 "graph: unknown edge_type from LLM, defaulting to semantic"
397 );
398 crate::graph::EdgeType::Semantic
399 });
400 let belief_cfg =
401 config
402 .belief_revision_enabled
403 .then_some(crate::graph::BeliefRevisionConfig {
404 similarity_threshold: config.belief_revision_similarity_threshold,
405 });
406 match resolver
407 .resolve_edge_typed(
408 src_id,
409 tgt_id,
410 &edge.relation,
411 &edge.fact,
412 0.8,
413 None,
414 edge_type,
415 belief_cfg.as_ref(),
416 )
417 .await
418 {
419 Ok(Some(_)) => edges_inserted += 1,
420 Ok(None) => {} Err(e) => {
422 tracing::debug!("graph: skipping edge: {e:#}");
423 }
424 }
425 }
426
427 store.checkpoint_wal().await?;
428
429 let new_entity_ids: Vec<i64> = entity_name_to_id.into_values().collect();
430
431 if let Some(conv_id) = config.conversation_id {
433 match store.ensure_episode(conv_id).await {
434 Ok(episode_id) => {
435 for &entity_id in &new_entity_ids {
436 if let Err(e) = store.link_entity_to_episode(episode_id, entity_id).await {
437 tracing::debug!("episode linking skipped for entity {entity_id}: {e:#}");
438 }
439 }
440 }
441 Err(e) => {
442 tracing::warn!("failed to ensure episode for conversation {conv_id}: {e:#}");
443 }
444 }
445 }
446
447 #[cfg(feature = "profiling")]
448 {
449 let span = tracing::Span::current();
450 span.record("entities", entities_upserted);
451 span.record("edges", edges_inserted);
452 }
453
454 Ok(ExtractionResult {
455 stats: ExtractionStats {
456 entities_upserted,
457 edges_inserted,
458 },
459 entity_ids: new_entity_ids,
460 })
461}
462
463impl SemanticMemory {
464 #[allow(clippy::too_many_lines)]
476 pub fn spawn_graph_extraction(
477 &self,
478 content: String,
479 context_messages: Vec<String>,
480 config: GraphExtractionConfig,
481 post_extract_validator: PostExtractValidator,
482 ) -> tokio::task::JoinHandle<()> {
483 let pool = self.sqlite.pool().clone();
484 let provider = self.provider.clone();
485 let failure_counter = self.community_detection_failures.clone();
486 let extraction_count = self.graph_extraction_count.clone();
487 let extraction_failures = self.graph_extraction_failures.clone();
488 let embedding_store = self.qdrant.clone();
490
491 tokio::spawn(async move {
492 let timeout_dur = std::time::Duration::from_secs(config.extraction_timeout_secs);
493 let extraction_result = tokio::time::timeout(
494 timeout_dur,
495 extract_and_store(
496 content,
497 context_messages,
498 provider.clone(),
499 pool.clone(),
500 config.clone(),
501 post_extract_validator,
502 embedding_store.clone(),
503 ),
504 )
505 .await;
506
507 let (extraction_ok, new_entity_ids) = match extraction_result {
508 Ok(Ok(result)) => {
509 tracing::debug!(
510 entities = result.stats.entities_upserted,
511 edges = result.stats.edges_inserted,
512 "graph extraction completed"
513 );
514 extraction_count.fetch_add(1, Ordering::Relaxed);
515 (true, result.entity_ids)
516 }
517 Ok(Err(e)) => {
518 tracing::warn!("graph extraction failed: {e:#}");
519 extraction_failures.fetch_add(1, Ordering::Relaxed);
520 (false, vec![])
521 }
522 Err(_elapsed) => {
523 tracing::warn!("graph extraction timed out");
524 extraction_failures.fetch_add(1, Ordering::Relaxed);
525 (false, vec![])
526 }
527 };
528
529 if extraction_ok
531 && config.note_linking.enabled
532 && !new_entity_ids.is_empty()
533 && let Some(store) = embedding_store
534 {
535 let linking_timeout =
536 std::time::Duration::from_secs(config.note_linking.timeout_secs);
537 match tokio::time::timeout(
538 linking_timeout,
539 link_memory_notes(
540 &new_entity_ids,
541 pool.clone(),
542 store,
543 provider.clone(),
544 &config.note_linking,
545 ),
546 )
547 .await
548 {
549 Ok(stats) => {
550 tracing::debug!(
551 entities_processed = stats.entities_processed,
552 edges_created = stats.edges_created,
553 "note linking completed"
554 );
555 }
556 Err(_elapsed) => {
557 tracing::debug!("note linking timed out (partial edges may exist)");
558 }
559 }
560 }
561
562 if extraction_ok && config.community_refresh_interval > 0 {
563 use crate::graph::GraphStore;
564
565 let store = GraphStore::new(pool.clone());
566 let extraction_count = store.extraction_count().await.unwrap_or(0);
567 if extraction_count > 0
568 && i64::try_from(config.community_refresh_interval)
569 .is_ok_and(|interval| extraction_count % interval == 0)
570 {
571 tracing::info!(extraction_count, "triggering community detection refresh");
572 let store2 = GraphStore::new(pool);
573 let provider2 = provider;
574 let retention_days = config.expired_edge_retention_days;
575 let max_cap = config.max_entities_cap;
576 let max_prompt_bytes = config.community_summary_max_prompt_bytes;
577 let concurrency = config.community_summary_concurrency;
578 let edge_chunk_size = config.lpa_edge_chunk_size;
579 let decay_lambda = config.link_weight_decay_lambda;
580 let decay_interval_secs = config.link_weight_decay_interval_secs;
581 tokio::spawn(async move {
582 match crate::graph::community::detect_communities(
583 &store2,
584 &provider2,
585 max_prompt_bytes,
586 concurrency,
587 edge_chunk_size,
588 )
589 .await
590 {
591 Ok(count) => {
592 tracing::info!(communities = count, "community detection complete");
593 }
594 Err(e) => {
595 tracing::warn!("community detection failed: {e:#}");
596 failure_counter.fetch_add(1, Ordering::Relaxed);
597 }
598 }
599 match crate::graph::community::run_graph_eviction(
600 &store2,
601 retention_days,
602 max_cap,
603 )
604 .await
605 {
606 Ok(stats) => {
607 tracing::info!(
608 expired_edges = stats.expired_edges_deleted,
609 orphan_entities = stats.orphan_entities_deleted,
610 capped_entities = stats.capped_entities_deleted,
611 "graph eviction complete"
612 );
613 }
614 Err(e) => {
615 tracing::warn!("graph eviction failed: {e:#}");
616 }
617 }
618
619 if decay_lambda > 0.0 && decay_interval_secs > 0 {
621 let now_secs = std::time::SystemTime::now()
622 .duration_since(std::time::UNIX_EPOCH)
623 .map(|d| d.as_secs())
624 .unwrap_or(0);
625 let last_decay = store2
626 .get_metadata("last_link_weight_decay_at")
627 .await
628 .ok()
629 .flatten()
630 .and_then(|s| s.parse::<u64>().ok())
631 .unwrap_or(0);
632 if now_secs.saturating_sub(last_decay) >= decay_interval_secs {
633 match store2
634 .decay_edge_retrieval_counts(decay_lambda, decay_interval_secs)
635 .await
636 {
637 Ok(affected) => {
638 tracing::info!(affected, "link weight decay applied");
639 let _ = store2
640 .set_metadata(
641 "last_link_weight_decay_at",
642 &now_secs.to_string(),
643 )
644 .await;
645 }
646 Err(e) => {
647 tracing::warn!("link weight decay failed: {e:#}");
648 }
649 }
650 }
651 }
652 });
653 }
654 }
655 })
656 }
657}
658
659#[cfg(test)]
660mod tests {
661 use std::sync::Arc;
662
663 use zeph_llm::any::AnyProvider;
664
665 use super::extract_and_store;
666 use crate::embedding_store::EmbeddingStore;
667 use crate::graph::GraphStore;
668 use crate::in_memory_store::InMemoryVectorStore;
669 use crate::store::SqliteStore;
670
671 use super::GraphExtractionConfig;
672
673 async fn setup() -> (GraphStore, Arc<EmbeddingStore>) {
674 let sqlite = SqliteStore::new(":memory:").await.unwrap();
675 let pool = sqlite.pool().clone();
676 let mem_store = Box::new(InMemoryVectorStore::new());
677 let emb = Arc::new(EmbeddingStore::with_store(mem_store, pool.clone()));
678 let gs = GraphStore::new(pool);
679 (gs, emb)
680 }
681
682 #[tokio::test]
685 async fn extract_and_store_sets_qdrant_point_id_when_embedding_store_provided() {
686 let (gs, emb) = setup().await;
687
688 let extraction_json = r#"{"entities":[{"name":"Rust","type":"language","summary":"systems language"}],"edges":[]}"#;
690 let mut mock =
691 zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
692 mock.supports_embeddings = true;
693 mock.embedding = vec![1.0_f32, 0.0, 0.0, 0.0];
694 let provider = AnyProvider::Mock(mock);
695
696 let config = GraphExtractionConfig {
697 max_entities: 10,
698 max_edges: 10,
699 extraction_timeout_secs: 10,
700 ..Default::default()
701 };
702
703 let result = extract_and_store(
704 "Rust is a systems programming language.".to_owned(),
705 vec![],
706 provider,
707 gs.pool().clone(),
708 config,
709 None,
710 Some(emb.clone()),
711 )
712 .await
713 .unwrap();
714
715 assert_eq!(
716 result.stats.entities_upserted, 1,
717 "one entity should be upserted"
718 );
719
720 let entity = gs
724 .find_entity("rust", crate::graph::EntityType::Language)
725 .await
726 .unwrap()
727 .expect("entity 'rust' must exist in SQLite");
728
729 assert!(
730 entity.qdrant_point_id.is_some(),
731 "qdrant_point_id must be set when embedding_store + provider are both provided (regression for #1829)"
732 );
733 }
734
735 #[tokio::test]
738 async fn extract_and_store_without_embedding_store_still_upserts_entities() {
739 let (gs, _emb) = setup().await;
740
741 let extraction_json = r#"{"entities":[{"name":"Python","type":"language","summary":"scripting"}],"edges":[]}"#;
742 let mock = zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
743 let provider = AnyProvider::Mock(mock);
744
745 let config = GraphExtractionConfig {
746 max_entities: 10,
747 max_edges: 10,
748 extraction_timeout_secs: 10,
749 ..Default::default()
750 };
751
752 let result = extract_and_store(
753 "Python is a scripting language.".to_owned(),
754 vec![],
755 provider,
756 gs.pool().clone(),
757 config,
758 None,
759 None, )
761 .await
762 .unwrap();
763
764 assert_eq!(result.stats.entities_upserted, 1);
765
766 let entity = gs
767 .find_entity("python", crate::graph::EntityType::Language)
768 .await
769 .unwrap()
770 .expect("entity 'python' must exist");
771
772 assert!(
773 entity.qdrant_point_id.is_none(),
774 "qdrant_point_id must remain None when no embedding_store is provided"
775 );
776 }
777
778 #[tokio::test]
782 async fn extract_and_store_fts5_cross_session_visibility() {
783 let file = tempfile::NamedTempFile::new().expect("tempfile");
784 let path = file.path().to_str().expect("valid path").to_string();
785
786 {
788 let sqlite = crate::store::SqliteStore::new(&path).await.unwrap();
789 let extraction_json = r#"{"entities":[{"name":"Ferris","type":"concept","summary":"Rust mascot"}],"edges":[]}"#;
790 let mock =
791 zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
792 let provider = AnyProvider::Mock(mock);
793 let config = GraphExtractionConfig {
794 max_entities: 10,
795 max_edges: 10,
796 extraction_timeout_secs: 10,
797 ..Default::default()
798 };
799 extract_and_store(
800 "Ferris is the Rust mascot.".to_owned(),
801 vec![],
802 provider,
803 sqlite.pool().clone(),
804 config,
805 None,
806 None,
807 )
808 .await
809 .unwrap();
810 }
811
812 let sqlite_b = crate::store::SqliteStore::new(&path).await.unwrap();
814 let gs_b = crate::graph::GraphStore::new(sqlite_b.pool().clone());
815 let results = gs_b.find_entities_fuzzy("Ferris", 10).await.unwrap();
816 assert!(
817 !results.is_empty(),
818 "FTS5 cross-session (#2166): entity extracted in session A must be visible in session B"
819 );
820 }
821
822 #[tokio::test]
825 async fn extract_and_store_skips_self_loop_edges() {
826 let (gs, _emb) = setup().await;
827
828 let extraction_json = r#"{
830 "entities":[{"name":"Rust","type":"language","summary":"systems language"}],
831 "edges":[{"source":"Rust","target":"Rust","relation":"is","fact":"Rust is Rust","edge_type":"semantic"}]
832 }"#;
833 let mock = zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
834 let provider = AnyProvider::Mock(mock);
835
836 let config = GraphExtractionConfig {
837 max_entities: 10,
838 max_edges: 10,
839 extraction_timeout_secs: 10,
840 ..Default::default()
841 };
842
843 let result = extract_and_store(
844 "Rust is a language.".to_owned(),
845 vec![],
846 provider,
847 gs.pool().clone(),
848 config,
849 None,
850 None,
851 )
852 .await
853 .unwrap();
854
855 assert_eq!(result.stats.entities_upserted, 1);
856 assert_eq!(
857 result.stats.edges_inserted, 0,
858 "self-loop edge must be rejected (#2215)"
859 );
860 }
861}