1use crate::chunking::{chunk_text_semantic, ChunkingConfig, Tokenizer};
5use crate::db::MemoryDatabase;
6use crate::embeddings::EmbeddingService;
7use crate::types::{
8 CleanupLogEntry, EmbeddingHealth, MemoryChunk, MemoryConfig, MemoryContext, MemoryResult,
9 MemoryRetrievalMeta, MemorySearchResult, MemoryStats, MemoryTier, StoreMessageRequest,
10};
11use chrono::Utc;
12use std::path::Path;
13use std::sync::Arc;
14use tandem_providers::{MemoryConsolidationConfig, ProviderRegistry};
15use tokio::sync::Mutex;
16
17pub struct MemoryManager {
19 db: Arc<MemoryDatabase>,
20 embedding_service: Arc<Mutex<EmbeddingService>>,
21 tokenizer: Tokenizer,
22}
23
24impl MemoryManager {
25 fn is_malformed_database_error(err: &crate::types::MemoryError) -> bool {
26 err.to_string()
27 .to_lowercase()
28 .contains("database disk image is malformed")
29 }
30
31 pub fn db(&self) -> &Arc<MemoryDatabase> {
32 &self.db
33 }
34
35 pub async fn new(db_path: &Path) -> MemoryResult<Self> {
37 let db = Arc::new(MemoryDatabase::new(db_path).await?);
38 let embedding_service = Arc::new(Mutex::new(EmbeddingService::new()));
39 let tokenizer = Tokenizer::new()?;
40
41 Ok(Self {
42 db,
43 embedding_service,
44 tokenizer,
45 })
46 }
47
48 pub async fn store_message(&self, request: StoreMessageRequest) -> MemoryResult<Vec<String>> {
55 if self
56 .db
57 .ensure_vector_tables_healthy()
58 .await
59 .unwrap_or(false)
60 {
61 tracing::warn!("Memory vector tables were repaired before storing message chunks");
62 }
63
64 let config = if let Some(ref pid) = request.project_id {
65 self.db.get_or_create_config(pid).await?
66 } else {
67 MemoryConfig::default()
68 };
69
70 let chunking_config = ChunkingConfig {
72 chunk_size: config.chunk_size as usize,
73 chunk_overlap: config.chunk_overlap as usize,
74 separator: None,
75 };
76
77 let text_chunks = chunk_text_semantic(&request.content, &chunking_config)?;
78
79 if text_chunks.is_empty() {
80 return Ok(Vec::new());
81 }
82
83 let mut chunk_ids = Vec::with_capacity(text_chunks.len());
84 let embedding_service = self.embedding_service.lock().await;
85
86 for text_chunk in text_chunks {
87 let chunk_id = uuid::Uuid::new_v4().to_string();
88
89 let embedding = embedding_service.embed(&text_chunk.content).await?;
91
92 let chunk = MemoryChunk {
94 id: chunk_id.clone(),
95 content: text_chunk.content,
96 tier: request.tier,
97 session_id: request.session_id.clone(),
98 project_id: request.project_id.clone(),
99 source: request.source.clone(),
100 source_path: request.source_path.clone(),
101 source_mtime: request.source_mtime,
102 source_size: request.source_size,
103 source_hash: request.source_hash.clone(),
104 created_at: Utc::now(),
105 token_count: text_chunk.token_count as i64,
106 metadata: request.metadata.clone(),
107 };
108
109 if let Err(err) = self.db.store_chunk(&chunk, &embedding).await {
111 tracing::warn!("Failed to store memory chunk {}: {}", chunk.id, err);
112 let repaired = self.db.try_repair_after_error(&err).await.unwrap_or(false)
113 || self
114 .db
115 .ensure_vector_tables_healthy()
116 .await
117 .unwrap_or(false);
118 if repaired {
119 tracing::warn!(
120 "Retrying memory chunk insert after vector table repair: {}",
121 chunk.id
122 );
123 if let Err(retry_err) = self.db.store_chunk(&chunk, &embedding).await {
124 if Self::is_malformed_database_error(&retry_err) {
125 tracing::warn!(
126 "Memory DB still malformed after vector repair. Resetting memory tables and retrying chunk insert: {}",
127 chunk.id
128 );
129 self.db.reset_all_memory_tables().await?;
130 self.db.store_chunk(&chunk, &embedding).await?;
131 } else {
132 return Err(retry_err);
133 }
134 }
135 } else {
136 return Err(err);
137 }
138 }
139 chunk_ids.push(chunk_id);
140 }
141
142 if config.auto_cleanup {
144 self.maybe_cleanup(&request.project_id).await?;
145 }
146
147 Ok(chunk_ids)
148 }
149
150 pub async fn search(
152 &self,
153 query: &str,
154 tier: Option<MemoryTier>,
155 project_id: Option<&str>,
156 session_id: Option<&str>,
157 limit: Option<i64>,
158 ) -> MemoryResult<Vec<MemorySearchResult>> {
159 let effective_limit = limit.unwrap_or(5);
160
161 let embedding_service = self.embedding_service.lock().await;
163 let query_embedding = embedding_service.embed(query).await?;
164 drop(embedding_service);
165
166 let mut results = Vec::new();
167
168 let tiers_to_search = match tier {
170 Some(t) => vec![t],
171 None => {
172 if project_id.is_some() {
173 vec![MemoryTier::Session, MemoryTier::Project, MemoryTier::Global]
174 } else {
175 vec![MemoryTier::Session, MemoryTier::Global]
176 }
177 }
178 };
179
180 for search_tier in tiers_to_search {
181 let tier_results = match self
182 .db
183 .search_similar(
184 &query_embedding,
185 search_tier,
186 project_id,
187 session_id,
188 effective_limit,
189 )
190 .await
191 {
192 Ok(results) => results,
193 Err(err) => {
194 tracing::warn!(
195 "Memory tier search failed for {:?}: {}. Attempting vector repair.",
196 search_tier,
197 err
198 );
199 let repaired = self.db.try_repair_after_error(&err).await.unwrap_or(false)
200 || self
201 .db
202 .ensure_vector_tables_healthy()
203 .await
204 .unwrap_or(false);
205 if repaired {
206 match self
207 .db
208 .search_similar(
209 &query_embedding,
210 search_tier,
211 project_id,
212 session_id,
213 effective_limit,
214 )
215 .await
216 {
217 Ok(results) => results,
218 Err(retry_err) => {
219 tracing::warn!(
220 "Memory tier search still failing for {:?} after repair: {}",
221 search_tier,
222 retry_err
223 );
224 continue;
225 }
226 }
227 } else {
228 continue;
229 }
230 }
231 };
232
233 for (chunk, distance) in tier_results {
234 let similarity = 1.0 - distance.clamp(0.0, 1.0);
238
239 results.push(MemorySearchResult { chunk, similarity });
240 }
241 }
242
243 results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
245 results.truncate(effective_limit as usize);
246
247 Ok(results)
248 }
249
250 pub async fn retrieve_context(
255 &self,
256 query: &str,
257 project_id: Option<&str>,
258 session_id: Option<&str>,
259 token_budget: Option<i64>,
260 ) -> MemoryResult<MemoryContext> {
261 let (context, _) = self
262 .retrieve_context_with_meta(query, project_id, session_id, token_budget)
263 .await?;
264 Ok(context)
265 }
266
267 pub async fn retrieve_context_with_meta(
269 &self,
270 query: &str,
271 project_id: Option<&str>,
272 session_id: Option<&str>,
273 token_budget: Option<i64>,
274 ) -> MemoryResult<(MemoryContext, MemoryRetrievalMeta)> {
275 let config = if let Some(pid) = project_id {
276 self.db.get_or_create_config(pid).await?
277 } else {
278 MemoryConfig::default()
279 };
280 let budget = token_budget.unwrap_or(config.token_budget);
281 let retrieval_limit = config.retrieval_k.max(1);
282
283 let current_session = if let Some(sid) = session_id {
285 self.db.get_session_chunks(sid).await?
286 } else {
287 Vec::new()
288 };
289
290 let search_results = self
292 .search(query, None, project_id, session_id, Some(retrieval_limit))
293 .await?;
294
295 let mut score_min: Option<f64> = None;
296 let mut score_max: Option<f64> = None;
297 for result in &search_results {
298 score_min = Some(match score_min {
299 Some(current) => current.min(result.similarity),
300 None => result.similarity,
301 });
302 score_max = Some(match score_max {
303 Some(current) => current.max(result.similarity),
304 None => result.similarity,
305 });
306 }
307
308 let mut current_session = current_session;
309 let mut relevant_history = Vec::new();
310 let mut project_facts = Vec::new();
311
312 for result in search_results {
313 match result.chunk.tier {
314 MemoryTier::Project => {
315 project_facts.push(result.chunk);
316 }
317 MemoryTier::Global => {
318 project_facts.push(result.chunk);
319 }
320 MemoryTier::Session => {
321 if !current_session.iter().any(|c| c.id == result.chunk.id) {
323 relevant_history.push(result.chunk);
324 }
325 }
326 }
327 }
328
329 let mut total_tokens: i64 = current_session.iter().map(|c| c.token_count).sum();
331 total_tokens += relevant_history.iter().map(|c| c.token_count).sum::<i64>();
332 total_tokens += project_facts.iter().map(|c| c.token_count).sum::<i64>();
333
334 if total_tokens > budget {
336 let excess = total_tokens - budget;
337 self.trim_context(
338 &mut current_session,
339 &mut relevant_history,
340 &mut project_facts,
341 excess,
342 )?;
343 total_tokens = current_session.iter().map(|c| c.token_count).sum::<i64>()
344 + relevant_history.iter().map(|c| c.token_count).sum::<i64>()
345 + project_facts.iter().map(|c| c.token_count).sum::<i64>();
346 }
347
348 let context = MemoryContext {
349 current_session,
350 relevant_history,
351 project_facts,
352 total_tokens,
353 };
354 let chunks_total = context.current_session.len()
355 + context.relevant_history.len()
356 + context.project_facts.len();
357 let meta = MemoryRetrievalMeta {
358 used: chunks_total > 0,
359 chunks_total,
360 session_chunks: context.current_session.len(),
361 history_chunks: context.relevant_history.len(),
362 project_fact_chunks: context.project_facts.len(),
363 score_min,
364 score_max,
365 };
366
367 Ok((context, meta))
368 }
369
370 fn trim_context(
372 &self,
373 current_session: &mut Vec<MemoryChunk>,
374 relevant_history: &mut Vec<MemoryChunk>,
375 project_facts: &mut Vec<MemoryChunk>,
376 excess_tokens: i64,
377 ) -> MemoryResult<()> {
378 let mut tokens_to_remove = excess_tokens;
379
380 while tokens_to_remove > 0 && !relevant_history.is_empty() {
382 if let Some(chunk) = relevant_history.pop() {
383 tokens_to_remove -= chunk.token_count;
384 }
385 }
386
387 while tokens_to_remove > 0 && !project_facts.is_empty() {
389 if let Some(chunk) = project_facts.pop() {
390 tokens_to_remove -= chunk.token_count;
391 }
392 }
393
394 while tokens_to_remove > 0 && !current_session.is_empty() {
395 if let Some(chunk) = current_session.pop() {
396 tokens_to_remove -= chunk.token_count;
397 }
398 }
399
400 Ok(())
401 }
402
403 pub async fn clear_session(&self, session_id: &str) -> MemoryResult<u64> {
405 let count = self.db.clear_session_memory(session_id).await?;
406
407 self.db
409 .log_cleanup(
410 "manual",
411 MemoryTier::Session,
412 None,
413 Some(session_id),
414 count as i64,
415 0,
416 )
417 .await?;
418
419 Ok(count)
420 }
421
422 pub async fn clear_project(&self, project_id: &str) -> MemoryResult<u64> {
424 let count = self.db.clear_project_memory(project_id).await?;
425
426 self.db
428 .log_cleanup(
429 "manual",
430 MemoryTier::Project,
431 Some(project_id),
432 None,
433 count as i64,
434 0,
435 )
436 .await?;
437
438 Ok(count)
439 }
440
441 pub async fn get_stats(&self) -> MemoryResult<MemoryStats> {
443 self.db.get_stats().await
444 }
445
446 pub async fn get_config(&self, project_id: &str) -> MemoryResult<MemoryConfig> {
448 self.db.get_or_create_config(project_id).await
449 }
450
451 pub async fn set_config(&self, project_id: &str, config: &MemoryConfig) -> MemoryResult<()> {
453 self.db.update_config(project_id, config).await
454 }
455
456 pub async fn run_cleanup(&self, project_id: Option<&str>) -> MemoryResult<u64> {
458 let mut total_cleaned = 0u64;
459
460 if let Some(pid) = project_id {
461 let config = self.db.get_or_create_config(pid).await?;
463
464 if config.auto_cleanup {
465 let cleaned = self
467 .db
468 .cleanup_old_sessions(config.session_retention_days)
469 .await?;
470 total_cleaned += cleaned;
471
472 if cleaned > 0 {
473 self.db
474 .log_cleanup(
475 "auto",
476 MemoryTier::Session,
477 Some(pid),
478 None,
479 cleaned as i64,
480 0,
481 )
482 .await?;
483 }
484 }
485 } else {
486 let cleaned = self.db.cleanup_old_sessions(30).await?;
490 total_cleaned += cleaned;
491 }
492
493 if total_cleaned > 100 {
495 self.db.vacuum().await?;
496 }
497
498 Ok(total_cleaned)
499 }
500
501 async fn maybe_cleanup(&self, project_id: &Option<String>) -> MemoryResult<()> {
503 if let Some(pid) = project_id {
504 let stats = self.db.get_stats().await?;
505 let config = self.db.get_or_create_config(pid).await?;
506
507 if stats.project_chunks > config.max_chunks {
509 let excess = stats.project_chunks - config.max_chunks;
511 tracing::info!("Project {} has {} excess chunks", pid, excess);
514 }
515 }
516
517 Ok(())
518 }
519
520 pub async fn get_cleanup_log(&self, _limit: i64) -> MemoryResult<Vec<CleanupLogEntry>> {
522 Ok(Vec::new())
525 }
526
527 pub fn count_tokens(&self, text: &str) -> usize {
529 self.tokenizer.count_tokens(text)
530 }
531
532 pub async fn embedding_health(&self) -> EmbeddingHealth {
534 let service = self.embedding_service.lock().await;
535 if service.is_available() {
536 EmbeddingHealth {
537 status: "ok".to_string(),
538 reason: None,
539 }
540 } else {
541 EmbeddingHealth {
542 status: "degraded_disabled".to_string(),
543 reason: service.disabled_reason().map(ToString::to_string),
544 }
545 }
546 }
547
548 pub async fn consolidate_session(
550 &self,
551 session_id: &str,
552 project_id: Option<&str>,
553 providers: &ProviderRegistry,
554 config: &MemoryConsolidationConfig,
555 ) -> MemoryResult<Option<String>> {
556 if !config.enabled {
557 return Ok(None);
558 }
559
560 let chunks = self.db.get_session_chunks(session_id).await?;
561 if chunks.is_empty() {
562 return Ok(None);
563 }
564
565 let mut text_parts = Vec::new();
567 for chunk in &chunks {
568 text_parts.push(chunk.content.clone());
569 }
570 let full_text = text_parts.join("\n\n---\n\n");
571
572 let prompt = format!(
574 "Please provide a concise but comprehensive summary of the following chat session. \
575 Focus on the key decisions, technical details, code changes, and unresolved issues. \
576 Do NOT include conversational filler, greetings, or sign-offs. \
577 This summary will be used as long-term memory to recall the context of this work.\n\n\
578 Session transcripts:\n\n{}",
579 full_text
580 );
581
582 let provider_override = config.provider.as_deref().filter(|s| !s.is_empty());
583 let model_override = config.model.as_deref().filter(|s| !s.is_empty());
584
585 let summary_text = match providers
586 .complete_cheapest(&prompt, provider_override, model_override)
587 .await
588 {
589 Ok(s) => s,
590 Err(e) => {
591 tracing::warn!("Memory consolidation LLM failed for session {session_id}: {e}");
592 return Ok(None);
593 }
594 };
595
596 if summary_text.trim().is_empty() {
597 return Ok(None);
598 }
599
600 let embedding = {
602 let service = self.embedding_service.lock().await;
603 service
604 .embed(&summary_text)
605 .await
606 .map_err(|e| crate::types::MemoryError::Embedding(e.to_string()))?
607 };
608
609 let chunk_id = uuid::Uuid::new_v4().to_string();
611 let chunk = MemoryChunk {
612 id: chunk_id,
613 content: summary_text.clone(),
614 tier: MemoryTier::Project,
615 session_id: None, project_id: project_id.map(ToString::to_string),
617 created_at: Utc::now(),
618 source: "consolidation".to_string(),
619 token_count: self.count_tokens(&summary_text) as i64,
620 source_path: None,
621 source_mtime: None,
622 source_size: None,
623 source_hash: None,
624 metadata: None,
625 };
626
627 self.db.store_chunk(&chunk, &embedding).await?;
628
629 self.db.clear_session_memory(session_id).await?;
631
632 tracing::info!(
633 "Session {session_id} consolidated into summary chunk. Original chunks cleared."
634 );
635
636 Ok(Some(summary_text))
637 }
638}
639
640pub async fn create_memory_manager(app_data_dir: &Path) -> MemoryResult<MemoryManager> {
642 let db_path = app_data_dir.join("tandem_memory.db");
643 MemoryManager::new(&db_path).await
644}
645
646#[cfg(test)]
647mod tests {
648 use super::*;
649 use tempfile::TempDir;
650
651 fn is_embeddings_disabled(err: &crate::types::MemoryError) -> bool {
652 matches!(err, crate::types::MemoryError::Embedding(msg) if msg.to_ascii_lowercase().contains("embeddings disabled"))
653 }
654
655 async fn setup_test_manager() -> (MemoryManager, TempDir) {
656 let temp_dir = TempDir::new().unwrap();
657 let db_path = temp_dir.path().join("test_memory.db");
658 let manager = MemoryManager::new(&db_path).await.unwrap();
659 (manager, temp_dir)
660 }
661
662 #[tokio::test]
663 async fn test_store_and_search() {
664 let (manager, _temp) = setup_test_manager().await;
665
666 let request = StoreMessageRequest {
667 content: "This is a test message about artificial intelligence and machine learning."
668 .to_string(),
669 tier: MemoryTier::Project,
670 session_id: Some("session-1".to_string()),
671 project_id: Some("project-1".to_string()),
672 source: "user_message".to_string(),
673 source_path: None,
674 source_mtime: None,
675 source_size: None,
676 source_hash: None,
677 metadata: None,
678 };
679
680 let chunk_ids = match manager.store_message(request).await {
681 Ok(ids) => ids,
682 Err(err) if is_embeddings_disabled(&err) => return,
683 Err(err) => panic!("store_message failed: {err}"),
684 };
685 assert!(!chunk_ids.is_empty());
686
687 let results = manager
689 .search(
690 "artificial intelligence",
691 None,
692 Some("project-1"),
693 None,
694 None,
695 )
696 .await;
697 let results = match results {
698 Ok(results) => results,
699 Err(err) if is_embeddings_disabled(&err) => return,
700 Err(err) => panic!("search failed: {err}"),
701 };
702
703 assert!(!results.is_empty());
704 assert!(results[0].similarity >= 0.0);
706 }
707
708 #[tokio::test]
709 async fn test_retrieve_context() {
710 let (manager, _temp) = setup_test_manager().await;
711
712 let request = StoreMessageRequest {
714 content: "The project uses React and TypeScript for the frontend.".to_string(),
715 tier: MemoryTier::Project,
716 session_id: None,
717 project_id: Some("project-1".to_string()),
718 source: "assistant_response".to_string(),
719 source_path: None,
720 source_mtime: None,
721 source_size: None,
722 source_hash: None,
723 metadata: None,
724 };
725 match manager.store_message(request).await {
726 Ok(_) => {}
727 Err(err) if is_embeddings_disabled(&err) => return,
728 Err(err) => panic!("store_message failed: {err}"),
729 }
730
731 let context = manager
732 .retrieve_context("What technologies are used?", Some("project-1"), None, None)
733 .await;
734 let context = match context {
735 Ok(context) => context,
736 Err(err) if is_embeddings_disabled(&err) => return,
737 Err(err) => panic!("retrieve_context failed: {err}"),
738 };
739
740 assert!(!context.project_facts.is_empty());
741 }
742
743 #[tokio::test]
744 async fn test_retrieve_context_with_meta() {
745 let (manager, _temp) = setup_test_manager().await;
746
747 let request = StoreMessageRequest {
748 content: "The backend uses Rust and sqlite-vec for retrieval.".to_string(),
749 tier: MemoryTier::Project,
750 session_id: None,
751 project_id: Some("project-1".to_string()),
752 source: "assistant_response".to_string(),
753 source_path: None,
754 source_mtime: None,
755 source_size: None,
756 source_hash: None,
757 metadata: None,
758 };
759 match manager.store_message(request).await {
760 Ok(_) => {}
761 Err(err) if is_embeddings_disabled(&err) => return,
762 Err(err) => panic!("store_message failed: {err}"),
763 }
764
765 let result = manager
766 .retrieve_context_with_meta("What does the backend use?", Some("project-1"), None, None)
767 .await;
768 let (context, meta) = match result {
769 Ok(v) => v,
770 Err(err) if is_embeddings_disabled(&err) => return,
771 Err(err) => panic!("retrieve_context_with_meta failed: {err}"),
772 };
773
774 assert!(meta.chunks_total > 0);
775 assert!(meta.used);
776 assert_eq!(
777 meta.chunks_total,
778 context.current_session.len()
779 + context.relevant_history.len()
780 + context.project_facts.len()
781 );
782 assert!(meta.score_min.is_some());
783 assert!(meta.score_max.is_some());
784 }
785
786 #[tokio::test]
787 async fn test_config_management() {
788 let (manager, _temp) = setup_test_manager().await;
789
790 let config = manager.get_config("project-1").await.unwrap();
791 assert_eq!(config.max_chunks, 10000);
792
793 let new_config = MemoryConfig {
794 max_chunks: 5000,
795 retrieval_k: 10,
796 ..Default::default()
797 };
798
799 manager.set_config("project-1", &new_config).await.unwrap();
800
801 let updated = manager.get_config("project-1").await.unwrap();
802 assert_eq!(updated.max_chunks, 5000);
803 assert_eq!(updated.retrieval_k, 10);
804 }
805}