1mod algorithms;
5mod corrections;
6mod cross_session;
7mod graph;
8pub(crate) mod importance;
9pub mod persona;
10mod recall;
11mod summarization;
12pub mod trajectory;
13pub mod tree_consolidation;
14pub(crate) mod write_buffer;
15
16#[cfg(test)]
17mod tests;
18
19use std::sync::Arc;
20use std::sync::Mutex;
21use std::sync::atomic::AtomicU64;
22
23use zeph_llm::any::AnyProvider;
24
25use crate::admission::AdmissionControl;
26use crate::embedding_store::EmbeddingStore;
27use crate::error::MemoryError;
28use crate::store::SqliteStore;
29use crate::token_counter::TokenCounter;
30
31pub(crate) const SESSION_SUMMARIES_COLLECTION: &str = "zeph_session_summaries";
32pub(crate) const KEY_FACTS_COLLECTION: &str = "zeph_key_facts";
33pub(crate) const CORRECTIONS_COLLECTION: &str = "zeph_corrections";
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq)]
37pub struct BackfillProgress {
38 pub done: usize,
40 pub total: usize,
42}
43
44pub use algorithms::{apply_mmr, apply_temporal_decay};
45pub use cross_session::SessionSummaryResult;
46pub use graph::{
47 ExtractionResult, ExtractionStats, GraphExtractionConfig, LinkingStats, NoteLinkingConfig,
48 PostExtractValidator, extract_and_store, link_memory_notes,
49};
50pub use persona::{
51 PersonaExtractionConfig, contains_self_referential_language, extract_persona_facts,
52};
53pub use recall::{EmbedContext, RecalledMessage};
54pub use summarization::{StructuredSummary, Summary, build_summarization_prompt};
55pub use trajectory::{TrajectoryEntry, TrajectoryExtractionConfig, extract_trajectory_entries};
56pub use tree_consolidation::{
57 TreeConsolidationConfig, TreeConsolidationResult, run_tree_consolidation_sweep,
58 start_tree_consolidation_loop,
59};
60pub use write_buffer::{BufferedWrite, WriteBuffer};
61
62pub struct SemanticMemory {
63 pub(crate) sqlite: SqliteStore,
64 pub(crate) qdrant: Option<Arc<EmbeddingStore>>,
65 pub(crate) provider: AnyProvider,
66 pub(crate) embed_provider: Option<AnyProvider>,
72 pub(crate) embedding_model: String,
73 pub(crate) vector_weight: f64,
74 pub(crate) keyword_weight: f64,
75 pub(crate) temporal_decay_enabled: bool,
76 pub(crate) temporal_decay_half_life_days: u32,
77 pub(crate) mmr_enabled: bool,
78 pub(crate) mmr_lambda: f32,
79 pub(crate) importance_enabled: bool,
80 pub(crate) importance_weight: f64,
81 pub(crate) tier_boost_semantic: f64,
84 pub token_counter: Arc<TokenCounter>,
85 pub graph_store: Option<Arc<crate::graph::GraphStore>>,
86 pub(crate) community_detection_failures: Arc<AtomicU64>,
87 pub(crate) graph_extraction_count: Arc<AtomicU64>,
88 pub(crate) graph_extraction_failures: Arc<AtomicU64>,
89 pub(crate) admission_control: Option<Arc<AdmissionControl>>,
91 pub(crate) key_facts_dedup_threshold: f32,
95 pub(crate) embed_tasks: Mutex<tokio::task::JoinSet<()>>,
101}
102
103impl SemanticMemory {
104 pub async fn new(
115 sqlite_path: &str,
116 qdrant_url: &str,
117 provider: AnyProvider,
118 embedding_model: &str,
119 ) -> Result<Self, MemoryError> {
120 Self::with_weights(sqlite_path, qdrant_url, provider, embedding_model, 0.7, 0.3).await
121 }
122
123 pub async fn with_weights(
132 sqlite_path: &str,
133 qdrant_url: &str,
134 provider: AnyProvider,
135 embedding_model: &str,
136 vector_weight: f64,
137 keyword_weight: f64,
138 ) -> Result<Self, MemoryError> {
139 Self::with_weights_and_pool_size(
140 sqlite_path,
141 qdrant_url,
142 provider,
143 embedding_model,
144 vector_weight,
145 keyword_weight,
146 5,
147 )
148 .await
149 }
150
151 pub async fn with_weights_and_pool_size(
160 sqlite_path: &str,
161 qdrant_url: &str,
162 provider: AnyProvider,
163 embedding_model: &str,
164 vector_weight: f64,
165 keyword_weight: f64,
166 pool_size: u32,
167 ) -> Result<Self, MemoryError> {
168 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
169 let pool = sqlite.pool().clone();
170
171 let qdrant = match EmbeddingStore::new(qdrant_url, pool) {
172 Ok(store) => Some(Arc::new(store)),
173 Err(e) => {
174 tracing::warn!("Qdrant unavailable, semantic search disabled: {e:#}");
175 None
176 }
177 };
178
179 Ok(Self {
180 sqlite,
181 qdrant,
182 provider,
183 embed_provider: None,
184 embedding_model: embedding_model.into(),
185 vector_weight,
186 keyword_weight,
187 temporal_decay_enabled: false,
188 temporal_decay_half_life_days: 30,
189 mmr_enabled: false,
190 mmr_lambda: 0.7,
191 importance_enabled: false,
192 importance_weight: 0.15,
193 tier_boost_semantic: 1.3,
194 token_counter: Arc::new(TokenCounter::new()),
195 graph_store: None,
196 community_detection_failures: Arc::new(AtomicU64::new(0)),
197 graph_extraction_count: Arc::new(AtomicU64::new(0)),
198 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
199 admission_control: None,
200 key_facts_dedup_threshold: 0.95,
201 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
202 })
203 }
204
205 pub async fn with_qdrant_ops(
214 sqlite_path: &str,
215 ops: crate::QdrantOps,
216 provider: AnyProvider,
217 embedding_model: &str,
218 vector_weight: f64,
219 keyword_weight: f64,
220 pool_size: u32,
221 ) -> Result<Self, MemoryError> {
222 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
223 let pool = sqlite.pool().clone();
224 let store = EmbeddingStore::with_store(Box::new(ops), pool);
225
226 Ok(Self {
227 sqlite,
228 qdrant: Some(Arc::new(store)),
229 provider,
230 embed_provider: None,
231 embedding_model: embedding_model.into(),
232 vector_weight,
233 keyword_weight,
234 temporal_decay_enabled: false,
235 temporal_decay_half_life_days: 30,
236 mmr_enabled: false,
237 mmr_lambda: 0.7,
238 importance_enabled: false,
239 importance_weight: 0.15,
240 tier_boost_semantic: 1.3,
241 token_counter: Arc::new(TokenCounter::new()),
242 graph_store: None,
243 community_detection_failures: Arc::new(AtomicU64::new(0)),
244 graph_extraction_count: Arc::new(AtomicU64::new(0)),
245 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
246 admission_control: None,
247 key_facts_dedup_threshold: 0.95,
248 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
249 })
250 }
251
252 #[must_use]
257 pub fn with_graph_store(mut self, store: Arc<crate::graph::GraphStore>) -> Self {
258 self.graph_store = Some(store);
259 self
260 }
261
262 #[must_use]
264 pub fn community_detection_failures(&self) -> u64 {
265 use std::sync::atomic::Ordering;
266 self.community_detection_failures.load(Ordering::Relaxed)
267 }
268
269 #[must_use]
271 pub fn graph_extraction_count(&self) -> u64 {
272 use std::sync::atomic::Ordering;
273 self.graph_extraction_count.load(Ordering::Relaxed)
274 }
275
276 #[must_use]
278 pub fn graph_extraction_failures(&self) -> u64 {
279 use std::sync::atomic::Ordering;
280 self.graph_extraction_failures.load(Ordering::Relaxed)
281 }
282
283 #[must_use]
285 pub fn with_ranking_options(
286 mut self,
287 temporal_decay_enabled: bool,
288 temporal_decay_half_life_days: u32,
289 mmr_enabled: bool,
290 mmr_lambda: f32,
291 ) -> Self {
292 self.temporal_decay_enabled = temporal_decay_enabled;
293 self.temporal_decay_half_life_days = temporal_decay_half_life_days;
294 self.mmr_enabled = mmr_enabled;
295 self.mmr_lambda = mmr_lambda;
296 self
297 }
298
299 #[must_use]
301 pub fn with_importance_options(mut self, enabled: bool, weight: f64) -> Self {
302 self.importance_enabled = enabled;
303 self.importance_weight = weight;
304 self
305 }
306
307 #[must_use]
311 pub fn with_tier_boost(mut self, boost: f64) -> Self {
312 self.tier_boost_semantic = boost;
313 self
314 }
315
316 #[must_use]
321 pub fn with_admission_control(mut self, control: AdmissionControl) -> Self {
322 self.admission_control = Some(Arc::new(control));
323 self
324 }
325
326 #[must_use]
331 pub fn with_key_facts_dedup_threshold(mut self, threshold: f32) -> Self {
332 self.key_facts_dedup_threshold = threshold;
333 self
334 }
335
336 #[must_use]
342 pub fn with_embed_provider(mut self, embed_provider: AnyProvider) -> Self {
343 self.embed_provider = Some(embed_provider);
344 self
345 }
346
347 pub(crate) fn effective_embed_provider(&self) -> &AnyProvider {
351 self.embed_provider.as_ref().unwrap_or(&self.provider)
352 }
353
354 #[must_use]
358 pub fn from_parts(
359 sqlite: SqliteStore,
360 qdrant: Option<Arc<EmbeddingStore>>,
361 provider: AnyProvider,
362 embedding_model: impl Into<String>,
363 vector_weight: f64,
364 keyword_weight: f64,
365 token_counter: Arc<TokenCounter>,
366 ) -> Self {
367 Self {
368 sqlite,
369 qdrant,
370 provider,
371 embed_provider: None,
372 embedding_model: embedding_model.into(),
373 vector_weight,
374 keyword_weight,
375 temporal_decay_enabled: false,
376 temporal_decay_half_life_days: 30,
377 mmr_enabled: false,
378 mmr_lambda: 0.7,
379 importance_enabled: false,
380 importance_weight: 0.15,
381 tier_boost_semantic: 1.3,
382 token_counter,
383 graph_store: None,
384 community_detection_failures: Arc::new(AtomicU64::new(0)),
385 graph_extraction_count: Arc::new(AtomicU64::new(0)),
386 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
387 admission_control: None,
388 key_facts_dedup_threshold: 0.95,
389 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
390 }
391 }
392
393 pub async fn with_sqlite_backend(
399 sqlite_path: &str,
400 provider: AnyProvider,
401 embedding_model: &str,
402 vector_weight: f64,
403 keyword_weight: f64,
404 ) -> Result<Self, MemoryError> {
405 Self::with_sqlite_backend_and_pool_size(
406 sqlite_path,
407 provider,
408 embedding_model,
409 vector_weight,
410 keyword_weight,
411 5,
412 )
413 .await
414 }
415
416 pub async fn with_sqlite_backend_and_pool_size(
422 sqlite_path: &str,
423 provider: AnyProvider,
424 embedding_model: &str,
425 vector_weight: f64,
426 keyword_weight: f64,
427 pool_size: u32,
428 ) -> Result<Self, MemoryError> {
429 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
430 let pool = sqlite.pool().clone();
431 let store = EmbeddingStore::new_sqlite(pool);
432
433 Ok(Self {
434 sqlite,
435 qdrant: Some(Arc::new(store)),
436 provider,
437 embed_provider: None,
438 embedding_model: embedding_model.into(),
439 vector_weight,
440 keyword_weight,
441 temporal_decay_enabled: false,
442 temporal_decay_half_life_days: 30,
443 mmr_enabled: false,
444 mmr_lambda: 0.7,
445 importance_enabled: false,
446 importance_weight: 0.15,
447 tier_boost_semantic: 1.3,
448 token_counter: Arc::new(TokenCounter::new()),
449 graph_store: None,
450 community_detection_failures: Arc::new(AtomicU64::new(0)),
451 graph_extraction_count: Arc::new(AtomicU64::new(0)),
452 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
453 admission_control: None,
454 key_facts_dedup_threshold: 0.95,
455 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
456 })
457 }
458
459 #[must_use]
461 pub fn sqlite(&self) -> &SqliteStore {
462 &self.sqlite
463 }
464
465 pub async fn is_vector_store_connected(&self) -> bool {
470 match self.qdrant.as_ref() {
471 Some(store) => store.health_check().await,
472 None => false,
473 }
474 }
475
476 #[must_use]
478 pub fn has_vector_store(&self) -> bool {
479 self.qdrant.is_some()
480 }
481
482 #[must_use]
484 pub fn embedding_store(&self) -> Option<&Arc<EmbeddingStore>> {
485 self.qdrant.as_ref()
486 }
487
488 pub fn provider(&self) -> &AnyProvider {
490 &self.provider
491 }
492
493 pub async fn message_count(
499 &self,
500 conversation_id: crate::types::ConversationId,
501 ) -> Result<i64, MemoryError> {
502 self.sqlite.count_messages(conversation_id).await
503 }
504
505 pub async fn unsummarized_message_count(
511 &self,
512 conversation_id: crate::types::ConversationId,
513 ) -> Result<i64, MemoryError> {
514 let after_id = self
515 .sqlite
516 .latest_summary_last_message_id(conversation_id)
517 .await?
518 .unwrap_or(crate::types::MessageId(0));
519 self.sqlite
520 .count_messages_after(conversation_id, after_id)
521 .await
522 }
523}