1use crate::chunking::{chunk_text_semantic, ChunkingConfig, Tokenizer};
5use crate::db::MemoryDatabase;
6use crate::embeddings::EmbeddingService;
7use crate::types::{
8 CleanupLogEntry, EmbeddingHealth, MemoryChunk, MemoryConfig, MemoryContext, MemoryResult,
9 MemoryRetrievalMeta, MemorySearchResult, MemoryStats, MemoryTier, StoreMessageRequest,
10};
11use chrono::Utc;
12use std::path::Path;
13use std::sync::Arc;
14use tokio::sync::Mutex;
15
16pub struct MemoryManager {
18 db: Arc<MemoryDatabase>,
19 embedding_service: Arc<Mutex<EmbeddingService>>,
20 tokenizer: Tokenizer,
21}
22
23impl MemoryManager {
24 pub fn db(&self) -> &Arc<MemoryDatabase> {
25 &self.db
26 }
27
28 pub async fn new(db_path: &Path) -> MemoryResult<Self> {
30 let db = Arc::new(MemoryDatabase::new(db_path).await?);
31 let embedding_service = Arc::new(Mutex::new(EmbeddingService::new()));
32 let tokenizer = Tokenizer::new()?;
33
34 Ok(Self {
35 db,
36 embedding_service,
37 tokenizer,
38 })
39 }
40
41 pub async fn store_message(&self, request: StoreMessageRequest) -> MemoryResult<Vec<String>> {
48 if self
49 .db
50 .ensure_vector_tables_healthy()
51 .await
52 .unwrap_or(false)
53 {
54 tracing::warn!("Memory vector tables were repaired before storing message chunks");
55 }
56
57 let config = if let Some(ref pid) = request.project_id {
58 self.db.get_or_create_config(pid).await?
59 } else {
60 MemoryConfig::default()
61 };
62
63 let chunking_config = ChunkingConfig {
65 chunk_size: config.chunk_size as usize,
66 chunk_overlap: config.chunk_overlap as usize,
67 separator: None,
68 };
69
70 let text_chunks = chunk_text_semantic(&request.content, &chunking_config)?;
71
72 if text_chunks.is_empty() {
73 return Ok(Vec::new());
74 }
75
76 let mut chunk_ids = Vec::with_capacity(text_chunks.len());
77 let embedding_service = self.embedding_service.lock().await;
78
79 for text_chunk in text_chunks {
80 let chunk_id = uuid::Uuid::new_v4().to_string();
81
82 let embedding = embedding_service.embed(&text_chunk.content).await?;
84
85 let chunk = MemoryChunk {
87 id: chunk_id.clone(),
88 content: text_chunk.content,
89 tier: request.tier,
90 session_id: request.session_id.clone(),
91 project_id: request.project_id.clone(),
92 source: request.source.clone(),
93 source_path: request.source_path.clone(),
94 source_mtime: request.source_mtime,
95 source_size: request.source_size,
96 source_hash: request.source_hash.clone(),
97 created_at: Utc::now(),
98 token_count: text_chunk.token_count as i64,
99 metadata: request.metadata.clone(),
100 };
101
102 if let Err(err) = self.db.store_chunk(&chunk, &embedding).await {
104 tracing::warn!("Failed to store memory chunk {}: {}", chunk.id, err);
105 let repaired = self
106 .db
107 .ensure_vector_tables_healthy()
108 .await
109 .unwrap_or(false);
110 if repaired {
111 tracing::warn!(
112 "Retrying memory chunk insert after vector table repair: {}",
113 chunk.id
114 );
115 self.db.store_chunk(&chunk, &embedding).await?;
116 } else {
117 return Err(err);
118 }
119 }
120 chunk_ids.push(chunk_id);
121 }
122
123 if config.auto_cleanup {
125 self.maybe_cleanup(&request.project_id).await?;
126 }
127
128 Ok(chunk_ids)
129 }
130
131 pub async fn search(
133 &self,
134 query: &str,
135 tier: Option<MemoryTier>,
136 project_id: Option<&str>,
137 session_id: Option<&str>,
138 limit: Option<i64>,
139 ) -> MemoryResult<Vec<MemorySearchResult>> {
140 let effective_limit = limit.unwrap_or(5);
141
142 let embedding_service = self.embedding_service.lock().await;
144 let query_embedding = embedding_service.embed(query).await?;
145 drop(embedding_service);
146
147 let mut results = Vec::new();
148
149 let tiers_to_search = match tier {
151 Some(t) => vec![t],
152 None => {
153 if project_id.is_some() {
154 vec![MemoryTier::Session, MemoryTier::Project, MemoryTier::Global]
155 } else {
156 vec![MemoryTier::Session, MemoryTier::Global]
157 }
158 }
159 };
160
161 for search_tier in tiers_to_search {
162 let tier_results = match self
163 .db
164 .search_similar(
165 &query_embedding,
166 search_tier,
167 project_id,
168 session_id,
169 effective_limit,
170 )
171 .await
172 {
173 Ok(results) => results,
174 Err(err) => {
175 tracing::warn!(
176 "Memory tier search failed for {:?}: {}. Attempting vector repair.",
177 search_tier,
178 err
179 );
180 let repaired = self
181 .db
182 .ensure_vector_tables_healthy()
183 .await
184 .unwrap_or(false);
185 if repaired {
186 match self
187 .db
188 .search_similar(
189 &query_embedding,
190 search_tier,
191 project_id,
192 session_id,
193 effective_limit,
194 )
195 .await
196 {
197 Ok(results) => results,
198 Err(retry_err) => {
199 tracing::warn!(
200 "Memory tier search still failing for {:?} after repair: {}",
201 search_tier,
202 retry_err
203 );
204 continue;
205 }
206 }
207 } else {
208 continue;
209 }
210 }
211 };
212
213 for (chunk, distance) in tier_results {
214 let similarity = 1.0 - distance.clamp(0.0, 1.0);
218
219 results.push(MemorySearchResult { chunk, similarity });
220 }
221 }
222
223 results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
225 results.truncate(effective_limit as usize);
226
227 Ok(results)
228 }
229
230 pub async fn retrieve_context(
235 &self,
236 query: &str,
237 project_id: Option<&str>,
238 session_id: Option<&str>,
239 token_budget: Option<i64>,
240 ) -> MemoryResult<MemoryContext> {
241 let (context, _) = self
242 .retrieve_context_with_meta(query, project_id, session_id, token_budget)
243 .await?;
244 Ok(context)
245 }
246
247 pub async fn retrieve_context_with_meta(
249 &self,
250 query: &str,
251 project_id: Option<&str>,
252 session_id: Option<&str>,
253 token_budget: Option<i64>,
254 ) -> MemoryResult<(MemoryContext, MemoryRetrievalMeta)> {
255 let config = if let Some(pid) = project_id {
256 self.db.get_or_create_config(pid).await?
257 } else {
258 MemoryConfig::default()
259 };
260 let budget = token_budget.unwrap_or(config.token_budget);
261 let retrieval_limit = config.retrieval_k.max(1);
262
263 let current_session = if let Some(sid) = session_id {
265 self.db.get_session_chunks(sid).await?
266 } else {
267 Vec::new()
268 };
269
270 let search_results = self
272 .search(query, None, project_id, session_id, Some(retrieval_limit))
273 .await?;
274
275 let mut score_min: Option<f64> = None;
276 let mut score_max: Option<f64> = None;
277 for result in &search_results {
278 score_min = Some(match score_min {
279 Some(current) => current.min(result.similarity),
280 None => result.similarity,
281 });
282 score_max = Some(match score_max {
283 Some(current) => current.max(result.similarity),
284 None => result.similarity,
285 });
286 }
287
288 let mut current_session = current_session;
289 let mut relevant_history = Vec::new();
290 let mut project_facts = Vec::new();
291
292 for result in search_results {
293 match result.chunk.tier {
294 MemoryTier::Project => {
295 project_facts.push(result.chunk);
296 }
297 MemoryTier::Global => {
298 project_facts.push(result.chunk);
299 }
300 MemoryTier::Session => {
301 if !current_session.iter().any(|c| c.id == result.chunk.id) {
303 relevant_history.push(result.chunk);
304 }
305 }
306 }
307 }
308
309 let mut total_tokens: i64 = current_session.iter().map(|c| c.token_count).sum();
311 total_tokens += relevant_history.iter().map(|c| c.token_count).sum::<i64>();
312 total_tokens += project_facts.iter().map(|c| c.token_count).sum::<i64>();
313
314 if total_tokens > budget {
316 let excess = total_tokens - budget;
317 self.trim_context(
318 &mut current_session,
319 &mut relevant_history,
320 &mut project_facts,
321 excess,
322 )?;
323 total_tokens = current_session.iter().map(|c| c.token_count).sum::<i64>()
324 + relevant_history.iter().map(|c| c.token_count).sum::<i64>()
325 + project_facts.iter().map(|c| c.token_count).sum::<i64>();
326 }
327
328 let context = MemoryContext {
329 current_session,
330 relevant_history,
331 project_facts,
332 total_tokens,
333 };
334 let chunks_total = context.current_session.len()
335 + context.relevant_history.len()
336 + context.project_facts.len();
337 let meta = MemoryRetrievalMeta {
338 used: chunks_total > 0,
339 chunks_total,
340 session_chunks: context.current_session.len(),
341 history_chunks: context.relevant_history.len(),
342 project_fact_chunks: context.project_facts.len(),
343 score_min,
344 score_max,
345 };
346
347 Ok((context, meta))
348 }
349
350 fn trim_context(
352 &self,
353 current_session: &mut Vec<MemoryChunk>,
354 relevant_history: &mut Vec<MemoryChunk>,
355 project_facts: &mut Vec<MemoryChunk>,
356 excess_tokens: i64,
357 ) -> MemoryResult<()> {
358 let mut tokens_to_remove = excess_tokens;
359
360 while tokens_to_remove > 0 && !relevant_history.is_empty() {
362 if let Some(chunk) = relevant_history.pop() {
363 tokens_to_remove -= chunk.token_count;
364 }
365 }
366
367 while tokens_to_remove > 0 && !project_facts.is_empty() {
369 if let Some(chunk) = project_facts.pop() {
370 tokens_to_remove -= chunk.token_count;
371 }
372 }
373
374 while tokens_to_remove > 0 && !current_session.is_empty() {
375 if let Some(chunk) = current_session.pop() {
376 tokens_to_remove -= chunk.token_count;
377 }
378 }
379
380 Ok(())
381 }
382
383 pub async fn clear_session(&self, session_id: &str) -> MemoryResult<u64> {
385 let count = self.db.clear_session_memory(session_id).await?;
386
387 self.db
389 .log_cleanup(
390 "manual",
391 MemoryTier::Session,
392 None,
393 Some(session_id),
394 count as i64,
395 0,
396 )
397 .await?;
398
399 Ok(count)
400 }
401
402 pub async fn clear_project(&self, project_id: &str) -> MemoryResult<u64> {
404 let count = self.db.clear_project_memory(project_id).await?;
405
406 self.db
408 .log_cleanup(
409 "manual",
410 MemoryTier::Project,
411 Some(project_id),
412 None,
413 count as i64,
414 0,
415 )
416 .await?;
417
418 Ok(count)
419 }
420
421 pub async fn get_stats(&self) -> MemoryResult<MemoryStats> {
423 self.db.get_stats().await
424 }
425
426 pub async fn get_config(&self, project_id: &str) -> MemoryResult<MemoryConfig> {
428 self.db.get_or_create_config(project_id).await
429 }
430
431 pub async fn set_config(&self, project_id: &str, config: &MemoryConfig) -> MemoryResult<()> {
433 self.db.update_config(project_id, config).await
434 }
435
436 pub async fn run_cleanup(&self, project_id: Option<&str>) -> MemoryResult<u64> {
438 let mut total_cleaned = 0u64;
439
440 if let Some(pid) = project_id {
441 let config = self.db.get_or_create_config(pid).await?;
443
444 if config.auto_cleanup {
445 let cleaned = self
447 .db
448 .cleanup_old_sessions(config.session_retention_days)
449 .await?;
450 total_cleaned += cleaned;
451
452 if cleaned > 0 {
453 self.db
454 .log_cleanup(
455 "auto",
456 MemoryTier::Session,
457 Some(pid),
458 None,
459 cleaned as i64,
460 0,
461 )
462 .await?;
463 }
464 }
465 } else {
466 let cleaned = self.db.cleanup_old_sessions(30).await?;
470 total_cleaned += cleaned;
471 }
472
473 if total_cleaned > 100 {
475 self.db.vacuum().await?;
476 }
477
478 Ok(total_cleaned)
479 }
480
481 async fn maybe_cleanup(&self, project_id: &Option<String>) -> MemoryResult<()> {
483 if let Some(pid) = project_id {
484 let stats = self.db.get_stats().await?;
485 let config = self.db.get_or_create_config(pid).await?;
486
487 if stats.project_chunks > config.max_chunks {
489 let excess = stats.project_chunks - config.max_chunks;
491 tracing::info!("Project {} has {} excess chunks", pid, excess);
494 }
495 }
496
497 Ok(())
498 }
499
500 pub async fn get_cleanup_log(&self, _limit: i64) -> MemoryResult<Vec<CleanupLogEntry>> {
502 Ok(Vec::new())
505 }
506
507 pub fn count_tokens(&self, text: &str) -> usize {
509 self.tokenizer.count_tokens(text)
510 }
511
512 pub async fn embedding_health(&self) -> EmbeddingHealth {
514 let service = self.embedding_service.lock().await;
515 if service.is_available() {
516 EmbeddingHealth {
517 status: "ok".to_string(),
518 reason: None,
519 }
520 } else {
521 EmbeddingHealth {
522 status: "degraded_disabled".to_string(),
523 reason: service.disabled_reason().map(ToString::to_string),
524 }
525 }
526 }
527}
528
529pub async fn create_memory_manager(app_data_dir: &Path) -> MemoryResult<MemoryManager> {
531 let db_path = app_data_dir.join("tandem_memory.db");
532 MemoryManager::new(&db_path).await
533}
534
535#[cfg(test)]
536mod tests {
537 use super::*;
538 use tempfile::TempDir;
539
540 fn is_embeddings_disabled(err: &crate::types::MemoryError) -> bool {
541 matches!(err, crate::types::MemoryError::Embedding(msg) if msg.to_ascii_lowercase().contains("embeddings disabled"))
542 }
543
544 async fn setup_test_manager() -> (MemoryManager, TempDir) {
545 let temp_dir = TempDir::new().unwrap();
546 let db_path = temp_dir.path().join("test_memory.db");
547 let manager = MemoryManager::new(&db_path).await.unwrap();
548 (manager, temp_dir)
549 }
550
551 #[tokio::test]
552 async fn test_store_and_search() {
553 let (manager, _temp) = setup_test_manager().await;
554
555 let request = StoreMessageRequest {
556 content: "This is a test message about artificial intelligence and machine learning."
557 .to_string(),
558 tier: MemoryTier::Project,
559 session_id: Some("session-1".to_string()),
560 project_id: Some("project-1".to_string()),
561 source: "user_message".to_string(),
562 source_path: None,
563 source_mtime: None,
564 source_size: None,
565 source_hash: None,
566 metadata: None,
567 };
568
569 let chunk_ids = match manager.store_message(request).await {
570 Ok(ids) => ids,
571 Err(err) if is_embeddings_disabled(&err) => return,
572 Err(err) => panic!("store_message failed: {err}"),
573 };
574 assert!(!chunk_ids.is_empty());
575
576 let results = manager
578 .search(
579 "artificial intelligence",
580 None,
581 Some("project-1"),
582 None,
583 None,
584 )
585 .await;
586 let results = match results {
587 Ok(results) => results,
588 Err(err) if is_embeddings_disabled(&err) => return,
589 Err(err) => panic!("search failed: {err}"),
590 };
591
592 assert!(!results.is_empty());
593 assert!(results[0].similarity >= 0.0);
595 }
596
597 #[tokio::test]
598 async fn test_retrieve_context() {
599 let (manager, _temp) = setup_test_manager().await;
600
601 let request = StoreMessageRequest {
603 content: "The project uses React and TypeScript for the frontend.".to_string(),
604 tier: MemoryTier::Project,
605 session_id: None,
606 project_id: Some("project-1".to_string()),
607 source: "assistant_response".to_string(),
608 source_path: None,
609 source_mtime: None,
610 source_size: None,
611 source_hash: None,
612 metadata: None,
613 };
614 match manager.store_message(request).await {
615 Ok(_) => {}
616 Err(err) if is_embeddings_disabled(&err) => return,
617 Err(err) => panic!("store_message failed: {err}"),
618 }
619
620 let context = manager
621 .retrieve_context("What technologies are used?", Some("project-1"), None, None)
622 .await;
623 let context = match context {
624 Ok(context) => context,
625 Err(err) if is_embeddings_disabled(&err) => return,
626 Err(err) => panic!("retrieve_context failed: {err}"),
627 };
628
629 assert!(!context.project_facts.is_empty());
630 }
631
632 #[tokio::test]
633 async fn test_retrieve_context_with_meta() {
634 let (manager, _temp) = setup_test_manager().await;
635
636 let request = StoreMessageRequest {
637 content: "The backend uses Rust and sqlite-vec for retrieval.".to_string(),
638 tier: MemoryTier::Project,
639 session_id: None,
640 project_id: Some("project-1".to_string()),
641 source: "assistant_response".to_string(),
642 source_path: None,
643 source_mtime: None,
644 source_size: None,
645 source_hash: None,
646 metadata: None,
647 };
648 match manager.store_message(request).await {
649 Ok(_) => {}
650 Err(err) if is_embeddings_disabled(&err) => return,
651 Err(err) => panic!("store_message failed: {err}"),
652 }
653
654 let result = manager
655 .retrieve_context_with_meta("What does the backend use?", Some("project-1"), None, None)
656 .await;
657 let (context, meta) = match result {
658 Ok(v) => v,
659 Err(err) if is_embeddings_disabled(&err) => return,
660 Err(err) => panic!("retrieve_context_with_meta failed: {err}"),
661 };
662
663 assert!(meta.chunks_total > 0);
664 assert!(meta.used);
665 assert_eq!(
666 meta.chunks_total,
667 context.current_session.len()
668 + context.relevant_history.len()
669 + context.project_facts.len()
670 );
671 assert!(meta.score_min.is_some());
672 assert!(meta.score_max.is_some());
673 }
674
675 #[tokio::test]
676 async fn test_config_management() {
677 let (manager, _temp) = setup_test_manager().await;
678
679 let config = manager.get_config("project-1").await.unwrap();
680 assert_eq!(config.max_chunks, 10000);
681
682 let new_config = MemoryConfig {
683 max_chunks: 5000,
684 retrieval_k: 10,
685 ..Default::default()
686 };
687
688 manager.set_config("project-1", &new_config).await.unwrap();
689
690 let updated = manager.get_config("project-1").await.unwrap();
691 assert_eq!(updated.max_chunks, 5000);
692 assert_eq!(updated.retrieval_k, 10);
693 }
694}