1use std::sync::Arc;
5use std::sync::atomic::Ordering;
6
7use zeph_llm::any::AnyProvider;
8use zeph_llm::provider::LlmProvider as _;
9
10use crate::embedding_store::EmbeddingStore;
11use crate::error::MemoryError;
12use crate::graph::extractor::ExtractionResult as ExtractorResult;
13use crate::vector_store::VectorFilter;
14
15use super::SemanticMemory;
16
17pub type PostExtractValidator = Option<Box<dyn Fn(&ExtractorResult) -> Result<(), String> + Send>>;
22
23#[derive(Debug, Clone, Default)]
28pub struct GraphExtractionConfig {
29 pub max_entities: usize,
30 pub max_edges: usize,
31 pub extraction_timeout_secs: u64,
32 pub community_refresh_interval: usize,
33 pub expired_edge_retention_days: u32,
34 pub max_entities_cap: usize,
35 pub community_summary_max_prompt_bytes: usize,
36 pub community_summary_concurrency: usize,
37 pub lpa_edge_chunk_size: usize,
38 pub note_linking: NoteLinkingConfig,
40}
41
42#[derive(Debug, Clone)]
44pub struct NoteLinkingConfig {
45 pub enabled: bool,
46 pub similarity_threshold: f32,
47 pub top_k: usize,
48 pub timeout_secs: u64,
49}
50
51impl Default for NoteLinkingConfig {
52 fn default() -> Self {
53 Self {
54 enabled: false,
55 similarity_threshold: 0.85,
56 top_k: 10,
57 timeout_secs: 5,
58 }
59 }
60}
61
62#[derive(Debug, Default)]
64pub struct ExtractionStats {
65 pub entities_upserted: usize,
66 pub edges_inserted: usize,
67}
68
69#[derive(Debug, Default)]
71pub struct ExtractionResult {
72 pub stats: ExtractionStats,
73 pub entity_ids: Vec<i64>,
75}
76
77#[derive(Debug, Default)]
79pub struct LinkingStats {
80 pub entities_processed: usize,
81 pub edges_created: usize,
82}
83
84const ENTITY_COLLECTION: &str = "zeph_graph_entities";
86
87struct EntityWorkItem {
89 entity_id: i64,
90 canonical_name: String,
91 embed_text: String,
92 self_point_id: Option<String>,
93}
94
95#[allow(clippy::too_many_lines)]
111pub async fn link_memory_notes(
112 entity_ids: &[i64],
113 pool: sqlx::SqlitePool,
114 embedding_store: Arc<EmbeddingStore>,
115 provider: AnyProvider,
116 cfg: &NoteLinkingConfig,
117) -> LinkingStats {
118 use futures::future;
119
120 use crate::graph::GraphStore;
121
122 let store = GraphStore::new(pool);
123 let mut stats = LinkingStats::default();
124
125 let mut work_items: Vec<EntityWorkItem> = Vec::with_capacity(entity_ids.len());
127 for &entity_id in entity_ids {
128 let entity = match store.find_entity_by_id(entity_id).await {
129 Ok(Some(e)) => e,
130 Ok(None) => {
131 tracing::debug!("note_linking: entity {entity_id} not found, skipping");
132 continue;
133 }
134 Err(e) => {
135 tracing::debug!("note_linking: DB error loading entity {entity_id}: {e:#}");
136 continue;
137 }
138 };
139 let embed_text = match &entity.summary {
140 Some(s) if !s.is_empty() => format!("{}: {s}", entity.canonical_name),
141 _ => entity.canonical_name.clone(),
142 };
143 work_items.push(EntityWorkItem {
144 entity_id,
145 canonical_name: entity.canonical_name,
146 embed_text,
147 self_point_id: entity.qdrant_point_id,
148 });
149 }
150
151 if work_items.is_empty() {
152 return stats;
153 }
154
155 let embed_results: Vec<_> =
157 future::join_all(work_items.iter().map(|w| provider.embed(&w.embed_text))).await;
158
159 let search_limit = cfg.top_k + 1; let valid: Vec<(usize, Vec<f32>)> = embed_results
162 .into_iter()
163 .enumerate()
164 .filter_map(|(i, r)| match r {
165 Ok(v) => Some((i, v)),
166 Err(e) => {
167 tracing::debug!(
168 "note_linking: embed failed for entity {:?}: {e:#}",
169 work_items[i].canonical_name
170 );
171 None
172 }
173 })
174 .collect();
175
176 let search_results: Vec<_> = future::join_all(valid.iter().map(|(_, vec)| {
177 embedding_store.search_collection(
178 ENTITY_COLLECTION,
179 vec,
180 search_limit,
181 None::<VectorFilter>,
182 )
183 }))
184 .await;
185
186 let mut seen_pairs = std::collections::HashSet::new();
191
192 for ((work_idx, _), search_result) in valid.iter().zip(search_results.iter()) {
193 let w = &work_items[*work_idx];
194
195 let results = match search_result {
196 Ok(r) => r,
197 Err(e) => {
198 tracing::debug!(
199 "note_linking: search failed for entity {:?}: {e:#}",
200 w.canonical_name
201 );
202 continue;
203 }
204 };
205
206 stats.entities_processed += 1;
207
208 let self_point_id = w.self_point_id.as_deref();
209 let candidates = results
210 .iter()
211 .filter(|p| Some(p.id.as_str()) != self_point_id && p.score >= cfg.similarity_threshold)
212 .take(cfg.top_k);
213
214 for point in candidates {
215 let Some(target_id) = point
216 .payload
217 .get("entity_id")
218 .and_then(serde_json::Value::as_i64)
219 else {
220 tracing::debug!(
221 "note_linking: missing entity_id in payload for point {}",
222 point.id
223 );
224 continue;
225 };
226
227 if target_id == w.entity_id {
228 continue; }
230
231 let (src, tgt) = if w.entity_id < target_id {
233 (w.entity_id, target_id)
234 } else {
235 (target_id, w.entity_id)
236 };
237
238 if !seen_pairs.insert((src, tgt)) {
240 continue;
241 }
242
243 let fact = format!("Semantically similar entities (score: {:.3})", point.score);
244
245 match store
246 .insert_edge(src, tgt, "similar_to", &fact, point.score, None)
247 .await
248 {
249 Ok(_) => stats.edges_created += 1,
250 Err(e) => {
251 tracing::debug!("note_linking: insert_edge failed: {e:#}");
252 }
253 }
254 }
255 }
256
257 stats
258}
259
260#[allow(clippy::too_many_lines)]
271pub async fn extract_and_store(
272 content: String,
273 context_messages: Vec<String>,
274 provider: AnyProvider,
275 pool: sqlx::SqlitePool,
276 config: GraphExtractionConfig,
277 post_extract_validator: PostExtractValidator,
278 embedding_store: Option<Arc<EmbeddingStore>>,
279) -> Result<ExtractionResult, MemoryError> {
280 use crate::graph::{EntityResolver, GraphExtractor, GraphStore};
281
282 let extractor = GraphExtractor::new(provider.clone(), config.max_entities, config.max_edges);
283 let ctx_refs: Vec<&str> = context_messages.iter().map(String::as_str).collect();
284
285 let store = GraphStore::new(pool);
286
287 let pool = store.pool();
288 sqlx::query(
289 "INSERT INTO graph_metadata (key, value) VALUES ('extraction_count', '0')
290 ON CONFLICT(key) DO NOTHING",
291 )
292 .execute(pool)
293 .await?;
294 sqlx::query(
295 "UPDATE graph_metadata
296 SET value = CAST(CAST(value AS INTEGER) + 1 AS TEXT)
297 WHERE key = 'extraction_count'",
298 )
299 .execute(pool)
300 .await?;
301
302 let Some(result) = extractor.extract(&content, &ctx_refs).await? else {
303 return Ok(ExtractionResult::default());
304 };
305
306 if let Some(ref validator) = post_extract_validator
309 && let Err(reason) = validator(&result)
310 {
311 tracing::warn!(
312 reason,
313 "graph extraction validation failed, skipping upsert"
314 );
315 return Ok(ExtractionResult::default());
316 }
317
318 let resolver = if let Some(ref emb) = embedding_store {
319 EntityResolver::new(&store)
320 .with_embedding_store(emb)
321 .with_provider(&provider)
322 } else {
323 EntityResolver::new(&store)
324 };
325
326 let mut entities_upserted = 0usize;
327 let mut entity_name_to_id: std::collections::HashMap<String, i64> =
328 std::collections::HashMap::new();
329
330 for entity in &result.entities {
331 match resolver
332 .resolve(&entity.name, &entity.entity_type, entity.summary.as_deref())
333 .await
334 {
335 Ok((id, _outcome)) => {
336 entity_name_to_id.insert(entity.name.clone(), id);
337 entities_upserted += 1;
338 }
339 Err(e) => {
340 tracing::debug!("graph: skipping entity {:?}: {e:#}", entity.name);
341 }
342 }
343 }
344
345 let mut edges_inserted = 0usize;
346 for edge in &result.edges {
347 let (Some(&src_id), Some(&tgt_id)) = (
348 entity_name_to_id.get(&edge.source),
349 entity_name_to_id.get(&edge.target),
350 ) else {
351 tracing::debug!(
352 "graph: skipping edge {:?}->{:?}: entity not resolved",
353 edge.source,
354 edge.target
355 );
356 continue;
357 };
358 let edge_type = edge
361 .edge_type
362 .parse::<crate::graph::EdgeType>()
363 .unwrap_or_else(|_| {
364 tracing::warn!(
365 raw_type = %edge.edge_type,
366 "graph: unknown edge_type from LLM, defaulting to semantic"
367 );
368 crate::graph::EdgeType::Semantic
369 });
370 match resolver
371 .resolve_edge_typed(
372 src_id,
373 tgt_id,
374 &edge.relation,
375 &edge.fact,
376 0.8,
377 None,
378 edge_type,
379 )
380 .await
381 {
382 Ok(Some(_)) => edges_inserted += 1,
383 Ok(None) => {} Err(e) => {
385 tracing::debug!("graph: skipping edge: {e:#}");
386 }
387 }
388 }
389
390 let new_entity_ids: Vec<i64> = entity_name_to_id.into_values().collect();
391
392 Ok(ExtractionResult {
393 stats: ExtractionStats {
394 entities_upserted,
395 edges_inserted,
396 },
397 entity_ids: new_entity_ids,
398 })
399}
400
401#[cfg(test)]
402mod tests {
403 use std::sync::Arc;
404
405 use zeph_llm::any::AnyProvider;
406
407 use super::extract_and_store;
408 use crate::embedding_store::EmbeddingStore;
409 use crate::graph::GraphStore;
410 use crate::in_memory_store::InMemoryVectorStore;
411 use crate::sqlite::SqliteStore;
412
413 use super::GraphExtractionConfig;
414
415 async fn setup() -> (GraphStore, Arc<EmbeddingStore>) {
416 let sqlite = SqliteStore::new(":memory:").await.unwrap();
417 let pool = sqlite.pool().clone();
418 let mem_store = Box::new(InMemoryVectorStore::new());
419 let emb = Arc::new(EmbeddingStore::with_store(mem_store, pool.clone()));
420 let gs = GraphStore::new(pool);
421 (gs, emb)
422 }
423
424 #[tokio::test]
427 async fn extract_and_store_sets_qdrant_point_id_when_embedding_store_provided() {
428 let (gs, emb) = setup().await;
429
430 let extraction_json = r#"{"entities":[{"name":"Rust","type":"language","summary":"systems language"}],"edges":[]}"#;
432 let mut mock =
433 zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
434 mock.supports_embeddings = true;
435 mock.embedding = vec![1.0_f32, 0.0, 0.0, 0.0];
436 let provider = AnyProvider::Mock(mock);
437
438 let config = GraphExtractionConfig {
439 max_entities: 10,
440 max_edges: 10,
441 extraction_timeout_secs: 10,
442 ..Default::default()
443 };
444
445 let result = extract_and_store(
446 "Rust is a systems programming language.".to_owned(),
447 vec![],
448 provider,
449 gs.pool().clone(),
450 config,
451 None,
452 Some(emb.clone()),
453 )
454 .await
455 .unwrap();
456
457 assert_eq!(
458 result.stats.entities_upserted, 1,
459 "one entity should be upserted"
460 );
461
462 let entity = gs
466 .find_entity("rust", crate::graph::EntityType::Language)
467 .await
468 .unwrap()
469 .expect("entity 'rust' must exist in SQLite");
470
471 assert!(
472 entity.qdrant_point_id.is_some(),
473 "qdrant_point_id must be set when embedding_store + provider are both provided (regression for #1829)"
474 );
475 }
476
477 #[tokio::test]
480 async fn extract_and_store_without_embedding_store_still_upserts_entities() {
481 let (gs, _emb) = setup().await;
482
483 let extraction_json = r#"{"entities":[{"name":"Python","type":"language","summary":"scripting"}],"edges":[]}"#;
484 let mock = zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
485 let provider = AnyProvider::Mock(mock);
486
487 let config = GraphExtractionConfig {
488 max_entities: 10,
489 max_edges: 10,
490 extraction_timeout_secs: 10,
491 ..Default::default()
492 };
493
494 let result = extract_and_store(
495 "Python is a scripting language.".to_owned(),
496 vec![],
497 provider,
498 gs.pool().clone(),
499 config,
500 None,
501 None, )
503 .await
504 .unwrap();
505
506 assert_eq!(result.stats.entities_upserted, 1);
507
508 let entity = gs
509 .find_entity("python", crate::graph::EntityType::Language)
510 .await
511 .unwrap()
512 .expect("entity 'python' must exist");
513
514 assert!(
515 entity.qdrant_point_id.is_none(),
516 "qdrant_point_id must remain None when no embedding_store is provided"
517 );
518 }
519}
520
521impl SemanticMemory {
522 #[allow(clippy::too_many_lines)]
534 pub fn spawn_graph_extraction(
535 &self,
536 content: String,
537 context_messages: Vec<String>,
538 config: GraphExtractionConfig,
539 post_extract_validator: PostExtractValidator,
540 ) -> tokio::task::JoinHandle<()> {
541 let pool = self.sqlite.pool().clone();
542 let provider = self.provider.clone();
543 let failure_counter = self.community_detection_failures.clone();
544 let extraction_count = self.graph_extraction_count.clone();
545 let extraction_failures = self.graph_extraction_failures.clone();
546 let embedding_store = self.qdrant.clone();
548
549 tokio::spawn(async move {
550 let timeout_dur = std::time::Duration::from_secs(config.extraction_timeout_secs);
551 let extraction_result = tokio::time::timeout(
552 timeout_dur,
553 extract_and_store(
554 content,
555 context_messages,
556 provider.clone(),
557 pool.clone(),
558 config.clone(),
559 post_extract_validator,
560 embedding_store.clone(),
561 ),
562 )
563 .await;
564
565 let (extraction_ok, new_entity_ids) = match extraction_result {
566 Ok(Ok(result)) => {
567 tracing::debug!(
568 entities = result.stats.entities_upserted,
569 edges = result.stats.edges_inserted,
570 "graph extraction completed"
571 );
572 extraction_count.fetch_add(1, Ordering::Relaxed);
573 (true, result.entity_ids)
574 }
575 Ok(Err(e)) => {
576 tracing::warn!("graph extraction failed: {e:#}");
577 extraction_failures.fetch_add(1, Ordering::Relaxed);
578 (false, vec![])
579 }
580 Err(_elapsed) => {
581 tracing::warn!("graph extraction timed out");
582 extraction_failures.fetch_add(1, Ordering::Relaxed);
583 (false, vec![])
584 }
585 };
586
587 if extraction_ok
589 && config.note_linking.enabled
590 && !new_entity_ids.is_empty()
591 && let Some(store) = embedding_store
592 {
593 let linking_timeout =
594 std::time::Duration::from_secs(config.note_linking.timeout_secs);
595 match tokio::time::timeout(
596 linking_timeout,
597 link_memory_notes(
598 &new_entity_ids,
599 pool.clone(),
600 store,
601 provider.clone(),
602 &config.note_linking,
603 ),
604 )
605 .await
606 {
607 Ok(stats) => {
608 tracing::debug!(
609 entities_processed = stats.entities_processed,
610 edges_created = stats.edges_created,
611 "note linking completed"
612 );
613 }
614 Err(_elapsed) => {
615 tracing::debug!("note linking timed out (partial edges may exist)");
616 }
617 }
618 }
619
620 if extraction_ok && config.community_refresh_interval > 0 {
621 use crate::graph::GraphStore;
622
623 let store = GraphStore::new(pool.clone());
624 let extraction_count = store.extraction_count().await.unwrap_or(0);
625 if extraction_count > 0
626 && i64::try_from(config.community_refresh_interval)
627 .is_ok_and(|interval| extraction_count % interval == 0)
628 {
629 tracing::info!(extraction_count, "triggering community detection refresh");
630 let store2 = GraphStore::new(pool);
631 let provider2 = provider;
632 let retention_days = config.expired_edge_retention_days;
633 let max_cap = config.max_entities_cap;
634 let max_prompt_bytes = config.community_summary_max_prompt_bytes;
635 let concurrency = config.community_summary_concurrency;
636 let edge_chunk_size = config.lpa_edge_chunk_size;
637 tokio::spawn(async move {
638 match crate::graph::community::detect_communities(
639 &store2,
640 &provider2,
641 max_prompt_bytes,
642 concurrency,
643 edge_chunk_size,
644 )
645 .await
646 {
647 Ok(count) => {
648 tracing::info!(communities = count, "community detection complete");
649 }
650 Err(e) => {
651 tracing::warn!("community detection failed: {e:#}");
652 failure_counter.fetch_add(1, Ordering::Relaxed);
653 }
654 }
655 match crate::graph::community::run_graph_eviction(
656 &store2,
657 retention_days,
658 max_cap,
659 )
660 .await
661 {
662 Ok(stats) => {
663 tracing::info!(
664 expired_edges = stats.expired_edges_deleted,
665 orphan_entities = stats.orphan_entities_deleted,
666 capped_entities = stats.capped_entities_deleted,
667 "graph eviction complete"
668 );
669 }
670 Err(e) => {
671 tracing::warn!("graph eviction failed: {e:#}");
672 }
673 }
674 });
675 }
676 }
677 })
678 }
679}