1mod algorithms;
24mod corrections;
25mod cross_session;
26mod graph;
27pub(crate) mod importance;
28pub mod persona;
29mod recall;
30mod summarization;
31pub mod trajectory;
32pub mod tree_consolidation;
33pub(crate) mod write_buffer;
34
35#[cfg(test)]
36mod tests;
37
38use std::sync::Arc;
39use std::sync::Mutex;
40use std::sync::atomic::AtomicU64;
41use std::time::Instant;
42
43use tokio::sync::RwLock;
44use zeph_llm::any::AnyProvider;
45use zeph_llm::provider::LlmProvider as _;
46
47use crate::admission::AdmissionControl;
48use crate::embedding_store::EmbeddingStore;
49use crate::error::MemoryError;
50use crate::retrieval_failure_logger::RetrievalFailureLogger;
51use crate::store::SqliteStore;
52use crate::store::retrieval_failures::RetrievalFailureRecord;
53use crate::token_counter::TokenCounter;
54
55pub(crate) const SESSION_SUMMARIES_COLLECTION: &str = "zeph_session_summaries";
56pub(crate) const KEY_FACTS_COLLECTION: &str = "zeph_key_facts";
57pub(crate) const CORRECTIONS_COLLECTION: &str = "zeph_corrections";
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub struct BackfillProgress {
62 pub done: usize,
64 pub total: usize,
66}
67
68pub use algorithms::{apply_mmr, apply_temporal_decay};
69pub use cross_session::SessionSummaryResult;
70pub use graph::{
71 ExtractionResult, ExtractionStats, GraphExtractionConfig, LinkingStats, NoteLinkingConfig,
72 PostExtractValidator, extract_and_store, link_memory_notes,
73};
74pub use persona::{
75 PersonaExtractionConfig, contains_self_referential_language, extract_persona_facts,
76};
77pub use recall::{EmbedContext, RecalledMessage};
78pub use summarization::{StructuredSummary, Summary, build_summarization_prompt};
79pub use trajectory::{TrajectoryEntry, TrajectoryExtractionConfig, extract_trajectory_entries};
80pub use tree_consolidation::{
81 TreeConsolidationConfig, TreeConsolidationResult, run_tree_consolidation_sweep,
82 start_tree_consolidation_loop,
83};
84pub use write_buffer::{BufferedWrite, WriteBuffer};
85
86#[derive(Debug, Clone)]
91pub(crate) struct CachedCentroid {
92 pub vector: Vec<f32>,
94 pub computed_at: Instant,
96}
97
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
103#[non_exhaustive]
104pub enum TemporalDecay {
105 Enabled,
107 #[default]
109 Disabled,
110}
111
112impl TemporalDecay {
113 #[must_use]
115 #[inline]
116 pub fn is_enabled(self) -> bool {
117 self == Self::Enabled
118 }
119}
120
121impl From<bool> for TemporalDecay {
122 fn from(b: bool) -> Self {
123 if b { Self::Enabled } else { Self::Disabled }
124 }
125}
126
127#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
132#[non_exhaustive]
133pub enum MmrReranking {
134 Enabled,
136 #[default]
138 Disabled,
139}
140
141impl MmrReranking {
142 #[must_use]
144 #[inline]
145 pub fn is_enabled(self) -> bool {
146 self == Self::Enabled
147 }
148}
149
150impl From<bool> for MmrReranking {
151 fn from(b: bool) -> Self {
152 if b { Self::Enabled } else { Self::Disabled }
153 }
154}
155
156#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
161#[non_exhaustive]
162pub enum ImportanceScoring {
163 Enabled,
165 #[default]
167 Disabled,
168}
169
170impl ImportanceScoring {
171 #[must_use]
173 #[inline]
174 pub fn is_enabled(self) -> bool {
175 self == Self::Enabled
176 }
177}
178
179impl From<bool> for ImportanceScoring {
180 fn from(b: bool) -> Self {
181 if b { Self::Enabled } else { Self::Disabled }
182 }
183}
184
185#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
190#[non_exhaustive]
191pub enum QueryBiasCorrection {
192 #[default]
194 Enabled,
195 Disabled,
197}
198
199impl QueryBiasCorrection {
200 #[must_use]
202 #[inline]
203 pub fn is_enabled(self) -> bool {
204 self == Self::Enabled
205 }
206}
207
208impl From<bool> for QueryBiasCorrection {
209 fn from(b: bool) -> Self {
210 if b { Self::Enabled } else { Self::Disabled }
211 }
212}
213
214#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
219#[non_exhaustive]
220pub enum HebbianReinforcement {
221 Enabled,
223 #[default]
225 Disabled,
226}
227
228impl HebbianReinforcement {
229 #[must_use]
231 #[inline]
232 pub fn is_enabled(self) -> bool {
233 self == Self::Enabled
234 }
235}
236
237impl From<bool> for HebbianReinforcement {
238 fn from(b: bool) -> Self {
239 if b { Self::Enabled } else { Self::Disabled }
240 }
241}
242
243#[derive(Debug, Clone, Copy, PartialEq, Eq)]
248pub(crate) enum QueryIntent {
249 FirstPerson,
251 Other,
253}
254
255#[derive(Debug, Clone)]
259pub struct HelaSpreadRuntime {
260 pub enabled: bool,
262 pub depth: u32,
264 pub max_visited: usize,
266 pub edge_types: Vec<crate::graph::EdgeType>,
268 pub step_budget: Option<std::time::Duration>,
270 pub embed_timeout: Option<std::time::Duration>,
272}
273
274impl Default for HelaSpreadRuntime {
275 fn default() -> Self {
276 Self {
277 enabled: false,
278 depth: 2,
279 max_visited: 200,
280 edge_types: Vec::new(),
281 step_budget: Some(std::time::Duration::from_millis(8)),
282 embed_timeout: Some(std::time::Duration::from_secs(5)),
283 }
284 }
285}
286
287pub struct SemanticMemory {
292 pub(crate) sqlite: SqliteStore,
293 pub(crate) qdrant: Option<Arc<EmbeddingStore>>,
294 pub(crate) provider: AnyProvider,
295 pub(crate) embed_provider: Option<AnyProvider>,
301 pub(crate) embedding_model: String,
302 pub(crate) vector_weight: f64,
303 pub(crate) keyword_weight: f64,
304 pub(crate) temporal_decay: TemporalDecay,
305 pub(crate) temporal_decay_half_life_days: u32,
306 pub(crate) mmr_reranking: MmrReranking,
307 pub(crate) mmr_lambda: f32,
308 pub(crate) importance_scoring: ImportanceScoring,
309 pub(crate) importance_weight: f64,
310 pub(crate) tier_boost_semantic: f64,
313 pub token_counter: Arc<TokenCounter>,
314 pub graph_store: Option<Arc<crate::graph::GraphStore>>,
315 pub experience: Option<Arc<crate::graph::experience::ExperienceStore>>,
319 pub reasoning: Option<Arc<crate::reasoning::ReasoningMemory>>,
323 pub(crate) community_detection_failures: Arc<AtomicU64>,
324 pub(crate) graph_extraction_count: Arc<AtomicU64>,
325 pub(crate) graph_extraction_failures: Arc<AtomicU64>,
326 pub(crate) last_qdrant_warn: Arc<AtomicU64>,
327 pub(crate) admission_control: Option<Arc<AdmissionControl>>,
329 pub(crate) quality_gate: Option<Arc<crate::quality_gate::QualityGate>>,
332 pub(crate) key_facts_dedup_threshold: f32,
336 pub(crate) embed_tasks: Mutex<tokio::task::JoinSet<()>>,
342 pub(crate) retrieval_depth: u32,
346 pub(crate) search_prompt_template: String,
351 pub(crate) depth_below_limit_warned: Arc<std::sync::atomic::AtomicBool>,
353 pub(crate) missing_placeholder_warned: Arc<std::sync::atomic::AtomicBool>,
355 pub(crate) query_bias_correction: QueryBiasCorrection,
357 pub(crate) query_bias_profile_weight: f32,
359 pub(crate) profile_centroid: RwLock<Option<CachedCentroid>>,
364 pub(crate) profile_centroid_ttl_secs: u64,
366 pub(crate) hebbian_reinforcement: HebbianReinforcement,
368 pub(crate) hebbian_lr: f32,
370 pub(crate) hebbian_spread: HelaSpreadRuntime,
372 pub(crate) retrieval_failure_logger: Option<RetrievalFailureLogger>,
376 pub(crate) summarization_llm_timeout_secs: u64,
378 pub(crate) query_sensitive_cost: bool,
383 pub(crate) five_signal: Option<Arc<crate::five_signal::FiveSignalRuntime>>,
388 pub(crate) embed_timeout: std::time::Duration,
392 pub(crate) graph_cancel: Mutex<Option<tokio_util::sync::CancellationToken>>,
398}
399
400impl SemanticMemory {
401 pub async fn new(
412 sqlite_path: &str,
413 qdrant_url: &str,
414 api_key: Option<&str>,
415 provider: AnyProvider,
416 embedding_model: &str,
417 ) -> Result<Self, MemoryError> {
418 Self::with_weights(
419 sqlite_path,
420 qdrant_url,
421 api_key,
422 provider,
423 embedding_model,
424 0.7,
425 0.3,
426 )
427 .await
428 }
429
430 pub async fn with_weights(
439 sqlite_path: &str,
440 qdrant_url: &str,
441 api_key: Option<&str>,
442 provider: AnyProvider,
443 embedding_model: &str,
444 vector_weight: f64,
445 keyword_weight: f64,
446 ) -> Result<Self, MemoryError> {
447 Self::with_weights_and_pool_size(
448 sqlite_path,
449 qdrant_url,
450 api_key,
451 provider,
452 embedding_model,
453 vector_weight,
454 keyword_weight,
455 5,
456 )
457 .await
458 }
459
460 #[allow(clippy::too_many_arguments)]
469 pub async fn with_weights_and_pool_size(
470 sqlite_path: &str,
471 qdrant_url: &str,
472 api_key: Option<&str>,
473 provider: AnyProvider,
474 embedding_model: &str,
475 vector_weight: f64,
476 keyword_weight: f64,
477 pool_size: u32,
478 ) -> Result<Self, MemoryError> {
479 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
480 let pool = sqlite.pool().clone();
481
482 let qdrant = match EmbeddingStore::new(qdrant_url, api_key, pool) {
483 Ok(store) => Some(Arc::new(store)),
484 Err(e) => {
485 tracing::warn!("Qdrant unavailable, semantic search disabled: {e:#}");
486 None
487 }
488 };
489
490 Ok(Self {
491 sqlite,
492 qdrant,
493 provider,
494 embed_provider: None,
495 embedding_model: embedding_model.into(),
496 vector_weight,
497 keyword_weight,
498 temporal_decay: TemporalDecay::Disabled,
499 temporal_decay_half_life_days: 30,
500 mmr_reranking: MmrReranking::Disabled,
501 mmr_lambda: 0.7,
502 importance_scoring: ImportanceScoring::Disabled,
503 importance_weight: 0.15,
504 tier_boost_semantic: 1.3,
505 token_counter: Arc::new(TokenCounter::new()),
506 graph_store: None,
507 experience: None,
508 reasoning: None,
509 community_detection_failures: Arc::new(AtomicU64::new(0)),
510 graph_extraction_count: Arc::new(AtomicU64::new(0)),
511 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
512 last_qdrant_warn: Arc::new(AtomicU64::new(0)),
513 admission_control: None,
514 quality_gate: None,
515 key_facts_dedup_threshold: 0.95,
516 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
517 retrieval_depth: 0,
518 search_prompt_template: String::new(),
519 depth_below_limit_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
520 missing_placeholder_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
521 query_bias_correction: QueryBiasCorrection::Enabled,
522 query_bias_profile_weight: 0.25,
523 profile_centroid: RwLock::new(None),
524 profile_centroid_ttl_secs: 300,
525 hebbian_reinforcement: HebbianReinforcement::Disabled,
526 hebbian_lr: 0.1,
527 hebbian_spread: HelaSpreadRuntime::default(),
528 retrieval_failure_logger: None,
529 summarization_llm_timeout_secs: 60,
530 query_sensitive_cost: false,
531 five_signal: None,
532 embed_timeout: std::time::Duration::from_secs(5),
533 graph_cancel: Mutex::new(None),
534 })
535 }
536
537 pub async fn with_qdrant_ops(
546 sqlite_path: &str,
547 ops: crate::QdrantOps,
548 provider: AnyProvider,
549 embedding_model: &str,
550 vector_weight: f64,
551 keyword_weight: f64,
552 pool_size: u32,
553 ) -> Result<Self, MemoryError> {
554 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
555 let pool = sqlite.pool().clone();
556 let store = EmbeddingStore::with_store(Box::new(ops), pool);
557
558 Ok(Self {
559 sqlite,
560 qdrant: Some(Arc::new(store)),
561 provider,
562 embed_provider: None,
563 embedding_model: embedding_model.into(),
564 vector_weight,
565 keyword_weight,
566 temporal_decay: TemporalDecay::Disabled,
567 temporal_decay_half_life_days: 30,
568 mmr_reranking: MmrReranking::Disabled,
569 mmr_lambda: 0.7,
570 importance_scoring: ImportanceScoring::Disabled,
571 importance_weight: 0.15,
572 tier_boost_semantic: 1.3,
573 token_counter: Arc::new(TokenCounter::new()),
574 graph_store: None,
575 experience: None,
576 reasoning: None,
577 community_detection_failures: Arc::new(AtomicU64::new(0)),
578 graph_extraction_count: Arc::new(AtomicU64::new(0)),
579 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
580 last_qdrant_warn: Arc::new(AtomicU64::new(0)),
581 admission_control: None,
582 quality_gate: None,
583 key_facts_dedup_threshold: 0.95,
584 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
585 retrieval_depth: 0,
586 search_prompt_template: String::new(),
587 depth_below_limit_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
588 missing_placeholder_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
589 query_bias_correction: QueryBiasCorrection::Enabled,
590 query_bias_profile_weight: 0.25,
591 profile_centroid: RwLock::new(None),
592 profile_centroid_ttl_secs: 300,
593 hebbian_reinforcement: HebbianReinforcement::Disabled,
594 hebbian_lr: 0.1,
595 hebbian_spread: HelaSpreadRuntime::default(),
596 retrieval_failure_logger: None,
597 summarization_llm_timeout_secs: 60,
598 query_sensitive_cost: false,
599 five_signal: None,
600 embed_timeout: std::time::Duration::from_secs(5),
601 graph_cancel: Mutex::new(None),
602 })
603 }
604
605 #[must_use]
610 pub fn with_graph_store(mut self, store: Arc<crate::graph::GraphStore>) -> Self {
611 self.graph_store = Some(store);
612 self
613 }
614
615 #[must_use]
621 pub fn with_experience_store(
622 mut self,
623 store: Arc<crate::graph::experience::ExperienceStore>,
624 ) -> Self {
625 self.experience = Some(store);
626 self
627 }
628
629 #[must_use]
635 pub fn with_reasoning(mut self, store: Arc<crate::reasoning::ReasoningMemory>) -> Self {
636 self.reasoning = Some(store);
637 self
638 }
639
640 #[must_use]
645 pub fn with_retrieval_failure_logger(mut self, logger: RetrievalFailureLogger) -> Self {
646 self.retrieval_failure_logger = Some(logger);
647 self
648 }
649
650 pub fn log_retrieval_failure(&self, record: RetrievalFailureRecord) {
656 if let Some(logger) = &self.retrieval_failure_logger {
657 logger.log(record);
658 }
659 }
660
661 #[must_use]
663 pub fn community_detection_failures(&self) -> u64 {
664 use std::sync::atomic::Ordering;
665 self.community_detection_failures.load(Ordering::Relaxed)
666 }
667
668 #[must_use]
670 pub fn graph_extraction_count(&self) -> u64 {
671 use std::sync::atomic::Ordering;
672 self.graph_extraction_count.load(Ordering::Relaxed)
673 }
674
675 #[must_use]
677 pub fn graph_extraction_failures(&self) -> u64 {
678 use std::sync::atomic::Ordering;
679 self.graph_extraction_failures.load(Ordering::Relaxed)
680 }
681
682 #[must_use]
684 pub fn with_ranking_options(
685 mut self,
686 temporal_decay: TemporalDecay,
687 temporal_decay_half_life_days: u32,
688 mmr_reranking: MmrReranking,
689 mmr_lambda: f32,
690 ) -> Self {
691 self.temporal_decay = temporal_decay;
692 self.temporal_decay_half_life_days = temporal_decay_half_life_days;
693 self.mmr_reranking = mmr_reranking;
694 self.mmr_lambda = mmr_lambda;
695 self
696 }
697
698 #[must_use]
700 pub fn with_importance_options(mut self, scoring: ImportanceScoring, weight: f64) -> Self {
701 self.importance_scoring = scoring;
702 self.importance_weight = weight;
703 self
704 }
705
706 #[must_use]
710 pub fn with_tier_boost(mut self, boost: f64) -> Self {
711 self.tier_boost_semantic = boost;
712 self
713 }
714
715 #[must_use]
720 pub fn with_admission_control(mut self, control: AdmissionControl) -> Self {
721 self.admission_control = Some(Arc::new(control));
722 self
723 }
724
725 #[must_use]
731 pub fn with_quality_gate(mut self, gate: Arc<crate::quality_gate::QualityGate>) -> Self {
732 self.quality_gate = Some(gate);
733 self
734 }
735
736 #[must_use]
741 pub fn with_key_facts_dedup_threshold(mut self, threshold: f32) -> Self {
742 self.key_facts_dedup_threshold = threshold;
743 self
744 }
745
746 #[must_use]
750 pub fn with_summarization_timeout(mut self, timeout_secs: u64) -> Self {
751 self.summarization_llm_timeout_secs = timeout_secs;
752 self
753 }
754
755 #[must_use]
761 pub fn with_embed_timeout(mut self, timeout_secs: u64) -> Self {
762 let t = std::time::Duration::from_secs(timeout_secs.max(1));
763 self.embed_timeout = t;
764 self.hebbian_spread.embed_timeout = Some(t);
765 self
766 }
767
768 #[must_use]
775 pub fn with_query_bias(
776 mut self,
777 correction: QueryBiasCorrection,
778 profile_weight: f32,
779 centroid_ttl_secs: u64,
780 ) -> Self {
781 self.query_bias_correction = correction;
782 self.query_bias_profile_weight = profile_weight.clamp(0.0, 1.0);
783 self.profile_centroid_ttl_secs = centroid_ttl_secs;
784 self
785 }
786
787 #[must_use]
792 pub fn with_hebbian_spread(mut self, runtime: HelaSpreadRuntime) -> Self {
793 self.hebbian_spread = runtime;
794 self
795 }
796
797 #[must_use]
802 pub fn with_hebbian(mut self, reinforcement: HebbianReinforcement, lr: f32) -> Self {
803 let lr = lr.max(0.0);
804 if reinforcement.is_enabled() && lr == 0.0 {
805 tracing::warn!("hebbian enabled with lr=0.0 — no reinforcement will occur");
806 }
807 self.hebbian_reinforcement = reinforcement;
808 self.hebbian_lr = lr;
809 self
810 }
811
812 #[must_use]
817 pub fn with_query_sensitive_cost(mut self, enabled: bool) -> Self {
818 self.query_sensitive_cost = enabled;
819 self
820 }
821
822 #[must_use]
839 pub fn with_five_signal(mut self, runtime: Arc<crate::five_signal::FiveSignalRuntime>) -> Self {
840 self.five_signal = Some(runtime);
841 self
842 }
843
844 #[must_use]
846 pub fn five_signal_runtime(&self) -> Option<Arc<crate::five_signal::FiveSignalRuntime>> {
847 self.five_signal.clone()
848 }
849
850 pub(crate) fn classify_query_intent(query: &str) -> QueryIntent {
855 if persona::contains_self_referential_language(query) {
856 QueryIntent::FirstPerson
857 } else {
858 QueryIntent::Other
859 }
860 }
861
862 #[tracing::instrument(name = "memory.query_bias.apply", skip(self, embedding), fields(query_len = query.len()))]
868 pub(crate) async fn apply_query_bias(&self, query: &str, embedding: Vec<f32>) -> Vec<f32> {
869 if !self.query_bias_correction.is_enabled() {
870 tracing::debug!(reason = "disabled", "query-bias: skipping");
871 return embedding;
872 }
873 if Self::classify_query_intent(query) != QueryIntent::FirstPerson {
874 tracing::debug!(reason = "not_first_person", "query-bias: skipping");
875 return embedding;
876 }
877 let Some(centroid) = self.profile_centroid_cached().await else {
878 tracing::debug!(reason = "no_centroid", "query-bias: skipping");
879 return embedding;
880 };
881 if centroid.len() != embedding.len() {
882 tracing::warn!(
883 centroid_dim = centroid.len(),
884 query_dim = embedding.len(),
885 reason = "dim_mismatch",
886 "query-bias: dimension mismatch between profile centroid and query embedding — skipping bias"
887 );
888 return embedding;
889 }
890 let w = self.query_bias_profile_weight;
891 tracing::debug!(
892 intent = "first_person",
893 centroid_dim = centroid.len(),
894 weight = w,
895 "query-bias: applying profile bias"
896 );
897 embedding
898 .iter()
899 .zip(centroid.iter())
900 .map(|(&q, &c)| (1.0 - w) * q + w * c)
901 .collect()
902 }
903
904 #[tracing::instrument(name = "memory.query_bias.centroid", skip(self))]
909 pub(crate) async fn profile_centroid_cached(&self) -> Option<Vec<f32>> {
910 {
912 let guard = self.profile_centroid.read().await;
913 if let Some(c) = &*guard
914 && c.computed_at.elapsed().as_secs() < self.profile_centroid_ttl_secs
915 {
916 let ttl_remaining = self
917 .profile_centroid_ttl_secs
918 .saturating_sub(c.computed_at.elapsed().as_secs());
919 tracing::debug!(
920 centroid_dim = c.vector.len(),
921 ttl_remaining_secs = ttl_remaining,
922 "query-bias: centroid cache hit"
923 );
924 return Some(c.vector.clone());
925 }
926 }
927 let computed = self.compute_profile_centroid().await;
929 let mut guard = self.profile_centroid.write().await;
930 match computed {
931 Some(v) => {
932 tracing::debug!(centroid_dim = v.len(), "query-bias: centroid computed");
933 *guard = Some(CachedCentroid {
934 vector: v.clone(),
935 computed_at: Instant::now(),
936 });
937 Some(v)
938 }
939 None => {
940 guard.as_ref().map(|c| c.vector.clone())
942 }
943 }
944 }
945
946 async fn compute_profile_centroid(&self) -> Option<Vec<f32>> {
951 let facts = match self.sqlite.load_persona_facts(0.0).await {
952 Ok(f) => f,
953 Err(e) => {
954 tracing::warn!(error = %e, "query-bias: failed to load persona facts");
955 return None;
956 }
957 };
958 if facts.is_empty() {
959 return None;
960 }
961 let provider = self.effective_embed_provider();
962 let texts: Vec<String> = facts.iter().map(|f| f.content.clone()).collect();
963 let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
964 for text in &texts {
965 match tokio::time::timeout(self.embed_timeout, provider.embed(text)).await {
966 Ok(Ok(v)) => embeddings.push(v),
967 Ok(Err(e)) => {
968 tracing::warn!(error = %e, "query-bias: failed to embed persona fact — skipping");
969 }
970 Err(_) => {
971 tracing::warn!("query-bias: embed timed out for persona fact — skipping");
972 }
973 }
974 }
975 if embeddings.is_empty() {
976 return None;
977 }
978 let dim = embeddings[0].len();
979 let mut centroid = vec![0.0f32; dim];
980 for emb in &embeddings {
981 if emb.len() != dim {
982 tracing::warn!(
983 expected = dim,
984 got = emb.len(),
985 "query-bias: persona embedding dimension mismatch — skipping fact"
986 );
987 continue;
988 }
989 for (c, &v) in centroid.iter_mut().zip(emb.iter()) {
990 *c += v;
991 }
992 }
993 #[allow(clippy::cast_precision_loss)]
994 let n = embeddings.len() as f32;
995 for c in &mut centroid {
996 *c /= n;
997 }
998 Some(centroid)
999 }
1000
1001 #[must_use]
1009 pub fn with_retrieval_options(
1010 mut self,
1011 depth: u32,
1012 search_prompt_template: impl Into<String>,
1013 ) -> Self {
1014 self.retrieval_depth = depth;
1015 self.search_prompt_template = search_prompt_template.into();
1016 self
1017 }
1018
1019 pub(crate) fn effective_depth(&self, limit: usize) -> usize {
1028 use std::sync::atomic::Ordering;
1029
1030 let depth = self.retrieval_depth as usize;
1031 if depth == 0 {
1032 return limit.saturating_mul(2);
1033 }
1034 if depth < limit {
1035 if !self.depth_below_limit_warned.swap(true, Ordering::Relaxed) {
1036 tracing::warn!(
1037 retrieval_depth = depth,
1038 recall_limit = limit,
1039 "memory.retrieval.depth < recall_limit; ANN pool cannot saturate top-k — consider raising depth"
1040 );
1041 }
1042 } else if depth < limit.saturating_mul(2) {
1043 tracing::info!(
1044 retrieval_depth = depth,
1045 recall_limit = limit,
1046 legacy_default = limit.saturating_mul(2),
1047 "memory.retrieval.depth is below legacy limit*2; ANN pool will be smaller than pre-#3340"
1048 );
1049 } else {
1050 tracing::debug!(
1051 retrieval_depth = depth,
1052 recall_limit = limit,
1053 "recall: using configured ANN depth"
1054 );
1055 }
1056 depth
1057 }
1058
1059 pub(crate) fn apply_search_prompt(&self, query: &str) -> String {
1064 use std::sync::atomic::Ordering;
1065
1066 let template = &self.search_prompt_template;
1067 if template.is_empty() {
1068 return query.to_owned();
1069 }
1070 if !template.contains("{query}") {
1071 if !self
1072 .missing_placeholder_warned
1073 .swap(true, Ordering::Relaxed)
1074 {
1075 tracing::warn!(
1076 template = template.as_str(),
1077 "memory.retrieval.search_prompt_template has no {{query}} placeholder — \
1078 using raw query as-is"
1079 );
1080 }
1081 return query.to_owned();
1082 }
1083 template.replace("{query}", query)
1084 }
1085
1086 #[must_use]
1092 pub fn with_embedding_provider(mut self, embed_provider: AnyProvider) -> Self {
1093 self.embed_provider = Some(embed_provider);
1094 self
1095 }
1096
1097 pub fn effective_embed_provider(&self) -> &AnyProvider {
1101 self.embed_provider.as_ref().unwrap_or(&self.provider)
1102 }
1103
1104 #[must_use]
1108 pub fn from_parts(
1109 sqlite: SqliteStore,
1110 qdrant: Option<Arc<EmbeddingStore>>,
1111 provider: AnyProvider,
1112 embedding_model: impl Into<String>,
1113 vector_weight: f64,
1114 keyword_weight: f64,
1115 token_counter: Arc<TokenCounter>,
1116 ) -> Self {
1117 Self {
1118 sqlite,
1119 qdrant,
1120 provider,
1121 embed_provider: None,
1122 embedding_model: embedding_model.into(),
1123 vector_weight,
1124 keyword_weight,
1125 temporal_decay: TemporalDecay::Disabled,
1126 temporal_decay_half_life_days: 30,
1127 mmr_reranking: MmrReranking::Disabled,
1128 mmr_lambda: 0.7,
1129 importance_scoring: ImportanceScoring::Disabled,
1130 importance_weight: 0.15,
1131 tier_boost_semantic: 1.3,
1132 token_counter,
1133 graph_store: None,
1134 experience: None,
1135 reasoning: None,
1136 community_detection_failures: Arc::new(AtomicU64::new(0)),
1137 graph_extraction_count: Arc::new(AtomicU64::new(0)),
1138 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
1139 last_qdrant_warn: Arc::new(AtomicU64::new(0)),
1140 admission_control: None,
1141 quality_gate: None,
1142 key_facts_dedup_threshold: 0.95,
1143 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
1144 retrieval_depth: 0,
1145 search_prompt_template: String::new(),
1146 depth_below_limit_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
1147 missing_placeholder_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
1148 query_bias_correction: QueryBiasCorrection::Enabled,
1149 query_bias_profile_weight: 0.25,
1150 profile_centroid: RwLock::new(None),
1151 profile_centroid_ttl_secs: 300,
1152 hebbian_reinforcement: HebbianReinforcement::Disabled,
1153 hebbian_lr: 0.1,
1154 hebbian_spread: HelaSpreadRuntime::default(),
1155 retrieval_failure_logger: None,
1156 summarization_llm_timeout_secs: 60,
1157 query_sensitive_cost: false,
1158 five_signal: None,
1159 embed_timeout: std::time::Duration::from_secs(5),
1160 graph_cancel: Mutex::new(None),
1161 }
1162 }
1163
1164 pub async fn with_sqlite_backend(
1170 sqlite_path: &str,
1171 provider: AnyProvider,
1172 embedding_model: &str,
1173 vector_weight: f64,
1174 keyword_weight: f64,
1175 ) -> Result<Self, MemoryError> {
1176 Self::with_sqlite_backend_and_pool_size(
1177 sqlite_path,
1178 provider,
1179 embedding_model,
1180 vector_weight,
1181 keyword_weight,
1182 5,
1183 )
1184 .await
1185 }
1186
1187 pub async fn with_sqlite_backend_and_pool_size(
1193 sqlite_path: &str,
1194 provider: AnyProvider,
1195 embedding_model: &str,
1196 vector_weight: f64,
1197 keyword_weight: f64,
1198 pool_size: u32,
1199 ) -> Result<Self, MemoryError> {
1200 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
1201 let pool = sqlite.pool().clone();
1202 let store = EmbeddingStore::new_sqlite(pool);
1203
1204 Ok(Self {
1205 sqlite,
1206 qdrant: Some(Arc::new(store)),
1207 provider,
1208 embed_provider: None,
1209 embedding_model: embedding_model.into(),
1210 vector_weight,
1211 keyword_weight,
1212 temporal_decay: TemporalDecay::Disabled,
1213 temporal_decay_half_life_days: 30,
1214 mmr_reranking: MmrReranking::Disabled,
1215 mmr_lambda: 0.7,
1216 importance_scoring: ImportanceScoring::Disabled,
1217 importance_weight: 0.15,
1218 tier_boost_semantic: 1.3,
1219 token_counter: Arc::new(TokenCounter::new()),
1220 graph_store: None,
1221 experience: None,
1222 reasoning: None,
1223 community_detection_failures: Arc::new(AtomicU64::new(0)),
1224 graph_extraction_count: Arc::new(AtomicU64::new(0)),
1225 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
1226 last_qdrant_warn: Arc::new(AtomicU64::new(0)),
1227 admission_control: None,
1228 quality_gate: None,
1229 key_facts_dedup_threshold: 0.95,
1230 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
1231 retrieval_depth: 0,
1232 search_prompt_template: String::new(),
1233 depth_below_limit_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
1234 missing_placeholder_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
1235 query_bias_correction: QueryBiasCorrection::Enabled,
1236 query_bias_profile_weight: 0.25,
1237 profile_centroid: RwLock::new(None),
1238 profile_centroid_ttl_secs: 300,
1239 hebbian_reinforcement: HebbianReinforcement::Disabled,
1240 hebbian_lr: 0.1,
1241 hebbian_spread: HelaSpreadRuntime::default(),
1242 retrieval_failure_logger: None,
1243 summarization_llm_timeout_secs: 60,
1244 query_sensitive_cost: false,
1245 five_signal: None,
1246 embed_timeout: std::time::Duration::from_secs(5),
1247 graph_cancel: Mutex::new(None),
1248 })
1249 }
1250
1251 #[must_use]
1253 pub fn sqlite(&self) -> &SqliteStore {
1254 &self.sqlite
1255 }
1256
1257 #[must_use]
1259 pub fn embed_timeout(&self) -> std::time::Duration {
1260 self.embed_timeout
1261 }
1262
1263 pub async fn is_vector_store_connected(&self) -> bool {
1268 match self.qdrant.as_ref() {
1269 Some(store) => store.health_check().await,
1270 None => false,
1271 }
1272 }
1273
1274 #[must_use]
1276 pub fn has_vector_store(&self) -> bool {
1277 self.qdrant.is_some()
1278 }
1279
1280 #[must_use]
1282 pub fn embedding_store(&self) -> Option<&Arc<EmbeddingStore>> {
1283 self.qdrant.as_ref()
1284 }
1285
1286 pub fn provider(&self) -> &AnyProvider {
1288 &self.provider
1289 }
1290
1291 pub async fn message_count(
1297 &self,
1298 conversation_id: crate::types::ConversationId,
1299 ) -> Result<i64, MemoryError> {
1300 self.sqlite.count_messages(conversation_id).await
1301 }
1302
1303 pub async fn unsummarized_message_count(
1309 &self,
1310 conversation_id: crate::types::ConversationId,
1311 ) -> Result<i64, MemoryError> {
1312 let after_id = self
1313 .sqlite
1314 .latest_summary_last_message_id(conversation_id)
1315 .await?
1316 .unwrap_or(crate::types::MessageId(0));
1317 self.sqlite
1318 .count_messages_after(conversation_id, after_id)
1319 .await
1320 }
1321
1322 pub async fn load_promotion_window(
1343 &self,
1344 max_items: usize,
1345 ) -> Result<Vec<crate::compression::promotion::PromotionInput>, MemoryError> {
1346 use zeph_db::sql;
1347
1348 let limit = i64::try_from(max_items).unwrap_or(i64::MAX);
1349 let rows: Vec<(
1350 crate::types::MessageId,
1351 crate::types::ConversationId,
1352 String,
1353 )> = zeph_db::query_as(sql!(
1354 "SELECT id, conversation_id, content \
1355 FROM messages \
1356 WHERE deleted_at IS NULL \
1357 ORDER BY id DESC \
1358 LIMIT ?"
1359 ))
1360 .bind(limit)
1361 .fetch_all(self.sqlite.pool())
1362 .await?;
1363
1364 let mut vectors = if let Some(qdrant) = &self.qdrant {
1365 let ids: Vec<_> = rows.iter().map(|(id, _, _)| *id).collect();
1366 let mut raw = qdrant.get_vectors_for_messages(&ids).await?;
1367
1368 let ref_dim = raw.values().next().map(Vec::len);
1370 if let Some(ref_dim) = ref_dim {
1371 let mismatched: Vec<_> = raw
1372 .iter()
1373 .filter(|(_, v)| v.len() != ref_dim)
1374 .map(|(id, v)| (*id, v.len()))
1375 .collect();
1376 if !mismatched.is_empty() {
1377 tracing::warn!(
1378 expected_dim = ref_dim,
1379 dropped_count = mismatched.len(),
1380 "load_promotion_window: dimension mismatch — dropping mismatched vectors"
1381 );
1382 for (id, _) in mismatched {
1383 raw.remove(&id);
1384 }
1385 }
1386 }
1387 raw
1388 } else {
1389 std::collections::HashMap::new()
1390 };
1391
1392 Ok(rows
1393 .into_iter()
1394 .map(|(message_id, conversation_id, content)| {
1395 crate::compression::promotion::PromotionInput {
1396 message_id,
1397 conversation_id,
1398 content,
1399 embedding: vectors.remove(&message_id),
1400 }
1401 })
1402 .collect())
1403 }
1404
1405 pub async fn retrieve_reasoning_strategies(
1418 &self,
1419 query: &str,
1420 limit: usize,
1421 ) -> Result<Vec<crate::reasoning::ReasoningStrategy>, MemoryError> {
1422 let Some(reasoning) = &self.reasoning else {
1423 return Ok(Vec::new());
1424 };
1425 if !self.effective_embed_provider().supports_embeddings() {
1426 return Ok(Vec::new());
1427 }
1428 let embedding = match tokio::time::timeout(
1429 self.embed_timeout,
1430 self.effective_embed_provider().embed(query),
1431 )
1432 .await
1433 {
1434 Ok(Ok(v)) => v,
1435 Ok(Err(e)) => return Err(e.into()),
1436 Err(_) => {
1437 tracing::warn!("retrieve_reasoning_strategies: embed timed out, returning empty");
1438 return Ok(Vec::new());
1439 }
1440 };
1441 reasoning
1442 .retrieve_by_embedding(&embedding, limit as u64)
1443 .await
1444 }
1445}