1use crate::chunking::{chunk_text_semantic, ChunkingConfig, Tokenizer};
5use crate::context_layers::ContextLayerGenerator;
6use crate::context_uri::ContextUri;
7use crate::db::MemoryDatabase;
8use crate::embeddings::EmbeddingService;
9use crate::types::{
10 CleanupLogEntry, DirectoryListing, EmbeddingHealth, LayerType, MemoryChunk, MemoryConfig,
11 MemoryContext, MemoryError, MemoryLayer, MemoryNode, MemoryResult, MemoryRetrievalMeta,
12 MemorySearchResult, MemoryStats, MemoryTier, NodeType, StoreMessageRequest, TreeNode,
13};
14use chrono::Utc;
15use std::path::Path;
16use std::sync::Arc;
17use tandem_providers::{MemoryConsolidationConfig, ProviderRegistry};
18use tokio::sync::Mutex;
19
20pub struct MemoryManager {
22 db: Arc<MemoryDatabase>,
23 embedding_service: Arc<Mutex<EmbeddingService>>,
24 tokenizer: Tokenizer,
25}
26
27impl MemoryManager {
28 fn is_malformed_database_error(err: &crate::types::MemoryError) -> bool {
29 err.to_string()
30 .to_lowercase()
31 .contains("database disk image is malformed")
32 }
33
34 pub fn db(&self) -> &Arc<MemoryDatabase> {
35 &self.db
36 }
37
38 pub async fn new(db_path: &Path) -> MemoryResult<Self> {
40 let db = Arc::new(MemoryDatabase::new(db_path).await?);
41 let embedding_service = Arc::new(Mutex::new(EmbeddingService::new()));
42 let tokenizer = Tokenizer::new()?;
43
44 Ok(Self {
45 db,
46 embedding_service,
47 tokenizer,
48 })
49 }
50
51 pub async fn store_message(&self, request: StoreMessageRequest) -> MemoryResult<Vec<String>> {
58 if self
59 .db
60 .ensure_vector_tables_healthy()
61 .await
62 .unwrap_or(false)
63 {
64 tracing::warn!("Memory vector tables were repaired before storing message chunks");
65 }
66
67 let config = if let Some(ref pid) = request.project_id {
68 self.db.get_or_create_config(pid).await?
69 } else {
70 MemoryConfig::default()
71 };
72
73 let chunking_config = ChunkingConfig {
75 chunk_size: config.chunk_size as usize,
76 chunk_overlap: config.chunk_overlap as usize,
77 separator: None,
78 };
79
80 let text_chunks = chunk_text_semantic(&request.content, &chunking_config)?;
81
82 if text_chunks.is_empty() {
83 return Ok(Vec::new());
84 }
85
86 let mut chunk_ids = Vec::with_capacity(text_chunks.len());
87 let embedding_service = self.embedding_service.lock().await;
88
89 for text_chunk in text_chunks {
90 let chunk_id = uuid::Uuid::new_v4().to_string();
91
92 let embedding = embedding_service.embed(&text_chunk.content).await?;
94
95 let chunk = MemoryChunk {
97 id: chunk_id.clone(),
98 content: text_chunk.content,
99 tier: request.tier,
100 session_id: request.session_id.clone(),
101 project_id: request.project_id.clone(),
102 source: request.source.clone(),
103 source_path: request.source_path.clone(),
104 source_mtime: request.source_mtime,
105 source_size: request.source_size,
106 source_hash: request.source_hash.clone(),
107 created_at: Utc::now(),
108 token_count: text_chunk.token_count as i64,
109 metadata: request.metadata.clone(),
110 };
111
112 if let Err(err) = self.db.store_chunk(&chunk, &embedding).await {
114 tracing::warn!("Failed to store memory chunk {}: {}", chunk.id, err);
115 let repaired = self.db.try_repair_after_error(&err).await.unwrap_or(false)
116 || self
117 .db
118 .ensure_vector_tables_healthy()
119 .await
120 .unwrap_or(false);
121 if repaired {
122 tracing::warn!(
123 "Retrying memory chunk insert after vector table repair: {}",
124 chunk.id
125 );
126 if let Err(retry_err) = self.db.store_chunk(&chunk, &embedding).await {
127 if Self::is_malformed_database_error(&retry_err) {
128 tracing::warn!(
129 "Memory DB still malformed after vector repair. Resetting memory tables and retrying chunk insert: {}",
130 chunk.id
131 );
132 self.db.reset_all_memory_tables().await?;
133 self.db.store_chunk(&chunk, &embedding).await?;
134 } else {
135 return Err(retry_err);
136 }
137 }
138 } else {
139 return Err(err);
140 }
141 }
142 chunk_ids.push(chunk_id);
143 }
144
145 if config.auto_cleanup {
147 self.maybe_cleanup(&request.project_id).await?;
148 }
149
150 Ok(chunk_ids)
151 }
152
153 pub async fn search(
155 &self,
156 query: &str,
157 tier: Option<MemoryTier>,
158 project_id: Option<&str>,
159 session_id: Option<&str>,
160 limit: Option<i64>,
161 ) -> MemoryResult<Vec<MemorySearchResult>> {
162 let effective_limit = limit.unwrap_or(5);
163
164 let embedding_service = self.embedding_service.lock().await;
166 let query_embedding = embedding_service.embed(query).await?;
167 drop(embedding_service);
168
169 let mut results = Vec::new();
170
171 let tiers_to_search = match tier {
173 Some(t) => vec![t],
174 None => {
175 if project_id.is_some() {
176 vec![MemoryTier::Session, MemoryTier::Project, MemoryTier::Global]
177 } else {
178 vec![MemoryTier::Session, MemoryTier::Global]
179 }
180 }
181 };
182
183 for search_tier in tiers_to_search {
184 let tier_results = match self
185 .db
186 .search_similar(
187 &query_embedding,
188 search_tier,
189 project_id,
190 session_id,
191 effective_limit,
192 )
193 .await
194 {
195 Ok(results) => results,
196 Err(err) => {
197 tracing::warn!(
198 "Memory tier search failed for {:?}: {}. Attempting vector repair.",
199 search_tier,
200 err
201 );
202 let repaired = self.db.try_repair_after_error(&err).await.unwrap_or(false)
203 || self
204 .db
205 .ensure_vector_tables_healthy()
206 .await
207 .unwrap_or(false);
208 if repaired {
209 match self
210 .db
211 .search_similar(
212 &query_embedding,
213 search_tier,
214 project_id,
215 session_id,
216 effective_limit,
217 )
218 .await
219 {
220 Ok(results) => results,
221 Err(retry_err) => {
222 tracing::warn!(
223 "Memory tier search still failing for {:?} after repair: {}",
224 search_tier,
225 retry_err
226 );
227 continue;
228 }
229 }
230 } else {
231 continue;
232 }
233 }
234 };
235
236 for (chunk, distance) in tier_results {
237 let similarity = 1.0 - distance.clamp(0.0, 1.0);
241
242 results.push(MemorySearchResult { chunk, similarity });
243 }
244 }
245
246 results.sort_by(|a, b| b.similarity.partial_cmp(&a.similarity).unwrap());
248 results.truncate(effective_limit as usize);
249
250 Ok(results)
251 }
252
253 pub async fn retrieve_context(
258 &self,
259 query: &str,
260 project_id: Option<&str>,
261 session_id: Option<&str>,
262 token_budget: Option<i64>,
263 ) -> MemoryResult<MemoryContext> {
264 let (context, _) = self
265 .retrieve_context_with_meta(query, project_id, session_id, token_budget)
266 .await?;
267 Ok(context)
268 }
269
270 pub async fn retrieve_context_with_meta(
272 &self,
273 query: &str,
274 project_id: Option<&str>,
275 session_id: Option<&str>,
276 token_budget: Option<i64>,
277 ) -> MemoryResult<(MemoryContext, MemoryRetrievalMeta)> {
278 let config = if let Some(pid) = project_id {
279 self.db.get_or_create_config(pid).await?
280 } else {
281 MemoryConfig::default()
282 };
283 let budget = token_budget.unwrap_or(config.token_budget);
284 let retrieval_limit = config.retrieval_k.max(1);
285
286 let current_session = if let Some(sid) = session_id {
288 self.db.get_session_chunks(sid).await?
289 } else {
290 Vec::new()
291 };
292
293 let search_results = self
295 .search(query, None, project_id, session_id, Some(retrieval_limit))
296 .await?;
297
298 let mut score_min: Option<f64> = None;
299 let mut score_max: Option<f64> = None;
300 for result in &search_results {
301 score_min = Some(match score_min {
302 Some(current) => current.min(result.similarity),
303 None => result.similarity,
304 });
305 score_max = Some(match score_max {
306 Some(current) => current.max(result.similarity),
307 None => result.similarity,
308 });
309 }
310
311 let mut current_session = current_session;
312 let mut relevant_history = Vec::new();
313 let mut project_facts = Vec::new();
314
315 for result in search_results {
316 match result.chunk.tier {
317 MemoryTier::Project => {
318 project_facts.push(result.chunk);
319 }
320 MemoryTier::Global => {
321 project_facts.push(result.chunk);
322 }
323 MemoryTier::Session => {
324 if !current_session.iter().any(|c| c.id == result.chunk.id) {
326 relevant_history.push(result.chunk);
327 }
328 }
329 }
330 }
331
332 let mut total_tokens: i64 = current_session.iter().map(|c| c.token_count).sum();
334 total_tokens += relevant_history.iter().map(|c| c.token_count).sum::<i64>();
335 total_tokens += project_facts.iter().map(|c| c.token_count).sum::<i64>();
336
337 if total_tokens > budget {
339 let excess = total_tokens - budget;
340 self.trim_context(
341 &mut current_session,
342 &mut relevant_history,
343 &mut project_facts,
344 excess,
345 )?;
346 total_tokens = current_session.iter().map(|c| c.token_count).sum::<i64>()
347 + relevant_history.iter().map(|c| c.token_count).sum::<i64>()
348 + project_facts.iter().map(|c| c.token_count).sum::<i64>();
349 }
350
351 let context = MemoryContext {
352 current_session,
353 relevant_history,
354 project_facts,
355 total_tokens,
356 };
357 let chunks_total = context.current_session.len()
358 + context.relevant_history.len()
359 + context.project_facts.len();
360 let meta = MemoryRetrievalMeta {
361 used: chunks_total > 0,
362 chunks_total,
363 session_chunks: context.current_session.len(),
364 history_chunks: context.relevant_history.len(),
365 project_fact_chunks: context.project_facts.len(),
366 score_min,
367 score_max,
368 };
369
370 Ok((context, meta))
371 }
372
373 fn trim_context(
375 &self,
376 current_session: &mut Vec<MemoryChunk>,
377 relevant_history: &mut Vec<MemoryChunk>,
378 project_facts: &mut Vec<MemoryChunk>,
379 excess_tokens: i64,
380 ) -> MemoryResult<()> {
381 let mut tokens_to_remove = excess_tokens;
382
383 while tokens_to_remove > 0 && !relevant_history.is_empty() {
385 if let Some(chunk) = relevant_history.pop() {
386 tokens_to_remove -= chunk.token_count;
387 }
388 }
389
390 while tokens_to_remove > 0 && !project_facts.is_empty() {
392 if let Some(chunk) = project_facts.pop() {
393 tokens_to_remove -= chunk.token_count;
394 }
395 }
396
397 while tokens_to_remove > 0 && !current_session.is_empty() {
398 if let Some(chunk) = current_session.pop() {
399 tokens_to_remove -= chunk.token_count;
400 }
401 }
402
403 Ok(())
404 }
405
406 pub async fn clear_session(&self, session_id: &str) -> MemoryResult<u64> {
408 let count = self.db.clear_session_memory(session_id).await?;
409
410 self.db
412 .log_cleanup(
413 "manual",
414 MemoryTier::Session,
415 None,
416 Some(session_id),
417 count as i64,
418 0,
419 )
420 .await?;
421
422 Ok(count)
423 }
424
425 pub async fn clear_project(&self, project_id: &str) -> MemoryResult<u64> {
427 let count = self.db.clear_project_memory(project_id).await?;
428
429 self.db
431 .log_cleanup(
432 "manual",
433 MemoryTier::Project,
434 Some(project_id),
435 None,
436 count as i64,
437 0,
438 )
439 .await?;
440
441 Ok(count)
442 }
443
444 pub async fn get_stats(&self) -> MemoryResult<MemoryStats> {
446 self.db.get_stats().await
447 }
448
449 pub async fn get_config(&self, project_id: &str) -> MemoryResult<MemoryConfig> {
451 self.db.get_or_create_config(project_id).await
452 }
453
454 pub async fn set_config(&self, project_id: &str, config: &MemoryConfig) -> MemoryResult<()> {
456 self.db.update_config(project_id, config).await
457 }
458
459 pub async fn resolve_uri(&self, uri: &str) -> MemoryResult<Option<MemoryNode>> {
460 self.db.get_node_by_uri(uri).await
461 }
462
463 pub async fn list_directory(&self, uri: &str) -> MemoryResult<DirectoryListing> {
464 let nodes = self.db.list_directory(uri).await?;
465 let directories: Vec<MemoryNode> = nodes
466 .iter()
467 .filter(|n| n.node_type == NodeType::Directory)
468 .cloned()
469 .collect();
470 let files: Vec<MemoryNode> = nodes
471 .iter()
472 .filter(|n| n.node_type == NodeType::File)
473 .cloned()
474 .collect();
475
476 Ok(DirectoryListing {
477 uri: uri.to_string(),
478 nodes,
479 total_children: directories.len() + files.len(),
480 directories,
481 files,
482 })
483 }
484
485 pub async fn tree(&self, uri: &str, max_depth: usize) -> MemoryResult<Vec<TreeNode>> {
486 self.db.get_children_tree(uri, max_depth).await
487 }
488
489 pub async fn create_context_node(
490 &self,
491 uri: &str,
492 node_type: NodeType,
493 metadata: Option<serde_json::Value>,
494 ) -> MemoryResult<String> {
495 let parsed_uri =
496 ContextUri::parse(uri).map_err(|e| MemoryError::InvalidConfig(e.message))?;
497 let parent_uri = parsed_uri.parent().map(|p| p.to_string());
498 self.db
499 .create_node(uri, parent_uri.as_deref(), node_type, metadata.as_ref())
500 .await
501 }
502
503 pub async fn get_context_layer(
504 &self,
505 node_id: &str,
506 layer_type: LayerType,
507 ) -> MemoryResult<Option<MemoryLayer>> {
508 self.db.get_layer(node_id, layer_type).await
509 }
510
511 pub async fn store_content_with_layers(
512 &self,
513 uri: &str,
514 content: &str,
515 metadata: Option<serde_json::Value>,
516 ) -> MemoryResult<String> {
517 let parsed_uri =
518 ContextUri::parse(uri).map_err(|e| MemoryError::InvalidConfig(e.message))?;
519 let node_type = if parsed_uri
520 .last_segment()
521 .map(|s| s.ends_with(".md") || s.ends_with(".txt") || s.contains("."))
522 .unwrap_or(false)
523 {
524 NodeType::File
525 } else {
526 NodeType::Directory
527 };
528
529 let parent_uri = parsed_uri.parent().map(|p| p.to_string());
530 let node_id = self
531 .db
532 .create_node(uri, parent_uri.as_deref(), node_type, metadata.as_ref())
533 .await?;
534
535 let token_count = self.tokenizer.count_tokens(content) as i64;
536 self.db
537 .create_layer(&node_id, LayerType::L2, content, token_count, None)
538 .await?;
539
540 Ok(node_id)
541 }
542
543 pub async fn generate_layers_for_node(
544 &self,
545 node_id: &str,
546 providers: &ProviderRegistry,
547 ) -> MemoryResult<()> {
548 let l2_layer = self.db.get_layer(node_id, LayerType::L2).await?;
549 let l2_content = match l2_layer {
550 Some(layer) => layer.content,
551 None => return Ok(()),
552 };
553
554 let generator = ContextLayerGenerator::new(Arc::new(providers.clone()));
555
556 let (l0_content, l1_content) = generator.generate_layers(&l2_content).await?;
557
558 let l0_tokens = self.tokenizer.count_tokens(&l0_content) as i64;
559 let l1_tokens = self.tokenizer.count_tokens(&l1_content) as i64;
560
561 if self.db.get_layer(node_id, LayerType::L0).await?.is_none() {
562 self.db
563 .create_layer(node_id, LayerType::L0, &l0_content, l0_tokens, None)
564 .await?;
565 }
566
567 if self.db.get_layer(node_id, LayerType::L1).await?.is_none() {
568 self.db
569 .create_layer(node_id, LayerType::L1, &l1_content, l1_tokens, None)
570 .await?;
571 }
572
573 Ok(())
574 }
575
576 pub async fn get_layer_content(
577 &self,
578 node_id: &str,
579 layer_type: LayerType,
580 ) -> MemoryResult<Option<String>> {
581 let layer = self.db.get_layer(node_id, layer_type).await?;
582 Ok(layer.map(|l| l.content))
583 }
584
585 pub async fn store_content_with_layers_auto(
586 &self,
587 uri: &str,
588 content: &str,
589 metadata: Option<serde_json::Value>,
590 providers: Option<&ProviderRegistry>,
591 ) -> MemoryResult<String> {
592 let node_id = self
593 .store_content_with_layers(uri, content, metadata)
594 .await?;
595
596 if let Some(p) = providers {
597 if let Err(e) = self.generate_layers_for_node(&node_id, p).await {
598 tracing::warn!("Failed to generate layers for node {}: {}", node_id, e);
599 }
600 }
601
602 Ok(node_id)
603 }
604
605 pub async fn run_cleanup(&self, project_id: Option<&str>) -> MemoryResult<u64> {
607 let mut total_cleaned = 0u64;
608
609 if let Some(pid) = project_id {
610 let config = self.db.get_or_create_config(pid).await?;
612
613 if config.auto_cleanup {
614 let cleaned = self
616 .db
617 .cleanup_old_sessions(config.session_retention_days)
618 .await?;
619 total_cleaned += cleaned;
620
621 if cleaned > 0 {
622 self.db
623 .log_cleanup(
624 "auto",
625 MemoryTier::Session,
626 Some(pid),
627 None,
628 cleaned as i64,
629 0,
630 )
631 .await?;
632 }
633 }
634 } else {
635 let cleaned = self.db.cleanup_old_sessions(30).await?;
639 total_cleaned += cleaned;
640 }
641
642 if total_cleaned > 100 {
644 self.db.vacuum().await?;
645 }
646
647 Ok(total_cleaned)
648 }
649
650 async fn maybe_cleanup(&self, project_id: &Option<String>) -> MemoryResult<()> {
652 if let Some(pid) = project_id {
653 let stats = self.db.get_stats().await?;
654 let config = self.db.get_or_create_config(pid).await?;
655
656 if stats.project_chunks > config.max_chunks {
658 let excess = stats.project_chunks - config.max_chunks;
660 tracing::info!("Project {} has {} excess chunks", pid, excess);
663 }
664 }
665
666 Ok(())
667 }
668
669 pub async fn get_cleanup_log(&self, _limit: i64) -> MemoryResult<Vec<CleanupLogEntry>> {
671 Ok(Vec::new())
674 }
675
676 pub fn count_tokens(&self, text: &str) -> usize {
678 self.tokenizer.count_tokens(text)
679 }
680
681 pub async fn embedding_health(&self) -> EmbeddingHealth {
683 let service = self.embedding_service.lock().await;
684 if service.is_available() {
685 EmbeddingHealth {
686 status: "ok".to_string(),
687 reason: None,
688 }
689 } else {
690 EmbeddingHealth {
691 status: "degraded_disabled".to_string(),
692 reason: service.disabled_reason().map(ToString::to_string),
693 }
694 }
695 }
696
697 pub async fn consolidate_session(
699 &self,
700 session_id: &str,
701 project_id: Option<&str>,
702 providers: &ProviderRegistry,
703 config: &MemoryConsolidationConfig,
704 ) -> MemoryResult<Option<String>> {
705 if !config.enabled {
706 return Ok(None);
707 }
708
709 let chunks = self.db.get_session_chunks(session_id).await?;
710 if chunks.is_empty() {
711 return Ok(None);
712 }
713
714 let mut text_parts = Vec::new();
716 for chunk in &chunks {
717 text_parts.push(chunk.content.clone());
718 }
719 let full_text = text_parts.join("\n\n---\n\n");
720
721 let prompt = format!(
723 "Please provide a concise but comprehensive summary of the following chat session. \
724 Focus on the key decisions, technical details, code changes, and unresolved issues. \
725 Do NOT include conversational filler, greetings, or sign-offs. \
726 This summary will be used as long-term memory to recall the context of this work.\n\n\
727 Session transcripts:\n\n{}",
728 full_text
729 );
730
731 let provider_override = config.provider.as_deref().filter(|s| !s.is_empty());
732 let model_override = config.model.as_deref().filter(|s| !s.is_empty());
733
734 let summary_text = match providers
735 .complete_cheapest(&prompt, provider_override, model_override)
736 .await
737 {
738 Ok(s) => s,
739 Err(e) => {
740 tracing::warn!("Memory consolidation LLM failed for session {session_id}: {e}");
741 return Ok(None);
742 }
743 };
744
745 if summary_text.trim().is_empty() {
746 return Ok(None);
747 }
748
749 let embedding = {
751 let service = self.embedding_service.lock().await;
752 service
753 .embed(&summary_text)
754 .await
755 .map_err(|e| crate::types::MemoryError::Embedding(e.to_string()))?
756 };
757
758 let chunk_id = uuid::Uuid::new_v4().to_string();
760 let chunk = MemoryChunk {
761 id: chunk_id,
762 content: summary_text.clone(),
763 tier: MemoryTier::Project,
764 session_id: None, project_id: project_id.map(ToString::to_string),
766 created_at: Utc::now(),
767 source: "consolidation".to_string(),
768 token_count: self.count_tokens(&summary_text) as i64,
769 source_path: None,
770 source_mtime: None,
771 source_size: None,
772 source_hash: None,
773 metadata: None,
774 };
775
776 self.db.store_chunk(&chunk, &embedding).await?;
777
778 self.db.clear_session_memory(session_id).await?;
780
781 tracing::info!(
782 "Session {session_id} consolidated into summary chunk. Original chunks cleared."
783 );
784
785 Ok(Some(summary_text))
786 }
787}
788
789pub async fn create_memory_manager(app_data_dir: &Path) -> MemoryResult<MemoryManager> {
791 let db_path = app_data_dir.join("tandem_memory.db");
792 MemoryManager::new(&db_path).await
793}
794
795#[cfg(test)]
796mod tests {
797 use super::*;
798 use tempfile::TempDir;
799
800 fn is_embeddings_disabled(err: &crate::types::MemoryError) -> bool {
801 matches!(err, crate::types::MemoryError::Embedding(msg) if msg.to_ascii_lowercase().contains("embeddings disabled"))
802 }
803
804 async fn setup_test_manager() -> (MemoryManager, TempDir) {
805 let temp_dir = TempDir::new().unwrap();
806 let db_path = temp_dir.path().join("test_memory.db");
807 let manager = MemoryManager::new(&db_path).await.unwrap();
808 (manager, temp_dir)
809 }
810
811 #[tokio::test]
812 async fn test_store_and_search() {
813 let (manager, _temp) = setup_test_manager().await;
814
815 let request = StoreMessageRequest {
816 content: "This is a test message about artificial intelligence and machine learning."
817 .to_string(),
818 tier: MemoryTier::Project,
819 session_id: Some("session-1".to_string()),
820 project_id: Some("project-1".to_string()),
821 source: "user_message".to_string(),
822 source_path: None,
823 source_mtime: None,
824 source_size: None,
825 source_hash: None,
826 metadata: None,
827 };
828
829 let chunk_ids = match manager.store_message(request).await {
830 Ok(ids) => ids,
831 Err(err) if is_embeddings_disabled(&err) => return,
832 Err(err) => panic!("store_message failed: {err}"),
833 };
834 assert!(!chunk_ids.is_empty());
835
836 let results = manager
838 .search(
839 "artificial intelligence",
840 None,
841 Some("project-1"),
842 None,
843 None,
844 )
845 .await;
846 let results = match results {
847 Ok(results) => results,
848 Err(err) if is_embeddings_disabled(&err) => return,
849 Err(err) => panic!("search failed: {err}"),
850 };
851
852 assert!(!results.is_empty());
853 assert!(results[0].similarity >= 0.0);
855 }
856
857 #[tokio::test]
858 async fn test_retrieve_context() {
859 let (manager, _temp) = setup_test_manager().await;
860
861 let request = StoreMessageRequest {
863 content: "The project uses React and TypeScript for the frontend.".to_string(),
864 tier: MemoryTier::Project,
865 session_id: None,
866 project_id: Some("project-1".to_string()),
867 source: "assistant_response".to_string(),
868 source_path: None,
869 source_mtime: None,
870 source_size: None,
871 source_hash: None,
872 metadata: None,
873 };
874 match manager.store_message(request).await {
875 Ok(_) => {}
876 Err(err) if is_embeddings_disabled(&err) => return,
877 Err(err) => panic!("store_message failed: {err}"),
878 }
879
880 let context = manager
881 .retrieve_context("What technologies are used?", Some("project-1"), None, None)
882 .await;
883 let context = match context {
884 Ok(context) => context,
885 Err(err) if is_embeddings_disabled(&err) => return,
886 Err(err) => panic!("retrieve_context failed: {err}"),
887 };
888
889 assert!(!context.project_facts.is_empty());
890 }
891
892 #[tokio::test]
893 async fn test_retrieve_context_with_meta() {
894 let (manager, _temp) = setup_test_manager().await;
895
896 let request = StoreMessageRequest {
897 content: "The backend uses Rust and sqlite-vec for retrieval.".to_string(),
898 tier: MemoryTier::Project,
899 session_id: None,
900 project_id: Some("project-1".to_string()),
901 source: "assistant_response".to_string(),
902 source_path: None,
903 source_mtime: None,
904 source_size: None,
905 source_hash: None,
906 metadata: None,
907 };
908 match manager.store_message(request).await {
909 Ok(_) => {}
910 Err(err) if is_embeddings_disabled(&err) => return,
911 Err(err) => panic!("store_message failed: {err}"),
912 }
913
914 let result = manager
915 .retrieve_context_with_meta("What does the backend use?", Some("project-1"), None, None)
916 .await;
917 let (context, meta) = match result {
918 Ok(v) => v,
919 Err(err) if is_embeddings_disabled(&err) => return,
920 Err(err) => panic!("retrieve_context_with_meta failed: {err}"),
921 };
922
923 assert!(meta.chunks_total > 0);
924 assert!(meta.used);
925 assert_eq!(
926 meta.chunks_total,
927 context.current_session.len()
928 + context.relevant_history.len()
929 + context.project_facts.len()
930 );
931 assert!(meta.score_min.is_some());
932 assert!(meta.score_max.is_some());
933 }
934
935 #[tokio::test]
936 async fn test_config_management() {
937 let (manager, _temp) = setup_test_manager().await;
938
939 let config = manager.get_config("project-1").await.unwrap();
940 assert_eq!(config.max_chunks, 10000);
941
942 let new_config = MemoryConfig {
943 max_chunks: 5000,
944 retrieval_k: 10,
945 ..Default::default()
946 };
947
948 manager.set_config("project-1", &new_config).await.unwrap();
949
950 let updated = manager.get_config("project-1").await.unwrap();
951 assert_eq!(updated.max_chunks, 5000);
952 assert_eq!(updated.retrieval_k, 10);
953 }
954}