1mod algorithms;
24mod corrections;
25mod cross_session;
26mod graph;
27pub(crate) mod importance;
28pub mod persona;
29mod recall;
30mod summarization;
31pub mod trajectory;
32pub mod tree_consolidation;
33pub(crate) mod write_buffer;
34
35#[cfg(test)]
36mod tests;
37
38use std::sync::Arc;
39use std::sync::Mutex;
40use std::sync::atomic::AtomicU64;
41use std::time::Instant;
42
43use tokio::sync::RwLock;
44use zeph_llm::any::AnyProvider;
45use zeph_llm::provider::LlmProvider as _;
46
47use crate::admission::AdmissionControl;
48use crate::embedding_store::EmbeddingStore;
49use crate::error::MemoryError;
50use crate::retrieval_failure_logger::RetrievalFailureLogger;
51use crate::store::SqliteStore;
52use crate::store::retrieval_failures::RetrievalFailureRecord;
53use crate::token_counter::TokenCounter;
54
55pub(crate) const SESSION_SUMMARIES_COLLECTION: &str = "zeph_session_summaries";
56pub(crate) const KEY_FACTS_COLLECTION: &str = "zeph_key_facts";
57pub(crate) const CORRECTIONS_COLLECTION: &str = "zeph_corrections";
58
59#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub struct BackfillProgress {
62 pub done: usize,
64 pub total: usize,
66}
67
68pub use algorithms::{apply_mmr, apply_temporal_decay};
69pub use cross_session::SessionSummaryResult;
70pub use graph::{
71 ExtractionResult, ExtractionStats, GraphExtractionConfig, LinkingStats, NoteLinkingConfig,
72 PostExtractValidator, extract_and_store, link_memory_notes,
73};
74pub use persona::{
75 PersonaExtractionConfig, contains_self_referential_language, extract_persona_facts,
76};
77pub use recall::{EmbedContext, RecalledMessage};
78pub use summarization::{StructuredSummary, Summary, build_summarization_prompt};
79pub use trajectory::{TrajectoryEntry, TrajectoryExtractionConfig, extract_trajectory_entries};
80pub use tree_consolidation::{
81 TreeConsolidationConfig, TreeConsolidationResult, run_tree_consolidation_sweep,
82 start_tree_consolidation_loop,
83};
84pub use write_buffer::{BufferedWrite, WriteBuffer};
85
86#[derive(Debug, Clone)]
91pub(crate) struct CachedCentroid {
92 pub vector: Vec<f32>,
94 pub computed_at: Instant,
96}
97
98#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
103pub enum TemporalDecay {
104 Enabled,
106 #[default]
108 Disabled,
109}
110
111impl TemporalDecay {
112 #[must_use]
114 #[inline]
115 pub fn is_enabled(self) -> bool {
116 self == Self::Enabled
117 }
118}
119
120impl From<bool> for TemporalDecay {
121 fn from(b: bool) -> Self {
122 if b { Self::Enabled } else { Self::Disabled }
123 }
124}
125
126#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
131pub enum MmrReranking {
132 Enabled,
134 #[default]
136 Disabled,
137}
138
139impl MmrReranking {
140 #[must_use]
142 #[inline]
143 pub fn is_enabled(self) -> bool {
144 self == Self::Enabled
145 }
146}
147
148impl From<bool> for MmrReranking {
149 fn from(b: bool) -> Self {
150 if b { Self::Enabled } else { Self::Disabled }
151 }
152}
153
154#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
159pub enum ImportanceScoring {
160 Enabled,
162 #[default]
164 Disabled,
165}
166
167impl ImportanceScoring {
168 #[must_use]
170 #[inline]
171 pub fn is_enabled(self) -> bool {
172 self == Self::Enabled
173 }
174}
175
176impl From<bool> for ImportanceScoring {
177 fn from(b: bool) -> Self {
178 if b { Self::Enabled } else { Self::Disabled }
179 }
180}
181
182#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
187pub enum QueryBiasCorrection {
188 #[default]
190 Enabled,
191 Disabled,
193}
194
195impl QueryBiasCorrection {
196 #[must_use]
198 #[inline]
199 pub fn is_enabled(self) -> bool {
200 self == Self::Enabled
201 }
202}
203
204impl From<bool> for QueryBiasCorrection {
205 fn from(b: bool) -> Self {
206 if b { Self::Enabled } else { Self::Disabled }
207 }
208}
209
210#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
215pub enum HebbianReinforcement {
216 Enabled,
218 #[default]
220 Disabled,
221}
222
223impl HebbianReinforcement {
224 #[must_use]
226 #[inline]
227 pub fn is_enabled(self) -> bool {
228 self == Self::Enabled
229 }
230}
231
232impl From<bool> for HebbianReinforcement {
233 fn from(b: bool) -> Self {
234 if b { Self::Enabled } else { Self::Disabled }
235 }
236}
237
238#[derive(Debug, Clone, Copy, PartialEq, Eq)]
243pub(crate) enum QueryIntent {
244 FirstPerson,
246 Other,
248}
249
250#[derive(Debug, Clone, Default)]
254pub struct HelaSpreadRuntime {
255 pub enabled: bool,
257 pub depth: u32,
259 pub max_visited: usize,
261 pub edge_types: Vec<crate::graph::EdgeType>,
263 pub step_budget: Option<std::time::Duration>,
265}
266
267pub struct SemanticMemory {
272 pub(crate) sqlite: SqliteStore,
273 pub(crate) qdrant: Option<Arc<EmbeddingStore>>,
274 pub(crate) provider: AnyProvider,
275 pub(crate) embed_provider: Option<AnyProvider>,
281 pub(crate) embedding_model: String,
282 pub(crate) vector_weight: f64,
283 pub(crate) keyword_weight: f64,
284 pub(crate) temporal_decay: TemporalDecay,
285 pub(crate) temporal_decay_half_life_days: u32,
286 pub(crate) mmr_reranking: MmrReranking,
287 pub(crate) mmr_lambda: f32,
288 pub(crate) importance_scoring: ImportanceScoring,
289 pub(crate) importance_weight: f64,
290 pub(crate) tier_boost_semantic: f64,
293 pub token_counter: Arc<TokenCounter>,
294 pub graph_store: Option<Arc<crate::graph::GraphStore>>,
295 pub experience: Option<Arc<crate::graph::experience::ExperienceStore>>,
299 pub reasoning: Option<Arc<crate::reasoning::ReasoningMemory>>,
303 pub(crate) community_detection_failures: Arc<AtomicU64>,
304 pub(crate) graph_extraction_count: Arc<AtomicU64>,
305 pub(crate) graph_extraction_failures: Arc<AtomicU64>,
306 pub(crate) last_qdrant_warn: Arc<AtomicU64>,
307 pub(crate) admission_control: Option<Arc<AdmissionControl>>,
309 pub(crate) quality_gate: Option<Arc<crate::quality_gate::QualityGate>>,
312 pub(crate) key_facts_dedup_threshold: f32,
316 pub(crate) embed_tasks: Mutex<tokio::task::JoinSet<()>>,
322 pub(crate) retrieval_depth: u32,
326 pub(crate) search_prompt_template: String,
331 pub(crate) depth_below_limit_warned: Arc<std::sync::atomic::AtomicBool>,
333 pub(crate) missing_placeholder_warned: Arc<std::sync::atomic::AtomicBool>,
335 pub(crate) query_bias_correction: QueryBiasCorrection,
337 pub(crate) query_bias_profile_weight: f32,
339 pub(crate) profile_centroid: RwLock<Option<CachedCentroid>>,
344 pub(crate) profile_centroid_ttl_secs: u64,
346 pub(crate) hebbian_reinforcement: HebbianReinforcement,
348 pub(crate) hebbian_lr: f32,
350 pub(crate) hebbian_spread: HelaSpreadRuntime,
352 pub(crate) retrieval_failure_logger: Option<RetrievalFailureLogger>,
356}
357
358impl SemanticMemory {
359 pub async fn new(
370 sqlite_path: &str,
371 qdrant_url: &str,
372 api_key: Option<&str>,
373 provider: AnyProvider,
374 embedding_model: &str,
375 ) -> Result<Self, MemoryError> {
376 Self::with_weights(
377 sqlite_path,
378 qdrant_url,
379 api_key,
380 provider,
381 embedding_model,
382 0.7,
383 0.3,
384 )
385 .await
386 }
387
388 pub async fn with_weights(
397 sqlite_path: &str,
398 qdrant_url: &str,
399 api_key: Option<&str>,
400 provider: AnyProvider,
401 embedding_model: &str,
402 vector_weight: f64,
403 keyword_weight: f64,
404 ) -> Result<Self, MemoryError> {
405 Self::with_weights_and_pool_size(
406 sqlite_path,
407 qdrant_url,
408 api_key,
409 provider,
410 embedding_model,
411 vector_weight,
412 keyword_weight,
413 5,
414 )
415 .await
416 }
417
418 #[allow(clippy::too_many_arguments)]
427 pub async fn with_weights_and_pool_size(
428 sqlite_path: &str,
429 qdrant_url: &str,
430 api_key: Option<&str>,
431 provider: AnyProvider,
432 embedding_model: &str,
433 vector_weight: f64,
434 keyword_weight: f64,
435 pool_size: u32,
436 ) -> Result<Self, MemoryError> {
437 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
438 let pool = sqlite.pool().clone();
439
440 let qdrant = match EmbeddingStore::new(qdrant_url, api_key, pool) {
441 Ok(store) => Some(Arc::new(store)),
442 Err(e) => {
443 tracing::warn!("Qdrant unavailable, semantic search disabled: {e:#}");
444 None
445 }
446 };
447
448 Ok(Self {
449 sqlite,
450 qdrant,
451 provider,
452 embed_provider: None,
453 embedding_model: embedding_model.into(),
454 vector_weight,
455 keyword_weight,
456 temporal_decay: TemporalDecay::Disabled,
457 temporal_decay_half_life_days: 30,
458 mmr_reranking: MmrReranking::Disabled,
459 mmr_lambda: 0.7,
460 importance_scoring: ImportanceScoring::Disabled,
461 importance_weight: 0.15,
462 tier_boost_semantic: 1.3,
463 token_counter: Arc::new(TokenCounter::new()),
464 graph_store: None,
465 experience: None,
466 reasoning: None,
467 community_detection_failures: Arc::new(AtomicU64::new(0)),
468 graph_extraction_count: Arc::new(AtomicU64::new(0)),
469 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
470 last_qdrant_warn: Arc::new(AtomicU64::new(0)),
471 admission_control: None,
472 quality_gate: None,
473 key_facts_dedup_threshold: 0.95,
474 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
475 retrieval_depth: 0,
476 search_prompt_template: String::new(),
477 depth_below_limit_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
478 missing_placeholder_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
479 query_bias_correction: QueryBiasCorrection::Enabled,
480 query_bias_profile_weight: 0.25,
481 profile_centroid: RwLock::new(None),
482 profile_centroid_ttl_secs: 300,
483 hebbian_reinforcement: HebbianReinforcement::Disabled,
484 hebbian_lr: 0.1,
485 hebbian_spread: HelaSpreadRuntime::default(),
486 retrieval_failure_logger: None,
487 })
488 }
489
490 pub async fn with_qdrant_ops(
499 sqlite_path: &str,
500 ops: crate::QdrantOps,
501 provider: AnyProvider,
502 embedding_model: &str,
503 vector_weight: f64,
504 keyword_weight: f64,
505 pool_size: u32,
506 ) -> Result<Self, MemoryError> {
507 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
508 let pool = sqlite.pool().clone();
509 let store = EmbeddingStore::with_store(Box::new(ops), pool);
510
511 Ok(Self {
512 sqlite,
513 qdrant: Some(Arc::new(store)),
514 provider,
515 embed_provider: None,
516 embedding_model: embedding_model.into(),
517 vector_weight,
518 keyword_weight,
519 temporal_decay: TemporalDecay::Disabled,
520 temporal_decay_half_life_days: 30,
521 mmr_reranking: MmrReranking::Disabled,
522 mmr_lambda: 0.7,
523 importance_scoring: ImportanceScoring::Disabled,
524 importance_weight: 0.15,
525 tier_boost_semantic: 1.3,
526 token_counter: Arc::new(TokenCounter::new()),
527 graph_store: None,
528 experience: None,
529 reasoning: None,
530 community_detection_failures: Arc::new(AtomicU64::new(0)),
531 graph_extraction_count: Arc::new(AtomicU64::new(0)),
532 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
533 last_qdrant_warn: Arc::new(AtomicU64::new(0)),
534 admission_control: None,
535 quality_gate: None,
536 key_facts_dedup_threshold: 0.95,
537 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
538 retrieval_depth: 0,
539 search_prompt_template: String::new(),
540 depth_below_limit_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
541 missing_placeholder_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
542 query_bias_correction: QueryBiasCorrection::Enabled,
543 query_bias_profile_weight: 0.25,
544 profile_centroid: RwLock::new(None),
545 profile_centroid_ttl_secs: 300,
546 hebbian_reinforcement: HebbianReinforcement::Disabled,
547 hebbian_lr: 0.1,
548 hebbian_spread: HelaSpreadRuntime::default(),
549 retrieval_failure_logger: None,
550 })
551 }
552
553 #[must_use]
558 pub fn with_graph_store(mut self, store: Arc<crate::graph::GraphStore>) -> Self {
559 self.graph_store = Some(store);
560 self
561 }
562
563 #[must_use]
569 pub fn with_experience_store(
570 mut self,
571 store: Arc<crate::graph::experience::ExperienceStore>,
572 ) -> Self {
573 self.experience = Some(store);
574 self
575 }
576
577 #[must_use]
583 pub fn with_reasoning(mut self, store: Arc<crate::reasoning::ReasoningMemory>) -> Self {
584 self.reasoning = Some(store);
585 self
586 }
587
588 #[must_use]
593 pub fn with_retrieval_failure_logger(mut self, logger: RetrievalFailureLogger) -> Self {
594 self.retrieval_failure_logger = Some(logger);
595 self
596 }
597
598 pub fn log_retrieval_failure(&self, record: RetrievalFailureRecord) {
604 if let Some(logger) = &self.retrieval_failure_logger {
605 logger.log(record);
606 }
607 }
608
609 #[must_use]
611 pub fn community_detection_failures(&self) -> u64 {
612 use std::sync::atomic::Ordering;
613 self.community_detection_failures.load(Ordering::Relaxed)
614 }
615
616 #[must_use]
618 pub fn graph_extraction_count(&self) -> u64 {
619 use std::sync::atomic::Ordering;
620 self.graph_extraction_count.load(Ordering::Relaxed)
621 }
622
623 #[must_use]
625 pub fn graph_extraction_failures(&self) -> u64 {
626 use std::sync::atomic::Ordering;
627 self.graph_extraction_failures.load(Ordering::Relaxed)
628 }
629
630 #[must_use]
632 pub fn with_ranking_options(
633 mut self,
634 temporal_decay: TemporalDecay,
635 temporal_decay_half_life_days: u32,
636 mmr_reranking: MmrReranking,
637 mmr_lambda: f32,
638 ) -> Self {
639 self.temporal_decay = temporal_decay;
640 self.temporal_decay_half_life_days = temporal_decay_half_life_days;
641 self.mmr_reranking = mmr_reranking;
642 self.mmr_lambda = mmr_lambda;
643 self
644 }
645
646 #[must_use]
648 pub fn with_importance_options(mut self, scoring: ImportanceScoring, weight: f64) -> Self {
649 self.importance_scoring = scoring;
650 self.importance_weight = weight;
651 self
652 }
653
654 #[must_use]
658 pub fn with_tier_boost(mut self, boost: f64) -> Self {
659 self.tier_boost_semantic = boost;
660 self
661 }
662
663 #[must_use]
668 pub fn with_admission_control(mut self, control: AdmissionControl) -> Self {
669 self.admission_control = Some(Arc::new(control));
670 self
671 }
672
673 #[must_use]
679 pub fn with_quality_gate(mut self, gate: Arc<crate::quality_gate::QualityGate>) -> Self {
680 self.quality_gate = Some(gate);
681 self
682 }
683
684 #[must_use]
689 pub fn with_key_facts_dedup_threshold(mut self, threshold: f32) -> Self {
690 self.key_facts_dedup_threshold = threshold;
691 self
692 }
693
694 #[must_use]
701 pub fn with_query_bias(
702 mut self,
703 correction: QueryBiasCorrection,
704 profile_weight: f32,
705 centroid_ttl_secs: u64,
706 ) -> Self {
707 self.query_bias_correction = correction;
708 self.query_bias_profile_weight = profile_weight.clamp(0.0, 1.0);
709 self.profile_centroid_ttl_secs = centroid_ttl_secs;
710 self
711 }
712
713 #[must_use]
718 pub fn with_hebbian_spread(mut self, runtime: HelaSpreadRuntime) -> Self {
719 self.hebbian_spread = runtime;
720 self
721 }
722
723 #[must_use]
728 pub fn with_hebbian(mut self, reinforcement: HebbianReinforcement, lr: f32) -> Self {
729 let lr = lr.max(0.0);
730 if reinforcement.is_enabled() && lr == 0.0 {
731 tracing::warn!("hebbian enabled with lr=0.0 — no reinforcement will occur");
732 }
733 self.hebbian_reinforcement = reinforcement;
734 self.hebbian_lr = lr;
735 self
736 }
737
738 pub(crate) fn classify_query_intent(query: &str) -> QueryIntent {
743 if persona::contains_self_referential_language(query) {
744 QueryIntent::FirstPerson
745 } else {
746 QueryIntent::Other
747 }
748 }
749
750 #[tracing::instrument(name = "memory.query_bias.apply", skip(self, embedding), fields(query_len = query.len()))]
756 pub(crate) async fn apply_query_bias(&self, query: &str, embedding: Vec<f32>) -> Vec<f32> {
757 if !self.query_bias_correction.is_enabled() {
758 tracing::debug!(reason = "disabled", "query-bias: skipping");
759 return embedding;
760 }
761 if Self::classify_query_intent(query) != QueryIntent::FirstPerson {
762 tracing::debug!(reason = "not_first_person", "query-bias: skipping");
763 return embedding;
764 }
765 let Some(centroid) = self.profile_centroid_cached().await else {
766 tracing::debug!(reason = "no_centroid", "query-bias: skipping");
767 return embedding;
768 };
769 if centroid.len() != embedding.len() {
770 tracing::warn!(
771 centroid_dim = centroid.len(),
772 query_dim = embedding.len(),
773 reason = "dim_mismatch",
774 "query-bias: dimension mismatch between profile centroid and query embedding — skipping bias"
775 );
776 return embedding;
777 }
778 let w = self.query_bias_profile_weight;
779 tracing::debug!(
780 intent = "first_person",
781 centroid_dim = centroid.len(),
782 weight = w,
783 "query-bias: applying profile bias"
784 );
785 embedding
786 .iter()
787 .zip(centroid.iter())
788 .map(|(&q, &c)| (1.0 - w) * q + w * c)
789 .collect()
790 }
791
792 #[tracing::instrument(name = "memory.query_bias.centroid", skip(self))]
797 pub(crate) async fn profile_centroid_cached(&self) -> Option<Vec<f32>> {
798 {
800 let guard = self.profile_centroid.read().await;
801 if let Some(c) = &*guard
802 && c.computed_at.elapsed().as_secs() < self.profile_centroid_ttl_secs
803 {
804 let ttl_remaining = self
805 .profile_centroid_ttl_secs
806 .saturating_sub(c.computed_at.elapsed().as_secs());
807 tracing::debug!(
808 centroid_dim = c.vector.len(),
809 ttl_remaining_secs = ttl_remaining,
810 "query-bias: centroid cache hit"
811 );
812 return Some(c.vector.clone());
813 }
814 }
815 let computed = self.compute_profile_centroid().await;
817 let mut guard = self.profile_centroid.write().await;
818 match computed {
819 Some(v) => {
820 tracing::debug!(centroid_dim = v.len(), "query-bias: centroid computed");
821 *guard = Some(CachedCentroid {
822 vector: v.clone(),
823 computed_at: Instant::now(),
824 });
825 Some(v)
826 }
827 None => {
828 guard.as_ref().map(|c| c.vector.clone())
830 }
831 }
832 }
833
834 async fn compute_profile_centroid(&self) -> Option<Vec<f32>> {
839 let facts = match self.sqlite.load_persona_facts(0.0).await {
840 Ok(f) => f,
841 Err(e) => {
842 tracing::warn!(error = %e, "query-bias: failed to load persona facts");
843 return None;
844 }
845 };
846 if facts.is_empty() {
847 return None;
848 }
849 let provider = self.effective_embed_provider();
850 let texts: Vec<String> = facts.iter().map(|f| f.content.clone()).collect();
851 let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(texts.len());
852 for text in &texts {
853 match provider.embed(text).await {
854 Ok(v) => embeddings.push(v),
855 Err(e) => {
856 tracing::warn!(error = %e, "query-bias: failed to embed persona fact — skipping");
857 }
858 }
859 }
860 if embeddings.is_empty() {
861 return None;
862 }
863 let dim = embeddings[0].len();
864 let mut centroid = vec![0.0f32; dim];
865 for emb in &embeddings {
866 if emb.len() != dim {
867 tracing::warn!(
868 expected = dim,
869 got = emb.len(),
870 "query-bias: persona embedding dimension mismatch — skipping fact"
871 );
872 continue;
873 }
874 for (c, &v) in centroid.iter_mut().zip(emb.iter()) {
875 *c += v;
876 }
877 }
878 #[allow(clippy::cast_precision_loss)]
879 let n = embeddings.len() as f32;
880 for c in &mut centroid {
881 *c /= n;
882 }
883 Some(centroid)
884 }
885
886 #[must_use]
894 pub fn with_retrieval_options(
895 mut self,
896 depth: u32,
897 search_prompt_template: impl Into<String>,
898 ) -> Self {
899 self.retrieval_depth = depth;
900 self.search_prompt_template = search_prompt_template.into();
901 self
902 }
903
904 pub(crate) fn effective_depth(&self, limit: usize) -> usize {
913 use std::sync::atomic::Ordering;
914
915 let depth = self.retrieval_depth as usize;
916 if depth == 0 {
917 return limit.saturating_mul(2);
918 }
919 if depth < limit {
920 if !self.depth_below_limit_warned.swap(true, Ordering::Relaxed) {
921 tracing::warn!(
922 retrieval_depth = depth,
923 recall_limit = limit,
924 "memory.retrieval.depth < recall_limit; ANN pool cannot saturate top-k — consider raising depth"
925 );
926 }
927 } else if depth < limit.saturating_mul(2) {
928 tracing::info!(
929 retrieval_depth = depth,
930 recall_limit = limit,
931 legacy_default = limit.saturating_mul(2),
932 "memory.retrieval.depth is below legacy limit*2; ANN pool will be smaller than pre-#3340"
933 );
934 } else {
935 tracing::debug!(
936 retrieval_depth = depth,
937 recall_limit = limit,
938 "recall: using configured ANN depth"
939 );
940 }
941 depth
942 }
943
944 pub(crate) fn apply_search_prompt(&self, query: &str) -> String {
949 use std::sync::atomic::Ordering;
950
951 let template = &self.search_prompt_template;
952 if template.is_empty() {
953 return query.to_owned();
954 }
955 if !template.contains("{query}") {
956 if !self
957 .missing_placeholder_warned
958 .swap(true, Ordering::Relaxed)
959 {
960 tracing::warn!(
961 template = template.as_str(),
962 "memory.retrieval.search_prompt_template has no {{query}} placeholder — \
963 using raw query as-is"
964 );
965 }
966 return query.to_owned();
967 }
968 template.replace("{query}", query)
969 }
970
971 #[must_use]
977 pub fn with_embed_provider(mut self, embed_provider: AnyProvider) -> Self {
978 self.embed_provider = Some(embed_provider);
979 self
980 }
981
982 pub fn effective_embed_provider(&self) -> &AnyProvider {
986 self.embed_provider.as_ref().unwrap_or(&self.provider)
987 }
988
989 #[must_use]
993 pub fn from_parts(
994 sqlite: SqliteStore,
995 qdrant: Option<Arc<EmbeddingStore>>,
996 provider: AnyProvider,
997 embedding_model: impl Into<String>,
998 vector_weight: f64,
999 keyword_weight: f64,
1000 token_counter: Arc<TokenCounter>,
1001 ) -> Self {
1002 Self {
1003 sqlite,
1004 qdrant,
1005 provider,
1006 embed_provider: None,
1007 embedding_model: embedding_model.into(),
1008 vector_weight,
1009 keyword_weight,
1010 temporal_decay: TemporalDecay::Disabled,
1011 temporal_decay_half_life_days: 30,
1012 mmr_reranking: MmrReranking::Disabled,
1013 mmr_lambda: 0.7,
1014 importance_scoring: ImportanceScoring::Disabled,
1015 importance_weight: 0.15,
1016 tier_boost_semantic: 1.3,
1017 token_counter,
1018 graph_store: None,
1019 experience: None,
1020 reasoning: None,
1021 community_detection_failures: Arc::new(AtomicU64::new(0)),
1022 graph_extraction_count: Arc::new(AtomicU64::new(0)),
1023 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
1024 last_qdrant_warn: Arc::new(AtomicU64::new(0)),
1025 admission_control: None,
1026 quality_gate: None,
1027 key_facts_dedup_threshold: 0.95,
1028 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
1029 retrieval_depth: 0,
1030 search_prompt_template: String::new(),
1031 depth_below_limit_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
1032 missing_placeholder_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
1033 query_bias_correction: QueryBiasCorrection::Enabled,
1034 query_bias_profile_weight: 0.25,
1035 profile_centroid: RwLock::new(None),
1036 profile_centroid_ttl_secs: 300,
1037 hebbian_reinforcement: HebbianReinforcement::Disabled,
1038 hebbian_lr: 0.1,
1039 hebbian_spread: HelaSpreadRuntime::default(),
1040 retrieval_failure_logger: None,
1041 }
1042 }
1043
1044 pub async fn with_sqlite_backend(
1050 sqlite_path: &str,
1051 provider: AnyProvider,
1052 embedding_model: &str,
1053 vector_weight: f64,
1054 keyword_weight: f64,
1055 ) -> Result<Self, MemoryError> {
1056 Self::with_sqlite_backend_and_pool_size(
1057 sqlite_path,
1058 provider,
1059 embedding_model,
1060 vector_weight,
1061 keyword_weight,
1062 5,
1063 )
1064 .await
1065 }
1066
1067 pub async fn with_sqlite_backend_and_pool_size(
1073 sqlite_path: &str,
1074 provider: AnyProvider,
1075 embedding_model: &str,
1076 vector_weight: f64,
1077 keyword_weight: f64,
1078 pool_size: u32,
1079 ) -> Result<Self, MemoryError> {
1080 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
1081 let pool = sqlite.pool().clone();
1082 let store = EmbeddingStore::new_sqlite(pool);
1083
1084 Ok(Self {
1085 sqlite,
1086 qdrant: Some(Arc::new(store)),
1087 provider,
1088 embed_provider: None,
1089 embedding_model: embedding_model.into(),
1090 vector_weight,
1091 keyword_weight,
1092 temporal_decay: TemporalDecay::Disabled,
1093 temporal_decay_half_life_days: 30,
1094 mmr_reranking: MmrReranking::Disabled,
1095 mmr_lambda: 0.7,
1096 importance_scoring: ImportanceScoring::Disabled,
1097 importance_weight: 0.15,
1098 tier_boost_semantic: 1.3,
1099 token_counter: Arc::new(TokenCounter::new()),
1100 graph_store: None,
1101 experience: None,
1102 reasoning: None,
1103 community_detection_failures: Arc::new(AtomicU64::new(0)),
1104 graph_extraction_count: Arc::new(AtomicU64::new(0)),
1105 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
1106 last_qdrant_warn: Arc::new(AtomicU64::new(0)),
1107 admission_control: None,
1108 quality_gate: None,
1109 key_facts_dedup_threshold: 0.95,
1110 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
1111 retrieval_depth: 0,
1112 search_prompt_template: String::new(),
1113 depth_below_limit_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
1114 missing_placeholder_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
1115 query_bias_correction: QueryBiasCorrection::Enabled,
1116 query_bias_profile_weight: 0.25,
1117 profile_centroid: RwLock::new(None),
1118 profile_centroid_ttl_secs: 300,
1119 hebbian_reinforcement: HebbianReinforcement::Disabled,
1120 hebbian_lr: 0.1,
1121 hebbian_spread: HelaSpreadRuntime::default(),
1122 retrieval_failure_logger: None,
1123 })
1124 }
1125
1126 #[must_use]
1128 pub fn sqlite(&self) -> &SqliteStore {
1129 &self.sqlite
1130 }
1131
1132 pub async fn is_vector_store_connected(&self) -> bool {
1137 match self.qdrant.as_ref() {
1138 Some(store) => store.health_check().await,
1139 None => false,
1140 }
1141 }
1142
1143 #[must_use]
1145 pub fn has_vector_store(&self) -> bool {
1146 self.qdrant.is_some()
1147 }
1148
1149 #[must_use]
1151 pub fn embedding_store(&self) -> Option<&Arc<EmbeddingStore>> {
1152 self.qdrant.as_ref()
1153 }
1154
1155 pub fn provider(&self) -> &AnyProvider {
1157 &self.provider
1158 }
1159
1160 pub async fn message_count(
1166 &self,
1167 conversation_id: crate::types::ConversationId,
1168 ) -> Result<i64, MemoryError> {
1169 self.sqlite.count_messages(conversation_id).await
1170 }
1171
1172 pub async fn unsummarized_message_count(
1178 &self,
1179 conversation_id: crate::types::ConversationId,
1180 ) -> Result<i64, MemoryError> {
1181 let after_id = self
1182 .sqlite
1183 .latest_summary_last_message_id(conversation_id)
1184 .await?
1185 .unwrap_or(crate::types::MessageId(0));
1186 self.sqlite
1187 .count_messages_after(conversation_id, after_id)
1188 .await
1189 }
1190
1191 pub async fn load_promotion_window(
1212 &self,
1213 max_items: usize,
1214 ) -> Result<Vec<crate::compression::promotion::PromotionInput>, MemoryError> {
1215 use zeph_db::sql;
1216
1217 let limit = i64::try_from(max_items).unwrap_or(i64::MAX);
1218 let rows: Vec<(
1219 crate::types::MessageId,
1220 crate::types::ConversationId,
1221 String,
1222 )> = zeph_db::query_as(sql!(
1223 "SELECT id, conversation_id, content \
1224 FROM messages \
1225 WHERE deleted_at IS NULL \
1226 ORDER BY id DESC \
1227 LIMIT ?"
1228 ))
1229 .bind(limit)
1230 .fetch_all(self.sqlite.pool())
1231 .await?;
1232
1233 let mut vectors = if let Some(qdrant) = &self.qdrant {
1234 let ids: Vec<_> = rows.iter().map(|(id, _, _)| *id).collect();
1235 let mut raw = qdrant.get_vectors_for_messages(&ids).await?;
1236
1237 let ref_dim = raw.values().next().map(Vec::len);
1239 if let Some(ref_dim) = ref_dim {
1240 let mismatched: Vec<_> = raw
1241 .iter()
1242 .filter(|(_, v)| v.len() != ref_dim)
1243 .map(|(id, v)| (*id, v.len()))
1244 .collect();
1245 if !mismatched.is_empty() {
1246 tracing::warn!(
1247 expected_dim = ref_dim,
1248 dropped_count = mismatched.len(),
1249 "load_promotion_window: dimension mismatch — dropping mismatched vectors"
1250 );
1251 for (id, _) in mismatched {
1252 raw.remove(&id);
1253 }
1254 }
1255 }
1256 raw
1257 } else {
1258 std::collections::HashMap::new()
1259 };
1260
1261 Ok(rows
1262 .into_iter()
1263 .map(|(message_id, conversation_id, content)| {
1264 crate::compression::promotion::PromotionInput {
1265 message_id,
1266 conversation_id,
1267 content,
1268 embedding: vectors.remove(&message_id),
1269 }
1270 })
1271 .collect())
1272 }
1273
1274 pub async fn retrieve_reasoning_strategies(
1287 &self,
1288 query: &str,
1289 limit: usize,
1290 ) -> Result<Vec<crate::reasoning::ReasoningStrategy>, MemoryError> {
1291 let Some(reasoning) = &self.reasoning else {
1292 return Ok(Vec::new());
1293 };
1294 if !self.effective_embed_provider().supports_embeddings() {
1295 return Ok(Vec::new());
1296 }
1297 let embedding = self.effective_embed_provider().embed(query).await?;
1298 reasoning
1299 .retrieve_by_embedding(&embedding, limit as u64)
1300 .await
1301 }
1302}