1mod algorithms;
5mod corrections;
6mod cross_session;
7mod graph;
8mod recall;
9mod summarization;
10
11#[cfg(test)]
12mod tests;
13
14use std::sync::Arc;
15use std::sync::atomic::AtomicU64;
16
17use zeph_llm::any::AnyProvider;
18
19use crate::embedding_store::EmbeddingStore;
20use crate::error::MemoryError;
21use crate::sqlite::SqliteStore;
22use crate::token_counter::TokenCounter;
23
24pub(crate) const SESSION_SUMMARIES_COLLECTION: &str = "zeph_session_summaries";
25pub(crate) const KEY_FACTS_COLLECTION: &str = "zeph_key_facts";
26pub(crate) const CORRECTIONS_COLLECTION: &str = "zeph_corrections";
27
28pub use algorithms::{apply_mmr, apply_temporal_decay};
29pub use cross_session::SessionSummaryResult;
30pub use graph::{
31 ExtractionResult, ExtractionStats, GraphExtractionConfig, LinkingStats, NoteLinkingConfig,
32 PostExtractValidator, extract_and_store, link_memory_notes,
33};
34pub use recall::RecalledMessage;
35pub use summarization::{StructuredSummary, Summary, build_summarization_prompt};
36
37pub struct SemanticMemory {
38 pub(crate) sqlite: SqliteStore,
39 pub(crate) qdrant: Option<Arc<EmbeddingStore>>,
40 pub(crate) provider: AnyProvider,
41 pub(crate) embedding_model: String,
42 pub(crate) vector_weight: f64,
43 pub(crate) keyword_weight: f64,
44 pub(crate) temporal_decay_enabled: bool,
45 pub(crate) temporal_decay_half_life_days: u32,
46 pub(crate) mmr_enabled: bool,
47 pub(crate) mmr_lambda: f32,
48 pub token_counter: Arc<TokenCounter>,
49 pub graph_store: Option<Arc<crate::graph::GraphStore>>,
50 pub(crate) community_detection_failures: Arc<AtomicU64>,
51 pub(crate) graph_extraction_count: Arc<AtomicU64>,
52 pub(crate) graph_extraction_failures: Arc<AtomicU64>,
53}
54
55impl SemanticMemory {
56 pub async fn new(
67 sqlite_path: &str,
68 qdrant_url: &str,
69 provider: AnyProvider,
70 embedding_model: &str,
71 ) -> Result<Self, MemoryError> {
72 Self::with_weights(sqlite_path, qdrant_url, provider, embedding_model, 0.7, 0.3).await
73 }
74
75 pub async fn with_weights(
84 sqlite_path: &str,
85 qdrant_url: &str,
86 provider: AnyProvider,
87 embedding_model: &str,
88 vector_weight: f64,
89 keyword_weight: f64,
90 ) -> Result<Self, MemoryError> {
91 Self::with_weights_and_pool_size(
92 sqlite_path,
93 qdrant_url,
94 provider,
95 embedding_model,
96 vector_weight,
97 keyword_weight,
98 5,
99 )
100 .await
101 }
102
103 pub async fn with_weights_and_pool_size(
112 sqlite_path: &str,
113 qdrant_url: &str,
114 provider: AnyProvider,
115 embedding_model: &str,
116 vector_weight: f64,
117 keyword_weight: f64,
118 pool_size: u32,
119 ) -> Result<Self, MemoryError> {
120 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
121 let pool = sqlite.pool().clone();
122
123 let qdrant = match EmbeddingStore::new(qdrant_url, pool) {
124 Ok(store) => Some(Arc::new(store)),
125 Err(e) => {
126 tracing::warn!("Qdrant unavailable, semantic search disabled: {e:#}");
127 None
128 }
129 };
130
131 Ok(Self {
132 sqlite,
133 qdrant,
134 provider,
135 embedding_model: embedding_model.into(),
136 vector_weight,
137 keyword_weight,
138 temporal_decay_enabled: false,
139 temporal_decay_half_life_days: 30,
140 mmr_enabled: false,
141 mmr_lambda: 0.7,
142 token_counter: Arc::new(TokenCounter::new()),
143 graph_store: None,
144 community_detection_failures: Arc::new(AtomicU64::new(0)),
145 graph_extraction_count: Arc::new(AtomicU64::new(0)),
146 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
147 })
148 }
149
150 pub async fn with_qdrant_ops(
159 sqlite_path: &str,
160 ops: crate::QdrantOps,
161 provider: AnyProvider,
162 embedding_model: &str,
163 vector_weight: f64,
164 keyword_weight: f64,
165 pool_size: u32,
166 ) -> Result<Self, MemoryError> {
167 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
168 let pool = sqlite.pool().clone();
169 let store = EmbeddingStore::with_store(Box::new(ops), pool);
170
171 Ok(Self {
172 sqlite,
173 qdrant: Some(Arc::new(store)),
174 provider,
175 embedding_model: embedding_model.into(),
176 vector_weight,
177 keyword_weight,
178 temporal_decay_enabled: false,
179 temporal_decay_half_life_days: 30,
180 mmr_enabled: false,
181 mmr_lambda: 0.7,
182 token_counter: Arc::new(TokenCounter::new()),
183 graph_store: None,
184 community_detection_failures: Arc::new(AtomicU64::new(0)),
185 graph_extraction_count: Arc::new(AtomicU64::new(0)),
186 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
187 })
188 }
189
190 #[must_use]
195 pub fn with_graph_store(mut self, store: Arc<crate::graph::GraphStore>) -> Self {
196 self.graph_store = Some(store);
197 self
198 }
199
200 #[must_use]
202 pub fn community_detection_failures(&self) -> u64 {
203 use std::sync::atomic::Ordering;
204 self.community_detection_failures.load(Ordering::Relaxed)
205 }
206
207 #[must_use]
209 pub fn graph_extraction_count(&self) -> u64 {
210 use std::sync::atomic::Ordering;
211 self.graph_extraction_count.load(Ordering::Relaxed)
212 }
213
214 #[must_use]
216 pub fn graph_extraction_failures(&self) -> u64 {
217 use std::sync::atomic::Ordering;
218 self.graph_extraction_failures.load(Ordering::Relaxed)
219 }
220
221 #[must_use]
223 pub fn with_ranking_options(
224 mut self,
225 temporal_decay_enabled: bool,
226 temporal_decay_half_life_days: u32,
227 mmr_enabled: bool,
228 mmr_lambda: f32,
229 ) -> Self {
230 self.temporal_decay_enabled = temporal_decay_enabled;
231 self.temporal_decay_half_life_days = temporal_decay_half_life_days;
232 self.mmr_enabled = mmr_enabled;
233 self.mmr_lambda = mmr_lambda;
234 self
235 }
236
237 #[must_use]
241 pub fn from_parts(
242 sqlite: SqliteStore,
243 qdrant: Option<Arc<EmbeddingStore>>,
244 provider: AnyProvider,
245 embedding_model: impl Into<String>,
246 vector_weight: f64,
247 keyword_weight: f64,
248 token_counter: Arc<TokenCounter>,
249 ) -> Self {
250 Self {
251 sqlite,
252 qdrant,
253 provider,
254 embedding_model: embedding_model.into(),
255 vector_weight,
256 keyword_weight,
257 temporal_decay_enabled: false,
258 temporal_decay_half_life_days: 30,
259 mmr_enabled: false,
260 mmr_lambda: 0.7,
261 token_counter,
262 graph_store: None,
263 community_detection_failures: Arc::new(AtomicU64::new(0)),
264 graph_extraction_count: Arc::new(AtomicU64::new(0)),
265 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
266 }
267 }
268
269 pub async fn with_sqlite_backend(
275 sqlite_path: &str,
276 provider: AnyProvider,
277 embedding_model: &str,
278 vector_weight: f64,
279 keyword_weight: f64,
280 ) -> Result<Self, MemoryError> {
281 Self::with_sqlite_backend_and_pool_size(
282 sqlite_path,
283 provider,
284 embedding_model,
285 vector_weight,
286 keyword_weight,
287 5,
288 )
289 .await
290 }
291
292 pub async fn with_sqlite_backend_and_pool_size(
298 sqlite_path: &str,
299 provider: AnyProvider,
300 embedding_model: &str,
301 vector_weight: f64,
302 keyword_weight: f64,
303 pool_size: u32,
304 ) -> Result<Self, MemoryError> {
305 let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
306 let pool = sqlite.pool().clone();
307 let store = EmbeddingStore::new_sqlite(pool);
308
309 Ok(Self {
310 sqlite,
311 qdrant: Some(Arc::new(store)),
312 provider,
313 embedding_model: embedding_model.into(),
314 vector_weight,
315 keyword_weight,
316 temporal_decay_enabled: false,
317 temporal_decay_half_life_days: 30,
318 mmr_enabled: false,
319 mmr_lambda: 0.7,
320 token_counter: Arc::new(TokenCounter::new()),
321 graph_store: None,
322 community_detection_failures: Arc::new(AtomicU64::new(0)),
323 graph_extraction_count: Arc::new(AtomicU64::new(0)),
324 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
325 })
326 }
327
328 #[must_use]
330 pub fn sqlite(&self) -> &SqliteStore {
331 &self.sqlite
332 }
333
334 pub async fn is_vector_store_connected(&self) -> bool {
339 match self.qdrant.as_ref() {
340 Some(store) => store.health_check().await,
341 None => false,
342 }
343 }
344
345 #[must_use]
347 pub fn has_vector_store(&self) -> bool {
348 self.qdrant.is_some()
349 }
350
351 #[must_use]
353 pub fn embedding_store(&self) -> Option<&Arc<EmbeddingStore>> {
354 self.qdrant.as_ref()
355 }
356
357 pub async fn message_count(
363 &self,
364 conversation_id: crate::types::ConversationId,
365 ) -> Result<i64, MemoryError> {
366 self.sqlite.count_messages(conversation_id).await
367 }
368
369 pub async fn unsummarized_message_count(
375 &self,
376 conversation_id: crate::types::ConversationId,
377 ) -> Result<i64, MemoryError> {
378 let after_id = self
379 .sqlite
380 .latest_summary_last_message_id(conversation_id)
381 .await?
382 .unwrap_or(crate::types::MessageId(0));
383 self.sqlite
384 .count_messages_after(conversation_id, after_id)
385 .await
386 }
387}