Skip to main content

zeph_memory/semantic/
graph.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::sync::Arc;
5#[allow(unused_imports)]
6use zeph_db::sql;
7
8use std::sync::atomic::Ordering;
9use zeph_db::DbPool;
10
11pub use zeph_common::config::memory::NoteLinkingConfig;
12use zeph_llm::any::AnyProvider;
13use zeph_llm::provider::LlmProvider as _;
14
15use crate::embedding_store::EmbeddingStore;
16use crate::error::MemoryError;
17use crate::graph::extractor::ExtractionResult as ExtractorResult;
18use crate::vector_store::VectorFilter;
19
20use super::SemanticMemory;
21
22/// Callback type for post-extraction validation.
23///
24/// A generic predicate opaque to zeph-memory — callers (zeph-core) provide security
25/// validation without introducing a dependency on security policy in this crate.
26pub type PostExtractValidator = Option<Box<dyn Fn(&ExtractorResult) -> Result<(), String> + Send>>;
27
28/// Config for the spawned background extraction task.
29///
30/// Owned clone of the relevant fields from `GraphConfig` — no references, safe to send to
31/// spawned tasks.
32#[derive(Debug, Clone)]
33pub struct GraphExtractionConfig {
34    pub max_entities: usize,
35    pub max_edges: usize,
36    pub extraction_timeout_secs: u64,
37    pub community_refresh_interval: usize,
38    pub expired_edge_retention_days: u32,
39    pub max_entities_cap: usize,
40    pub community_summary_max_prompt_bytes: usize,
41    pub community_summary_concurrency: usize,
42    pub lpa_edge_chunk_size: usize,
43    /// A-MEM note linking config, cloned from `GraphConfig.note_linking`.
44    pub note_linking: NoteLinkingConfig,
45    /// A-MEM link weight decay lambda. Range: `(0.0, 1.0]`. Default: `0.95`.
46    pub link_weight_decay_lambda: f64,
47    /// Seconds between link weight decay passes. Default: `86400`.
48    pub link_weight_decay_interval_secs: u64,
49    /// Kumiho belief revision: enable semantic contradiction detection for edges.
50    pub belief_revision_enabled: bool,
51    /// Cosine similarity threshold for belief revision contradiction detection.
52    pub belief_revision_similarity_threshold: f32,
53    /// GAAMA episode linking: `conversation_id` to link extracted entities to their episode.
54    /// `None` disables episode linking for this extraction pass.
55    pub conversation_id: Option<i64>,
56}
57
58impl Default for GraphExtractionConfig {
59    fn default() -> Self {
60        Self {
61            max_entities: 0,
62            max_edges: 0,
63            extraction_timeout_secs: 0,
64            community_refresh_interval: 0,
65            expired_edge_retention_days: 0,
66            max_entities_cap: 0,
67            community_summary_max_prompt_bytes: 0,
68            community_summary_concurrency: 0,
69            lpa_edge_chunk_size: 0,
70            note_linking: NoteLinkingConfig::default(),
71            link_weight_decay_lambda: 0.95,
72            link_weight_decay_interval_secs: 86400,
73            belief_revision_enabled: false,
74            belief_revision_similarity_threshold: 0.85,
75            conversation_id: None,
76        }
77    }
78}
79
80/// Stats returned from a completed extraction.
81#[derive(Debug, Default)]
82pub struct ExtractionStats {
83    pub entities_upserted: usize,
84    pub edges_inserted: usize,
85}
86
87/// Result returned from `extract_and_store`, combining stats with entity IDs needed for linking.
88#[derive(Debug, Default)]
89pub struct ExtractionResult {
90    pub stats: ExtractionStats,
91    /// IDs of entities upserted during this extraction pass. Passed to `link_memory_notes`.
92    pub entity_ids: Vec<i64>,
93}
94
95/// Stats returned from a completed note-linking pass.
96#[derive(Debug, Default)]
97pub struct LinkingStats {
98    pub entities_processed: usize,
99    pub edges_created: usize,
100}
101
102/// Qdrant collection name for entity embeddings (mirrors the constant in `resolver.rs`).
103const ENTITY_COLLECTION: &str = "zeph_graph_entities";
104
105/// Work item for a single entity during a note-linking pass.
106struct EntityWorkItem {
107    entity_id: i64,
108    canonical_name: String,
109    embed_text: String,
110    self_point_id: Option<String>,
111}
112
113/// Link newly extracted entities to semantically similar entities in the graph.
114///
115/// For each entity in `entity_ids`:
116/// 1. Load the entity name + summary from `SQLite`.
117/// 2. Embed all entity texts in parallel.
118/// 3. Search the entity embedding collection in parallel for the `top_k + 1` most similar points.
119/// 4. Filter out the entity itself (by `qdrant_point_id` or `entity_id` payload) and points
120///    below `similarity_threshold`.
121/// 5. Insert a unidirectional `similar_to` edge where `source_id < target_id` to avoid
122///    double-counting in BFS recall while still being traversable via the OR clause in
123///    `edges_for_entity`. The edge confidence is set to the cosine similarity score.
124/// 6. Deduplicate pairs within a single pass so that a pair encountered from both A→B and B→A
125///    directions is only inserted once, keeping `edges_created` accurate.
126///
127/// Errors are logged and not propagated — this is a best-effort background enrichment step.
128#[allow(clippy::too_many_lines)]
129pub async fn link_memory_notes(
130    entity_ids: &[i64],
131    pool: DbPool,
132    embedding_store: Arc<EmbeddingStore>,
133    provider: AnyProvider,
134    cfg: &NoteLinkingConfig,
135) -> LinkingStats {
136    use futures::future;
137
138    use crate::graph::GraphStore;
139
140    let store = GraphStore::new(pool);
141    let mut stats = LinkingStats::default();
142
143    // Phase 1: load entities from DB sequentially (cheap; avoids connection-pool contention).
144    let mut work_items: Vec<EntityWorkItem> = Vec::with_capacity(entity_ids.len());
145    for &entity_id in entity_ids {
146        let entity = match store.find_entity_by_id(entity_id).await {
147            Ok(Some(e)) => e,
148            Ok(None) => {
149                tracing::debug!("note_linking: entity {entity_id} not found, skipping");
150                continue;
151            }
152            Err(e) => {
153                tracing::debug!("note_linking: DB error loading entity {entity_id}: {e:#}");
154                continue;
155            }
156        };
157        let embed_text = match &entity.summary {
158            Some(s) if !s.is_empty() => format!("{}: {s}", entity.canonical_name),
159            _ => entity.canonical_name.clone(),
160        };
161        work_items.push(EntityWorkItem {
162            entity_id,
163            canonical_name: entity.canonical_name,
164            embed_text,
165            self_point_id: entity.qdrant_point_id,
166        });
167    }
168
169    if work_items.is_empty() {
170        return stats;
171    }
172
173    // Phase 2: embed all entity texts in parallel to reduce N serial HTTP round-trips to 1.
174    let embed_results: Vec<_> =
175        future::join_all(work_items.iter().map(|w| provider.embed(&w.embed_text))).await;
176
177    // Phase 3: search for similar entities in parallel for all successfully embedded entities.
178    let search_limit = cfg.top_k + 1; // +1 to account for self-match
179    let valid: Vec<(usize, Vec<f32>)> = embed_results
180        .into_iter()
181        .enumerate()
182        .filter_map(|(i, r)| match r {
183            Ok(v) => Some((i, v)),
184            Err(e) => {
185                tracing::debug!(
186                    "note_linking: embed failed for entity {:?}: {e:#}",
187                    work_items[i].canonical_name
188                );
189                None
190            }
191        })
192        .collect();
193
194    let search_results: Vec<_> = future::join_all(valid.iter().map(|(_, vec)| {
195        embedding_store.search_collection(
196            ENTITY_COLLECTION,
197            vec,
198            search_limit,
199            None::<VectorFilter>,
200        )
201    }))
202    .await;
203
204    // Phase 4: insert edges; deduplicate pairs seen from both A→B and B→A directions.
205    // Without deduplication, both directions call insert_edge for the same normalised pair and
206    // both return Ok (the second call updates confidence on the existing row), inflating
207    // edges_created by the number of bidirectional hits.
208    let mut seen_pairs = std::collections::HashSet::new();
209
210    for ((work_idx, _), search_result) in valid.iter().zip(search_results.iter()) {
211        let w = &work_items[*work_idx];
212
213        let results = match search_result {
214            Ok(r) => r,
215            Err(e) => {
216                tracing::debug!(
217                    "note_linking: search failed for entity {:?}: {e:#}",
218                    w.canonical_name
219                );
220                continue;
221            }
222        };
223
224        stats.entities_processed += 1;
225
226        let self_point_id = w.self_point_id.as_deref();
227        let candidates = results
228            .iter()
229            .filter(|p| Some(p.id.as_str()) != self_point_id && p.score >= cfg.similarity_threshold)
230            .take(cfg.top_k);
231
232        for point in candidates {
233            let Some(target_id) = point
234                .payload
235                .get("entity_id")
236                .and_then(serde_json::Value::as_i64)
237            else {
238                tracing::debug!(
239                    "note_linking: missing entity_id in payload for point {}",
240                    point.id
241                );
242                continue;
243            };
244
245            if target_id == w.entity_id {
246                continue; // secondary self-guard when qdrant_point_id is null
247            }
248
249            // Normalise direction: always store source_id < target_id.
250            let (src, tgt) = if w.entity_id < target_id {
251                (w.entity_id, target_id)
252            } else {
253                (target_id, w.entity_id)
254            };
255
256            // Skip pairs already processed in this pass to avoid double-counting.
257            if !seen_pairs.insert((src, tgt)) {
258                continue;
259            }
260
261            let fact = format!("Semantically similar entities (score: {:.3})", point.score);
262
263            match store
264                .insert_edge(src, tgt, "similar_to", &fact, point.score, None)
265                .await
266            {
267                Ok(_) => stats.edges_created += 1,
268                Err(e) => {
269                    tracing::debug!("note_linking: insert_edge failed: {e:#}");
270                }
271            }
272        }
273    }
274
275    stats
276}
277
278/// Extract entities and edges from `content` and persist them to the graph store.
279///
280/// This function runs inside a spawned task — it receives owned data only.
281///
282/// The optional `embedding_store` enables entity embedding storage in Qdrant, which is
283/// required for A-MEM note linking to find semantically similar entities across sessions.
284///
285/// # Errors
286///
287/// Returns an error if the database query fails or LLM extraction fails.
288#[allow(clippy::too_many_lines)]
289pub async fn extract_and_store(
290    content: String,
291    context_messages: Vec<String>,
292    provider: AnyProvider,
293    pool: DbPool,
294    config: GraphExtractionConfig,
295    post_extract_validator: PostExtractValidator,
296    embedding_store: Option<Arc<EmbeddingStore>>,
297) -> Result<ExtractionResult, MemoryError> {
298    use crate::graph::{EntityResolver, GraphExtractor, GraphStore};
299
300    let extractor = GraphExtractor::new(provider.clone(), config.max_entities, config.max_edges);
301    let ctx_refs: Vec<&str> = context_messages.iter().map(String::as_str).collect();
302
303    let store = GraphStore::new(pool);
304
305    let pool = store.pool();
306    zeph_db::query(sql!(
307        "INSERT INTO graph_metadata (key, value) VALUES ('extraction_count', '0')
308         ON CONFLICT(key) DO NOTHING"
309    ))
310    .execute(pool)
311    .await?;
312    zeph_db::query(sql!(
313        "UPDATE graph_metadata
314         SET value = CAST(CAST(value AS INTEGER) + 1 AS TEXT)
315         WHERE key = 'extraction_count'"
316    ))
317    .execute(pool)
318    .await?;
319
320    let Some(result) = extractor.extract(&content, &ctx_refs).await? else {
321        return Ok(ExtractionResult::default());
322    };
323
324    // Post-extraction validation callback. zeph-memory does not know the callback is a
325    // security validator — it is a generic predicate opaque to this crate (design decision D1).
326    if let Some(ref validator) = post_extract_validator
327        && let Err(reason) = validator(&result)
328    {
329        tracing::warn!(
330            reason,
331            "graph extraction validation failed, skipping upsert"
332        );
333        return Ok(ExtractionResult::default());
334    }
335
336    let resolver = if let Some(ref emb) = embedding_store {
337        EntityResolver::new(&store)
338            .with_embedding_store(emb)
339            .with_provider(&provider)
340    } else {
341        EntityResolver::new(&store)
342    };
343
344    let mut entities_upserted = 0usize;
345    let mut entity_name_to_id: std::collections::HashMap<String, i64> =
346        std::collections::HashMap::new();
347
348    for entity in &result.entities {
349        match resolver
350            .resolve(&entity.name, &entity.entity_type, entity.summary.as_deref())
351            .await
352        {
353            Ok((id, _outcome)) => {
354                entity_name_to_id.insert(entity.name.clone(), id);
355                entities_upserted += 1;
356            }
357            Err(e) => {
358                tracing::debug!("graph: skipping entity {:?}: {e:#}", entity.name);
359            }
360        }
361    }
362
363    let mut edges_inserted = 0usize;
364    for edge in &result.edges {
365        let (Some(&src_id), Some(&tgt_id)) = (
366            entity_name_to_id.get(&edge.source),
367            entity_name_to_id.get(&edge.target),
368        ) else {
369            tracing::debug!(
370                "graph: skipping edge {:?}->{:?}: entity not resolved",
371                edge.source,
372                edge.target
373            );
374            continue;
375        };
376        if src_id == tgt_id {
377            tracing::debug!(
378                "graph: skipping self-loop edge {:?}->{:?} (entity_id={src_id})",
379                edge.source,
380                edge.target
381            );
382            continue;
383        }
384        // Parse LLM-provided edge_type; default to Semantic on any parse failure so
385        // edges are never dropped due to classification errors.
386        let edge_type = edge
387            .edge_type
388            .parse::<crate::graph::EdgeType>()
389            .unwrap_or_else(|_| {
390                tracing::warn!(
391                    raw_type = %edge.edge_type,
392                    "graph: unknown edge_type from LLM, defaulting to semantic"
393                );
394                crate::graph::EdgeType::Semantic
395            });
396        let belief_cfg =
397            config
398                .belief_revision_enabled
399                .then_some(crate::graph::BeliefRevisionConfig {
400                    similarity_threshold: config.belief_revision_similarity_threshold,
401                });
402        match resolver
403            .resolve_edge_typed(
404                src_id,
405                tgt_id,
406                &edge.relation,
407                &edge.fact,
408                0.8,
409                None,
410                edge_type,
411                belief_cfg.as_ref(),
412            )
413            .await
414        {
415            Ok(Some(_)) => edges_inserted += 1,
416            Ok(None) => {} // deduplicated
417            Err(e) => {
418                tracing::debug!("graph: skipping edge: {e:#}");
419            }
420        }
421    }
422
423    store.checkpoint_wal().await?;
424
425    let new_entity_ids: Vec<i64> = entity_name_to_id.into_values().collect();
426
427    // GAAMA episode linking: link all extracted entities to the episode for this conversation.
428    if let Some(conv_id) = config.conversation_id {
429        match store.ensure_episode(conv_id).await {
430            Ok(episode_id) => {
431                for &entity_id in &new_entity_ids {
432                    if let Err(e) = store.link_entity_to_episode(episode_id, entity_id).await {
433                        tracing::debug!("episode linking skipped for entity {entity_id}: {e:#}");
434                    }
435                }
436            }
437            Err(e) => {
438                tracing::warn!("failed to ensure episode for conversation {conv_id}: {e:#}");
439            }
440        }
441    }
442
443    Ok(ExtractionResult {
444        stats: ExtractionStats {
445            entities_upserted,
446            edges_inserted,
447        },
448        entity_ids: new_entity_ids,
449    })
450}
451
452impl SemanticMemory {
453    /// Spawn background graph extraction for a message. Fire-and-forget — never blocks.
454    ///
455    /// Extraction runs in a separate tokio task with a timeout. Any error or timeout is
456    /// logged and the task exits silently; the agent response is never blocked.
457    ///
458    /// The optional `post_extract_validator` is called after extraction, before upsert.
459    /// It is a generic predicate opaque to zeph-memory (design decision D1).
460    ///
461    /// When `config.note_linking.enabled` is `true` and an embedding store is available,
462    /// `link_memory_notes` runs after successful extraction inside the same task, bounded
463    /// by `config.note_linking.timeout_secs`.
464    #[allow(clippy::too_many_lines)]
465    pub fn spawn_graph_extraction(
466        &self,
467        content: String,
468        context_messages: Vec<String>,
469        config: GraphExtractionConfig,
470        post_extract_validator: PostExtractValidator,
471    ) -> tokio::task::JoinHandle<()> {
472        let pool = self.sqlite.pool().clone();
473        let provider = self.provider.clone();
474        let failure_counter = self.community_detection_failures.clone();
475        let extraction_count = self.graph_extraction_count.clone();
476        let extraction_failures = self.graph_extraction_failures.clone();
477        // Clone the embedding store Arc before moving into the task.
478        let embedding_store = self.qdrant.clone();
479
480        tokio::spawn(async move {
481            let timeout_dur = std::time::Duration::from_secs(config.extraction_timeout_secs);
482            let extraction_result = tokio::time::timeout(
483                timeout_dur,
484                extract_and_store(
485                    content,
486                    context_messages,
487                    provider.clone(),
488                    pool.clone(),
489                    config.clone(),
490                    post_extract_validator,
491                    embedding_store.clone(),
492                ),
493            )
494            .await;
495
496            let (extraction_ok, new_entity_ids) = match extraction_result {
497                Ok(Ok(result)) => {
498                    tracing::debug!(
499                        entities = result.stats.entities_upserted,
500                        edges = result.stats.edges_inserted,
501                        "graph extraction completed"
502                    );
503                    extraction_count.fetch_add(1, Ordering::Relaxed);
504                    (true, result.entity_ids)
505                }
506                Ok(Err(e)) => {
507                    tracing::warn!("graph extraction failed: {e:#}");
508                    extraction_failures.fetch_add(1, Ordering::Relaxed);
509                    (false, vec![])
510                }
511                Err(_elapsed) => {
512                    tracing::warn!("graph extraction timed out");
513                    extraction_failures.fetch_add(1, Ordering::Relaxed);
514                    (false, vec![])
515                }
516            };
517
518            // A-MEM note linking: run after successful extraction when enabled.
519            if extraction_ok
520                && config.note_linking.enabled
521                && !new_entity_ids.is_empty()
522                && let Some(store) = embedding_store
523            {
524                let linking_timeout =
525                    std::time::Duration::from_secs(config.note_linking.timeout_secs);
526                match tokio::time::timeout(
527                    linking_timeout,
528                    link_memory_notes(
529                        &new_entity_ids,
530                        pool.clone(),
531                        store,
532                        provider.clone(),
533                        &config.note_linking,
534                    ),
535                )
536                .await
537                {
538                    Ok(stats) => {
539                        tracing::debug!(
540                            entities_processed = stats.entities_processed,
541                            edges_created = stats.edges_created,
542                            "note linking completed"
543                        );
544                    }
545                    Err(_elapsed) => {
546                        tracing::debug!("note linking timed out (partial edges may exist)");
547                    }
548                }
549            }
550
551            if extraction_ok && config.community_refresh_interval > 0 {
552                use crate::graph::GraphStore;
553
554                let store = GraphStore::new(pool.clone());
555                let extraction_count = store.extraction_count().await.unwrap_or(0);
556                if extraction_count > 0
557                    && i64::try_from(config.community_refresh_interval)
558                        .is_ok_and(|interval| extraction_count % interval == 0)
559                {
560                    tracing::info!(extraction_count, "triggering community detection refresh");
561                    let store2 = GraphStore::new(pool);
562                    let provider2 = provider;
563                    let retention_days = config.expired_edge_retention_days;
564                    let max_cap = config.max_entities_cap;
565                    let max_prompt_bytes = config.community_summary_max_prompt_bytes;
566                    let concurrency = config.community_summary_concurrency;
567                    let edge_chunk_size = config.lpa_edge_chunk_size;
568                    let decay_lambda = config.link_weight_decay_lambda;
569                    let decay_interval_secs = config.link_weight_decay_interval_secs;
570                    tokio::spawn(async move {
571                        match crate::graph::community::detect_communities(
572                            &store2,
573                            &provider2,
574                            max_prompt_bytes,
575                            concurrency,
576                            edge_chunk_size,
577                        )
578                        .await
579                        {
580                            Ok(count) => {
581                                tracing::info!(communities = count, "community detection complete");
582                            }
583                            Err(e) => {
584                                tracing::warn!("community detection failed: {e:#}");
585                                failure_counter.fetch_add(1, Ordering::Relaxed);
586                            }
587                        }
588                        match crate::graph::community::run_graph_eviction(
589                            &store2,
590                            retention_days,
591                            max_cap,
592                        )
593                        .await
594                        {
595                            Ok(stats) => {
596                                tracing::info!(
597                                    expired_edges = stats.expired_edges_deleted,
598                                    orphan_entities = stats.orphan_entities_deleted,
599                                    capped_entities = stats.capped_entities_deleted,
600                                    "graph eviction complete"
601                                );
602                            }
603                            Err(e) => {
604                                tracing::warn!("graph eviction failed: {e:#}");
605                            }
606                        }
607
608                        // Time-based link weight decay — independent of eviction cycle.
609                        if decay_lambda > 0.0 && decay_interval_secs > 0 {
610                            let now_secs = std::time::SystemTime::now()
611                                .duration_since(std::time::UNIX_EPOCH)
612                                .map(|d| d.as_secs())
613                                .unwrap_or(0);
614                            let last_decay = store2
615                                .get_metadata("last_link_weight_decay_at")
616                                .await
617                                .ok()
618                                .flatten()
619                                .and_then(|s| s.parse::<u64>().ok())
620                                .unwrap_or(0);
621                            if now_secs.saturating_sub(last_decay) >= decay_interval_secs {
622                                match store2
623                                    .decay_edge_retrieval_counts(decay_lambda, decay_interval_secs)
624                                    .await
625                                {
626                                    Ok(affected) => {
627                                        tracing::info!(affected, "link weight decay applied");
628                                        let _ = store2
629                                            .set_metadata(
630                                                "last_link_weight_decay_at",
631                                                &now_secs.to_string(),
632                                            )
633                                            .await;
634                                    }
635                                    Err(e) => {
636                                        tracing::warn!("link weight decay failed: {e:#}");
637                                    }
638                                }
639                            }
640                        }
641                    });
642                }
643            }
644        })
645    }
646}
647
648#[cfg(test)]
649mod tests {
650    use std::sync::Arc;
651
652    use zeph_llm::any::AnyProvider;
653
654    use super::extract_and_store;
655    use crate::embedding_store::EmbeddingStore;
656    use crate::graph::GraphStore;
657    use crate::in_memory_store::InMemoryVectorStore;
658    use crate::store::SqliteStore;
659
660    use super::GraphExtractionConfig;
661
662    async fn setup() -> (GraphStore, Arc<EmbeddingStore>) {
663        let sqlite = SqliteStore::new(":memory:").await.unwrap();
664        let pool = sqlite.pool().clone();
665        let mem_store = Box::new(InMemoryVectorStore::new());
666        let emb = Arc::new(EmbeddingStore::with_store(mem_store, pool.clone()));
667        let gs = GraphStore::new(pool);
668        (gs, emb)
669    }
670
671    /// Regression test for #1829: `extract_and_store()` must pass the provider to `EntityResolver`
672    /// so that `store_entity_embedding()` is called and `qdrant_point_id` is set in `SQLite`.
673    #[tokio::test]
674    async fn extract_and_store_sets_qdrant_point_id_when_embedding_store_provided() {
675        let (gs, emb) = setup().await;
676
677        // MockProvider: supports embeddings, returns a valid extraction JSON for chat
678        let extraction_json = r#"{"entities":[{"name":"Rust","type":"language","summary":"systems language"}],"edges":[]}"#;
679        let mut mock =
680            zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
681        mock.supports_embeddings = true;
682        mock.embedding = vec![1.0_f32, 0.0, 0.0, 0.0];
683        let provider = AnyProvider::Mock(mock);
684
685        let config = GraphExtractionConfig {
686            max_entities: 10,
687            max_edges: 10,
688            extraction_timeout_secs: 10,
689            ..Default::default()
690        };
691
692        let result = extract_and_store(
693            "Rust is a systems programming language.".to_owned(),
694            vec![],
695            provider,
696            gs.pool().clone(),
697            config,
698            None,
699            Some(emb.clone()),
700        )
701        .await
702        .unwrap();
703
704        assert_eq!(
705            result.stats.entities_upserted, 1,
706            "one entity should be upserted"
707        );
708
709        // The entity must have a qdrant_point_id — this proves store_entity_embedding() was called.
710        // Before the fix, EntityResolver was built without a provider, so embed() was never called
711        // and qdrant_point_id remained NULL.
712        let entity = gs
713            .find_entity("rust", crate::graph::EntityType::Language)
714            .await
715            .unwrap()
716            .expect("entity 'rust' must exist in SQLite");
717
718        assert!(
719            entity.qdrant_point_id.is_some(),
720            "qdrant_point_id must be set when embedding_store + provider are both provided (regression for #1829)"
721        );
722    }
723
724    /// When no `embedding_store` is provided, `extract_and_store()` must still work correctly
725    /// (no embeddings stored, but entities are still upserted).
726    #[tokio::test]
727    async fn extract_and_store_without_embedding_store_still_upserts_entities() {
728        let (gs, _emb) = setup().await;
729
730        let extraction_json = r#"{"entities":[{"name":"Python","type":"language","summary":"scripting"}],"edges":[]}"#;
731        let mock = zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
732        let provider = AnyProvider::Mock(mock);
733
734        let config = GraphExtractionConfig {
735            max_entities: 10,
736            max_edges: 10,
737            extraction_timeout_secs: 10,
738            ..Default::default()
739        };
740
741        let result = extract_and_store(
742            "Python is a scripting language.".to_owned(),
743            vec![],
744            provider,
745            gs.pool().clone(),
746            config,
747            None,
748            None, // no embedding_store
749        )
750        .await
751        .unwrap();
752
753        assert_eq!(result.stats.entities_upserted, 1);
754
755        let entity = gs
756            .find_entity("python", crate::graph::EntityType::Language)
757            .await
758            .unwrap()
759            .expect("entity 'python' must exist");
760
761        assert!(
762            entity.qdrant_point_id.is_none(),
763            "qdrant_point_id must remain None when no embedding_store is provided"
764        );
765    }
766
767    /// Regression test for #2166: FTS5 entity writes must be visible to a new connection pool
768    /// opened after extraction completes. Without `checkpoint_wal()` in `extract_and_store`,
769    /// a fresh pool sees stale FTS5 shadow tables and `find_entities_fuzzy` returns empty.
770    #[tokio::test]
771    async fn extract_and_store_fts5_cross_session_visibility() {
772        let file = tempfile::NamedTempFile::new().expect("tempfile");
773        let path = file.path().to_str().expect("valid path").to_string();
774
775        // Session A: run extract_and_store on a file DB (not :memory:) so WAL is used.
776        {
777            let sqlite = crate::store::SqliteStore::new(&path).await.unwrap();
778            let extraction_json = r#"{"entities":[{"name":"Ferris","type":"concept","summary":"Rust mascot"}],"edges":[]}"#;
779            let mock =
780                zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
781            let provider = AnyProvider::Mock(mock);
782            let config = GraphExtractionConfig {
783                max_entities: 10,
784                max_edges: 10,
785                extraction_timeout_secs: 10,
786                ..Default::default()
787            };
788            extract_and_store(
789                "Ferris is the Rust mascot.".to_owned(),
790                vec![],
791                provider,
792                sqlite.pool().clone(),
793                config,
794                None,
795                None,
796            )
797            .await
798            .unwrap();
799        }
800
801        // Session B: new pool — FTS5 must see the entity extracted in session A.
802        let sqlite_b = crate::store::SqliteStore::new(&path).await.unwrap();
803        let gs_b = crate::graph::GraphStore::new(sqlite_b.pool().clone());
804        let results = gs_b.find_entities_fuzzy("Ferris", 10).await.unwrap();
805        assert!(
806            !results.is_empty(),
807            "FTS5 cross-session (#2166): entity extracted in session A must be visible in session B"
808        );
809    }
810
811    /// Regression test for #2215: self-loop edges (source == target entity) must be silently
812    /// skipped; no edge row should be inserted.
813    #[tokio::test]
814    async fn extract_and_store_skips_self_loop_edges() {
815        let (gs, _emb) = setup().await;
816
817        // LLM returns one entity and one self-loop edge (source == target).
818        let extraction_json = r#"{
819            "entities":[{"name":"Rust","type":"language","summary":"systems language"}],
820            "edges":[{"source":"Rust","target":"Rust","relation":"is","fact":"Rust is Rust","edge_type":"semantic"}]
821        }"#;
822        let mock = zeph_llm::mock::MockProvider::with_responses(vec![extraction_json.to_owned()]);
823        let provider = AnyProvider::Mock(mock);
824
825        let config = GraphExtractionConfig {
826            max_entities: 10,
827            max_edges: 10,
828            extraction_timeout_secs: 10,
829            ..Default::default()
830        };
831
832        let result = extract_and_store(
833            "Rust is a language.".to_owned(),
834            vec![],
835            provider,
836            gs.pool().clone(),
837            config,
838            None,
839            None,
840        )
841        .await
842        .unwrap();
843
844        assert_eq!(result.stats.entities_upserted, 1);
845        assert_eq!(
846            result.stats.edges_inserted, 0,
847            "self-loop edge must be rejected (#2215)"
848        );
849    }
850}