1mod algorithms;
24mod corrections;
25mod cross_session;
26mod graph;
27pub(crate) mod importance;
28pub mod persona;
29mod recall;
30mod summarization;
31pub mod trajectory;
32pub mod tree_consolidation;
33pub(crate) mod write_buffer;
34
35#[cfg(test)]
36mod tests;
37
38use std::sync::Arc;
39use std::sync::Mutex;
40use std::sync::atomic::AtomicU64;
41
42use zeph_llm::any::AnyProvider;
43
44use crate::admission::AdmissionControl;
45use crate::embedding_store::EmbeddingStore;
46use crate::error::MemoryError;
47use crate::store::SqliteStore;
48use crate::token_counter::TokenCounter;
49
50pub(crate) const SESSION_SUMMARIES_COLLECTION: &str = "zeph_session_summaries";
51pub(crate) const KEY_FACTS_COLLECTION: &str = "zeph_key_facts";
52pub(crate) const CORRECTIONS_COLLECTION: &str = "zeph_corrections";
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub struct BackfillProgress {
57 pub done: usize,
59 pub total: usize,
61}
62
63pub use algorithms::{apply_mmr, apply_temporal_decay};
64pub use cross_session::SessionSummaryResult;
65pub use graph::{
66 ExtractionResult, ExtractionStats, GraphExtractionConfig, LinkingStats, NoteLinkingConfig,
67 PostExtractValidator, extract_and_store, link_memory_notes,
68};
69pub use persona::{
70 PersonaExtractionConfig, contains_self_referential_language, extract_persona_facts,
71};
72pub use recall::{EmbedContext, RecalledMessage};
73pub use summarization::{StructuredSummary, Summary, build_summarization_prompt};
74pub use trajectory::{TrajectoryEntry, TrajectoryExtractionConfig, extract_trajectory_entries};
75pub use tree_consolidation::{
76 TreeConsolidationConfig, TreeConsolidationResult, run_tree_consolidation_sweep,
77 start_tree_consolidation_loop,
78};
79pub use write_buffer::{BufferedWrite, WriteBuffer};
80
81pub struct SemanticMemory {
86 pub(crate) sqlite: SqliteStore,
87 pub(crate) qdrant: Option<Arc<EmbeddingStore>>,
88 pub(crate) provider: AnyProvider,
89 pub(crate) embed_provider: Option<AnyProvider>,
95 pub(crate) embedding_model: String,
96 pub(crate) vector_weight: f64,
97 pub(crate) keyword_weight: f64,
98 pub(crate) temporal_decay_enabled: bool,
99 pub(crate) temporal_decay_half_life_days: u32,
100 pub(crate) mmr_enabled: bool,
101 pub(crate) mmr_lambda: f32,
102 pub(crate) importance_enabled: bool,
103 pub(crate) importance_weight: f64,
104 pub(crate) tier_boost_semantic: f64,
107 pub token_counter: Arc<TokenCounter>,
108 pub graph_store: Option<Arc<crate::graph::GraphStore>>,
109 pub(crate) community_detection_failures: Arc<AtomicU64>,
110 pub(crate) graph_extraction_count: Arc<AtomicU64>,
111 pub(crate) graph_extraction_failures: Arc<AtomicU64>,
112 pub(crate) last_qdrant_warn: Arc<AtomicU64>,
113 pub(crate) admission_control: Option<Arc<AdmissionControl>>,
115 pub(crate) quality_gate: Option<Arc<crate::quality_gate::QualityGate>>,
118 pub(crate) key_facts_dedup_threshold: f32,
122 pub(crate) embed_tasks: Mutex<tokio::task::JoinSet<()>>,
128}
129
130impl SemanticMemory {
131 pub async fn new(
142 sqlite_path: &str,
143 qdrant_url: &str,
144 provider: AnyProvider,
145 embedding_model: &str,
146 ) -> Result<Self, MemoryError> {
147 Self::with_weights(sqlite_path, qdrant_url, provider, embedding_model, 0.7, 0.3).await
148 }
149
150 pub async fn with_weights(
159 sqlite_path: &str,
160 qdrant_url: &str,
161 provider: AnyProvider,
162 embedding_model: &str,
163 vector_weight: f64,
164 keyword_weight: f64,
165 ) -> Result<Self, MemoryError> {
166 Self::with_weights_and_pool_size(
167 sqlite_path,
168 qdrant_url,
169 provider,
170 embedding_model,
171 vector_weight,
172 keyword_weight,
173 5,
174 )
175 .await
176 }
177
178 pub async fn with_weights_and_pool_size(
187 sqlite_path: &str,
188 qdrant_url: &str,
189 provider: AnyProvider,
190 embedding_model: &str,
191 vector_weight: f64,
192 keyword_weight: f64,
193 pool_size: u32,
194 ) -> Result<Self, MemoryError> {
195 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
196 let pool = sqlite.pool().clone();
197
198 let qdrant = match EmbeddingStore::new(qdrant_url, pool) {
199 Ok(store) => Some(Arc::new(store)),
200 Err(e) => {
201 tracing::warn!("Qdrant unavailable, semantic search disabled: {e:#}");
202 None
203 }
204 };
205
206 Ok(Self {
207 sqlite,
208 qdrant,
209 provider,
210 embed_provider: None,
211 embedding_model: embedding_model.into(),
212 vector_weight,
213 keyword_weight,
214 temporal_decay_enabled: false,
215 temporal_decay_half_life_days: 30,
216 mmr_enabled: false,
217 mmr_lambda: 0.7,
218 importance_enabled: false,
219 importance_weight: 0.15,
220 tier_boost_semantic: 1.3,
221 token_counter: Arc::new(TokenCounter::new()),
222 graph_store: None,
223 community_detection_failures: Arc::new(AtomicU64::new(0)),
224 graph_extraction_count: Arc::new(AtomicU64::new(0)),
225 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
226 last_qdrant_warn: Arc::new(AtomicU64::new(0)),
227 admission_control: None,
228 quality_gate: None,
229 key_facts_dedup_threshold: 0.95,
230 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
231 })
232 }
233
234 pub async fn with_qdrant_ops(
243 sqlite_path: &str,
244 ops: crate::QdrantOps,
245 provider: AnyProvider,
246 embedding_model: &str,
247 vector_weight: f64,
248 keyword_weight: f64,
249 pool_size: u32,
250 ) -> Result<Self, MemoryError> {
251 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
252 let pool = sqlite.pool().clone();
253 let store = EmbeddingStore::with_store(Box::new(ops), pool);
254
255 Ok(Self {
256 sqlite,
257 qdrant: Some(Arc::new(store)),
258 provider,
259 embed_provider: None,
260 embedding_model: embedding_model.into(),
261 vector_weight,
262 keyword_weight,
263 temporal_decay_enabled: false,
264 temporal_decay_half_life_days: 30,
265 mmr_enabled: false,
266 mmr_lambda: 0.7,
267 importance_enabled: false,
268 importance_weight: 0.15,
269 tier_boost_semantic: 1.3,
270 token_counter: Arc::new(TokenCounter::new()),
271 graph_store: None,
272 community_detection_failures: Arc::new(AtomicU64::new(0)),
273 graph_extraction_count: Arc::new(AtomicU64::new(0)),
274 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
275 last_qdrant_warn: Arc::new(AtomicU64::new(0)),
276 admission_control: None,
277 quality_gate: None,
278 key_facts_dedup_threshold: 0.95,
279 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
280 })
281 }
282
283 #[must_use]
288 pub fn with_graph_store(mut self, store: Arc<crate::graph::GraphStore>) -> Self {
289 self.graph_store = Some(store);
290 self
291 }
292
293 #[must_use]
295 pub fn community_detection_failures(&self) -> u64 {
296 use std::sync::atomic::Ordering;
297 self.community_detection_failures.load(Ordering::Relaxed)
298 }
299
300 #[must_use]
302 pub fn graph_extraction_count(&self) -> u64 {
303 use std::sync::atomic::Ordering;
304 self.graph_extraction_count.load(Ordering::Relaxed)
305 }
306
307 #[must_use]
309 pub fn graph_extraction_failures(&self) -> u64 {
310 use std::sync::atomic::Ordering;
311 self.graph_extraction_failures.load(Ordering::Relaxed)
312 }
313
314 #[must_use]
316 pub fn with_ranking_options(
317 mut self,
318 temporal_decay_enabled: bool,
319 temporal_decay_half_life_days: u32,
320 mmr_enabled: bool,
321 mmr_lambda: f32,
322 ) -> Self {
323 self.temporal_decay_enabled = temporal_decay_enabled;
324 self.temporal_decay_half_life_days = temporal_decay_half_life_days;
325 self.mmr_enabled = mmr_enabled;
326 self.mmr_lambda = mmr_lambda;
327 self
328 }
329
330 #[must_use]
332 pub fn with_importance_options(mut self, enabled: bool, weight: f64) -> Self {
333 self.importance_enabled = enabled;
334 self.importance_weight = weight;
335 self
336 }
337
338 #[must_use]
342 pub fn with_tier_boost(mut self, boost: f64) -> Self {
343 self.tier_boost_semantic = boost;
344 self
345 }
346
347 #[must_use]
352 pub fn with_admission_control(mut self, control: AdmissionControl) -> Self {
353 self.admission_control = Some(Arc::new(control));
354 self
355 }
356
357 #[must_use]
363 pub fn with_quality_gate(mut self, gate: Arc<crate::quality_gate::QualityGate>) -> Self {
364 self.quality_gate = Some(gate);
365 self
366 }
367
368 #[must_use]
373 pub fn with_key_facts_dedup_threshold(mut self, threshold: f32) -> Self {
374 self.key_facts_dedup_threshold = threshold;
375 self
376 }
377
378 #[must_use]
384 pub fn with_embed_provider(mut self, embed_provider: AnyProvider) -> Self {
385 self.embed_provider = Some(embed_provider);
386 self
387 }
388
389 pub(crate) fn effective_embed_provider(&self) -> &AnyProvider {
393 self.embed_provider.as_ref().unwrap_or(&self.provider)
394 }
395
396 #[must_use]
400 pub fn from_parts(
401 sqlite: SqliteStore,
402 qdrant: Option<Arc<EmbeddingStore>>,
403 provider: AnyProvider,
404 embedding_model: impl Into<String>,
405 vector_weight: f64,
406 keyword_weight: f64,
407 token_counter: Arc<TokenCounter>,
408 ) -> Self {
409 Self {
410 sqlite,
411 qdrant,
412 provider,
413 embed_provider: None,
414 embedding_model: embedding_model.into(),
415 vector_weight,
416 keyword_weight,
417 temporal_decay_enabled: false,
418 temporal_decay_half_life_days: 30,
419 mmr_enabled: false,
420 mmr_lambda: 0.7,
421 importance_enabled: false,
422 importance_weight: 0.15,
423 tier_boost_semantic: 1.3,
424 token_counter,
425 graph_store: None,
426 community_detection_failures: Arc::new(AtomicU64::new(0)),
427 graph_extraction_count: Arc::new(AtomicU64::new(0)),
428 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
429 last_qdrant_warn: Arc::new(AtomicU64::new(0)),
430 admission_control: None,
431 quality_gate: None,
432 key_facts_dedup_threshold: 0.95,
433 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
434 }
435 }
436
437 pub async fn with_sqlite_backend(
443 sqlite_path: &str,
444 provider: AnyProvider,
445 embedding_model: &str,
446 vector_weight: f64,
447 keyword_weight: f64,
448 ) -> Result<Self, MemoryError> {
449 Self::with_sqlite_backend_and_pool_size(
450 sqlite_path,
451 provider,
452 embedding_model,
453 vector_weight,
454 keyword_weight,
455 5,
456 )
457 .await
458 }
459
460 pub async fn with_sqlite_backend_and_pool_size(
466 sqlite_path: &str,
467 provider: AnyProvider,
468 embedding_model: &str,
469 vector_weight: f64,
470 keyword_weight: f64,
471 pool_size: u32,
472 ) -> Result<Self, MemoryError> {
473 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
474 let pool = sqlite.pool().clone();
475 let store = EmbeddingStore::new_sqlite(pool);
476
477 Ok(Self {
478 sqlite,
479 qdrant: Some(Arc::new(store)),
480 provider,
481 embed_provider: None,
482 embedding_model: embedding_model.into(),
483 vector_weight,
484 keyword_weight,
485 temporal_decay_enabled: false,
486 temporal_decay_half_life_days: 30,
487 mmr_enabled: false,
488 mmr_lambda: 0.7,
489 importance_enabled: false,
490 importance_weight: 0.15,
491 tier_boost_semantic: 1.3,
492 token_counter: Arc::new(TokenCounter::new()),
493 graph_store: None,
494 community_detection_failures: Arc::new(AtomicU64::new(0)),
495 graph_extraction_count: Arc::new(AtomicU64::new(0)),
496 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
497 last_qdrant_warn: Arc::new(AtomicU64::new(0)),
498 admission_control: None,
499 quality_gate: None,
500 key_facts_dedup_threshold: 0.95,
501 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
502 })
503 }
504
505 #[must_use]
507 pub fn sqlite(&self) -> &SqliteStore {
508 &self.sqlite
509 }
510
511 pub async fn is_vector_store_connected(&self) -> bool {
516 match self.qdrant.as_ref() {
517 Some(store) => store.health_check().await,
518 None => false,
519 }
520 }
521
522 #[must_use]
524 pub fn has_vector_store(&self) -> bool {
525 self.qdrant.is_some()
526 }
527
528 #[must_use]
530 pub fn embedding_store(&self) -> Option<&Arc<EmbeddingStore>> {
531 self.qdrant.as_ref()
532 }
533
534 pub fn provider(&self) -> &AnyProvider {
536 &self.provider
537 }
538
539 pub async fn message_count(
545 &self,
546 conversation_id: crate::types::ConversationId,
547 ) -> Result<i64, MemoryError> {
548 self.sqlite.count_messages(conversation_id).await
549 }
550
551 pub async fn unsummarized_message_count(
557 &self,
558 conversation_id: crate::types::ConversationId,
559 ) -> Result<i64, MemoryError> {
560 let after_id = self
561 .sqlite
562 .latest_summary_last_message_id(conversation_id)
563 .await?
564 .unwrap_or(crate::types::MessageId(0));
565 self.sqlite
566 .count_messages_after(conversation_id, after_id)
567 .await
568 }
569}