1use std::sync::Arc;
5
6use dashmap::DashMap;
7use futures::stream::{self, StreamExt as _};
8use schemars::JsonSchema;
9use serde::Deserialize;
10use tokio::sync::Mutex;
11use zeph_common::sanitize::strip_control_chars;
12use zeph_common::text::truncate_to_bytes_ref;
13use zeph_llm::any::AnyProvider;
14use zeph_llm::provider::{LlmProvider as _, Message, Role};
15
16use super::store::GraphStore;
17use super::types::EntityType;
18use crate::embedding_store::EmbeddingStore;
19use crate::error::MemoryError;
20use crate::graph::extractor::ExtractedEntity;
21use crate::types::MessageId;
22use crate::vector_store::{FieldCondition, FieldValue, VectorFilter};
23
24const MIN_ENTITY_NAME_BYTES: usize = 3;
26const MAX_ENTITY_NAME_BYTES: usize = 512;
28const MAX_RELATION_BYTES: usize = 256;
30const MAX_FACT_BYTES: usize = 2048;
32
33const ENTITY_COLLECTION: &str = "zeph_graph_entities";
35
36const EMBED_TIMEOUT_SECS: u64 = 30;
38
39#[derive(Debug, Clone, PartialEq)]
41pub enum ResolutionOutcome {
42 ExactMatch,
44 EmbeddingMatch { score: f32 },
46 LlmDisambiguated,
48 Created,
50}
51
52#[derive(Debug, Deserialize, JsonSchema)]
54struct DisambiguationResponse {
55 same_entity: bool,
56}
57
58type NameLockMap = Arc<DashMap<String, Arc<Mutex<()>>>>;
67
68pub struct EntityResolver<'a> {
69 store: &'a GraphStore,
70 embedding_store: Option<&'a Arc<EmbeddingStore>>,
71 provider: Option<&'a AnyProvider>,
72 similarity_threshold: f32,
73 ambiguous_threshold: f32,
74 name_locks: NameLockMap,
75 fallback_count: Arc<std::sync::atomic::AtomicU64>,
77}
78
79impl<'a> EntityResolver<'a> {
80 #[must_use]
81 pub fn new(store: &'a GraphStore) -> Self {
82 Self {
83 store,
84 embedding_store: None,
85 provider: None,
86 similarity_threshold: 0.85,
87 ambiguous_threshold: 0.70,
88 name_locks: Arc::new(DashMap::new()),
89 fallback_count: Arc::new(std::sync::atomic::AtomicU64::new(0)),
90 }
91 }
92
93 #[must_use]
94 pub fn with_embedding_store(mut self, store: &'a Arc<EmbeddingStore>) -> Self {
95 self.embedding_store = Some(store);
96 self
97 }
98
99 #[must_use]
100 pub fn with_provider(mut self, provider: &'a AnyProvider) -> Self {
101 self.provider = Some(provider);
102 self
103 }
104
105 #[must_use]
106 pub fn with_thresholds(mut self, similarity: f32, ambiguous: f32) -> Self {
107 self.similarity_threshold = similarity;
108 self.ambiguous_threshold = ambiguous;
109 self
110 }
111
112 #[must_use]
114 pub fn fallback_count(&self) -> Arc<std::sync::atomic::AtomicU64> {
115 Arc::clone(&self.fallback_count)
116 }
117
118 fn normalize_name(name: &str) -> String {
120 let lowered = name.trim().to_lowercase();
121 let cleaned = strip_control_chars(&lowered);
122 let normalized = truncate_to_bytes_ref(&cleaned, MAX_ENTITY_NAME_BYTES).to_owned();
123 if normalized.len() < cleaned.len() {
124 tracing::debug!(
125 "graph resolver: entity name truncated to {} bytes",
126 MAX_ENTITY_NAME_BYTES
127 );
128 }
129 normalized
130 }
131
132 fn parse_entity_type(entity_type: &str) -> EntityType {
134 entity_type
135 .trim()
136 .to_lowercase()
137 .parse::<EntityType>()
138 .unwrap_or_else(|_| {
139 tracing::debug!(
140 "graph resolver: unknown entity type {:?}, falling back to Concept",
141 entity_type
142 );
143 EntityType::Concept
144 })
145 }
146
147 async fn lock_name(&self, normalized: &str) -> tokio::sync::OwnedMutexGuard<()> {
149 let lock = self
150 .name_locks
151 .entry(normalized.to_owned())
152 .or_insert_with(|| Arc::new(Mutex::new(())))
153 .clone();
154 lock.lock_owned().await
155 }
156
157 pub async fn resolve(
176 &self,
177 name: &str,
178 entity_type: &str,
179 summary: Option<&str>,
180 ) -> Result<(i64, ResolutionOutcome), MemoryError> {
181 let normalized = Self::normalize_name(name);
182
183 if normalized.is_empty() {
184 return Err(MemoryError::GraphStore("empty entity name".into()));
185 }
186
187 if normalized.len() < MIN_ENTITY_NAME_BYTES {
188 return Err(MemoryError::GraphStore(format!(
189 "entity name too short: {normalized:?} ({} bytes, min {MIN_ENTITY_NAME_BYTES})",
190 normalized.len()
191 )));
192 }
193
194 let et = Self::parse_entity_type(entity_type);
195
196 let surface_name = name.trim().to_owned();
198
199 let _guard = self.lock_name(&normalized).await;
201
202 if let Some(entity) = self.store.find_entity_by_alias(&normalized, et).await? {
204 self.store
205 .upsert_entity(&surface_name, &entity.canonical_name, et, summary)
206 .await?;
207 return Ok((entity.id, ResolutionOutcome::ExactMatch));
208 }
209
210 if let Some(entity) = self.store.find_entity(&normalized, et).await? {
212 self.store
213 .upsert_entity(&surface_name, &entity.canonical_name, et, summary)
214 .await?;
215 return Ok((entity.id, ResolutionOutcome::ExactMatch));
216 }
217
218 if let Some(outcome) = self
220 .resolve_via_embedding(&normalized, name, &surface_name, et, summary)
221 .await?
222 {
223 return Ok(outcome);
224 }
225
226 let entity_id = self
228 .store
229 .upsert_entity(&surface_name, &normalized, et, summary)
230 .await?;
231
232 self.register_aliases(entity_id, &normalized, name).await?;
233
234 Ok((entity_id, ResolutionOutcome::Created))
235 }
236
237 async fn embed_entity_text(
240 &self,
241 provider: &AnyProvider,
242 normalized: &str,
243 summary: Option<&str>,
244 ) -> Option<Vec<f32>> {
245 let safe_summary = truncate_to_bytes_ref(summary.unwrap_or(""), MAX_FACT_BYTES);
246 let embed_text = format!("{normalized}: {safe_summary}");
247 let embed_result = tokio::time::timeout(
248 std::time::Duration::from_secs(EMBED_TIMEOUT_SECS),
249 provider.embed(&embed_text),
250 )
251 .await;
252 match embed_result {
253 Ok(Ok(v)) => Some(v),
254 Ok(Err(err)) => {
255 self.fallback_count
256 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
257 tracing::warn!(entity_name = %normalized, error = %err,
258 "embed() failed; falling back to exact-match-only entity creation");
259 None
260 }
261 Err(_timeout) => {
262 self.fallback_count
263 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
264 tracing::warn!(entity_name = %normalized,
265 "embed() timed out after {}s; falling back to create new entity",
266 EMBED_TIMEOUT_SECS);
267 None
268 }
269 }
270 }
271
272 #[allow(clippy::too_many_arguments)]
275 async fn handle_ambiguous_candidate(
276 &self,
277 emb_store: &EmbeddingStore,
278 provider: &AnyProvider,
279 payload: &std::collections::HashMap<String, serde_json::Value>,
280 score: f32,
281 surface_name: &str,
282 normalized: &str,
283 et: EntityType,
284 summary: Option<&str>,
285 ) -> Result<Option<(i64, ResolutionOutcome)>, MemoryError> {
286 let entity_id = payload
287 .get("entity_id")
288 .and_then(serde_json::Value::as_i64)
289 .ok_or_else(|| MemoryError::GraphStore("missing entity_id in payload".into()))?;
290 let existing_name = payload
291 .get("name")
292 .and_then(|v| v.as_str())
293 .unwrap_or("")
294 .to_owned();
295 let existing_summary = payload
296 .get("summary")
297 .and_then(|v| v.as_str())
298 .unwrap_or("")
299 .to_owned();
300 let existing_type = payload
302 .get("entity_type")
303 .and_then(|v| v.as_str())
304 .unwrap_or(et.as_str())
305 .to_owned();
306 match self
307 .llm_disambiguate(
308 provider,
309 normalized,
310 et.as_str(),
311 summary.unwrap_or(""),
312 &existing_name,
313 &existing_type,
314 &existing_summary,
315 score,
316 )
317 .await
318 {
319 Some(true) => {
320 self.merge_entity(
321 emb_store,
322 provider,
323 entity_id,
324 surface_name,
325 normalized,
326 et,
327 summary,
328 )
329 .await?;
330 Ok(Some((entity_id, ResolutionOutcome::LlmDisambiguated)))
331 }
332 Some(false) => Ok(None),
333 None => {
334 self.fallback_count
335 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
336 tracing::warn!(entity_name = %normalized,
337 "LLM disambiguation failed; falling back to create new entity");
338 Ok(None)
339 }
340 }
341 }
342
343 async fn resolve_via_embedding(
346 &self,
347 normalized: &str,
348 original_name: &str,
349 surface_name: &str,
350 et: EntityType,
351 summary: Option<&str>,
352 ) -> Result<Option<(i64, ResolutionOutcome)>, MemoryError> {
353 let (Some(emb_store), Some(provider)) = (self.embedding_store, self.provider) else {
354 return Ok(None);
355 };
356
357 let Some(query_vec) = self.embed_entity_text(provider, normalized, summary).await else {
358 return Ok(None);
359 };
360
361 let type_filter = VectorFilter {
362 must: vec![FieldCondition {
363 field: "entity_type".into(),
364 value: FieldValue::Text(et.as_str().to_owned()),
365 }],
366 must_not: vec![],
367 };
368 let candidates = match emb_store
369 .search_collection(ENTITY_COLLECTION, &query_vec, 5, Some(type_filter))
370 .await
371 {
372 Ok(c) => c,
373 Err(err) => {
374 self.fallback_count
375 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
376 tracing::warn!(entity_name = %normalized, error = %err,
377 "Qdrant search failed; falling back to create new entity");
378 return self
379 .create_with_embedding(
380 emb_store,
381 surface_name,
382 normalized,
383 original_name,
384 et,
385 summary,
386 &query_vec,
387 )
388 .await
389 .map(Some);
390 }
391 };
392
393 if let Some(best) = candidates.first() {
394 let score = best.score;
395 if score >= self.similarity_threshold {
396 let entity_id = best
397 .payload
398 .get("entity_id")
399 .and_then(serde_json::Value::as_i64)
400 .ok_or_else(|| {
401 MemoryError::GraphStore("missing entity_id in payload".into())
402 })?;
403 self.merge_entity(
404 emb_store,
405 provider,
406 entity_id,
407 surface_name,
408 normalized,
409 et,
410 summary,
411 )
412 .await?;
413 return Ok(Some((
414 entity_id,
415 ResolutionOutcome::EmbeddingMatch { score },
416 )));
417 } else if score >= self.ambiguous_threshold
418 && let Some(result) = self
419 .handle_ambiguous_candidate(
420 emb_store,
421 provider,
422 &best.payload,
423 score,
424 surface_name,
425 normalized,
426 et,
427 summary,
428 )
429 .await?
430 {
431 return Ok(Some(result));
432 }
433 }
435
436 self.create_with_embedding(
438 emb_store,
439 surface_name,
440 normalized,
441 original_name,
442 et,
443 summary,
444 &query_vec,
445 )
446 .await
447 .map(Some)
448 }
449
450 #[allow(clippy::too_many_arguments)]
452 async fn create_with_embedding(
453 &self,
454 emb_store: &EmbeddingStore,
455 surface_name: &str,
456 normalized: &str,
457 original_name: &str,
458 et: EntityType,
459 summary: Option<&str>,
460 query_vec: &[f32],
461 ) -> Result<(i64, ResolutionOutcome), MemoryError> {
462 let entity_id = self
463 .store
464 .upsert_entity(surface_name, normalized, et, summary)
465 .await?;
466 self.register_aliases(entity_id, normalized, original_name)
467 .await?;
468 self.store_entity_embedding(
469 emb_store,
470 entity_id,
471 None,
472 normalized,
473 et,
474 summary.unwrap_or(""),
475 query_vec,
476 )
477 .await;
478 Ok((entity_id, ResolutionOutcome::Created))
479 }
480
481 async fn register_aliases(
483 &self,
484 entity_id: i64,
485 normalized: &str,
486 original_name: &str,
487 ) -> Result<(), MemoryError> {
488 self.store.add_alias(entity_id, normalized).await?;
489
490 let original_trimmed = original_name.trim().to_lowercase();
493 let original_clean_str = strip_control_chars(&original_trimmed);
494 let original_clean = truncate_to_bytes_ref(&original_clean_str, MAX_ENTITY_NAME_BYTES);
495 if original_clean != normalized {
496 self.store.add_alias(entity_id, original_clean).await?;
497 }
498
499 Ok(())
500 }
501
502 #[allow(clippy::too_many_arguments)]
504 async fn merge_entity(
505 &self,
506 emb_store: &EmbeddingStore,
507 provider: &AnyProvider,
508 entity_id: i64,
509 new_surface_name: &str,
510 new_canonical_name: &str,
511 entity_type: EntityType,
512 new_summary: Option<&str>,
513 ) -> Result<(), MemoryError> {
514 let existing = self.store.find_entity_by_id(entity_id).await?;
517 let existing_summary = existing
518 .as_ref()
519 .and_then(|e| e.summary.as_deref())
520 .unwrap_or("");
521
522 let merged_summary = if let Some(new) = new_summary {
523 if !new.is_empty() && !existing_summary.is_empty() {
524 let combined = format!("{existing_summary}; {new}");
525 truncate_to_bytes_ref(&combined, MAX_FACT_BYTES).to_owned()
527 } else if !new.is_empty() {
528 new.to_owned()
529 } else {
530 existing_summary.to_owned()
531 }
532 } else {
533 existing_summary.to_owned()
534 };
535
536 let summary_opt = if merged_summary.is_empty() {
537 None
538 } else {
539 Some(merged_summary.as_str())
540 };
541
542 let existing_canonical = existing.as_ref().map_or_else(
544 || new_canonical_name.to_owned(),
545 |e| e.canonical_name.clone(),
546 );
547 let existing_name_owned = existing
548 .as_ref()
549 .map_or_else(|| new_surface_name.to_owned(), |e| e.name.clone());
550 self.store
551 .upsert_entity(
552 &existing_name_owned,
553 &existing_canonical,
554 entity_type,
555 summary_opt,
556 )
557 .await?;
558
559 let existing_point_id = existing
561 .as_ref()
562 .and_then(|e| e.qdrant_point_id.as_deref())
563 .map(ToOwned::to_owned);
564
565 let embed_text = format!("{existing_name_owned}: {merged_summary}");
567 let embed_result = tokio::time::timeout(
568 std::time::Duration::from_secs(EMBED_TIMEOUT_SECS),
569 provider.embed(&embed_text),
570 )
571 .await;
572
573 match embed_result {
574 Ok(Ok(vec)) => {
575 self.store_entity_embedding(
576 emb_store,
577 entity_id,
578 existing_point_id.as_deref(),
579 &existing_name_owned,
580 entity_type,
581 &merged_summary,
582 &vec,
583 )
584 .await;
585 }
586 Ok(Err(err)) => {
587 tracing::warn!(
588 entity_id,
589 error = %err,
590 "merge re-embed failed; Qdrant entry may be stale"
591 );
592 }
593 Err(_) => {
594 tracing::warn!(
595 entity_id,
596 "merge re-embed timed out; Qdrant entry may be stale"
597 );
598 }
599 }
600
601 Ok(())
602 }
603
604 #[allow(clippy::too_many_arguments)]
612 async fn store_entity_embedding(
613 &self,
614 emb_store: &EmbeddingStore,
615 entity_id: i64,
616 existing_point_id: Option<&str>,
617 name: &str,
618 entity_type: EntityType,
619 summary: &str,
620 vector: &[f32],
621 ) {
622 let vector_size = u64::try_from(vector.len()).unwrap_or(384);
626 if let Err(err) = emb_store
627 .ensure_named_collection(ENTITY_COLLECTION, vector_size)
628 .await
629 {
630 tracing::error!(
631 error = %err,
632 "failed to ensure entity embedding collection; skipping Qdrant upsert"
633 );
634 return;
635 }
636
637 let payload = serde_json::json!({
638 "entity_id": entity_id,
639 "name": name,
640 "entity_type": entity_type.as_str(),
641 "summary": summary,
642 });
643
644 if let Some(point_id) = existing_point_id {
645 if let Err(err) = emb_store
647 .upsert_to_collection(ENTITY_COLLECTION, point_id, payload, vector.to_vec())
648 .await
649 {
650 tracing::warn!(
651 entity_id,
652 error = %err,
653 "Qdrant upsert (existing point) failed; Qdrant entry may be stale"
654 );
655 }
656 } else {
657 match emb_store
658 .store_to_collection(ENTITY_COLLECTION, payload, vector.to_vec())
659 .await
660 {
661 Ok(point_id) => {
662 if let Err(err) = self
663 .store
664 .set_entity_qdrant_point_id(entity_id, &point_id)
665 .await
666 {
667 tracing::warn!(
668 entity_id,
669 error = %err,
670 "failed to store qdrant_point_id in SQLite"
671 );
672 }
673 }
674 Err(err) => {
675 tracing::warn!(
676 entity_id,
677 error = %err,
678 "Qdrant upsert failed; entity created in SQLite, qdrant_point_id remains NULL"
679 );
680 }
681 }
682 }
683 }
684
685 #[allow(clippy::too_many_arguments)]
689 async fn llm_disambiguate(
690 &self,
691 provider: &AnyProvider,
692 new_name: &str,
693 new_type: &str,
694 new_summary: &str,
695 existing_name: &str,
696 existing_type: &str,
697 existing_summary: &str,
698 score: f32,
699 ) -> Option<bool> {
700 let prompt = format!(
701 "New entity:\n\
702 - Name: <external-data>{new_name}</external-data>\n\
703 - Type: <external-data>{new_type}</external-data>\n\
704 - Summary: <external-data>{new_summary}</external-data>\n\
705 \n\
706 Existing entity:\n\
707 - Name: <external-data>{existing_name}</external-data>\n\
708 - Type: <external-data>{existing_type}</external-data>\n\
709 - Summary: <external-data>{existing_summary}</external-data>\n\
710 \n\
711 Cosine similarity: {score:.2}\n\
712 \n\
713 Are these the same entity? Respond with JSON: {{\"same_entity\": true}} or {{\"same_entity\": false}}"
714 );
715
716 let messages = [
717 Message::from_legacy(
718 Role::System,
719 "You are an entity disambiguation assistant. Given a new entity mention and \
720 an existing entity from the knowledge graph, determine if they refer to the same \
721 real-world entity. Respond only with JSON.",
722 ),
723 Message::from_legacy(Role::User, prompt),
724 ];
725
726 let response = match provider.chat(&messages).await {
727 Ok(r) => r,
728 Err(err) => {
729 tracing::warn!(error = %err, "LLM disambiguation chat failed");
730 return None;
731 }
732 };
733
734 let json_str = extract_json(&response);
736 match serde_json::from_str::<DisambiguationResponse>(json_str) {
737 Ok(parsed) => Some(parsed.same_entity),
738 Err(err) => {
739 tracing::warn!(error = %err, response = %response, "failed to parse LLM disambiguation response");
740 None
741 }
742 }
743 }
744
745 pub async fn resolve_batch(
758 &self,
759 entities: &[ExtractedEntity],
760 ) -> Result<Vec<(i64, ResolutionOutcome)>, MemoryError> {
761 if entities.is_empty() {
762 return Ok(Vec::new());
763 }
764
765 let mut results: Vec<Option<(i64, ResolutionOutcome)>> = vec![None; entities.len()];
767
768 let mut stream = stream::iter(entities.iter().enumerate().map(|(i, e)| {
769 let name = e.name.clone();
770 let entity_type = e.entity_type.clone();
771 let summary = e.summary.clone();
772 async move {
773 let result = self.resolve(&name, &entity_type, summary.as_deref()).await;
774 (i, result)
775 }
776 }))
777 .buffer_unordered(4);
778
779 while let Some((i, result)) = stream.next().await {
780 match result {
781 Ok(outcome) => results[i] = Some(outcome),
782 Err(err) => return Err(err),
783 }
784 }
785
786 Ok(results
787 .into_iter()
788 .enumerate()
789 .map(|(i, r)| {
790 r.unwrap_or_else(|| {
791 tracing::warn!(
792 index = i,
793 "resolve_batch: missing result at index — bug in stream collection"
794 );
795 panic!("resolve_batch: missing result at index {i}")
796 })
797 })
798 .collect())
799 }
800
801 pub async fn resolve_edge(
815 &self,
816 source_id: i64,
817 target_id: i64,
818 relation: &str,
819 fact: &str,
820 confidence: f32,
821 episode_id: Option<MessageId>,
822 ) -> Result<Option<i64>, MemoryError> {
823 let relation_clean = strip_control_chars(&relation.trim().to_lowercase());
824 let normalized_relation =
825 truncate_to_bytes_ref(&relation_clean, MAX_RELATION_BYTES).to_owned();
826
827 let fact_clean = strip_control_chars(fact.trim());
828 let normalized_fact = truncate_to_bytes_ref(&fact_clean, MAX_FACT_BYTES).to_owned();
829
830 let existing_edges = self.store.edges_exact(source_id, target_id).await?;
832
833 let matching = existing_edges
834 .iter()
835 .find(|e| e.relation == normalized_relation);
836
837 if let Some(old) = matching {
838 if old.fact == normalized_fact {
839 return Ok(None);
841 }
842 self.store.invalidate_edge(old.id).await?;
844 }
845
846 let new_id = self
847 .store
848 .insert_edge(
849 source_id,
850 target_id,
851 &normalized_relation,
852 &normalized_fact,
853 confidence,
854 episode_id,
855 )
856 .await?;
857 Ok(Some(new_id))
858 }
859}
860
861fn extract_json(s: &str) -> &str {
863 let trimmed = s.trim();
864 if let Some(inner) = trimmed.strip_prefix("```json")
866 && let Some(end) = inner.rfind("```")
867 {
868 return inner[..end].trim();
869 }
870 if let Some(inner) = trimmed.strip_prefix("```")
871 && let Some(end) = inner.rfind("```")
872 {
873 return inner[..end].trim();
874 }
875 if let (Some(start), Some(end)) = (trimmed.find('{'), trimmed.rfind('}'))
877 && start <= end
878 {
879 return &trimmed[start..=end];
880 }
881 trimmed
882}
883
884#[cfg(test)]
885mod tests;