Skip to main content

zeph_memory/graph/resolver/
mod.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::sync::Arc;
5
6use dashmap::DashMap;
7use futures::stream::{self, StreamExt as _};
8use schemars::JsonSchema;
9use serde::Deserialize;
10use tokio::sync::Mutex;
11use zeph_common::sanitize::strip_control_chars;
12use zeph_common::text::truncate_to_bytes_ref;
13use zeph_llm::any::AnyProvider;
14use zeph_llm::provider::{LlmProvider as _, Message, Role};
15
16use super::store::GraphStore;
17use super::types::EntityType;
18use crate::embedding_store::EmbeddingStore;
19use crate::error::MemoryError;
20use crate::graph::extractor::ExtractedEntity;
21use crate::types::MessageId;
22use crate::vector_store::{FieldCondition, FieldValue, VectorFilter};
23
24/// Minimum byte length for entity names — rejects noise tokens like "go", "cd".
25const MIN_ENTITY_NAME_BYTES: usize = 3;
26/// Maximum byte length for entity names stored in the graph.
27const MAX_ENTITY_NAME_BYTES: usize = 512;
28/// Maximum byte length for relation strings.
29const MAX_RELATION_BYTES: usize = 256;
30/// Maximum byte length for fact strings.
31const MAX_FACT_BYTES: usize = 2048;
32
33/// Qdrant collection for entity embeddings.
34const ENTITY_COLLECTION: &str = "zeph_graph_entities";
35
36/// Timeout for a single `embed()` call in seconds.
37const EMBED_TIMEOUT_SECS: u64 = 30;
38
39/// Outcome of an entity resolution attempt.
40#[derive(Debug, Clone, PartialEq)]
41pub enum ResolutionOutcome {
42    /// Exact name+type match in `SQLite`.
43    ExactMatch,
44    /// Cosine similarity >= merge threshold; score is the cosine similarity value.
45    EmbeddingMatch { score: f32 },
46    /// LLM confirmed merge in ambiguous similarity range.
47    LlmDisambiguated,
48    /// New entity was created.
49    Created,
50}
51
52/// LLM response for entity disambiguation.
53#[derive(Debug, Deserialize, JsonSchema)]
54struct DisambiguationResponse {
55    same_entity: bool,
56}
57
58/// Per-entity-name lock guard to prevent concurrent duplicate creation.
59///
60/// Keyed by normalized entity name. Entities with different names resolve concurrently;
61/// entities with the same name are serialized.
62///
63/// TODO(SEC-M33-02): This map grows unboundedly — one entry per unique normalized name.
64/// For a short-lived resolver this is acceptable. If the resolver becomes long-lived
65/// (stored in `SemanticMemory`), add eviction or use a fixed-size sharded lock array.
66type NameLockMap = Arc<DashMap<String, Arc<Mutex<()>>>>;
67
68pub struct EntityResolver<'a> {
69    store: &'a GraphStore,
70    embedding_store: Option<&'a Arc<EmbeddingStore>>,
71    provider: Option<&'a AnyProvider>,
72    similarity_threshold: f32,
73    ambiguous_threshold: f32,
74    name_locks: NameLockMap,
75    /// Counter for error-triggered fallbacks (embed/LLM failures). Tests can read this via Arc.
76    fallback_count: Arc<std::sync::atomic::AtomicU64>,
77}
78
79impl<'a> EntityResolver<'a> {
80    #[must_use]
81    pub fn new(store: &'a GraphStore) -> Self {
82        Self {
83            store,
84            embedding_store: None,
85            provider: None,
86            similarity_threshold: 0.85,
87            ambiguous_threshold: 0.70,
88            name_locks: Arc::new(DashMap::new()),
89            fallback_count: Arc::new(std::sync::atomic::AtomicU64::new(0)),
90        }
91    }
92
93    #[must_use]
94    pub fn with_embedding_store(mut self, store: &'a Arc<EmbeddingStore>) -> Self {
95        self.embedding_store = Some(store);
96        self
97    }
98
99    #[must_use]
100    pub fn with_provider(mut self, provider: &'a AnyProvider) -> Self {
101        self.provider = Some(provider);
102        self
103    }
104
105    #[must_use]
106    pub fn with_thresholds(mut self, similarity: f32, ambiguous: f32) -> Self {
107        self.similarity_threshold = similarity;
108        self.ambiguous_threshold = ambiguous;
109        self
110    }
111
112    /// Shared fallback counter — tests can clone this Arc to inspect the value.
113    #[must_use]
114    pub fn fallback_count(&self) -> Arc<std::sync::atomic::AtomicU64> {
115        Arc::clone(&self.fallback_count)
116    }
117
118    /// Normalize an entity name: trim, lowercase, strip control chars, truncate.
119    fn normalize_name(name: &str) -> String {
120        let lowered = name.trim().to_lowercase();
121        let cleaned = strip_control_chars(&lowered);
122        let normalized = truncate_to_bytes_ref(&cleaned, MAX_ENTITY_NAME_BYTES).to_owned();
123        if normalized.len() < cleaned.len() {
124            tracing::debug!(
125                "graph resolver: entity name truncated to {} bytes",
126                MAX_ENTITY_NAME_BYTES
127            );
128        }
129        normalized
130    }
131
132    /// Parse an entity type string, falling back to `Concept` on unknown values.
133    fn parse_entity_type(entity_type: &str) -> EntityType {
134        entity_type
135            .trim()
136            .to_lowercase()
137            .parse::<EntityType>()
138            .unwrap_or_else(|_| {
139                tracing::debug!(
140                    "graph resolver: unknown entity type {:?}, falling back to Concept",
141                    entity_type
142                );
143                EntityType::Concept
144            })
145    }
146
147    /// Acquire the per-name lock and return the guard. Keeps lock alive for the caller.
148    async fn lock_name(&self, normalized: &str) -> tokio::sync::OwnedMutexGuard<()> {
149        let lock = self
150            .name_locks
151            .entry(normalized.to_owned())
152            .or_insert_with(|| Arc::new(Mutex::new(())))
153            .clone();
154        lock.lock_owned().await
155    }
156
157    /// Resolve an extracted entity using the alias-first canonicalization pipeline.
158    ///
159    /// Pipeline:
160    /// 1. Normalize: trim, lowercase, strip control chars, truncate to 512 bytes.
161    /// 2. Parse entity type (fallback to Concept on unknown).
162    /// 3. Alias lookup: search `graph_entity_aliases` by normalized name + `entity_type`.
163    ///    If found, touch `last_seen_at` and return the existing entity id.
164    /// 4. Canonical name lookup: search `graph_entities` by `canonical_name` + `entity_type`.
165    ///    If found, touch `last_seen_at` and return the existing entity id.
166    /// 5. When `embedding_store` and `provider` are configured, performs embedding-based fuzzy
167    ///    matching: cosine similarity search (Qdrant), LLM disambiguation for ambiguous range,
168    ///    merge or create based on result. Failures degrade gracefully to step 6.
169    /// 6. Create: upsert new entity with `canonical_name` = normalized name.
170    /// 7. Register the normalized form (and original trimmed form if different) as aliases.
171    ///
172    /// # Errors
173    ///
174    /// Returns an error if the entity name is empty after normalization, or if a DB operation fails.
175    pub async fn resolve(
176        &self,
177        name: &str,
178        entity_type: &str,
179        summary: Option<&str>,
180    ) -> Result<(i64, ResolutionOutcome), MemoryError> {
181        let normalized = Self::normalize_name(name);
182
183        if normalized.is_empty() {
184            return Err(MemoryError::GraphStore("empty entity name".into()));
185        }
186
187        if normalized.len() < MIN_ENTITY_NAME_BYTES {
188            return Err(MemoryError::GraphStore(format!(
189                "entity name too short: {normalized:?} ({} bytes, min {MIN_ENTITY_NAME_BYTES})",
190                normalized.len()
191            )));
192        }
193
194        let et = Self::parse_entity_type(entity_type);
195
196        // The surface form preserves the original casing for user-facing display.
197        let surface_name = name.trim().to_owned();
198
199        // Acquire per-name lock to prevent concurrent duplicate creation.
200        let _guard = self.lock_name(&normalized).await;
201
202        // Step 3: alias-first lookup (filters by entity_type to prevent cross-type collisions).
203        if let Some(entity) = self.store.find_entity_by_alias(&normalized, et).await? {
204            self.store
205                .upsert_entity(&surface_name, &entity.canonical_name, et, summary)
206                .await?;
207            return Ok((entity.id, ResolutionOutcome::ExactMatch));
208        }
209
210        // Step 4: canonical name lookup.
211        if let Some(entity) = self.store.find_entity(&normalized, et).await? {
212            self.store
213                .upsert_entity(&surface_name, &entity.canonical_name, et, summary)
214                .await?;
215            return Ok((entity.id, ResolutionOutcome::ExactMatch));
216        }
217
218        // Step 5: Embedding-based resolution (when configured).
219        if let Some(outcome) = self
220            .resolve_via_embedding(&normalized, name, &surface_name, et, summary)
221            .await?
222        {
223            return Ok(outcome);
224        }
225
226        // Step 6: Create new entity (no embedding store, or embedding failure).
227        let entity_id = self
228            .store
229            .upsert_entity(&surface_name, &normalized, et, summary)
230            .await?;
231
232        self.register_aliases(entity_id, &normalized, name).await?;
233
234        Ok((entity_id, ResolutionOutcome::Created))
235    }
236
237    /// Compute embedding for an entity, incrementing `fallback_count` on failure/timeout.
238    /// Returns `None` when embedding is unavailable (caller should skip vector operations).
239    async fn embed_entity_text(
240        &self,
241        provider: &AnyProvider,
242        normalized: &str,
243        summary: Option<&str>,
244    ) -> Option<Vec<f32>> {
245        let safe_summary = truncate_to_bytes_ref(summary.unwrap_or(""), MAX_FACT_BYTES);
246        let embed_text = format!("{normalized}: {safe_summary}");
247        let embed_result = tokio::time::timeout(
248            std::time::Duration::from_secs(EMBED_TIMEOUT_SECS),
249            provider.embed(&embed_text),
250        )
251        .await;
252        match embed_result {
253            Ok(Ok(v)) => Some(v),
254            Ok(Err(err)) => {
255                self.fallback_count
256                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
257                tracing::warn!(entity_name = %normalized, error = %err,
258                    "embed() failed; falling back to exact-match-only entity creation");
259                None
260            }
261            Err(_timeout) => {
262                self.fallback_count
263                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
264                tracing::warn!(entity_name = %normalized,
265                    "embed() timed out after {}s; falling back to create new entity",
266                    EMBED_TIMEOUT_SECS);
267                None
268            }
269        }
270    }
271
272    /// Handle a candidate in the ambiguous score range by running LLM disambiguation.
273    /// Returns `Ok(Some(...))` if the LLM confirms a match, `Ok(None)` to fall through to create.
274    #[allow(clippy::too_many_arguments)]
275    async fn handle_ambiguous_candidate(
276        &self,
277        emb_store: &EmbeddingStore,
278        provider: &AnyProvider,
279        payload: &std::collections::HashMap<String, serde_json::Value>,
280        score: f32,
281        surface_name: &str,
282        normalized: &str,
283        et: EntityType,
284        summary: Option<&str>,
285    ) -> Result<Option<(i64, ResolutionOutcome)>, MemoryError> {
286        let entity_id = payload
287            .get("entity_id")
288            .and_then(serde_json::Value::as_i64)
289            .ok_or_else(|| MemoryError::GraphStore("missing entity_id in payload".into()))?;
290        let existing_name = payload
291            .get("name")
292            .and_then(|v| v.as_str())
293            .unwrap_or("")
294            .to_owned();
295        let existing_summary = payload
296            .get("summary")
297            .and_then(|v| v.as_str())
298            .unwrap_or("")
299            .to_owned();
300        // Use the existing entity's actual type from the payload (IC-S3)
301        let existing_type = payload
302            .get("entity_type")
303            .and_then(|v| v.as_str())
304            .unwrap_or(et.as_str())
305            .to_owned();
306        match self
307            .llm_disambiguate(
308                provider,
309                normalized,
310                et.as_str(),
311                summary.unwrap_or(""),
312                &existing_name,
313                &existing_type,
314                &existing_summary,
315                score,
316            )
317            .await
318        {
319            Some(true) => {
320                self.merge_entity(
321                    emb_store,
322                    provider,
323                    entity_id,
324                    surface_name,
325                    normalized,
326                    et,
327                    summary,
328                )
329                .await?;
330                Ok(Some((entity_id, ResolutionOutcome::LlmDisambiguated)))
331            }
332            Some(false) => Ok(None),
333            None => {
334                self.fallback_count
335                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
336                tracing::warn!(entity_name = %normalized,
337                    "LLM disambiguation failed; falling back to create new entity");
338                Ok(None)
339            }
340        }
341    }
342
343    /// Attempt embedding-based resolution. Returns `Ok(Some(...))` if resolved (early return),
344    /// `Ok(None)` if no match found (caller should fall through to create), or `Err` on DB error.
345    async fn resolve_via_embedding(
346        &self,
347        normalized: &str,
348        original_name: &str,
349        surface_name: &str,
350        et: EntityType,
351        summary: Option<&str>,
352    ) -> Result<Option<(i64, ResolutionOutcome)>, MemoryError> {
353        let (Some(emb_store), Some(provider)) = (self.embedding_store, self.provider) else {
354            return Ok(None);
355        };
356
357        let Some(query_vec) = self.embed_entity_text(provider, normalized, summary).await else {
358            return Ok(None);
359        };
360
361        let type_filter = VectorFilter {
362            must: vec![FieldCondition {
363                field: "entity_type".into(),
364                value: FieldValue::Text(et.as_str().to_owned()),
365            }],
366            must_not: vec![],
367        };
368        let candidates = match emb_store
369            .search_collection(ENTITY_COLLECTION, &query_vec, 5, Some(type_filter))
370            .await
371        {
372            Ok(c) => c,
373            Err(err) => {
374                self.fallback_count
375                    .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
376                tracing::warn!(entity_name = %normalized, error = %err,
377                    "Qdrant search failed; falling back to create new entity");
378                return self
379                    .create_with_embedding(
380                        emb_store,
381                        surface_name,
382                        normalized,
383                        original_name,
384                        et,
385                        summary,
386                        &query_vec,
387                    )
388                    .await
389                    .map(Some);
390            }
391        };
392
393        if let Some(best) = candidates.first() {
394            let score = best.score;
395            if score >= self.similarity_threshold {
396                let entity_id = best
397                    .payload
398                    .get("entity_id")
399                    .and_then(serde_json::Value::as_i64)
400                    .ok_or_else(|| {
401                        MemoryError::GraphStore("missing entity_id in payload".into())
402                    })?;
403                self.merge_entity(
404                    emb_store,
405                    provider,
406                    entity_id,
407                    surface_name,
408                    normalized,
409                    et,
410                    summary,
411                )
412                .await?;
413                return Ok(Some((
414                    entity_id,
415                    ResolutionOutcome::EmbeddingMatch { score },
416                )));
417            } else if score >= self.ambiguous_threshold
418                && let Some(result) = self
419                    .handle_ambiguous_candidate(
420                        emb_store,
421                        provider,
422                        &best.payload,
423                        score,
424                        surface_name,
425                        normalized,
426                        et,
427                        summary,
428                    )
429                    .await?
430            {
431                return Ok(Some(result));
432            }
433            // score < ambiguous_threshold or LLM said different: fall through to create with embedding
434        }
435
436        // No suitable match — create new entity and store embedding.
437        self.create_with_embedding(
438            emb_store,
439            surface_name,
440            normalized,
441            original_name,
442            et,
443            summary,
444            &query_vec,
445        )
446        .await
447        .map(Some)
448    }
449
450    /// Create a new entity, register aliases, and store its embedding in Qdrant.
451    #[allow(clippy::too_many_arguments)]
452    async fn create_with_embedding(
453        &self,
454        emb_store: &EmbeddingStore,
455        surface_name: &str,
456        normalized: &str,
457        original_name: &str,
458        et: EntityType,
459        summary: Option<&str>,
460        query_vec: &[f32],
461    ) -> Result<(i64, ResolutionOutcome), MemoryError> {
462        let entity_id = self
463            .store
464            .upsert_entity(surface_name, normalized, et, summary)
465            .await?;
466        self.register_aliases(entity_id, normalized, original_name)
467            .await?;
468        self.store_entity_embedding(
469            emb_store,
470            entity_id,
471            None,
472            normalized,
473            et,
474            summary.unwrap_or(""),
475            query_vec,
476        )
477        .await;
478        Ok((entity_id, ResolutionOutcome::Created))
479    }
480
481    /// Register the normalized form and original trimmed form as aliases for an entity.
482    async fn register_aliases(
483        &self,
484        entity_id: i64,
485        normalized: &str,
486        original_name: &str,
487    ) -> Result<(), MemoryError> {
488        self.store.add_alias(entity_id, normalized).await?;
489
490        // Also register the original trimmed lowercased form if it differs from normalized
491        // (e.g. when control chars were stripped, leaving a shorter string).
492        let original_trimmed = original_name.trim().to_lowercase();
493        let original_clean_str = strip_control_chars(&original_trimmed);
494        let original_clean = truncate_to_bytes_ref(&original_clean_str, MAX_ENTITY_NAME_BYTES);
495        if original_clean != normalized {
496            self.store.add_alias(entity_id, original_clean).await?;
497        }
498
499        Ok(())
500    }
501
502    /// Merge an existing entity with new information: combine summaries, update Qdrant.
503    #[allow(clippy::too_many_arguments)]
504    async fn merge_entity(
505        &self,
506        emb_store: &EmbeddingStore,
507        provider: &AnyProvider,
508        entity_id: i64,
509        new_surface_name: &str,
510        new_canonical_name: &str,
511        entity_type: EntityType,
512        new_summary: Option<&str>,
513    ) -> Result<(), MemoryError> {
514        // TODO(PERF-03): The Qdrant payload already contains name/summary at the call site;
515        // pass them in as parameters to eliminate this extra SQLite roundtrip per merge.
516        let existing = self.store.find_entity_by_id(entity_id).await?;
517        let existing_summary = existing
518            .as_ref()
519            .and_then(|e| e.summary.as_deref())
520            .unwrap_or("");
521
522        let merged_summary = if let Some(new) = new_summary {
523            if !new.is_empty() && !existing_summary.is_empty() {
524                let combined = format!("{existing_summary}; {new}");
525                // TODO(S2): use LLM-based summary merge when summary exceeds 512 bytes
526                truncate_to_bytes_ref(&combined, MAX_FACT_BYTES).to_owned()
527            } else if !new.is_empty() {
528                new.to_owned()
529            } else {
530                existing_summary.to_owned()
531            }
532        } else {
533            existing_summary.to_owned()
534        };
535
536        let summary_opt = if merged_summary.is_empty() {
537            None
538        } else {
539            Some(merged_summary.as_str())
540        };
541
542        // Update the EXISTING entity's summary (keep its canonical_name, update surface display name).
543        let existing_canonical = existing.as_ref().map_or_else(
544            || new_canonical_name.to_owned(),
545            |e| e.canonical_name.clone(),
546        );
547        let existing_name_owned = existing
548            .as_ref()
549            .map_or_else(|| new_surface_name.to_owned(), |e| e.name.clone());
550        self.store
551            .upsert_entity(
552                &existing_name_owned,
553                &existing_canonical,
554                entity_type,
555                summary_opt,
556            )
557            .await?;
558
559        // Retrieve existing qdrant_point_id to reuse it (avoids orphaned stale points, IC-S1)
560        let existing_point_id = existing
561            .as_ref()
562            .and_then(|e| e.qdrant_point_id.as_deref())
563            .map(ToOwned::to_owned);
564
565        // Re-embed merged text and upsert to Qdrant
566        let embed_text = format!("{existing_name_owned}: {merged_summary}");
567        let embed_result = tokio::time::timeout(
568            std::time::Duration::from_secs(EMBED_TIMEOUT_SECS),
569            provider.embed(&embed_text),
570        )
571        .await;
572
573        match embed_result {
574            Ok(Ok(vec)) => {
575                self.store_entity_embedding(
576                    emb_store,
577                    entity_id,
578                    existing_point_id.as_deref(),
579                    &existing_name_owned,
580                    entity_type,
581                    &merged_summary,
582                    &vec,
583                )
584                .await;
585            }
586            Ok(Err(err)) => {
587                tracing::warn!(
588                    entity_id,
589                    error = %err,
590                    "merge re-embed failed; Qdrant entry may be stale"
591                );
592            }
593            Err(_) => {
594                tracing::warn!(
595                    entity_id,
596                    "merge re-embed timed out; Qdrant entry may be stale"
597                );
598            }
599        }
600
601        Ok(())
602    }
603
604    /// Store an entity embedding in Qdrant and update `qdrant_point_id` in `SQLite`.
605    ///
606    /// When `existing_point_id` is `Some`, the existing Qdrant point is updated in-place
607    /// (upsert by ID) to avoid orphaned stale points. When `None`, a new point is created.
608    ///
609    /// Failures are logged at warn level but do not propagate — the entity is still
610    /// valid in `SQLite` even if Qdrant upsert fails.
611    #[allow(clippy::too_many_arguments)]
612    async fn store_entity_embedding(
613        &self,
614        emb_store: &EmbeddingStore,
615        entity_id: i64,
616        existing_point_id: Option<&str>,
617        name: &str,
618        entity_type: EntityType,
619        summary: &str,
620        vector: &[f32],
621    ) {
622        // TODO(PERF-05): ensure_named_collection() is called on every store_entity_embedding()
623        // invocation, generating one Qdrant network roundtrip per entity in a batch. Cache this
624        // result at resolver construction time via `std::sync::OnceLock<bool>` to call it once.
625        let vector_size = u64::try_from(vector.len()).unwrap_or(384);
626        if let Err(err) = emb_store
627            .ensure_named_collection(ENTITY_COLLECTION, vector_size)
628            .await
629        {
630            tracing::error!(
631                error = %err,
632                "failed to ensure entity embedding collection; skipping Qdrant upsert"
633            );
634            return;
635        }
636
637        let payload = serde_json::json!({
638            "entity_id": entity_id,
639            "name": name,
640            "entity_type": entity_type.as_str(),
641            "summary": summary,
642        });
643
644        if let Some(point_id) = existing_point_id {
645            // Reuse existing point to avoid orphaned stale points (IC-S1)
646            if let Err(err) = emb_store
647                .upsert_to_collection(ENTITY_COLLECTION, point_id, payload, vector.to_vec())
648                .await
649            {
650                tracing::warn!(
651                    entity_id,
652                    error = %err,
653                    "Qdrant upsert (existing point) failed; Qdrant entry may be stale"
654                );
655            }
656        } else {
657            match emb_store
658                .store_to_collection(ENTITY_COLLECTION, payload, vector.to_vec())
659                .await
660            {
661                Ok(point_id) => {
662                    if let Err(err) = self
663                        .store
664                        .set_entity_qdrant_point_id(entity_id, &point_id)
665                        .await
666                    {
667                        tracing::warn!(
668                            entity_id,
669                            error = %err,
670                            "failed to store qdrant_point_id in SQLite"
671                        );
672                    }
673                }
674                Err(err) => {
675                    tracing::warn!(
676                        entity_id,
677                        error = %err,
678                        "Qdrant upsert failed; entity created in SQLite, qdrant_point_id remains NULL"
679                    );
680                }
681            }
682        }
683    }
684
685    /// Ask the LLM whether two entities are the same.
686    ///
687    /// Returns `Some(true)` for merge, `Some(false)` for separate, `None` on failure.
688    #[allow(clippy::too_many_arguments)]
689    async fn llm_disambiguate(
690        &self,
691        provider: &AnyProvider,
692        new_name: &str,
693        new_type: &str,
694        new_summary: &str,
695        existing_name: &str,
696        existing_type: &str,
697        existing_summary: &str,
698        score: f32,
699    ) -> Option<bool> {
700        let prompt = format!(
701            "New entity:\n\
702             - Name: <external-data>{new_name}</external-data>\n\
703             - Type: <external-data>{new_type}</external-data>\n\
704             - Summary: <external-data>{new_summary}</external-data>\n\
705             \n\
706             Existing entity:\n\
707             - Name: <external-data>{existing_name}</external-data>\n\
708             - Type: <external-data>{existing_type}</external-data>\n\
709             - Summary: <external-data>{existing_summary}</external-data>\n\
710             \n\
711             Cosine similarity: {score:.2}\n\
712             \n\
713             Are these the same entity? Respond with JSON: {{\"same_entity\": true}} or {{\"same_entity\": false}}"
714        );
715
716        let messages = [
717            Message::from_legacy(
718                Role::System,
719                "You are an entity disambiguation assistant. Given a new entity mention and \
720                 an existing entity from the knowledge graph, determine if they refer to the same \
721                 real-world entity. Respond only with JSON.",
722            ),
723            Message::from_legacy(Role::User, prompt),
724        ];
725
726        let response = match provider.chat(&messages).await {
727            Ok(r) => r,
728            Err(err) => {
729                tracing::warn!(error = %err, "LLM disambiguation chat failed");
730                return None;
731            }
732        };
733
734        // Parse JSON response, tolerating markdown code fences
735        let json_str = extract_json(&response);
736        match serde_json::from_str::<DisambiguationResponse>(json_str) {
737            Ok(parsed) => Some(parsed.same_entity),
738            Err(err) => {
739                tracing::warn!(error = %err, response = %response, "failed to parse LLM disambiguation response");
740                None
741            }
742        }
743    }
744
745    /// Resolve a batch of extracted entities concurrently.
746    ///
747    /// Returns a `Vec` of `(entity_id, ResolutionOutcome)` in the same order as input.
748    ///
749    /// # Errors
750    ///
751    /// Returns an error if any DB operation fails.
752    ///
753    /// # Panics
754    ///
755    /// Panics if an internal stream collection bug causes a result index to be missing.
756    /// This indicates a programming error and should never occur in correct usage.
757    pub async fn resolve_batch(
758        &self,
759        entities: &[ExtractedEntity],
760    ) -> Result<Vec<(i64, ResolutionOutcome)>, MemoryError> {
761        if entities.is_empty() {
762            return Ok(Vec::new());
763        }
764
765        // Process up to 4 embed+resolve operations concurrently (IC-S2/PERF-01).
766        let mut results: Vec<Option<(i64, ResolutionOutcome)>> = vec![None; entities.len()];
767
768        let mut stream = stream::iter(entities.iter().enumerate().map(|(i, e)| {
769            let name = e.name.clone();
770            let entity_type = e.entity_type.clone();
771            let summary = e.summary.clone();
772            async move {
773                let result = self.resolve(&name, &entity_type, summary.as_deref()).await;
774                (i, result)
775            }
776        }))
777        .buffer_unordered(4);
778
779        while let Some((i, result)) = stream.next().await {
780            match result {
781                Ok(outcome) => results[i] = Some(outcome),
782                Err(err) => return Err(err),
783            }
784        }
785
786        Ok(results
787            .into_iter()
788            .enumerate()
789            .map(|(i, r)| {
790                r.unwrap_or_else(|| {
791                    tracing::warn!(
792                        index = i,
793                        "resolve_batch: missing result at index — bug in stream collection"
794                    );
795                    panic!("resolve_batch: missing result at index {i}")
796                })
797            })
798            .collect())
799    }
800
801    /// Resolve an extracted edge: deduplicate or supersede existing edges.
802    ///
803    /// - If an active edge with the same direction and relation exists with an identical fact,
804    ///   returns `None` (deduplicated).
805    /// - If an active edge with the same direction and relation exists with a different fact,
806    ///   invalidates the old edge and inserts the new one, returning `Some(new_id)`.
807    /// - If no matching edge exists, inserts a new edge and returns `Some(new_id)`.
808    ///
809    /// Relation and fact strings are sanitized (control chars stripped, length-capped).
810    ///
811    /// # Errors
812    ///
813    /// Returns an error if any database operation fails.
814    pub async fn resolve_edge(
815        &self,
816        source_id: i64,
817        target_id: i64,
818        relation: &str,
819        fact: &str,
820        confidence: f32,
821        episode_id: Option<MessageId>,
822    ) -> Result<Option<i64>, MemoryError> {
823        let relation_clean = strip_control_chars(&relation.trim().to_lowercase());
824        let normalized_relation =
825            truncate_to_bytes_ref(&relation_clean, MAX_RELATION_BYTES).to_owned();
826
827        let fact_clean = strip_control_chars(fact.trim());
828        let normalized_fact = truncate_to_bytes_ref(&fact_clean, MAX_FACT_BYTES).to_owned();
829
830        // Fetch only exact-direction edges — no reverse edges to filter out
831        let existing_edges = self.store.edges_exact(source_id, target_id).await?;
832
833        let matching = existing_edges
834            .iter()
835            .find(|e| e.relation == normalized_relation);
836
837        if let Some(old) = matching {
838            if old.fact == normalized_fact {
839                // Exact duplicate — skip
840                return Ok(None);
841            }
842            // Same relation, different fact — supersede
843            self.store.invalidate_edge(old.id).await?;
844        }
845
846        let new_id = self
847            .store
848            .insert_edge(
849                source_id,
850                target_id,
851                &normalized_relation,
852                &normalized_fact,
853                confidence,
854                episode_id,
855            )
856            .await?;
857        Ok(Some(new_id))
858    }
859
860    /// Resolve a typed edge: deduplicate or supersede existing edges of the same type.
861    ///
862    /// Identical to [`resolve_edge`] but includes `edge_type` in the matching key.
863    /// An active edge with the same `(source, target, relation, edge_type)` and identical
864    /// fact returns `None`; same relation+type with different fact is superseded.
865    ///
866    /// This ensures that different MAGMA edge types for the same entity pair are stored
867    /// independently (critic mitigation: dedup key includes `edge_type`).
868    ///
869    /// # Errors
870    ///
871    /// Returns an error if any database operation fails.
872    #[allow(clippy::too_many_arguments)]
873    pub async fn resolve_edge_typed(
874        &self,
875        source_id: i64,
876        target_id: i64,
877        relation: &str,
878        fact: &str,
879        confidence: f32,
880        episode_id: Option<crate::types::MessageId>,
881        edge_type: crate::graph::EdgeType,
882    ) -> Result<Option<i64>, MemoryError> {
883        let relation_clean = strip_control_chars(&relation.trim().to_lowercase());
884        let normalized_relation =
885            truncate_to_bytes_ref(&relation_clean, MAX_RELATION_BYTES).to_owned();
886
887        let fact_clean = strip_control_chars(fact.trim());
888        let normalized_fact = truncate_to_bytes_ref(&fact_clean, MAX_FACT_BYTES).to_owned();
889
890        let existing_edges = self.store.edges_exact(source_id, target_id).await?;
891
892        // Match on (relation, edge_type) — different types are distinct edges
893        let matching = existing_edges
894            .iter()
895            .find(|e| e.relation == normalized_relation && e.edge_type == edge_type);
896
897        if let Some(old) = matching {
898            if old.fact == normalized_fact {
899                return Ok(None);
900            }
901            self.store.invalidate_edge(old.id).await?;
902        }
903
904        let new_id = self
905            .store
906            .insert_edge_typed(
907                source_id,
908                target_id,
909                &normalized_relation,
910                &normalized_fact,
911                confidence,
912                episode_id,
913                edge_type,
914            )
915            .await?;
916        Ok(Some(new_id))
917    }
918}
919
920/// Extract a JSON object from a string that may contain markdown code fences.
921fn extract_json(s: &str) -> &str {
922    let trimmed = s.trim();
923    // Strip ```json ... ``` or ``` ... ```
924    if let Some(inner) = trimmed.strip_prefix("```json")
925        && let Some(end) = inner.rfind("```")
926    {
927        return inner[..end].trim();
928    }
929    if let Some(inner) = trimmed.strip_prefix("```")
930        && let Some(end) = inner.rfind("```")
931    {
932        return inner[..end].trim();
933    }
934    // Find first '{' to last '}'
935    if let (Some(start), Some(end)) = (trimmed.find('{'), trimmed.rfind('}'))
936        && start <= end
937    {
938        return &trimmed[start..=end];
939    }
940    trimmed
941}
942
943#[cfg(test)]
944mod tests;