Skip to main content

tandem_memory/
manager.rs

1// Memory Manager Module
2// High-level memory operations (store, retrieve, cleanup)
3
4use crate::chunking::{chunk_text_semantic, ChunkingConfig, Tokenizer};
5use crate::db::MemoryDatabase;
6use crate::embeddings::EmbeddingService;
7use crate::types::{
8    CleanupLogEntry, EmbeddingHealth, MemoryChunk, MemoryConfig, MemoryContext, MemoryResult,
9    MemoryRetrievalMeta, MemorySearchResult, MemoryStats, MemoryTier, StoreMessageRequest,
10};
11use chrono::Utc;
12use std::path::Path;
13use std::sync::Arc;
14use tokio::sync::Mutex;
15
16/// High-level memory manager that coordinates database, embeddings, and chunking
17pub struct MemoryManager {
18    db: Arc<MemoryDatabase>,
19    embedding_service: Arc<Mutex<EmbeddingService>>,
20    tokenizer: Tokenizer,
21}
22
23impl MemoryManager {
24    pub fn db(&self) -> &Arc<MemoryDatabase> {
25        &self.db
26    }
27
28    /// Initialize the memory manager
29    pub async fn new(db_path: &Path) -> MemoryResult<Self> {
30        let db = Arc::new(MemoryDatabase::new(db_path).await?);
31        let embedding_service = Arc::new(Mutex::new(EmbeddingService::new()));
32        let tokenizer = Tokenizer::new()?;
33
34        Ok(Self {
35            db,
36            embedding_service,
37            tokenizer,
38        })
39    }
40
41    /// Store a message in memory
42    ///
43    /// This will:
44    /// 1. Chunk the message content
45    /// 2. Generate embeddings for each chunk
46    /// 3. Store chunks and embeddings in the database
47    pub async fn store_message(&self, request: StoreMessageRequest) -> MemoryResult<Vec<String>> {
48        if self
49            .db
50            .ensure_vector_tables_healthy()
51            .await
52            .unwrap_or(false)
53        {
54            tracing::warn!("Memory vector tables were repaired before storing message chunks");
55        }
56
57        let config = if let Some(ref pid) = request.project_id {
58            self.db.get_or_create_config(pid).await?
59        } else {
60            MemoryConfig::default()
61        };
62
63        // Chunk the content
64        let chunking_config = ChunkingConfig {
65            chunk_size: config.chunk_size as usize,
66            chunk_overlap: config.chunk_overlap as usize,
67            separator: None,
68        };
69
70        let text_chunks = chunk_text_semantic(&request.content, &chunking_config)?;
71
72        if text_chunks.is_empty() {
73            return Ok(Vec::new());
74        }
75
76        let mut chunk_ids = Vec::with_capacity(text_chunks.len());
77        let embedding_service = self.embedding_service.lock().await;
78
79        for text_chunk in text_chunks {
80            let chunk_id = uuid::Uuid::new_v4().to_string();
81
82            // Generate embedding
83            let embedding = embedding_service.embed(&text_chunk.content).await?;
84
85            // Create memory chunk
86            let chunk = MemoryChunk {
87                id: chunk_id.clone(),
88                content: text_chunk.content,
89                tier: request.tier,
90                session_id: request.session_id.clone(),
91                project_id: request.project_id.clone(),
92                source: request.source.clone(),
93                source_path: request.source_path.clone(),
94                source_mtime: request.source_mtime,
95                source_size: request.source_size,
96                source_hash: request.source_hash.clone(),
97                created_at: Utc::now(),
98                token_count: text_chunk.token_count as i64,
99                metadata: request.metadata.clone(),
100            };
101
102            // Store in database (retry once after vector-table self-heal).
103            if let Err(err) = self.db.store_chunk(&chunk, &embedding).await {
104                tracing::warn!("Failed to store memory chunk {}: {}", chunk.id, err);
105                let repaired = self
106                    .db
107                    .ensure_vector_tables_healthy()
108                    .await
109                    .unwrap_or(false);
110                if repaired {
111                    tracing::warn!(
112                        "Retrying memory chunk insert after vector table repair: {}",
113                        chunk.id
114                    );
115                    self.db.store_chunk(&chunk, &embedding).await?;
116                } else {
117                    return Err(err);
118                }
119            }
120            chunk_ids.push(chunk_id);
121        }
122
123        // Check if cleanup is needed
124        if config.auto_cleanup {
125            self.maybe_cleanup(&request.project_id).await?;
126        }
127
128        Ok(chunk_ids)
129    }
130
131    /// Search memory for relevant chunks
132    pub async fn search(
133        &self,
134        query: &str,
135        tier: Option<MemoryTier>,
136        project_id: Option<&str>,
137        session_id: Option<&str>,
138        limit: Option<i64>,
139    ) -> MemoryResult<Vec<MemorySearchResult>> {
140        let effective_limit = limit.unwrap_or(5);
141
142        // Generate query embedding
143        let embedding_service = self.embedding_service.lock().await;
144        let query_embedding = embedding_service.embed(query).await?;
145        drop(embedding_service);
146
147        let mut results = Vec::new();
148
149        // Search in specified tier or all tiers
150        let tiers_to_search = match tier {
151            Some(t) => vec![t],
152            None => {
153                if project_id.is_some() {
154                    vec![MemoryTier::Session, MemoryTier::Project, MemoryTier::Global]
155                } else {
156                    vec![MemoryTier::Session, MemoryTier::Global]
157                }
158            }
159        };
160
161        for search_tier in tiers_to_search {
162            let tier_results = match self
163                .db
164                .search_similar(
165                    &query_embedding,
166                    search_tier,
167                    project_id,
168                    session_id,
169                    effective_limit,
170                )
171                .await
172            {
173                Ok(results) => results,
174                Err(err) => {
175                    tracing::warn!(
176                        "Memory tier search failed for {:?}: {}. Attempting vector repair.",
177                        search_tier,
178                        err
179                    );
180                    let repaired = self
181                        .db
182                        .ensure_vector_tables_healthy()
183                        .await
184                        .unwrap_or(false);
185                    if repaired {
186                        match self
187                            .db
188                            .search_similar(
189                                &query_embedding,
190                                search_tier,
191                                project_id,
192                                session_id,
193                                effective_limit,
194                            )
195                            .await
196                        {
197                            Ok(results) => results,
198                            Err(retry_err) => {
199                                tracing::warn!(
200                                    "Memory tier search still failing for {:?} after repair: {}",
201                                    search_tier,
202                                    retry_err
203                                );
204                                continue;
205                            }
206                        }
207                    } else {
208                        continue;
209                    }
210                }
211            };
212
213            for (chunk, distance) in tier_results {
214                // Convert distance to similarity (cosine similarity)
215                // sqlite-vec returns distance, where lower is more similar
216                // Cosine similarity ranges from -1 to 1, but for normalized vectors it's 0 to 1
217                let similarity = 1.0 - distance.clamp(0.0, 1.0);
218
219                results.push(MemorySearchResult { chunk, similarity });
220            }
221        }
222
223        // Sort by similarity (highest first) and limit results
224        results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
225        results.truncate(effective_limit as usize);
226
227        Ok(results)
228    }
229
230    /// Retrieve context for a message
231    ///
232    /// This retrieves relevant chunks from all tiers and formats them
233    /// for injection into the prompt
234    pub async fn retrieve_context(
235        &self,
236        query: &str,
237        project_id: Option<&str>,
238        session_id: Option<&str>,
239        token_budget: Option<i64>,
240    ) -> MemoryResult<MemoryContext> {
241        let (context, _) = self
242            .retrieve_context_with_meta(query, project_id, session_id, token_budget)
243            .await?;
244        Ok(context)
245    }
246
247    /// Retrieve context plus retrieval metadata for observability.
248    pub async fn retrieve_context_with_meta(
249        &self,
250        query: &str,
251        project_id: Option<&str>,
252        session_id: Option<&str>,
253        token_budget: Option<i64>,
254    ) -> MemoryResult<(MemoryContext, MemoryRetrievalMeta)> {
255        let config = if let Some(pid) = project_id {
256            self.db.get_or_create_config(pid).await?
257        } else {
258            MemoryConfig::default()
259        };
260        let budget = token_budget.unwrap_or(config.token_budget);
261        let retrieval_limit = config.retrieval_k.max(1);
262
263        // Get recent session chunks
264        let current_session = if let Some(sid) = session_id {
265            self.db.get_session_chunks(sid).await?
266        } else {
267            Vec::new()
268        };
269
270        // Search for relevant history
271        let search_results = self
272            .search(query, None, project_id, session_id, Some(retrieval_limit))
273            .await?;
274
275        let mut score_min: Option<f64> = None;
276        let mut score_max: Option<f64> = None;
277        for result in &search_results {
278            score_min = Some(match score_min {
279                Some(current) => current.min(result.similarity),
280                None => result.similarity,
281            });
282            score_max = Some(match score_max {
283                Some(current) => current.max(result.similarity),
284                None => result.similarity,
285            });
286        }
287
288        let mut current_session = current_session;
289        let mut relevant_history = Vec::new();
290        let mut project_facts = Vec::new();
291
292        for result in search_results {
293            match result.chunk.tier {
294                MemoryTier::Project => {
295                    project_facts.push(result.chunk);
296                }
297                MemoryTier::Global => {
298                    project_facts.push(result.chunk);
299                }
300                MemoryTier::Session => {
301                    // Only add to relevant_history if not in current_session
302                    if !current_session.iter().any(|c| c.id == result.chunk.id) {
303                        relevant_history.push(result.chunk);
304                    }
305                }
306            }
307        }
308
309        // Calculate total tokens and trim if necessary
310        let mut total_tokens: i64 = current_session.iter().map(|c| c.token_count).sum();
311        total_tokens += relevant_history.iter().map(|c| c.token_count).sum::<i64>();
312        total_tokens += project_facts.iter().map(|c| c.token_count).sum::<i64>();
313
314        // Trim to fit budget if necessary
315        if total_tokens > budget {
316            let excess = total_tokens - budget;
317            self.trim_context(
318                &mut current_session,
319                &mut relevant_history,
320                &mut project_facts,
321                excess,
322            )?;
323            total_tokens = current_session.iter().map(|c| c.token_count).sum::<i64>()
324                + relevant_history.iter().map(|c| c.token_count).sum::<i64>()
325                + project_facts.iter().map(|c| c.token_count).sum::<i64>();
326        }
327
328        let context = MemoryContext {
329            current_session,
330            relevant_history,
331            project_facts,
332            total_tokens,
333        };
334        let chunks_total = context.current_session.len()
335            + context.relevant_history.len()
336            + context.project_facts.len();
337        let meta = MemoryRetrievalMeta {
338            used: chunks_total > 0,
339            chunks_total,
340            session_chunks: context.current_session.len(),
341            history_chunks: context.relevant_history.len(),
342            project_fact_chunks: context.project_facts.len(),
343            score_min,
344            score_max,
345        };
346
347        Ok((context, meta))
348    }
349
350    /// Trim context to fit within token budget
351    fn trim_context(
352        &self,
353        current_session: &mut Vec<MemoryChunk>,
354        relevant_history: &mut Vec<MemoryChunk>,
355        project_facts: &mut Vec<MemoryChunk>,
356        excess_tokens: i64,
357    ) -> MemoryResult<()> {
358        let mut tokens_to_remove = excess_tokens;
359
360        // First, trim relevant_history (less important than project_facts)
361        while tokens_to_remove > 0 && !relevant_history.is_empty() {
362            if let Some(chunk) = relevant_history.pop() {
363                tokens_to_remove -= chunk.token_count;
364            }
365        }
366
367        // If still over budget, trim project_facts
368        while tokens_to_remove > 0 && !project_facts.is_empty() {
369            if let Some(chunk) = project_facts.pop() {
370                tokens_to_remove -= chunk.token_count;
371            }
372        }
373
374        while tokens_to_remove > 0 && !current_session.is_empty() {
375            if let Some(chunk) = current_session.pop() {
376                tokens_to_remove -= chunk.token_count;
377            }
378        }
379
380        Ok(())
381    }
382
383    /// Clear session memory
384    pub async fn clear_session(&self, session_id: &str) -> MemoryResult<u64> {
385        let count = self.db.clear_session_memory(session_id).await?;
386
387        // Log cleanup
388        self.db
389            .log_cleanup(
390                "manual",
391                MemoryTier::Session,
392                None,
393                Some(session_id),
394                count as i64,
395                0,
396            )
397            .await?;
398
399        Ok(count)
400    }
401
402    /// Clear project memory
403    pub async fn clear_project(&self, project_id: &str) -> MemoryResult<u64> {
404        let count = self.db.clear_project_memory(project_id).await?;
405
406        // Log cleanup
407        self.db
408            .log_cleanup(
409                "manual",
410                MemoryTier::Project,
411                Some(project_id),
412                None,
413                count as i64,
414                0,
415            )
416            .await?;
417
418        Ok(count)
419    }
420
421    /// Get memory statistics
422    pub async fn get_stats(&self) -> MemoryResult<MemoryStats> {
423        self.db.get_stats().await
424    }
425
426    /// Get memory configuration for a project
427    pub async fn get_config(&self, project_id: &str) -> MemoryResult<MemoryConfig> {
428        self.db.get_or_create_config(project_id).await
429    }
430
431    /// Update memory configuration for a project
432    pub async fn set_config(&self, project_id: &str, config: &MemoryConfig) -> MemoryResult<()> {
433        self.db.update_config(project_id, config).await
434    }
435
436    /// Run cleanup based on retention policies
437    pub async fn run_cleanup(&self, project_id: Option<&str>) -> MemoryResult<u64> {
438        let mut total_cleaned = 0u64;
439
440        if let Some(pid) = project_id {
441            // Get config for this project
442            let config = self.db.get_or_create_config(pid).await?;
443
444            if config.auto_cleanup {
445                // Clean up old session memory
446                let cleaned = self
447                    .db
448                    .cleanup_old_sessions(config.session_retention_days)
449                    .await?;
450                total_cleaned += cleaned;
451
452                if cleaned > 0 {
453                    self.db
454                        .log_cleanup(
455                            "auto",
456                            MemoryTier::Session,
457                            Some(pid),
458                            None,
459                            cleaned as i64,
460                            0,
461                        )
462                        .await?;
463                }
464            }
465        } else {
466            // Clean up all projects with auto_cleanup enabled
467            // This would require listing all projects, for now just clean session memory
468            // with a default retention period
469            let cleaned = self.db.cleanup_old_sessions(30).await?;
470            total_cleaned += cleaned;
471        }
472
473        // Vacuum if significant cleanup occurred
474        if total_cleaned > 100 {
475            self.db.vacuum().await?;
476        }
477
478        Ok(total_cleaned)
479    }
480
481    /// Check if cleanup is needed and run it
482    async fn maybe_cleanup(&self, project_id: &Option<String>) -> MemoryResult<()> {
483        if let Some(pid) = project_id {
484            let stats = self.db.get_stats().await?;
485            let config = self.db.get_or_create_config(pid).await?;
486
487            // Check if we're over the chunk limit
488            if stats.project_chunks > config.max_chunks {
489                // Remove oldest chunks
490                let excess = stats.project_chunks - config.max_chunks;
491                // This would require a new DB method to delete oldest chunks
492                // For now, just log
493                tracing::info!("Project {} has {} excess chunks", pid, excess);
494            }
495        }
496
497        Ok(())
498    }
499
500    /// Get cleanup log entries
501    pub async fn get_cleanup_log(&self, _limit: i64) -> MemoryResult<Vec<CleanupLogEntry>> {
502        // This would be implemented in the DB layer
503        // For now, return empty
504        Ok(Vec::new())
505    }
506
507    /// Count tokens in text
508    pub fn count_tokens(&self, text: &str) -> usize {
509        self.tokenizer.count_tokens(text)
510    }
511
512    /// Report embedding backend health for UI/telemetry.
513    pub async fn embedding_health(&self) -> EmbeddingHealth {
514        let service = self.embedding_service.lock().await;
515        if service.is_available() {
516            EmbeddingHealth {
517                status: "ok".to_string(),
518                reason: None,
519            }
520        } else {
521            EmbeddingHealth {
522                status: "degraded_disabled".to_string(),
523                reason: service.disabled_reason().map(ToString::to_string),
524            }
525        }
526    }
527}
528
529/// Create memory manager with default database path
530pub async fn create_memory_manager(app_data_dir: &Path) -> MemoryResult<MemoryManager> {
531    let db_path = app_data_dir.join("tandem_memory.db");
532    MemoryManager::new(&db_path).await
533}
534
535#[cfg(test)]
536mod tests {
537    use super::*;
538    use tempfile::TempDir;
539
540    fn is_embeddings_disabled(err: &crate::types::MemoryError) -> bool {
541        matches!(err, crate::types::MemoryError::Embedding(msg) if msg.to_ascii_lowercase().contains("embeddings disabled"))
542    }
543
544    async fn setup_test_manager() -> (MemoryManager, TempDir) {
545        let temp_dir = TempDir::new().unwrap();
546        let db_path = temp_dir.path().join("test_memory.db");
547        let manager = MemoryManager::new(&db_path).await.unwrap();
548        (manager, temp_dir)
549    }
550
551    #[tokio::test]
552    async fn test_store_and_search() {
553        let (manager, _temp) = setup_test_manager().await;
554
555        let request = StoreMessageRequest {
556            content: "This is a test message about artificial intelligence and machine learning."
557                .to_string(),
558            tier: MemoryTier::Project,
559            session_id: Some("session-1".to_string()),
560            project_id: Some("project-1".to_string()),
561            source: "user_message".to_string(),
562            source_path: None,
563            source_mtime: None,
564            source_size: None,
565            source_hash: None,
566            metadata: None,
567        };
568
569        let chunk_ids = match manager.store_message(request).await {
570            Ok(ids) => ids,
571            Err(err) if is_embeddings_disabled(&err) => return,
572            Err(err) => panic!("store_message failed: {err}"),
573        };
574        assert!(!chunk_ids.is_empty());
575
576        // Search for the content
577        let results = manager
578            .search(
579                "artificial intelligence",
580                None,
581                Some("project-1"),
582                None,
583                None,
584            )
585            .await;
586        let results = match results {
587            Ok(results) => results,
588            Err(err) if is_embeddings_disabled(&err) => return,
589            Err(err) => panic!("search failed: {err}"),
590        };
591
592        assert!(!results.is_empty());
593        // Similarity can be 0.0 with random hash embeddings (orthogonal or negative correlation)
594        assert!(results[0].similarity >= 0.0);
595    }
596
597    #[tokio::test]
598    async fn test_retrieve_context() {
599        let (manager, _temp) = setup_test_manager().await;
600
601        // Store some test data
602        let request = StoreMessageRequest {
603            content: "The project uses React and TypeScript for the frontend.".to_string(),
604            tier: MemoryTier::Project,
605            session_id: None,
606            project_id: Some("project-1".to_string()),
607            source: "assistant_response".to_string(),
608            source_path: None,
609            source_mtime: None,
610            source_size: None,
611            source_hash: None,
612            metadata: None,
613        };
614        match manager.store_message(request).await {
615            Ok(_) => {}
616            Err(err) if is_embeddings_disabled(&err) => return,
617            Err(err) => panic!("store_message failed: {err}"),
618        }
619
620        let context = manager
621            .retrieve_context("What technologies are used?", Some("project-1"), None, None)
622            .await;
623        let context = match context {
624            Ok(context) => context,
625            Err(err) if is_embeddings_disabled(&err) => return,
626            Err(err) => panic!("retrieve_context failed: {err}"),
627        };
628
629        assert!(!context.project_facts.is_empty());
630    }
631
632    #[tokio::test]
633    async fn test_retrieve_context_with_meta() {
634        let (manager, _temp) = setup_test_manager().await;
635
636        let request = StoreMessageRequest {
637            content: "The backend uses Rust and sqlite-vec for retrieval.".to_string(),
638            tier: MemoryTier::Project,
639            session_id: None,
640            project_id: Some("project-1".to_string()),
641            source: "assistant_response".to_string(),
642            source_path: None,
643            source_mtime: None,
644            source_size: None,
645            source_hash: None,
646            metadata: None,
647        };
648        match manager.store_message(request).await {
649            Ok(_) => {}
650            Err(err) if is_embeddings_disabled(&err) => return,
651            Err(err) => panic!("store_message failed: {err}"),
652        }
653
654        let result = manager
655            .retrieve_context_with_meta("What does the backend use?", Some("project-1"), None, None)
656            .await;
657        let (context, meta) = match result {
658            Ok(v) => v,
659            Err(err) if is_embeddings_disabled(&err) => return,
660            Err(err) => panic!("retrieve_context_with_meta failed: {err}"),
661        };
662
663        assert!(meta.chunks_total > 0);
664        assert!(meta.used);
665        assert_eq!(
666            meta.chunks_total,
667            context.current_session.len()
668                + context.relevant_history.len()
669                + context.project_facts.len()
670        );
671        assert!(meta.score_min.is_some());
672        assert!(meta.score_max.is_some());
673    }
674
675    #[tokio::test]
676    async fn test_config_management() {
677        let (manager, _temp) = setup_test_manager().await;
678
679        let config = manager.get_config("project-1").await.unwrap();
680        assert_eq!(config.max_chunks, 10000);
681
682        let new_config = MemoryConfig {
683            max_chunks: 5000,
684            retrieval_k: 10,
685            ..Default::default()
686        };
687
688        manager.set_config("project-1", &new_config).await.unwrap();
689
690        let updated = manager.get_config("project-1").await.unwrap();
691        assert_eq!(updated.max_chunks, 5000);
692        assert_eq!(updated.retrieval_k, 10);
693    }
694}