1use std::path::{Path, PathBuf};
2use std::str::FromStr;
3use std::sync::Arc;
4
5use chrono::Local;
6use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
7use sqlx::{Row, SqlitePool};
8use tracing::debug;
9
10use starpod_core::{StarpodError, Result};
11
12use crate::defaults;
13use crate::embedder::{self, Embedder};
14use crate::fusion;
15use crate::indexer::{self, reindex_source, CHUNK_SIZE, CHUNK_OVERLAP};
16use crate::schema;
17use crate::scoring;
18
19const BOOTSTRAP_FILE_CAP: usize = 20_000;
21
22const DEFAULT_HALF_LIFE_DAYS: f64 = 30.0;
24
25#[derive(Debug, Clone)]
32pub struct SearchResult {
33 pub source: String,
35 pub text: String,
37 pub line_start: usize,
39 pub line_end: usize,
41 pub rank: f64,
46}
47
48pub struct MemoryStore {
68 agent_home: PathBuf,
70 config_dir: PathBuf,
72 pool: SqlitePool,
73 half_life_days: f64,
75 mmr_lambda: f64,
77 embedder: Option<Arc<dyn Embedder>>,
79 chunk_size: usize,
81 chunk_overlap: usize,
83 bootstrap_file_cap: usize,
85}
86
87impl MemoryStore {
88 pub async fn new(agent_home: &Path, config_dir: &Path, db_dir: &Path) -> Result<Self> {
94 std::fs::create_dir_all(agent_home)
96 .map_err(StarpodError::Io)?;
97 std::fs::create_dir_all(config_dir)
98 .map_err(StarpodError::Io)?;
99 std::fs::create_dir_all(db_dir)
100 .map_err(StarpodError::Io)?;
101
102 let db_path = db_dir.join("memory.db");
104 let opts = SqliteConnectOptions::from_str(
105 &format!("sqlite://{}?mode=rwc", db_path.display()),
106 )
107 .map_err(|e| StarpodError::Database(format!("Invalid DB path: {}", e)))?;
108
109 let pool = SqlitePoolOptions::new()
110 .max_connections(5)
111 .connect_with(opts)
112 .await
113 .map_err(|e| StarpodError::Database(format!("Failed to open database: {}", e)))?;
114
115 schema::run_migrations(&pool).await?;
117
118 let store = Self {
119 agent_home: agent_home.to_path_buf(),
120 config_dir: config_dir.to_path_buf(),
121 pool,
122 half_life_days: DEFAULT_HALF_LIFE_DAYS,
123 mmr_lambda: 0.7,
124 embedder: None,
125 chunk_size: CHUNK_SIZE,
126 chunk_overlap: CHUNK_OVERLAP,
127 bootstrap_file_cap: BOOTSTRAP_FILE_CAP,
128 };
129
130 store.seed_defaults()?;
132
133 store.reindex().await?;
135
136 Ok(store)
137 }
138
139 pub async fn new_user(user_dir: &Path) -> Result<Self> {
150 std::fs::create_dir_all(user_dir)
151 .map_err(StarpodError::Io)?;
152
153 let db_path = user_dir.join("memory.db");
154 let opts = SqliteConnectOptions::from_str(
155 &format!("sqlite://{}?mode=rwc", db_path.display()),
156 )
157 .map_err(|e| StarpodError::Database(format!("Invalid DB path: {}", e)))?;
158
159 let pool = SqlitePoolOptions::new()
160 .max_connections(1)
161 .connect_with(opts)
162 .await
163 .map_err(|e| StarpodError::Database(format!("Failed to open user database: {}", e)))?;
164
165 schema::run_migrations(&pool).await?;
166
167 let store = Self {
168 agent_home: user_dir.to_path_buf(),
169 config_dir: user_dir.to_path_buf(),
170 pool,
171 half_life_days: DEFAULT_HALF_LIFE_DAYS,
172 mmr_lambda: 0.7,
173 embedder: None,
174 chunk_size: CHUNK_SIZE,
175 chunk_overlap: CHUNK_OVERLAP,
176 bootstrap_file_cap: BOOTSTRAP_FILE_CAP,
177 };
178
179 store.reindex().await?;
181
182 Ok(store)
183 }
184
185 fn seed_defaults(&self) -> Result<bool> {
193 let fresh = !self.config_dir.join("SOUL.md").exists();
194
195 if fresh {
197 let path = self.config_dir.join("SOUL.md");
198 debug!(file = "SOUL.md", "Seeding default SOUL.md");
199 std::fs::write(&path, defaults::DEFAULT_SOUL)?;
200 }
201
202 let lifecycle_files = [
204 ("HEARTBEAT.md", defaults::DEFAULT_HEARTBEAT),
205 ("BOOT.md", defaults::DEFAULT_BOOT),
206 ("BOOTSTRAP.md", defaults::DEFAULT_BOOTSTRAP),
207 ];
208
209 for (name, content) in &lifecycle_files {
210 let path = self.config_dir.join(name);
211 if !path.exists() {
212 debug!(file = %name, "Seeding default file");
213 std::fs::write(&path, content)?;
214 }
215 }
216
217 Ok(fresh)
218 }
219
220 pub fn agent_home(&self) -> &Path {
222 &self.agent_home
223 }
224
225 pub fn config_dir(&self) -> &Path {
227 &self.config_dir
228 }
229
230 const CONFIG_FILES: &[&str] = &[
232 "SOUL.md", "HEARTBEAT.md", "BOOT.md", "BOOTSTRAP.md",
233 ];
234
235 fn resolve_path(&self, name: &str) -> PathBuf {
237 if !name.contains('/') && Self::CONFIG_FILES.iter().any(|&f| f == name) {
239 self.config_dir.join(name)
240 } else {
241 self.agent_home.join(name)
242 }
243 }
244
245 pub fn has_bootstrap(&self) -> bool {
247 let path = self.config_dir.join("BOOTSTRAP.md");
248 path.is_file()
249 && std::fs::read_to_string(&path)
250 .map(|c| !c.trim().is_empty())
251 .unwrap_or(false)
252 }
253
254 pub fn clear_bootstrap(&self) -> Result<()> {
256 let path = self.config_dir.join("BOOTSTRAP.md");
257 if path.exists() {
258 std::fs::write(&path, "")?;
259 }
260 Ok(())
261 }
262
263 pub fn bootstrap_context(&self) -> Result<String> {
268 let content = self.read_file("SOUL.md")?;
269 let capped = if content.len() > self.bootstrap_file_cap {
270 let mut end = self.bootstrap_file_cap;
271 while end > 0 && !content.is_char_boundary(end) { end -= 1; }
272 &content[..end]
273 } else {
274 &content
275 };
276 Ok(format!("--- SOUL.md ---\n{}", capped))
277 }
278
279 pub fn set_half_life_days(&mut self, days: f64) {
281 self.half_life_days = days;
282 }
283
284 pub fn set_mmr_lambda(&mut self, lambda: f64) {
286 self.mmr_lambda = lambda;
287 }
288
289 pub fn set_chunk_size(&mut self, size: usize) {
291 self.chunk_size = size;
292 }
293
294 pub fn set_chunk_overlap(&mut self, overlap: usize) {
296 self.chunk_overlap = overlap;
297 }
298
299 pub fn set_bootstrap_file_cap(&mut self, cap: usize) {
301 self.bootstrap_file_cap = cap;
302 }
303
304 pub async fn search(&self, query: &str, limit: usize) -> Result<Vec<SearchResult>> {
310 let fetch_limit = (limit * 3).max(30);
312 let rows = sqlx::query(
313 "SELECT source, chunk_text, line_start, line_end, rank
314 FROM memory_fts
315 WHERE memory_fts MATCH ?1
316 ORDER BY rank
317 LIMIT ?2",
318 )
319 .bind(query)
320 .bind(fetch_limit as i64)
321 .fetch_all(&self.pool)
322 .await
323 .map_err(|e| StarpodError::Database(format!("Search query failed: {}", e)))?;
324
325 let mut results: Vec<SearchResult> = rows
326 .iter()
327 .map(|row| {
328 let source = row.get::<String, _>("source");
329 let raw_rank = row.get::<f64, _>("rank");
330 let adjusted_rank = scoring::apply_decay(raw_rank, &source, self.half_life_days);
331 SearchResult {
332 source,
333 text: row.get::<String, _>("chunk_text"),
334 line_start: row.get::<i64, _>("line_start") as usize,
335 line_end: row.get::<i64, _>("line_end") as usize,
336 rank: adjusted_rank,
337 }
338 })
339 .collect();
340
341 results.sort_by(|a, b| a.rank.partial_cmp(&b.rank).unwrap_or(std::cmp::Ordering::Equal));
343 results.truncate(limit);
344
345 Ok(results)
346 }
347
348 pub fn set_embedder(&mut self, embedder: Arc<dyn Embedder>) {
350 self.embedder = Some(embedder);
351 }
352
353 pub async fn vector_search(&self, query: &str, limit: usize) -> Result<Vec<SearchResult>> {
357 let embedder = match &self.embedder {
358 Some(e) => e,
359 None => return Ok(Vec::new()),
360 };
361
362 let query_vecs = embedder
364 .embed(&[query.to_string()])
365 .await?;
366 let query_vec = match query_vecs.first() {
367 Some(v) => v,
368 None => return Ok(Vec::new()),
369 };
370
371 let rows = sqlx::query(
373 "SELECT v.source, v.embedding, v.line_start, v.line_end, f.chunk_text
374 FROM memory_vectors v
375 LEFT JOIN memory_fts f ON f.source = v.source
376 AND f.line_start = v.line_start AND f.line_end = v.line_end"
377 )
378 .fetch_all(&self.pool)
379 .await
380 .map_err(|e| StarpodError::Database(format!("Vector search failed: {}", e)))?;
381
382 let mut scored: Vec<(f32, SearchResult)> = Vec::new();
383 for row in &rows {
384 let blob: Vec<u8> = row.get("embedding");
385 let embedding = bytes_to_f32_vec(&blob);
386 let similarity = embedder::cosine_similarity(query_vec, &embedding);
387
388 let source: String = row.get("source");
389 let text: String = row.try_get("chunk_text").unwrap_or_default();
390
391 scored.push((similarity, SearchResult {
392 source,
393 text,
394 line_start: row.get::<i64, _>("line_start") as usize,
395 line_end: row.get::<i64, _>("line_end") as usize,
396 rank: -(similarity as f64), }));
398 }
399
400 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
402 scored.truncate(limit);
403
404 Ok(scored.into_iter().map(|(_, r)| r).collect())
405 }
406
407 pub async fn hybrid_search(&self, query: &str, limit: usize) -> Result<Vec<SearchResult>> {
412 let embedder = match &self.embedder {
413 Some(e) => e,
414 None => return self.search(query, limit).await,
415 };
416
417 let fts_limit = (limit * 3).max(30);
419 let vec_limit = (limit * 3).max(30);
420
421 let (fts_results, vec_results) = tokio::join!(
422 self.fts_search_raw(query, fts_limit),
423 self.vector_search(query, vec_limit),
424 );
425
426 let fts_results = fts_results?;
427 let vec_results = vec_results?;
428
429 let mut fused = fusion::reciprocal_rank_fusion(&fts_results, &vec_results, limit * 3);
431
432 for result in &mut fused {
434 let decay = scoring::decay_factor(&result.source, self.half_life_days);
435 if decay > 0.0 && decay < 1.0 {
436 result.rank *= decay;
437 }
438 }
439
440 fused.sort_by(|a, b| a.rank.partial_cmp(&b.rank).unwrap_or(std::cmp::Ordering::Equal));
442
443 let mmr_pool_size = (limit * 2).min(fused.len());
445 if mmr_pool_size > 0 {
446 let query_vecs = embedder.embed(&[query.to_string()]).await?;
448 if let Some(query_vec) = query_vecs.first() {
449 let texts: Vec<String> = fused[..mmr_pool_size]
451 .iter()
452 .map(|r| r.text.clone())
453 .collect();
454 let embeddings = embedder.embed(&texts).await?;
455
456 let candidates: Vec<(Vec<f32>, usize)> = embeddings
457 .into_iter()
458 .enumerate()
459 .map(|(i, emb)| (emb, i))
460 .collect();
461
462 let selected_indices =
463 scoring::mmr_rerank(query_vec, &candidates, limit, self.mmr_lambda);
464
465 let pool = fused;
466 fused = selected_indices
467 .into_iter()
468 .map(|idx| pool[idx].clone())
469 .collect();
470 } else {
471 fused.truncate(limit);
472 }
473 }
474
475 Ok(fused)
476 }
477
478 async fn fts_search_raw(&self, query: &str, limit: usize) -> Result<Vec<SearchResult>> {
480 let rows = sqlx::query(
481 "SELECT source, chunk_text, line_start, line_end, rank
482 FROM memory_fts
483 WHERE memory_fts MATCH ?1
484 ORDER BY rank
485 LIMIT ?2",
486 )
487 .bind(query)
488 .bind(limit as i64)
489 .fetch_all(&self.pool)
490 .await
491 .map_err(|e| StarpodError::Database(format!("Search query failed: {}", e)))?;
492
493 Ok(rows
494 .iter()
495 .map(|row| SearchResult {
496 source: row.get::<String, _>("source"),
497 text: row.get::<String, _>("chunk_text"),
498 line_start: row.get::<i64, _>("line_start") as usize,
499 line_end: row.get::<i64, _>("line_end") as usize,
500 rank: row.get::<f64, _>("rank"),
501 })
502 .collect())
503 }
504
505 async fn embed_and_store_source(&self, source: &str, text: &str) -> Result<()> {
507 let embedder = match &self.embedder {
508 Some(e) => e,
509 None => return Ok(()),
510 };
511
512 sqlx::query("DELETE FROM memory_vectors WHERE source = ?1")
514 .bind(source)
515 .execute(&self.pool)
516 .await
517 .map_err(|e| StarpodError::Database(format!("Failed to delete old vectors: {}", e)))?;
518
519 let chunks = indexer::chunk_text(source, text, self.chunk_size, self.chunk_overlap);
521 if chunks.is_empty() {
522 return Ok(());
523 }
524
525 let texts: Vec<String> = chunks.iter().map(|c| c.text.clone()).collect();
527 let embeddings = embedder.embed(&texts).await?;
528
529 for (idx, (chunk, embedding)) in chunks.iter().zip(embeddings.iter()).enumerate() {
531 let blob = f32_vec_to_bytes(embedding);
532 sqlx::query(
533 "INSERT INTO memory_vectors (source, chunk_idx, embedding, line_start, line_end)
534 VALUES (?1, ?2, ?3, ?4, ?5)",
535 )
536 .bind(&chunk.source)
537 .bind(idx as i64)
538 .bind(&blob)
539 .bind(chunk.line_start as i64)
540 .bind(chunk.line_end as i64)
541 .execute(&self.pool)
542 .await
543 .map_err(|e| StarpodError::Database(format!("Failed to insert vector: {}", e)))?;
544 }
545
546 Ok(())
547 }
548
549 pub fn read_file(&self, name: &str) -> Result<String> {
551 scoring::validate_path(name, &self.agent_home)?;
553 let path = self.resolve_path(name);
554 if !path.exists() {
555 return Ok(String::new());
556 }
557 std::fs::read_to_string(&path).map_err(StarpodError::Io)
558 }
559
560 pub async fn write_file(&self, name: &str, content: &str) -> Result<()> {
565 scoring::validate_path(name, &self.agent_home)?;
566 scoring::validate_content_size(content)?;
567
568 let path = self.resolve_path(name);
569
570 if let Some(parent) = path.parent() {
572 std::fs::create_dir_all(parent)?;
573 }
574
575 std::fs::write(&path, content)?;
576
577 reindex_source(&self.pool, name, content, self.chunk_size, self.chunk_overlap).await?;
579 self.embed_and_store_source(name, content).await?;
580
581 Ok(())
582 }
583
584 pub async fn append_daily(&self, text: &str) -> Result<()> {
586 let today = Local::now().format("%Y-%m-%d").to_string();
587 let filename = format!("memory/{}.md", today);
588 let path = self.agent_home.join(&filename);
589
590 if let Some(parent) = path.parent() {
592 std::fs::create_dir_all(parent)?;
593 }
594
595 let timestamp = Local::now().format("%H:%M:%S").to_string();
596 let entry = format!("\n## {}\n{}\n", timestamp, text);
597
598 let mut content = if path.exists() {
599 std::fs::read_to_string(&path)?
600 } else {
601 format!("# Daily Log — {}\n", today)
602 };
603
604 content.push_str(&entry);
605 std::fs::write(&path, &content)?;
606
607 reindex_source(&self.pool, &filename, &content, self.chunk_size, self.chunk_overlap).await?;
609 self.embed_and_store_source(&filename, &content).await?;
610
611 Ok(())
612 }
613
614 pub async fn reindex(&self) -> Result<()> {
620 sqlx::query("DELETE FROM memory_fts")
622 .execute(&self.pool)
623 .await
624 .map_err(|e| StarpodError::Database(format!("Failed to clear FTS: {}", e)))?;
625
626 sqlx::query("DELETE FROM memory_vectors")
628 .execute(&self.pool)
629 .await
630 .map_err(|e| StarpodError::Database(format!("Failed to clear vectors: {}", e)))?;
631
632 self.index_dir(&self.config_dir.clone(), "").await?;
634
635 if let Ok(entries) = std::fs::read_dir(&self.agent_home) {
637 for entry in entries.flatten() {
638 let path = entry.path();
639 if path.is_file() && path.extension().is_some_and(|ext| ext == "md") {
640 let filename = entry.file_name().to_string_lossy().to_string();
641 if !Self::CONFIG_FILES.iter().any(|&f| f == filename) {
643 let content = std::fs::read_to_string(&path)?;
644 reindex_source(&self.pool, &filename, &content, self.chunk_size, self.chunk_overlap).await?;
645 self.embed_and_store_source(&filename, &content).await?;
646 }
647 }
648 }
649 }
650
651 Ok(())
652 }
653
654 async fn index_dir(&self, dir: &Path, prefix: &str) -> Result<()> {
656 let entries = std::fs::read_dir(dir).map_err(StarpodError::Io)?;
657
658 for entry in entries {
659 let entry = entry.map_err(StarpodError::Io)?;
660 let path = entry.path();
661 if path.is_file() && path.extension().is_some_and(|ext| ext == "md") {
662 let filename = entry.file_name().to_string_lossy().to_string();
663 let source = format!("{}{}", prefix, filename);
664 let content = std::fs::read_to_string(&path)?;
665 reindex_source(&self.pool, &source, &content, self.chunk_size, self.chunk_overlap).await?;
666 self.embed_and_store_source(&source, &content).await?;
667 }
668 }
669
670 Ok(())
671 }
672}
673
674fn f32_vec_to_bytes(vec: &[f32]) -> Vec<u8> {
676 let mut bytes = Vec::with_capacity(vec.len() * 4);
677 for &v in vec {
678 bytes.extend_from_slice(&v.to_le_bytes());
679 }
680 bytes
681}
682
683fn bytes_to_f32_vec(bytes: &[u8]) -> Vec<f32> {
685 bytes
686 .chunks_exact(4)
687 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
688 .collect()
689}
690
691#[cfg(test)]
692mod tests {
693 use super::*;
694 use tempfile::TempDir;
695
696 async fn test_store(tmp: &TempDir) -> MemoryStore {
700 let agent_home = tmp.path().join("agent_home");
701 let config_dir = tmp.path().join("agent_home").join("config");
702 let db_dir = tmp.path().join("db");
703 MemoryStore::new(&agent_home, &config_dir, &db_dir).await.unwrap()
704 }
705
706 #[tokio::test]
709 async fn test_new_seeds_defaults() {
710 let tmp = TempDir::new().unwrap();
711 let store = test_store(&tmp).await;
712 let config_dir = tmp.path().join("agent_home").join("config");
713
714 assert!(config_dir.join("SOUL.md").exists());
716 assert!(config_dir.join("HEARTBEAT.md").exists());
717 assert!(config_dir.join("BOOT.md").exists());
718 assert!(config_dir.join("BOOTSTRAP.md").exists());
719
720 assert!(!config_dir.join("USER.md").exists());
722 assert!(!config_dir.join("MEMORY.md").exists());
723
724 assert!(tmp.path().join("db").join("memory.db").exists());
726
727 let soul = store.read_file("SOUL.md").unwrap();
729 assert!(soul.contains("Aster"));
730 }
731
732 #[tokio::test]
733 async fn test_write_and_search() {
734 let tmp = TempDir::new().unwrap();
735 let store = test_store(&tmp).await;
736
737 store
738 .write_file("test-content.md", "Rust is a systems programming language focused on safety and performance.")
739 .await
740 .unwrap();
741
742 let results = store.search("Rust programming", 5).await.unwrap();
743 assert!(!results.is_empty());
744 assert!(results[0].text.contains("Rust"));
745 }
746
747 #[tokio::test]
748 async fn test_append_daily() {
749 let tmp = TempDir::new().unwrap();
750 let store = test_store(&tmp).await;
751 let agent_home = tmp.path().join("agent_home");
752
753 std::fs::create_dir_all(agent_home.join("memory")).unwrap();
755
756 store.append_daily("Had a great conversation about Rust.").await.unwrap();
757 store.append_daily("Discussed memory management.").await.unwrap();
758
759 let today = Local::now().format("%Y-%m-%d").to_string();
760 let content = store.read_file(&format!("memory/{}.md", today)).unwrap();
761 assert!(content.contains("great conversation"));
762 assert!(content.contains("memory management"));
763 }
764
765 #[tokio::test]
766 async fn test_bootstrap_context() {
767 let tmp = TempDir::new().unwrap();
768 let store = test_store(&tmp).await;
769
770 let ctx = store.bootstrap_context().unwrap();
771 assert!(ctx.contains("SOUL.md"));
772 assert!(ctx.contains("Aster"));
773 assert!(!ctx.contains("USER.md"));
775 assert!(!ctx.contains("MEMORY.md"));
776 }
777
778 #[tokio::test]
779 async fn test_reindex() {
780 let tmp = TempDir::new().unwrap();
781 let store = test_store(&tmp).await;
782 let agent_home = tmp.path().join("agent_home");
783
784 std::fs::write(
786 agent_home.join("test-quantum.md"),
787 "This is about quantum computing and qubits.",
788 )
789 .unwrap();
790
791 store.reindex().await.unwrap();
793
794 let results = store.search("quantum computing", 5).await.unwrap();
795 assert!(!results.is_empty());
796 }
797
798 #[tokio::test]
801 async fn write_file_rejects_traversal() {
802 let tmp = TempDir::new().unwrap();
803 let store = test_store(&tmp).await;
804 let err = store.write_file("../escape.md", "evil content").await;
805 assert!(err.is_err(), "write_file should reject path traversal");
806 }
807
808 #[tokio::test]
809 async fn write_file_rejects_non_md() {
810 let tmp = TempDir::new().unwrap();
811 let store = test_store(&tmp).await;
812 let err = store.write_file("script.sh", "#!/bin/bash").await;
813 assert!(err.is_err(), "write_file should reject non-.md files");
814 }
815
816 #[tokio::test]
817 async fn write_file_rejects_absolute_path() {
818 let tmp = TempDir::new().unwrap();
819 let store = test_store(&tmp).await;
820 let err = store.write_file("/tmp/evil.md", "content").await;
821 assert!(err.is_err(), "write_file should reject absolute paths");
822 }
823
824 #[tokio::test]
825 async fn read_file_rejects_traversal() {
826 let tmp = TempDir::new().unwrap();
827 let store = test_store(&tmp).await;
828 let err = store.read_file("../../etc/passwd.md");
829 assert!(err.is_err(), "read_file should reject path traversal");
830 }
831
832 #[tokio::test]
833 async fn read_file_rejects_non_md() {
834 let tmp = TempDir::new().unwrap();
835 let store = test_store(&tmp).await;
836 let err = store.read_file("secret.json");
837 assert!(err.is_err(), "read_file should reject non-.md files");
838 }
839
840 #[tokio::test]
843 async fn write_file_rejects_oversized_content() {
844 let tmp = TempDir::new().unwrap();
845 let store = test_store(&tmp).await;
846 let big = "x".repeat(scoring::MAX_WRITE_SIZE + 1);
847 let err = store.write_file("big.md", &big).await;
848 assert!(err.is_err(), "write_file should reject content > 1 MB");
849 }
850
851 #[tokio::test]
852 async fn write_file_accepts_content_at_limit() {
853 let tmp = TempDir::new().unwrap();
854 let store = test_store(&tmp).await;
855 let exact = "x".repeat(scoring::MAX_WRITE_SIZE);
856 let result = store.write_file("exact.md", &exact).await;
857 assert!(result.is_ok(), "write_file should accept content at exactly 1 MB");
858 }
859
860 #[tokio::test]
863 async fn set_half_life_days_is_applied() {
864 let tmp = TempDir::new().unwrap();
865 let mut store = test_store(&tmp).await;
866 store.set_half_life_days(7.0);
867 assert_eq!(store.half_life_days, 7.0);
868 }
869
870 #[tokio::test]
871 async fn set_mmr_lambda_is_applied() {
872 let tmp = TempDir::new().unwrap();
873 let mut store = test_store(&tmp).await;
874 store.set_mmr_lambda(0.5);
875 assert_eq!(store.mmr_lambda, 0.5);
876 }
877
878 #[tokio::test]
879 async fn set_chunk_size_is_applied() {
880 let tmp = TempDir::new().unwrap();
881 let mut store = test_store(&tmp).await;
882 store.set_chunk_size(800);
883 assert_eq!(store.chunk_size, 800);
884 }
885
886 #[tokio::test]
887 async fn set_chunk_overlap_is_applied() {
888 let tmp = TempDir::new().unwrap();
889 let mut store = test_store(&tmp).await;
890 store.set_chunk_overlap(160);
891 assert_eq!(store.chunk_overlap, 160);
892 }
893
894 #[tokio::test]
895 async fn set_bootstrap_file_cap_is_applied() {
896 let tmp = TempDir::new().unwrap();
897 let mut store = test_store(&tmp).await;
898 store.set_bootstrap_file_cap(5000);
899 assert_eq!(store.bootstrap_file_cap, 5000);
900 }
901
902 #[tokio::test]
903 async fn bootstrap_file_cap_limits_output() {
904 let tmp = TempDir::new().unwrap();
905 let mut store = test_store(&tmp).await;
906
907 let large_content = "x".repeat(10_000);
909 store.write_file("SOUL.md", &large_content).await.unwrap();
910
911 store.set_bootstrap_file_cap(500);
913
914 let ctx = store.bootstrap_context().unwrap();
915 let soul_section = ctx
918 .split("--- SOUL.md ---\n")
919 .nth(1)
920 .unwrap_or("")
921 .split("\n\n--- ")
922 .next()
923 .unwrap_or("");
924 assert!(
925 soul_section.len() <= 500,
926 "SOUL.md section should be capped at 500 chars, got {}",
927 soul_section.len(),
928 );
929 }
930
931 #[tokio::test]
934 async fn vector_search_returns_empty_without_embedder() {
935 let tmp = TempDir::new().unwrap();
936 let store = test_store(&tmp).await;
937 let results = store.vector_search("anything", 10).await.unwrap();
938 assert!(results.is_empty(), "vector_search should return empty without embedder");
939 }
940
941 #[tokio::test]
942 async fn hybrid_search_falls_back_to_fts_without_embedder() {
943 let tmp = TempDir::new().unwrap();
944 let store = test_store(&tmp).await;
945
946 store
947 .write_file("test-elephants.md", "Unique test content about elephants in Africa.")
948 .await
949 .unwrap();
950
951 let results = store.hybrid_search("elephants Africa", 5).await.unwrap();
953 assert!(!results.is_empty(), "hybrid_search should fall back to FTS without embedder");
954 assert!(results[0].text.contains("elephants"));
955 }
956
957 struct MockEmbedder;
963
964 #[async_trait::async_trait]
965 impl Embedder for MockEmbedder {
966 async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
967 Ok(texts.iter().map(|t| {
968 let mut vec = vec![0.0f32; 8];
969 for ch in t.chars() {
970 let idx = (ch.to_ascii_lowercase() as usize).wrapping_sub('a' as usize);
971 if idx < 8 {
972 vec[idx] += 1.0;
973 }
974 }
975 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
977 if norm > 0.0 {
978 for v in &mut vec {
979 *v /= norm;
980 }
981 }
982 vec
983 }).collect())
984 }
985
986 fn dimensions(&self) -> usize {
987 8
988 }
989 }
990
991 #[tokio::test]
992 async fn set_embedder_enables_vector_storage() {
993 let tmp = TempDir::new().unwrap();
994 let mut store = test_store(&tmp).await;
995 store.set_embedder(Arc::new(MockEmbedder));
996
997 store
998 .write_file("test-cats.md", "Cats are wonderful animals that love to sleep.")
999 .await
1000 .unwrap();
1001
1002 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM memory_vectors WHERE source = 'test-cats.md'")
1004 .fetch_one(&store.pool)
1005 .await
1006 .unwrap();
1007 assert!(count > 0, "Vectors should be stored after write_file with embedder");
1008 }
1009
1010 #[tokio::test]
1011 async fn vector_search_with_mock_embedder() {
1012 let tmp = TempDir::new().unwrap();
1013 let mut store = test_store(&tmp).await;
1014 store.set_embedder(Arc::new(MockEmbedder));
1015
1016 store.write_file("test-abc.md", "aaa bbb ccc abc").await.unwrap();
1017 store.write_file("test-def.md", "ddd eee fff def").await.unwrap();
1018
1019 let results = store.vector_search("aaa abc", 5).await.unwrap();
1021 assert!(!results.is_empty(), "vector_search should return results with embedder");
1022 }
1023
1024 #[tokio::test]
1025 async fn hybrid_search_with_mock_embedder() {
1026 let tmp = TempDir::new().unwrap();
1027 let mut store = test_store(&tmp).await;
1028 store.set_embedder(Arc::new(MockEmbedder));
1029
1030 store.write_file("test-alpha.md", "Alpha beta gamma delta").await.unwrap();
1031 store.write_file("test-beta.md", "Beta epsilon zeta eta").await.unwrap();
1032
1033 let results = store.hybrid_search("alpha beta", 5).await.unwrap();
1034 assert!(!results.is_empty(), "hybrid_search should return results with embedder");
1035 }
1036
1037 #[tokio::test]
1038 async fn reindex_clears_and_rebuilds_vectors() {
1039 let tmp = TempDir::new().unwrap();
1040 let mut store = test_store(&tmp).await;
1041 store.set_embedder(Arc::new(MockEmbedder));
1042
1043 store.write_file("test-vectors.md", "Test content here").await.unwrap();
1044
1045 let before: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM memory_vectors")
1047 .fetch_one(&store.pool)
1048 .await
1049 .unwrap();
1050 assert!(before > 0);
1051
1052 store.reindex().await.unwrap();
1054
1055 let after: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM memory_vectors")
1056 .fetch_one(&store.pool)
1057 .await
1058 .unwrap();
1059 assert!(after > 0, "Reindex should rebuild vectors");
1061 }
1062
1063 #[test]
1066 fn f32_bytes_round_trip() {
1067 let original = vec![1.0f32, -2.5, 0.0, std::f32::consts::PI, f32::MAX, f32::MIN];
1068 let bytes = f32_vec_to_bytes(&original);
1069 assert_eq!(bytes.len(), original.len() * 4);
1070 let restored = bytes_to_f32_vec(&bytes);
1071 assert_eq!(original, restored);
1072 }
1073
1074 #[test]
1075 fn f32_bytes_empty_round_trip() {
1076 let original: Vec<f32> = vec![];
1077 let bytes = f32_vec_to_bytes(&original);
1078 assert!(bytes.is_empty());
1079 let restored = bytes_to_f32_vec(&bytes);
1080 assert!(restored.is_empty());
1081 }
1082
1083 #[test]
1084 fn f32_bytes_single_value() {
1085 let original = vec![42.0f32];
1086 let bytes = f32_vec_to_bytes(&original);
1087 assert_eq!(bytes.len(), 4);
1088 let restored = bytes_to_f32_vec(&bytes);
1089 assert_eq!(original, restored);
1090 }
1091
1092 #[tokio::test]
1097 async fn search_applies_temporal_decay() {
1098 let tmp = TempDir::new().unwrap();
1099 let store = test_store(&tmp).await;
1100 let agent_home = tmp.path().join("agent_home");
1101
1102 let content = "Temporal decay test content about quantum physics and relativity.";
1104 store.write_file("test-physics.md", content).await.unwrap();
1105
1106 let old_date = Local::now().date_naive() - chrono::Duration::days(90);
1108 let old_filename = format!("memory/{}.md", old_date.format("%Y-%m-%d"));
1109 std::fs::create_dir_all(agent_home.join("memory")).unwrap();
1110 let old_path = agent_home.join(&old_filename);
1111 std::fs::write(&old_path, content).unwrap();
1112 store.reindex().await.unwrap();
1113
1114 let results = store.search("quantum physics relativity", 10).await.unwrap();
1115 assert!(!results.is_empty(), "Should find at least the evergreen file");
1117 }
1118
1119 #[tokio::test]
1120 async fn test_append_daily_creates_memory_dir() {
1121 let tmp = TempDir::new().unwrap();
1122 let store = test_store(&tmp).await;
1123 let agent_home = tmp.path().join("agent_home");
1124
1125 assert!(!agent_home.join("memory").exists());
1127
1128 store.append_daily("First entry without pre-existing dir.").await.unwrap();
1129
1130 assert!(agent_home.join("memory").exists());
1131 let today = Local::now().format("%Y-%m-%d").to_string();
1132 let content = store.read_file(&format!("memory/{}.md", today)).unwrap();
1133 assert!(content.contains("First entry"));
1134 }
1135
1136 #[tokio::test]
1137 async fn test_bootstrap_context_multibyte_safe() {
1138 let tmp = TempDir::new().unwrap();
1139 let agent_home = tmp.path().join("agent_home");
1140 let config_dir = agent_home.join("config");
1141 let db_dir = tmp.path().join("db");
1142 std::fs::create_dir_all(&config_dir).unwrap();
1143
1144 let soul = "# Soul\n".to_string() + &"café 🌟 ".repeat(5000);
1147 std::fs::write(config_dir.join("SOUL.md"), &soul).unwrap();
1148
1149 let store = MemoryStore::new(&agent_home, &config_dir, &db_dir).await.unwrap();
1150 let ctx = store.bootstrap_context().unwrap();
1152 assert!(ctx.contains("SOUL.md"));
1153 assert!(ctx.is_char_boundary(ctx.len()));
1155 }
1156
1157 #[tokio::test]
1160 async fn config_files_routed_to_config_dir() {
1161 let tmp = TempDir::new().unwrap();
1162 let store = test_store(&tmp).await;
1163 let config_dir = tmp.path().join("agent_home").join("config");
1164
1165 store.write_file("SOUL.md", "# Soul\nCustom soul.").await.unwrap();
1167 assert!(config_dir.join("SOUL.md").is_file());
1168 let content = std::fs::read_to_string(config_dir.join("SOUL.md")).unwrap();
1169 assert!(content.contains("Custom soul"));
1170
1171 let read = store.read_file("SOUL.md").unwrap();
1173 assert!(read.contains("Custom soul"));
1174 }
1175
1176 #[tokio::test]
1177 async fn runtime_files_routed_to_agent_home() {
1178 let tmp = TempDir::new().unwrap();
1179 let store = test_store(&tmp).await;
1180 let agent_home = tmp.path().join("agent_home");
1181 let config_dir = agent_home.join("config");
1182
1183 store.write_file("notes.md", "Some notes.").await.unwrap();
1185 assert!(agent_home.join("notes.md").is_file());
1186 assert!(!config_dir.join("notes.md").exists());
1187
1188 let content = store.read_file("notes.md").unwrap();
1190 assert!(content.contains("Some notes"));
1191 }
1192
1193 #[tokio::test]
1194 async fn reindex_covers_both_config_and_agent_home() {
1195 let tmp = TempDir::new().unwrap();
1196 let store = test_store(&tmp).await;
1197
1198 store.write_file("SOUL.md", "Soul content about quantum.").await.unwrap();
1200
1201 store.write_file("notes.md", "Notes about quantum.").await.unwrap();
1203
1204 store.reindex().await.unwrap();
1206
1207 let results = store.search("quantum", 10).await.unwrap();
1208 let sources: Vec<&str> = results.iter().map(|r| r.source.as_str()).collect();
1209 assert!(sources.contains(&"SOUL.md"), "SOUL.md from config_dir should be indexed");
1210 assert!(sources.contains(&"notes.md"), "notes.md from agent_home should be indexed");
1211 }
1212
1213 #[tokio::test]
1214 async fn bootstrap_context_reads_from_config_dir() {
1215 let tmp = TempDir::new().unwrap();
1216 let store = test_store(&tmp).await;
1217
1218 store.write_file("SOUL.md", "# Soul\nI am ConfigBot.").await.unwrap();
1220
1221 let ctx = store.bootstrap_context().unwrap();
1222 assert!(ctx.contains("ConfigBot"), "bootstrap should read from config_dir");
1223 }
1224
1225 #[tokio::test]
1226 async fn has_bootstrap_checks_config_dir() {
1227 let tmp = TempDir::new().unwrap();
1228 let store = test_store(&tmp).await;
1229 let config_dir = tmp.path().join("agent_home").join("config");
1230
1231 assert!(!store.has_bootstrap(), "Default BOOTSTRAP.md should be empty");
1233
1234 std::fs::write(config_dir.join("BOOTSTRAP.md"), "Do something on first run.").unwrap();
1236 assert!(store.has_bootstrap(), "BOOTSTRAP.md with content should be detected");
1237
1238 store.clear_bootstrap().unwrap();
1240 assert!(!store.has_bootstrap(), "Cleared BOOTSTRAP.md should not be detected");
1241 }
1242
1243 #[tokio::test]
1246 async fn new_user_creates_db_in_user_dir() {
1247 let tmp = TempDir::new().unwrap();
1248 let user_dir = tmp.path().join("users").join("alice");
1249
1250 let _store = MemoryStore::new_user(&user_dir).await.unwrap();
1251
1252 assert!(user_dir.join("memory.db").exists(), "memory.db should be in user_dir");
1253 }
1254
1255 #[tokio::test]
1256 async fn new_user_does_not_seed_defaults() {
1257 let tmp = TempDir::new().unwrap();
1258 let user_dir = tmp.path().join("users").join("bob");
1259
1260 let _store = MemoryStore::new_user(&user_dir).await.unwrap();
1261
1262 assert!(!user_dir.join("SOUL.md").exists(), "new_user should not seed SOUL.md");
1264 assert!(!user_dir.join("HEARTBEAT.md").exists(), "new_user should not seed HEARTBEAT.md");
1265 }
1266
1267 #[tokio::test]
1268 async fn new_user_indexes_existing_files() {
1269 let tmp = TempDir::new().unwrap();
1270 let user_dir = tmp.path().join("users").join("carol");
1271 std::fs::create_dir_all(&user_dir).unwrap();
1272
1273 std::fs::write(
1275 user_dir.join("MEMORY.md"),
1276 "# Memory\n\nCarol likes functional programming.\n",
1277 ).unwrap();
1278
1279 let store = MemoryStore::new_user(&user_dir).await.unwrap();
1280
1281 let results = store.search("functional programming", 5).await.unwrap();
1283 assert!(!results.is_empty(), "Pre-existing file should be indexed on startup");
1284 assert!(results.iter().any(|r| r.text.contains("functional programming")));
1285 }
1286
1287 #[tokio::test]
1288 async fn new_user_write_and_search() {
1289 let tmp = TempDir::new().unwrap();
1290 let user_dir = tmp.path().join("users").join("dave");
1291
1292 let store = MemoryStore::new_user(&user_dir).await.unwrap();
1293
1294 store.write_file("MEMORY.md", "# Memory\n\nDave prefers dark theme.\n")
1295 .await
1296 .unwrap();
1297
1298 let results = store.search("dark theme", 5).await.unwrap();
1299 assert!(!results.is_empty(), "Written file should be searchable");
1300 assert!(results.iter().any(|r| r.text.contains("dark theme")));
1301 }
1302
1303 #[tokio::test]
1304 async fn new_user_append_daily_and_search() {
1305 let tmp = TempDir::new().unwrap();
1306 let user_dir = tmp.path().join("users").join("eve");
1307
1308 let store = MemoryStore::new_user(&user_dir).await.unwrap();
1309
1310 store.append_daily("Discussed API design patterns").await.unwrap();
1311
1312 let results = store.search("API design", 5).await.unwrap();
1313 assert!(!results.is_empty(), "Daily log should be searchable");
1314 }
1315}