Skip to main content

zeph_memory/semantic/
graph.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::sync::Arc;
5use std::sync::atomic::Ordering;
6
7use zeph_llm::any::AnyProvider;
8use zeph_llm::provider::LlmProvider as _;
9
10use crate::embedding_store::EmbeddingStore;
11use crate::error::MemoryError;
12use crate::graph::extractor::ExtractionResult as ExtractorResult;
13use crate::vector_store::VectorFilter;
14
15use super::SemanticMemory;
16
17/// Callback type for post-extraction validation.
18///
19/// A generic predicate opaque to zeph-memory — callers (zeph-core) provide security
20/// validation without introducing a dependency on security policy in this crate.
21pub type PostExtractValidator = Option<Box<dyn Fn(&ExtractorResult) -> Result<(), String> + Send>>;
22
23/// Config for the spawned background extraction task.
24///
25/// Owned clone of the relevant fields from `GraphConfig` — no references, safe to send to
26/// spawned tasks.
27#[derive(Debug, Clone, Default)]
28pub struct GraphExtractionConfig {
29    pub max_entities: usize,
30    pub max_edges: usize,
31    pub extraction_timeout_secs: u64,
32    pub community_refresh_interval: usize,
33    pub expired_edge_retention_days: u32,
34    pub max_entities_cap: usize,
35    pub community_summary_max_prompt_bytes: usize,
36    pub community_summary_concurrency: usize,
37    pub lpa_edge_chunk_size: usize,
38    /// A-MEM note linking config, cloned from `GraphConfig.note_linking`.
39    pub note_linking: NoteLinkingConfig,
40}
41
42/// Config for A-MEM dynamic note linking, owned by the spawned extraction task.
43#[derive(Debug, Clone)]
44pub struct NoteLinkingConfig {
45    pub enabled: bool,
46    pub similarity_threshold: f32,
47    pub top_k: usize,
48    pub timeout_secs: u64,
49}
50
51impl Default for NoteLinkingConfig {
52    fn default() -> Self {
53        Self {
54            enabled: false,
55            similarity_threshold: 0.85,
56            top_k: 10,
57            timeout_secs: 5,
58        }
59    }
60}
61
62/// Stats returned from a completed extraction.
63#[derive(Debug, Default)]
64pub struct ExtractionStats {
65    pub entities_upserted: usize,
66    pub edges_inserted: usize,
67}
68
69/// Result returned from `extract_and_store`, combining stats with entity IDs needed for linking.
70#[derive(Debug, Default)]
71pub struct ExtractionResult {
72    pub stats: ExtractionStats,
73    /// IDs of entities upserted during this extraction pass. Passed to `link_memory_notes`.
74    pub entity_ids: Vec<i64>,
75}
76
77/// Stats returned from a completed note-linking pass.
78#[derive(Debug, Default)]
79pub struct LinkingStats {
80    pub entities_processed: usize,
81    pub edges_created: usize,
82}
83
84/// Qdrant collection name for entity embeddings (mirrors the constant in `resolver.rs`).
85const ENTITY_COLLECTION: &str = "zeph_graph_entities";
86
87/// Work item for a single entity during a note-linking pass.
88struct EntityWorkItem {
89    entity_id: i64,
90    canonical_name: String,
91    embed_text: String,
92    self_point_id: Option<String>,
93}
94
95/// Link newly extracted entities to semantically similar entities in the graph.
96///
97/// For each entity in `entity_ids`:
98/// 1. Load the entity name + summary from `SQLite`.
99/// 2. Embed all entity texts in parallel.
100/// 3. Search the entity embedding collection in parallel for the `top_k + 1` most similar points.
101/// 4. Filter out the entity itself (by `qdrant_point_id` or `entity_id` payload) and points
102///    below `similarity_threshold`.
103/// 5. Insert a unidirectional `similar_to` edge where `source_id < target_id` to avoid
104///    double-counting in BFS recall while still being traversable via the OR clause in
105///    `edges_for_entity`. The edge confidence is set to the cosine similarity score.
106/// 6. Deduplicate pairs within a single pass so that a pair encountered from both A→B and B→A
107///    directions is only inserted once, keeping `edges_created` accurate.
108///
109/// Errors are logged and not propagated — this is a best-effort background enrichment step.
110#[allow(clippy::too_many_lines)]
111pub async fn link_memory_notes(
112    entity_ids: &[i64],
113    pool: sqlx::SqlitePool,
114    embedding_store: Arc<EmbeddingStore>,
115    provider: AnyProvider,
116    cfg: &NoteLinkingConfig,
117) -> LinkingStats {
118    use futures::future;
119
120    use crate::graph::GraphStore;
121
122    let store = GraphStore::new(pool);
123    let mut stats = LinkingStats::default();
124
125    // Phase 1: load entities from DB sequentially (cheap; avoids connection-pool contention).
126    let mut work_items: Vec<EntityWorkItem> = Vec::with_capacity(entity_ids.len());
127    for &entity_id in entity_ids {
128        let entity = match store.find_entity_by_id(entity_id).await {
129            Ok(Some(e)) => e,
130            Ok(None) => {
131                tracing::debug!("note_linking: entity {entity_id} not found, skipping");
132                continue;
133            }
134            Err(e) => {
135                tracing::debug!("note_linking: DB error loading entity {entity_id}: {e:#}");
136                continue;
137            }
138        };
139        let embed_text = match &entity.summary {
140            Some(s) if !s.is_empty() => format!("{}: {s}", entity.canonical_name),
141            _ => entity.canonical_name.clone(),
142        };
143        work_items.push(EntityWorkItem {
144            entity_id,
145            canonical_name: entity.canonical_name,
146            embed_text,
147            self_point_id: entity.qdrant_point_id,
148        });
149    }
150
151    if work_items.is_empty() {
152        return stats;
153    }
154
155    // Phase 2: embed all entity texts in parallel to reduce N serial HTTP round-trips to 1.
156    let embed_results: Vec<_> =
157        future::join_all(work_items.iter().map(|w| provider.embed(&w.embed_text))).await;
158
159    // Phase 3: search for similar entities in parallel for all successfully embedded entities.
160    let search_limit = cfg.top_k + 1; // +1 to account for self-match
161    let valid: Vec<(usize, Vec<f32>)> = embed_results
162        .into_iter()
163        .enumerate()
164        .filter_map(|(i, r)| match r {
165            Ok(v) => Some((i, v)),
166            Err(e) => {
167                tracing::debug!(
168                    "note_linking: embed failed for entity {:?}: {e:#}",
169                    work_items[i].canonical_name
170                );
171                None
172            }
173        })
174        .collect();
175
176    let search_results: Vec<_> = future::join_all(valid.iter().map(|(_, vec)| {
177        embedding_store.search_collection(
178            ENTITY_COLLECTION,
179            vec,
180            search_limit,
181            None::<VectorFilter>,
182        )
183    }))
184    .await;
185
186    // Phase 4: insert edges; deduplicate pairs seen from both A→B and B→A directions.
187    // Without deduplication, both directions call insert_edge for the same normalised pair and
188    // both return Ok (the second call updates confidence on the existing row), inflating
189    // edges_created by the number of bidirectional hits.
190    let mut seen_pairs = std::collections::HashSet::new();
191
192    for ((work_idx, _), search_result) in valid.iter().zip(search_results.iter()) {
193        let w = &work_items[*work_idx];
194
195        let results = match search_result {
196            Ok(r) => r,
197            Err(e) => {
198                tracing::debug!(
199                    "note_linking: search failed for entity {:?}: {e:#}",
200                    w.canonical_name
201                );
202                continue;
203            }
204        };
205
206        stats.entities_processed += 1;
207
208        let self_point_id = w.self_point_id.as_deref();
209        let candidates = results
210            .iter()
211            .filter(|p| Some(p.id.as_str()) != self_point_id && p.score >= cfg.similarity_threshold)
212            .take(cfg.top_k);
213
214        for point in candidates {
215            let Some(target_id) = point
216                .payload
217                .get("entity_id")
218                .and_then(serde_json::Value::as_i64)
219            else {
220                tracing::debug!(
221                    "note_linking: missing entity_id in payload for point {}",
222                    point.id
223                );
224                continue;
225            };
226
227            if target_id == w.entity_id {
228                continue; // secondary self-guard when qdrant_point_id is null
229            }
230
231            // Normalise direction: always store source_id < target_id.
232            let (src, tgt) = if w.entity_id < target_id {
233                (w.entity_id, target_id)
234            } else {
235                (target_id, w.entity_id)
236            };
237
238            // Skip pairs already processed in this pass to avoid double-counting.
239            if !seen_pairs.insert((src, tgt)) {
240                continue;
241            }
242
243            let fact = format!("Semantically similar entities (score: {:.3})", point.score);
244
245            match store
246                .insert_edge(src, tgt, "similar_to", &fact, point.score, None)
247                .await
248            {
249                Ok(_) => stats.edges_created += 1,
250                Err(e) => {
251                    tracing::debug!("note_linking: insert_edge failed: {e:#}");
252                }
253            }
254        }
255    }
256
257    stats
258}
259
260/// Extract entities and edges from `content` and persist them to the graph store.
261///
262/// This function runs inside a spawned task — it receives owned data only.
263///
264/// The optional `embedding_store` enables entity embedding storage in Qdrant, which is
265/// required for A-MEM note linking to find semantically similar entities across sessions.
266///
267/// # Errors
268///
269/// Returns an error if the database query fails or LLM extraction fails.
270pub async fn extract_and_store(
271    content: String,
272    context_messages: Vec<String>,
273    provider: AnyProvider,
274    pool: sqlx::SqlitePool,
275    config: GraphExtractionConfig,
276    post_extract_validator: PostExtractValidator,
277    embedding_store: Option<Arc<EmbeddingStore>>,
278) -> Result<ExtractionResult, MemoryError> {
279    use crate::graph::{EntityResolver, GraphExtractor, GraphStore};
280
281    let extractor = GraphExtractor::new(provider.clone(), config.max_entities, config.max_edges);
282    let ctx_refs: Vec<&str> = context_messages.iter().map(String::as_str).collect();
283
284    let store = GraphStore::new(pool);
285
286    let pool = store.pool();
287    sqlx::query(
288        "INSERT INTO graph_metadata (key, value) VALUES ('extraction_count', '0')
289         ON CONFLICT(key) DO NOTHING",
290    )
291    .execute(pool)
292    .await?;
293    sqlx::query(
294        "UPDATE graph_metadata
295         SET value = CAST(CAST(value AS INTEGER) + 1 AS TEXT)
296         WHERE key = 'extraction_count'",
297    )
298    .execute(pool)
299    .await?;
300
301    let Some(result) = extractor.extract(&content, &ctx_refs).await? else {
302        return Ok(ExtractionResult::default());
303    };
304
305    // Post-extraction validation callback. zeph-memory does not know the callback is a
306    // security validator — it is a generic predicate opaque to this crate (design decision D1).
307    if let Some(ref validator) = post_extract_validator
308        && let Err(reason) = validator(&result)
309    {
310        tracing::warn!(
311            reason,
312            "graph extraction validation failed, skipping upsert"
313        );
314        return Ok(ExtractionResult::default());
315    }
316
317    let resolver = if let Some(ref emb) = embedding_store {
318        EntityResolver::new(&store)
319            .with_embedding_store(emb)
320            .with_provider(&provider)
321    } else {
322        EntityResolver::new(&store)
323    };
324
325    let mut entities_upserted = 0usize;
326    let mut entity_name_to_id: std::collections::HashMap<String, i64> =
327        std::collections::HashMap::new();
328
329    for entity in &result.entities {
330        match resolver
331            .resolve(&entity.name, &entity.entity_type, entity.summary.as_deref())
332            .await
333        {
334            Ok((id, _outcome)) => {
335                entity_name_to_id.insert(entity.name.clone(), id);
336                entities_upserted += 1;
337            }
338            Err(e) => {
339                tracing::debug!("graph: skipping entity {:?}: {e:#}", entity.name);
340            }
341        }
342    }
343
344    let mut edges_inserted = 0usize;
345    for edge in &result.edges {
346        let (Some(&src_id), Some(&tgt_id)) = (
347            entity_name_to_id.get(&edge.source),
348            entity_name_to_id.get(&edge.target),
349        ) else {
350            tracing::debug!(
351                "graph: skipping edge {:?}->{:?}: entity not resolved",
352                edge.source,
353                edge.target
354            );
355            continue;
356        };
357        match resolver
358            .resolve_edge(src_id, tgt_id, &edge.relation, &edge.fact, 0.8, None)
359            .await
360        {
361            Ok(Some(_)) => edges_inserted += 1,
362            Ok(None) => {} // deduplicated
363            Err(e) => {
364                tracing::debug!("graph: skipping edge: {e:#}");
365            }
366        }
367    }
368
369    let new_entity_ids: Vec<i64> = entity_name_to_id.into_values().collect();
370
371    Ok(ExtractionResult {
372        stats: ExtractionStats {
373            entities_upserted,
374            edges_inserted,
375        },
376        entity_ids: new_entity_ids,
377    })
378}
379
380#[cfg(test)]
381mod tests {
382    use std::sync::Arc;
383
384    use zeph_llm::any::AnyProvider;
385
386    use super::extract_and_store;
387    use crate::embedding_store::EmbeddingStore;
388    use crate::graph::GraphStore;
389    use crate::in_memory_store::InMemoryVectorStore;
390    use crate::sqlite::SqliteStore;
391
392    use super::GraphExtractionConfig;
393
394    async fn setup() -> (GraphStore, Arc<EmbeddingStore>) {
395        let sqlite = SqliteStore::new(":memory:").await.unwrap();
396        let pool = sqlite.pool().clone();
397        let mem_store = Box::new(InMemoryVectorStore::new());
398        let emb = Arc::new(EmbeddingStore::with_store(mem_store, pool.clone()));
399        let gs = GraphStore::new(pool);
400        (gs, emb)
401    }
402
403    /// Regression test for #1829: extract_and_store() must pass the provider to EntityResolver
404    /// so that store_entity_embedding() is called and qdrant_point_id is set in SQLite.
405    #[tokio::test]
406    async fn extract_and_store_sets_qdrant_point_id_when_embedding_store_provided() {
407        let (gs, emb) = setup().await;
408
409        // MockProvider: supports embeddings, returns a valid extraction JSON for chat
410        let extraction_json = r#"{"entities":[{"name":"Rust","type":"language","summary":"systems language"}],"edges":[]}"#;
411        let mut mock =
412            zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
413        mock.supports_embeddings = true;
414        mock.embedding = vec![1.0_f32, 0.0, 0.0, 0.0];
415        let provider = AnyProvider::Mock(mock);
416
417        let config = GraphExtractionConfig {
418            max_entities: 10,
419            max_edges: 10,
420            extraction_timeout_secs: 10,
421            ..Default::default()
422        };
423
424        let result = extract_and_store(
425            "Rust is a systems programming language.".to_owned(),
426            vec![],
427            provider,
428            gs.pool().clone(),
429            config,
430            None,
431            Some(emb.clone()),
432        )
433        .await
434        .unwrap();
435
436        assert_eq!(
437            result.stats.entities_upserted, 1,
438            "one entity should be upserted"
439        );
440
441        // The entity must have a qdrant_point_id — this proves store_entity_embedding() was called.
442        // Before the fix, EntityResolver was built without a provider, so embed() was never called
443        // and qdrant_point_id remained NULL.
444        let entity = gs
445            .find_entity("rust", crate::graph::EntityType::Language)
446            .await
447            .unwrap()
448            .expect("entity 'rust' must exist in SQLite");
449
450        assert!(
451            entity.qdrant_point_id.is_some(),
452            "qdrant_point_id must be set when embedding_store + provider are both provided (regression for #1829)"
453        );
454    }
455
456    /// When no embedding_store is provided, extract_and_store() must still work correctly
457    /// (no embeddings stored, but entities are still upserted).
458    #[tokio::test]
459    async fn extract_and_store_without_embedding_store_still_upserts_entities() {
460        let (gs, _emb) = setup().await;
461
462        let extraction_json = r#"{"entities":[{"name":"Python","type":"language","summary":"scripting"}],"edges":[]}"#;
463        let mock = zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
464        let provider = AnyProvider::Mock(mock);
465
466        let config = GraphExtractionConfig {
467            max_entities: 10,
468            max_edges: 10,
469            extraction_timeout_secs: 10,
470            ..Default::default()
471        };
472
473        let result = extract_and_store(
474            "Python is a scripting language.".to_owned(),
475            vec![],
476            provider,
477            gs.pool().clone(),
478            config,
479            None,
480            None, // no embedding_store
481        )
482        .await
483        .unwrap();
484
485        assert_eq!(result.stats.entities_upserted, 1);
486
487        let entity = gs
488            .find_entity("python", crate::graph::EntityType::Language)
489            .await
490            .unwrap()
491            .expect("entity 'python' must exist");
492
493        assert!(
494            entity.qdrant_point_id.is_none(),
495            "qdrant_point_id must remain None when no embedding_store is provided"
496        );
497    }
498}
499
500impl SemanticMemory {
501    /// Spawn background graph extraction for a message. Fire-and-forget — never blocks.
502    ///
503    /// Extraction runs in a separate tokio task with a timeout. Any error or timeout is
504    /// logged and the task exits silently; the agent response is never blocked.
505    ///
506    /// The optional `post_extract_validator` is called after extraction, before upsert.
507    /// It is a generic predicate opaque to zeph-memory (design decision D1).
508    ///
509    /// When `config.note_linking.enabled` is `true` and an embedding store is available,
510    /// `link_memory_notes` runs after successful extraction inside the same task, bounded
511    /// by `config.note_linking.timeout_secs`.
512    #[allow(clippy::too_many_lines)]
513    pub fn spawn_graph_extraction(
514        &self,
515        content: String,
516        context_messages: Vec<String>,
517        config: GraphExtractionConfig,
518        post_extract_validator: PostExtractValidator,
519    ) -> tokio::task::JoinHandle<()> {
520        let pool = self.sqlite.pool().clone();
521        let provider = self.provider.clone();
522        let failure_counter = self.community_detection_failures.clone();
523        let extraction_count = self.graph_extraction_count.clone();
524        let extraction_failures = self.graph_extraction_failures.clone();
525        // Clone the embedding store Arc before moving into the task.
526        let embedding_store = self.qdrant.clone();
527
528        tokio::spawn(async move {
529            let timeout_dur = std::time::Duration::from_secs(config.extraction_timeout_secs);
530            let extraction_result = tokio::time::timeout(
531                timeout_dur,
532                extract_and_store(
533                    content,
534                    context_messages,
535                    provider.clone(),
536                    pool.clone(),
537                    config.clone(),
538                    post_extract_validator,
539                    embedding_store.clone(),
540                ),
541            )
542            .await;
543
544            let (extraction_ok, new_entity_ids) = match extraction_result {
545                Ok(Ok(result)) => {
546                    tracing::debug!(
547                        entities = result.stats.entities_upserted,
548                        edges = result.stats.edges_inserted,
549                        "graph extraction completed"
550                    );
551                    extraction_count.fetch_add(1, Ordering::Relaxed);
552                    (true, result.entity_ids)
553                }
554                Ok(Err(e)) => {
555                    tracing::warn!("graph extraction failed: {e:#}");
556                    extraction_failures.fetch_add(1, Ordering::Relaxed);
557                    (false, vec![])
558                }
559                Err(_elapsed) => {
560                    tracing::warn!("graph extraction timed out");
561                    extraction_failures.fetch_add(1, Ordering::Relaxed);
562                    (false, vec![])
563                }
564            };
565
566            // A-MEM note linking: run after successful extraction when enabled.
567            if extraction_ok
568                && config.note_linking.enabled
569                && !new_entity_ids.is_empty()
570                && let Some(store) = embedding_store
571            {
572                let linking_timeout =
573                    std::time::Duration::from_secs(config.note_linking.timeout_secs);
574                match tokio::time::timeout(
575                    linking_timeout,
576                    link_memory_notes(
577                        &new_entity_ids,
578                        pool.clone(),
579                        store,
580                        provider.clone(),
581                        &config.note_linking,
582                    ),
583                )
584                .await
585                {
586                    Ok(stats) => {
587                        tracing::debug!(
588                            entities_processed = stats.entities_processed,
589                            edges_created = stats.edges_created,
590                            "note linking completed"
591                        );
592                    }
593                    Err(_elapsed) => {
594                        tracing::debug!("note linking timed out (partial edges may exist)");
595                    }
596                }
597            }
598
599            if extraction_ok && config.community_refresh_interval > 0 {
600                use crate::graph::GraphStore;
601
602                let store = GraphStore::new(pool.clone());
603                let extraction_count = store.extraction_count().await.unwrap_or(0);
604                if extraction_count > 0
605                    && i64::try_from(config.community_refresh_interval)
606                        .is_ok_and(|interval| extraction_count % interval == 0)
607                {
608                    tracing::info!(extraction_count, "triggering community detection refresh");
609                    let store2 = GraphStore::new(pool);
610                    let provider2 = provider;
611                    let retention_days = config.expired_edge_retention_days;
612                    let max_cap = config.max_entities_cap;
613                    let max_prompt_bytes = config.community_summary_max_prompt_bytes;
614                    let concurrency = config.community_summary_concurrency;
615                    let edge_chunk_size = config.lpa_edge_chunk_size;
616                    tokio::spawn(async move {
617                        match crate::graph::community::detect_communities(
618                            &store2,
619                            &provider2,
620                            max_prompt_bytes,
621                            concurrency,
622                            edge_chunk_size,
623                        )
624                        .await
625                        {
626                            Ok(count) => {
627                                tracing::info!(communities = count, "community detection complete");
628                            }
629                            Err(e) => {
630                                tracing::warn!("community detection failed: {e:#}");
631                                failure_counter.fetch_add(1, Ordering::Relaxed);
632                            }
633                        }
634                        match crate::graph::community::run_graph_eviction(
635                            &store2,
636                            retention_days,
637                            max_cap,
638                        )
639                        .await
640                        {
641                            Ok(stats) => {
642                                tracing::info!(
643                                    expired_edges = stats.expired_edges_deleted,
644                                    orphan_entities = stats.orphan_entities_deleted,
645                                    capped_entities = stats.capped_entities_deleted,
646                                    "graph eviction complete"
647                                );
648                            }
649                            Err(e) => {
650                                tracing::warn!("graph eviction failed: {e:#}");
651                            }
652                        }
653                    });
654                }
655            }
656        })
657    }
658}