Skip to main content

zeph_memory/semantic/
graph.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::sync::Arc;
5#[allow(unused_imports)]
6use zeph_db::sql;
7
8use std::sync::atomic::Ordering;
9use tokio_util::sync::CancellationToken;
10use zeph_db::DbPool;
11
12pub use zeph_common::config::memory::NoteLinkingConfig;
13use zeph_common::sanitize::strip_control_chars;
14use zeph_common::text::truncate_to_bytes_ref;
15use zeph_llm::any::AnyProvider;
16use zeph_llm::provider::LlmProvider as _;
17
18use crate::embedding_store::EmbeddingStore;
19use crate::error::MemoryError;
20use crate::graph::extractor::ExtractionResult as ExtractorResult;
21use crate::vector_store::VectorFilter;
22
23use super::SemanticMemory;
24
25/// Callback type for post-extraction validation.
26///
27/// A generic predicate opaque to zeph-memory — callers (zeph-core) provide security
28/// validation without introducing a dependency on security policy in this crate.
29pub type PostExtractValidator = Option<Box<dyn Fn(&ExtractorResult) -> Result<(), String> + Send>>;
30
31/// Config for the spawned background extraction task.
32///
33/// Owned clone of the relevant fields from `GraphConfig` — no references, safe to send to
34/// spawned tasks.
35#[derive(Debug, Clone)]
36pub struct GraphExtractionConfig {
37    pub max_entities: usize,
38    pub max_edges: usize,
39    pub extraction_timeout_secs: u64,
40    pub community_refresh_interval: usize,
41    pub expired_edge_retention_days: u32,
42    pub max_entities_cap: usize,
43    pub community_summary_max_prompt_bytes: usize,
44    pub community_summary_concurrency: usize,
45    pub lpa_edge_chunk_size: usize,
46    /// A-MEM note linking config, cloned from `GraphConfig.note_linking`.
47    pub note_linking: NoteLinkingConfig,
48    /// A-MEM link weight decay lambda. Range: `(0.0, 1.0]`. Default: `0.95`.
49    pub link_weight_decay_lambda: f64,
50    /// Seconds between link weight decay passes. Default: `86400`.
51    pub link_weight_decay_interval_secs: u64,
52    /// Kumiho belief revision: enable semantic contradiction detection for edges.
53    pub belief_revision_enabled: bool,
54    /// Cosine similarity threshold for belief revision contradiction detection.
55    pub belief_revision_similarity_threshold: f32,
56    /// GAAMA episode linking: `conversation_id` to link extracted entities to their episode.
57    /// `None` disables episode linking for this extraction pass.
58    pub conversation_id: Option<i64>,
59    /// APEX-MEM: use `insert_or_supersede` instead of `resolve_edge_typed`. Default: `false`.
60    pub apex_mem_enabled: bool,
61    /// LLM call timeout for extraction, in seconds. Default: `30`.
62    pub llm_timeout_secs: u64,
63    /// Per-call timeout for every `embed()` invocation, in seconds. Default: `5`.
64    pub embed_timeout_secs: u64,
65}
66
67impl Default for GraphExtractionConfig {
68    fn default() -> Self {
69        Self {
70            max_entities: 0,
71            max_edges: 0,
72            extraction_timeout_secs: 0,
73            community_refresh_interval: 0,
74            expired_edge_retention_days: 0,
75            max_entities_cap: 0,
76            community_summary_max_prompt_bytes: 0,
77            community_summary_concurrency: 0,
78            lpa_edge_chunk_size: 0,
79            note_linking: NoteLinkingConfig::default(),
80            link_weight_decay_lambda: 0.95,
81            link_weight_decay_interval_secs: 86400,
82            belief_revision_enabled: false,
83            belief_revision_similarity_threshold: 0.85,
84            conversation_id: None,
85            apex_mem_enabled: false,
86            llm_timeout_secs: 30,
87            embed_timeout_secs: 5,
88        }
89    }
90}
91
92/// Stats returned from a completed extraction.
93#[derive(Debug, Default)]
94pub struct ExtractionStats {
95    pub entities_upserted: usize,
96    pub edges_inserted: usize,
97}
98
99/// Result returned from `extract_and_store`, combining stats with entity IDs needed for linking.
100#[derive(Debug, Default)]
101pub struct ExtractionResult {
102    pub stats: ExtractionStats,
103    /// IDs of entities upserted during this extraction pass. Passed to `link_memory_notes`.
104    pub entity_ids: Vec<i64>,
105}
106
107/// Stats returned from a completed note-linking pass.
108#[derive(Debug, Default)]
109pub struct LinkingStats {
110    pub entities_processed: usize,
111    pub edges_created: usize,
112}
113
114/// Qdrant collection name for entity embeddings (mirrors the constant in `resolver.rs`).
115const ENTITY_COLLECTION: &str = "zeph_graph_entities";
116
117/// Mirrors the constant from `graph/resolver/mod.rs` — used for sanitizing APEX-MEM inputs.
118const MAX_RELATION_BYTES: usize = 256;
119/// Mirrors the constant from `graph/resolver/mod.rs` — used for sanitizing APEX-MEM inputs.
120const MAX_FACT_BYTES: usize = 2048;
121
122/// Work item for a single entity during a note-linking pass.
123struct EntityWorkItem {
124    entity_id: i64,
125    canonical_name: String,
126    embed_text: String,
127    self_point_id: Option<String>,
128}
129
130/// Link newly extracted entities to semantically similar entities in the graph.
131///
132/// For each entity in `entity_ids`:
133/// 1. Load the entity name + summary from `SQLite`.
134/// 2. Embed all entity texts in parallel.
135/// 3. Search the entity embedding collection in parallel for the `top_k + 1` most similar points.
136/// 4. Filter out the entity itself (by `qdrant_point_id` or `entity_id` payload) and points
137///    below `similarity_threshold`.
138/// 5. Insert a unidirectional `similar_to` edge where `source_id < target_id` to avoid
139///    double-counting in BFS recall while still being traversable via the OR clause in
140///    `edges_for_entity`. The edge confidence is set to the cosine similarity score.
141/// 6. Deduplicate pairs within a single pass so that a pair encountered from both A→B and B→A
142///    directions is only inserted once, keeping `edges_created` accurate.
143///
144/// Errors are logged and not propagated — this is a best-effort background enrichment step.
145pub async fn link_memory_notes(
146    entity_ids: &[i64],
147    pool: DbPool,
148    embedding_store: Arc<EmbeddingStore>,
149    provider: AnyProvider,
150    cfg: &NoteLinkingConfig,
151) -> LinkingStats {
152    use crate::graph::GraphStore;
153
154    let store = GraphStore::new(pool);
155    let mut stats = LinkingStats::default();
156
157    let work_items = collect_note_link_work_items(entity_ids, &store).await;
158    if work_items.is_empty() {
159        return stats;
160    }
161
162    let valid = embed_work_items(&work_items, &provider, cfg).await;
163
164    let search_limit = cfg.top_k + 1; // +1 to account for self-match
165    let search_results = search_similar_for_items(&valid, &embedding_store, search_limit).await;
166
167    insert_similarity_edges(
168        &work_items,
169        &valid,
170        &search_results,
171        cfg,
172        &store,
173        &mut stats,
174    )
175    .await;
176
177    stats
178}
179
180/// Phase 1: load entities from the DB and build work items for embedding.
181///
182/// Processes entities sequentially to avoid connection-pool contention.
183async fn collect_note_link_work_items(
184    entity_ids: &[i64],
185    store: &crate::graph::GraphStore,
186) -> Vec<EntityWorkItem> {
187    let mut work_items: Vec<EntityWorkItem> = Vec::with_capacity(entity_ids.len());
188    for &entity_id in entity_ids {
189        let entity = match store.find_entity_by_id(entity_id).await {
190            Ok(Some(e)) => e,
191            Ok(None) => {
192                tracing::debug!("note_linking: entity {entity_id} not found, skipping");
193                continue;
194            }
195            Err(e) => {
196                tracing::debug!("note_linking: DB error loading entity {entity_id}: {e:#}");
197                continue;
198            }
199        };
200        let embed_text = match &entity.summary {
201            Some(s) if !s.is_empty() => format!("{}: {s}", entity.canonical_name),
202            _ => entity.canonical_name.clone(),
203        };
204        work_items.push(EntityWorkItem {
205            entity_id,
206            canonical_name: entity.canonical_name,
207            embed_text,
208            self_point_id: entity.qdrant_point_id,
209        });
210    }
211    work_items
212}
213
214/// Phase 2: embed all entity texts in parallel.
215///
216/// Returns `(work_idx, embedding)` pairs for successfully embedded items.
217/// Items that fail to embed are logged and dropped.
218async fn embed_work_items(
219    work_items: &[EntityWorkItem],
220    provider: &AnyProvider,
221    cfg: &NoteLinkingConfig,
222) -> Vec<(usize, Vec<f32>)> {
223    use futures::future;
224
225    let Ok(embed_results) = tokio::time::timeout(
226        std::time::Duration::from_secs(cfg.timeout_secs),
227        future::join_all(work_items.iter().map(|w| provider.embed(&w.embed_text))),
228    )
229    .await
230    else {
231        tracing::warn!(
232            count = work_items.len(),
233            "note_linking: batch embed timed out — skipping all entities"
234        );
235        return Vec::new();
236    };
237    embed_results
238        .into_iter()
239        .enumerate()
240        .filter_map(|(i, r)| match r {
241            Ok(v) => Some((i, v)),
242            Err(e) => {
243                tracing::debug!(
244                    "note_linking: embed failed for entity {:?}: {e:#}",
245                    work_items[i].canonical_name
246                );
247                None
248            }
249        })
250        .collect()
251}
252
253/// Phase 3: search the embedding store for similar entities for each embedded work item.
254async fn search_similar_for_items(
255    valid: &[(usize, Vec<f32>)],
256    embedding_store: &EmbeddingStore,
257    search_limit: usize,
258) -> Vec<Result<Vec<crate::ScoredVectorPoint>, MemoryError>> {
259    use futures::future;
260
261    future::join_all(valid.iter().map(|(_, vec)| {
262        embedding_store.search_collection(
263            ENTITY_COLLECTION,
264            vec,
265            search_limit,
266            None::<VectorFilter>,
267        )
268    }))
269    .await
270}
271
272/// Phase 4: insert similarity edges, deduplicating pairs seen from both A→B and B→A.
273///
274/// Without deduplication, both directions would call `insert_edge` for the same normalised
275/// pair and both return `Ok`, inflating `edges_created` by the number of bidirectional hits.
276async fn insert_similarity_edges(
277    work_items: &[EntityWorkItem],
278    valid: &[(usize, Vec<f32>)],
279    search_results: &[Result<Vec<crate::ScoredVectorPoint>, MemoryError>],
280    cfg: &NoteLinkingConfig,
281    store: &crate::graph::GraphStore,
282    stats: &mut LinkingStats,
283) {
284    let mut seen_pairs = std::collections::HashSet::new();
285
286    for ((work_idx, _), search_result) in valid.iter().zip(search_results.iter()) {
287        let w = &work_items[*work_idx];
288
289        let results = match search_result {
290            Ok(r) => r,
291            Err(e) => {
292                tracing::debug!(
293                    "note_linking: search failed for entity {:?}: {e:#}",
294                    w.canonical_name
295                );
296                continue;
297            }
298        };
299
300        stats.entities_processed += 1;
301
302        let self_point_id = w.self_point_id.as_deref();
303        let candidates = results
304            .iter()
305            .filter(|p| Some(p.id.as_str()) != self_point_id && p.score >= cfg.similarity_threshold)
306            .take(cfg.top_k);
307
308        for point in candidates {
309            let Some(target_id) = point
310                .payload
311                .get("entity_id")
312                .and_then(serde_json::Value::as_i64)
313            else {
314                tracing::debug!(
315                    "note_linking: missing entity_id in payload for point {}",
316                    point.id
317                );
318                continue;
319            };
320
321            if target_id == w.entity_id {
322                continue; // secondary self-guard when qdrant_point_id is null
323            }
324
325            // Normalise direction: always store source_id < target_id.
326            let (src, tgt) = if w.entity_id < target_id {
327                (w.entity_id, target_id)
328            } else {
329                (target_id, w.entity_id)
330            };
331
332            if !seen_pairs.insert((src, tgt)) {
333                continue;
334            }
335
336            let fact = format!("Semantically similar entities (score: {:.3})", point.score);
337
338            match store
339                .insert_edge(src, tgt, "similar_to", &fact, point.score, None)
340                .await
341            {
342                Ok(_) => stats.edges_created += 1,
343                Err(e) => {
344                    tracing::debug!("note_linking: insert_edge failed: {e:#}");
345                }
346            }
347        }
348    }
349}
350
351/// Extract entities and edges from `content` and persist them to the graph store.
352///
353/// This function runs inside a spawned task — it receives owned data only.
354///
355/// The optional `embedding_store` enables entity embedding storage in Qdrant, which is
356/// required for A-MEM note linking to find semantically similar entities across sessions.
357///
358/// # Errors
359///
360/// Returns an error if the database query fails or LLM extraction fails.
361#[cfg_attr(
362    feature = "profiling",
363    tracing::instrument(name = "memory.graph_extract", skip_all, fields(entities = tracing::field::Empty, edges = tracing::field::Empty))
364)]
365pub async fn extract_and_store(
366    content: String,
367    context_messages: Vec<String>,
368    provider: AnyProvider,
369    pool: DbPool,
370    config: GraphExtractionConfig,
371    post_extract_validator: PostExtractValidator,
372    embedding_store: Option<Arc<EmbeddingStore>>,
373) -> Result<ExtractionResult, MemoryError> {
374    use crate::graph::{EntityResolver, GraphExtractor, GraphStore};
375
376    let extractor = GraphExtractor::new(
377        provider.clone(),
378        config.max_entities,
379        config.max_edges,
380        config.llm_timeout_secs,
381    );
382    let ctx_refs: Vec<&str> = context_messages.iter().map(String::as_str).collect();
383
384    let store = GraphStore::new(pool);
385
386    bump_extraction_count(store.pool()).await?;
387
388    let Some(result) = extractor.extract(&content, &ctx_refs).await? else {
389        return Ok(ExtractionResult::default());
390    };
391
392    // Post-extraction validation callback. zeph-memory does not know the callback is a
393    // security validator — it is a generic predicate opaque to this crate (design decision D1).
394    if let Some(ref validator) = post_extract_validator
395        && let Err(reason) = validator(&result)
396    {
397        tracing::warn!(
398            reason,
399            "graph extraction validation failed, skipping upsert"
400        );
401        return Ok(ExtractionResult::default());
402    }
403
404    let resolver = if let Some(ref emb) = embedding_store {
405        EntityResolver::new(&store)
406            .with_embedding_store(emb)
407            .with_provider(&provider)
408            .with_embed_timeout(config.embed_timeout_secs)
409    } else {
410        EntityResolver::new(&store).with_embed_timeout(config.embed_timeout_secs)
411    };
412
413    let (entity_name_to_id, entities_upserted) = upsert_entities(&resolver, &result.entities).await;
414    let edges_inserted = insert_edges(&resolver, &result.edges, &entity_name_to_id, &config).await;
415
416    #[cfg(any(feature = "sqlite", feature = "postgres"))]
417    store.checkpoint_wal().await?;
418
419    let new_entity_ids: Vec<i64> = entity_name_to_id.into_values().collect();
420
421    link_episode(&store, &config, &new_entity_ids).await;
422
423    #[cfg(feature = "profiling")]
424    {
425        let span = tracing::Span::current();
426        span.record("entities", entities_upserted);
427        span.record("edges", edges_inserted);
428    }
429
430    Ok(ExtractionResult {
431        stats: ExtractionStats {
432            entities_upserted,
433            edges_inserted,
434        },
435        entity_ids: new_entity_ids,
436    })
437}
438
439/// Increment the extraction counter in `graph_metadata`.
440async fn bump_extraction_count(pool: &DbPool) -> Result<(), MemoryError> {
441    zeph_db::query(sql!(
442        "INSERT INTO graph_metadata (key, value) VALUES ('extraction_count', '0')
443         ON CONFLICT(key) DO NOTHING"
444    ))
445    .execute(pool)
446    .await?;
447    zeph_db::query(sql!(
448        "UPDATE graph_metadata
449         SET value = CAST(CAST(value AS INTEGER) + 1 AS TEXT)
450         WHERE key = 'extraction_count'"
451    ))
452    .execute(pool)
453    .await?;
454    Ok(())
455}
456
457/// Upsert all extracted entities and return the name-to-id map and upsert count.
458async fn upsert_entities(
459    resolver: &crate::graph::EntityResolver<'_>,
460    entities: &[crate::graph::extractor::ExtractedEntity],
461) -> (std::collections::HashMap<String, i64>, usize) {
462    let mut entity_name_to_id: std::collections::HashMap<String, i64> =
463        std::collections::HashMap::new();
464    let mut entities_upserted = 0usize;
465
466    for entity in entities {
467        match resolver
468            .resolve(&entity.name, &entity.entity_type, entity.summary.as_deref())
469            .await
470        {
471            Ok((id, _outcome)) => {
472                entity_name_to_id.insert(entity.name.clone(), id);
473                entities_upserted += 1;
474            }
475            Err(e) => {
476                tracing::debug!("graph: skipping entity {:?}: {e:#}", entity.name);
477            }
478        }
479    }
480
481    (entity_name_to_id, entities_upserted)
482}
483
484/// Insert extracted edges that have both endpoints in `name_to_id`.
485///
486/// Returns the number of edges actually inserted.
487async fn insert_edges(
488    resolver: &crate::graph::EntityResolver<'_>,
489    edges: &[crate::graph::extractor::ExtractedEdge],
490    name_to_id: &std::collections::HashMap<String, i64>,
491    config: &GraphExtractionConfig,
492) -> usize {
493    let mut edges_inserted = 0usize;
494    for edge in edges {
495        let (Some(&src_id), Some(&tgt_id)) =
496            (name_to_id.get(&edge.source), name_to_id.get(&edge.target))
497        else {
498            tracing::debug!(
499                "graph: skipping edge {:?}->{:?}: entity not resolved",
500                edge.source,
501                edge.target
502            );
503            continue;
504        };
505        if src_id == tgt_id {
506            tracing::debug!(
507                "graph: skipping self-loop edge {:?}->{:?} (entity_id={src_id})",
508                edge.source,
509                edge.target
510            );
511            continue;
512        }
513        // Parse LLM-provided edge_type; default to Semantic on any parse failure so
514        // edges are never dropped due to classification errors.
515        let edge_type = edge
516            .edge_type
517            .parse::<crate::graph::EdgeType>()
518            .unwrap_or_else(|_| {
519                tracing::warn!(
520                    raw_type = %edge.edge_type,
521                    "graph: unknown edge_type from LLM, defaulting to semantic"
522                );
523                crate::graph::EdgeType::Semantic
524            });
525        if config.apex_mem_enabled {
526            // APEX-MEM: append-only write path with supersession chains.
527            let relation_trimmed = edge.relation.trim();
528            let relation_display_clean = strip_control_chars(relation_trimmed);
529            let relation_display =
530                truncate_to_bytes_ref(&relation_display_clean, MAX_RELATION_BYTES).to_owned();
531            let canonical_clean = strip_control_chars(&relation_trimmed.to_lowercase());
532            let canonical_relation =
533                truncate_to_bytes_ref(&canonical_clean, MAX_RELATION_BYTES).to_owned();
534            let fact_clean = strip_control_chars(edge.fact.trim());
535            let normalized_fact = truncate_to_bytes_ref(&fact_clean, MAX_FACT_BYTES).to_owned();
536            match resolver
537                .graph_store()
538                .insert_or_supersede(
539                    src_id,
540                    tgt_id,
541                    &relation_display,
542                    &canonical_relation,
543                    &normalized_fact,
544                    0.8,
545                    None,
546                    edge_type,
547                    true,
548                )
549                .await
550            {
551                Ok(_) => edges_inserted += 1,
552                Err(e) => {
553                    tracing::debug!("graph: skipping edge (apex): {e:#}");
554                }
555            }
556        } else {
557            let belief_cfg =
558                config
559                    .belief_revision_enabled
560                    .then_some(crate::graph::BeliefRevisionConfig {
561                        similarity_threshold: config.belief_revision_similarity_threshold,
562                    });
563            match resolver
564                .resolve_edge_typed(
565                    src_id,
566                    tgt_id,
567                    &edge.relation,
568                    &edge.fact,
569                    0.8,
570                    None,
571                    edge_type,
572                    belief_cfg.as_ref(),
573                )
574                .await
575            {
576                Ok(Some(_)) => edges_inserted += 1,
577                Ok(None) => {} // deduplicated
578                Err(e) => {
579                    tracing::debug!("graph: skipping edge: {e:#}");
580                }
581            }
582        }
583    }
584    edges_inserted
585}
586
587/// Link extracted entities to their GAAMA episode when a conversation ID is configured.
588async fn link_episode(
589    store: &crate::graph::GraphStore,
590    config: &GraphExtractionConfig,
591    entity_ids: &[i64],
592) {
593    let Some(conv_id) = config.conversation_id else {
594        return;
595    };
596    match store.ensure_episode(conv_id).await {
597        Ok(episode_id) => {
598            for &entity_id in entity_ids {
599                if let Err(e) = store.link_entity_to_episode(episode_id, entity_id).await {
600                    tracing::debug!("episode linking skipped for entity {entity_id}: {e:#}");
601                }
602            }
603        }
604        Err(e) => {
605            tracing::warn!("failed to ensure episode for conversation {conv_id}: {e:#}");
606        }
607    }
608}
609
610impl SemanticMemory {
611    /// Spawn background graph extraction for a message. Fire-and-forget — never blocks.
612    ///
613    /// Extraction runs in a separate tokio task with a timeout. Any error or timeout is
614    /// logged and the task exits silently; the agent response is never blocked.
615    ///
616    /// The optional `post_extract_validator` is called after extraction, before upsert.
617    /// It is a generic predicate opaque to zeph-memory (design decision D1).
618    ///
619    /// When `config.note_linking.enabled` is `true` and an embedding store is available,
620    /// `link_memory_notes` runs after successful extraction inside the same task, bounded
621    /// by `config.note_linking.timeout_secs`.
622    ///
623    /// # Panics
624    ///
625    /// Panics if the internal `graph_cancel` mutex is poisoned (another thread panicked
626    /// while holding the lock).
627    pub fn spawn_graph_extraction(
628        &self,
629        content: String,
630        context_messages: Vec<String>,
631        config: GraphExtractionConfig,
632        post_extract_validator: PostExtractValidator,
633        provider_override: Option<AnyProvider>,
634        cancel: CancellationToken,
635    ) -> tokio::task::JoinHandle<()> {
636        let using_override = provider_override.is_some();
637        let provider = provider_override.unwrap_or_else(|| self.provider.clone());
638        if using_override {
639            tracing::debug!(
640                extract_provider = provider.name(),
641                "graph extraction using override provider (quality_gate bypassed)"
642            );
643        }
644        *self
645            .graph_cancel
646            .lock()
647            .expect("graph_cancel mutex poisoned") = Some(cancel.clone());
648
649        let ctx = GraphExtractionTaskCtx {
650            pool: self.sqlite.pool().clone(),
651            provider,
652            failure_counter: self.community_detection_failures.clone(),
653            extraction_count: self.graph_extraction_count.clone(),
654            extraction_failures: self.graph_extraction_failures.clone(),
655            embedding_store: self.qdrant.clone(),
656            cancel,
657        };
658
659        tokio::spawn(run_graph_extraction_task(
660            content,
661            context_messages,
662            config,
663            post_extract_validator,
664            ctx,
665        ))
666    }
667
668    /// Signal cooperative cancellation to the current background graph-extraction task.
669    ///
670    /// Fires the [`CancellationToken`] stored by the most recent [`spawn_graph_extraction`]
671    /// call. The task checks the token at community-refresh boundaries, so it exits cleanly
672    /// rather than being hard-aborted. This should be called before the supervisor calls
673    /// `abort()` on the underlying `JoinHandle` to give the task a chance to flush state.
674    ///
675    /// No-op if no extraction has been spawned or the previous token has already fired.
676    ///
677    /// # Panics
678    ///
679    /// Panics if the internal `graph_cancel` mutex is poisoned (another thread panicked
680    /// while holding the lock).
681    ///
682    /// [`spawn_graph_extraction`]: SemanticMemory::spawn_graph_extraction
683    pub fn cancel_graph_extraction(&self) {
684        if let Some(token) = self
685            .graph_cancel
686            .lock()
687            .expect("graph_cancel mutex poisoned")
688            .as_ref()
689        {
690            token.cancel();
691        }
692    }
693}
694
695/// Owned context bundled for the spawned extraction task.
696///
697/// Bundles the Arcs that must be cloned before entering `tokio::spawn`.
698struct GraphExtractionTaskCtx {
699    pool: DbPool,
700    provider: AnyProvider,
701    failure_counter: Arc<std::sync::atomic::AtomicU64>,
702    extraction_count: Arc<std::sync::atomic::AtomicU64>,
703    extraction_failures: Arc<std::sync::atomic::AtomicU64>,
704    embedding_store: Option<Arc<EmbeddingStore>>,
705    /// Cancellation signal propagated into background sub-tasks (community refresh).
706    cancel: CancellationToken,
707}
708
709/// Body of the spawned graph-extraction task.
710async fn run_graph_extraction_task(
711    content: String,
712    context_messages: Vec<String>,
713    config: GraphExtractionConfig,
714    post_extract_validator: PostExtractValidator,
715    ctx: GraphExtractionTaskCtx,
716) {
717    let timeout_dur = std::time::Duration::from_secs(config.extraction_timeout_secs);
718    let extraction_result = tokio::time::timeout(
719        timeout_dur,
720        extract_and_store(
721            content,
722            context_messages,
723            ctx.provider.clone(),
724            ctx.pool.clone(),
725            config.clone(),
726            post_extract_validator,
727            ctx.embedding_store.clone(),
728        ),
729    )
730    .await;
731
732    let (extraction_ok, new_entity_ids) = match extraction_result {
733        Ok(Ok(result)) => {
734            tracing::debug!(
735                entities = result.stats.entities_upserted,
736                edges = result.stats.edges_inserted,
737                "graph extraction completed"
738            );
739            ctx.extraction_count.fetch_add(1, Ordering::Relaxed);
740            (true, result.entity_ids)
741        }
742        Ok(Err(e)) => {
743            tracing::warn!("graph extraction failed: {e:#}");
744            ctx.extraction_failures.fetch_add(1, Ordering::Relaxed);
745            (false, vec![])
746        }
747        Err(_elapsed) => {
748            tracing::warn!("graph extraction timed out");
749            ctx.extraction_failures.fetch_add(1, Ordering::Relaxed);
750            (false, vec![])
751        }
752    };
753
754    run_note_linking(
755        extraction_ok,
756        &new_entity_ids,
757        ctx.pool.clone(),
758        ctx.embedding_store,
759        ctx.provider.clone(),
760        &config,
761    )
762    .await;
763
764    maybe_refresh_communities(
765        extraction_ok,
766        ctx.pool,
767        ctx.provider,
768        ctx.failure_counter,
769        &config,
770        ctx.cancel,
771    )
772    .await;
773}
774
775/// Run A-MEM note linking after successful extraction when enabled.
776async fn run_note_linking(
777    extraction_ok: bool,
778    new_entity_ids: &[i64],
779    pool: DbPool,
780    embedding_store: Option<Arc<EmbeddingStore>>,
781    provider: AnyProvider,
782    config: &GraphExtractionConfig,
783) {
784    if !extraction_ok || !config.note_linking.enabled || new_entity_ids.is_empty() {
785        return;
786    }
787    let Some(store) = embedding_store else {
788        return;
789    };
790    let linking_timeout = std::time::Duration::from_secs(config.note_linking.timeout_secs);
791    match tokio::time::timeout(
792        linking_timeout,
793        link_memory_notes(new_entity_ids, pool, store, provider, &config.note_linking),
794    )
795    .await
796    {
797        Ok(stats) => {
798            tracing::debug!(
799                entities_processed = stats.entities_processed,
800                edges_created = stats.edges_created,
801                "note linking completed"
802            );
803        }
804        Err(_elapsed) => {
805            tracing::debug!("note linking timed out (partial edges may exist)");
806        }
807    }
808}
809
810/// Trigger community detection, graph eviction, and link-weight decay when the extraction
811/// count hits the configured refresh interval.
812///
813/// Runs inline within the caller's task (no nested `tokio::spawn`). Each long-running step
814/// is guarded by `tokio::select!` on `cancel` so shutdown aborts immediately at the next
815/// yield point without leaving orphaned tasks.
816async fn maybe_refresh_communities(
817    extraction_ok: bool,
818    pool: DbPool,
819    provider: AnyProvider,
820    failure_counter: Arc<std::sync::atomic::AtomicU64>,
821    config: &GraphExtractionConfig,
822    cancel: CancellationToken,
823) {
824    use crate::graph::GraphStore;
825
826    if !extraction_ok || config.community_refresh_interval == 0 {
827        return;
828    }
829
830    let store = GraphStore::new(pool.clone());
831    let extraction_count = store.extraction_count().await.unwrap_or(0);
832    if extraction_count == 0
833        || !i64::try_from(config.community_refresh_interval)
834            .is_ok_and(|interval| extraction_count % interval == 0)
835    {
836        return;
837    }
838
839    tracing::info!(extraction_count, "triggering community detection refresh");
840    let store2 = GraphStore::new(pool);
841    let retention_days = config.expired_edge_retention_days;
842    let max_cap = config.max_entities_cap;
843    let max_prompt_bytes = config.community_summary_max_prompt_bytes;
844    let concurrency = config.community_summary_concurrency;
845    let edge_chunk_size = config.lpa_edge_chunk_size;
846    let decay_lambda = config.link_weight_decay_lambda;
847    let decay_interval_secs = config.link_weight_decay_interval_secs;
848
849    tokio::select! {
850        () = cancel.cancelled() => {
851            tracing::debug!("community refresh cancelled before community detection");
852            return;
853        }
854        result = crate::graph::community::detect_communities(
855            &store2,
856            &provider,
857            max_prompt_bytes,
858            concurrency,
859            edge_chunk_size,
860        ) => {
861            match result {
862                Ok(count) => {
863                    tracing::info!(communities = count, "community detection complete");
864                }
865                Err(e) => {
866                    tracing::warn!("community detection failed: {e:#}");
867                    failure_counter.fetch_add(1, Ordering::Relaxed);
868                }
869            }
870        }
871    }
872
873    tokio::select! {
874        () = cancel.cancelled() => {
875            tracing::debug!("community refresh cancelled before graph eviction");
876            return;
877        }
878        result = crate::graph::community::run_graph_eviction(&store2, retention_days, max_cap) => {
879            match result {
880                Ok(stats) => {
881                    tracing::info!(
882                        expired_edges = stats.expired_edges_deleted,
883                        orphan_entities = stats.orphan_entities_deleted,
884                        capped_entities = stats.capped_entities_deleted,
885                        "graph eviction complete"
886                    );
887                }
888                Err(e) => {
889                    tracing::warn!("graph eviction failed: {e:#}");
890                }
891            }
892        }
893    }
894
895    // Time-based link weight decay — independent of eviction cycle.
896    if decay_lambda > 0.0 && decay_interval_secs > 0 {
897        let now_secs = std::time::SystemTime::now()
898            .duration_since(std::time::UNIX_EPOCH)
899            .map_or(0, |d| d.as_secs());
900        let last_decay = store2
901            .get_metadata("last_link_weight_decay_at")
902            .await
903            .ok()
904            .flatten()
905            .and_then(|s| s.parse::<u64>().ok())
906            .unwrap_or(0);
907        if now_secs.saturating_sub(last_decay) >= decay_interval_secs {
908            tokio::select! {
909                () = cancel.cancelled() => {
910                    tracing::debug!("community refresh cancelled before link weight decay");
911                }
912                result = store2.decay_edge_retrieval_counts(decay_lambda, decay_interval_secs) => {
913                    match result {
914                        Ok(affected) => {
915                            tracing::info!(affected, "link weight decay applied");
916                            let _ = store2
917                                .set_metadata("last_link_weight_decay_at", &now_secs.to_string())
918                                .await;
919                        }
920                        Err(e) => {
921                            tracing::warn!("link weight decay failed: {e:#}");
922                        }
923                    }
924                }
925            }
926        }
927    }
928}
929
930#[cfg(test)]
931mod tests {
932    use std::sync::Arc;
933
934    use zeph_llm::any::AnyProvider;
935
936    use super::{NoteLinkingConfig, extract_and_store};
937    use crate::embedding_store::EmbeddingStore;
938    use crate::graph::GraphStore;
939    use crate::in_memory_store::InMemoryVectorStore;
940    use crate::store::SqliteStore;
941
942    use super::GraphExtractionConfig;
943
944    async fn setup() -> (GraphStore, Arc<EmbeddingStore>) {
945        let sqlite = SqliteStore::new(":memory:").await.unwrap();
946        let pool = sqlite.pool().clone();
947        let mem_store = Box::new(InMemoryVectorStore::new());
948        let emb = Arc::new(EmbeddingStore::with_store(mem_store, pool.clone()));
949        let gs = GraphStore::new(pool);
950        (gs, emb)
951    }
952
953    /// Regression test for #1829: `extract_and_store()` must pass the provider to `EntityResolver`
954    /// so that `store_entity_embedding()` is called and `qdrant_point_id` is set in `SQLite`.
955    #[tokio::test]
956    async fn extract_and_store_sets_qdrant_point_id_when_embedding_store_provided() {
957        let (gs, emb) = setup().await;
958
959        // MockProvider: supports embeddings, returns a valid extraction JSON for chat
960        let extraction_json = r#"{"entities":[{"name":"Rust","type":"language","summary":"systems language"}],"edges":[]}"#;
961        let mut mock =
962            zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
963        mock.supports_embeddings = true;
964        mock.embedding = vec![1.0_f32, 0.0, 0.0, 0.0];
965        let provider = AnyProvider::Mock(mock);
966
967        let config = GraphExtractionConfig {
968            max_entities: 10,
969            max_edges: 10,
970            extraction_timeout_secs: 10,
971            ..Default::default()
972        };
973
974        let result = extract_and_store(
975            "Rust is a systems programming language.".to_owned(),
976            vec![],
977            provider,
978            gs.pool().clone(),
979            config,
980            None,
981            Some(emb.clone()),
982        )
983        .await
984        .unwrap();
985
986        assert_eq!(
987            result.stats.entities_upserted, 1,
988            "one entity should be upserted"
989        );
990
991        // The entity must have a qdrant_point_id — this proves store_entity_embedding() was called.
992        // Before the fix, EntityResolver was built without a provider, so embed() was never called
993        // and qdrant_point_id remained NULL.
994        let entity = gs
995            .find_entity("rust", crate::graph::EntityType::Language)
996            .await
997            .unwrap()
998            .expect("entity 'rust' must exist in SQLite");
999
1000        assert!(
1001            entity.qdrant_point_id.is_some(),
1002            "qdrant_point_id must be set when embedding_store + provider are both provided (regression for #1829)"
1003        );
1004    }
1005
1006    /// When no `embedding_store` is provided, `extract_and_store()` must still work correctly
1007    /// (no embeddings stored, but entities are still upserted).
1008    #[tokio::test]
1009    async fn extract_and_store_without_embedding_store_still_upserts_entities() {
1010        let (gs, _emb) = setup().await;
1011
1012        let extraction_json = r#"{"entities":[{"name":"Python","type":"language","summary":"scripting"}],"edges":[]}"#;
1013        let mock = zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
1014        let provider = AnyProvider::Mock(mock);
1015
1016        let config = GraphExtractionConfig {
1017            max_entities: 10,
1018            max_edges: 10,
1019            extraction_timeout_secs: 10,
1020            ..Default::default()
1021        };
1022
1023        let result = extract_and_store(
1024            "Python is a scripting language.".to_owned(),
1025            vec![],
1026            provider,
1027            gs.pool().clone(),
1028            config,
1029            None,
1030            None, // no embedding_store
1031        )
1032        .await
1033        .unwrap();
1034
1035        assert_eq!(result.stats.entities_upserted, 1);
1036
1037        let entity = gs
1038            .find_entity("python", crate::graph::EntityType::Language)
1039            .await
1040            .unwrap()
1041            .expect("entity 'python' must exist");
1042
1043        assert!(
1044            entity.qdrant_point_id.is_none(),
1045            "qdrant_point_id must remain None when no embedding_store is provided"
1046        );
1047    }
1048
1049    /// Regression test for #2166: FTS5 entity writes must be visible to a new connection pool
1050    /// opened after extraction completes. Without `checkpoint_wal()` in `extract_and_store`,
1051    /// a fresh pool sees stale FTS5 shadow tables and `find_entities_fuzzy` returns empty.
1052    #[tokio::test]
1053    async fn extract_and_store_fts5_cross_session_visibility() {
1054        let file = tempfile::NamedTempFile::new().expect("tempfile");
1055        let path = file.path().to_str().expect("valid path").to_string();
1056
1057        // Session A: run extract_and_store on a file DB (not :memory:) so WAL is used.
1058        {
1059            let sqlite = crate::store::SqliteStore::new(&path).await.unwrap();
1060            let extraction_json = r#"{"entities":[{"name":"Ferris","type":"concept","summary":"Rust mascot"}],"edges":[]}"#;
1061            let mock =
1062                zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
1063            let provider = AnyProvider::Mock(mock);
1064            let config = GraphExtractionConfig {
1065                max_entities: 10,
1066                max_edges: 10,
1067                extraction_timeout_secs: 10,
1068                ..Default::default()
1069            };
1070            extract_and_store(
1071                "Ferris is the Rust mascot.".to_owned(),
1072                vec![],
1073                provider,
1074                sqlite.pool().clone(),
1075                config,
1076                None,
1077                None,
1078            )
1079            .await
1080            .unwrap();
1081        }
1082
1083        // Session B: new pool — FTS5 must see the entity extracted in session A.
1084        let sqlite_b = crate::store::SqliteStore::new(&path).await.unwrap();
1085        let gs_b = crate::graph::GraphStore::new(sqlite_b.pool().clone());
1086        let results = gs_b.find_entities_fuzzy("Ferris", 10).await.unwrap();
1087        assert!(
1088            !results.is_empty(),
1089            "FTS5 cross-session (#2166): entity extracted in session A must be visible in session B"
1090        );
1091    }
1092
1093    /// Regression test for #2215: self-loop edges (source == target entity) must be silently
1094    /// skipped; no edge row should be inserted.
1095    #[tokio::test]
1096    async fn extract_and_store_skips_self_loop_edges() {
1097        let (gs, _emb) = setup().await;
1098
1099        // LLM returns one entity and one self-loop edge (source == target).
1100        let extraction_json = r#"{
1101            "entities":[{"name":"Rust","type":"language","summary":"systems language"}],
1102            "edges":[{"source":"Rust","target":"Rust","relation":"is","fact":"Rust is Rust","edge_type":"semantic"}]
1103        }"#;
1104        let mock = zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
1105        let provider = AnyProvider::Mock(mock);
1106
1107        let config = GraphExtractionConfig {
1108            max_entities: 10,
1109            max_edges: 10,
1110            extraction_timeout_secs: 10,
1111            ..Default::default()
1112        };
1113
1114        let result = extract_and_store(
1115            "Rust is a language.".to_owned(),
1116            vec![],
1117            provider,
1118            gs.pool().clone(),
1119            config,
1120            None,
1121            None,
1122        )
1123        .await
1124        .unwrap();
1125
1126        assert_eq!(result.stats.entities_upserted, 1);
1127        assert_eq!(
1128            result.stats.edges_inserted, 0,
1129            "self-loop edge must be rejected (#2215)"
1130        );
1131    }
1132
1133    /// When `apex_mem_enabled = true`, edges must be inserted via `insert_or_supersede`
1134    /// (the APEX-MEM append-only path) instead of the legacy `resolve_edge_typed` path.
1135    /// Verifies that edges are still counted as inserted and that the supersession row
1136    /// is created in the database.
1137    #[tokio::test]
1138    async fn apex_mem_path_inserts_edge_via_insert_or_supersede() {
1139        let (gs, _emb) = setup().await;
1140
1141        let extraction_json = r#"{
1142            "entities":[
1143                {"name":"Alice","type":"person","summary":"a person"},
1144                {"name":"Bob","type":"person","summary":"another person"}
1145            ],
1146            "edges":[
1147                {"source":"Alice","target":"Bob","relation":"KNOWS","fact":"Alice knows Bob","edge_type":"semantic"}
1148            ]
1149        }"#;
1150        let mock = zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
1151        let provider = AnyProvider::Mock(mock);
1152
1153        let config = GraphExtractionConfig {
1154            max_entities: 10,
1155            max_edges: 10,
1156            extraction_timeout_secs: 10,
1157            apex_mem_enabled: true,
1158            ..Default::default()
1159        };
1160
1161        let result = extract_and_store(
1162            "Alice knows Bob.".to_owned(),
1163            vec![],
1164            provider,
1165            gs.pool().clone(),
1166            config,
1167            None,
1168            None,
1169        )
1170        .await
1171        .unwrap();
1172
1173        assert_eq!(result.stats.entities_upserted, 2, "two entities expected");
1174        assert_eq!(
1175            result.stats.edges_inserted, 1,
1176            "APEX-MEM path must insert the edge and count it (#3631)"
1177        );
1178
1179        // Verify the edge row exists and its relation preserves display casing.
1180        let alice_id = gs
1181            .find_entity("alice", crate::graph::EntityType::Person)
1182            .await
1183            .unwrap()
1184            .expect("entity 'alice' must exist")
1185            .id
1186            .0;
1187        let bob_id = gs
1188            .find_entity("bob", crate::graph::EntityType::Person)
1189            .await
1190            .unwrap()
1191            .expect("entity 'bob' must exist")
1192            .id
1193            .0;
1194        let edges = gs.edges_exact(alice_id, bob_id).await.unwrap();
1195        assert_eq!(edges.len(), 1, "exactly one edge expected");
1196        // canonical_relation is lowercased; relation field preserves original casing post-strip
1197        assert_eq!(
1198            edges[0].relation, "KNOWS",
1199            "display relation must preserve original casing"
1200        );
1201    }
1202
1203    /// Regression for #4297: `embed_work_items` must return an empty Vec (fail-open) when the
1204    /// batch `join_all` embed call exceeds the 30 s global timeout.
1205    #[tokio::test]
1206    async fn embed_work_items_timeout_returns_empty() {
1207        use zeph_llm::mock::MockProvider;
1208
1209        // embed_delay_ms > 30_000 ms would make the test too slow; we rely on tokio::time::pause
1210        // to advance virtual time instantly, so the timeout fires without real delay.
1211        tokio::time::pause();
1212
1213        // Delay longer than the 30 s timeout (in virtual time).
1214        let mut mock = MockProvider::default();
1215        mock.supports_embeddings = true;
1216        mock.embed_delay_ms = 31_000;
1217        let provider = AnyProvider::Mock(mock);
1218
1219        let work_items = vec![super::EntityWorkItem {
1220            entity_id: 1,
1221            canonical_name: "Alice".to_owned(),
1222            embed_text: "Alice".to_owned(),
1223            self_point_id: None,
1224        }];
1225
1226        let cfg = NoteLinkingConfig {
1227            timeout_secs: 30,
1228            ..NoteLinkingConfig::default()
1229        };
1230        let result = super::embed_work_items(&work_items, &provider, &cfg).await;
1231        assert!(
1232            result.is_empty(),
1233            "embed_work_items must return empty Vec on 30 s timeout (fail-open)"
1234        );
1235    }
1236
1237    /// Regression for #4622: `maybe_refresh_communities` must return immediately when the
1238    /// `CancellationToken` is already cancelled, without hanging or panicking.
1239    ///
1240    /// Before the fix a nested `tokio::spawn` was used with no `CancellationToken`, so shutdown
1241    /// could not interrupt community detection.  The inline `tokio::select!` path now exits at
1242    /// the first select arm when the token is pre-cancelled.
1243    #[tokio::test]
1244    async fn maybe_refresh_communities_respects_cancelled_token() {
1245        use tokio_util::sync::CancellationToken;
1246
1247        use crate::graph::GraphStore;
1248        use crate::store::SqliteStore;
1249
1250        let sqlite = SqliteStore::new(":memory:").await.unwrap();
1251        let pool = sqlite.pool().clone();
1252        let gs = GraphStore::new(pool.clone());
1253
1254        // Seed extraction_count=1 so the interval check passes (1 % 1 == 0).
1255        gs.set_metadata("extraction_count", "1").await.unwrap();
1256
1257        let config = GraphExtractionConfig {
1258            community_refresh_interval: 1,
1259            ..Default::default()
1260        };
1261
1262        let cancel = CancellationToken::new();
1263        cancel.cancel(); // pre-cancelled — all select! arms must short-circuit immediately
1264
1265        let extraction_json = r#"{"entities":[],"edges":[]}"#;
1266        let mock = zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
1267        let provider = AnyProvider::Mock(mock);
1268
1269        let failure_counter = Arc::new(std::sync::atomic::AtomicU64::new(0));
1270
1271        // Must complete promptly — if the fix regresses and a blocking call is made this will
1272        // hang forever (caught by tokio::time::timeout in CI or a test runtime timeout).
1273        super::maybe_refresh_communities(
1274            true,
1275            pool,
1276            provider,
1277            failure_counter.clone(),
1278            &config,
1279            cancel,
1280        )
1281        .await;
1282
1283        assert_eq!(
1284            failure_counter.load(std::sync::atomic::Ordering::Relaxed),
1285            0,
1286            "no failures should be recorded when cancelled before any detection step"
1287        );
1288    }
1289}