Skip to main content

zeph_memory/semantic/
graph.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::sync::Arc;
5use std::sync::atomic::Ordering;
6
7use zeph_llm::any::AnyProvider;
8use zeph_llm::provider::LlmProvider as _;
9
10use crate::embedding_store::EmbeddingStore;
11use crate::error::MemoryError;
12use crate::graph::extractor::ExtractionResult as ExtractorResult;
13use crate::vector_store::VectorFilter;
14
15use super::SemanticMemory;
16
17/// Callback type for post-extraction validation.
18///
19/// A generic predicate opaque to zeph-memory — callers (zeph-core) provide security
20/// validation without introducing a dependency on security policy in this crate.
21pub type PostExtractValidator = Option<Box<dyn Fn(&ExtractorResult) -> Result<(), String> + Send>>;
22
23/// Config for the spawned background extraction task.
24///
25/// Owned clone of the relevant fields from `GraphConfig` — no references, safe to send to
26/// spawned tasks.
27#[derive(Debug, Clone, Default)]
28pub struct GraphExtractionConfig {
29    pub max_entities: usize,
30    pub max_edges: usize,
31    pub extraction_timeout_secs: u64,
32    pub community_refresh_interval: usize,
33    pub expired_edge_retention_days: u32,
34    pub max_entities_cap: usize,
35    pub community_summary_max_prompt_bytes: usize,
36    pub community_summary_concurrency: usize,
37    pub lpa_edge_chunk_size: usize,
38    /// A-MEM note linking config, cloned from `GraphConfig.note_linking`.
39    pub note_linking: NoteLinkingConfig,
40}
41
42/// Config for A-MEM dynamic note linking, owned by the spawned extraction task.
43#[derive(Debug, Clone)]
44pub struct NoteLinkingConfig {
45    pub enabled: bool,
46    pub similarity_threshold: f32,
47    pub top_k: usize,
48    pub timeout_secs: u64,
49}
50
51impl Default for NoteLinkingConfig {
52    fn default() -> Self {
53        Self {
54            enabled: false,
55            similarity_threshold: 0.85,
56            top_k: 10,
57            timeout_secs: 5,
58        }
59    }
60}
61
62/// Stats returned from a completed extraction.
63#[derive(Debug, Default)]
64pub struct ExtractionStats {
65    pub entities_upserted: usize,
66    pub edges_inserted: usize,
67}
68
69/// Result returned from `extract_and_store`, combining stats with entity IDs needed for linking.
70#[derive(Debug, Default)]
71pub struct ExtractionResult {
72    pub stats: ExtractionStats,
73    /// IDs of entities upserted during this extraction pass. Passed to `link_memory_notes`.
74    pub entity_ids: Vec<i64>,
75}
76
77/// Stats returned from a completed note-linking pass.
78#[derive(Debug, Default)]
79pub struct LinkingStats {
80    pub entities_processed: usize,
81    pub edges_created: usize,
82}
83
84/// Qdrant collection name for entity embeddings (mirrors the constant in `resolver.rs`).
85const ENTITY_COLLECTION: &str = "zeph_graph_entities";
86
87/// Work item for a single entity during a note-linking pass.
88struct EntityWorkItem {
89    entity_id: i64,
90    canonical_name: String,
91    embed_text: String,
92    self_point_id: Option<String>,
93}
94
95/// Link newly extracted entities to semantically similar entities in the graph.
96///
97/// For each entity in `entity_ids`:
98/// 1. Load the entity name + summary from `SQLite`.
99/// 2. Embed all entity texts in parallel.
100/// 3. Search the entity embedding collection in parallel for the `top_k + 1` most similar points.
101/// 4. Filter out the entity itself (by `qdrant_point_id` or `entity_id` payload) and points
102///    below `similarity_threshold`.
103/// 5. Insert a unidirectional `similar_to` edge where `source_id < target_id` to avoid
104///    double-counting in BFS recall while still being traversable via the OR clause in
105///    `edges_for_entity`. The edge confidence is set to the cosine similarity score.
106/// 6. Deduplicate pairs within a single pass so that a pair encountered from both A→B and B→A
107///    directions is only inserted once, keeping `edges_created` accurate.
108///
109/// Errors are logged and not propagated — this is a best-effort background enrichment step.
110#[allow(clippy::too_many_lines)]
111pub async fn link_memory_notes(
112    entity_ids: &[i64],
113    pool: sqlx::SqlitePool,
114    embedding_store: Arc<EmbeddingStore>,
115    provider: AnyProvider,
116    cfg: &NoteLinkingConfig,
117) -> LinkingStats {
118    use futures::future;
119
120    use crate::graph::GraphStore;
121
122    let store = GraphStore::new(pool);
123    let mut stats = LinkingStats::default();
124
125    // Phase 1: load entities from DB sequentially (cheap; avoids connection-pool contention).
126    let mut work_items: Vec<EntityWorkItem> = Vec::with_capacity(entity_ids.len());
127    for &entity_id in entity_ids {
128        let entity = match store.find_entity_by_id(entity_id).await {
129            Ok(Some(e)) => e,
130            Ok(None) => {
131                tracing::debug!("note_linking: entity {entity_id} not found, skipping");
132                continue;
133            }
134            Err(e) => {
135                tracing::debug!("note_linking: DB error loading entity {entity_id}: {e:#}");
136                continue;
137            }
138        };
139        let embed_text = match &entity.summary {
140            Some(s) if !s.is_empty() => format!("{}: {s}", entity.canonical_name),
141            _ => entity.canonical_name.clone(),
142        };
143        work_items.push(EntityWorkItem {
144            entity_id,
145            canonical_name: entity.canonical_name,
146            embed_text,
147            self_point_id: entity.qdrant_point_id,
148        });
149    }
150
151    if work_items.is_empty() {
152        return stats;
153    }
154
155    // Phase 2: embed all entity texts in parallel to reduce N serial HTTP round-trips to 1.
156    let embed_results: Vec<_> =
157        future::join_all(work_items.iter().map(|w| provider.embed(&w.embed_text))).await;
158
159    // Phase 3: search for similar entities in parallel for all successfully embedded entities.
160    let search_limit = cfg.top_k + 1; // +1 to account for self-match
161    let valid: Vec<(usize, Vec<f32>)> = embed_results
162        .into_iter()
163        .enumerate()
164        .filter_map(|(i, r)| match r {
165            Ok(v) => Some((i, v)),
166            Err(e) => {
167                tracing::debug!(
168                    "note_linking: embed failed for entity {:?}: {e:#}",
169                    work_items[i].canonical_name
170                );
171                None
172            }
173        })
174        .collect();
175
176    let search_results: Vec<_> = future::join_all(valid.iter().map(|(_, vec)| {
177        embedding_store.search_collection(
178            ENTITY_COLLECTION,
179            vec,
180            search_limit,
181            None::<VectorFilter>,
182        )
183    }))
184    .await;
185
186    // Phase 4: insert edges; deduplicate pairs seen from both A→B and B→A directions.
187    // Without deduplication, both directions call insert_edge for the same normalised pair and
188    // both return Ok (the second call updates confidence on the existing row), inflating
189    // edges_created by the number of bidirectional hits.
190    let mut seen_pairs = std::collections::HashSet::new();
191
192    for ((work_idx, _), search_result) in valid.iter().zip(search_results.iter()) {
193        let w = &work_items[*work_idx];
194
195        let results = match search_result {
196            Ok(r) => r,
197            Err(e) => {
198                tracing::debug!(
199                    "note_linking: search failed for entity {:?}: {e:#}",
200                    w.canonical_name
201                );
202                continue;
203            }
204        };
205
206        stats.entities_processed += 1;
207
208        let self_point_id = w.self_point_id.as_deref();
209        let candidates = results
210            .iter()
211            .filter(|p| Some(p.id.as_str()) != self_point_id && p.score >= cfg.similarity_threshold)
212            .take(cfg.top_k);
213
214        for point in candidates {
215            let Some(target_id) = point
216                .payload
217                .get("entity_id")
218                .and_then(serde_json::Value::as_i64)
219            else {
220                tracing::debug!(
221                    "note_linking: missing entity_id in payload for point {}",
222                    point.id
223                );
224                continue;
225            };
226
227            if target_id == w.entity_id {
228                continue; // secondary self-guard when qdrant_point_id is null
229            }
230
231            // Normalise direction: always store source_id < target_id.
232            let (src, tgt) = if w.entity_id < target_id {
233                (w.entity_id, target_id)
234            } else {
235                (target_id, w.entity_id)
236            };
237
238            // Skip pairs already processed in this pass to avoid double-counting.
239            if !seen_pairs.insert((src, tgt)) {
240                continue;
241            }
242
243            let fact = format!("Semantically similar entities (score: {:.3})", point.score);
244
245            match store
246                .insert_edge(src, tgt, "similar_to", &fact, point.score, None)
247                .await
248            {
249                Ok(_) => stats.edges_created += 1,
250                Err(e) => {
251                    tracing::debug!("note_linking: insert_edge failed: {e:#}");
252                }
253            }
254        }
255    }
256
257    stats
258}
259
260/// Extract entities and edges from `content` and persist them to the graph store.
261///
262/// This function runs inside a spawned task — it receives owned data only.
263///
264/// The optional `embedding_store` enables entity embedding storage in Qdrant, which is
265/// required for A-MEM note linking to find semantically similar entities across sessions.
266///
267/// # Errors
268///
269/// Returns an error if the database query fails or LLM extraction fails.
270#[allow(clippy::too_many_lines)]
271pub async fn extract_and_store(
272    content: String,
273    context_messages: Vec<String>,
274    provider: AnyProvider,
275    pool: sqlx::SqlitePool,
276    config: GraphExtractionConfig,
277    post_extract_validator: PostExtractValidator,
278    embedding_store: Option<Arc<EmbeddingStore>>,
279) -> Result<ExtractionResult, MemoryError> {
280    use crate::graph::{EntityResolver, GraphExtractor, GraphStore};
281
282    let extractor = GraphExtractor::new(provider.clone(), config.max_entities, config.max_edges);
283    let ctx_refs: Vec<&str> = context_messages.iter().map(String::as_str).collect();
284
285    let store = GraphStore::new(pool);
286
287    let pool = store.pool();
288    sqlx::query(
289        "INSERT INTO graph_metadata (key, value) VALUES ('extraction_count', '0')
290         ON CONFLICT(key) DO NOTHING",
291    )
292    .execute(pool)
293    .await?;
294    sqlx::query(
295        "UPDATE graph_metadata
296         SET value = CAST(CAST(value AS INTEGER) + 1 AS TEXT)
297         WHERE key = 'extraction_count'",
298    )
299    .execute(pool)
300    .await?;
301
302    let Some(result) = extractor.extract(&content, &ctx_refs).await? else {
303        return Ok(ExtractionResult::default());
304    };
305
306    // Post-extraction validation callback. zeph-memory does not know the callback is a
307    // security validator — it is a generic predicate opaque to this crate (design decision D1).
308    if let Some(ref validator) = post_extract_validator
309        && let Err(reason) = validator(&result)
310    {
311        tracing::warn!(
312            reason,
313            "graph extraction validation failed, skipping upsert"
314        );
315        return Ok(ExtractionResult::default());
316    }
317
318    let resolver = if let Some(ref emb) = embedding_store {
319        EntityResolver::new(&store)
320            .with_embedding_store(emb)
321            .with_provider(&provider)
322    } else {
323        EntityResolver::new(&store)
324    };
325
326    let mut entities_upserted = 0usize;
327    let mut entity_name_to_id: std::collections::HashMap<String, i64> =
328        std::collections::HashMap::new();
329
330    for entity in &result.entities {
331        match resolver
332            .resolve(&entity.name, &entity.entity_type, entity.summary.as_deref())
333            .await
334        {
335            Ok((id, _outcome)) => {
336                entity_name_to_id.insert(entity.name.clone(), id);
337                entities_upserted += 1;
338            }
339            Err(e) => {
340                tracing::debug!("graph: skipping entity {:?}: {e:#}", entity.name);
341            }
342        }
343    }
344
345    let mut edges_inserted = 0usize;
346    for edge in &result.edges {
347        let (Some(&src_id), Some(&tgt_id)) = (
348            entity_name_to_id.get(&edge.source),
349            entity_name_to_id.get(&edge.target),
350        ) else {
351            tracing::debug!(
352                "graph: skipping edge {:?}->{:?}: entity not resolved",
353                edge.source,
354                edge.target
355            );
356            continue;
357        };
358        // Parse LLM-provided edge_type; default to Semantic on any parse failure so
359        // edges are never dropped due to classification errors.
360        let edge_type = edge
361            .edge_type
362            .parse::<crate::graph::EdgeType>()
363            .unwrap_or_else(|_| {
364                tracing::warn!(
365                    raw_type = %edge.edge_type,
366                    "graph: unknown edge_type from LLM, defaulting to semantic"
367                );
368                crate::graph::EdgeType::Semantic
369            });
370        match resolver
371            .resolve_edge_typed(
372                src_id,
373                tgt_id,
374                &edge.relation,
375                &edge.fact,
376                0.8,
377                None,
378                edge_type,
379            )
380            .await
381        {
382            Ok(Some(_)) => edges_inserted += 1,
383            Ok(None) => {} // deduplicated
384            Err(e) => {
385                tracing::debug!("graph: skipping edge: {e:#}");
386            }
387        }
388    }
389
390    let new_entity_ids: Vec<i64> = entity_name_to_id.into_values().collect();
391
392    Ok(ExtractionResult {
393        stats: ExtractionStats {
394            entities_upserted,
395            edges_inserted,
396        },
397        entity_ids: new_entity_ids,
398    })
399}
400
401#[cfg(test)]
402mod tests {
403    use std::sync::Arc;
404
405    use zeph_llm::any::AnyProvider;
406
407    use super::extract_and_store;
408    use crate::embedding_store::EmbeddingStore;
409    use crate::graph::GraphStore;
410    use crate::in_memory_store::InMemoryVectorStore;
411    use crate::sqlite::SqliteStore;
412
413    use super::GraphExtractionConfig;
414
415    async fn setup() -> (GraphStore, Arc<EmbeddingStore>) {
416        let sqlite = SqliteStore::new(":memory:").await.unwrap();
417        let pool = sqlite.pool().clone();
418        let mem_store = Box::new(InMemoryVectorStore::new());
419        let emb = Arc::new(EmbeddingStore::with_store(mem_store, pool.clone()));
420        let gs = GraphStore::new(pool);
421        (gs, emb)
422    }
423
424    /// Regression test for #1829: extract_and_store() must pass the provider to EntityResolver
425    /// so that store_entity_embedding() is called and qdrant_point_id is set in SQLite.
426    #[tokio::test]
427    async fn extract_and_store_sets_qdrant_point_id_when_embedding_store_provided() {
428        let (gs, emb) = setup().await;
429
430        // MockProvider: supports embeddings, returns a valid extraction JSON for chat
431        let extraction_json = r#"{"entities":[{"name":"Rust","type":"language","summary":"systems language"}],"edges":[]}"#;
432        let mut mock =
433            zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
434        mock.supports_embeddings = true;
435        mock.embedding = vec![1.0_f32, 0.0, 0.0, 0.0];
436        let provider = AnyProvider::Mock(mock);
437
438        let config = GraphExtractionConfig {
439            max_entities: 10,
440            max_edges: 10,
441            extraction_timeout_secs: 10,
442            ..Default::default()
443        };
444
445        let result = extract_and_store(
446            "Rust is a systems programming language.".to_owned(),
447            vec![],
448            provider,
449            gs.pool().clone(),
450            config,
451            None,
452            Some(emb.clone()),
453        )
454        .await
455        .unwrap();
456
457        assert_eq!(
458            result.stats.entities_upserted, 1,
459            "one entity should be upserted"
460        );
461
462        // The entity must have a qdrant_point_id — this proves store_entity_embedding() was called.
463        // Before the fix, EntityResolver was built without a provider, so embed() was never called
464        // and qdrant_point_id remained NULL.
465        let entity = gs
466            .find_entity("rust", crate::graph::EntityType::Language)
467            .await
468            .unwrap()
469            .expect("entity 'rust' must exist in SQLite");
470
471        assert!(
472            entity.qdrant_point_id.is_some(),
473            "qdrant_point_id must be set when embedding_store + provider are both provided (regression for #1829)"
474        );
475    }
476
477    /// When no embedding_store is provided, extract_and_store() must still work correctly
478    /// (no embeddings stored, but entities are still upserted).
479    #[tokio::test]
480    async fn extract_and_store_without_embedding_store_still_upserts_entities() {
481        let (gs, _emb) = setup().await;
482
483        let extraction_json = r#"{"entities":[{"name":"Python","type":"language","summary":"scripting"}],"edges":[]}"#;
484        let mock = zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
485        let provider = AnyProvider::Mock(mock);
486
487        let config = GraphExtractionConfig {
488            max_entities: 10,
489            max_edges: 10,
490            extraction_timeout_secs: 10,
491            ..Default::default()
492        };
493
494        let result = extract_and_store(
495            "Python is a scripting language.".to_owned(),
496            vec![],
497            provider,
498            gs.pool().clone(),
499            config,
500            None,
501            None, // no embedding_store
502        )
503        .await
504        .unwrap();
505
506        assert_eq!(result.stats.entities_upserted, 1);
507
508        let entity = gs
509            .find_entity("python", crate::graph::EntityType::Language)
510            .await
511            .unwrap()
512            .expect("entity 'python' must exist");
513
514        assert!(
515            entity.qdrant_point_id.is_none(),
516            "qdrant_point_id must remain None when no embedding_store is provided"
517        );
518    }
519}
520
521impl SemanticMemory {
522    /// Spawn background graph extraction for a message. Fire-and-forget — never blocks.
523    ///
524    /// Extraction runs in a separate tokio task with a timeout. Any error or timeout is
525    /// logged and the task exits silently; the agent response is never blocked.
526    ///
527    /// The optional `post_extract_validator` is called after extraction, before upsert.
528    /// It is a generic predicate opaque to zeph-memory (design decision D1).
529    ///
530    /// When `config.note_linking.enabled` is `true` and an embedding store is available,
531    /// `link_memory_notes` runs after successful extraction inside the same task, bounded
532    /// by `config.note_linking.timeout_secs`.
533    #[allow(clippy::too_many_lines)]
534    pub fn spawn_graph_extraction(
535        &self,
536        content: String,
537        context_messages: Vec<String>,
538        config: GraphExtractionConfig,
539        post_extract_validator: PostExtractValidator,
540    ) -> tokio::task::JoinHandle<()> {
541        let pool = self.sqlite.pool().clone();
542        let provider = self.provider.clone();
543        let failure_counter = self.community_detection_failures.clone();
544        let extraction_count = self.graph_extraction_count.clone();
545        let extraction_failures = self.graph_extraction_failures.clone();
546        // Clone the embedding store Arc before moving into the task.
547        let embedding_store = self.qdrant.clone();
548
549        tokio::spawn(async move {
550            let timeout_dur = std::time::Duration::from_secs(config.extraction_timeout_secs);
551            let extraction_result = tokio::time::timeout(
552                timeout_dur,
553                extract_and_store(
554                    content,
555                    context_messages,
556                    provider.clone(),
557                    pool.clone(),
558                    config.clone(),
559                    post_extract_validator,
560                    embedding_store.clone(),
561                ),
562            )
563            .await;
564
565            let (extraction_ok, new_entity_ids) = match extraction_result {
566                Ok(Ok(result)) => {
567                    tracing::debug!(
568                        entities = result.stats.entities_upserted,
569                        edges = result.stats.edges_inserted,
570                        "graph extraction completed"
571                    );
572                    extraction_count.fetch_add(1, Ordering::Relaxed);
573                    (true, result.entity_ids)
574                }
575                Ok(Err(e)) => {
576                    tracing::warn!("graph extraction failed: {e:#}");
577                    extraction_failures.fetch_add(1, Ordering::Relaxed);
578                    (false, vec![])
579                }
580                Err(_elapsed) => {
581                    tracing::warn!("graph extraction timed out");
582                    extraction_failures.fetch_add(1, Ordering::Relaxed);
583                    (false, vec![])
584                }
585            };
586
587            // A-MEM note linking: run after successful extraction when enabled.
588            if extraction_ok
589                && config.note_linking.enabled
590                && !new_entity_ids.is_empty()
591                && let Some(store) = embedding_store
592            {
593                let linking_timeout =
594                    std::time::Duration::from_secs(config.note_linking.timeout_secs);
595                match tokio::time::timeout(
596                    linking_timeout,
597                    link_memory_notes(
598                        &new_entity_ids,
599                        pool.clone(),
600                        store,
601                        provider.clone(),
602                        &config.note_linking,
603                    ),
604                )
605                .await
606                {
607                    Ok(stats) => {
608                        tracing::debug!(
609                            entities_processed = stats.entities_processed,
610                            edges_created = stats.edges_created,
611                            "note linking completed"
612                        );
613                    }
614                    Err(_elapsed) => {
615                        tracing::debug!("note linking timed out (partial edges may exist)");
616                    }
617                }
618            }
619
620            if extraction_ok && config.community_refresh_interval > 0 {
621                use crate::graph::GraphStore;
622
623                let store = GraphStore::new(pool.clone());
624                let extraction_count = store.extraction_count().await.unwrap_or(0);
625                if extraction_count > 0
626                    && i64::try_from(config.community_refresh_interval)
627                        .is_ok_and(|interval| extraction_count % interval == 0)
628                {
629                    tracing::info!(extraction_count, "triggering community detection refresh");
630                    let store2 = GraphStore::new(pool);
631                    let provider2 = provider;
632                    let retention_days = config.expired_edge_retention_days;
633                    let max_cap = config.max_entities_cap;
634                    let max_prompt_bytes = config.community_summary_max_prompt_bytes;
635                    let concurrency = config.community_summary_concurrency;
636                    let edge_chunk_size = config.lpa_edge_chunk_size;
637                    tokio::spawn(async move {
638                        match crate::graph::community::detect_communities(
639                            &store2,
640                            &provider2,
641                            max_prompt_bytes,
642                            concurrency,
643                            edge_chunk_size,
644                        )
645                        .await
646                        {
647                            Ok(count) => {
648                                tracing::info!(communities = count, "community detection complete");
649                            }
650                            Err(e) => {
651                                tracing::warn!("community detection failed: {e:#}");
652                                failure_counter.fetch_add(1, Ordering::Relaxed);
653                            }
654                        }
655                        match crate::graph::community::run_graph_eviction(
656                            &store2,
657                            retention_days,
658                            max_cap,
659                        )
660                        .await
661                        {
662                            Ok(stats) => {
663                                tracing::info!(
664                                    expired_edges = stats.expired_edges_deleted,
665                                    orphan_entities = stats.orphan_entities_deleted,
666                                    capped_entities = stats.capped_entities_deleted,
667                                    "graph eviction complete"
668                                );
669                            }
670                            Err(e) => {
671                                tracing::warn!("graph eviction failed: {e:#}");
672                            }
673                        }
674                    });
675                }
676            }
677        })
678    }
679}