1use std::sync::Arc;
5use std::sync::atomic::Ordering;
6
7use zeph_llm::any::AnyProvider;
8use zeph_llm::provider::LlmProvider as _;
9
10use crate::embedding_store::EmbeddingStore;
11use crate::error::MemoryError;
12use crate::graph::extractor::ExtractionResult as ExtractorResult;
13use crate::vector_store::VectorFilter;
14
15use super::SemanticMemory;
16
17pub type PostExtractValidator = Option<Box<dyn Fn(&ExtractorResult) -> Result<(), String> + Send>>;
22
23#[derive(Debug, Clone)]
28pub struct GraphExtractionConfig {
29 pub max_entities: usize,
30 pub max_edges: usize,
31 pub extraction_timeout_secs: u64,
32 pub community_refresh_interval: usize,
33 pub expired_edge_retention_days: u32,
34 pub max_entities_cap: usize,
35 pub community_summary_max_prompt_bytes: usize,
36 pub community_summary_concurrency: usize,
37 pub lpa_edge_chunk_size: usize,
38 pub note_linking: NoteLinkingConfig,
40 pub link_weight_decay_lambda: f64,
42 pub link_weight_decay_interval_secs: u64,
44}
45
46impl Default for GraphExtractionConfig {
47 fn default() -> Self {
48 Self {
49 max_entities: 0,
50 max_edges: 0,
51 extraction_timeout_secs: 0,
52 community_refresh_interval: 0,
53 expired_edge_retention_days: 0,
54 max_entities_cap: 0,
55 community_summary_max_prompt_bytes: 0,
56 community_summary_concurrency: 0,
57 lpa_edge_chunk_size: 0,
58 note_linking: NoteLinkingConfig::default(),
59 link_weight_decay_lambda: 0.95,
60 link_weight_decay_interval_secs: 86400,
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
67pub struct NoteLinkingConfig {
68 pub enabled: bool,
69 pub similarity_threshold: f32,
70 pub top_k: usize,
71 pub timeout_secs: u64,
72}
73
74impl Default for NoteLinkingConfig {
75 fn default() -> Self {
76 Self {
77 enabled: false,
78 similarity_threshold: 0.85,
79 top_k: 10,
80 timeout_secs: 5,
81 }
82 }
83}
84
85#[derive(Debug, Default)]
87pub struct ExtractionStats {
88 pub entities_upserted: usize,
89 pub edges_inserted: usize,
90}
91
92#[derive(Debug, Default)]
94pub struct ExtractionResult {
95 pub stats: ExtractionStats,
96 pub entity_ids: Vec<i64>,
98}
99
100#[derive(Debug, Default)]
102pub struct LinkingStats {
103 pub entities_processed: usize,
104 pub edges_created: usize,
105}
106
107const ENTITY_COLLECTION: &str = "zeph_graph_entities";
109
110struct EntityWorkItem {
112 entity_id: i64,
113 canonical_name: String,
114 embed_text: String,
115 self_point_id: Option<String>,
116}
117
118#[allow(clippy::too_many_lines)]
134pub async fn link_memory_notes(
135 entity_ids: &[i64],
136 pool: sqlx::SqlitePool,
137 embedding_store: Arc<EmbeddingStore>,
138 provider: AnyProvider,
139 cfg: &NoteLinkingConfig,
140) -> LinkingStats {
141 use futures::future;
142
143 use crate::graph::GraphStore;
144
145 let store = GraphStore::new(pool);
146 let mut stats = LinkingStats::default();
147
148 let mut work_items: Vec<EntityWorkItem> = Vec::with_capacity(entity_ids.len());
150 for &entity_id in entity_ids {
151 let entity = match store.find_entity_by_id(entity_id).await {
152 Ok(Some(e)) => e,
153 Ok(None) => {
154 tracing::debug!("note_linking: entity {entity_id} not found, skipping");
155 continue;
156 }
157 Err(e) => {
158 tracing::debug!("note_linking: DB error loading entity {entity_id}: {e:#}");
159 continue;
160 }
161 };
162 let embed_text = match &entity.summary {
163 Some(s) if !s.is_empty() => format!("{}: {s}", entity.canonical_name),
164 _ => entity.canonical_name.clone(),
165 };
166 work_items.push(EntityWorkItem {
167 entity_id,
168 canonical_name: entity.canonical_name,
169 embed_text,
170 self_point_id: entity.qdrant_point_id,
171 });
172 }
173
174 if work_items.is_empty() {
175 return stats;
176 }
177
178 let embed_results: Vec<_> =
180 future::join_all(work_items.iter().map(|w| provider.embed(&w.embed_text))).await;
181
182 let search_limit = cfg.top_k + 1; let valid: Vec<(usize, Vec<f32>)> = embed_results
185 .into_iter()
186 .enumerate()
187 .filter_map(|(i, r)| match r {
188 Ok(v) => Some((i, v)),
189 Err(e) => {
190 tracing::debug!(
191 "note_linking: embed failed for entity {:?}: {e:#}",
192 work_items[i].canonical_name
193 );
194 None
195 }
196 })
197 .collect();
198
199 let search_results: Vec<_> = future::join_all(valid.iter().map(|(_, vec)| {
200 embedding_store.search_collection(
201 ENTITY_COLLECTION,
202 vec,
203 search_limit,
204 None::<VectorFilter>,
205 )
206 }))
207 .await;
208
209 let mut seen_pairs = std::collections::HashSet::new();
214
215 for ((work_idx, _), search_result) in valid.iter().zip(search_results.iter()) {
216 let w = &work_items[*work_idx];
217
218 let results = match search_result {
219 Ok(r) => r,
220 Err(e) => {
221 tracing::debug!(
222 "note_linking: search failed for entity {:?}: {e:#}",
223 w.canonical_name
224 );
225 continue;
226 }
227 };
228
229 stats.entities_processed += 1;
230
231 let self_point_id = w.self_point_id.as_deref();
232 let candidates = results
233 .iter()
234 .filter(|p| Some(p.id.as_str()) != self_point_id && p.score >= cfg.similarity_threshold)
235 .take(cfg.top_k);
236
237 for point in candidates {
238 let Some(target_id) = point
239 .payload
240 .get("entity_id")
241 .and_then(serde_json::Value::as_i64)
242 else {
243 tracing::debug!(
244 "note_linking: missing entity_id in payload for point {}",
245 point.id
246 );
247 continue;
248 };
249
250 if target_id == w.entity_id {
251 continue; }
253
254 let (src, tgt) = if w.entity_id < target_id {
256 (w.entity_id, target_id)
257 } else {
258 (target_id, w.entity_id)
259 };
260
261 if !seen_pairs.insert((src, tgt)) {
263 continue;
264 }
265
266 let fact = format!("Semantically similar entities (score: {:.3})", point.score);
267
268 match store
269 .insert_edge(src, tgt, "similar_to", &fact, point.score, None)
270 .await
271 {
272 Ok(_) => stats.edges_created += 1,
273 Err(e) => {
274 tracing::debug!("note_linking: insert_edge failed: {e:#}");
275 }
276 }
277 }
278 }
279
280 stats
281}
282
283#[allow(clippy::too_many_lines)]
294pub async fn extract_and_store(
295 content: String,
296 context_messages: Vec<String>,
297 provider: AnyProvider,
298 pool: sqlx::SqlitePool,
299 config: GraphExtractionConfig,
300 post_extract_validator: PostExtractValidator,
301 embedding_store: Option<Arc<EmbeddingStore>>,
302) -> Result<ExtractionResult, MemoryError> {
303 use crate::graph::{EntityResolver, GraphExtractor, GraphStore};
304
305 let extractor = GraphExtractor::new(provider.clone(), config.max_entities, config.max_edges);
306 let ctx_refs: Vec<&str> = context_messages.iter().map(String::as_str).collect();
307
308 let store = GraphStore::new(pool);
309
310 let pool = store.pool();
311 sqlx::query(
312 "INSERT INTO graph_metadata (key, value) VALUES ('extraction_count', '0')
313 ON CONFLICT(key) DO NOTHING",
314 )
315 .execute(pool)
316 .await?;
317 sqlx::query(
318 "UPDATE graph_metadata
319 SET value = CAST(CAST(value AS INTEGER) + 1 AS TEXT)
320 WHERE key = 'extraction_count'",
321 )
322 .execute(pool)
323 .await?;
324
325 let Some(result) = extractor.extract(&content, &ctx_refs).await? else {
326 return Ok(ExtractionResult::default());
327 };
328
329 if let Some(ref validator) = post_extract_validator
332 && let Err(reason) = validator(&result)
333 {
334 tracing::warn!(
335 reason,
336 "graph extraction validation failed, skipping upsert"
337 );
338 return Ok(ExtractionResult::default());
339 }
340
341 let resolver = if let Some(ref emb) = embedding_store {
342 EntityResolver::new(&store)
343 .with_embedding_store(emb)
344 .with_provider(&provider)
345 } else {
346 EntityResolver::new(&store)
347 };
348
349 let mut entities_upserted = 0usize;
350 let mut entity_name_to_id: std::collections::HashMap<String, i64> =
351 std::collections::HashMap::new();
352
353 for entity in &result.entities {
354 match resolver
355 .resolve(&entity.name, &entity.entity_type, entity.summary.as_deref())
356 .await
357 {
358 Ok((id, _outcome)) => {
359 entity_name_to_id.insert(entity.name.clone(), id);
360 entities_upserted += 1;
361 }
362 Err(e) => {
363 tracing::debug!("graph: skipping entity {:?}: {e:#}", entity.name);
364 }
365 }
366 }
367
368 let mut edges_inserted = 0usize;
369 for edge in &result.edges {
370 let (Some(&src_id), Some(&tgt_id)) = (
371 entity_name_to_id.get(&edge.source),
372 entity_name_to_id.get(&edge.target),
373 ) else {
374 tracing::debug!(
375 "graph: skipping edge {:?}->{:?}: entity not resolved",
376 edge.source,
377 edge.target
378 );
379 continue;
380 };
381 if src_id == tgt_id {
382 tracing::debug!(
383 "graph: skipping self-loop edge {:?}->{:?} (entity_id={src_id})",
384 edge.source,
385 edge.target
386 );
387 continue;
388 }
389 let edge_type = edge
392 .edge_type
393 .parse::<crate::graph::EdgeType>()
394 .unwrap_or_else(|_| {
395 tracing::warn!(
396 raw_type = %edge.edge_type,
397 "graph: unknown edge_type from LLM, defaulting to semantic"
398 );
399 crate::graph::EdgeType::Semantic
400 });
401 match resolver
402 .resolve_edge_typed(
403 src_id,
404 tgt_id,
405 &edge.relation,
406 &edge.fact,
407 0.8,
408 None,
409 edge_type,
410 )
411 .await
412 {
413 Ok(Some(_)) => edges_inserted += 1,
414 Ok(None) => {} Err(e) => {
416 tracing::debug!("graph: skipping edge: {e:#}");
417 }
418 }
419 }
420
421 store.checkpoint_wal().await?;
422
423 let new_entity_ids: Vec<i64> = entity_name_to_id.into_values().collect();
424
425 Ok(ExtractionResult {
426 stats: ExtractionStats {
427 entities_upserted,
428 edges_inserted,
429 },
430 entity_ids: new_entity_ids,
431 })
432}
433
434#[cfg(test)]
435mod tests {
436 use std::sync::Arc;
437
438 use zeph_llm::any::AnyProvider;
439
440 use super::extract_and_store;
441 use crate::embedding_store::EmbeddingStore;
442 use crate::graph::GraphStore;
443 use crate::in_memory_store::InMemoryVectorStore;
444 use crate::sqlite::SqliteStore;
445
446 use super::GraphExtractionConfig;
447
448 async fn setup() -> (GraphStore, Arc<EmbeddingStore>) {
449 let sqlite = SqliteStore::new(":memory:").await.unwrap();
450 let pool = sqlite.pool().clone();
451 let mem_store = Box::new(InMemoryVectorStore::new());
452 let emb = Arc::new(EmbeddingStore::with_store(mem_store, pool.clone()));
453 let gs = GraphStore::new(pool);
454 (gs, emb)
455 }
456
457 #[tokio::test]
460 async fn extract_and_store_sets_qdrant_point_id_when_embedding_store_provided() {
461 let (gs, emb) = setup().await;
462
463 let extraction_json = r#"{"entities":[{"name":"Rust","type":"language","summary":"systems language"}],"edges":[]}"#;
465 let mut mock =
466 zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
467 mock.supports_embeddings = true;
468 mock.embedding = vec![1.0_f32, 0.0, 0.0, 0.0];
469 let provider = AnyProvider::Mock(mock);
470
471 let config = GraphExtractionConfig {
472 max_entities: 10,
473 max_edges: 10,
474 extraction_timeout_secs: 10,
475 ..Default::default()
476 };
477
478 let result = extract_and_store(
479 "Rust is a systems programming language.".to_owned(),
480 vec![],
481 provider,
482 gs.pool().clone(),
483 config,
484 None,
485 Some(emb.clone()),
486 )
487 .await
488 .unwrap();
489
490 assert_eq!(
491 result.stats.entities_upserted, 1,
492 "one entity should be upserted"
493 );
494
495 let entity = gs
499 .find_entity("rust", crate::graph::EntityType::Language)
500 .await
501 .unwrap()
502 .expect("entity 'rust' must exist in SQLite");
503
504 assert!(
505 entity.qdrant_point_id.is_some(),
506 "qdrant_point_id must be set when embedding_store + provider are both provided (regression for #1829)"
507 );
508 }
509
510 #[tokio::test]
513 async fn extract_and_store_without_embedding_store_still_upserts_entities() {
514 let (gs, _emb) = setup().await;
515
516 let extraction_json = r#"{"entities":[{"name":"Python","type":"language","summary":"scripting"}],"edges":[]}"#;
517 let mock = zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
518 let provider = AnyProvider::Mock(mock);
519
520 let config = GraphExtractionConfig {
521 max_entities: 10,
522 max_edges: 10,
523 extraction_timeout_secs: 10,
524 ..Default::default()
525 };
526
527 let result = extract_and_store(
528 "Python is a scripting language.".to_owned(),
529 vec![],
530 provider,
531 gs.pool().clone(),
532 config,
533 None,
534 None, )
536 .await
537 .unwrap();
538
539 assert_eq!(result.stats.entities_upserted, 1);
540
541 let entity = gs
542 .find_entity("python", crate::graph::EntityType::Language)
543 .await
544 .unwrap()
545 .expect("entity 'python' must exist");
546
547 assert!(
548 entity.qdrant_point_id.is_none(),
549 "qdrant_point_id must remain None when no embedding_store is provided"
550 );
551 }
552
553 #[tokio::test]
557 async fn extract_and_store_fts5_cross_session_visibility() {
558 let file = tempfile::NamedTempFile::new().expect("tempfile");
559 let path = file.path().to_str().expect("valid path").to_string();
560
561 {
563 let sqlite = crate::sqlite::SqliteStore::new(&path).await.unwrap();
564 let extraction_json = r#"{"entities":[{"name":"Ferris","type":"concept","summary":"Rust mascot"}],"edges":[]}"#;
565 let mock =
566 zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
567 let provider = AnyProvider::Mock(mock);
568 let config = GraphExtractionConfig {
569 max_entities: 10,
570 max_edges: 10,
571 extraction_timeout_secs: 10,
572 ..Default::default()
573 };
574 extract_and_store(
575 "Ferris is the Rust mascot.".to_owned(),
576 vec![],
577 provider,
578 sqlite.pool().clone(),
579 config,
580 None,
581 None,
582 )
583 .await
584 .unwrap();
585 }
586
587 let sqlite_b = crate::sqlite::SqliteStore::new(&path).await.unwrap();
589 let gs_b = crate::graph::GraphStore::new(sqlite_b.pool().clone());
590 let results = gs_b.find_entities_fuzzy("Ferris", 10).await.unwrap();
591 assert!(
592 !results.is_empty(),
593 "FTS5 cross-session (#2166): entity extracted in session A must be visible in session B"
594 );
595 }
596
597 #[tokio::test]
600 async fn extract_and_store_skips_self_loop_edges() {
601 let (gs, _emb) = setup().await;
602
603 let extraction_json = r#"{
605 "entities":[{"name":"Rust","type":"language","summary":"systems language"}],
606 "edges":[{"source":"Rust","target":"Rust","relation":"is","fact":"Rust is Rust","edge_type":"semantic"}]
607 }"#;
608 let mock = zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
609 let provider = AnyProvider::Mock(mock);
610
611 let config = GraphExtractionConfig {
612 max_entities: 10,
613 max_edges: 10,
614 extraction_timeout_secs: 10,
615 ..Default::default()
616 };
617
618 let result = extract_and_store(
619 "Rust is a language.".to_owned(),
620 vec![],
621 provider,
622 gs.pool().clone(),
623 config,
624 None,
625 None,
626 )
627 .await
628 .unwrap();
629
630 assert_eq!(result.stats.entities_upserted, 1);
631 assert_eq!(
632 result.stats.edges_inserted, 0,
633 "self-loop edge must be rejected (#2215)"
634 );
635 }
636}
637
638impl SemanticMemory {
639 #[allow(clippy::too_many_lines)]
651 pub fn spawn_graph_extraction(
652 &self,
653 content: String,
654 context_messages: Vec<String>,
655 config: GraphExtractionConfig,
656 post_extract_validator: PostExtractValidator,
657 ) -> tokio::task::JoinHandle<()> {
658 let pool = self.sqlite.pool().clone();
659 let provider = self.provider.clone();
660 let failure_counter = self.community_detection_failures.clone();
661 let extraction_count = self.graph_extraction_count.clone();
662 let extraction_failures = self.graph_extraction_failures.clone();
663 let embedding_store = self.qdrant.clone();
665
666 tokio::spawn(async move {
667 let timeout_dur = std::time::Duration::from_secs(config.extraction_timeout_secs);
668 let extraction_result = tokio::time::timeout(
669 timeout_dur,
670 extract_and_store(
671 content,
672 context_messages,
673 provider.clone(),
674 pool.clone(),
675 config.clone(),
676 post_extract_validator,
677 embedding_store.clone(),
678 ),
679 )
680 .await;
681
682 let (extraction_ok, new_entity_ids) = match extraction_result {
683 Ok(Ok(result)) => {
684 tracing::debug!(
685 entities = result.stats.entities_upserted,
686 edges = result.stats.edges_inserted,
687 "graph extraction completed"
688 );
689 extraction_count.fetch_add(1, Ordering::Relaxed);
690 (true, result.entity_ids)
691 }
692 Ok(Err(e)) => {
693 tracing::warn!("graph extraction failed: {e:#}");
694 extraction_failures.fetch_add(1, Ordering::Relaxed);
695 (false, vec![])
696 }
697 Err(_elapsed) => {
698 tracing::warn!("graph extraction timed out");
699 extraction_failures.fetch_add(1, Ordering::Relaxed);
700 (false, vec![])
701 }
702 };
703
704 if extraction_ok
706 && config.note_linking.enabled
707 && !new_entity_ids.is_empty()
708 && let Some(store) = embedding_store
709 {
710 let linking_timeout =
711 std::time::Duration::from_secs(config.note_linking.timeout_secs);
712 match tokio::time::timeout(
713 linking_timeout,
714 link_memory_notes(
715 &new_entity_ids,
716 pool.clone(),
717 store,
718 provider.clone(),
719 &config.note_linking,
720 ),
721 )
722 .await
723 {
724 Ok(stats) => {
725 tracing::debug!(
726 entities_processed = stats.entities_processed,
727 edges_created = stats.edges_created,
728 "note linking completed"
729 );
730 }
731 Err(_elapsed) => {
732 tracing::debug!("note linking timed out (partial edges may exist)");
733 }
734 }
735 }
736
737 if extraction_ok && config.community_refresh_interval > 0 {
738 use crate::graph::GraphStore;
739
740 let store = GraphStore::new(pool.clone());
741 let extraction_count = store.extraction_count().await.unwrap_or(0);
742 if extraction_count > 0
743 && i64::try_from(config.community_refresh_interval)
744 .is_ok_and(|interval| extraction_count % interval == 0)
745 {
746 tracing::info!(extraction_count, "triggering community detection refresh");
747 let store2 = GraphStore::new(pool);
748 let provider2 = provider;
749 let retention_days = config.expired_edge_retention_days;
750 let max_cap = config.max_entities_cap;
751 let max_prompt_bytes = config.community_summary_max_prompt_bytes;
752 let concurrency = config.community_summary_concurrency;
753 let edge_chunk_size = config.lpa_edge_chunk_size;
754 let decay_lambda = config.link_weight_decay_lambda;
755 let decay_interval_secs = config.link_weight_decay_interval_secs;
756 tokio::spawn(async move {
757 match crate::graph::community::detect_communities(
758 &store2,
759 &provider2,
760 max_prompt_bytes,
761 concurrency,
762 edge_chunk_size,
763 )
764 .await
765 {
766 Ok(count) => {
767 tracing::info!(communities = count, "community detection complete");
768 }
769 Err(e) => {
770 tracing::warn!("community detection failed: {e:#}");
771 failure_counter.fetch_add(1, Ordering::Relaxed);
772 }
773 }
774 match crate::graph::community::run_graph_eviction(
775 &store2,
776 retention_days,
777 max_cap,
778 )
779 .await
780 {
781 Ok(stats) => {
782 tracing::info!(
783 expired_edges = stats.expired_edges_deleted,
784 orphan_entities = stats.orphan_entities_deleted,
785 capped_entities = stats.capped_entities_deleted,
786 "graph eviction complete"
787 );
788 }
789 Err(e) => {
790 tracing::warn!("graph eviction failed: {e:#}");
791 }
792 }
793
794 if decay_lambda > 0.0 && decay_interval_secs > 0 {
796 let now_secs = std::time::SystemTime::now()
797 .duration_since(std::time::UNIX_EPOCH)
798 .map(|d| d.as_secs())
799 .unwrap_or(0);
800 let last_decay = store2
801 .get_metadata("last_link_weight_decay_at")
802 .await
803 .ok()
804 .flatten()
805 .and_then(|s| s.parse::<u64>().ok())
806 .unwrap_or(0);
807 if now_secs.saturating_sub(last_decay) >= decay_interval_secs {
808 match store2
809 .decay_edge_retrieval_counts(decay_lambda, decay_interval_secs)
810 .await
811 {
812 Ok(affected) => {
813 tracing::info!(affected, "link weight decay applied");
814 let _ = store2
815 .set_metadata(
816 "last_link_weight_decay_at",
817 &now_secs.to_string(),
818 )
819 .await;
820 }
821 Err(e) => {
822 tracing::warn!("link weight decay failed: {e:#}");
823 }
824 }
825 }
826 }
827 });
828 }
829 }
830 })
831 }
832}