1mod algorithms;
24mod corrections;
25mod cross_session;
26mod graph;
27pub(crate) mod importance;
28pub mod persona;
29mod recall;
30mod summarization;
31pub mod trajectory;
32pub mod tree_consolidation;
33pub(crate) mod write_buffer;
34
35#[cfg(test)]
36mod tests;
37
38use std::sync::Arc;
39use std::sync::Mutex;
40use std::sync::atomic::AtomicU64;
41use std::time::Instant;
42
43use tokio::sync::RwLock;
44use zeph_llm::any::AnyProvider;
45use zeph_llm::provider::LlmProvider as _;
46
47use crate::admission::AdmissionControl;
48use crate::embedding_store::EmbeddingStore;
49use crate::error::MemoryError;
50use crate::retrieval_failure_logger::RetrievalFailureLogger;
51use crate::store::SqliteStore;
52use crate::store::retrieval_failures::RetrievalFailureRecord;
53use crate::token_counter::TokenCounter;
54
55pub(crate) const SESSION_SUMMARIES_COLLECTION: &str = "zeph_session_summaries";
56pub(crate) const KEY_FACTS_COLLECTION: &str = "zeph_key_facts";
57pub(crate) const CORRECTIONS_COLLECTION: &str = "zeph_corrections";
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub struct BackfillProgress {
62 pub done: usize,
64 pub total: usize,
66}
67
68pub use algorithms::{apply_mmr, apply_temporal_decay};
69pub use cross_session::SessionSummaryResult;
70pub use graph::{
71 ExtractionResult, ExtractionStats, GraphExtractionConfig, LinkingStats, NoteLinkingConfig,
72 PostExtractValidator, extract_and_store, link_memory_notes,
73};
74pub use persona::{
75 PersonaExtractionConfig, contains_self_referential_language, extract_persona_facts,
76};
77pub use recall::{EmbedContext, RecalledMessage};
78pub use summarization::{StructuredSummary, Summary, build_summarization_prompt};
79pub use trajectory::{TrajectoryEntry, TrajectoryExtractionConfig, extract_trajectory_entries};
80pub use tree_consolidation::{
81 TreeConsolidationConfig, TreeConsolidationResult, run_tree_consolidation_sweep,
82 start_tree_consolidation_loop,
83};
84pub use write_buffer::{BufferedWrite, WriteBuffer};
85
86#[derive(Debug, Clone)]
91pub(crate) struct CachedCentroid {
92 pub vector: Vec<f32>,
94 pub computed_at: Instant,
96}
97
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
103pub enum TemporalDecay {
104 Enabled,
106 #[default]
108 Disabled,
109}
110
111impl TemporalDecay {
112 #[must_use]
114 #[inline]
115 pub fn is_enabled(self) -> bool {
116 self == Self::Enabled
117 }
118}
119
120impl From<bool> for TemporalDecay {
121 fn from(b: bool) -> Self {
122 if b { Self::Enabled } else { Self::Disabled }
123 }
124}
125
126#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
131pub enum MmrReranking {
132 Enabled,
134 #[default]
136 Disabled,
137}
138
139impl MmrReranking {
140 #[must_use]
142 #[inline]
143 pub fn is_enabled(self) -> bool {
144 self == Self::Enabled
145 }
146}
147
148impl From<bool> for MmrReranking {
149 fn from(b: bool) -> Self {
150 if b { Self::Enabled } else { Self::Disabled }
151 }
152}
153
154#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
159pub enum ImportanceScoring {
160 Enabled,
162 #[default]
164 Disabled,
165}
166
167impl ImportanceScoring {
168 #[must_use]
170 #[inline]
171 pub fn is_enabled(self) -> bool {
172 self == Self::Enabled
173 }
174}
175
176impl From<bool> for ImportanceScoring {
177 fn from(b: bool) -> Self {
178 if b { Self::Enabled } else { Self::Disabled }
179 }
180}
181
182#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
187pub enum QueryBiasCorrection {
188 #[default]
190 Enabled,
191 Disabled,
193}
194
195impl QueryBiasCorrection {
196 #[must_use]
198 #[inline]
199 pub fn is_enabled(self) -> bool {
200 self == Self::Enabled
201 }
202}
203
204impl From<bool> for QueryBiasCorrection {
205 fn from(b: bool) -> Self {
206 if b { Self::Enabled } else { Self::Disabled }
207 }
208}
209
210#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
215pub enum HebbianReinforcement {
216 Enabled,
218 #[default]
220 Disabled,
221}
222
223impl HebbianReinforcement {
224 #[must_use]
226 #[inline]
227 pub fn is_enabled(self) -> bool {
228 self == Self::Enabled
229 }
230}
231
232impl From<bool> for HebbianReinforcement {
233 fn from(b: bool) -> Self {
234 if b { Self::Enabled } else { Self::Disabled }
235 }
236}
237
238#[derive(Debug, Clone, Copy, PartialEq, Eq)]
243pub(crate) enum QueryIntent {
244 FirstPerson,
246 Other,
248}
249
250#[derive(Debug, Clone)]
254pub struct HelaSpreadRuntime {
255 pub enabled: bool,
257 pub depth: u32,
259 pub max_visited: usize,
261 pub edge_types: Vec<crate::graph::EdgeType>,
263 pub step_budget: Option<std::time::Duration>,
265 pub embed_timeout: Option<std::time::Duration>,
267}
268
269impl Default for HelaSpreadRuntime {
270 fn default() -> Self {
271 Self {
272 enabled: false,
273 depth: 2,
274 max_visited: 200,
275 edge_types: Vec::new(),
276 step_budget: Some(std::time::Duration::from_millis(8)),
277 embed_timeout: Some(std::time::Duration::from_secs(5)),
278 }
279 }
280}
281
282pub struct SemanticMemory {
287 pub(crate) sqlite: SqliteStore,
288 pub(crate) qdrant: Option<Arc<EmbeddingStore>>,
289 pub(crate) provider: AnyProvider,
290 pub(crate) embed_provider: Option<AnyProvider>,
296 pub(crate) embedding_model: String,
297 pub(crate) vector_weight: f64,
298 pub(crate) keyword_weight: f64,
299 pub(crate) temporal_decay: TemporalDecay,
300 pub(crate) temporal_decay_half_life_days: u32,
301 pub(crate) mmr_reranking: MmrReranking,
302 pub(crate) mmr_lambda: f32,
303 pub(crate) importance_scoring: ImportanceScoring,
304 pub(crate) importance_weight: f64,
305 pub(crate) tier_boost_semantic: f64,
308 pub token_counter: Arc<TokenCounter>,
309 pub graph_store: Option<Arc<crate::graph::GraphStore>>,
310 pub experience: Option<Arc<crate::graph::experience::ExperienceStore>>,
314 pub reasoning: Option<Arc<crate::reasoning::ReasoningMemory>>,
318 pub(crate) community_detection_failures: Arc<AtomicU64>,
319 pub(crate) graph_extraction_count: Arc<AtomicU64>,
320 pub(crate) graph_extraction_failures: Arc<AtomicU64>,
321 pub(crate) last_qdrant_warn: Arc<AtomicU64>,
322 pub(crate) admission_control: Option<Arc<AdmissionControl>>,
324 pub(crate) quality_gate: Option<Arc<crate::quality_gate::QualityGate>>,
327 pub(crate) key_facts_dedup_threshold: f32,
331 pub(crate) embed_tasks: Mutex<tokio::task::JoinSet<()>>,
337 pub(crate) retrieval_depth: u32,
341 pub(crate) search_prompt_template: String,
346 pub(crate) depth_below_limit_warned: Arc<std::sync::atomic::AtomicBool>,
348 pub(crate) missing_placeholder_warned: Arc<std::sync::atomic::AtomicBool>,
350 pub(crate) query_bias_correction: QueryBiasCorrection,
352 pub(crate) query_bias_profile_weight: f32,
354 pub(crate) profile_centroid: RwLock<Option<CachedCentroid>>,
359 pub(crate) profile_centroid_ttl_secs: u64,
361 pub(crate) hebbian_reinforcement: HebbianReinforcement,
363 pub(crate) hebbian_lr: f32,
365 pub(crate) hebbian_spread: HelaSpreadRuntime,
367 pub(crate) retrieval_failure_logger: Option<RetrievalFailureLogger>,
371 pub(crate) summarization_llm_timeout_secs: u64,
373 pub(crate) query_sensitive_cost: bool,
378}
379
380impl SemanticMemory {
381 pub async fn new(
392 sqlite_path: &str,
393 qdrant_url: &str,
394 api_key: Option<&str>,
395 provider: AnyProvider,
396 embedding_model: &str,
397 ) -> Result<Self, MemoryError> {
398 Self::with_weights(
399 sqlite_path,
400 qdrant_url,
401 api_key,
402 provider,
403 embedding_model,
404 0.7,
405 0.3,
406 )
407 .await
408 }
409
410 pub async fn with_weights(
419 sqlite_path: &str,
420 qdrant_url: &str,
421 api_key: Option<&str>,
422 provider: AnyProvider,
423 embedding_model: &str,
424 vector_weight: f64,
425 keyword_weight: f64,
426 ) -> Result<Self, MemoryError> {
427 Self::with_weights_and_pool_size(
428 sqlite_path,
429 qdrant_url,
430 api_key,
431 provider,
432 embedding_model,
433 vector_weight,
434 keyword_weight,
435 5,
436 )
437 .await
438 }
439
440 #[allow(clippy::too_many_arguments)]
449 pub async fn with_weights_and_pool_size(
450 sqlite_path: &str,
451 qdrant_url: &str,
452 api_key: Option<&str>,
453 provider: AnyProvider,
454 embedding_model: &str,
455 vector_weight: f64,
456 keyword_weight: f64,
457 pool_size: u32,
458 ) -> Result<Self, MemoryError> {
459 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
460 let pool = sqlite.pool().clone();
461
462 let qdrant = match EmbeddingStore::new(qdrant_url, api_key, pool) {
463 Ok(store) => Some(Arc::new(store)),
464 Err(e) => {
465 tracing::warn!("Qdrant unavailable, semantic search disabled: {e:#}");
466 None
467 }
468 };
469
470 Ok(Self {
471 sqlite,
472 qdrant,
473 provider,
474 embed_provider: None,
475 embedding_model: embedding_model.into(),
476 vector_weight,
477 keyword_weight,
478 temporal_decay: TemporalDecay::Disabled,
479 temporal_decay_half_life_days: 30,
480 mmr_reranking: MmrReranking::Disabled,
481 mmr_lambda: 0.7,
482 importance_scoring: ImportanceScoring::Disabled,
483 importance_weight: 0.15,
484 tier_boost_semantic: 1.3,
485 token_counter: Arc::new(TokenCounter::new()),
486 graph_store: None,
487 experience: None,
488 reasoning: None,
489 community_detection_failures: Arc::new(AtomicU64::new(0)),
490 graph_extraction_count: Arc::new(AtomicU64::new(0)),
491 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
492 last_qdrant_warn: Arc::new(AtomicU64::new(0)),
493 admission_control: None,
494 quality_gate: None,
495 key_facts_dedup_threshold: 0.95,
496 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
497 retrieval_depth: 0,
498 search_prompt_template: String::new(),
499 depth_below_limit_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
500 missing_placeholder_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
501 query_bias_correction: QueryBiasCorrection::Enabled,
502 query_bias_profile_weight: 0.25,
503 profile_centroid: RwLock::new(None),
504 profile_centroid_ttl_secs: 300,
505 hebbian_reinforcement: HebbianReinforcement::Disabled,
506 hebbian_lr: 0.1,
507 hebbian_spread: HelaSpreadRuntime::default(),
508 retrieval_failure_logger: None,
509 summarization_llm_timeout_secs: 60,
510 query_sensitive_cost: false,
511 })
512 }
513
514 pub async fn with_qdrant_ops(
523 sqlite_path: &str,
524 ops: crate::QdrantOps,
525 provider: AnyProvider,
526 embedding_model: &str,
527 vector_weight: f64,
528 keyword_weight: f64,
529 pool_size: u32,
530 ) -> Result<Self, MemoryError> {
531 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
532 let pool = sqlite.pool().clone();
533 let store = EmbeddingStore::with_store(Box::new(ops), pool);
534
535 Ok(Self {
536 sqlite,
537 qdrant: Some(Arc::new(store)),
538 provider,
539 embed_provider: None,
540 embedding_model: embedding_model.into(),
541 vector_weight,
542 keyword_weight,
543 temporal_decay: TemporalDecay::Disabled,
544 temporal_decay_half_life_days: 30,
545 mmr_reranking: MmrReranking::Disabled,
546 mmr_lambda: 0.7,
547 importance_scoring: ImportanceScoring::Disabled,
548 importance_weight: 0.15,
549 tier_boost_semantic: 1.3,
550 token_counter: Arc::new(TokenCounter::new()),
551 graph_store: None,
552 experience: None,
553 reasoning: None,
554 community_detection_failures: Arc::new(AtomicU64::new(0)),
555 graph_extraction_count: Arc::new(AtomicU64::new(0)),
556 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
557 last_qdrant_warn: Arc::new(AtomicU64::new(0)),
558 admission_control: None,
559 quality_gate: None,
560 key_facts_dedup_threshold: 0.95,
561 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
562 retrieval_depth: 0,
563 search_prompt_template: String::new(),
564 depth_below_limit_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
565 missing_placeholder_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
566 query_bias_correction: QueryBiasCorrection::Enabled,
567 query_bias_profile_weight: 0.25,
568 profile_centroid: RwLock::new(None),
569 profile_centroid_ttl_secs: 300,
570 hebbian_reinforcement: HebbianReinforcement::Disabled,
571 hebbian_lr: 0.1,
572 hebbian_spread: HelaSpreadRuntime::default(),
573 retrieval_failure_logger: None,
574 summarization_llm_timeout_secs: 60,
575 query_sensitive_cost: false,
576 })
577 }
578
579 #[must_use]
584 pub fn with_graph_store(mut self, store: Arc<crate::graph::GraphStore>) -> Self {
585 self.graph_store = Some(store);
586 self
587 }
588
589 #[must_use]
595 pub fn with_experience_store(
596 mut self,
597 store: Arc<crate::graph::experience::ExperienceStore>,
598 ) -> Self {
599 self.experience = Some(store);
600 self
601 }
602
603 #[must_use]
609 pub fn with_reasoning(mut self, store: Arc<crate::reasoning::ReasoningMemory>) -> Self {
610 self.reasoning = Some(store);
611 self
612 }
613
614 #[must_use]
619 pub fn with_retrieval_failure_logger(mut self, logger: RetrievalFailureLogger) -> Self {
620 self.retrieval_failure_logger = Some(logger);
621 self
622 }
623
624 pub fn log_retrieval_failure(&self, record: RetrievalFailureRecord) {
630 if let Some(logger) = &self.retrieval_failure_logger {
631 logger.log(record);
632 }
633 }
634
635 #[must_use]
637 pub fn community_detection_failures(&self) -> u64 {
638 use std::sync::atomic::Ordering;
639 self.community_detection_failures.load(Ordering::Relaxed)
640 }
641
642 #[must_use]
644 pub fn graph_extraction_count(&self) -> u64 {
645 use std::sync::atomic::Ordering;
646 self.graph_extraction_count.load(Ordering::Relaxed)
647 }
648
649 #[must_use]
651 pub fn graph_extraction_failures(&self) -> u64 {
652 use std::sync::atomic::Ordering;
653 self.graph_extraction_failures.load(Ordering::Relaxed)
654 }
655
656 #[must_use]
658 pub fn with_ranking_options(
659 mut self,
660 temporal_decay: TemporalDecay,
661 temporal_decay_half_life_days: u32,
662 mmr_reranking: MmrReranking,
663 mmr_lambda: f32,
664 ) -> Self {
665 self.temporal_decay = temporal_decay;
666 self.temporal_decay_half_life_days = temporal_decay_half_life_days;
667 self.mmr_reranking = mmr_reranking;
668 self.mmr_lambda = mmr_lambda;
669 self
670 }
671
672 #[must_use]
674 pub fn with_importance_options(mut self, scoring: ImportanceScoring, weight: f64) -> Self {
675 self.importance_scoring = scoring;
676 self.importance_weight = weight;
677 self
678 }
679
680 #[must_use]
684 pub fn with_tier_boost(mut self, boost: f64) -> Self {
685 self.tier_boost_semantic = boost;
686 self
687 }
688
689 #[must_use]
694 pub fn with_admission_control(mut self, control: AdmissionControl) -> Self {
695 self.admission_control = Some(Arc::new(control));
696 self
697 }
698
699 #[must_use]
705 pub fn with_quality_gate(mut self, gate: Arc<crate::quality_gate::QualityGate>) -> Self {
706 self.quality_gate = Some(gate);
707 self
708 }
709
710 #[must_use]
715 pub fn with_key_facts_dedup_threshold(mut self, threshold: f32) -> Self {
716 self.key_facts_dedup_threshold = threshold;
717 self
718 }
719
720 #[must_use]
724 pub fn with_summarization_timeout(mut self, timeout_secs: u64) -> Self {
725 self.summarization_llm_timeout_secs = timeout_secs;
726 self
727 }
728
729 #[must_use]
736 pub fn with_query_bias(
737 mut self,
738 correction: QueryBiasCorrection,
739 profile_weight: f32,
740 centroid_ttl_secs: u64,
741 ) -> Self {
742 self.query_bias_correction = correction;
743 self.query_bias_profile_weight = profile_weight.clamp(0.0, 1.0);
744 self.profile_centroid_ttl_secs = centroid_ttl_secs;
745 self
746 }
747
748 #[must_use]
753 pub fn with_hebbian_spread(mut self, runtime: HelaSpreadRuntime) -> Self {
754 self.hebbian_spread = runtime;
755 self
756 }
757
758 #[must_use]
763 pub fn with_hebbian(mut self, reinforcement: HebbianReinforcement, lr: f32) -> Self {
764 let lr = lr.max(0.0);
765 if reinforcement.is_enabled() && lr == 0.0 {
766 tracing::warn!("hebbian enabled with lr=0.0 — no reinforcement will occur");
767 }
768 self.hebbian_reinforcement = reinforcement;
769 self.hebbian_lr = lr;
770 self
771 }
772
773 #[must_use]
778 pub fn with_query_sensitive_cost(mut self, enabled: bool) -> Self {
779 self.query_sensitive_cost = enabled;
780 self
781 }
782
783 pub(crate) fn classify_query_intent(query: &str) -> QueryIntent {
788 if persona::contains_self_referential_language(query) {
789 QueryIntent::FirstPerson
790 } else {
791 QueryIntent::Other
792 }
793 }
794
795 #[tracing::instrument(name = "memory.query_bias.apply", skip(self, embedding), fields(query_len = query.len()))]
801 pub(crate) async fn apply_query_bias(&self, query: &str, embedding: Vec<f32>) -> Vec<f32> {
802 if !self.query_bias_correction.is_enabled() {
803 tracing::debug!(reason = "disabled", "query-bias: skipping");
804 return embedding;
805 }
806 if Self::classify_query_intent(query) != QueryIntent::FirstPerson {
807 tracing::debug!(reason = "not_first_person", "query-bias: skipping");
808 return embedding;
809 }
810 let Some(centroid) = self.profile_centroid_cached().await else {
811 tracing::debug!(reason = "no_centroid", "query-bias: skipping");
812 return embedding;
813 };
814 if centroid.len() != embedding.len() {
815 tracing::warn!(
816 centroid_dim = centroid.len(),
817 query_dim = embedding.len(),
818 reason = "dim_mismatch",
819 "query-bias: dimension mismatch between profile centroid and query embedding — skipping bias"
820 );
821 return embedding;
822 }
823 let w = self.query_bias_profile_weight;
824 tracing::debug!(
825 intent = "first_person",
826 centroid_dim = centroid.len(),
827 weight = w,
828 "query-bias: applying profile bias"
829 );
830 embedding
831 .iter()
832 .zip(centroid.iter())
833 .map(|(&q, &c)| (1.0 - w) * q + w * c)
834 .collect()
835 }
836
837 #[tracing::instrument(name = "memory.query_bias.centroid", skip(self))]
842 pub(crate) async fn profile_centroid_cached(&self) -> Option<Vec<f32>> {
843 {
845 let guard = self.profile_centroid.read().await;
846 if let Some(c) = &*guard
847 && c.computed_at.elapsed().as_secs() < self.profile_centroid_ttl_secs
848 {
849 let ttl_remaining = self
850 .profile_centroid_ttl_secs
851 .saturating_sub(c.computed_at.elapsed().as_secs());
852 tracing::debug!(
853 centroid_dim = c.vector.len(),
854 ttl_remaining_secs = ttl_remaining,
855 "query-bias: centroid cache hit"
856 );
857 return Some(c.vector.clone());
858 }
859 }
860 let computed = self.compute_profile_centroid().await;
862 let mut guard = self.profile_centroid.write().await;
863 match computed {
864 Some(v) => {
865 tracing::debug!(centroid_dim = v.len(), "query-bias: centroid computed");
866 *guard = Some(CachedCentroid {
867 vector: v.clone(),
868 computed_at: Instant::now(),
869 });
870 Some(v)
871 }
872 None => {
873 guard.as_ref().map(|c| c.vector.clone())
875 }
876 }
877 }
878
879 async fn compute_profile_centroid(&self) -> Option<Vec<f32>> {
884 let facts = match self.sqlite.load_persona_facts(0.0).await {
885 Ok(f) => f,
886 Err(e) => {
887 tracing::warn!(error = %e, "query-bias: failed to load persona facts");
888 return None;
889 }
890 };
891 if facts.is_empty() {
892 return None;
893 }
894 let provider = self.effective_embed_provider();
895 let texts: Vec<String> = facts.iter().map(|f| f.content.clone()).collect();
896 let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
897 for text in &texts {
898 match tokio::time::timeout(std::time::Duration::from_secs(5), provider.embed(text))
899 .await
900 {
901 Ok(Ok(v)) => embeddings.push(v),
902 Ok(Err(e)) => {
903 tracing::warn!(error = %e, "query-bias: failed to embed persona fact — skipping");
904 }
905 Err(_) => {
906 tracing::warn!("query-bias: embed timed out for persona fact — skipping");
907 }
908 }
909 }
910 if embeddings.is_empty() {
911 return None;
912 }
913 let dim = embeddings[0].len();
914 let mut centroid = vec![0.0f32; dim];
915 for emb in &embeddings {
916 if emb.len() != dim {
917 tracing::warn!(
918 expected = dim,
919 got = emb.len(),
920 "query-bias: persona embedding dimension mismatch — skipping fact"
921 );
922 continue;
923 }
924 for (c, &v) in centroid.iter_mut().zip(emb.iter()) {
925 *c += v;
926 }
927 }
928 #[allow(clippy::cast_precision_loss)]
929 let n = embeddings.len() as f32;
930 for c in &mut centroid {
931 *c /= n;
932 }
933 Some(centroid)
934 }
935
936 #[must_use]
944 pub fn with_retrieval_options(
945 mut self,
946 depth: u32,
947 search_prompt_template: impl Into<String>,
948 ) -> Self {
949 self.retrieval_depth = depth;
950 self.search_prompt_template = search_prompt_template.into();
951 self
952 }
953
954 pub(crate) fn effective_depth(&self, limit: usize) -> usize {
963 use std::sync::atomic::Ordering;
964
965 let depth = self.retrieval_depth as usize;
966 if depth == 0 {
967 return limit.saturating_mul(2);
968 }
969 if depth < limit {
970 if !self.depth_below_limit_warned.swap(true, Ordering::Relaxed) {
971 tracing::warn!(
972 retrieval_depth = depth,
973 recall_limit = limit,
974 "memory.retrieval.depth < recall_limit; ANN pool cannot saturate top-k — consider raising depth"
975 );
976 }
977 } else if depth < limit.saturating_mul(2) {
978 tracing::info!(
979 retrieval_depth = depth,
980 recall_limit = limit,
981 legacy_default = limit.saturating_mul(2),
982 "memory.retrieval.depth is below legacy limit*2; ANN pool will be smaller than pre-#3340"
983 );
984 } else {
985 tracing::debug!(
986 retrieval_depth = depth,
987 recall_limit = limit,
988 "recall: using configured ANN depth"
989 );
990 }
991 depth
992 }
993
994 pub(crate) fn apply_search_prompt(&self, query: &str) -> String {
999 use std::sync::atomic::Ordering;
1000
1001 let template = &self.search_prompt_template;
1002 if template.is_empty() {
1003 return query.to_owned();
1004 }
1005 if !template.contains("{query}") {
1006 if !self
1007 .missing_placeholder_warned
1008 .swap(true, Ordering::Relaxed)
1009 {
1010 tracing::warn!(
1011 template = template.as_str(),
1012 "memory.retrieval.search_prompt_template has no {{query}} placeholder — \
1013 using raw query as-is"
1014 );
1015 }
1016 return query.to_owned();
1017 }
1018 template.replace("{query}", query)
1019 }
1020
1021 #[must_use]
1027 pub fn with_embed_provider(mut self, embed_provider: AnyProvider) -> Self {
1028 self.embed_provider = Some(embed_provider);
1029 self
1030 }
1031
1032 pub fn effective_embed_provider(&self) -> &AnyProvider {
1036 self.embed_provider.as_ref().unwrap_or(&self.provider)
1037 }
1038
1039 #[must_use]
1043 pub fn from_parts(
1044 sqlite: SqliteStore,
1045 qdrant: Option<Arc<EmbeddingStore>>,
1046 provider: AnyProvider,
1047 embedding_model: impl Into<String>,
1048 vector_weight: f64,
1049 keyword_weight: f64,
1050 token_counter: Arc<TokenCounter>,
1051 ) -> Self {
1052 Self {
1053 sqlite,
1054 qdrant,
1055 provider,
1056 embed_provider: None,
1057 embedding_model: embedding_model.into(),
1058 vector_weight,
1059 keyword_weight,
1060 temporal_decay: TemporalDecay::Disabled,
1061 temporal_decay_half_life_days: 30,
1062 mmr_reranking: MmrReranking::Disabled,
1063 mmr_lambda: 0.7,
1064 importance_scoring: ImportanceScoring::Disabled,
1065 importance_weight: 0.15,
1066 tier_boost_semantic: 1.3,
1067 token_counter,
1068 graph_store: None,
1069 experience: None,
1070 reasoning: None,
1071 community_detection_failures: Arc::new(AtomicU64::new(0)),
1072 graph_extraction_count: Arc::new(AtomicU64::new(0)),
1073 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
1074 last_qdrant_warn: Arc::new(AtomicU64::new(0)),
1075 admission_control: None,
1076 quality_gate: None,
1077 key_facts_dedup_threshold: 0.95,
1078 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
1079 retrieval_depth: 0,
1080 search_prompt_template: String::new(),
1081 depth_below_limit_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
1082 missing_placeholder_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
1083 query_bias_correction: QueryBiasCorrection::Enabled,
1084 query_bias_profile_weight: 0.25,
1085 profile_centroid: RwLock::new(None),
1086 profile_centroid_ttl_secs: 300,
1087 hebbian_reinforcement: HebbianReinforcement::Disabled,
1088 hebbian_lr: 0.1,
1089 hebbian_spread: HelaSpreadRuntime::default(),
1090 retrieval_failure_logger: None,
1091 summarization_llm_timeout_secs: 60,
1092 query_sensitive_cost: false,
1093 }
1094 }
1095
1096 pub async fn with_sqlite_backend(
1102 sqlite_path: &str,
1103 provider: AnyProvider,
1104 embedding_model: &str,
1105 vector_weight: f64,
1106 keyword_weight: f64,
1107 ) -> Result<Self, MemoryError> {
1108 Self::with_sqlite_backend_and_pool_size(
1109 sqlite_path,
1110 provider,
1111 embedding_model,
1112 vector_weight,
1113 keyword_weight,
1114 5,
1115 )
1116 .await
1117 }
1118
1119 pub async fn with_sqlite_backend_and_pool_size(
1125 sqlite_path: &str,
1126 provider: AnyProvider,
1127 embedding_model: &str,
1128 vector_weight: f64,
1129 keyword_weight: f64,
1130 pool_size: u32,
1131 ) -> Result<Self, MemoryError> {
1132 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
1133 let pool = sqlite.pool().clone();
1134 let store = EmbeddingStore::new_sqlite(pool);
1135
1136 Ok(Self {
1137 sqlite,
1138 qdrant: Some(Arc::new(store)),
1139 provider,
1140 embed_provider: None,
1141 embedding_model: embedding_model.into(),
1142 vector_weight,
1143 keyword_weight,
1144 temporal_decay: TemporalDecay::Disabled,
1145 temporal_decay_half_life_days: 30,
1146 mmr_reranking: MmrReranking::Disabled,
1147 mmr_lambda: 0.7,
1148 importance_scoring: ImportanceScoring::Disabled,
1149 importance_weight: 0.15,
1150 tier_boost_semantic: 1.3,
1151 token_counter: Arc::new(TokenCounter::new()),
1152 graph_store: None,
1153 experience: None,
1154 reasoning: None,
1155 community_detection_failures: Arc::new(AtomicU64::new(0)),
1156 graph_extraction_count: Arc::new(AtomicU64::new(0)),
1157 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
1158 last_qdrant_warn: Arc::new(AtomicU64::new(0)),
1159 admission_control: None,
1160 quality_gate: None,
1161 key_facts_dedup_threshold: 0.95,
1162 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
1163 retrieval_depth: 0,
1164 search_prompt_template: String::new(),
1165 depth_below_limit_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
1166 missing_placeholder_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
1167 query_bias_correction: QueryBiasCorrection::Enabled,
1168 query_bias_profile_weight: 0.25,
1169 profile_centroid: RwLock::new(None),
1170 profile_centroid_ttl_secs: 300,
1171 hebbian_reinforcement: HebbianReinforcement::Disabled,
1172 hebbian_lr: 0.1,
1173 hebbian_spread: HelaSpreadRuntime::default(),
1174 retrieval_failure_logger: None,
1175 summarization_llm_timeout_secs: 60,
1176 query_sensitive_cost: false,
1177 })
1178 }
1179
1180 #[must_use]
1182 pub fn sqlite(&self) -> &SqliteStore {
1183 &self.sqlite
1184 }
1185
1186 pub async fn is_vector_store_connected(&self) -> bool {
1191 match self.qdrant.as_ref() {
1192 Some(store) => store.health_check().await,
1193 None => false,
1194 }
1195 }
1196
1197 #[must_use]
1199 pub fn has_vector_store(&self) -> bool {
1200 self.qdrant.is_some()
1201 }
1202
1203 #[must_use]
1205 pub fn embedding_store(&self) -> Option<&Arc<EmbeddingStore>> {
1206 self.qdrant.as_ref()
1207 }
1208
1209 pub fn provider(&self) -> &AnyProvider {
1211 &self.provider
1212 }
1213
1214 pub async fn message_count(
1220 &self,
1221 conversation_id: crate::types::ConversationId,
1222 ) -> Result<i64, MemoryError> {
1223 self.sqlite.count_messages(conversation_id).await
1224 }
1225
1226 pub async fn unsummarized_message_count(
1232 &self,
1233 conversation_id: crate::types::ConversationId,
1234 ) -> Result<i64, MemoryError> {
1235 let after_id = self
1236 .sqlite
1237 .latest_summary_last_message_id(conversation_id)
1238 .await?
1239 .unwrap_or(crate::types::MessageId(0));
1240 self.sqlite
1241 .count_messages_after(conversation_id, after_id)
1242 .await
1243 }
1244
1245 pub async fn load_promotion_window(
1266 &self,
1267 max_items: usize,
1268 ) -> Result<Vec<crate::compression::promotion::PromotionInput>, MemoryError> {
1269 use zeph_db::sql;
1270
1271 let limit = i64::try_from(max_items).unwrap_or(i64::MAX);
1272 let rows: Vec<(
1273 crate::types::MessageId,
1274 crate::types::ConversationId,
1275 String,
1276 )> = zeph_db::query_as(sql!(
1277 "SELECT id, conversation_id, content \
1278 FROM messages \
1279 WHERE deleted_at IS NULL \
1280 ORDER BY id DESC \
1281 LIMIT ?"
1282 ))
1283 .bind(limit)
1284 .fetch_all(self.sqlite.pool())
1285 .await?;
1286
1287 let mut vectors = if let Some(qdrant) = &self.qdrant {
1288 let ids: Vec<_> = rows.iter().map(|(id, _, _)| *id).collect();
1289 let mut raw = qdrant.get_vectors_for_messages(&ids).await?;
1290
1291 let ref_dim = raw.values().next().map(Vec::len);
1293 if let Some(ref_dim) = ref_dim {
1294 let mismatched: Vec<_> = raw
1295 .iter()
1296 .filter(|(_, v)| v.len() != ref_dim)
1297 .map(|(id, v)| (*id, v.len()))
1298 .collect();
1299 if !mismatched.is_empty() {
1300 tracing::warn!(
1301 expected_dim = ref_dim,
1302 dropped_count = mismatched.len(),
1303 "load_promotion_window: dimension mismatch — dropping mismatched vectors"
1304 );
1305 for (id, _) in mismatched {
1306 raw.remove(&id);
1307 }
1308 }
1309 }
1310 raw
1311 } else {
1312 std::collections::HashMap::new()
1313 };
1314
1315 Ok(rows
1316 .into_iter()
1317 .map(|(message_id, conversation_id, content)| {
1318 crate::compression::promotion::PromotionInput {
1319 message_id,
1320 conversation_id,
1321 content,
1322 embedding: vectors.remove(&message_id),
1323 }
1324 })
1325 .collect())
1326 }
1327
1328 pub async fn retrieve_reasoning_strategies(
1341 &self,
1342 query: &str,
1343 limit: usize,
1344 ) -> Result<Vec<crate::reasoning::ReasoningStrategy>, MemoryError> {
1345 let Some(reasoning) = &self.reasoning else {
1346 return Ok(Vec::new());
1347 };
1348 if !self.effective_embed_provider().supports_embeddings() {
1349 return Ok(Vec::new());
1350 }
1351 let embedding = match tokio::time::timeout(
1352 std::time::Duration::from_secs(5),
1353 self.effective_embed_provider().embed(query),
1354 )
1355 .await
1356 {
1357 Ok(Ok(v)) => v,
1358 Ok(Err(e)) => return Err(e.into()),
1359 Err(_) => {
1360 tracing::warn!("retrieve_reasoning_strategies: embed timed out, returning empty");
1361 return Ok(Vec::new());
1362 }
1363 };
1364 reasoning
1365 .retrieve_by_embedding(&embedding, limit as u64)
1366 .await
1367 }
1368}