Skip to main content

tandem_memory/
manager.rs

1// Memory Manager Module
2// High-level memory operations (store, retrieve, cleanup)
3
4use crate::chunking::{chunk_text_semantic, ChunkingConfig, Tokenizer};
5use crate::context_layers::ContextLayerGenerator;
6use crate::context_uri::ContextUri;
7use crate::db::MemoryDatabase;
8use crate::embeddings::EmbeddingService;
9use crate::types::{
10    CleanupLogEntry, DirectoryListing, EmbeddingHealth, LayerType, MemoryChunk, MemoryConfig,
11    MemoryContext, MemoryError, MemoryLayer, MemoryNode, MemoryResult, MemoryRetrievalMeta,
12    MemorySearchResult, MemoryStats, MemoryTier, NodeType, StoreMessageRequest, TreeNode,
13};
14use chrono::Utc;
15use std::path::Path;
16use std::sync::Arc;
17use tandem_providers::{MemoryConsolidationConfig, ProviderRegistry};
18use tokio::sync::Mutex;
19
20/// High-level memory manager that coordinates database, embeddings, and chunking
21pub struct MemoryManager {
22    db: Arc<MemoryDatabase>,
23    embedding_service: Arc<Mutex<EmbeddingService>>,
24    tokenizer: Tokenizer,
25}
26
27impl MemoryManager {
28    fn is_malformed_database_error(err: &crate::types::MemoryError) -> bool {
29        err.to_string()
30            .to_lowercase()
31            .contains("database disk image is malformed")
32    }
33
34    pub fn db(&self) -> &Arc<MemoryDatabase> {
35        &self.db
36    }
37
38    /// Initialize the memory manager
39    pub async fn new(db_path: &Path) -> MemoryResult<Self> {
40        let db = Arc::new(MemoryDatabase::new(db_path).await?);
41        let embedding_service = Arc::new(Mutex::new(EmbeddingService::new()));
42        let tokenizer = Tokenizer::new()?;
43
44        Ok(Self {
45            db,
46            embedding_service,
47            tokenizer,
48        })
49    }
50
51    /// Store a message in memory
52    ///
53    /// This will:
54    /// 1. Chunk the message content
55    /// 2. Generate embeddings for each chunk
56    /// 3. Store chunks and embeddings in the database
57    pub async fn store_message(&self, request: StoreMessageRequest) -> MemoryResult<Vec<String>> {
58        if self
59            .db
60            .ensure_vector_tables_healthy()
61            .await
62            .unwrap_or(false)
63        {
64            tracing::warn!("Memory vector tables were repaired before storing message chunks");
65        }
66
67        let config = if let Some(ref pid) = request.project_id {
68            self.db.get_or_create_config(pid).await?
69        } else {
70            MemoryConfig::default()
71        };
72
73        // Chunk the content
74        let chunking_config = ChunkingConfig {
75            chunk_size: config.chunk_size as usize,
76            chunk_overlap: config.chunk_overlap as usize,
77            separator: None,
78        };
79
80        let text_chunks = chunk_text_semantic(&request.content, &chunking_config)?;
81
82        if text_chunks.is_empty() {
83            return Ok(Vec::new());
84        }
85
86        let mut chunk_ids = Vec::with_capacity(text_chunks.len());
87        let embedding_service = self.embedding_service.lock().await;
88
89        for text_chunk in text_chunks {
90            let chunk_id = uuid::Uuid::new_v4().to_string();
91
92            // Generate embedding
93            let embedding = embedding_service.embed(&text_chunk.content).await?;
94
95            // Create memory chunk
96            let chunk = MemoryChunk {
97                id: chunk_id.clone(),
98                content: text_chunk.content,
99                tier: request.tier,
100                session_id: request.session_id.clone(),
101                project_id: request.project_id.clone(),
102                source: request.source.clone(),
103                source_path: request.source_path.clone(),
104                source_mtime: request.source_mtime,
105                source_size: request.source_size,
106                source_hash: request.source_hash.clone(),
107                created_at: Utc::now(),
108                token_count: text_chunk.token_count as i64,
109                metadata: request.metadata.clone(),
110            };
111
112            // Store in database (retry once after vector-table self-heal).
113            if let Err(err) = self.db.store_chunk(&chunk, &embedding).await {
114                tracing::warn!("Failed to store memory chunk {}: {}", chunk.id, err);
115                let repaired = self.db.try_repair_after_error(&err).await.unwrap_or(false)
116                    || self
117                        .db
118                        .ensure_vector_tables_healthy()
119                        .await
120                        .unwrap_or(false);
121                if repaired {
122                    tracing::warn!(
123                        "Retrying memory chunk insert after vector table repair: {}",
124                        chunk.id
125                    );
126                    if let Err(retry_err) = self.db.store_chunk(&chunk, &embedding).await {
127                        if Self::is_malformed_database_error(&retry_err) {
128                            tracing::warn!(
129                                "Memory DB still malformed after vector repair. Resetting memory tables and retrying chunk insert: {}",
130                                chunk.id
131                            );
132                            self.db.reset_all_memory_tables().await?;
133                            self.db.store_chunk(&chunk, &embedding).await?;
134                        } else {
135                            return Err(retry_err);
136                        }
137                    }
138                } else {
139                    return Err(err);
140                }
141            }
142            chunk_ids.push(chunk_id);
143        }
144
145        // Check if cleanup is needed
146        if config.auto_cleanup {
147            self.maybe_cleanup(&request.project_id).await?;
148        }
149
150        Ok(chunk_ids)
151    }
152
153    /// Search memory for relevant chunks
154    pub async fn search(
155        &self,
156        query: &str,
157        tier: Option<MemoryTier>,
158        project_id: Option<&str>,
159        session_id: Option<&str>,
160        limit: Option<i64>,
161    ) -> MemoryResult<Vec<MemorySearchResult>> {
162        let effective_limit = limit.unwrap_or(5);
163
164        // Generate query embedding
165        let embedding_service = self.embedding_service.lock().await;
166        let query_embedding = embedding_service.embed(query).await?;
167        drop(embedding_service);
168
169        let mut results = Vec::new();
170
171        // Search in specified tier or all tiers
172        let tiers_to_search = match tier {
173            Some(t) => vec![t],
174            None => {
175                if project_id.is_some() {
176                    vec![MemoryTier::Session, MemoryTier::Project, MemoryTier::Global]
177                } else {
178                    vec![MemoryTier::Session, MemoryTier::Global]
179                }
180            }
181        };
182
183        for search_tier in tiers_to_search {
184            let tier_results = match self
185                .db
186                .search_similar(
187                    &query_embedding,
188                    search_tier,
189                    project_id,
190                    session_id,
191                    effective_limit,
192                )
193                .await
194            {
195                Ok(results) => results,
196                Err(err) => {
197                    tracing::warn!(
198                        "Memory tier search failed for {:?}: {}. Attempting vector repair.",
199                        search_tier,
200                        err
201                    );
202                    let repaired = self.db.try_repair_after_error(&err).await.unwrap_or(false)
203                        || self
204                            .db
205                            .ensure_vector_tables_healthy()
206                            .await
207                            .unwrap_or(false);
208                    if repaired {
209                        match self
210                            .db
211                            .search_similar(
212                                &query_embedding,
213                                search_tier,
214                                project_id,
215                                session_id,
216                                effective_limit,
217                            )
218                            .await
219                        {
220                            Ok(results) => results,
221                            Err(retry_err) => {
222                                tracing::warn!(
223                                    "Memory tier search still failing for {:?} after repair: {}",
224                                    search_tier,
225                                    retry_err
226                                );
227                                continue;
228                            }
229                        }
230                    } else {
231                        continue;
232                    }
233                }
234            };
235
236            for (chunk, distance) in tier_results {
237                // Convert distance to similarity (cosine similarity)
238                // sqlite-vec returns distance, where lower is more similar
239                // Cosine similarity ranges from -1 to 1, but for normalized vectors it's 0 to 1
240                let similarity = 1.0 - distance.clamp(0.0, 1.0);
241
242                results.push(MemorySearchResult { chunk, similarity });
243            }
244        }
245
246        // Sort by similarity (highest first) and limit results
247        results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
248        results.truncate(effective_limit as usize);
249
250        Ok(results)
251    }
252
253    /// Retrieve context for a message
254    ///
255    /// This retrieves relevant chunks from all tiers and formats them
256    /// for injection into the prompt
257    pub async fn retrieve_context(
258        &self,
259        query: &str,
260        project_id: Option<&str>,
261        session_id: Option<&str>,
262        token_budget: Option<i64>,
263    ) -> MemoryResult<MemoryContext> {
264        let (context, _) = self
265            .retrieve_context_with_meta(query, project_id, session_id, token_budget)
266            .await?;
267        Ok(context)
268    }
269
270    /// Retrieve context plus retrieval metadata for observability.
271    pub async fn retrieve_context_with_meta(
272        &self,
273        query: &str,
274        project_id: Option<&str>,
275        session_id: Option<&str>,
276        token_budget: Option<i64>,
277    ) -> MemoryResult<(MemoryContext, MemoryRetrievalMeta)> {
278        let config = if let Some(pid) = project_id {
279            self.db.get_or_create_config(pid).await?
280        } else {
281            MemoryConfig::default()
282        };
283        let budget = token_budget.unwrap_or(config.token_budget);
284        let retrieval_limit = config.retrieval_k.max(1);
285
286        // Get recent session chunks
287        let current_session = if let Some(sid) = session_id {
288            self.db.get_session_chunks(sid).await?
289        } else {
290            Vec::new()
291        };
292
293        // Search for relevant history
294        let search_results = self
295            .search(query, None, project_id, session_id, Some(retrieval_limit))
296            .await?;
297
298        let mut score_min: Option<f64> = None;
299        let mut score_max: Option<f64> = None;
300        for result in &search_results {
301            score_min = Some(match score_min {
302                Some(current) => current.min(result.similarity),
303                None => result.similarity,
304            });
305            score_max = Some(match score_max {
306                Some(current) => current.max(result.similarity),
307                None => result.similarity,
308            });
309        }
310
311        let mut current_session = current_session;
312        let mut relevant_history = Vec::new();
313        let mut project_facts = Vec::new();
314
315        for result in search_results {
316            match result.chunk.tier {
317                MemoryTier::Project => {
318                    project_facts.push(result.chunk);
319                }
320                MemoryTier::Global => {
321                    project_facts.push(result.chunk);
322                }
323                MemoryTier::Session => {
324                    // Only add to relevant_history if not in current_session
325                    if !current_session.iter().any(|c| c.id == result.chunk.id) {
326                        relevant_history.push(result.chunk);
327                    }
328                }
329            }
330        }
331
332        // Calculate total tokens and trim if necessary
333        let mut total_tokens: i64 = current_session.iter().map(|c| c.token_count).sum();
334        total_tokens += relevant_history.iter().map(|c| c.token_count).sum::<i64>();
335        total_tokens += project_facts.iter().map(|c| c.token_count).sum::<i64>();
336
337        // Trim to fit budget if necessary
338        if total_tokens > budget {
339            let excess = total_tokens - budget;
340            self.trim_context(
341                &mut current_session,
342                &mut relevant_history,
343                &mut project_facts,
344                excess,
345            )?;
346            total_tokens = current_session.iter().map(|c| c.token_count).sum::<i64>()
347                + relevant_history.iter().map(|c| c.token_count).sum::<i64>()
348                + project_facts.iter().map(|c| c.token_count).sum::<i64>();
349        }
350
351        let context = MemoryContext {
352            current_session,
353            relevant_history,
354            project_facts,
355            total_tokens,
356        };
357        let chunks_total = context.current_session.len()
358            + context.relevant_history.len()
359            + context.project_facts.len();
360        let meta = MemoryRetrievalMeta {
361            used: chunks_total > 0,
362            chunks_total,
363            session_chunks: context.current_session.len(),
364            history_chunks: context.relevant_history.len(),
365            project_fact_chunks: context.project_facts.len(),
366            score_min,
367            score_max,
368        };
369
370        Ok((context, meta))
371    }
372
373    /// Trim context to fit within token budget
374    fn trim_context(
375        &self,
376        current_session: &mut Vec<MemoryChunk>,
377        relevant_history: &mut Vec<MemoryChunk>,
378        project_facts: &mut Vec<MemoryChunk>,
379        excess_tokens: i64,
380    ) -> MemoryResult<()> {
381        let mut tokens_to_remove = excess_tokens;
382
383        // First, trim relevant_history (less important than project_facts)
384        while tokens_to_remove > 0 && !relevant_history.is_empty() {
385            if let Some(chunk) = relevant_history.pop() {
386                tokens_to_remove -= chunk.token_count;
387            }
388        }
389
390        // If still over budget, trim project_facts
391        while tokens_to_remove > 0 && !project_facts.is_empty() {
392            if let Some(chunk) = project_facts.pop() {
393                tokens_to_remove -= chunk.token_count;
394            }
395        }
396
397        while tokens_to_remove > 0 && !current_session.is_empty() {
398            if let Some(chunk) = current_session.pop() {
399                tokens_to_remove -= chunk.token_count;
400            }
401        }
402
403        Ok(())
404    }
405
406    /// Clear session memory
407    pub async fn clear_session(&self, session_id: &str) -> MemoryResult<u64> {
408        let count = self.db.clear_session_memory(session_id).await?;
409
410        // Log cleanup
411        self.db
412            .log_cleanup(
413                "manual",
414                MemoryTier::Session,
415                None,
416                Some(session_id),
417                count as i64,
418                0,
419            )
420            .await?;
421
422        Ok(count)
423    }
424
425    /// Clear project memory
426    pub async fn clear_project(&self, project_id: &str) -> MemoryResult<u64> {
427        let count = self.db.clear_project_memory(project_id).await?;
428
429        // Log cleanup
430        self.db
431            .log_cleanup(
432                "manual",
433                MemoryTier::Project,
434                Some(project_id),
435                None,
436                count as i64,
437                0,
438            )
439            .await?;
440
441        Ok(count)
442    }
443
444    /// Get memory statistics
445    pub async fn get_stats(&self) -> MemoryResult<MemoryStats> {
446        self.db.get_stats().await
447    }
448
449    /// Get memory configuration for a project
450    pub async fn get_config(&self, project_id: &str) -> MemoryResult<MemoryConfig> {
451        self.db.get_or_create_config(project_id).await
452    }
453
454    /// Update memory configuration for a project
455    pub async fn set_config(&self, project_id: &str, config: &MemoryConfig) -> MemoryResult<()> {
456        self.db.update_config(project_id, config).await
457    }
458
459    pub async fn resolve_uri(&self, uri: &str) -> MemoryResult<Option<MemoryNode>> {
460        self.db.get_node_by_uri(uri).await
461    }
462
463    pub async fn list_directory(&self, uri: &str) -> MemoryResult<DirectoryListing> {
464        let nodes = self.db.list_directory(uri).await?;
465        let directories: Vec<MemoryNode> = nodes
466            .iter()
467            .filter(|n| n.node_type == NodeType::Directory)
468            .cloned()
469            .collect();
470        let files: Vec<MemoryNode> = nodes
471            .iter()
472            .filter(|n| n.node_type == NodeType::File)
473            .cloned()
474            .collect();
475
476        Ok(DirectoryListing {
477            uri: uri.to_string(),
478            nodes,
479            total_children: directories.len() + files.len(),
480            directories,
481            files,
482        })
483    }
484
485    pub async fn tree(&self, uri: &str, max_depth: usize) -> MemoryResult<Vec<TreeNode>> {
486        self.db.get_children_tree(uri, max_depth).await
487    }
488
489    pub async fn create_context_node(
490        &self,
491        uri: &str,
492        node_type: NodeType,
493        metadata: Option<serde_json::Value>,
494    ) -> MemoryResult<String> {
495        let parsed_uri =
496            ContextUri::parse(uri).map_err(|e| MemoryError::InvalidConfig(e.message))?;
497        let parent_uri = parsed_uri.parent().map(|p| p.to_string());
498        self.db
499            .create_node(uri, parent_uri.as_deref(), node_type, metadata.as_ref())
500            .await
501    }
502
503    pub async fn get_context_layer(
504        &self,
505        node_id: &str,
506        layer_type: LayerType,
507    ) -> MemoryResult<Option<MemoryLayer>> {
508        self.db.get_layer(node_id, layer_type).await
509    }
510
511    pub async fn store_content_with_layers(
512        &self,
513        uri: &str,
514        content: &str,
515        metadata: Option<serde_json::Value>,
516    ) -> MemoryResult<String> {
517        let parsed_uri =
518            ContextUri::parse(uri).map_err(|e| MemoryError::InvalidConfig(e.message))?;
519        let node_type = if parsed_uri
520            .last_segment()
521            .map(|s| s.ends_with(".md") || s.ends_with(".txt") || s.contains("."))
522            .unwrap_or(false)
523        {
524            NodeType::File
525        } else {
526            NodeType::Directory
527        };
528
529        let parent_uri = parsed_uri.parent().map(|p| p.to_string());
530        let node_id = self
531            .db
532            .create_node(uri, parent_uri.as_deref(), node_type, metadata.as_ref())
533            .await?;
534
535        let token_count = self.tokenizer.count_tokens(content) as i64;
536        self.db
537            .create_layer(&node_id, LayerType::L2, content, token_count, None)
538            .await?;
539
540        Ok(node_id)
541    }
542
543    pub async fn generate_layers_for_node(
544        &self,
545        node_id: &str,
546        providers: &ProviderRegistry,
547    ) -> MemoryResult<()> {
548        let l2_layer = self.db.get_layer(node_id, LayerType::L2).await?;
549        let l2_content = match l2_layer {
550            Some(layer) => layer.content,
551            None => return Ok(()),
552        };
553
554        let generator = ContextLayerGenerator::new(Arc::new(providers.clone()));
555
556        let (l0_content, l1_content) = generator.generate_layers(&l2_content).await?;
557
558        let l0_tokens = self.tokenizer.count_tokens(&l0_content) as i64;
559        let l1_tokens = self.tokenizer.count_tokens(&l1_content) as i64;
560
561        if self.db.get_layer(node_id, LayerType::L0).await?.is_none() {
562            self.db
563                .create_layer(node_id, LayerType::L0, &l0_content, l0_tokens, None)
564                .await?;
565        }
566
567        if self.db.get_layer(node_id, LayerType::L1).await?.is_none() {
568            self.db
569                .create_layer(node_id, LayerType::L1, &l1_content, l1_tokens, None)
570                .await?;
571        }
572
573        Ok(())
574    }
575
576    pub async fn get_layer_content(
577        &self,
578        node_id: &str,
579        layer_type: LayerType,
580    ) -> MemoryResult<Option<String>> {
581        let layer = self.db.get_layer(node_id, layer_type).await?;
582        Ok(layer.map(|l| l.content))
583    }
584
585    pub async fn store_content_with_layers_auto(
586        &self,
587        uri: &str,
588        content: &str,
589        metadata: Option<serde_json::Value>,
590        providers: Option<&ProviderRegistry>,
591    ) -> MemoryResult<String> {
592        let node_id = self
593            .store_content_with_layers(uri, content, metadata)
594            .await?;
595
596        if let Some(p) = providers {
597            if let Err(e) = self.generate_layers_for_node(&node_id, p).await {
598                tracing::warn!("Failed to generate layers for node {}: {}", node_id, e);
599            }
600        }
601
602        Ok(node_id)
603    }
604
605    /// Run cleanup based on retention policies
606    pub async fn run_cleanup(&self, project_id: Option<&str>) -> MemoryResult<u64> {
607        let mut total_cleaned = 0u64;
608
609        if let Some(pid) = project_id {
610            // Get config for this project
611            let config = self.db.get_or_create_config(pid).await?;
612
613            if config.auto_cleanup {
614                // Clean up old session memory
615                let cleaned = self
616                    .db
617                    .cleanup_old_sessions(config.session_retention_days)
618                    .await?;
619                total_cleaned += cleaned;
620
621                if cleaned > 0 {
622                    self.db
623                        .log_cleanup(
624                            "auto",
625                            MemoryTier::Session,
626                            Some(pid),
627                            None,
628                            cleaned as i64,
629                            0,
630                        )
631                        .await?;
632                }
633            }
634        } else {
635            // Clean up all projects with auto_cleanup enabled
636            // This would require listing all projects, for now just clean session memory
637            // with a default retention period
638            let cleaned = self.db.cleanup_old_sessions(30).await?;
639            total_cleaned += cleaned;
640        }
641
642        // Vacuum if significant cleanup occurred
643        if total_cleaned > 100 {
644            self.db.vacuum().await?;
645        }
646
647        Ok(total_cleaned)
648    }
649
650    /// Check if cleanup is needed and run it
651    async fn maybe_cleanup(&self, project_id: &Option<String>) -> MemoryResult<()> {
652        if let Some(pid) = project_id {
653            let stats = self.db.get_stats().await?;
654            let config = self.db.get_or_create_config(pid).await?;
655
656            // Check if we're over the chunk limit
657            if stats.project_chunks > config.max_chunks {
658                // Remove oldest chunks
659                let excess = stats.project_chunks - config.max_chunks;
660                // This would require a new DB method to delete oldest chunks
661                // For now, just log
662                tracing::info!("Project {} has {} excess chunks", pid, excess);
663            }
664        }
665
666        Ok(())
667    }
668
669    /// Get cleanup log entries
670    pub async fn get_cleanup_log(&self, _limit: i64) -> MemoryResult<Vec<CleanupLogEntry>> {
671        // This would be implemented in the DB layer
672        // For now, return empty
673        Ok(Vec::new())
674    }
675
676    /// Count tokens in text
677    pub fn count_tokens(&self, text: &str) -> usize {
678        self.tokenizer.count_tokens(text)
679    }
680
681    /// Report embedding backend health for UI/telemetry.
682    pub async fn embedding_health(&self) -> EmbeddingHealth {
683        let service = self.embedding_service.lock().await;
684        if service.is_available() {
685            EmbeddingHealth {
686                status: "ok".to_string(),
687                reason: None,
688            }
689        } else {
690            EmbeddingHealth {
691                status: "degraded_disabled".to_string(),
692                reason: service.disabled_reason().map(ToString::to_string),
693            }
694        }
695    }
696
697    /// Consolidate a session's memory into a summary chunk using the cheapest available provider.
698    pub async fn consolidate_session(
699        &self,
700        session_id: &str,
701        project_id: Option<&str>,
702        providers: &ProviderRegistry,
703        config: &MemoryConsolidationConfig,
704    ) -> MemoryResult<Option<String>> {
705        if !config.enabled {
706            return Ok(None);
707        }
708
709        let chunks = self.db.get_session_chunks(session_id).await?;
710        if chunks.is_empty() {
711            return Ok(None);
712        }
713
714        // Assemble text
715        let mut text_parts = Vec::new();
716        for chunk in &chunks {
717            text_parts.push(chunk.content.clone());
718        }
719        let full_text = text_parts.join("\n\n---\n\n");
720
721        // Build prompt
722        let prompt = format!(
723            "Please provide a concise but comprehensive summary of the following chat session. \
724            Focus on the key decisions, technical details, code changes, and unresolved issues. \
725            Do NOT include conversational filler, greetings, or sign-offs. \
726            This summary will be used as long-term memory to recall the context of this work.\n\n\
727            Session transcripts:\n\n{}",
728            full_text
729        );
730
731        let provider_override = config.provider.as_deref().filter(|s| !s.is_empty());
732        let model_override = config.model.as_deref().filter(|s| !s.is_empty());
733
734        let summary_text = match providers
735            .complete_cheapest(&prompt, provider_override, model_override)
736            .await
737        {
738            Ok(s) => s,
739            Err(e) => {
740                tracing::warn!("Memory consolidation LLM failed for session {session_id}: {e}");
741                return Ok(None);
742            }
743        };
744
745        if summary_text.trim().is_empty() {
746            return Ok(None);
747        }
748
749        // Generate embedding for the summary
750        let embedding = {
751            let service = self.embedding_service.lock().await;
752            service
753                .embed(&summary_text)
754                .await
755                .map_err(|e| crate::types::MemoryError::Embedding(e.to_string()))?
756        };
757
758        // Store the summary chunk
759        let chunk_id = uuid::Uuid::new_v4().to_string();
760        let chunk = MemoryChunk {
761            id: chunk_id,
762            content: summary_text.clone(),
763            tier: MemoryTier::Project,
764            session_id: None, // The summary belongs to the project, not the ephemeral session
765            project_id: project_id.map(ToString::to_string),
766            created_at: Utc::now(),
767            source: "consolidation".to_string(),
768            token_count: self.count_tokens(&summary_text) as i64,
769            source_path: None,
770            source_mtime: None,
771            source_size: None,
772            source_hash: None,
773            metadata: None,
774        };
775
776        self.db.store_chunk(&chunk, &embedding).await?;
777
778        // Clear original chunks now that they are consolidated
779        self.db.clear_session_memory(session_id).await?;
780
781        tracing::info!(
782            "Session {session_id} consolidated into summary chunk. Original chunks cleared."
783        );
784
785        Ok(Some(summary_text))
786    }
787}
788
789/// Create memory manager with default database path
790pub async fn create_memory_manager(app_data_dir: &Path) -> MemoryResult<MemoryManager> {
791    let db_path = app_data_dir.join("tandem_memory.db");
792    MemoryManager::new(&db_path).await
793}
794
795#[cfg(test)]
796mod tests {
797    use super::*;
798    use tempfile::TempDir;
799
800    fn is_embeddings_disabled(err: &crate::types::MemoryError) -> bool {
801        matches!(err, crate::types::MemoryError::Embedding(msg) if msg.to_ascii_lowercase().contains("embeddings disabled"))
802    }
803
804    async fn setup_test_manager() -> (MemoryManager, TempDir) {
805        let temp_dir = TempDir::new().unwrap();
806        let db_path = temp_dir.path().join("test_memory.db");
807        let manager = MemoryManager::new(&db_path).await.unwrap();
808        (manager, temp_dir)
809    }
810
811    #[tokio::test]
812    async fn test_store_and_search() {
813        let (manager, _temp) = setup_test_manager().await;
814
815        let request = StoreMessageRequest {
816            content: "This is a test message about artificial intelligence and machine learning."
817                .to_string(),
818            tier: MemoryTier::Project,
819            session_id: Some("session-1".to_string()),
820            project_id: Some("project-1".to_string()),
821            source: "user_message".to_string(),
822            source_path: None,
823            source_mtime: None,
824            source_size: None,
825            source_hash: None,
826            metadata: None,
827        };
828
829        let chunk_ids = match manager.store_message(request).await {
830            Ok(ids) => ids,
831            Err(err) if is_embeddings_disabled(&err) => return,
832            Err(err) => panic!("store_message failed: {err}"),
833        };
834        assert!(!chunk_ids.is_empty());
835
836        // Search for the content
837        let results = manager
838            .search(
839                "artificial intelligence",
840                None,
841                Some("project-1"),
842                None,
843                None,
844            )
845            .await;
846        let results = match results {
847            Ok(results) => results,
848            Err(err) if is_embeddings_disabled(&err) => return,
849            Err(err) => panic!("search failed: {err}"),
850        };
851
852        assert!(!results.is_empty());
853        // Similarity can be 0.0 with random hash embeddings (orthogonal or negative correlation)
854        assert!(results[0].similarity >= 0.0);
855    }
856
857    #[tokio::test]
858    async fn test_retrieve_context() {
859        let (manager, _temp) = setup_test_manager().await;
860
861        // Store some test data
862        let request = StoreMessageRequest {
863            content: "The project uses React and TypeScript for the frontend.".to_string(),
864            tier: MemoryTier::Project,
865            session_id: None,
866            project_id: Some("project-1".to_string()),
867            source: "assistant_response".to_string(),
868            source_path: None,
869            source_mtime: None,
870            source_size: None,
871            source_hash: None,
872            metadata: None,
873        };
874        match manager.store_message(request).await {
875            Ok(_) => {}
876            Err(err) if is_embeddings_disabled(&err) => return,
877            Err(err) => panic!("store_message failed: {err}"),
878        }
879
880        let context = manager
881            .retrieve_context("What technologies are used?", Some("project-1"), None, None)
882            .await;
883        let context = match context {
884            Ok(context) => context,
885            Err(err) if is_embeddings_disabled(&err) => return,
886            Err(err) => panic!("retrieve_context failed: {err}"),
887        };
888
889        assert!(!context.project_facts.is_empty());
890    }
891
892    #[tokio::test]
893    async fn test_retrieve_context_with_meta() {
894        let (manager, _temp) = setup_test_manager().await;
895
896        let request = StoreMessageRequest {
897            content: "The backend uses Rust and sqlite-vec for retrieval.".to_string(),
898            tier: MemoryTier::Project,
899            session_id: None,
900            project_id: Some("project-1".to_string()),
901            source: "assistant_response".to_string(),
902            source_path: None,
903            source_mtime: None,
904            source_size: None,
905            source_hash: None,
906            metadata: None,
907        };
908        match manager.store_message(request).await {
909            Ok(_) => {}
910            Err(err) if is_embeddings_disabled(&err) => return,
911            Err(err) => panic!("store_message failed: {err}"),
912        }
913
914        let result = manager
915            .retrieve_context_with_meta("What does the backend use?", Some("project-1"), None, None)
916            .await;
917        let (context, meta) = match result {
918            Ok(v) => v,
919            Err(err) if is_embeddings_disabled(&err) => return,
920            Err(err) => panic!("retrieve_context_with_meta failed: {err}"),
921        };
922
923        assert!(meta.chunks_total > 0);
924        assert!(meta.used);
925        assert_eq!(
926            meta.chunks_total,
927            context.current_session.len()
928                + context.relevant_history.len()
929                + context.project_facts.len()
930        );
931        assert!(meta.score_min.is_some());
932        assert!(meta.score_max.is_some());
933    }
934
935    #[tokio::test]
936    async fn test_config_management() {
937        let (manager, _temp) = setup_test_manager().await;
938
939        let config = manager.get_config("project-1").await.unwrap();
940        assert_eq!(config.max_chunks, 10000);
941
942        let new_config = MemoryConfig {
943            max_chunks: 5000,
944            retrieval_k: 10,
945            ..Default::default()
946        };
947
948        manager.set_config("project-1", &new_config).await.unwrap();
949
950        let updated = manager.get_config("project-1").await.unwrap();
951        assert_eq!(updated.max_chunks, 5000);
952        assert_eq!(updated.retrieval_k, 10);
953    }
954}