1use anyhow::{Context, Result};
10use parking_lot::RwLock;
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet};
13use std::fs;
14use std::path::{Path, PathBuf};
15use std::sync::Arc;
16use tracing::{info, warn};
17
18use super::introspection::ConsolidationEventBuffer;
19use super::storage::{MemoryStorage, SearchCriteria};
20use super::types::*;
21use crate::constants::{
22 PREFETCH_RECENCY_FULL_BOOST, PREFETCH_RECENCY_FULL_HOURS, PREFETCH_RECENCY_PARTIAL_BOOST,
23 PREFETCH_RECENCY_PARTIAL_HOURS, PREFETCH_TEMPORAL_WINDOW_HOURS,
24 VECTOR_SEARCH_CANDIDATE_MULTIPLIER,
25};
26use crate::embeddings::{minilm::MiniLMEmbedder, Embedder};
27use crate::vector_db::vamana::{VamanaConfig, VamanaIndex};
28
29const VAMANA_INDEX_FILE: &str = "vamana.idx";
31
32pub struct RetrievalEngine {
50 storage: Arc<MemoryStorage>,
51 embedder: Arc<MiniLMEmbedder>,
52 vector_index: Arc<RwLock<VamanaIndex>>,
54 id_mapping: Arc<RwLock<IdMapping>>,
56 storage_path: PathBuf,
58 consolidation_events: Option<Arc<RwLock<ConsolidationEventBuffer>>>,
62}
63
64#[derive(serde::Serialize, serde::Deserialize, Default)]
70struct IdMapping {
71 memory_to_vectors: HashMap<MemoryId, Vec<u32>>,
73 vector_to_memory: HashMap<u32, MemoryId>,
75}
76
77impl IdMapping {
78 fn new() -> Self {
79 Self {
80 memory_to_vectors: HashMap::new(),
81 vector_to_memory: HashMap::new(),
82 }
83 }
84
85 fn insert(&mut self, memory_id: MemoryId, vector_id: u32) {
90 if let Some(old_ids) = self.memory_to_vectors.remove(&memory_id) {
92 for old_id in old_ids {
93 self.vector_to_memory.remove(&old_id);
94 }
95 }
96 self.vector_to_memory.insert(vector_id, memory_id.clone());
97 self.memory_to_vectors.insert(memory_id, vec![vector_id]);
98 }
99
100 fn insert_chunks(&mut self, memory_id: MemoryId, vector_ids: Vec<u32>) {
105 if let Some(old_ids) = self.memory_to_vectors.remove(&memory_id) {
107 for old_id in old_ids {
108 self.vector_to_memory.remove(&old_id);
109 }
110 }
111 for &vid in &vector_ids {
112 self.vector_to_memory.insert(vid, memory_id.clone());
113 }
114 self.memory_to_vectors.insert(memory_id, vector_ids);
115 }
116
117 fn get_memory_id(&self, vector_id: u32) -> Option<&MemoryId> {
118 self.vector_to_memory.get(&vector_id)
119 }
120
121 fn remove_all(&mut self, memory_id: &MemoryId) -> Vec<u32> {
123 if let Some(vector_ids) = self.memory_to_vectors.remove(memory_id) {
124 for vid in &vector_ids {
125 self.vector_to_memory.remove(vid);
126 }
127 vector_ids
128 } else {
129 Vec::new()
130 }
131 }
132
133 fn len(&self) -> usize {
135 self.memory_to_vectors.len()
136 }
137
138 fn clear(&mut self) {
139 self.memory_to_vectors.clear();
140 self.vector_to_memory.clear();
141 }
142}
143
144impl RetrievalEngine {
145 pub fn new(storage: Arc<MemoryStorage>, embedder: Arc<MiniLMEmbedder>) -> Result<Self> {
152 Self::with_event_buffer(storage, embedder, None)
153 }
154
155 pub fn with_event_buffer(
165 storage: Arc<MemoryStorage>,
166 embedder: Arc<MiniLMEmbedder>,
167 consolidation_events: Option<Arc<RwLock<ConsolidationEventBuffer>>>,
168 ) -> Result<Self> {
169 let storage_path = storage.path().to_path_buf();
170
171 let vamana_config = VamanaConfig {
173 dimension: 384, max_degree: 32, search_list_size: 100, alpha: 1.2,
177 use_mmap: true, ..Default::default()
179 };
180
181 let vamana_storage = storage_path.join("vector_index");
182 std::fs::create_dir_all(&vamana_storage)?;
183 let vector_index = VamanaIndex::with_storage_path(vamana_config, Some(vamana_storage))
184 .context("Failed to initialize Vamana vector index")?;
185 let id_mapping = IdMapping::new();
186
187 let engine = Self {
192 storage,
193 embedder,
194 vector_index: Arc::new(RwLock::new(vector_index)),
195 id_mapping: Arc::new(RwLock::new(id_mapping)),
196 storage_path,
197 consolidation_events,
198 };
199
200 engine.rebuild_from_rocksdb()?;
202
203 Ok(engine)
204 }
205
206 fn rebuild_from_rocksdb(&self) -> Result<()> {
215 let start_time = std::time::Instant::now();
216
217 let vamana_path = self
219 .storage_path
220 .join("vector_index")
221 .join(VAMANA_INDEX_FILE);
222 if vamana_path.exists() {
223 if let Ok(loaded) = self.try_load_persisted_vamana(&vamana_path) {
224 if loaded {
225 info!(
226 "Instant startup: loaded Vamana in {:.2}ms",
227 start_time.elapsed().as_secs_f64() * 1000.0
228 );
229 return Ok(());
230 }
231 }
232 }
233
234 info!("No valid .vamana file, rebuilding from RocksDB...");
236
237 let mappings = self.storage.get_all_vector_mappings()?;
239
240 if !mappings.is_empty() {
241 info!(
243 "Loading {} vector mappings from RocksDB (atomic storage)",
244 mappings.len()
245 );
246
247 let mut vector_index = self.vector_index.write();
249 let mut id_mapping = self.id_mapping.write();
250 id_mapping.clear();
251 let mut indexed = 0;
252 let mut failed = 0;
253
254 for (memory_id, entry) in &mappings {
255 if entry.text_vectors().is_none() {
257 continue;
258 }
259
260 if let Ok(memory) = self.storage.get(memory_id) {
262 if let Some(ref embedding) = memory.experience.embeddings {
263 match vector_index.add_vector(embedding.clone()) {
265 Ok(new_vector_id) => {
266 id_mapping.insert(memory_id.clone(), new_vector_id);
267 indexed += 1;
268 }
269 Err(e) => {
270 tracing::warn!(
271 "Failed to index memory {} during rebuild: {}",
272 memory_id.0,
273 e
274 );
275 failed += 1;
276 }
277 }
278 }
279 }
280 }
281
282 let elapsed = start_time.elapsed();
283 info!(
284 "Rebuilt Vamana from RocksDB: {} indexed, {} failed in {:.2}s",
285 indexed,
286 failed,
287 elapsed.as_secs_f64()
288 );
289 } else {
290 info!("No vector mappings in RocksDB - checking for migration...");
293 self.migrate_to_atomic_storage()?;
294 }
295
296 Ok(())
297 }
298
299 fn migrate_to_atomic_storage(&self) -> Result<()> {
304 let start_time = std::time::Instant::now();
305
306 let memories = self.storage.get_all()?;
308 let total = memories.len();
309
310 if total == 0 {
311 info!("No memories to migrate");
312 return Ok(());
313 }
314
315 info!("Migrating {} memories to atomic storage...", total);
316
317 let mut vector_index = self.vector_index.write();
319 let mut id_mapping = self.id_mapping.write();
320 let mut migrated = 0;
321 let mut skipped = 0;
322 let mut failed = 0;
323
324 for (i, memory) in memories.iter().enumerate() {
325 if let Some(ref embedding) = memory.experience.embeddings {
327 match vector_index.add_vector(embedding.clone()) {
329 Ok(vector_id) => {
330 id_mapping.insert(memory.id.clone(), vector_id);
332
333 if let Err(e) = self
335 .storage
336 .update_vector_mapping(&memory.id, vec![vector_id])
337 {
338 tracing::warn!("Failed to persist mapping for {}: {}", memory.id.0, e);
339 failed += 1;
340 } else {
341 migrated += 1;
342 }
343 }
344 Err(e) => {
345 tracing::warn!("Failed to index memory {}: {}", memory.id.0, e);
346 failed += 1;
347 }
348 }
349 } else {
350 skipped += 1;
351 }
352
353 if (i + 1) % 500 == 0 || i + 1 == total {
355 info!(
356 "Migration progress: {}/{} ({:.1}%)",
357 i + 1,
358 total,
359 (i + 1) as f64 / total as f64 * 100.0
360 );
361 }
362 }
363
364 let elapsed = start_time.elapsed();
365 info!(
366 "Migration complete: {} migrated, {} skipped (no embeddings), {} failed in {:.2}s",
367 migrated,
368 skipped,
369 failed,
370 elapsed.as_secs_f64()
371 );
372
373 Ok(())
374 }
375
376 fn try_load_persisted_vamana(&self, vamana_path: &Path) -> Result<bool> {
381 if !VamanaIndex::verify_index_file(vamana_path)? {
383 warn!("Vamana file checksum mismatch, will rebuild");
384 return Ok(false);
385 }
386
387 let mut loaded_index = match VamanaIndex::load_from_file(vamana_path) {
389 Ok(idx) => idx,
390 Err(e) => {
391 warn!("Failed to load Vamana file: {}, will rebuild", e);
392 return Ok(false);
393 }
394 };
395
396 let loaded_count = loaded_index.len();
397
398 let mappings = self.storage.get_all_vector_mappings()?;
400 let rocksdb_count = mappings
401 .iter()
402 .filter(|(_, e)| e.text_vectors().is_some())
403 .count();
404
405 let drift_ratio = if loaded_count > 0 {
407 (loaded_count as f64 - rocksdb_count as f64).abs() / loaded_count as f64
408 } else {
409 0.0
410 };
411
412 if drift_ratio > 0.1 && loaded_count > 100 {
413 warn!(
414 "Vamana/RocksDB drift too high ({:.1}%): {} vs {}, will rebuild",
415 drift_ratio * 100.0,
416 loaded_count,
417 rocksdb_count
418 );
419 return Ok(false);
420 }
421
422 let missing_vector_mappings: Vec<MemoryId> = mappings
425 .iter()
426 .filter_map(|(memory_id, entry)| {
427 let vector_ids = entry.text_vectors()?;
428 if vector_ids.is_empty() {
429 return None;
430 }
431 let has_in_range_vector =
432 vector_ids.iter().any(|&vid| (vid as usize) < loaded_count);
433 if has_in_range_vector {
434 None
435 } else {
436 Some(memory_id.clone())
437 }
438 })
439 .collect();
440
441 {
445 let mut index = self.vector_index.write();
446 loaded_index.config.search_list_size = 100;
447 *index = loaded_index;
448 }
449
450 let mut id_mapping = self.id_mapping.write();
452 id_mapping.clear();
453
454 for (memory_id, entry) in mappings.iter() {
455 if let Some(vector_ids) = entry.text_vectors() {
456 if !vector_ids.is_empty() {
457 if vector_ids.len() == 1 {
460 id_mapping.insert(memory_id.clone(), vector_ids[0]);
461 } else {
462 id_mapping.insert_chunks(memory_id.clone(), vector_ids.clone());
463 }
464 }
465 }
466 }
467 drop(id_mapping);
468
469 if !missing_vector_mappings.is_empty() {
473 let mut index = self.vector_index.write();
475 let mut id_mapping = self.id_mapping.write();
476 let mut recovered = 0usize;
477 let mut recovery_failed = 0usize;
478
479 for memory_id in &missing_vector_mappings {
480 match self.storage.get(memory_id) {
481 Ok(memory) => {
482 if let Some(ref embedding) = memory.experience.embeddings {
483 match index.add_vector(embedding.clone()) {
484 Ok(new_vector_id) => {
485 id_mapping.remove_all(memory_id);
486 id_mapping.insert(memory_id.clone(), new_vector_id);
487 recovered += 1;
488 }
489 Err(e) => {
490 warn!(
491 "Failed to recover missing vector for memory {}: {}",
492 memory_id.0, e
493 );
494 recovery_failed += 1;
495 }
496 }
497 } else {
498 recovery_failed += 1;
499 }
500 }
501 Err(e) => {
502 warn!(
503 "Failed to load memory {} for vector recovery: {}",
504 memory_id.0, e
505 );
506 recovery_failed += 1;
507 }
508 }
509 }
510
511 if recovered > 0 || recovery_failed > 0 {
512 info!(
513 "Recovered {} missing vectors from RocksDB mappings ({} failed)",
514 recovered, recovery_failed
515 );
516 }
517 }
518
519 info!(
520 "Loaded {} vectors from .vamana, {} mappings from RocksDB",
521 self.vector_index.read().len(),
522 self.id_mapping.read().len()
523 );
524
525 Ok(true)
526 }
527
528 pub fn set_consolidation_events(&mut self, events: Arc<RwLock<ConsolidationEventBuffer>>) {
530 self.consolidation_events = Some(events);
531 }
532
533 pub fn save(&self) -> Result<()> {
542 let index_path = self.storage_path.join("vector_index");
543 fs::create_dir_all(&index_path)?;
544
545 let vamana_path = index_path.join(VAMANA_INDEX_FILE);
546
547 let vector_index = self.vector_index.read();
549 let id_mapping = self.id_mapping.read();
550 let vector_count = id_mapping.len();
551 if vector_count > 0 {
552 let tmp_path = vamana_path.with_extension("vamana.tmp");
554 match vector_index.save_to_file(&tmp_path) {
555 Ok(()) => {
556 if let Err(e) = fs::rename(&tmp_path, &vamana_path) {
558 warn!(
559 "Failed to rename .vamana.tmp to .vamana: {} (removing tmp)",
560 e
561 );
562 let _ = fs::remove_file(&tmp_path);
563 } else {
564 info!(
565 "Saved Vamana index: {} vectors to {} (instant startup enabled)",
566 vector_count,
567 vamana_path.display()
568 );
569 }
570 }
571 Err(e) => {
572 warn!(
573 "Failed to save Vamana index (will rebuild on restart): {}",
574 e
575 );
576 let _ = fs::remove_file(&tmp_path);
577 }
578 }
579 } else {
580 info!("Vamana index empty, skipping persistence");
581 }
582
583 Ok(())
584 }
585
586 pub fn len(&self) -> usize {
588 self.id_mapping.read().len()
589 }
590
591 pub fn is_empty(&self) -> bool {
593 self.len() == 0
594 }
595
596 pub fn get_indexed_memory_ids(&self) -> HashSet<MemoryId> {
598 self.id_mapping
599 .read()
600 .memory_to_vectors
601 .keys()
602 .cloned()
603 .collect()
604 }
605
606 pub fn index_memory(&self, memory: &Memory) -> Result<()> {
614 use crate::embeddings::chunking::{chunk_text, ChunkConfig};
615
616 let text = Self::extract_searchable_text(memory);
617 let chunk_config = ChunkConfig::default();
618 let chunk_result = chunk_text(&text, &chunk_config);
619
620 let vector_ids = if chunk_result.was_chunked {
621 let embeddings: Vec<Vec<f32>> = chunk_result
624 .chunks
625 .iter()
626 .map(|chunk| {
627 self.embedder
628 .encode(chunk)
629 .context("Failed to generate chunk embedding")
630 })
631 .collect::<Result<Vec<_>>>()?;
632
633 let mut ids = Vec::with_capacity(embeddings.len());
635 let mut index = self.vector_index.write();
636 for embedding in embeddings {
637 let vector_id = index
638 .add_vector(embedding)
639 .context("Failed to add chunk vector to index")?;
640 ids.push(vector_id);
641 }
642 drop(index);
643
644 self.id_mapping
646 .write()
647 .insert_chunks(memory.id.clone(), ids.clone());
648
649 tracing::debug!(
650 "Indexed memory {} with {} chunks (original: {} chars)",
651 memory.id.0,
652 chunk_result.chunks.len(),
653 chunk_result.original_length
654 );
655
656 ids
657 } else {
658 let embedding = if let Some(emb) = &memory.experience.embeddings {
660 emb.clone()
661 } else {
662 self.embedder
663 .encode(&text)
664 .context("Failed to generate embedding")?
665 };
666
667 let mut index = self.vector_index.write();
668 let vector_id = index
669 .add_vector(embedding)
670 .context("Failed to add vector to index")?;
671
672 self.id_mapping.write().insert(memory.id.clone(), vector_id);
674
675 vec![vector_id]
676 };
677
678 self.storage
681 .update_vector_mapping(&memory.id, vector_ids)
682 .context("Failed to persist vector mapping to RocksDB")?;
683
684 Ok(())
685 }
686
687 pub fn reindex_memory(&self, memory: &Memory) -> Result<()> {
694 let existing_vector_ids = {
696 let mapping = self.id_mapping.read();
697 mapping
698 .memory_to_vectors
699 .get(&memory.id)
700 .cloned()
701 .unwrap_or_default()
702 };
703
704 if !existing_vector_ids.is_empty() {
705 {
710 let index = self.vector_index.read();
711 for &vid in &existing_vector_ids {
712 index.mark_deleted(vid);
713 }
714 }
715
716 let mut mapping = self.id_mapping.write();
718 mapping.memory_to_vectors.remove(&memory.id);
719 for vector_id in existing_vector_ids {
720 mapping.vector_to_memory.remove(&vector_id);
721 }
722 }
723
724 self.index_memory(memory)
726 }
727
728 pub fn remove_memory(&self, memory_id: &MemoryId) -> bool {
736 let vector_ids = self.id_mapping.write().remove_all(memory_id);
738
739 if !vector_ids.is_empty() {
740 let index = self.vector_index.read();
742 for vid in &vector_ids {
743 index.mark_deleted(*vid);
744 }
745
746 if let Err(e) = self.storage.delete_vector_mapping(memory_id) {
751 tracing::warn!(
752 "Failed to delete vector mapping from RocksDB for {}: {}",
753 memory_id.0,
754 e
755 );
756 }
757
758 tracing::debug!(
759 "Removed memory {:?} from vector index ({} vectors)",
760 memory_id,
761 vector_ids.len()
762 );
763 true
764 } else {
765 tracing::debug!("Memory {:?} not found in vector index", memory_id);
766 false
767 }
768 }
769
770 fn extract_searchable_text(memory: &Memory) -> String {
772 let mut text = memory.experience.content.clone();
774
775 if !memory.experience.entities.is_empty() {
777 text.push(' ');
778 text.push_str(&memory.experience.entities.join(" "));
779 }
780
781 if let Some(context) = &memory.experience.context {
783 if let Some(topic) = &context.conversation.topic {
785 text.push(' ');
786 text.push_str(topic);
787 }
788 if !context.conversation.recent_messages.is_empty() {
790 text.push(' ');
791 text.push_str(&context.conversation.recent_messages.join(" "));
792 }
793 if let Some(name) = &context.project.name {
795 text.push(' ');
796 text.push_str(name);
797 }
798
799 if let Some(emotion) = &context.emotional.dominant_emotion {
801 text.push(' ');
802 text.push_str(emotion);
803 }
804
805 if let Some(episode_type) = &context.episode.episode_type {
807 text.push(' ');
808 text.push_str(episode_type);
809 }
810 }
811
812 if !memory.experience.outcomes.is_empty() {
814 text.push(' ');
815 text.push_str(&memory.experience.outcomes.join(" "));
816 }
817
818 text
819 }
820
821 pub fn search_ids(&self, query: &Query, limit: usize) -> Result<Vec<(MemoryId, f32)>> {
828 let query_embedding = if let Some(embedding) = &query.query_embedding {
830 embedding.clone()
831 } else if let Some(query_text) = &query.query_text {
832 self.embedder
833 .encode(query_text)
834 .context("Failed to generate query embedding")?
835 } else {
836 tracing::warn!("Empty query in search_ids: no query_text or query_embedding provided");
837 return Ok(Vec::new());
838 };
839
840 let episode_candidates: Option<HashSet<MemoryId>> =
844 if let Some(episode_id) = &query.episode_id {
845 let episode_memories = self
846 .storage
847 .search(SearchCriteria::ByEpisode(episode_id.clone()))?;
848 if episode_memories.is_empty() {
849 tracing::debug!(
850 "No memories found in episode {}, falling back to global search",
851 episode_id
852 );
853 None
854 } else {
855 tracing::debug!(
856 "Episode {} has {} memories, using as temporal filter",
857 episode_id,
858 episode_memories.len()
859 );
860 Some(episode_memories.into_iter().map(|m| m.id).collect())
861 }
862 } else {
863 None
864 };
865
866 let index = self.vector_index.read();
868 let results = index
869 .search(
870 &query_embedding,
871 limit * VECTOR_SEARCH_CANDIDATE_MULTIPLIER * 2,
872 )
873 .context("Vector search failed")?;
874
875 let id_mapping = self.id_mapping.read();
883 let mut best_scores: std::collections::HashMap<MemoryId, f32> =
884 std::collections::HashMap::new();
885
886 for (vector_id, distance) in results {
887 let similarity = -distance;
890
891 if let Some(memory_id) = id_mapping.get_memory_id(vector_id) {
892 if let Some(ref candidates) = episode_candidates {
894 if !candidates.contains(memory_id) {
895 continue; }
897 }
898
899 best_scores
901 .entry(memory_id.clone())
902 .and_modify(|score| {
903 if similarity > *score {
904 *score = similarity;
905 }
906 })
907 .or_insert(similarity);
908 }
909 }
910
911 let mut memory_ids: Vec<(MemoryId, f32)> = best_scores.into_iter().collect();
913 memory_ids.sort_by(|a, b| b.1.total_cmp(&a.1));
914 memory_ids.truncate(limit);
915
916 Ok(memory_ids)
917 }
918
919 pub fn get_from_storage(&self, id: &MemoryId) -> Result<Memory> {
921 self.storage.get(id)
922 }
923
924 pub fn search_by_embedding(
934 &self,
935 embedding: &[f32],
936 limit: usize,
937 exclude_id: Option<&MemoryId>,
938 ) -> Result<Vec<(MemoryId, f32)>> {
939 let index = self.vector_index.read();
941 let results = index
942 .search(embedding, limit * VECTOR_SEARCH_CANDIDATE_MULTIPLIER * 2)
943 .context("Vector search by embedding failed")?;
944
945 let id_mapping = self.id_mapping.read();
950 let mut best_scores: std::collections::HashMap<MemoryId, f32> =
951 std::collections::HashMap::new();
952
953 for (vector_id, distance) in results {
954 let similarity = -distance;
956
957 if let Some(memory_id) = id_mapping.get_memory_id(vector_id) {
958 if let Some(exclude) = exclude_id {
960 if memory_id == exclude {
961 continue;
962 }
963 }
964
965 best_scores
967 .entry(memory_id.clone())
968 .and_modify(|score| {
969 if similarity > *score {
970 *score = similarity;
971 }
972 })
973 .or_insert(similarity);
974 }
975 }
976
977 let mut memory_ids: Vec<(MemoryId, f32)> = best_scores.into_iter().collect();
979 memory_ids.sort_by(|a, b| b.1.total_cmp(&a.1));
980 memory_ids.truncate(limit);
981
982 Ok(memory_ids)
983 }
984
985 pub fn search(&self, query: &Query, limit: usize) -> Result<Vec<SharedMemory>> {
987 let results = match query.retrieval_mode {
988 RetrievalMode::Similarity => self.similarity_search(query, limit)?,
990 RetrievalMode::Temporal => self.temporal_search(query, limit)?,
991 RetrievalMode::Causal => self.causal_search(query, limit)?,
992 RetrievalMode::Associative => self.associative_search(query, limit)?,
993 RetrievalMode::Hybrid => self.hybrid_search(query, limit)?,
994 RetrievalMode::Spatial => self.spatial_search(query, limit)?,
996 RetrievalMode::Mission => self.mission_search(query, limit)?,
997 RetrievalMode::ActionOutcome => self.action_outcome_search(query, limit)?,
998 };
999
1000 Ok(results)
1001 }
1002
1003 fn similarity_search(&self, query: &Query, limit: usize) -> Result<Vec<SharedMemory>> {
1005 let query_embedding = if let Some(embedding) = &query.query_embedding {
1007 embedding.clone()
1008 } else if let Some(query_text) = &query.query_text {
1009 self.embedder
1010 .encode(query_text)
1011 .context("Failed to generate query embedding")?
1012 } else {
1013 tracing::warn!(
1014 "Empty query in similarity_search: no query_text or query_embedding provided"
1015 );
1016 return Ok(Vec::new());
1017 };
1018
1019 let index = self.vector_index.read();
1021 let results = index
1022 .search(&query_embedding, limit * VECTOR_SEARCH_CANDIDATE_MULTIPLIER)
1023 .context("Vector search failed")?;
1024
1025 let id_mapping = self.id_mapping.read();
1029 let mut memories = Vec::new();
1030 let mut seen_ids = std::collections::HashSet::new();
1031
1032 for (vector_id, _distance) in results {
1033 if let Some(memory_id) = id_mapping.get_memory_id(vector_id) {
1034 if !seen_ids.insert(memory_id.clone()) {
1035 continue; }
1037 if let Ok(memory) = self.storage.get(memory_id) {
1038 let shared_memory = Arc::new(memory);
1039 if self.matches_filters(&shared_memory, query) {
1040 memories.push(shared_memory);
1041 if memories.len() >= limit {
1042 break;
1043 }
1044 }
1045 }
1046 }
1047 }
1048
1049 Ok(memories)
1050 }
1051
1052 #[inline]
1057 pub fn matches_filters(&self, memory: &Memory, query: &Query) -> bool {
1058 query.matches(memory)
1059 }
1060
1061 fn temporal_search(&self, query: &Query, limit: usize) -> Result<Vec<SharedMemory>> {
1062 let criteria = if let Some(episode_id) = &query.episode_id {
1067 SearchCriteria::ByEpisodeSequence {
1069 episode_id: episode_id.clone(),
1070 min_sequence: None, max_sequence: None,
1072 }
1073 } else if let Some((start, end)) = &query.time_range {
1074 SearchCriteria::ByDate {
1076 start: *start,
1077 end: *end,
1078 }
1079 } else {
1080 let end = chrono::Utc::now();
1082 let start = end - chrono::Duration::days(7);
1083 SearchCriteria::ByDate { start, end }
1084 };
1085
1086 let mut memories: Vec<SharedMemory> = self
1087 .storage
1088 .search(criteria)?
1089 .into_iter()
1090 .map(Arc::new)
1091 .collect();
1092
1093 memories.retain(|m| self.matches_filters(m, query));
1094
1095 if query.episode_id.is_some() {
1097 memories.sort_by(|a, b| {
1100 let seq_a = a
1101 .experience
1102 .context
1103 .as_ref()
1104 .and_then(|c| c.episode.sequence_number)
1105 .unwrap_or(0);
1106 let seq_b = b
1107 .experience
1108 .context
1109 .as_ref()
1110 .and_then(|c| c.episode.sequence_number)
1111 .unwrap_or(0);
1112 seq_a.cmp(&seq_b)
1113 });
1114 } else {
1115 memories.sort_by(|a, b| b.created_at.cmp(&a.created_at));
1116 }
1117
1118 memories.truncate(limit);
1119 Ok(memories)
1120 }
1121
1122 fn causal_search(&self, query: &Query, limit: usize) -> Result<Vec<SharedMemory>> {
1123 let seeds = self.similarity_search(query, 3)?;
1124
1125 let mut results = HashSet::new();
1126 let mut to_explore = Vec::new();
1127
1128 for seed in &seeds {
1129 to_explore.push(seed.id.clone());
1130 results.insert(seed.id.clone());
1131 }
1132
1133 while !to_explore.is_empty() && results.len() < limit {
1134 if let Some(current_id) = to_explore.pop() {
1135 if let Ok(memory) = self.storage.get(¤t_id) {
1136 for related_id in &memory.experience.related_memories {
1137 if !results.contains(related_id) {
1138 results.insert(related_id.clone());
1139 to_explore.push(related_id.clone());
1140 }
1141 }
1142 }
1143 }
1144 }
1145
1146 let mut memories = Vec::new();
1147 for id in results.into_iter().take(limit) {
1148 if let Ok(memory) = self.storage.get(&id) {
1149 memories.push(Arc::new(memory));
1150 }
1151 }
1152
1153 Ok(memories)
1154 }
1155
1156 fn associative_search(&self, query: &Query, limit: usize) -> Result<Vec<SharedMemory>> {
1157 self.similarity_search(query, limit)
1161 }
1162
1163 fn hybrid_search(&self, query: &Query, limit: usize) -> Result<Vec<SharedMemory>> {
1164 let mut all_results: HashMap<MemoryId, SharedMemory> = HashMap::new();
1165 let mut scores: HashMap<MemoryId, f32> = HashMap::new();
1166
1167 let weights = [
1169 (RetrievalMode::Similarity, 0.5), (RetrievalMode::Temporal, 0.2), (RetrievalMode::Causal, 0.2), (RetrievalMode::Associative, 0.1), ];
1174
1175 for (mode, weight) in weights.iter() {
1176 let mut mode_query = query.clone();
1177 mode_query.retrieval_mode = mode.clone();
1178
1179 let results = match mode {
1180 RetrievalMode::Similarity => self.similarity_search(&mode_query, limit),
1181 RetrievalMode::Temporal => self.temporal_search(&mode_query, limit),
1182 RetrievalMode::Causal => self.causal_search(&mode_query, limit),
1183 RetrievalMode::Associative => self.associative_search(&mode_query, limit),
1184 _ => continue,
1185 };
1186
1187 if let Ok(memories) = results {
1188 for (rank, memory) in memories.into_iter().enumerate() {
1189 let score = weight * (1.0 / (rank as f32 + 1.0));
1191
1192 let memory_id = memory.id.clone();
1194 *scores.entry(memory_id.clone()).or_insert(0.0) += score;
1195 all_results.insert(memory_id, memory);
1196 }
1197 }
1198 }
1199
1200 let mut sorted: Vec<(f32, SharedMemory)> = all_results
1203 .into_iter()
1204 .map(|(id, memory)| {
1205 let retrieval_score = scores.get(&id).copied().unwrap_or(0.0);
1206 let salience = memory.salience_score_with_access();
1208 let final_score = retrieval_score * 0.7 + salience * 0.3;
1210 (final_score, memory)
1211 })
1212 .collect();
1213
1214 sorted.sort_by(|a, b| b.0.total_cmp(&a.0));
1215
1216 Ok(sorted.into_iter().take(limit).map(|(_, m)| m).collect())
1217 }
1218
1219 fn spatial_search(&self, query: &Query, limit: usize) -> Result<Vec<SharedMemory>> {
1226 let geo_filter = query
1227 .geo_filter
1228 .as_ref()
1229 .ok_or_else(|| anyhow::anyhow!("Spatial search requires geo_filter"))?;
1230
1231 let criteria = SearchCriteria::ByLocation {
1232 lat: geo_filter.lat,
1233 lon: geo_filter.lon,
1234 radius_meters: geo_filter.radius_meters,
1235 };
1236
1237 let mut memories: Vec<SharedMemory> = self
1238 .storage
1239 .search(criteria)?
1240 .into_iter()
1241 .map(Arc::new)
1242 .collect();
1243
1244 memories.retain(|m| self.matches_filters(m, query));
1246
1247 memories.sort_by(|a, b| {
1249 let dist_a = match a.experience.geo_location {
1250 Some(geo) => geo_filter.haversine_distance(geo[0], geo[1]),
1251 None => f64::MAX,
1252 };
1253 let dist_b = match b.experience.geo_location {
1254 Some(geo) => geo_filter.haversine_distance(geo[0], geo[1]),
1255 None => f64::MAX,
1256 };
1257 dist_a.total_cmp(&dist_b)
1258 });
1259
1260 memories.truncate(limit);
1261 Ok(memories)
1262 }
1263
1264 fn mission_search(&self, query: &Query, limit: usize) -> Result<Vec<SharedMemory>> {
1267 let mission_id = query
1268 .mission_id
1269 .as_ref()
1270 .ok_or_else(|| anyhow::anyhow!("Mission search requires mission_id"))?;
1271
1272 let criteria = SearchCriteria::ByMission(mission_id.clone());
1273
1274 let mut memories: Vec<SharedMemory> = self
1275 .storage
1276 .search(criteria)?
1277 .into_iter()
1278 .map(Arc::new)
1279 .collect();
1280
1281 memories.retain(|m| self.matches_filters(m, query));
1283
1284 memories.sort_by(|a, b| a.created_at.cmp(&b.created_at));
1286
1287 memories.truncate(limit);
1288 Ok(memories)
1289 }
1290
1291 fn action_outcome_search(&self, query: &Query, limit: usize) -> Result<Vec<SharedMemory>> {
1294 let (min_reward, max_reward) = query.reward_range.unwrap_or((0.0, 1.0));
1296
1297 let criteria = SearchCriteria::ByReward {
1298 min: min_reward,
1299 max: max_reward,
1300 };
1301
1302 let mut memories: Vec<SharedMemory> = self
1303 .storage
1304 .search(criteria)?
1305 .into_iter()
1306 .map(Arc::new)
1307 .collect();
1308
1309 memories.retain(|m| self.matches_filters(m, query));
1311
1312 memories.sort_by(|a, b| {
1314 let reward_a = a.experience.reward.unwrap_or(0.0);
1315 let reward_b = b.experience.reward.unwrap_or(0.0);
1316 reward_b.total_cmp(&reward_a)
1317 });
1318
1319 memories.truncate(limit);
1320 Ok(memories)
1321 }
1322
1323 pub fn rebuild_index(&self) -> Result<()> {
1330 let all_ids = self.storage.get_all_ids()?;
1332 let total = all_ids.len();
1333
1334 if total == 0 {
1335 tracing::info!("No memories to index");
1336 return Ok(());
1337 }
1338
1339 tracing::info!("Starting resumable index rebuild: {} memories", total);
1340
1341 let indexed_ids = self.get_indexed_memory_ids();
1343 let already_indexed = indexed_ids.len();
1344
1345 let mut indexed = 0;
1346 let mut skipped = 0;
1347 let mut failed = 0;
1348 let start_time = std::time::Instant::now();
1349
1350 for (i, memory_id) in all_ids.iter().enumerate() {
1352 if indexed_ids.contains(memory_id) {
1354 skipped += 1;
1355 } else {
1356 match self.storage.get(memory_id) {
1358 Ok(memory) => {
1359 if memory.is_forgotten() {
1360 skipped += 1;
1361 } else {
1362 match self.index_memory(&memory) {
1363 Ok(_) => indexed += 1,
1364 Err(e) => {
1365 failed += 1;
1366 tracing::warn!(
1367 "Failed to index memory {} during rebuild: {}",
1368 memory_id.0,
1369 e
1370 );
1371 }
1372 }
1373 }
1374 }
1375 Err(e) => {
1376 failed += 1;
1377 tracing::warn!(
1378 "Failed to load memory {} during rebuild: {}",
1379 memory_id.0,
1380 e
1381 );
1382 }
1383 }
1384 }
1385
1386 if (i + 1) % 1000 == 0 || i + 1 == total {
1388 let elapsed = start_time.elapsed().as_secs();
1389 let rate = if elapsed > 0 {
1390 (indexed + skipped) as f64 / elapsed as f64
1391 } else {
1392 0.0
1393 };
1394 tracing::info!(
1395 "Rebuild progress: {}/{} ({:.1}%), {} indexed, {} skipped, {} failed, {:.0}/sec",
1396 i + 1,
1397 total,
1398 (i + 1) as f64 / total as f64 * 100.0,
1399 indexed,
1400 skipped,
1401 failed,
1402 rate
1403 );
1404 }
1405 }
1406
1407 tracing::info!(
1408 "Index rebuild complete: {} indexed, {} already present, {} failed (total: {})",
1409 indexed,
1410 already_indexed + skipped,
1411 failed,
1412 self.len()
1413 );
1414
1415 Ok(())
1416 }
1417
1418 #[deprecated(note = "Use GraphMemory at API layer instead")]
1429 pub fn add_to_graph(&self, _memory: &Memory) {
1430 }
1432
1433 #[deprecated(note = "Use GraphMemory.record_memory_coactivation() at API layer instead")]
1436 pub fn record_coactivation(&self, _memory_ids: &[MemoryId]) {
1437 }
1439
1440 #[deprecated(note = "Use GraphMemory.apply_decay() at API layer instead")]
1443 pub fn graph_maintenance(&self) {
1444 }
1446
1447 #[deprecated(note = "Use GraphMemory.get_stats() at API layer instead")]
1450 pub fn graph_stats(&self) -> MemoryGraphStats {
1451 MemoryGraphStats {
1453 node_count: 0,
1454 edge_count: 0,
1455 avg_strength: 0.0,
1456 potentiated_count: 0,
1457 }
1458 }
1459
1460 pub fn auto_rebuild_index_if_needed(&self) -> Result<bool> {
1475 {
1477 let index = self.vector_index.read();
1478 if !index.needs_rebuild() || index.is_rebuilding() {
1479 return Ok(false);
1480 }
1481 }
1482
1483 info!("Index rebuild/compaction needed, performing full rebuild from RocksDB");
1485 self.rebuild_from_rocksdb()?;
1486 Ok(true)
1487 }
1488
1489 pub fn index_health(&self) -> IndexHealth {
1491 let index = self.vector_index.read();
1492 IndexHealth {
1493 total_vectors: index.len(),
1494 incremental_inserts: index.incremental_insert_count(),
1495 deleted_count: index.deleted_count(),
1496 deletion_ratio: index.deletion_ratio(),
1497 needs_rebuild: index.needs_rebuild(),
1498 needs_compaction: index.needs_compaction(),
1499 rebuild_threshold: crate::vector_db::vamana::REBUILD_THRESHOLD,
1500 deletion_ratio_threshold: crate::vector_db::vamana::DELETION_RATIO_THRESHOLD,
1501 }
1502 }
1503}
1504
1505#[derive(Debug, Clone)]
1507pub struct IndexHealth {
1508 pub total_vectors: usize,
1509 pub incremental_inserts: usize,
1510 pub deleted_count: usize,
1511 pub deletion_ratio: f32,
1512 pub needs_rebuild: bool,
1513 pub needs_compaction: bool,
1514 pub rebuild_threshold: usize,
1515 pub deletion_ratio_threshold: f32,
1516}
1517
1518#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
1527pub enum RetrievalOutcome {
1528 Helpful,
1531 Misleading,
1534 Neutral,
1537}
1538
1539#[derive(Debug, Clone)]
1541pub struct TrackedRetrieval {
1542 pub memories: Vec<SharedMemory>,
1544 pub retrieval_id: String,
1546 pub query_fingerprint: u64,
1548 pub retrieved_at: chrono::DateTime<chrono::Utc>,
1550}
1551
1552impl TrackedRetrieval {
1553 fn new(memories: Vec<SharedMemory>, query: &Query) -> Self {
1554 use std::hash::{Hash, Hasher};
1555 let mut hasher = std::collections::hash_map::DefaultHasher::new();
1556 if let Some(text) = &query.query_text {
1557 text.hash(&mut hasher);
1558 }
1559
1560 Self {
1561 memories,
1562 retrieval_id: uuid::Uuid::new_v4().to_string(),
1563 query_fingerprint: hasher.finish(),
1564 retrieved_at: chrono::Utc::now(),
1565 }
1566 }
1567
1568 pub fn memory_ids(&self) -> Vec<MemoryId> {
1570 self.memories.iter().map(|m| m.id.clone()).collect()
1571 }
1572}
1573
1574#[derive(Debug, Clone, Serialize, Deserialize)]
1576pub struct RetrievalFeedback {
1577 pub retrieval_id: String,
1579 pub outcome: RetrievalOutcome,
1581 pub task_context: Option<String>,
1583 pub feedback_at: chrono::DateTime<chrono::Utc>,
1585}
1586
1587impl RetrievalEngine {
1588 pub fn search_tracked(&self, query: &Query, limit: usize) -> Result<TrackedRetrieval> {
1597 let memories = self.search(query, limit)?;
1598 Ok(TrackedRetrieval::new(memories, query))
1599 }
1600
1601 pub fn reinforce_recall(
1610 &self,
1611 memory_ids: &[MemoryId],
1612 outcome: RetrievalOutcome,
1613 ) -> Result<ReinforcementStats> {
1614 if memory_ids.is_empty() {
1615 return Ok(ReinforcementStats::default());
1616 }
1617
1618 let mut stats = ReinforcementStats {
1619 memories_processed: memory_ids.len(),
1620 ..Default::default()
1621 };
1622
1623 if !matches!(outcome, RetrievalOutcome::Misleading) && memory_ids.len() >= 2 {
1625 let n = memory_ids.len();
1626 stats.associations_strengthened = n * (n - 1) / 2;
1627 }
1628
1629 match outcome {
1630 RetrievalOutcome::Helpful => {
1631 for id in memory_ids {
1633 if let Ok(memory) = self.storage.get(id) {
1634 memory.record_access();
1636 memory.boost_importance(0.05); if self.storage.update(&memory).is_ok() {
1640 stats.importance_boosts += 1;
1641 }
1642 }
1643 }
1644 }
1645 RetrievalOutcome::Misleading => {
1646 for id in memory_ids {
1648 if let Ok(memory) = self.storage.get(id) {
1649 memory.record_access();
1650 memory.decay_importance(0.10); if self.storage.update(&memory).is_ok() {
1654 stats.importance_decays += 1;
1655 }
1656 }
1657 }
1658 }
1660 RetrievalOutcome::Neutral => {
1661 for id in memory_ids {
1663 if let Ok(memory) = self.storage.get(id) {
1664 memory.record_access();
1665
1666 if let Err(e) = self.storage.update(&memory) {
1668 tracing::warn!(
1669 "Failed to persist access update for memory {}: {}",
1670 id.0,
1671 e
1672 );
1673 }
1674 }
1675 }
1676 }
1678 }
1679
1680 stats.outcome = outcome;
1681 Ok(stats)
1682 }
1683
1684 pub fn reinforce_tracked(
1686 &self,
1687 tracked: &TrackedRetrieval,
1688 outcome: RetrievalOutcome,
1689 ) -> Result<ReinforcementStats> {
1690 let ids = tracked.memory_ids();
1691 self.reinforce_recall(&ids, outcome)
1692 }
1693
1694 pub fn reinforce_batch(
1696 &self,
1697 feedbacks: &[RetrievalFeedback],
1698 retrieval_memories: &HashMap<String, Vec<MemoryId>>,
1699 ) -> Result<Vec<ReinforcementStats>> {
1700 let mut results = Vec::with_capacity(feedbacks.len());
1701
1702 for feedback in feedbacks {
1703 if let Some(memory_ids) = retrieval_memories.get(&feedback.retrieval_id) {
1704 let stats = self.reinforce_recall(memory_ids, feedback.outcome)?;
1705 results.push(stats);
1706 }
1707 }
1708
1709 Ok(results)
1710 }
1711}
1712
1713#[derive(Debug, Clone, Default)]
1715pub struct ReinforcementStats {
1716 pub memories_processed: usize,
1718 pub associations_strengthened: usize,
1720 pub importance_boosts: usize,
1722 pub importance_decays: usize,
1724 pub outcome: RetrievalOutcome,
1726 pub persist_failures: usize,
1728}
1729
1730impl Default for RetrievalOutcome {
1731 fn default() -> Self {
1732 Self::Neutral
1733 }
1734}
1735
1736#[derive(Debug, Clone, Default)]
1744pub struct MemoryGraphStats {
1745 pub node_count: usize,
1746 pub edge_count: usize,
1747 pub avg_strength: f32,
1748 pub potentiated_count: usize,
1749}
1750
1751#[derive(Debug, Clone, Default, Serialize, Deserialize)]
1757pub struct PrefetchContext {
1758 pub project_id: Option<String>,
1760 pub current_file: Option<String>,
1762 pub recent_entities: Vec<String>,
1764 pub hour_of_day: Option<u32>,
1766 pub day_of_week: Option<u32>,
1768 pub recent_queries: Vec<String>,
1770 pub task_type: Option<String>,
1772
1773 pub episode_id: Option<String>,
1776 pub emotional_valence: Option<f32>,
1778}
1779
1780impl PrefetchContext {
1781 pub fn from_rich_context(ctx: &super::types::RichContext) -> Self {
1783 Self {
1784 project_id: ctx.project.project_id.clone(),
1785 current_file: ctx.code.current_file.clone(),
1786 recent_entities: ctx.conversation.mentioned_entities.clone(),
1787 hour_of_day: ctx
1788 .temporal
1789 .time_of_day
1790 .as_ref()
1791 .and_then(|t| t.parse().ok()),
1792 day_of_week: ctx
1793 .temporal
1794 .day_of_week
1795 .as_ref()
1796 .and_then(|d| match d.as_str() {
1797 "Sunday" => Some(0),
1798 "Monday" => Some(1),
1799 "Tuesday" => Some(2),
1800 "Wednesday" => Some(3),
1801 "Thursday" => Some(4),
1802 "Friday" => Some(5),
1803 "Saturday" => Some(6),
1804 _ => None,
1805 }),
1806 recent_queries: Vec::new(),
1807 task_type: ctx.project.current_task.clone(),
1808 episode_id: ctx.episode.episode_id.clone(),
1810 emotional_valence: if ctx.emotional.valence != 0.0 {
1811 Some(ctx.emotional.valence)
1812 } else {
1813 None
1814 },
1815 }
1816 }
1817
1818 pub fn from_current_time() -> Self {
1820 let now = chrono::Utc::now();
1821 Self {
1822 hour_of_day: Some(now.hour()),
1823 day_of_week: Some(now.weekday().num_days_from_sunday()),
1824 ..Default::default()
1825 }
1826 }
1827}
1828
1829#[derive(Debug, Clone, Default)]
1831pub struct PrefetchResult {
1832 pub prefetched_ids: Vec<MemoryId>,
1834 pub reason: PrefetchReason,
1836 pub cache_hits: usize,
1838 pub fetches: usize,
1840}
1841
1842#[derive(Debug, Clone, Default, Serialize, Deserialize)]
1844pub enum PrefetchReason {
1845 Project(String),
1847 RelatedFiles,
1849 SharedEntities,
1851 TemporalPattern,
1853 AssociatedMemories,
1855 QueryPrediction,
1857 #[default]
1858 Mixed,
1860}
1861
1862pub struct AnticipatoryPrefetch {
1870 max_prefetch: usize,
1872}
1873
1874impl Default for AnticipatoryPrefetch {
1875 fn default() -> Self {
1876 Self::new()
1877 }
1878}
1879
1880impl AnticipatoryPrefetch {
1881 pub fn new() -> Self {
1882 Self { max_prefetch: 20 }
1883 }
1884
1885 pub fn with_limit(max_prefetch: usize) -> Self {
1887 Self { max_prefetch }
1888 }
1889
1890 pub fn generate_prefetch_query(&self, context: &PrefetchContext) -> Option<Query> {
1895 if let Some(project_id) = &context.project_id {
1897 return Some(self.project_query(project_id));
1898 }
1899
1900 if !context.recent_entities.is_empty() {
1902 return Some(self.entity_query(&context.recent_entities));
1903 }
1904
1905 if let Some(file_path) = &context.current_file {
1907 return Some(self.file_query(file_path));
1908 }
1909
1910 if let (Some(hour), Some(day)) = (context.hour_of_day, context.day_of_week) {
1912 return Some(self.temporal_query(hour, day));
1913 }
1914
1915 None
1916 }
1917
1918 fn project_query(&self, project_id: &str) -> Query {
1920 Query {
1921 query_text: Some(format!("project:{}", project_id)),
1922 max_results: self.max_prefetch,
1923 retrieval_mode: super::types::RetrievalMode::Similarity,
1924 ..Default::default()
1925 }
1926 }
1927
1928 fn entity_query(&self, entities: &[String]) -> Query {
1930 let query_text = entities.join(" ");
1931 Query {
1932 query_text: Some(query_text),
1933 max_results: self.max_prefetch,
1934 retrieval_mode: super::types::RetrievalMode::Similarity,
1935 ..Default::default()
1936 }
1937 }
1938
1939 fn file_query(&self, file_path: &str) -> Query {
1941 let filename = std::path::Path::new(file_path)
1943 .file_name()
1944 .and_then(|n| n.to_str())
1945 .unwrap_or(file_path);
1946
1947 Query {
1948 query_text: Some(format!("file {} code", filename)),
1949 max_results: self.max_prefetch,
1950 retrieval_mode: super::types::RetrievalMode::Similarity,
1951 ..Default::default()
1952 }
1953 }
1954
1955 fn temporal_query(&self, hour: u32, _day: u32) -> Query {
1957 let now = chrono::Utc::now();
1958
1959 let start_hour = if hour >= PREFETCH_TEMPORAL_WINDOW_HOURS as u32 {
1961 hour - PREFETCH_TEMPORAL_WINDOW_HOURS as u32
1962 } else {
1963 0
1964 };
1965 let end_hour = (hour + PREFETCH_TEMPORAL_WINDOW_HOURS as u32).min(23);
1966
1967 let start = now
1969 .with_hour(start_hour)
1970 .unwrap_or(now)
1971 .with_minute(0)
1972 .unwrap_or(now);
1973 let end = now
1974 .with_hour(end_hour)
1975 .unwrap_or(now)
1976 .with_minute(59)
1977 .unwrap_or(now);
1978
1979 Query {
1980 time_range: Some((start, end)),
1981 max_results: self.max_prefetch,
1982 retrieval_mode: super::types::RetrievalMode::Temporal,
1983 ..Default::default()
1984 }
1985 }
1986
1987 pub fn relevance_score(&self, memory: &Memory, context: &PrefetchContext) -> f32 {
1995 let mut score = 0.0;
1996
1997 if let Some(project_id) = &context.project_id {
1999 if let Some(ctx) = &memory.experience.context {
2000 if ctx.project.project_id.as_ref() == Some(project_id) {
2001 score += 0.4;
2002 }
2003 }
2004 }
2005
2006 let memory_entities: HashSet<_> = memory.experience.entities.iter().collect();
2008 let context_entities: HashSet<_> = context.recent_entities.iter().collect();
2009 let overlap = memory_entities.intersection(&context_entities).count();
2010 if overlap > 0 {
2011 score += 0.2 * (overlap as f32 / context_entities.len().max(1) as f32);
2012 }
2013
2014 if let Some(current_file) = &context.current_file {
2016 if memory.experience.content.contains(current_file) {
2017 score += 0.2;
2018 }
2019 if let Some(ctx) = &memory.experience.context {
2021 if ctx.code.related_files.iter().any(|f| f == current_file) {
2022 score += 0.1;
2023 }
2024 }
2025 }
2026
2027 if let Some(hour) = context.hour_of_day {
2029 let memory_hour = memory.created_at.hour();
2030 if (memory_hour as i32 - hour as i32).abs() <= PREFETCH_TEMPORAL_WINDOW_HOURS as i32 {
2031 score += 0.1;
2032 }
2033 }
2034
2035 let age_hours = (chrono::Utc::now() - memory.created_at).num_hours();
2037 if age_hours < PREFETCH_RECENCY_FULL_HOURS {
2038 score += PREFETCH_RECENCY_FULL_BOOST;
2039 } else if age_hours < PREFETCH_RECENCY_PARTIAL_HOURS {
2040 score += PREFETCH_RECENCY_PARTIAL_BOOST;
2041 }
2042
2043 if let Some(ctx) = &memory.experience.context {
2046 if ctx.emotional.arousal > 0.6 {
2048 score += 0.1 * ctx.emotional.arousal;
2049 }
2050
2051 if ctx.source.credibility > 0.8 {
2053 score += 0.05;
2054 }
2055
2056 if let Some(current_episode) = &context.episode_id {
2058 if ctx.episode.episode_id.as_ref() == Some(current_episode) {
2059 score += 0.3; }
2061 }
2062
2063 if let Some(current_valence) = context.emotional_valence {
2066 let valence_diff = (ctx.emotional.valence - current_valence).abs();
2067 if valence_diff < 0.3 {
2068 score += 0.1 * (1.0 - valence_diff / 0.3);
2070 }
2071 }
2072 }
2073
2074 score.min(1.0)
2075 }
2076}
2077
2078use chrono::{Datelike, Timelike};
2079
2080#[cfg(test)]
2081mod tests {
2082 use super::*;
2083
2084 #[test]
2085 fn test_id_mapping_basic() {
2086 let mut mapping = IdMapping::new();
2087 let memory_id = MemoryId(uuid::Uuid::new_v4());
2088
2089 mapping.insert(memory_id.clone(), 42);
2090
2091 assert_eq!(mapping.len(), 1);
2092 assert_eq!(mapping.get_memory_id(42), Some(&memory_id));
2093 }
2094
2095 #[test]
2096 fn test_id_mapping_chunks() {
2097 let mut mapping = IdMapping::new();
2098 let memory_id = MemoryId(uuid::Uuid::new_v4());
2099
2100 mapping.insert_chunks(memory_id.clone(), vec![1, 2, 3]);
2101
2102 assert_eq!(mapping.len(), 1);
2103 assert_eq!(mapping.get_memory_id(1), Some(&memory_id));
2104 assert_eq!(mapping.get_memory_id(2), Some(&memory_id));
2105 assert_eq!(mapping.get_memory_id(3), Some(&memory_id));
2106 }
2107
2108 #[test]
2109 fn test_id_mapping_remove_all() {
2110 let mut mapping = IdMapping::new();
2111 let memory_id = MemoryId(uuid::Uuid::new_v4());
2112
2113 mapping.insert_chunks(memory_id.clone(), vec![1, 2, 3]);
2114 let removed = mapping.remove_all(&memory_id);
2115
2116 assert_eq!(removed.len(), 3);
2117 assert_eq!(mapping.len(), 0);
2118 assert!(mapping.get_memory_id(1).is_none());
2119 }
2120
2121 #[test]
2122 fn test_id_mapping_clear() {
2123 let mut mapping = IdMapping::new();
2124 mapping.insert(MemoryId(uuid::Uuid::new_v4()), 1);
2125 mapping.insert(MemoryId(uuid::Uuid::new_v4()), 2);
2126
2127 mapping.clear();
2128
2129 assert_eq!(mapping.len(), 0);
2130 }
2131
2132 #[test]
2133 fn test_retrieval_outcome_default() {
2134 let outcome = RetrievalOutcome::default();
2135 assert_eq!(outcome, RetrievalOutcome::Neutral);
2136 }
2137
2138 #[test]
2139 fn test_reinforcement_stats_default() {
2140 let stats = ReinforcementStats::default();
2141
2142 assert_eq!(stats.memories_processed, 0);
2143 assert_eq!(stats.associations_strengthened, 0);
2144 assert_eq!(stats.importance_boosts, 0);
2145 assert_eq!(stats.importance_decays, 0);
2146 }
2147
2148 #[test]
2149 fn test_memory_graph_stats_default() {
2150 let stats = MemoryGraphStats::default();
2151
2152 assert_eq!(stats.node_count, 0);
2153 assert_eq!(stats.edge_count, 0);
2154 assert_eq!(stats.avg_strength, 0.0);
2155 assert_eq!(stats.potentiated_count, 0);
2156 }
2157
2158 #[test]
2159 fn test_prefetch_context_default() {
2160 let ctx = PrefetchContext::default();
2161
2162 assert!(ctx.project_id.is_none());
2163 assert!(ctx.current_file.is_none());
2164 assert!(ctx.recent_entities.is_empty());
2165 }
2166
2167 #[test]
2168 fn test_prefetch_context_from_current_time() {
2169 let ctx = PrefetchContext::from_current_time();
2170
2171 assert!(ctx.hour_of_day.is_some());
2172 assert!(ctx.day_of_week.is_some());
2173 }
2174
2175 #[test]
2176 fn test_anticipatory_prefetch_new() {
2177 let prefetch = AnticipatoryPrefetch::new();
2178 assert_eq!(prefetch.max_prefetch, 20);
2179 }
2180
2181 #[test]
2182 fn test_anticipatory_prefetch_with_limit() {
2183 let prefetch = AnticipatoryPrefetch::with_limit(50);
2184 assert_eq!(prefetch.max_prefetch, 50);
2185 }
2186
2187 #[test]
2188 fn test_generate_prefetch_query_project() {
2189 let prefetch = AnticipatoryPrefetch::new();
2190 let ctx = PrefetchContext {
2191 project_id: Some("my-project".to_string()),
2192 ..Default::default()
2193 };
2194
2195 let query = prefetch.generate_prefetch_query(&ctx);
2196
2197 assert!(query.is_some());
2198 let query = query.unwrap();
2199 assert!(query.query_text.unwrap().contains("my-project"));
2200 }
2201
2202 #[test]
2203 fn test_generate_prefetch_query_entities() {
2204 let prefetch = AnticipatoryPrefetch::new();
2205 let ctx = PrefetchContext {
2206 recent_entities: vec!["Rust".to_string(), "memory".to_string()],
2207 ..Default::default()
2208 };
2209
2210 let query = prefetch.generate_prefetch_query(&ctx);
2211
2212 assert!(query.is_some());
2213 let query = query.unwrap();
2214 let text = query.query_text.unwrap();
2215 assert!(text.contains("Rust"));
2216 assert!(text.contains("memory"));
2217 }
2218
2219 #[test]
2220 fn test_generate_prefetch_query_file() {
2221 let prefetch = AnticipatoryPrefetch::new();
2222 let ctx = PrefetchContext {
2223 current_file: Some("/src/memory/retrieval.rs".to_string()),
2224 ..Default::default()
2225 };
2226
2227 let query = prefetch.generate_prefetch_query(&ctx);
2228
2229 assert!(query.is_some());
2230 let query = query.unwrap();
2231 assert!(query.query_text.unwrap().contains("retrieval.rs"));
2232 }
2233
2234 #[test]
2235 fn test_generate_prefetch_query_temporal() {
2236 let prefetch = AnticipatoryPrefetch::new();
2237 let ctx = PrefetchContext {
2238 hour_of_day: Some(14),
2239 day_of_week: Some(1),
2240 ..Default::default()
2241 };
2242
2243 let query = prefetch.generate_prefetch_query(&ctx);
2244
2245 assert!(query.is_some());
2246 let query = query.unwrap();
2247 assert!(query.time_range.is_some());
2248 }
2249
2250 #[test]
2251 fn test_generate_prefetch_query_empty() {
2252 let prefetch = AnticipatoryPrefetch::new();
2253 let ctx = PrefetchContext::default();
2254
2255 let query = prefetch.generate_prefetch_query(&ctx);
2256
2257 assert!(query.is_none());
2258 }
2259
2260 #[test]
2261 fn test_prefetch_reason_default() {
2262 let reason = PrefetchReason::default();
2263 assert!(matches!(reason, PrefetchReason::Mixed));
2264 }
2265
2266 #[test]
2267 fn test_prefetch_result_default() {
2268 let result = PrefetchResult::default();
2269
2270 assert!(result.prefetched_ids.is_empty());
2271 assert_eq!(result.cache_hits, 0);
2272 assert_eq!(result.fetches, 0);
2273 }
2274
2275 #[test]
2276 fn test_index_health_struct() {
2277 let health = IndexHealth {
2278 total_vectors: 1000,
2279 incremental_inserts: 100,
2280 deleted_count: 50,
2281 deletion_ratio: 0.05,
2282 needs_rebuild: false,
2283 needs_compaction: false,
2284 rebuild_threshold: 500,
2285 deletion_ratio_threshold: 0.2,
2286 };
2287
2288 assert_eq!(health.total_vectors, 1000);
2289 assert!(!health.needs_rebuild);
2290 }
2291
2292 #[test]
2293 fn test_id_mapping_insert_is_idempotent() {
2294 let mut mapping = IdMapping::new();
2295 let memory_id = MemoryId(uuid::Uuid::new_v4());
2296
2297 mapping.insert(memory_id.clone(), 10);
2299 assert_eq!(mapping.len(), 1);
2300 assert_eq!(mapping.get_memory_id(10), Some(&memory_id));
2301
2302 mapping.insert(memory_id.clone(), 20);
2304 assert_eq!(mapping.len(), 1);
2305 assert_eq!(mapping.get_memory_id(20), Some(&memory_id));
2306 assert!(
2308 mapping.get_memory_id(10).is_none(),
2309 "old vector_id should be removed to prevent orphan"
2310 );
2311 assert_eq!(mapping.memory_to_vectors[&memory_id], vec![20]);
2312 }
2313
2314 #[test]
2315 fn test_id_mapping_insert_chunks_is_idempotent() {
2316 let mut mapping = IdMapping::new();
2317 let memory_id = MemoryId(uuid::Uuid::new_v4());
2318
2319 mapping.insert_chunks(memory_id.clone(), vec![1, 2, 3]);
2321 assert_eq!(mapping.len(), 1);
2322 assert_eq!(mapping.memory_to_vectors[&memory_id], vec![1, 2, 3]);
2323
2324 mapping.insert_chunks(memory_id.clone(), vec![10, 11]);
2326 assert_eq!(mapping.len(), 1);
2327 assert_eq!(mapping.memory_to_vectors[&memory_id], vec![10, 11]);
2328 assert!(mapping.get_memory_id(1).is_none(), "old chunk 1 orphaned");
2330 assert!(mapping.get_memory_id(2).is_none(), "old chunk 2 orphaned");
2331 assert!(mapping.get_memory_id(3).is_none(), "old chunk 3 orphaned");
2332 assert_eq!(mapping.get_memory_id(10), Some(&memory_id));
2334 assert_eq!(mapping.get_memory_id(11), Some(&memory_id));
2335 }
2336
2337 #[test]
2338 fn test_id_mapping_vector_count_stable_after_reinsert() {
2339 let mut mapping = IdMapping::new();
2340 let m1 = MemoryId(uuid::Uuid::new_v4());
2341 let m2 = MemoryId(uuid::Uuid::new_v4());
2342
2343 mapping.insert(m1.clone(), 1);
2344 mapping.insert(m2.clone(), 2);
2345 assert_eq!(mapping.vector_to_memory.len(), 2);
2346
2347 mapping.insert(m1.clone(), 3);
2349 assert_eq!(
2350 mapping.vector_to_memory.len(),
2351 2,
2352 "vector_to_memory should not grow on re-insert"
2353 );
2354 }
2355}