1#[cfg(not(any(feature = "hnsw", feature = "brute-force")))]
25compile_error!("At least one search backend feature must be enabled: 'hnsw' or 'brute-force'");
26
27pub mod chunker;
28pub mod config;
29pub mod conversation;
30pub mod db;
31pub mod documents;
32pub mod embedder;
33pub mod error;
34#[cfg(feature = "hnsw")]
35pub mod hnsw;
36pub mod knowledge;
37pub mod quantize;
38pub mod search;
39pub mod storage;
40pub mod tokenizer;
41pub mod types;
42
43pub use config::{ChunkingConfig, EmbeddingConfig, MemoryConfig, SearchConfig};
45pub use embedder::{Embedder, MockEmbedder, OllamaEmbedder};
46pub use error::MemoryError;
47#[cfg(feature = "hnsw")]
48pub use hnsw::{HnswConfig, HnswHit, HnswIndex};
49pub use quantize::{pack_quantized, unpack_quantized, QuantizedVector, Quantizer};
50pub use storage::StoragePaths;
51pub use tokenizer::{EstimateTokenCounter, TokenCounter};
52pub use types::{
53 Document, Fact, MemoryStats, Message, Role, SearchResult, SearchSource, SearchSourceType,
54 Session, TextChunk,
55};
56
57use std::sync::{Arc, Mutex};
58
59#[derive(Clone)]
63pub struct MemoryStore {
64 inner: Arc<MemoryStoreInner>,
65}
66
67struct MemoryStoreInner {
68 conn: Mutex<rusqlite::Connection>,
69 embedder: Box<dyn Embedder>,
70 config: MemoryConfig,
71 paths: StoragePaths,
72 token_counter: Arc<dyn TokenCounter>,
73 #[cfg(feature = "hnsw")]
74 hnsw_index: std::sync::RwLock<HnswIndex>,
75}
76
77#[cfg(feature = "hnsw")]
78impl Drop for MemoryStoreInner {
79 fn drop(&mut self) {
80 let hnsw_guard = match self.hnsw_index.read() {
81 Ok(g) => g,
82 Err(_) => {
83 tracing::warn!("HNSW RwLock poisoned on drop — skipping save");
84 return;
85 }
86 };
87
88 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
91 hnsw_guard.save(&self.paths.hnsw_dir, &self.paths.hnsw_basename)
92 }));
93 match result {
94 Ok(Err(e)) => tracing::error!("Failed to save HNSW index on drop: {}", e),
95 Err(_) => tracing::warn!("HNSW save panicked on drop (directory may have been removed)"),
96 Ok(Ok(())) => {}
97 }
98
99 if let Ok(conn) = self.conn.lock() {
101 if let Err(e) = hnsw_guard.flush_keymap(&conn) {
102 tracing::error!("Failed to flush HNSW keymap on drop: {}", e);
103 }
104 }
105 }
106}
107
108fn to_owned_string_vec(opt: Option<&[&str]>) -> Option<Vec<String>> {
111 opt.map(|s| s.iter().map(|v| v.to_string()).collect())
112}
113
114fn as_str_slice(opt: &Option<Vec<String>>) -> Option<Vec<&str>> {
116 opt.as_ref().map(|v| v.iter().map(|s| s.as_str()).collect())
117}
118
119impl MemoryStore {
120 async fn with_conn<F, T>(&self, f: F) -> Result<T, MemoryError>
125 where
126 F: FnOnce(&rusqlite::Connection) -> Result<T, MemoryError> + Send + 'static,
127 T: Send + 'static,
128 {
129 let inner = self.inner.clone();
130 tokio::task::spawn_blocking(move || {
131 let conn = inner.conn.lock().expect("mutex poisoned");
132 f(&conn)
133 })
134 .await
135 .map_err(|e| MemoryError::Other(format!("Blocking task panicked: {}", e)))?
136 }
137
138 pub fn open(config: MemoryConfig) -> Result<Self, MemoryError> {
143 let embedder = Box::new(OllamaEmbedder::new(&config.embedding));
144 Self::open_with_embedder(config, embedder)
145 }
146
147 #[allow(unused_mut)] pub fn open_with_embedder(
150 mut config: MemoryConfig,
151 embedder: Box<dyn Embedder>,
152 ) -> Result<Self, MemoryError> {
153 let paths = StoragePaths::new(&config.base_dir);
154
155 std::fs::create_dir_all(&paths.base_dir).map_err(|e| {
157 MemoryError::StorageError(format!(
158 "Failed to create directory {}: {}",
159 paths.base_dir.display(),
160 e
161 ))
162 })?;
163
164 let conn = db::open_database(&paths.sqlite_path)?;
165 db::check_embedding_metadata(&conn, &config.embedding)?;
166
167 #[cfg(feature = "hnsw")]
169 {
170 config.hnsw.dimensions = config.embedding.dimensions;
171 }
172
173 let token_counter = config
174 .token_counter
175 .clone()
176 .unwrap_or_else(tokenizer::default_token_counter);
177
178 #[cfg(feature = "hnsw")]
179 let hnsw_index = {
180 let hnsw_config = config.hnsw.clone();
181
182 let embeddings_dirty = db::is_embeddings_dirty(&conn)?;
183
184 if embeddings_dirty {
185 tracing::warn!(
188 "Embedding model changed — creating fresh HNSW index (old index is stale)"
189 );
190 HnswIndex::new(hnsw_config)?
191 } else if paths.hnsw_files_exist() {
192 tracing::info!("Loading HNSW index from {:?}", paths.hnsw_dir);
193 match HnswIndex::load(&paths.hnsw_dir, &paths.hnsw_basename, hnsw_config.clone()) {
194 Ok(index) => {
195 if let Err(e) = index.load_keymap(&conn) {
197 tracing::warn!("Failed to load HNSW key mappings: {}. Mappings will be empty until rebuild.", e);
198 }
199 tracing::info!(
200 "HNSW index loaded ({} active keys)",
201 index.len()
202 );
203 index
204 }
205 Err(e) => {
206 tracing::warn!(
207 "Failed to load HNSW index: {}. Creating new empty index.",
208 e
209 );
210 HnswIndex::new(hnsw_config)?
211 }
212 }
213 } else {
214 tracing::info!("Creating new empty HNSW index");
215 HnswIndex::new(hnsw_config)?
216 }
217 };
218
219 Ok(Self {
220 inner: Arc::new(MemoryStoreInner {
221 conn: Mutex::new(conn),
222 embedder,
223 config,
224 paths,
225 token_counter,
226 #[cfg(feature = "hnsw")]
227 hnsw_index: std::sync::RwLock::new(hnsw_index),
228 }),
229 })
230 }
231
232 #[cfg(feature = "hnsw")]
238 pub async fn rebuild_hnsw_index(&self) -> Result<(), MemoryError> {
239 tracing::info!("Rebuilding HNSW index from SQLite embeddings...");
240
241 let hnsw_config = self.inner.config.hnsw.clone();
242 let new_index = HnswIndex::new(hnsw_config)?;
243
244 let fact_data: Vec<(String, Vec<u8>)> = self
246 .with_conn(|conn| {
247 let mut stmt =
248 conn.prepare("SELECT id, embedding FROM facts WHERE embedding IS NOT NULL")?;
249 let result = stmt
250 .query_map([], |row| Ok((row.get(0)?, row.get(1)?)))?
251 .collect::<Result<Vec<_>, _>>()?;
252 Ok(result)
253 })
254 .await?;
255
256 for (fact_id, blob) in &fact_data {
257 let embedding = db::bytes_to_embedding(blob)?;
258 let key = format!("fact:{}", fact_id);
259 new_index.insert(key, &embedding)?;
260 }
261
262 let chunk_data: Vec<(String, Vec<u8>)> = self
264 .with_conn(|conn| {
265 let mut stmt =
266 conn.prepare("SELECT id, embedding FROM chunks WHERE embedding IS NOT NULL")?;
267 let result = stmt
268 .query_map([], |row| Ok((row.get(0)?, row.get(1)?)))?
269 .collect::<Result<Vec<_>, _>>()?;
270 Ok(result)
271 })
272 .await?;
273
274 for (chunk_id, blob) in &chunk_data {
275 let embedding = db::bytes_to_embedding(blob)?;
276 let key = format!("chunk:{}", chunk_id);
277 new_index.insert(key, &embedding)?;
278 }
279
280 let msg_data: Vec<(i64, Vec<u8>)> = self
282 .with_conn(|conn| {
283 let mut stmt = conn
284 .prepare("SELECT id, embedding FROM messages WHERE embedding IS NOT NULL")?;
285 let result = stmt
286 .query_map([], |row| Ok((row.get(0)?, row.get(1)?)))?
287 .collect::<Result<Vec<_>, _>>()?;
288 Ok(result)
289 })
290 .await?;
291
292 for (msg_id, blob) in &msg_data {
293 let embedding = db::bytes_to_embedding(blob)?;
294 let key = format!("msg:{}", msg_id);
295 new_index.insert(key, &embedding)?;
296 }
297
298 let total = fact_data.len() + chunk_data.len() + msg_data.len();
299 tracing::info!(
300 facts = fact_data.len(),
301 chunks = chunk_data.len(),
302 messages = msg_data.len(),
303 total = total,
304 "HNSW index rebuilt"
305 );
306
307 {
309 let mut guard = self.inner.hnsw_index.write().unwrap();
310 *guard = new_index;
311 }
312
313 {
315 let guard = self.inner.hnsw_index.read().unwrap();
316 guard.save(&self.inner.paths.hnsw_dir, &self.inner.paths.hnsw_basename)?;
317 let conn = self.inner.conn.lock().expect("mutex poisoned");
318 guard.flush_keymap(&conn)?;
319 }
320
321 Ok(())
322 }
323
324 #[cfg(feature = "hnsw")]
328 pub fn flush_hnsw(&self) -> Result<(), MemoryError> {
329 let guard = self.inner.hnsw_index.read().unwrap();
330 guard.save(&self.inner.paths.hnsw_dir, &self.inner.paths.hnsw_basename)?;
331
332 let conn = self.inner.conn.lock().expect("mutex poisoned");
334 guard.flush_keymap(&conn)?;
335 Ok(())
336 }
337
338 #[cfg(feature = "hnsw")]
342 pub async fn compact_hnsw(&self) -> Result<(), MemoryError> {
343 if !self.inner.hnsw_index.read().unwrap().needs_compaction() {
344 tracing::info!("HNSW compaction not needed (deleted ratio below threshold)");
345 return Ok(());
346 }
347 self.rebuild_hnsw_index().await
348 }
349
350 pub async fn create_session(&self, channel: &str) -> Result<String, MemoryError> {
354 let channel = channel.to_string();
355 self.with_conn(move |conn| conversation::create_session(conn, &channel, None))
356 .await
357 }
358
359 pub async fn list_sessions(
361 &self,
362 limit: usize,
363 offset: usize,
364 ) -> Result<Vec<Session>, MemoryError> {
365 self.with_conn(move |conn| conversation::list_sessions(conn, limit, offset))
366 .await
367 }
368
369 pub async fn delete_session(&self, session_id: &str) -> Result<(), MemoryError> {
371 let session_id = session_id.to_string();
372 self.with_conn(move |conn| conversation::delete_session(conn, &session_id))
373 .await
374 }
375
376 pub async fn add_message(
380 &self,
381 session_id: &str,
382 role: Role,
383 content: &str,
384 token_count: Option<u32>,
385 metadata: Option<serde_json::Value>,
386 ) -> Result<i64, MemoryError> {
387 let effective_token_count =
388 token_count.or_else(|| Some(self.inner.token_counter.count_tokens(content) as u32));
389 let sid = session_id.to_string();
390 let ct = content.to_string();
391 let meta = metadata;
392 self.with_conn(move |conn| {
393 conversation::add_message(conn, &sid, role, &ct, effective_token_count, meta.as_ref())
394 })
395 .await
396 }
397
398 pub async fn get_recent_messages(
400 &self,
401 session_id: &str,
402 limit: usize,
403 ) -> Result<Vec<Message>, MemoryError> {
404 let sid = session_id.to_string();
405 self.with_conn(move |conn| conversation::get_recent_messages(conn, &sid, limit))
406 .await
407 }
408
409 pub async fn get_messages_within_budget(
411 &self,
412 session_id: &str,
413 max_tokens: u32,
414 ) -> Result<Vec<Message>, MemoryError> {
415 let sid = session_id.to_string();
416 self.with_conn(move |conn| conversation::get_messages_within_budget(conn, &sid, max_tokens))
417 .await
418 }
419
420 pub async fn session_token_count(&self, session_id: &str) -> Result<u64, MemoryError> {
422 let sid = session_id.to_string();
423 self.with_conn(move |conn| conversation::session_token_count(conn, &sid))
424 .await
425 }
426
427 pub async fn add_fact(
431 &self,
432 namespace: &str,
433 content: &str,
434 source: Option<&str>,
435 metadata: Option<serde_json::Value>,
436 ) -> Result<String, MemoryError> {
437 let embedding = self.inner.embedder.embed(content).await?;
438 let embedding_bytes = db::embedding_to_bytes(&embedding);
439 let fact_id = uuid::Uuid::new_v4().to_string();
440
441 let quantizer = Quantizer::new(self.inner.config.embedding.dimensions);
443 let q8_bytes = quantizer.quantize(&embedding)
444 .map(|qv| quantize::pack_quantized(&qv))
445 .ok();
446
447 let ns = namespace.to_string();
448 let ct = content.to_string();
449 let fid = fact_id.clone();
450 let src = source.map(|s| s.to_string());
451 let meta = metadata;
452 self.with_conn(move |conn| {
453 knowledge::insert_fact_with_fts_q8(
454 conn,
455 &fid,
456 &ns,
457 &ct,
458 &embedding_bytes,
459 q8_bytes.as_deref(),
460 src.as_deref(),
461 meta.as_ref(),
462 )
463 })
464 .await?;
465
466 #[cfg(feature = "hnsw")]
468 {
469 let key = format!("fact:{}", fact_id);
470 self.inner.hnsw_index.read().unwrap().insert(key, &embedding)?;
471 }
472
473 Ok(fact_id)
474 }
475
476 pub async fn add_fact_with_embedding(
478 &self,
479 namespace: &str,
480 content: &str,
481 embedding: &[f32],
482 source: Option<&str>,
483 metadata: Option<serde_json::Value>,
484 ) -> Result<String, MemoryError> {
485 let embedding_bytes = db::embedding_to_bytes(embedding);
486 let fact_id = uuid::Uuid::new_v4().to_string();
487
488 let quantizer = Quantizer::new(self.inner.config.embedding.dimensions);
490 let q8_bytes = quantizer.quantize(embedding)
491 .map(|qv| quantize::pack_quantized(&qv))
492 .ok();
493
494 let ns = namespace.to_string();
495 let ct = content.to_string();
496 let fid = fact_id.clone();
497 let src = source.map(|s| s.to_string());
498 let meta = metadata;
499 self.with_conn(move |conn| {
500 knowledge::insert_fact_with_fts_q8(
501 conn,
502 &fid,
503 &ns,
504 &ct,
505 &embedding_bytes,
506 q8_bytes.as_deref(),
507 src.as_deref(),
508 meta.as_ref(),
509 )
510 })
511 .await?;
512
513 #[cfg(feature = "hnsw")]
515 {
516 let key = format!("fact:{}", fact_id);
517 self.inner.hnsw_index.read().unwrap().insert(key, embedding)?;
518 }
519
520 Ok(fact_id)
521 }
522
523 pub async fn update_fact(&self, fact_id: &str, content: &str) -> Result<(), MemoryError> {
525 let embedding = self.inner.embedder.embed(content).await?;
526 let embedding_bytes = db::embedding_to_bytes(&embedding);
527
528 let fid = fact_id.to_string();
529 let ct = content.to_string();
530 self.with_conn(move |conn| {
531 knowledge::update_fact_with_fts(conn, &fid, &ct, &embedding_bytes)
532 })
533 .await?;
534
535 #[cfg(feature = "hnsw")]
537 {
538 let key = format!("fact:{}", fact_id);
539 self.inner.hnsw_index.read().unwrap().update(key, &embedding)?;
540 }
541
542 Ok(())
543 }
544
545 pub async fn delete_fact(&self, fact_id: &str) -> Result<(), MemoryError> {
547 let fid = fact_id.to_string();
548 self.with_conn(move |conn| knowledge::delete_fact_with_fts(conn, &fid))
549 .await?;
550
551 #[cfg(feature = "hnsw")]
553 {
554 let key = format!("fact:{}", fact_id);
555 self.inner.hnsw_index.read().unwrap().delete(&key)?;
556 }
557
558 Ok(())
559 }
560
561 pub async fn delete_namespace(&self, namespace: &str) -> Result<usize, MemoryError> {
563 let ns = namespace.to_string();
564
565 #[cfg(feature = "hnsw")]
567 let fact_ids: Vec<String> = {
568 let ns_clone = ns.clone();
569 self.with_conn(move |conn| {
570 let mut stmt = conn.prepare("SELECT id FROM facts WHERE namespace = ?1")?;
571 let ids = stmt
572 .query_map(rusqlite::params![ns_clone], |row| row.get(0))?
573 .collect::<Result<Vec<String>, _>>()?;
574 Ok(ids)
575 })
576 .await?
577 };
578
579 let count = self
580 .with_conn(move |conn| knowledge::delete_namespace(conn, &ns))
581 .await?;
582
583 #[cfg(feature = "hnsw")]
585 {
586 for fact_id in &fact_ids {
587 let key = format!("fact:{}", fact_id);
588 self.inner.hnsw_index.read().unwrap().delete(&key)?;
589 }
590 }
591
592 Ok(count)
593 }
594
595 pub async fn get_fact(&self, fact_id: &str) -> Result<Option<Fact>, MemoryError> {
597 let fid = fact_id.to_string();
598 self.with_conn(move |conn| knowledge::get_fact(conn, &fid))
599 .await
600 }
601
602 pub async fn get_fact_embedding(&self, fact_id: &str) -> Result<Option<Vec<f32>>, MemoryError> {
604 let fid = fact_id.to_string();
605 self.with_conn(move |conn| knowledge::get_fact_embedding(conn, &fid))
606 .await
607 }
608
609 pub async fn list_facts(
611 &self,
612 namespace: &str,
613 limit: usize,
614 offset: usize,
615 ) -> Result<Vec<Fact>, MemoryError> {
616 let ns = namespace.to_string();
617 self.with_conn(move |conn| knowledge::list_facts(conn, &ns, limit, offset))
618 .await
619 }
620
621 pub async fn ingest_document(
625 &self,
626 title: &str,
627 content: &str,
628 namespace: &str,
629 source_path: Option<&str>,
630 metadata: Option<serde_json::Value>,
631 ) -> Result<String, MemoryError> {
632 let text_chunks = chunker::chunk_text(
633 content,
634 &self.inner.config.chunking,
635 self.inner.token_counter.as_ref(),
636 );
637
638 let chunk_texts: Vec<String> = text_chunks.iter().map(|c| c.content.clone()).collect();
639 let embeddings = self.inner.embedder.embed_batch(chunk_texts).await?;
640
641 let quantizer = Quantizer::new(self.inner.config.embedding.dimensions);
642 let chunks: Vec<documents::ChunkRow> = text_chunks
643 .iter()
644 .zip(embeddings.iter())
645 .map(|(tc, emb)| {
646 let q8 = quantizer.quantize(emb)
647 .map(|qv| quantize::pack_quantized(&qv))
648 .ok();
649 (
650 tc.content.clone(),
651 db::embedding_to_bytes(emb),
652 q8,
653 tc.token_count_estimate,
654 )
655 })
656 .collect();
657
658 let doc_id = uuid::Uuid::new_v4().to_string();
659
660 let did = doc_id.clone();
661 let t = title.to_string();
662 let ns = namespace.to_string();
663 let sp = source_path.map(|s| s.to_string());
664 let meta = metadata;
665
666 #[cfg(feature = "hnsw")]
668 let chunk_ids: Vec<String> = {
669 let chunk_ids: Vec<String> = (0..chunks.len())
671 .map(|_| uuid::Uuid::new_v4().to_string())
672 .collect();
673 let cids = chunk_ids.clone();
674
675 let did_clone = did.clone();
676 self.with_conn(move |conn| {
677 documents::insert_document_with_chunks_and_ids(
678 conn,
679 &did_clone,
680 &t,
681 &ns,
682 sp.as_deref(),
683 meta.as_ref(),
684 &chunks,
685 &cids,
686 )
687 })
688 .await?;
689
690 chunk_ids
691 };
692
693 #[cfg(not(feature = "hnsw"))]
694 {
695 self.with_conn(move |conn| {
696 documents::insert_document_with_chunks(
697 conn,
698 &did,
699 &t,
700 &ns,
701 sp.as_deref(),
702 meta.as_ref(),
703 &chunks,
704 )
705 })
706 .await?;
707 }
708
709 #[cfg(feature = "hnsw")]
711 {
712 for (chunk_id, embedding) in chunk_ids.iter().zip(embeddings.iter()) {
713 let key = format!("chunk:{}", chunk_id);
714 self.inner.hnsw_index.read().unwrap().insert(key, embedding)?;
715 }
716 }
717
718 Ok(doc_id)
719 }
720
721 pub async fn delete_document(&self, document_id: &str) -> Result<(), MemoryError> {
723 #[cfg(feature = "hnsw")]
725 let chunk_ids: Vec<String> = {
726 let did = document_id.to_string();
727 self.with_conn(move |conn| {
728 let mut stmt =
729 conn.prepare("SELECT id FROM chunks WHERE document_id = ?1")?;
730 let ids = stmt
731 .query_map(rusqlite::params![did], |row| row.get(0))?
732 .collect::<Result<Vec<String>, _>>()?;
733 Ok(ids)
734 })
735 .await?
736 };
737
738 let did = document_id.to_string();
739 self.with_conn(move |conn| documents::delete_document_with_chunks(conn, &did))
740 .await?;
741
742 #[cfg(feature = "hnsw")]
744 {
745 for chunk_id in &chunk_ids {
746 let key = format!("chunk:{}", chunk_id);
747 self.inner.hnsw_index.read().unwrap().delete(&key)?;
748 }
749 }
750
751 Ok(())
752 }
753
754 pub async fn list_documents(
756 &self,
757 namespace: &str,
758 limit: usize,
759 offset: usize,
760 ) -> Result<Vec<Document>, MemoryError> {
761 let ns = namespace.to_string();
762 self.with_conn(move |conn| documents::list_documents(conn, &ns, limit, offset))
763 .await
764 }
765
766 pub async fn search(
770 &self,
771 query: &str,
772 top_k: Option<usize>,
773 namespaces: Option<&[&str]>,
774 source_types: Option<&[SearchSourceType]>,
775 ) -> Result<Vec<SearchResult>, MemoryError> {
776 let k = top_k.unwrap_or(self.inner.config.search.default_top_k);
777
778 let query_embedding = self.inner.embedder.embed(query).await?;
779
780 #[cfg(feature = "hnsw")]
782 let hnsw_hits = {
783 let guard = self.inner.hnsw_index.read().unwrap();
784 if guard.needs_compaction() {
785 tracing::warn!(
786 deleted_ratio = guard.deleted_ratio(),
787 "HNSW index has high tombstone ratio. Call compact_hnsw() to reclaim."
788 );
789 }
790 let candidates = k * 3;
791 guard.search(&query_embedding, candidates)?
792 };
793
794 let q = query.to_string();
795 let config = self.inner.config.search.clone();
796 let ns_owned = to_owned_string_vec(namespaces);
797 let st_owned: Option<Vec<SearchSourceType>> = source_types.map(|s| s.to_vec());
798
799 #[cfg(feature = "hnsw")]
800 let hnsw_hits_owned = hnsw_hits;
801
802 self.with_conn(move |conn| {
803 if db::is_embeddings_dirty(conn)? {
804 tracing::warn!(
805 "Embeddings are stale after model change — search quality is degraded. \
806 Call reembed_all() to regenerate embeddings."
807 );
808 }
809 let ns_refs = as_str_slice(&ns_owned);
810 let ns_slice: Option<&[&str]> = ns_refs.as_deref();
811 let st_slice: Option<&[SearchSourceType]> = st_owned.as_deref();
812
813 #[cfg(feature = "hnsw")]
814 {
815 search::hybrid_search_with_hnsw(
816 conn,
817 &q,
818 &query_embedding,
819 &config,
820 k,
821 ns_slice,
822 st_slice,
823 None,
824 &hnsw_hits_owned,
825 )
826 }
827 #[cfg(not(feature = "hnsw"))]
828 {
829 search::hybrid_search(
830 conn,
831 &q,
832 &query_embedding,
833 &config,
834 k,
835 ns_slice,
836 st_slice,
837 None,
838 )
839 }
840 })
841 .await
842 }
843
844 pub async fn search_fts_only(
846 &self,
847 query: &str,
848 top_k: Option<usize>,
849 namespaces: Option<&[&str]>,
850 source_types: Option<&[SearchSourceType]>,
851 ) -> Result<Vec<SearchResult>, MemoryError> {
852 let k = top_k.unwrap_or(self.inner.config.search.default_top_k);
853 let q = query.to_string();
854 let config = self.inner.config.search.clone();
855 let ns_owned = to_owned_string_vec(namespaces);
856 let st_owned: Option<Vec<SearchSourceType>> = source_types.map(|s| s.to_vec());
857 self.with_conn(move |conn| {
858 let ns_refs = as_str_slice(&ns_owned);
859 let ns_slice: Option<&[&str]> = ns_refs.as_deref();
860 let st_slice: Option<&[SearchSourceType]> = st_owned.as_deref();
861 search::fts_only_search(conn, &q, &config, k, ns_slice, st_slice, None)
862 })
863 .await
864 }
865
866 pub async fn search_vector_only(
868 &self,
869 query: &str,
870 top_k: Option<usize>,
871 namespaces: Option<&[&str]>,
872 source_types: Option<&[SearchSourceType]>,
873 ) -> Result<Vec<SearchResult>, MemoryError> {
874 let k = top_k.unwrap_or(self.inner.config.search.default_top_k);
875 let query_embedding = self.inner.embedder.embed(query).await?;
876
877 #[cfg(feature = "hnsw")]
879 let hnsw_hits = {
880 let candidates = k * 3;
881 self.inner.hnsw_index.read().unwrap().search(&query_embedding, candidates)?
882 };
883
884 let config = self.inner.config.search.clone();
885 let ns_owned = to_owned_string_vec(namespaces);
886 let st_owned: Option<Vec<SearchSourceType>> = source_types.map(|s| s.to_vec());
887
888 #[cfg(feature = "hnsw")]
889 let hnsw_hits_owned = hnsw_hits;
890
891 self.with_conn(move |conn| {
892 if db::is_embeddings_dirty(conn)? {
893 tracing::warn!(
894 "Embeddings are stale after model change — search quality is degraded. \
895 Call reembed_all() to regenerate embeddings."
896 );
897 }
898 let ns_refs = as_str_slice(&ns_owned);
899 let ns_slice: Option<&[&str]> = ns_refs.as_deref();
900 let st_slice: Option<&[SearchSourceType]> = st_owned.as_deref();
901
902 #[cfg(feature = "hnsw")]
903 {
904 search::vector_only_search_with_hnsw(
905 conn, &config, k, ns_slice, st_slice, None, &hnsw_hits_owned,
906 )
907 }
908 #[cfg(not(feature = "hnsw"))]
909 {
910 search::vector_only_search(conn, &query_embedding, &config, k, ns_slice, st_slice, None)
911 }
912 })
913 .await
914 }
915
916 pub async fn add_message_embedded(
920 &self,
921 session_id: &str,
922 role: Role,
923 content: &str,
924 token_count: Option<u32>,
925 metadata: Option<serde_json::Value>,
926 ) -> Result<i64, MemoryError> {
927 let effective_token_count =
928 token_count.or_else(|| Some(self.inner.token_counter.count_tokens(content) as u32));
929
930 let embedding = self.inner.embedder.embed(content).await?;
931 let embedding_bytes = db::embedding_to_bytes(&embedding);
932
933 let quantizer = Quantizer::new(self.inner.config.embedding.dimensions);
935 let q8_bytes = quantizer.quantize(&embedding)
936 .map(|qv| quantize::pack_quantized(&qv))
937 .ok();
938
939 let sid = session_id.to_string();
940 let ct = content.to_string();
941 let meta = metadata;
942 let msg_id = self
943 .with_conn(move |conn| {
944 conversation::add_message_with_embedding_q8(
945 conn,
946 &sid,
947 role,
948 &ct,
949 effective_token_count,
950 meta.as_ref(),
951 &embedding_bytes,
952 q8_bytes.as_deref(),
953 )
954 })
955 .await?;
956
957 #[cfg(feature = "hnsw")]
959 {
960 let key = format!("msg:{}", msg_id);
961 self.inner.hnsw_index.read().unwrap().insert(key, &embedding)?;
962 }
963
964 Ok(msg_id)
965 }
966
967 pub async fn search_conversations(
969 &self,
970 query: &str,
971 top_k: Option<usize>,
972 session_ids: Option<&[&str]>,
973 ) -> Result<Vec<SearchResult>, MemoryError> {
974 let k = top_k.unwrap_or(self.inner.config.search.default_top_k);
975
976 let query_embedding = self.inner.embedder.embed(query).await?;
977
978 let q = query.to_string();
979 let config = self.inner.config.search.clone();
980 let sids_owned = to_owned_string_vec(session_ids);
981 self.with_conn(move |conn| {
982 let sids_refs = as_str_slice(&sids_owned);
983 let sids_slice: Option<&[&str]> = sids_refs.as_deref();
984 search::hybrid_search(
985 conn,
986 &q,
987 &query_embedding,
988 &config,
989 k,
990 None,
991 Some(&[SearchSourceType::Messages]),
992 sids_slice,
993 )
994 })
995 .await
996 }
997
998 pub fn chunk_text(&self, text: &str) -> Vec<TextChunk> {
1002 chunker::chunk_text(
1003 text,
1004 &self.inner.config.chunking,
1005 self.inner.token_counter.as_ref(),
1006 )
1007 }
1008
1009 pub async fn embed(&self, text: &str) -> Result<Vec<f32>, MemoryError> {
1011 self.inner.embedder.embed(text).await
1012 }
1013
1014 pub async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, MemoryError> {
1016 let owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
1017 self.inner.embedder.embed_batch(owned).await
1018 }
1019
1020 pub async fn stats(&self) -> Result<MemoryStats, MemoryError> {
1022 let db_path = self.inner.paths.sqlite_path.clone();
1023 self.with_conn(move |conn| {
1024 let total_facts: u64 =
1025 conn.query_row("SELECT COUNT(*) FROM facts", [], |r| r.get(0))?;
1026 let total_documents: u64 =
1027 conn.query_row("SELECT COUNT(*) FROM documents", [], |r| r.get(0))?;
1028 let total_chunks: u64 =
1029 conn.query_row("SELECT COUNT(*) FROM chunks", [], |r| r.get(0))?;
1030 let total_sessions: u64 =
1031 conn.query_row("SELECT COUNT(*) FROM sessions", [], |r| r.get(0))?;
1032 let total_messages: u64 =
1033 conn.query_row("SELECT COUNT(*) FROM messages", [], |r| r.get(0))?;
1034
1035 let db_size = std::fs::metadata(&db_path).map(|m| m.len()).unwrap_or(0);
1036
1037 let (model, dims): (Option<String>, Option<usize>) = conn
1038 .query_row(
1039 "SELECT model_name, dimensions FROM embedding_metadata WHERE id = 1",
1040 [],
1041 |r| Ok((Some(r.get(0)?), Some(r.get(1)?))),
1042 )
1043 .unwrap_or((None, None));
1044
1045 Ok(MemoryStats {
1046 total_facts,
1047 total_documents,
1048 total_chunks,
1049 total_sessions,
1050 total_messages,
1051 database_size_bytes: db_size,
1052 embedding_model: model,
1053 embedding_dimensions: dims,
1054 })
1055 })
1056 .await
1057 }
1058
1059 pub async fn embeddings_are_dirty(&self) -> Result<bool, MemoryError> {
1061 self.with_conn(db::is_embeddings_dirty).await
1062 }
1063
1064 pub async fn reembed_all(&self) -> Result<usize, MemoryError> {
1066 let mut count = 0usize;
1067 let batch_size = self.inner.config.embedding.batch_size;
1068 let dims = self.inner.config.embedding.dimensions;
1069
1070 let fact_contents: Vec<(String, String)> = self
1072 .with_conn(|conn| {
1073 let mut stmt = conn.prepare("SELECT id, content FROM facts")?;
1074 let result = stmt
1075 .query_map([], |row| Ok((row.get(0)?, row.get(1)?)))?
1076 .collect::<Result<Vec<_>, _>>()?;
1077 Ok(result)
1078 })
1079 .await?;
1080
1081 let mut fact_count = 0usize;
1082 for batch in fact_contents.chunks(batch_size) {
1083 let texts: Vec<String> = batch.iter().map(|(_, c)| c.clone()).collect();
1084 let embeddings = self.inner.embedder.embed_batch(texts).await?;
1085
1086 let quantizer = Quantizer::new(dims);
1087 let updates: Vec<(String, Vec<u8>, Option<Vec<u8>>)> = batch
1088 .iter()
1089 .zip(embeddings.iter())
1090 .map(|((id, _), emb)| {
1091 let q8 = quantizer.quantize(emb)
1092 .map(|qv| quantize::pack_quantized(&qv))
1093 .ok();
1094 (id.clone(), db::embedding_to_bytes(emb), q8)
1095 })
1096 .collect();
1097
1098 self.with_conn(move |conn| {
1099 db::with_transaction(conn, |tx| {
1100 for (fid, bytes, q8) in &updates {
1101 tx.execute(
1102 "UPDATE facts SET embedding = ?1, embedding_q8 = ?2, updated_at = datetime('now') WHERE id = ?3",
1103 rusqlite::params![bytes, q8.as_deref(), fid],
1104 )?;
1105 }
1106 Ok(())
1107 })
1108 })
1109 .await?;
1110
1111 fact_count += batch.len();
1112 count += batch.len();
1113 if fact_count % 100 < batch_size {
1114 tracing::info!(fact_count, "Re-embedded {} facts so far", fact_count);
1115 }
1116 }
1117
1118 let chunk_data: Vec<(String, String)> = self
1120 .with_conn(|conn| {
1121 let mut stmt = conn.prepare("SELECT id, content FROM chunks")?;
1122 let result = stmt
1123 .query_map([], |row| Ok((row.get(0)?, row.get(1)?)))?
1124 .collect::<Result<Vec<_>, _>>()?;
1125 Ok(result)
1126 })
1127 .await?;
1128
1129 let mut chunk_count = 0usize;
1130 for batch in chunk_data.chunks(batch_size) {
1131 let texts: Vec<String> = batch.iter().map(|(_, c)| c.clone()).collect();
1132 let embeddings = self.inner.embedder.embed_batch(texts).await?;
1133
1134 let quantizer = Quantizer::new(dims);
1135 let updates: Vec<(String, Vec<u8>, Option<Vec<u8>>)> = batch
1136 .iter()
1137 .zip(embeddings.iter())
1138 .map(|((id, _), emb)| {
1139 let q8 = quantizer.quantize(emb)
1140 .map(|qv| quantize::pack_quantized(&qv))
1141 .ok();
1142 (id.clone(), db::embedding_to_bytes(emb), q8)
1143 })
1144 .collect();
1145
1146 self.with_conn(move |conn| {
1147 db::with_transaction(conn, |tx| {
1148 for (cid, bytes, q8) in &updates {
1149 tx.execute(
1150 "UPDATE chunks SET embedding = ?1, embedding_q8 = ?2 WHERE id = ?3",
1151 rusqlite::params![bytes, q8.as_deref(), cid],
1152 )?;
1153 }
1154 Ok(())
1155 })
1156 })
1157 .await?;
1158
1159 chunk_count += batch.len();
1160 count += batch.len();
1161 if chunk_count % 100 < batch_size {
1162 tracing::info!(chunk_count, "Re-embedded {} chunks so far", chunk_count);
1163 }
1164 }
1165
1166 let message_data: Vec<(i64, String)> = self
1168 .with_conn(|conn| {
1169 let mut stmt =
1170 conn.prepare("SELECT id, content FROM messages WHERE embedding IS NOT NULL")?;
1171 let result = stmt
1172 .query_map([], |row| Ok((row.get(0)?, row.get(1)?)))?
1173 .collect::<Result<Vec<_>, _>>()?;
1174 Ok(result)
1175 })
1176 .await?;
1177
1178 let mut msg_count = 0usize;
1179 for batch in message_data.chunks(batch_size) {
1180 let texts: Vec<String> = batch.iter().map(|(_, c)| c.clone()).collect();
1181 let embeddings = self.inner.embedder.embed_batch(texts).await?;
1182
1183 let quantizer = Quantizer::new(dims);
1184 let updates: Vec<(i64, Vec<u8>, Option<Vec<u8>>)> = batch
1185 .iter()
1186 .zip(embeddings.iter())
1187 .map(|((id, _), emb)| {
1188 let q8 = quantizer.quantize(emb)
1189 .map(|qv| quantize::pack_quantized(&qv))
1190 .ok();
1191 (*id, db::embedding_to_bytes(emb), q8)
1192 })
1193 .collect();
1194
1195 self.with_conn(move |conn| {
1196 db::with_transaction(conn, |tx| {
1197 for (mid, bytes, q8) in &updates {
1198 tx.execute(
1199 "UPDATE messages SET embedding = ?1, embedding_q8 = ?2 WHERE id = ?3",
1200 rusqlite::params![bytes, q8.as_deref(), mid],
1201 )?;
1202 }
1203 Ok(())
1204 })
1205 })
1206 .await?;
1207
1208 msg_count += batch.len();
1209 count += batch.len();
1210 if msg_count % 100 < batch_size {
1211 tracing::info!(msg_count, "Re-embedded {} messages so far", msg_count);
1212 }
1213 }
1214
1215 self.with_conn(db::clear_embeddings_dirty).await?;
1217
1218 tracing::info!(
1219 facts = fact_count,
1220 chunks = chunk_count,
1221 messages = msg_count,
1222 total = count,
1223 "Re-embedding complete"
1224 );
1225
1226 #[cfg(feature = "hnsw")]
1228 {
1229 tracing::info!("Rebuilding HNSW index after re-embedding...");
1230 self.rebuild_hnsw_index().await?;
1231 }
1232
1233 Ok(count)
1234 }
1235
1236 pub async fn vacuum(&self) -> Result<(), MemoryError> {
1238 self.with_conn(|conn| {
1239 conn.execute_batch("VACUUM")?;
1240 Ok(())
1241 })
1242 .await
1243 }
1244
1245 #[cfg(any(test, feature = "testing"))]
1247 pub async fn raw_execute(&self, sql: &str, params: Vec<String>) -> Result<usize, MemoryError> {
1248 let sql = sql.to_string();
1249 self.with_conn(move |conn| {
1250 let param_refs: Vec<&dyn rusqlite::types::ToSql> = params
1251 .iter()
1252 .map(|s| s as &dyn rusqlite::types::ToSql)
1253 .collect();
1254 Ok(conn.execute(&sql, &*param_refs)?)
1255 })
1256 .await
1257 }
1258}