1use std::path::{Path, PathBuf};
2use std::str::FromStr;
3use std::sync::Arc;
4
5use chrono::Local;
6use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions};
7use sqlx::{Row, SqlitePool};
8use tracing::debug;
9
10use starpod_core::{StarpodError, Result};
11
12use crate::defaults;
13use crate::embedder::{self, Embedder};
14use crate::fusion;
15use crate::indexer::{self, reindex_source, CHUNK_SIZE, CHUNK_OVERLAP};
16use crate::schema;
17use crate::scoring;
18
19const BOOTSTRAP_FILE_CAP: usize = 20_000;
21
22const DEFAULT_HALF_LIFE_DAYS: f64 = 30.0;
24
25#[derive(Debug, Clone)]
32pub struct SearchResult {
33 pub source: String,
35 pub text: String,
37 pub line_start: usize,
39 pub line_end: usize,
41 pub rank: f64,
46}
47
48pub struct MemoryStore {
68 agent_home: PathBuf,
70 config_dir: PathBuf,
72 pool: SqlitePool,
73 half_life_days: f64,
75 mmr_lambda: f64,
77 embedder: Option<Arc<dyn Embedder>>,
79 chunk_size: usize,
81 chunk_overlap: usize,
83 bootstrap_file_cap: usize,
85}
86
87impl MemoryStore {
88 pub async fn new(agent_home: &Path, config_dir: &Path, db_dir: &Path) -> Result<Self> {
94 std::fs::create_dir_all(agent_home)
96 .map_err(StarpodError::Io)?;
97 std::fs::create_dir_all(config_dir)
98 .map_err(StarpodError::Io)?;
99 std::fs::create_dir_all(db_dir)
100 .map_err(StarpodError::Io)?;
101
102 let db_path = db_dir.join("memory.db");
105 let opts = SqliteConnectOptions::from_str(
106 &format!("sqlite://{}?mode=rwc", db_path.display()),
107 )
108 .map_err(|e| StarpodError::Database(format!("Invalid DB path: {}", e)))?
109 .pragma("journal_mode", "WAL")
110 .pragma("busy_timeout", "5000")
111 .pragma("synchronous", "NORMAL");
112
113 let pool = SqlitePoolOptions::new()
114 .max_connections(2)
115 .connect_with(opts)
116 .await
117 .map_err(|e| StarpodError::Database(format!("Failed to open database: {}", e)))?;
118
119 schema::run_migrations(&pool).await?;
121
122 let store = Self {
123 agent_home: agent_home.to_path_buf(),
124 config_dir: config_dir.to_path_buf(),
125 pool,
126 half_life_days: DEFAULT_HALF_LIFE_DAYS,
127 mmr_lambda: 0.7,
128 embedder: None,
129 chunk_size: CHUNK_SIZE,
130 chunk_overlap: CHUNK_OVERLAP,
131 bootstrap_file_cap: BOOTSTRAP_FILE_CAP,
132 };
133
134 store.seed_defaults()?;
136
137 store.reindex().await?;
139
140 Ok(store)
141 }
142
143 pub async fn new_user(user_dir: &Path) -> Result<Self> {
154 std::fs::create_dir_all(user_dir)
155 .map_err(StarpodError::Io)?;
156
157 let db_path = user_dir.join("memory.db");
158 let opts = SqliteConnectOptions::from_str(
159 &format!("sqlite://{}?mode=rwc", db_path.display()),
160 )
161 .map_err(|e| StarpodError::Database(format!("Invalid DB path: {}", e)))?;
162
163 let pool = SqlitePoolOptions::new()
164 .max_connections(1)
165 .connect_with(opts)
166 .await
167 .map_err(|e| StarpodError::Database(format!("Failed to open user database: {}", e)))?;
168
169 schema::run_migrations(&pool).await?;
170
171 let store = Self {
172 agent_home: user_dir.to_path_buf(),
173 config_dir: user_dir.to_path_buf(),
174 pool,
175 half_life_days: DEFAULT_HALF_LIFE_DAYS,
176 mmr_lambda: 0.7,
177 embedder: None,
178 chunk_size: CHUNK_SIZE,
179 chunk_overlap: CHUNK_OVERLAP,
180 bootstrap_file_cap: BOOTSTRAP_FILE_CAP,
181 };
182
183 store.reindex().await?;
185
186 Ok(store)
187 }
188
189 fn seed_defaults(&self) -> Result<bool> {
197 let fresh = !self.config_dir.join("SOUL.md").exists();
198
199 if fresh {
201 let path = self.config_dir.join("SOUL.md");
202 debug!(file = "SOUL.md", "Seeding default SOUL.md");
203 std::fs::write(&path, defaults::DEFAULT_SOUL)?;
204 }
205
206 let lifecycle_files = [
208 ("HEARTBEAT.md", defaults::DEFAULT_HEARTBEAT),
209 ("BOOT.md", defaults::DEFAULT_BOOT),
210 ("BOOTSTRAP.md", defaults::DEFAULT_BOOTSTRAP),
211 ];
212
213 for (name, content) in &lifecycle_files {
214 let path = self.config_dir.join(name);
215 if !path.exists() {
216 debug!(file = %name, "Seeding default file");
217 std::fs::write(&path, content)?;
218 }
219 }
220
221 Ok(fresh)
222 }
223
224 pub fn agent_home(&self) -> &Path {
226 &self.agent_home
227 }
228
229 pub fn config_dir(&self) -> &Path {
231 &self.config_dir
232 }
233
234 const CONFIG_FILES: &[&str] = &[
236 "SOUL.md", "HEARTBEAT.md", "BOOT.md", "BOOTSTRAP.md",
237 ];
238
239 fn resolve_path(&self, name: &str) -> PathBuf {
241 if !name.contains('/') && Self::CONFIG_FILES.iter().any(|&f| f == name) {
243 self.config_dir.join(name)
244 } else {
245 self.agent_home.join(name)
246 }
247 }
248
249 pub fn has_bootstrap(&self) -> bool {
251 let path = self.config_dir.join("BOOTSTRAP.md");
252 path.is_file()
253 && std::fs::read_to_string(&path)
254 .map(|c| !c.trim().is_empty())
255 .unwrap_or(false)
256 }
257
258 pub fn clear_bootstrap(&self) -> Result<()> {
260 let path = self.config_dir.join("BOOTSTRAP.md");
261 if path.exists() {
262 std::fs::write(&path, "")?;
263 }
264 Ok(())
265 }
266
267 pub fn bootstrap_context(&self) -> Result<String> {
272 let content = self.read_file("SOUL.md")?;
273 let capped = if content.len() > self.bootstrap_file_cap {
274 let mut end = self.bootstrap_file_cap;
275 while end > 0 && !content.is_char_boundary(end) { end -= 1; }
276 &content[..end]
277 } else {
278 &content
279 };
280 Ok(format!("--- SOUL.md ---\n{}", capped))
281 }
282
283 pub fn set_half_life_days(&mut self, days: f64) {
285 self.half_life_days = days;
286 }
287
288 pub fn set_mmr_lambda(&mut self, lambda: f64) {
290 self.mmr_lambda = lambda;
291 }
292
293 pub fn set_chunk_size(&mut self, size: usize) {
295 self.chunk_size = size;
296 }
297
298 pub fn set_chunk_overlap(&mut self, overlap: usize) {
300 self.chunk_overlap = overlap;
301 }
302
303 pub fn set_bootstrap_file_cap(&mut self, cap: usize) {
305 self.bootstrap_file_cap = cap;
306 }
307
308 pub async fn search(&self, query: &str, limit: usize) -> Result<Vec<SearchResult>> {
314 let fetch_limit = (limit * 3).max(30);
316 let rows = sqlx::query(
317 "SELECT source, chunk_text, line_start, line_end, rank
318 FROM memory_fts
319 WHERE memory_fts MATCH ?1
320 ORDER BY rank
321 LIMIT ?2",
322 )
323 .bind(query)
324 .bind(fetch_limit as i64)
325 .fetch_all(&self.pool)
326 .await
327 .map_err(|e| StarpodError::Database(format!("Search query failed: {}", e)))?;
328
329 let mut results: Vec<SearchResult> = rows
330 .iter()
331 .map(|row| {
332 let source = row.get::<String, _>("source");
333 let raw_rank = row.get::<f64, _>("rank");
334 let adjusted_rank = scoring::apply_decay(raw_rank, &source, self.half_life_days);
335 SearchResult {
336 source,
337 text: row.get::<String, _>("chunk_text"),
338 line_start: row.get::<i64, _>("line_start") as usize,
339 line_end: row.get::<i64, _>("line_end") as usize,
340 rank: adjusted_rank,
341 }
342 })
343 .collect();
344
345 results.sort_by(|a, b| a.rank.partial_cmp(&b.rank).unwrap_or(std::cmp::Ordering::Equal));
347 results.truncate(limit);
348
349 Ok(results)
350 }
351
352 pub fn set_embedder(&mut self, embedder: Arc<dyn Embedder>) {
354 self.embedder = Some(embedder);
355 }
356
357 pub async fn vector_search(&self, query: &str, limit: usize) -> Result<Vec<SearchResult>> {
361 let embedder = match &self.embedder {
362 Some(e) => e,
363 None => return Ok(Vec::new()),
364 };
365
366 let query_vecs = embedder
368 .embed(&[query.to_string()])
369 .await?;
370 let query_vec = match query_vecs.first() {
371 Some(v) => v,
372 None => return Ok(Vec::new()),
373 };
374
375 let rows = sqlx::query(
377 "SELECT v.source, v.embedding, v.line_start, v.line_end, f.chunk_text
378 FROM memory_vectors v
379 LEFT JOIN memory_fts f ON f.source = v.source
380 AND f.line_start = v.line_start AND f.line_end = v.line_end"
381 )
382 .fetch_all(&self.pool)
383 .await
384 .map_err(|e| StarpodError::Database(format!("Vector search failed: {}", e)))?;
385
386 let mut scored: Vec<(f32, SearchResult)> = Vec::new();
387 for row in &rows {
388 let blob: Vec<u8> = row.get("embedding");
389 let embedding = bytes_to_f32_vec(&blob);
390 let similarity = embedder::cosine_similarity(query_vec, &embedding);
391
392 let source: String = row.get("source");
393 let text: String = row.try_get("chunk_text").unwrap_or_default();
394
395 scored.push((similarity, SearchResult {
396 source,
397 text,
398 line_start: row.get::<i64, _>("line_start") as usize,
399 line_end: row.get::<i64, _>("line_end") as usize,
400 rank: -(similarity as f64), }));
402 }
403
404 scored.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
406 scored.truncate(limit);
407
408 Ok(scored.into_iter().map(|(_, r)| r).collect())
409 }
410
411 pub async fn hybrid_search(&self, query: &str, limit: usize) -> Result<Vec<SearchResult>> {
416 let embedder = match &self.embedder {
417 Some(e) => e,
418 None => return self.search(query, limit).await,
419 };
420
421 let fts_limit = (limit * 3).max(30);
423 let vec_limit = (limit * 3).max(30);
424
425 let (fts_results, vec_results) = tokio::join!(
426 self.fts_search_raw(query, fts_limit),
427 self.vector_search(query, vec_limit),
428 );
429
430 let fts_results = fts_results?;
431 let vec_results = vec_results?;
432
433 let mut fused = fusion::reciprocal_rank_fusion(&fts_results, &vec_results, limit * 3);
435
436 for result in &mut fused {
438 let decay = scoring::decay_factor(&result.source, self.half_life_days);
439 if decay > 0.0 && decay < 1.0 {
440 result.rank *= decay;
441 }
442 }
443
444 fused.sort_by(|a, b| a.rank.partial_cmp(&b.rank).unwrap_or(std::cmp::Ordering::Equal));
446
447 let mmr_pool_size = (limit * 2).min(fused.len());
449 if mmr_pool_size > 0 {
450 let query_vecs = embedder.embed(&[query.to_string()]).await?;
452 if let Some(query_vec) = query_vecs.first() {
453 let texts: Vec<String> = fused[..mmr_pool_size]
455 .iter()
456 .map(|r| r.text.clone())
457 .collect();
458 let embeddings = embedder.embed(&texts).await?;
459
460 let candidates: Vec<(Vec<f32>, usize)> = embeddings
461 .into_iter()
462 .enumerate()
463 .map(|(i, emb)| (emb, i))
464 .collect();
465
466 let selected_indices =
467 scoring::mmr_rerank(query_vec, &candidates, limit, self.mmr_lambda);
468
469 let pool = fused;
470 fused = selected_indices
471 .into_iter()
472 .map(|idx| pool[idx].clone())
473 .collect();
474 } else {
475 fused.truncate(limit);
476 }
477 }
478
479 Ok(fused)
480 }
481
482 async fn fts_search_raw(&self, query: &str, limit: usize) -> Result<Vec<SearchResult>> {
484 let rows = sqlx::query(
485 "SELECT source, chunk_text, line_start, line_end, rank
486 FROM memory_fts
487 WHERE memory_fts MATCH ?1
488 ORDER BY rank
489 LIMIT ?2",
490 )
491 .bind(query)
492 .bind(limit as i64)
493 .fetch_all(&self.pool)
494 .await
495 .map_err(|e| StarpodError::Database(format!("Search query failed: {}", e)))?;
496
497 Ok(rows
498 .iter()
499 .map(|row| SearchResult {
500 source: row.get::<String, _>("source"),
501 text: row.get::<String, _>("chunk_text"),
502 line_start: row.get::<i64, _>("line_start") as usize,
503 line_end: row.get::<i64, _>("line_end") as usize,
504 rank: row.get::<f64, _>("rank"),
505 })
506 .collect())
507 }
508
509 async fn embed_and_store_source(&self, source: &str, text: &str) -> Result<()> {
511 let embedder = match &self.embedder {
512 Some(e) => e,
513 None => return Ok(()),
514 };
515
516 sqlx::query("DELETE FROM memory_vectors WHERE source = ?1")
518 .bind(source)
519 .execute(&self.pool)
520 .await
521 .map_err(|e| StarpodError::Database(format!("Failed to delete old vectors: {}", e)))?;
522
523 let chunks = indexer::chunk_text(source, text, self.chunk_size, self.chunk_overlap);
525 if chunks.is_empty() {
526 return Ok(());
527 }
528
529 let texts: Vec<String> = chunks.iter().map(|c| c.text.clone()).collect();
531 let embeddings = embedder.embed(&texts).await?;
532
533 for (idx, (chunk, embedding)) in chunks.iter().zip(embeddings.iter()).enumerate() {
535 let blob = f32_vec_to_bytes(embedding);
536 sqlx::query(
537 "INSERT INTO memory_vectors (source, chunk_idx, embedding, line_start, line_end)
538 VALUES (?1, ?2, ?3, ?4, ?5)",
539 )
540 .bind(&chunk.source)
541 .bind(idx as i64)
542 .bind(&blob)
543 .bind(chunk.line_start as i64)
544 .bind(chunk.line_end as i64)
545 .execute(&self.pool)
546 .await
547 .map_err(|e| StarpodError::Database(format!("Failed to insert vector: {}", e)))?;
548 }
549
550 Ok(())
551 }
552
553 pub fn read_file(&self, name: &str) -> Result<String> {
555 scoring::validate_path(name, &self.agent_home)?;
557 let path = self.resolve_path(name);
558 if !path.exists() {
559 return Ok(String::new());
560 }
561 std::fs::read_to_string(&path).map_err(StarpodError::Io)
562 }
563
564 pub async fn write_file(&self, name: &str, content: &str) -> Result<()> {
569 scoring::validate_path(name, &self.agent_home)?;
570 scoring::validate_content_size(content)?;
571
572 let path = self.resolve_path(name);
573
574 if let Some(parent) = path.parent() {
576 std::fs::create_dir_all(parent)?;
577 }
578
579 std::fs::write(&path, content)?;
580
581 reindex_source(&self.pool, name, content, self.chunk_size, self.chunk_overlap).await?;
583 self.embed_and_store_source(name, content).await?;
584
585 Ok(())
586 }
587
588 pub async fn append_daily(&self, text: &str) -> Result<()> {
590 let today = Local::now().format("%Y-%m-%d").to_string();
591 let filename = format!("memory/{}.md", today);
592 let path = self.agent_home.join(&filename);
593
594 if let Some(parent) = path.parent() {
596 std::fs::create_dir_all(parent)?;
597 }
598
599 let timestamp = Local::now().format("%H:%M:%S").to_string();
600 let entry = format!("\n## {}\n{}\n", timestamp, text);
601
602 let mut content = if path.exists() {
603 std::fs::read_to_string(&path)?
604 } else {
605 format!("# Daily Log — {}\n", today)
606 };
607
608 content.push_str(&entry);
609 std::fs::write(&path, &content)?;
610
611 reindex_source(&self.pool, &filename, &content, self.chunk_size, self.chunk_overlap).await?;
613 self.embed_and_store_source(&filename, &content).await?;
614
615 Ok(())
616 }
617
618 pub async fn reindex(&self) -> Result<()> {
624 sqlx::query("DELETE FROM memory_fts")
626 .execute(&self.pool)
627 .await
628 .map_err(|e| StarpodError::Database(format!("Failed to clear FTS: {}", e)))?;
629
630 sqlx::query("DELETE FROM memory_vectors")
632 .execute(&self.pool)
633 .await
634 .map_err(|e| StarpodError::Database(format!("Failed to clear vectors: {}", e)))?;
635
636 self.index_dir(&self.config_dir.clone(), "").await?;
638
639 if let Ok(entries) = std::fs::read_dir(&self.agent_home) {
641 for entry in entries.flatten() {
642 let path = entry.path();
643 if path.is_file() && path.extension().is_some_and(|ext| ext == "md") {
644 let filename = entry.file_name().to_string_lossy().to_string();
645 if !Self::CONFIG_FILES.iter().any(|&f| f == filename) {
647 let content = std::fs::read_to_string(&path)?;
648 reindex_source(&self.pool, &filename, &content, self.chunk_size, self.chunk_overlap).await?;
649 self.embed_and_store_source(&filename, &content).await?;
650 }
651 }
652 }
653 }
654
655 Ok(())
656 }
657
658 async fn index_dir(&self, dir: &Path, prefix: &str) -> Result<()> {
660 let entries = std::fs::read_dir(dir).map_err(StarpodError::Io)?;
661
662 for entry in entries {
663 let entry = entry.map_err(StarpodError::Io)?;
664 let path = entry.path();
665 if path.is_file() && path.extension().is_some_and(|ext| ext == "md") {
666 let filename = entry.file_name().to_string_lossy().to_string();
667 let source = format!("{}{}", prefix, filename);
668 let content = std::fs::read_to_string(&path)?;
669 reindex_source(&self.pool, &source, &content, self.chunk_size, self.chunk_overlap).await?;
670 self.embed_and_store_source(&source, &content).await?;
671 }
672 }
673
674 Ok(())
675 }
676}
677
678fn f32_vec_to_bytes(vec: &[f32]) -> Vec<u8> {
680 let mut bytes = Vec::with_capacity(vec.len() * 4);
681 for &v in vec {
682 bytes.extend_from_slice(&v.to_le_bytes());
683 }
684 bytes
685}
686
687fn bytes_to_f32_vec(bytes: &[u8]) -> Vec<f32> {
689 bytes
690 .chunks_exact(4)
691 .map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
692 .collect()
693}
694
695#[cfg(test)]
696mod tests {
697 use super::*;
698 use tempfile::TempDir;
699
700 async fn test_store(tmp: &TempDir) -> MemoryStore {
704 let agent_home = tmp.path().join("agent_home");
705 let config_dir = tmp.path().join("agent_home").join("config");
706 let db_dir = tmp.path().join("db");
707 MemoryStore::new(&agent_home, &config_dir, &db_dir).await.unwrap()
708 }
709
710 #[tokio::test]
713 async fn test_new_seeds_defaults() {
714 let tmp = TempDir::new().unwrap();
715 let store = test_store(&tmp).await;
716 let config_dir = tmp.path().join("agent_home").join("config");
717
718 assert!(config_dir.join("SOUL.md").exists());
720 assert!(config_dir.join("HEARTBEAT.md").exists());
721 assert!(config_dir.join("BOOT.md").exists());
722 assert!(config_dir.join("BOOTSTRAP.md").exists());
723
724 assert!(!config_dir.join("USER.md").exists());
726 assert!(!config_dir.join("MEMORY.md").exists());
727
728 assert!(tmp.path().join("db").join("memory.db").exists());
730
731 let soul = store.read_file("SOUL.md").unwrap();
733 assert!(soul.contains("Aster"));
734 }
735
736 #[tokio::test]
737 async fn test_write_and_search() {
738 let tmp = TempDir::new().unwrap();
739 let store = test_store(&tmp).await;
740
741 store
742 .write_file("test-content.md", "Rust is a systems programming language focused on safety and performance.")
743 .await
744 .unwrap();
745
746 let results = store.search("Rust programming", 5).await.unwrap();
747 assert!(!results.is_empty());
748 assert!(results[0].text.contains("Rust"));
749 }
750
751 #[tokio::test]
752 async fn test_append_daily() {
753 let tmp = TempDir::new().unwrap();
754 let store = test_store(&tmp).await;
755 let agent_home = tmp.path().join("agent_home");
756
757 std::fs::create_dir_all(agent_home.join("memory")).unwrap();
759
760 store.append_daily("Had a great conversation about Rust.").await.unwrap();
761 store.append_daily("Discussed memory management.").await.unwrap();
762
763 let today = Local::now().format("%Y-%m-%d").to_string();
764 let content = store.read_file(&format!("memory/{}.md", today)).unwrap();
765 assert!(content.contains("great conversation"));
766 assert!(content.contains("memory management"));
767 }
768
769 #[tokio::test]
770 async fn test_bootstrap_context() {
771 let tmp = TempDir::new().unwrap();
772 let store = test_store(&tmp).await;
773
774 let ctx = store.bootstrap_context().unwrap();
775 assert!(ctx.contains("SOUL.md"));
776 assert!(ctx.contains("Aster"));
777 assert!(!ctx.contains("USER.md"));
779 assert!(!ctx.contains("MEMORY.md"));
780 }
781
782 #[tokio::test]
783 async fn test_reindex() {
784 let tmp = TempDir::new().unwrap();
785 let store = test_store(&tmp).await;
786 let agent_home = tmp.path().join("agent_home");
787
788 std::fs::write(
790 agent_home.join("test-quantum.md"),
791 "This is about quantum computing and qubits.",
792 )
793 .unwrap();
794
795 store.reindex().await.unwrap();
797
798 let results = store.search("quantum computing", 5).await.unwrap();
799 assert!(!results.is_empty());
800 }
801
802 #[tokio::test]
805 async fn write_file_rejects_traversal() {
806 let tmp = TempDir::new().unwrap();
807 let store = test_store(&tmp).await;
808 let err = store.write_file("../escape.md", "evil content").await;
809 assert!(err.is_err(), "write_file should reject path traversal");
810 }
811
812 #[tokio::test]
813 async fn write_file_rejects_non_md() {
814 let tmp = TempDir::new().unwrap();
815 let store = test_store(&tmp).await;
816 let err = store.write_file("script.sh", "#!/bin/bash").await;
817 assert!(err.is_err(), "write_file should reject non-.md files");
818 }
819
820 #[tokio::test]
821 async fn write_file_rejects_absolute_path() {
822 let tmp = TempDir::new().unwrap();
823 let store = test_store(&tmp).await;
824 let err = store.write_file("/tmp/evil.md", "content").await;
825 assert!(err.is_err(), "write_file should reject absolute paths");
826 }
827
828 #[tokio::test]
829 async fn read_file_rejects_traversal() {
830 let tmp = TempDir::new().unwrap();
831 let store = test_store(&tmp).await;
832 let err = store.read_file("../../etc/passwd.md");
833 assert!(err.is_err(), "read_file should reject path traversal");
834 }
835
836 #[tokio::test]
837 async fn read_file_rejects_non_md() {
838 let tmp = TempDir::new().unwrap();
839 let store = test_store(&tmp).await;
840 let err = store.read_file("secret.json");
841 assert!(err.is_err(), "read_file should reject non-.md files");
842 }
843
844 #[tokio::test]
847 async fn write_file_rejects_oversized_content() {
848 let tmp = TempDir::new().unwrap();
849 let store = test_store(&tmp).await;
850 let big = "x".repeat(scoring::MAX_WRITE_SIZE + 1);
851 let err = store.write_file("big.md", &big).await;
852 assert!(err.is_err(), "write_file should reject content > 1 MB");
853 }
854
855 #[tokio::test]
856 async fn write_file_accepts_content_at_limit() {
857 let tmp = TempDir::new().unwrap();
858 let store = test_store(&tmp).await;
859 let exact = "x".repeat(scoring::MAX_WRITE_SIZE);
860 let result = store.write_file("exact.md", &exact).await;
861 assert!(result.is_ok(), "write_file should accept content at exactly 1 MB");
862 }
863
864 #[tokio::test]
867 async fn set_half_life_days_is_applied() {
868 let tmp = TempDir::new().unwrap();
869 let mut store = test_store(&tmp).await;
870 store.set_half_life_days(7.0);
871 assert_eq!(store.half_life_days, 7.0);
872 }
873
874 #[tokio::test]
875 async fn set_mmr_lambda_is_applied() {
876 let tmp = TempDir::new().unwrap();
877 let mut store = test_store(&tmp).await;
878 store.set_mmr_lambda(0.5);
879 assert_eq!(store.mmr_lambda, 0.5);
880 }
881
882 #[tokio::test]
883 async fn set_chunk_size_is_applied() {
884 let tmp = TempDir::new().unwrap();
885 let mut store = test_store(&tmp).await;
886 store.set_chunk_size(800);
887 assert_eq!(store.chunk_size, 800);
888 }
889
890 #[tokio::test]
891 async fn set_chunk_overlap_is_applied() {
892 let tmp = TempDir::new().unwrap();
893 let mut store = test_store(&tmp).await;
894 store.set_chunk_overlap(160);
895 assert_eq!(store.chunk_overlap, 160);
896 }
897
898 #[tokio::test]
899 async fn set_bootstrap_file_cap_is_applied() {
900 let tmp = TempDir::new().unwrap();
901 let mut store = test_store(&tmp).await;
902 store.set_bootstrap_file_cap(5000);
903 assert_eq!(store.bootstrap_file_cap, 5000);
904 }
905
906 #[tokio::test]
907 async fn bootstrap_file_cap_limits_output() {
908 let tmp = TempDir::new().unwrap();
909 let mut store = test_store(&tmp).await;
910
911 let large_content = "x".repeat(10_000);
913 store.write_file("SOUL.md", &large_content).await.unwrap();
914
915 store.set_bootstrap_file_cap(500);
917
918 let ctx = store.bootstrap_context().unwrap();
919 let soul_section = ctx
922 .split("--- SOUL.md ---\n")
923 .nth(1)
924 .unwrap_or("")
925 .split("\n\n--- ")
926 .next()
927 .unwrap_or("");
928 assert!(
929 soul_section.len() <= 500,
930 "SOUL.md section should be capped at 500 chars, got {}",
931 soul_section.len(),
932 );
933 }
934
935 #[tokio::test]
938 async fn vector_search_returns_empty_without_embedder() {
939 let tmp = TempDir::new().unwrap();
940 let store = test_store(&tmp).await;
941 let results = store.vector_search("anything", 10).await.unwrap();
942 assert!(results.is_empty(), "vector_search should return empty without embedder");
943 }
944
945 #[tokio::test]
946 async fn hybrid_search_falls_back_to_fts_without_embedder() {
947 let tmp = TempDir::new().unwrap();
948 let store = test_store(&tmp).await;
949
950 store
951 .write_file("test-elephants.md", "Unique test content about elephants in Africa.")
952 .await
953 .unwrap();
954
955 let results = store.hybrid_search("elephants Africa", 5).await.unwrap();
957 assert!(!results.is_empty(), "hybrid_search should fall back to FTS without embedder");
958 assert!(results[0].text.contains("elephants"));
959 }
960
961 struct MockEmbedder;
967
968 #[async_trait::async_trait]
969 impl Embedder for MockEmbedder {
970 async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
971 Ok(texts.iter().map(|t| {
972 let mut vec = vec![0.0f32; 8];
973 for ch in t.chars() {
974 let idx = (ch.to_ascii_lowercase() as usize).wrapping_sub('a' as usize);
975 if idx < 8 {
976 vec[idx] += 1.0;
977 }
978 }
979 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
981 if norm > 0.0 {
982 for v in &mut vec {
983 *v /= norm;
984 }
985 }
986 vec
987 }).collect())
988 }
989
990 fn dimensions(&self) -> usize {
991 8
992 }
993 }
994
995 #[tokio::test]
996 async fn set_embedder_enables_vector_storage() {
997 let tmp = TempDir::new().unwrap();
998 let mut store = test_store(&tmp).await;
999 store.set_embedder(Arc::new(MockEmbedder));
1000
1001 store
1002 .write_file("test-cats.md", "Cats are wonderful animals that love to sleep.")
1003 .await
1004 .unwrap();
1005
1006 let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM memory_vectors WHERE source = 'test-cats.md'")
1008 .fetch_one(&store.pool)
1009 .await
1010 .unwrap();
1011 assert!(count > 0, "Vectors should be stored after write_file with embedder");
1012 }
1013
1014 #[tokio::test]
1015 async fn vector_search_with_mock_embedder() {
1016 let tmp = TempDir::new().unwrap();
1017 let mut store = test_store(&tmp).await;
1018 store.set_embedder(Arc::new(MockEmbedder));
1019
1020 store.write_file("test-abc.md", "aaa bbb ccc abc").await.unwrap();
1021 store.write_file("test-def.md", "ddd eee fff def").await.unwrap();
1022
1023 let results = store.vector_search("aaa abc", 5).await.unwrap();
1025 assert!(!results.is_empty(), "vector_search should return results with embedder");
1026 }
1027
1028 #[tokio::test]
1029 async fn hybrid_search_with_mock_embedder() {
1030 let tmp = TempDir::new().unwrap();
1031 let mut store = test_store(&tmp).await;
1032 store.set_embedder(Arc::new(MockEmbedder));
1033
1034 store.write_file("test-alpha.md", "Alpha beta gamma delta").await.unwrap();
1035 store.write_file("test-beta.md", "Beta epsilon zeta eta").await.unwrap();
1036
1037 let results = store.hybrid_search("alpha beta", 5).await.unwrap();
1038 assert!(!results.is_empty(), "hybrid_search should return results with embedder");
1039 }
1040
1041 #[tokio::test]
1042 async fn reindex_clears_and_rebuilds_vectors() {
1043 let tmp = TempDir::new().unwrap();
1044 let mut store = test_store(&tmp).await;
1045 store.set_embedder(Arc::new(MockEmbedder));
1046
1047 store.write_file("test-vectors.md", "Test content here").await.unwrap();
1048
1049 let before: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM memory_vectors")
1051 .fetch_one(&store.pool)
1052 .await
1053 .unwrap();
1054 assert!(before > 0);
1055
1056 store.reindex().await.unwrap();
1058
1059 let after: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM memory_vectors")
1060 .fetch_one(&store.pool)
1061 .await
1062 .unwrap();
1063 assert!(after > 0, "Reindex should rebuild vectors");
1065 }
1066
1067 #[test]
1070 fn f32_bytes_round_trip() {
1071 let original = vec![1.0f32, -2.5, 0.0, std::f32::consts::PI, f32::MAX, f32::MIN];
1072 let bytes = f32_vec_to_bytes(&original);
1073 assert_eq!(bytes.len(), original.len() * 4);
1074 let restored = bytes_to_f32_vec(&bytes);
1075 assert_eq!(original, restored);
1076 }
1077
1078 #[test]
1079 fn f32_bytes_empty_round_trip() {
1080 let original: Vec<f32> = vec![];
1081 let bytes = f32_vec_to_bytes(&original);
1082 assert!(bytes.is_empty());
1083 let restored = bytes_to_f32_vec(&bytes);
1084 assert!(restored.is_empty());
1085 }
1086
1087 #[test]
1088 fn f32_bytes_single_value() {
1089 let original = vec![42.0f32];
1090 let bytes = f32_vec_to_bytes(&original);
1091 assert_eq!(bytes.len(), 4);
1092 let restored = bytes_to_f32_vec(&bytes);
1093 assert_eq!(original, restored);
1094 }
1095
1096 #[tokio::test]
1101 async fn search_applies_temporal_decay() {
1102 let tmp = TempDir::new().unwrap();
1103 let store = test_store(&tmp).await;
1104 let agent_home = tmp.path().join("agent_home");
1105
1106 let content = "Temporal decay test content about quantum physics and relativity.";
1108 store.write_file("test-physics.md", content).await.unwrap();
1109
1110 let old_date = Local::now().date_naive() - chrono::Duration::days(90);
1112 let old_filename = format!("memory/{}.md", old_date.format("%Y-%m-%d"));
1113 std::fs::create_dir_all(agent_home.join("memory")).unwrap();
1114 let old_path = agent_home.join(&old_filename);
1115 std::fs::write(&old_path, content).unwrap();
1116 store.reindex().await.unwrap();
1117
1118 let results = store.search("quantum physics relativity", 10).await.unwrap();
1119 assert!(!results.is_empty(), "Should find at least the evergreen file");
1121 }
1122
1123 #[tokio::test]
1124 async fn test_append_daily_creates_memory_dir() {
1125 let tmp = TempDir::new().unwrap();
1126 let store = test_store(&tmp).await;
1127 let agent_home = tmp.path().join("agent_home");
1128
1129 assert!(!agent_home.join("memory").exists());
1131
1132 store.append_daily("First entry without pre-existing dir.").await.unwrap();
1133
1134 assert!(agent_home.join("memory").exists());
1135 let today = Local::now().format("%Y-%m-%d").to_string();
1136 let content = store.read_file(&format!("memory/{}.md", today)).unwrap();
1137 assert!(content.contains("First entry"));
1138 }
1139
1140 #[tokio::test]
1141 async fn test_bootstrap_context_multibyte_safe() {
1142 let tmp = TempDir::new().unwrap();
1143 let agent_home = tmp.path().join("agent_home");
1144 let config_dir = agent_home.join("config");
1145 let db_dir = tmp.path().join("db");
1146 std::fs::create_dir_all(&config_dir).unwrap();
1147
1148 let soul = "# Soul\n".to_string() + &"café 🌟 ".repeat(5000);
1151 std::fs::write(config_dir.join("SOUL.md"), &soul).unwrap();
1152
1153 let store = MemoryStore::new(&agent_home, &config_dir, &db_dir).await.unwrap();
1154 let ctx = store.bootstrap_context().unwrap();
1156 assert!(ctx.contains("SOUL.md"));
1157 assert!(ctx.is_char_boundary(ctx.len()));
1159 }
1160
1161 #[tokio::test]
1164 async fn config_files_routed_to_config_dir() {
1165 let tmp = TempDir::new().unwrap();
1166 let store = test_store(&tmp).await;
1167 let config_dir = tmp.path().join("agent_home").join("config");
1168
1169 store.write_file("SOUL.md", "# Soul\nCustom soul.").await.unwrap();
1171 assert!(config_dir.join("SOUL.md").is_file());
1172 let content = std::fs::read_to_string(config_dir.join("SOUL.md")).unwrap();
1173 assert!(content.contains("Custom soul"));
1174
1175 let read = store.read_file("SOUL.md").unwrap();
1177 assert!(read.contains("Custom soul"));
1178 }
1179
1180 #[tokio::test]
1181 async fn runtime_files_routed_to_agent_home() {
1182 let tmp = TempDir::new().unwrap();
1183 let store = test_store(&tmp).await;
1184 let agent_home = tmp.path().join("agent_home");
1185 let config_dir = agent_home.join("config");
1186
1187 store.write_file("notes.md", "Some notes.").await.unwrap();
1189 assert!(agent_home.join("notes.md").is_file());
1190 assert!(!config_dir.join("notes.md").exists());
1191
1192 let content = store.read_file("notes.md").unwrap();
1194 assert!(content.contains("Some notes"));
1195 }
1196
1197 #[tokio::test]
1198 async fn reindex_covers_both_config_and_agent_home() {
1199 let tmp = TempDir::new().unwrap();
1200 let store = test_store(&tmp).await;
1201
1202 store.write_file("SOUL.md", "Soul content about quantum.").await.unwrap();
1204
1205 store.write_file("notes.md", "Notes about quantum.").await.unwrap();
1207
1208 store.reindex().await.unwrap();
1210
1211 let results = store.search("quantum", 10).await.unwrap();
1212 let sources: Vec<&str> = results.iter().map(|r| r.source.as_str()).collect();
1213 assert!(sources.contains(&"SOUL.md"), "SOUL.md from config_dir should be indexed");
1214 assert!(sources.contains(&"notes.md"), "notes.md from agent_home should be indexed");
1215 }
1216
1217 #[tokio::test]
1218 async fn bootstrap_context_reads_from_config_dir() {
1219 let tmp = TempDir::new().unwrap();
1220 let store = test_store(&tmp).await;
1221
1222 store.write_file("SOUL.md", "# Soul\nI am ConfigBot.").await.unwrap();
1224
1225 let ctx = store.bootstrap_context().unwrap();
1226 assert!(ctx.contains("ConfigBot"), "bootstrap should read from config_dir");
1227 }
1228
1229 #[tokio::test]
1230 async fn has_bootstrap_checks_config_dir() {
1231 let tmp = TempDir::new().unwrap();
1232 let store = test_store(&tmp).await;
1233 let config_dir = tmp.path().join("agent_home").join("config");
1234
1235 assert!(!store.has_bootstrap(), "Default BOOTSTRAP.md should be empty");
1237
1238 std::fs::write(config_dir.join("BOOTSTRAP.md"), "Do something on first run.").unwrap();
1240 assert!(store.has_bootstrap(), "BOOTSTRAP.md with content should be detected");
1241
1242 store.clear_bootstrap().unwrap();
1244 assert!(!store.has_bootstrap(), "Cleared BOOTSTRAP.md should not be detected");
1245 }
1246
1247 #[tokio::test]
1250 async fn new_user_creates_db_in_user_dir() {
1251 let tmp = TempDir::new().unwrap();
1252 let user_dir = tmp.path().join("users").join("alice");
1253
1254 let _store = MemoryStore::new_user(&user_dir).await.unwrap();
1255
1256 assert!(user_dir.join("memory.db").exists(), "memory.db should be in user_dir");
1257 }
1258
1259 #[tokio::test]
1260 async fn new_user_does_not_seed_defaults() {
1261 let tmp = TempDir::new().unwrap();
1262 let user_dir = tmp.path().join("users").join("bob");
1263
1264 let _store = MemoryStore::new_user(&user_dir).await.unwrap();
1265
1266 assert!(!user_dir.join("SOUL.md").exists(), "new_user should not seed SOUL.md");
1268 assert!(!user_dir.join("HEARTBEAT.md").exists(), "new_user should not seed HEARTBEAT.md");
1269 }
1270
1271 #[tokio::test]
1272 async fn new_user_indexes_existing_files() {
1273 let tmp = TempDir::new().unwrap();
1274 let user_dir = tmp.path().join("users").join("carol");
1275 std::fs::create_dir_all(&user_dir).unwrap();
1276
1277 std::fs::write(
1279 user_dir.join("MEMORY.md"),
1280 "# Memory\n\nCarol likes functional programming.\n",
1281 ).unwrap();
1282
1283 let store = MemoryStore::new_user(&user_dir).await.unwrap();
1284
1285 let results = store.search("functional programming", 5).await.unwrap();
1287 assert!(!results.is_empty(), "Pre-existing file should be indexed on startup");
1288 assert!(results.iter().any(|r| r.text.contains("functional programming")));
1289 }
1290
1291 #[tokio::test]
1292 async fn new_user_write_and_search() {
1293 let tmp = TempDir::new().unwrap();
1294 let user_dir = tmp.path().join("users").join("dave");
1295
1296 let store = MemoryStore::new_user(&user_dir).await.unwrap();
1297
1298 store.write_file("MEMORY.md", "# Memory\n\nDave prefers dark theme.\n")
1299 .await
1300 .unwrap();
1301
1302 let results = store.search("dark theme", 5).await.unwrap();
1303 assert!(!results.is_empty(), "Written file should be searchable");
1304 assert!(results.iter().any(|r| r.text.contains("dark theme")));
1305 }
1306
1307 #[tokio::test]
1308 async fn new_user_append_daily_and_search() {
1309 let tmp = TempDir::new().unwrap();
1310 let user_dir = tmp.path().join("users").join("eve");
1311
1312 let store = MemoryStore::new_user(&user_dir).await.unwrap();
1313
1314 store.append_daily("Discussed API design patterns").await.unwrap();
1315
1316 let results = store.search("API design", 5).await.unwrap();
1317 assert!(!results.is_empty(), "Daily log should be searchable");
1318 }
1319}