1use std::sync::Arc;
5#[allow(unused_imports)]
6use zeph_db::sql;
7
8use std::sync::atomic::Ordering;
9use zeph_db::DbPool;
10
11use zeph_llm::any::AnyProvider;
12use zeph_llm::provider::LlmProvider as _;
13
14use crate::embedding_store::EmbeddingStore;
15use crate::error::MemoryError;
16use crate::graph::extractor::ExtractionResult as ExtractorResult;
17use crate::vector_store::VectorFilter;
18
19use super::SemanticMemory;
20
21pub type PostExtractValidator = Option<Box<dyn Fn(&ExtractorResult) -> Result<(), String> + Send>>;
26
27#[derive(Debug, Clone)]
32pub struct GraphExtractionConfig {
33 pub max_entities: usize,
34 pub max_edges: usize,
35 pub extraction_timeout_secs: u64,
36 pub community_refresh_interval: usize,
37 pub expired_edge_retention_days: u32,
38 pub max_entities_cap: usize,
39 pub community_summary_max_prompt_bytes: usize,
40 pub community_summary_concurrency: usize,
41 pub lpa_edge_chunk_size: usize,
42 pub note_linking: NoteLinkingConfig,
44 pub link_weight_decay_lambda: f64,
46 pub link_weight_decay_interval_secs: u64,
48}
49
50impl Default for GraphExtractionConfig {
51 fn default() -> Self {
52 Self {
53 max_entities: 0,
54 max_edges: 0,
55 extraction_timeout_secs: 0,
56 community_refresh_interval: 0,
57 expired_edge_retention_days: 0,
58 max_entities_cap: 0,
59 community_summary_max_prompt_bytes: 0,
60 community_summary_concurrency: 0,
61 lpa_edge_chunk_size: 0,
62 note_linking: NoteLinkingConfig::default(),
63 link_weight_decay_lambda: 0.95,
64 link_weight_decay_interval_secs: 86400,
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
71pub struct NoteLinkingConfig {
72 pub enabled: bool,
73 pub similarity_threshold: f32,
74 pub top_k: usize,
75 pub timeout_secs: u64,
76}
77
78impl Default for NoteLinkingConfig {
79 fn default() -> Self {
80 Self {
81 enabled: false,
82 similarity_threshold: 0.85,
83 top_k: 10,
84 timeout_secs: 5,
85 }
86 }
87}
88
89#[derive(Debug, Default)]
91pub struct ExtractionStats {
92 pub entities_upserted: usize,
93 pub edges_inserted: usize,
94}
95
96#[derive(Debug, Default)]
98pub struct ExtractionResult {
99 pub stats: ExtractionStats,
100 pub entity_ids: Vec<i64>,
102}
103
104#[derive(Debug, Default)]
106pub struct LinkingStats {
107 pub entities_processed: usize,
108 pub edges_created: usize,
109}
110
111const ENTITY_COLLECTION: &str = "zeph_graph_entities";
113
114struct EntityWorkItem {
116 entity_id: i64,
117 canonical_name: String,
118 embed_text: String,
119 self_point_id: Option<String>,
120}
121
122#[allow(clippy::too_many_lines)]
138pub async fn link_memory_notes(
139 entity_ids: &[i64],
140 pool: DbPool,
141 embedding_store: Arc<EmbeddingStore>,
142 provider: AnyProvider,
143 cfg: &NoteLinkingConfig,
144) -> LinkingStats {
145 use futures::future;
146
147 use crate::graph::GraphStore;
148
149 let store = GraphStore::new(pool);
150 let mut stats = LinkingStats::default();
151
152 let mut work_items: Vec<EntityWorkItem> = Vec::with_capacity(entity_ids.len());
154 for &entity_id in entity_ids {
155 let entity = match store.find_entity_by_id(entity_id).await {
156 Ok(Some(e)) => e,
157 Ok(None) => {
158 tracing::debug!("note_linking: entity {entity_id} not found, skipping");
159 continue;
160 }
161 Err(e) => {
162 tracing::debug!("note_linking: DB error loading entity {entity_id}: {e:#}");
163 continue;
164 }
165 };
166 let embed_text = match &entity.summary {
167 Some(s) if !s.is_empty() => format!("{}: {s}", entity.canonical_name),
168 _ => entity.canonical_name.clone(),
169 };
170 work_items.push(EntityWorkItem {
171 entity_id,
172 canonical_name: entity.canonical_name,
173 embed_text,
174 self_point_id: entity.qdrant_point_id,
175 });
176 }
177
178 if work_items.is_empty() {
179 return stats;
180 }
181
182 let embed_results: Vec<_> =
184 future::join_all(work_items.iter().map(|w| provider.embed(&w.embed_text))).await;
185
186 let search_limit = cfg.top_k + 1; let valid: Vec<(usize, Vec<f32>)> = embed_results
189 .into_iter()
190 .enumerate()
191 .filter_map(|(i, r)| match r {
192 Ok(v) => Some((i, v)),
193 Err(e) => {
194 tracing::debug!(
195 "note_linking: embed failed for entity {:?}: {e:#}",
196 work_items[i].canonical_name
197 );
198 None
199 }
200 })
201 .collect();
202
203 let search_results: Vec<_> = future::join_all(valid.iter().map(|(_, vec)| {
204 embedding_store.search_collection(
205 ENTITY_COLLECTION,
206 vec,
207 search_limit,
208 None::<VectorFilter>,
209 )
210 }))
211 .await;
212
213 let mut seen_pairs = std::collections::HashSet::new();
218
219 for ((work_idx, _), search_result) in valid.iter().zip(search_results.iter()) {
220 let w = &work_items[*work_idx];
221
222 let results = match search_result {
223 Ok(r) => r,
224 Err(e) => {
225 tracing::debug!(
226 "note_linking: search failed for entity {:?}: {e:#}",
227 w.canonical_name
228 );
229 continue;
230 }
231 };
232
233 stats.entities_processed += 1;
234
235 let self_point_id = w.self_point_id.as_deref();
236 let candidates = results
237 .iter()
238 .filter(|p| Some(p.id.as_str()) != self_point_id && p.score >= cfg.similarity_threshold)
239 .take(cfg.top_k);
240
241 for point in candidates {
242 let Some(target_id) = point
243 .payload
244 .get("entity_id")
245 .and_then(serde_json::Value::as_i64)
246 else {
247 tracing::debug!(
248 "note_linking: missing entity_id in payload for point {}",
249 point.id
250 );
251 continue;
252 };
253
254 if target_id == w.entity_id {
255 continue; }
257
258 let (src, tgt) = if w.entity_id < target_id {
260 (w.entity_id, target_id)
261 } else {
262 (target_id, w.entity_id)
263 };
264
265 if !seen_pairs.insert((src, tgt)) {
267 continue;
268 }
269
270 let fact = format!("Semantically similar entities (score: {:.3})", point.score);
271
272 match store
273 .insert_edge(src, tgt, "similar_to", &fact, point.score, None)
274 .await
275 {
276 Ok(_) => stats.edges_created += 1,
277 Err(e) => {
278 tracing::debug!("note_linking: insert_edge failed: {e:#}");
279 }
280 }
281 }
282 }
283
284 stats
285}
286
287#[allow(clippy::too_many_lines)]
298pub async fn extract_and_store(
299 content: String,
300 context_messages: Vec<String>,
301 provider: AnyProvider,
302 pool: DbPool,
303 config: GraphExtractionConfig,
304 post_extract_validator: PostExtractValidator,
305 embedding_store: Option<Arc<EmbeddingStore>>,
306) -> Result<ExtractionResult, MemoryError> {
307 use crate::graph::{EntityResolver, GraphExtractor, GraphStore};
308
309 let extractor = GraphExtractor::new(provider.clone(), config.max_entities, config.max_edges);
310 let ctx_refs: Vec<&str> = context_messages.iter().map(String::as_str).collect();
311
312 let store = GraphStore::new(pool);
313
314 let pool = store.pool();
315 zeph_db::query(sql!(
316 "INSERT INTO graph_metadata (key, value) VALUES ('extraction_count', '0')
317 ON CONFLICT(key) DO NOTHING"
318 ))
319 .execute(pool)
320 .await?;
321 zeph_db::query(sql!(
322 "UPDATE graph_metadata
323 SET value = CAST(CAST(value AS INTEGER) + 1 AS TEXT)
324 WHERE key = 'extraction_count'"
325 ))
326 .execute(pool)
327 .await?;
328
329 let Some(result) = extractor.extract(&content, &ctx_refs).await? else {
330 return Ok(ExtractionResult::default());
331 };
332
333 if let Some(ref validator) = post_extract_validator
336 && let Err(reason) = validator(&result)
337 {
338 tracing::warn!(
339 reason,
340 "graph extraction validation failed, skipping upsert"
341 );
342 return Ok(ExtractionResult::default());
343 }
344
345 let resolver = if let Some(ref emb) = embedding_store {
346 EntityResolver::new(&store)
347 .with_embedding_store(emb)
348 .with_provider(&provider)
349 } else {
350 EntityResolver::new(&store)
351 };
352
353 let mut entities_upserted = 0usize;
354 let mut entity_name_to_id: std::collections::HashMap<String, i64> =
355 std::collections::HashMap::new();
356
357 for entity in &result.entities {
358 match resolver
359 .resolve(&entity.name, &entity.entity_type, entity.summary.as_deref())
360 .await
361 {
362 Ok((id, _outcome)) => {
363 entity_name_to_id.insert(entity.name.clone(), id);
364 entities_upserted += 1;
365 }
366 Err(e) => {
367 tracing::debug!("graph: skipping entity {:?}: {e:#}", entity.name);
368 }
369 }
370 }
371
372 let mut edges_inserted = 0usize;
373 for edge in &result.edges {
374 let (Some(&src_id), Some(&tgt_id)) = (
375 entity_name_to_id.get(&edge.source),
376 entity_name_to_id.get(&edge.target),
377 ) else {
378 tracing::debug!(
379 "graph: skipping edge {:?}->{:?}: entity not resolved",
380 edge.source,
381 edge.target
382 );
383 continue;
384 };
385 if src_id == tgt_id {
386 tracing::debug!(
387 "graph: skipping self-loop edge {:?}->{:?} (entity_id={src_id})",
388 edge.source,
389 edge.target
390 );
391 continue;
392 }
393 let edge_type = edge
396 .edge_type
397 .parse::<crate::graph::EdgeType>()
398 .unwrap_or_else(|_| {
399 tracing::warn!(
400 raw_type = %edge.edge_type,
401 "graph: unknown edge_type from LLM, defaulting to semantic"
402 );
403 crate::graph::EdgeType::Semantic
404 });
405 match resolver
406 .resolve_edge_typed(
407 src_id,
408 tgt_id,
409 &edge.relation,
410 &edge.fact,
411 0.8,
412 None,
413 edge_type,
414 )
415 .await
416 {
417 Ok(Some(_)) => edges_inserted += 1,
418 Ok(None) => {} Err(e) => {
420 tracing::debug!("graph: skipping edge: {e:#}");
421 }
422 }
423 }
424
425 store.checkpoint_wal().await?;
426
427 let new_entity_ids: Vec<i64> = entity_name_to_id.into_values().collect();
428
429 Ok(ExtractionResult {
430 stats: ExtractionStats {
431 entities_upserted,
432 edges_inserted,
433 },
434 entity_ids: new_entity_ids,
435 })
436}
437
438impl SemanticMemory {
439 #[allow(clippy::too_many_lines)]
451 pub fn spawn_graph_extraction(
452 &self,
453 content: String,
454 context_messages: Vec<String>,
455 config: GraphExtractionConfig,
456 post_extract_validator: PostExtractValidator,
457 ) -> tokio::task::JoinHandle<()> {
458 let pool = self.sqlite.pool().clone();
459 let provider = self.provider.clone();
460 let failure_counter = self.community_detection_failures.clone();
461 let extraction_count = self.graph_extraction_count.clone();
462 let extraction_failures = self.graph_extraction_failures.clone();
463 let embedding_store = self.qdrant.clone();
465
466 tokio::spawn(async move {
467 let timeout_dur = std::time::Duration::from_secs(config.extraction_timeout_secs);
468 let extraction_result = tokio::time::timeout(
469 timeout_dur,
470 extract_and_store(
471 content,
472 context_messages,
473 provider.clone(),
474 pool.clone(),
475 config.clone(),
476 post_extract_validator,
477 embedding_store.clone(),
478 ),
479 )
480 .await;
481
482 let (extraction_ok, new_entity_ids) = match extraction_result {
483 Ok(Ok(result)) => {
484 tracing::debug!(
485 entities = result.stats.entities_upserted,
486 edges = result.stats.edges_inserted,
487 "graph extraction completed"
488 );
489 extraction_count.fetch_add(1, Ordering::Relaxed);
490 (true, result.entity_ids)
491 }
492 Ok(Err(e)) => {
493 tracing::warn!("graph extraction failed: {e:#}");
494 extraction_failures.fetch_add(1, Ordering::Relaxed);
495 (false, vec![])
496 }
497 Err(_elapsed) => {
498 tracing::warn!("graph extraction timed out");
499 extraction_failures.fetch_add(1, Ordering::Relaxed);
500 (false, vec![])
501 }
502 };
503
504 if extraction_ok
506 && config.note_linking.enabled
507 && !new_entity_ids.is_empty()
508 && let Some(store) = embedding_store
509 {
510 let linking_timeout =
511 std::time::Duration::from_secs(config.note_linking.timeout_secs);
512 match tokio::time::timeout(
513 linking_timeout,
514 link_memory_notes(
515 &new_entity_ids,
516 pool.clone(),
517 store,
518 provider.clone(),
519 &config.note_linking,
520 ),
521 )
522 .await
523 {
524 Ok(stats) => {
525 tracing::debug!(
526 entities_processed = stats.entities_processed,
527 edges_created = stats.edges_created,
528 "note linking completed"
529 );
530 }
531 Err(_elapsed) => {
532 tracing::debug!("note linking timed out (partial edges may exist)");
533 }
534 }
535 }
536
537 if extraction_ok && config.community_refresh_interval > 0 {
538 use crate::graph::GraphStore;
539
540 let store = GraphStore::new(pool.clone());
541 let extraction_count = store.extraction_count().await.unwrap_or(0);
542 if extraction_count > 0
543 && i64::try_from(config.community_refresh_interval)
544 .is_ok_and(|interval| extraction_count % interval == 0)
545 {
546 tracing::info!(extraction_count, "triggering community detection refresh");
547 let store2 = GraphStore::new(pool);
548 let provider2 = provider;
549 let retention_days = config.expired_edge_retention_days;
550 let max_cap = config.max_entities_cap;
551 let max_prompt_bytes = config.community_summary_max_prompt_bytes;
552 let concurrency = config.community_summary_concurrency;
553 let edge_chunk_size = config.lpa_edge_chunk_size;
554 let decay_lambda = config.link_weight_decay_lambda;
555 let decay_interval_secs = config.link_weight_decay_interval_secs;
556 tokio::spawn(async move {
557 match crate::graph::community::detect_communities(
558 &store2,
559 &provider2,
560 max_prompt_bytes,
561 concurrency,
562 edge_chunk_size,
563 )
564 .await
565 {
566 Ok(count) => {
567 tracing::info!(communities = count, "community detection complete");
568 }
569 Err(e) => {
570 tracing::warn!("community detection failed: {e:#}");
571 failure_counter.fetch_add(1, Ordering::Relaxed);
572 }
573 }
574 match crate::graph::community::run_graph_eviction(
575 &store2,
576 retention_days,
577 max_cap,
578 )
579 .await
580 {
581 Ok(stats) => {
582 tracing::info!(
583 expired_edges = stats.expired_edges_deleted,
584 orphan_entities = stats.orphan_entities_deleted,
585 capped_entities = stats.capped_entities_deleted,
586 "graph eviction complete"
587 );
588 }
589 Err(e) => {
590 tracing::warn!("graph eviction failed: {e:#}");
591 }
592 }
593
594 if decay_lambda > 0.0 && decay_interval_secs > 0 {
596 let now_secs = std::time::SystemTime::now()
597 .duration_since(std::time::UNIX_EPOCH)
598 .map(|d| d.as_secs())
599 .unwrap_or(0);
600 let last_decay = store2
601 .get_metadata("last_link_weight_decay_at")
602 .await
603 .ok()
604 .flatten()
605 .and_then(|s| s.parse::<u64>().ok())
606 .unwrap_or(0);
607 if now_secs.saturating_sub(last_decay) >= decay_interval_secs {
608 match store2
609 .decay_edge_retrieval_counts(decay_lambda, decay_interval_secs)
610 .await
611 {
612 Ok(affected) => {
613 tracing::info!(affected, "link weight decay applied");
614 let _ = store2
615 .set_metadata(
616 "last_link_weight_decay_at",
617 &now_secs.to_string(),
618 )
619 .await;
620 }
621 Err(e) => {
622 tracing::warn!("link weight decay failed: {e:#}");
623 }
624 }
625 }
626 }
627 });
628 }
629 }
630 })
631 }
632}
633
634#[cfg(test)]
635mod tests {
636 use std::sync::Arc;
637
638 use zeph_llm::any::AnyProvider;
639
640 use super::extract_and_store;
641 use crate::embedding_store::EmbeddingStore;
642 use crate::graph::GraphStore;
643 use crate::in_memory_store::InMemoryVectorStore;
644 use crate::store::SqliteStore;
645
646 use super::GraphExtractionConfig;
647
648 async fn setup() -> (GraphStore, Arc<EmbeddingStore>) {
649 let sqlite = SqliteStore::new(":memory:").await.unwrap();
650 let pool = sqlite.pool().clone();
651 let mem_store = Box::new(InMemoryVectorStore::new());
652 let emb = Arc::new(EmbeddingStore::with_store(mem_store, pool.clone()));
653 let gs = GraphStore::new(pool);
654 (gs, emb)
655 }
656
657 #[tokio::test]
660 async fn extract_and_store_sets_qdrant_point_id_when_embedding_store_provided() {
661 let (gs, emb) = setup().await;
662
663 let extraction_json = r#"{"entities":[{"name":"Rust","type":"language","summary":"systems language"}],"edges":[]}"#;
665 let mut mock =
666 zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
667 mock.supports_embeddings = true;
668 mock.embedding = vec![1.0_f32, 0.0, 0.0, 0.0];
669 let provider = AnyProvider::Mock(mock);
670
671 let config = GraphExtractionConfig {
672 max_entities: 10,
673 max_edges: 10,
674 extraction_timeout_secs: 10,
675 ..Default::default()
676 };
677
678 let result = extract_and_store(
679 "Rust is a systems programming language.".to_owned(),
680 vec![],
681 provider,
682 gs.pool().clone(),
683 config,
684 None,
685 Some(emb.clone()),
686 )
687 .await
688 .unwrap();
689
690 assert_eq!(
691 result.stats.entities_upserted, 1,
692 "one entity should be upserted"
693 );
694
695 let entity = gs
699 .find_entity("rust", crate::graph::EntityType::Language)
700 .await
701 .unwrap()
702 .expect("entity 'rust' must exist in SQLite");
703
704 assert!(
705 entity.qdrant_point_id.is_some(),
706 "qdrant_point_id must be set when embedding_store + provider are both provided (regression for #1829)"
707 );
708 }
709
710 #[tokio::test]
713 async fn extract_and_store_without_embedding_store_still_upserts_entities() {
714 let (gs, _emb) = setup().await;
715
716 let extraction_json = r#"{"entities":[{"name":"Python","type":"language","summary":"scripting"}],"edges":[]}"#;
717 let mock = zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
718 let provider = AnyProvider::Mock(mock);
719
720 let config = GraphExtractionConfig {
721 max_entities: 10,
722 max_edges: 10,
723 extraction_timeout_secs: 10,
724 ..Default::default()
725 };
726
727 let result = extract_and_store(
728 "Python is a scripting language.".to_owned(),
729 vec![],
730 provider,
731 gs.pool().clone(),
732 config,
733 None,
734 None, )
736 .await
737 .unwrap();
738
739 assert_eq!(result.stats.entities_upserted, 1);
740
741 let entity = gs
742 .find_entity("python", crate::graph::EntityType::Language)
743 .await
744 .unwrap()
745 .expect("entity 'python' must exist");
746
747 assert!(
748 entity.qdrant_point_id.is_none(),
749 "qdrant_point_id must remain None when no embedding_store is provided"
750 );
751 }
752
753 #[tokio::test]
757 async fn extract_and_store_fts5_cross_session_visibility() {
758 let file = tempfile::NamedTempFile::new().expect("tempfile");
759 let path = file.path().to_str().expect("valid path").to_string();
760
761 {
763 let sqlite = crate::store::SqliteStore::new(&path).await.unwrap();
764 let extraction_json = r#"{"entities":[{"name":"Ferris","type":"concept","summary":"Rust mascot"}],"edges":[]}"#;
765 let mock =
766 zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
767 let provider = AnyProvider::Mock(mock);
768 let config = GraphExtractionConfig {
769 max_entities: 10,
770 max_edges: 10,
771 extraction_timeout_secs: 10,
772 ..Default::default()
773 };
774 extract_and_store(
775 "Ferris is the Rust mascot.".to_owned(),
776 vec![],
777 provider,
778 sqlite.pool().clone(),
779 config,
780 None,
781 None,
782 )
783 .await
784 .unwrap();
785 }
786
787 let sqlite_b = crate::store::SqliteStore::new(&path).await.unwrap();
789 let gs_b = crate::graph::GraphStore::new(sqlite_b.pool().clone());
790 let results = gs_b.find_entities_fuzzy("Ferris", 10).await.unwrap();
791 assert!(
792 !results.is_empty(),
793 "FTS5 cross-session (#2166): entity extracted in session A must be visible in session B"
794 );
795 }
796
797 #[tokio::test]
800 async fn extract_and_store_skips_self_loop_edges() {
801 let (gs, _emb) = setup().await;
802
803 let extraction_json = r#"{
805 "entities":[{"name":"Rust","type":"language","summary":"systems language"}],
806 "edges":[{"source":"Rust","target":"Rust","relation":"is","fact":"Rust is Rust","edge_type":"semantic"}]
807 }"#;
808 let mock = zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
809 let provider = AnyProvider::Mock(mock);
810
811 let config = GraphExtractionConfig {
812 max_entities: 10,
813 max_edges: 10,
814 extraction_timeout_secs: 10,
815 ..Default::default()
816 };
817
818 let result = extract_and_store(
819 "Rust is a language.".to_owned(),
820 vec![],
821 provider,
822 gs.pool().clone(),
823 config,
824 None,
825 None,
826 )
827 .await
828 .unwrap();
829
830 assert_eq!(result.stats.entities_upserted, 1);
831 assert_eq!(
832 result.stats.edges_inserted, 0,
833 "self-loop edge must be rejected (#2215)"
834 );
835 }
836}