Skip to main content

zeph_memory/semantic/
graph.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::sync::Arc;
5use std::sync::atomic::Ordering;
6
7use zeph_llm::any::AnyProvider;
8use zeph_llm::provider::LlmProvider as _;
9
10use crate::embedding_store::EmbeddingStore;
11use crate::error::MemoryError;
12use crate::graph::extractor::ExtractionResult as ExtractorResult;
13use crate::vector_store::VectorFilter;
14
15use super::SemanticMemory;
16
17/// Callback type for post-extraction validation.
18///
19/// A generic predicate opaque to zeph-memory — callers (zeph-core) provide security
20/// validation without introducing a dependency on security policy in this crate.
21pub type PostExtractValidator = Option<Box<dyn Fn(&ExtractorResult) -> Result<(), String> + Send>>;
22
23/// Config for the spawned background extraction task.
24///
25/// Owned clone of the relevant fields from `GraphConfig` — no references, safe to send to
26/// spawned tasks.
27#[derive(Debug, Clone, Default)]
28pub struct GraphExtractionConfig {
29    pub max_entities: usize,
30    pub max_edges: usize,
31    pub extraction_timeout_secs: u64,
32    pub community_refresh_interval: usize,
33    pub expired_edge_retention_days: u32,
34    pub max_entities_cap: usize,
35    pub community_summary_max_prompt_bytes: usize,
36    pub community_summary_concurrency: usize,
37    pub lpa_edge_chunk_size: usize,
38    /// A-MEM note linking config, cloned from `GraphConfig.note_linking`.
39    pub note_linking: NoteLinkingConfig,
40}
41
42/// Config for A-MEM dynamic note linking, owned by the spawned extraction task.
43#[derive(Debug, Clone)]
44pub struct NoteLinkingConfig {
45    pub enabled: bool,
46    pub similarity_threshold: f32,
47    pub top_k: usize,
48    pub timeout_secs: u64,
49}
50
51impl Default for NoteLinkingConfig {
52    fn default() -> Self {
53        Self {
54            enabled: false,
55            similarity_threshold: 0.85,
56            top_k: 10,
57            timeout_secs: 5,
58        }
59    }
60}
61
62/// Stats returned from a completed extraction.
63#[derive(Debug, Default)]
64pub struct ExtractionStats {
65    pub entities_upserted: usize,
66    pub edges_inserted: usize,
67}
68
69/// Result returned from `extract_and_store`, combining stats with entity IDs needed for linking.
70#[derive(Debug, Default)]
71pub struct ExtractionResult {
72    pub stats: ExtractionStats,
73    /// IDs of entities upserted during this extraction pass. Passed to `link_memory_notes`.
74    pub entity_ids: Vec<i64>,
75}
76
77/// Stats returned from a completed note-linking pass.
78#[derive(Debug, Default)]
79pub struct LinkingStats {
80    pub entities_processed: usize,
81    pub edges_created: usize,
82}
83
84/// Qdrant collection name for entity embeddings (mirrors the constant in `resolver.rs`).
85const ENTITY_COLLECTION: &str = "zeph_graph_entities";
86
87/// Work item for a single entity during a note-linking pass.
88struct EntityWorkItem {
89    entity_id: i64,
90    canonical_name: String,
91    embed_text: String,
92    self_point_id: Option<String>,
93}
94
95/// Link newly extracted entities to semantically similar entities in the graph.
96///
97/// For each entity in `entity_ids`:
98/// 1. Load the entity name + summary from `SQLite`.
99/// 2. Embed all entity texts in parallel.
100/// 3. Search the entity embedding collection in parallel for the `top_k + 1` most similar points.
101/// 4. Filter out the entity itself (by `qdrant_point_id` or `entity_id` payload) and points
102///    below `similarity_threshold`.
103/// 5. Insert a unidirectional `similar_to` edge where `source_id < target_id` to avoid
104///    double-counting in BFS recall while still being traversable via the OR clause in
105///    `edges_for_entity`. The edge confidence is set to the cosine similarity score.
106/// 6. Deduplicate pairs within a single pass so that a pair encountered from both A→B and B→A
107///    directions is only inserted once, keeping `edges_created` accurate.
108///
109/// Errors are logged and not propagated — this is a best-effort background enrichment step.
110#[allow(clippy::too_many_lines)]
111pub async fn link_memory_notes(
112    entity_ids: &[i64],
113    pool: sqlx::SqlitePool,
114    embedding_store: Arc<EmbeddingStore>,
115    provider: AnyProvider,
116    cfg: &NoteLinkingConfig,
117) -> LinkingStats {
118    use futures::future;
119
120    use crate::graph::GraphStore;
121
122    let store = GraphStore::new(pool);
123    let mut stats = LinkingStats::default();
124
125    // Phase 1: load entities from DB sequentially (cheap; avoids connection-pool contention).
126    let mut work_items: Vec<EntityWorkItem> = Vec::with_capacity(entity_ids.len());
127    for &entity_id in entity_ids {
128        let entity = match store.find_entity_by_id(entity_id).await {
129            Ok(Some(e)) => e,
130            Ok(None) => {
131                tracing::debug!("note_linking: entity {entity_id} not found, skipping");
132                continue;
133            }
134            Err(e) => {
135                tracing::debug!("note_linking: DB error loading entity {entity_id}: {e:#}");
136                continue;
137            }
138        };
139        let embed_text = match &entity.summary {
140            Some(s) if !s.is_empty() => format!("{}: {s}", entity.canonical_name),
141            _ => entity.canonical_name.clone(),
142        };
143        work_items.push(EntityWorkItem {
144            entity_id,
145            canonical_name: entity.canonical_name,
146            embed_text,
147            self_point_id: entity.qdrant_point_id,
148        });
149    }
150
151    if work_items.is_empty() {
152        return stats;
153    }
154
155    // Phase 2: embed all entity texts in parallel to reduce N serial HTTP round-trips to 1.
156    let embed_results: Vec<_> =
157        future::join_all(work_items.iter().map(|w| provider.embed(&w.embed_text))).await;
158
159    // Phase 3: search for similar entities in parallel for all successfully embedded entities.
160    let search_limit = cfg.top_k + 1; // +1 to account for self-match
161    let valid: Vec<(usize, Vec<f32>)> = embed_results
162        .into_iter()
163        .enumerate()
164        .filter_map(|(i, r)| match r {
165            Ok(v) => Some((i, v)),
166            Err(e) => {
167                tracing::debug!(
168                    "note_linking: embed failed for entity {:?}: {e:#}",
169                    work_items[i].canonical_name
170                );
171                None
172            }
173        })
174        .collect();
175
176    let search_results: Vec<_> = future::join_all(valid.iter().map(|(_, vec)| {
177        embedding_store.search_collection(
178            ENTITY_COLLECTION,
179            vec,
180            search_limit,
181            None::<VectorFilter>,
182        )
183    }))
184    .await;
185
186    // Phase 4: insert edges; deduplicate pairs seen from both A→B and B→A directions.
187    // Without deduplication, both directions call insert_edge for the same normalised pair and
188    // both return Ok (the second call updates confidence on the existing row), inflating
189    // edges_created by the number of bidirectional hits.
190    let mut seen_pairs = std::collections::HashSet::new();
191
192    for ((work_idx, _), search_result) in valid.iter().zip(search_results.iter()) {
193        let w = &work_items[*work_idx];
194
195        let results = match search_result {
196            Ok(r) => r,
197            Err(e) => {
198                tracing::debug!(
199                    "note_linking: search failed for entity {:?}: {e:#}",
200                    w.canonical_name
201                );
202                continue;
203            }
204        };
205
206        stats.entities_processed += 1;
207
208        let self_point_id = w.self_point_id.as_deref();
209        let candidates = results
210            .iter()
211            .filter(|p| Some(p.id.as_str()) != self_point_id && p.score >= cfg.similarity_threshold)
212            .take(cfg.top_k);
213
214        for point in candidates {
215            let Some(target_id) = point
216                .payload
217                .get("entity_id")
218                .and_then(serde_json::Value::as_i64)
219            else {
220                tracing::debug!(
221                    "note_linking: missing entity_id in payload for point {}",
222                    point.id
223                );
224                continue;
225            };
226
227            if target_id == w.entity_id {
228                continue; // secondary self-guard when qdrant_point_id is null
229            }
230
231            // Normalise direction: always store source_id < target_id.
232            let (src, tgt) = if w.entity_id < target_id {
233                (w.entity_id, target_id)
234            } else {
235                (target_id, w.entity_id)
236            };
237
238            // Skip pairs already processed in this pass to avoid double-counting.
239            if !seen_pairs.insert((src, tgt)) {
240                continue;
241            }
242
243            let fact = format!("Semantically similar entities (score: {:.3})", point.score);
244
245            match store
246                .insert_edge(src, tgt, "similar_to", &fact, point.score, None)
247                .await
248            {
249                Ok(_) => stats.edges_created += 1,
250                Err(e) => {
251                    tracing::debug!("note_linking: insert_edge failed: {e:#}");
252                }
253            }
254        }
255    }
256
257    stats
258}
259
260/// Extract entities and edges from `content` and persist them to the graph store.
261///
262/// This function runs inside a spawned task — it receives owned data only.
263///
264/// # Errors
265///
266/// Returns an error if the database query fails or LLM extraction fails.
267pub async fn extract_and_store(
268    content: String,
269    context_messages: Vec<String>,
270    provider: AnyProvider,
271    pool: sqlx::SqlitePool,
272    config: GraphExtractionConfig,
273    post_extract_validator: PostExtractValidator,
274) -> Result<ExtractionResult, MemoryError> {
275    use crate::graph::{EntityResolver, GraphExtractor, GraphStore};
276
277    let extractor = GraphExtractor::new(provider, config.max_entities, config.max_edges);
278    let ctx_refs: Vec<&str> = context_messages.iter().map(String::as_str).collect();
279
280    let store = GraphStore::new(pool);
281
282    let pool = store.pool();
283    sqlx::query(
284        "INSERT INTO graph_metadata (key, value) VALUES ('extraction_count', '0')
285         ON CONFLICT(key) DO NOTHING",
286    )
287    .execute(pool)
288    .await?;
289    sqlx::query(
290        "UPDATE graph_metadata
291         SET value = CAST(CAST(value AS INTEGER) + 1 AS TEXT)
292         WHERE key = 'extraction_count'",
293    )
294    .execute(pool)
295    .await?;
296
297    let Some(result) = extractor.extract(&content, &ctx_refs).await? else {
298        return Ok(ExtractionResult::default());
299    };
300
301    // Post-extraction validation callback. zeph-memory does not know the callback is a
302    // security validator — it is a generic predicate opaque to this crate (design decision D1).
303    if let Some(ref validator) = post_extract_validator
304        && let Err(reason) = validator(&result)
305    {
306        tracing::warn!(
307            reason,
308            "graph extraction validation failed, skipping upsert"
309        );
310        return Ok(ExtractionResult::default());
311    }
312
313    let resolver = EntityResolver::new(&store);
314
315    let mut entities_upserted = 0usize;
316    let mut entity_name_to_id: std::collections::HashMap<String, i64> =
317        std::collections::HashMap::new();
318
319    for entity in &result.entities {
320        match resolver
321            .resolve(&entity.name, &entity.entity_type, entity.summary.as_deref())
322            .await
323        {
324            Ok((id, _outcome)) => {
325                entity_name_to_id.insert(entity.name.clone(), id);
326                entities_upserted += 1;
327            }
328            Err(e) => {
329                tracing::debug!("graph: skipping entity {:?}: {e:#}", entity.name);
330            }
331        }
332    }
333
334    let mut edges_inserted = 0usize;
335    for edge in &result.edges {
336        let (Some(&src_id), Some(&tgt_id)) = (
337            entity_name_to_id.get(&edge.source),
338            entity_name_to_id.get(&edge.target),
339        ) else {
340            tracing::debug!(
341                "graph: skipping edge {:?}->{:?}: entity not resolved",
342                edge.source,
343                edge.target
344            );
345            continue;
346        };
347        match resolver
348            .resolve_edge(src_id, tgt_id, &edge.relation, &edge.fact, 0.8, None)
349            .await
350        {
351            Ok(Some(_)) => edges_inserted += 1,
352            Ok(None) => {} // deduplicated
353            Err(e) => {
354                tracing::debug!("graph: skipping edge: {e:#}");
355            }
356        }
357    }
358
359    let new_entity_ids: Vec<i64> = entity_name_to_id.into_values().collect();
360
361    Ok(ExtractionResult {
362        stats: ExtractionStats {
363            entities_upserted,
364            edges_inserted,
365        },
366        entity_ids: new_entity_ids,
367    })
368}
369
370impl SemanticMemory {
371    /// Spawn background graph extraction for a message. Fire-and-forget — never blocks.
372    ///
373    /// Extraction runs in a separate tokio task with a timeout. Any error or timeout is
374    /// logged and the task exits silently; the agent response is never blocked.
375    ///
376    /// The optional `post_extract_validator` is called after extraction, before upsert.
377    /// It is a generic predicate opaque to zeph-memory (design decision D1).
378    ///
379    /// When `config.note_linking.enabled` is `true` and an embedding store is available,
380    /// `link_memory_notes` runs after successful extraction inside the same task, bounded
381    /// by `config.note_linking.timeout_secs`.
382    #[allow(clippy::too_many_lines)]
383    pub fn spawn_graph_extraction(
384        &self,
385        content: String,
386        context_messages: Vec<String>,
387        config: GraphExtractionConfig,
388        post_extract_validator: PostExtractValidator,
389    ) {
390        let pool = self.sqlite.pool().clone();
391        let provider = self.provider.clone();
392        let failure_counter = self.community_detection_failures.clone();
393        let extraction_count = self.graph_extraction_count.clone();
394        let extraction_failures = self.graph_extraction_failures.clone();
395        // Clone the embedding store Arc before moving into the task.
396        let embedding_store = self.qdrant.clone();
397
398        tokio::spawn(async move {
399            let timeout_dur = std::time::Duration::from_secs(config.extraction_timeout_secs);
400            let extraction_result = tokio::time::timeout(
401                timeout_dur,
402                extract_and_store(
403                    content,
404                    context_messages,
405                    provider.clone(),
406                    pool.clone(),
407                    config.clone(),
408                    post_extract_validator,
409                ),
410            )
411            .await;
412
413            let (extraction_ok, new_entity_ids) = match extraction_result {
414                Ok(Ok(result)) => {
415                    tracing::debug!(
416                        entities = result.stats.entities_upserted,
417                        edges = result.stats.edges_inserted,
418                        "graph extraction completed"
419                    );
420                    extraction_count.fetch_add(1, Ordering::Relaxed);
421                    (true, result.entity_ids)
422                }
423                Ok(Err(e)) => {
424                    tracing::warn!("graph extraction failed: {e:#}");
425                    extraction_failures.fetch_add(1, Ordering::Relaxed);
426                    (false, vec![])
427                }
428                Err(_elapsed) => {
429                    tracing::warn!("graph extraction timed out");
430                    extraction_failures.fetch_add(1, Ordering::Relaxed);
431                    (false, vec![])
432                }
433            };
434
435            // A-MEM note linking: run after successful extraction when enabled.
436            if extraction_ok
437                && config.note_linking.enabled
438                && !new_entity_ids.is_empty()
439                && let Some(store) = embedding_store
440            {
441                let linking_timeout =
442                    std::time::Duration::from_secs(config.note_linking.timeout_secs);
443                match tokio::time::timeout(
444                    linking_timeout,
445                    link_memory_notes(
446                        &new_entity_ids,
447                        pool.clone(),
448                        store,
449                        provider.clone(),
450                        &config.note_linking,
451                    ),
452                )
453                .await
454                {
455                    Ok(stats) => {
456                        tracing::debug!(
457                            entities_processed = stats.entities_processed,
458                            edges_created = stats.edges_created,
459                            "note linking completed"
460                        );
461                    }
462                    Err(_elapsed) => {
463                        tracing::debug!("note linking timed out (partial edges may exist)");
464                    }
465                }
466            }
467
468            if extraction_ok && config.community_refresh_interval > 0 {
469                use crate::graph::GraphStore;
470
471                let store = GraphStore::new(pool.clone());
472                let extraction_count = store.extraction_count().await.unwrap_or(0);
473                if extraction_count > 0
474                    && i64::try_from(config.community_refresh_interval)
475                        .is_ok_and(|interval| extraction_count % interval == 0)
476                {
477                    tracing::info!(extraction_count, "triggering community detection refresh");
478                    let store2 = GraphStore::new(pool);
479                    let provider2 = provider;
480                    let retention_days = config.expired_edge_retention_days;
481                    let max_cap = config.max_entities_cap;
482                    let max_prompt_bytes = config.community_summary_max_prompt_bytes;
483                    let concurrency = config.community_summary_concurrency;
484                    let edge_chunk_size = config.lpa_edge_chunk_size;
485                    tokio::spawn(async move {
486                        match crate::graph::community::detect_communities(
487                            &store2,
488                            &provider2,
489                            max_prompt_bytes,
490                            concurrency,
491                            edge_chunk_size,
492                        )
493                        .await
494                        {
495                            Ok(count) => {
496                                tracing::info!(communities = count, "community detection complete");
497                            }
498                            Err(e) => {
499                                tracing::warn!("community detection failed: {e:#}");
500                                failure_counter.fetch_add(1, Ordering::Relaxed);
501                            }
502                        }
503                        match crate::graph::community::run_graph_eviction(
504                            &store2,
505                            retention_days,
506                            max_cap,
507                        )
508                        .await
509                        {
510                            Ok(stats) => {
511                                tracing::info!(
512                                    expired_edges = stats.expired_edges_deleted,
513                                    orphan_entities = stats.orphan_entities_deleted,
514                                    capped_entities = stats.capped_entities_deleted,
515                                    "graph eviction complete"
516                                );
517                            }
518                            Err(e) => {
519                                tracing::warn!("graph eviction failed: {e:#}");
520                            }
521                        }
522                    });
523                }
524            }
525        });
526    }
527}