1mod algorithms;
24mod corrections;
25mod cross_session;
26mod graph;
27pub(crate) mod importance;
28pub mod persona;
29mod recall;
30mod summarization;
31pub mod trajectory;
32pub mod tree_consolidation;
33pub(crate) mod write_buffer;
34
35#[cfg(test)]
36mod tests;
37
38use std::sync::Arc;
39use std::sync::Mutex;
40use std::sync::atomic::AtomicU64;
41
42use zeph_llm::any::AnyProvider;
43
44use crate::admission::AdmissionControl;
45use crate::embedding_store::EmbeddingStore;
46use crate::error::MemoryError;
47use crate::store::SqliteStore;
48use crate::token_counter::TokenCounter;
49
50pub(crate) const SESSION_SUMMARIES_COLLECTION: &str = "zeph_session_summaries";
51pub(crate) const KEY_FACTS_COLLECTION: &str = "zeph_key_facts";
52pub(crate) const CORRECTIONS_COLLECTION: &str = "zeph_corrections";
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub struct BackfillProgress {
57 pub done: usize,
59 pub total: usize,
61}
62
63pub use algorithms::{apply_mmr, apply_temporal_decay};
64pub use cross_session::SessionSummaryResult;
65pub use graph::{
66 ExtractionResult, ExtractionStats, GraphExtractionConfig, LinkingStats, NoteLinkingConfig,
67 PostExtractValidator, extract_and_store, link_memory_notes,
68};
69pub use persona::{
70 PersonaExtractionConfig, contains_self_referential_language, extract_persona_facts,
71};
72pub use recall::{EmbedContext, RecalledMessage};
73pub use summarization::{StructuredSummary, Summary, build_summarization_prompt};
74pub use trajectory::{TrajectoryEntry, TrajectoryExtractionConfig, extract_trajectory_entries};
75pub use tree_consolidation::{
76 TreeConsolidationConfig, TreeConsolidationResult, run_tree_consolidation_sweep,
77 start_tree_consolidation_loop,
78};
79pub use write_buffer::{BufferedWrite, WriteBuffer};
80
81pub struct SemanticMemory {
86 pub(crate) sqlite: SqliteStore,
87 pub(crate) qdrant: Option<Arc<EmbeddingStore>>,
88 pub(crate) provider: AnyProvider,
89 pub(crate) embed_provider: Option<AnyProvider>,
95 pub(crate) embedding_model: String,
96 pub(crate) vector_weight: f64,
97 pub(crate) keyword_weight: f64,
98 pub(crate) temporal_decay_enabled: bool,
99 pub(crate) temporal_decay_half_life_days: u32,
100 pub(crate) mmr_enabled: bool,
101 pub(crate) mmr_lambda: f32,
102 pub(crate) importance_enabled: bool,
103 pub(crate) importance_weight: f64,
104 pub(crate) tier_boost_semantic: f64,
107 pub token_counter: Arc<TokenCounter>,
108 pub graph_store: Option<Arc<crate::graph::GraphStore>>,
109 pub(crate) community_detection_failures: Arc<AtomicU64>,
110 pub(crate) graph_extraction_count: Arc<AtomicU64>,
111 pub(crate) graph_extraction_failures: Arc<AtomicU64>,
112 pub(crate) last_qdrant_warn: Arc<AtomicU64>,
113 pub(crate) admission_control: Option<Arc<AdmissionControl>>,
115 pub(crate) key_facts_dedup_threshold: f32,
119 pub(crate) embed_tasks: Mutex<tokio::task::JoinSet<()>>,
125}
126
127impl SemanticMemory {
128 pub async fn new(
139 sqlite_path: &str,
140 qdrant_url: &str,
141 provider: AnyProvider,
142 embedding_model: &str,
143 ) -> Result<Self, MemoryError> {
144 Self::with_weights(sqlite_path, qdrant_url, provider, embedding_model, 0.7, 0.3).await
145 }
146
147 pub async fn with_weights(
156 sqlite_path: &str,
157 qdrant_url: &str,
158 provider: AnyProvider,
159 embedding_model: &str,
160 vector_weight: f64,
161 keyword_weight: f64,
162 ) -> Result<Self, MemoryError> {
163 Self::with_weights_and_pool_size(
164 sqlite_path,
165 qdrant_url,
166 provider,
167 embedding_model,
168 vector_weight,
169 keyword_weight,
170 5,
171 )
172 .await
173 }
174
175 pub async fn with_weights_and_pool_size(
184 sqlite_path: &str,
185 qdrant_url: &str,
186 provider: AnyProvider,
187 embedding_model: &str,
188 vector_weight: f64,
189 keyword_weight: f64,
190 pool_size: u32,
191 ) -> Result<Self, MemoryError> {
192 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
193 let pool = sqlite.pool().clone();
194
195 let qdrant = match EmbeddingStore::new(qdrant_url, pool) {
196 Ok(store) => Some(Arc::new(store)),
197 Err(e) => {
198 tracing::warn!("Qdrant unavailable, semantic search disabled: {e:#}");
199 None
200 }
201 };
202
203 Ok(Self {
204 sqlite,
205 qdrant,
206 provider,
207 embed_provider: None,
208 embedding_model: embedding_model.into(),
209 vector_weight,
210 keyword_weight,
211 temporal_decay_enabled: false,
212 temporal_decay_half_life_days: 30,
213 mmr_enabled: false,
214 mmr_lambda: 0.7,
215 importance_enabled: false,
216 importance_weight: 0.15,
217 tier_boost_semantic: 1.3,
218 token_counter: Arc::new(TokenCounter::new()),
219 graph_store: None,
220 community_detection_failures: Arc::new(AtomicU64::new(0)),
221 graph_extraction_count: Arc::new(AtomicU64::new(0)),
222 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
223 last_qdrant_warn: Arc::new(AtomicU64::new(0)),
224 admission_control: None,
225 key_facts_dedup_threshold: 0.95,
226 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
227 })
228 }
229
230 pub async fn with_qdrant_ops(
239 sqlite_path: &str,
240 ops: crate::QdrantOps,
241 provider: AnyProvider,
242 embedding_model: &str,
243 vector_weight: f64,
244 keyword_weight: f64,
245 pool_size: u32,
246 ) -> Result<Self, MemoryError> {
247 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
248 let pool = sqlite.pool().clone();
249 let store = EmbeddingStore::with_store(Box::new(ops), pool);
250
251 Ok(Self {
252 sqlite,
253 qdrant: Some(Arc::new(store)),
254 provider,
255 embed_provider: None,
256 embedding_model: embedding_model.into(),
257 vector_weight,
258 keyword_weight,
259 temporal_decay_enabled: false,
260 temporal_decay_half_life_days: 30,
261 mmr_enabled: false,
262 mmr_lambda: 0.7,
263 importance_enabled: false,
264 importance_weight: 0.15,
265 tier_boost_semantic: 1.3,
266 token_counter: Arc::new(TokenCounter::new()),
267 graph_store: None,
268 community_detection_failures: Arc::new(AtomicU64::new(0)),
269 graph_extraction_count: Arc::new(AtomicU64::new(0)),
270 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
271 last_qdrant_warn: Arc::new(AtomicU64::new(0)),
272 admission_control: None,
273 key_facts_dedup_threshold: 0.95,
274 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
275 })
276 }
277
278 #[must_use]
283 pub fn with_graph_store(mut self, store: Arc<crate::graph::GraphStore>) -> Self {
284 self.graph_store = Some(store);
285 self
286 }
287
288 #[must_use]
290 pub fn community_detection_failures(&self) -> u64 {
291 use std::sync::atomic::Ordering;
292 self.community_detection_failures.load(Ordering::Relaxed)
293 }
294
295 #[must_use]
297 pub fn graph_extraction_count(&self) -> u64 {
298 use std::sync::atomic::Ordering;
299 self.graph_extraction_count.load(Ordering::Relaxed)
300 }
301
302 #[must_use]
304 pub fn graph_extraction_failures(&self) -> u64 {
305 use std::sync::atomic::Ordering;
306 self.graph_extraction_failures.load(Ordering::Relaxed)
307 }
308
309 #[must_use]
311 pub fn with_ranking_options(
312 mut self,
313 temporal_decay_enabled: bool,
314 temporal_decay_half_life_days: u32,
315 mmr_enabled: bool,
316 mmr_lambda: f32,
317 ) -> Self {
318 self.temporal_decay_enabled = temporal_decay_enabled;
319 self.temporal_decay_half_life_days = temporal_decay_half_life_days;
320 self.mmr_enabled = mmr_enabled;
321 self.mmr_lambda = mmr_lambda;
322 self
323 }
324
325 #[must_use]
327 pub fn with_importance_options(mut self, enabled: bool, weight: f64) -> Self {
328 self.importance_enabled = enabled;
329 self.importance_weight = weight;
330 self
331 }
332
333 #[must_use]
337 pub fn with_tier_boost(mut self, boost: f64) -> Self {
338 self.tier_boost_semantic = boost;
339 self
340 }
341
342 #[must_use]
347 pub fn with_admission_control(mut self, control: AdmissionControl) -> Self {
348 self.admission_control = Some(Arc::new(control));
349 self
350 }
351
352 #[must_use]
357 pub fn with_key_facts_dedup_threshold(mut self, threshold: f32) -> Self {
358 self.key_facts_dedup_threshold = threshold;
359 self
360 }
361
362 #[must_use]
368 pub fn with_embed_provider(mut self, embed_provider: AnyProvider) -> Self {
369 self.embed_provider = Some(embed_provider);
370 self
371 }
372
373 pub(crate) fn effective_embed_provider(&self) -> &AnyProvider {
377 self.embed_provider.as_ref().unwrap_or(&self.provider)
378 }
379
380 #[must_use]
384 pub fn from_parts(
385 sqlite: SqliteStore,
386 qdrant: Option<Arc<EmbeddingStore>>,
387 provider: AnyProvider,
388 embedding_model: impl Into<String>,
389 vector_weight: f64,
390 keyword_weight: f64,
391 token_counter: Arc<TokenCounter>,
392 ) -> Self {
393 Self {
394 sqlite,
395 qdrant,
396 provider,
397 embed_provider: None,
398 embedding_model: embedding_model.into(),
399 vector_weight,
400 keyword_weight,
401 temporal_decay_enabled: false,
402 temporal_decay_half_life_days: 30,
403 mmr_enabled: false,
404 mmr_lambda: 0.7,
405 importance_enabled: false,
406 importance_weight: 0.15,
407 tier_boost_semantic: 1.3,
408 token_counter,
409 graph_store: None,
410 community_detection_failures: Arc::new(AtomicU64::new(0)),
411 graph_extraction_count: Arc::new(AtomicU64::new(0)),
412 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
413 last_qdrant_warn: Arc::new(AtomicU64::new(0)),
414 admission_control: None,
415 key_facts_dedup_threshold: 0.95,
416 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
417 }
418 }
419
420 pub async fn with_sqlite_backend(
426 sqlite_path: &str,
427 provider: AnyProvider,
428 embedding_model: &str,
429 vector_weight: f64,
430 keyword_weight: f64,
431 ) -> Result<Self, MemoryError> {
432 Self::with_sqlite_backend_and_pool_size(
433 sqlite_path,
434 provider,
435 embedding_model,
436 vector_weight,
437 keyword_weight,
438 5,
439 )
440 .await
441 }
442
443 pub async fn with_sqlite_backend_and_pool_size(
449 sqlite_path: &str,
450 provider: AnyProvider,
451 embedding_model: &str,
452 vector_weight: f64,
453 keyword_weight: f64,
454 pool_size: u32,
455 ) -> Result<Self, MemoryError> {
456 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
457 let pool = sqlite.pool().clone();
458 let store = EmbeddingStore::new_sqlite(pool);
459
460 Ok(Self {
461 sqlite,
462 qdrant: Some(Arc::new(store)),
463 provider,
464 embed_provider: None,
465 embedding_model: embedding_model.into(),
466 vector_weight,
467 keyword_weight,
468 temporal_decay_enabled: false,
469 temporal_decay_half_life_days: 30,
470 mmr_enabled: false,
471 mmr_lambda: 0.7,
472 importance_enabled: false,
473 importance_weight: 0.15,
474 tier_boost_semantic: 1.3,
475 token_counter: Arc::new(TokenCounter::new()),
476 graph_store: None,
477 community_detection_failures: Arc::new(AtomicU64::new(0)),
478 graph_extraction_count: Arc::new(AtomicU64::new(0)),
479 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
480 last_qdrant_warn: Arc::new(AtomicU64::new(0)),
481 admission_control: None,
482 key_facts_dedup_threshold: 0.95,
483 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
484 })
485 }
486
487 #[must_use]
489 pub fn sqlite(&self) -> &SqliteStore {
490 &self.sqlite
491 }
492
493 pub async fn is_vector_store_connected(&self) -> bool {
498 match self.qdrant.as_ref() {
499 Some(store) => store.health_check().await,
500 None => false,
501 }
502 }
503
504 #[must_use]
506 pub fn has_vector_store(&self) -> bool {
507 self.qdrant.is_some()
508 }
509
510 #[must_use]
512 pub fn embedding_store(&self) -> Option<&Arc<EmbeddingStore>> {
513 self.qdrant.as_ref()
514 }
515
516 pub fn provider(&self) -> &AnyProvider {
518 &self.provider
519 }
520
521 pub async fn message_count(
527 &self,
528 conversation_id: crate::types::ConversationId,
529 ) -> Result<i64, MemoryError> {
530 self.sqlite.count_messages(conversation_id).await
531 }
532
533 pub async fn unsummarized_message_count(
539 &self,
540 conversation_id: crate::types::ConversationId,
541 ) -> Result<i64, MemoryError> {
542 let after_id = self
543 .sqlite
544 .latest_summary_last_message_id(conversation_id)
545 .await?
546 .unwrap_or(crate::types::MessageId(0));
547 self.sqlite
548 .count_messages_after(conversation_id, after_id)
549 .await
550 }
551}