1use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6
7use futures::{StreamExt as _, TryStreamExt as _};
8use zeph_llm::provider::{LlmProvider as _, Message};
9
10const CHARS_PER_TOKEN: usize = 4;
12
13const CHUNK_CHARS: usize = 400 * CHARS_PER_TOKEN;
15
16const CHUNK_OVERLAP_CHARS: usize = 80 * CHARS_PER_TOKEN;
18
19fn chunk_text(text: &str) -> Vec<&str> {
25 if text.len() <= CHUNK_CHARS {
26 return vec![text];
27 }
28
29 let mut chunks = Vec::new();
30 let mut start = 0;
31
32 while start < text.len() {
33 let end = if start + CHUNK_CHARS >= text.len() {
34 text.len()
35 } else {
36 let boundary = text.floor_char_boundary(start + CHUNK_CHARS);
38 let slice = &text[start..boundary];
40 if let Some(pos) = slice.rfind("\n\n") {
41 start + pos + 2
42 } else if let Some(pos) = slice.rfind('\n') {
43 start + pos + 1
44 } else if let Some(pos) = slice.rfind(' ') {
45 start + pos + 1
46 } else {
47 boundary
48 }
49 };
50
51 chunks.push(&text[start..end]);
52 if end >= text.len() {
53 break;
54 }
55 let next = end.saturating_sub(CHUNK_OVERLAP_CHARS);
59 let new_start = text.ceil_char_boundary(next);
60 start = if new_start > start { new_start } else { end };
61 }
62
63 chunks
64}
65
66use crate::admission::log_admission_decision;
67use crate::embedding_store::{MessageKind, SearchFilter};
68use crate::error::MemoryError;
69use crate::types::{ConversationId, MessageId};
70
71use super::SemanticMemory;
72use super::algorithms::{apply_mmr, apply_temporal_decay};
73
74#[derive(Debug, Clone, Default)]
78pub struct EmbedContext {
79 pub tool_name: Option<String>,
80 pub exit_code: Option<i32>,
81 pub timestamp: Option<String>,
82}
83
84#[derive(Debug)]
85pub struct RecalledMessage {
86 pub message: Message,
87 pub score: f32,
88}
89
90const MAX_EMBED_BG_TASKS: usize = 64;
92
93struct EmbedBgArgs {
95 qdrant: Arc<crate::embedding_store::EmbeddingStore>,
96 embed_provider: zeph_llm::any::AnyProvider,
97 embedding_model: String,
98 message_id: MessageId,
99 conversation_id: ConversationId,
100 role: String,
101 content: String,
102 last_qdrant_warn: Arc<AtomicU64>,
103}
104
105async fn embed_and_store_regular_bg(args: EmbedBgArgs) {
109 let EmbedBgArgs {
110 qdrant,
111 embed_provider,
112 embedding_model,
113 message_id,
114 conversation_id,
115 role,
116 content,
117 last_qdrant_warn,
118 } = args;
119 let chunks = chunk_text(&content);
120 let chunk_count = chunks.len();
121
122 let vectors = match embed_provider.embed_batch(&chunks).await {
123 Ok(v) => v,
124 Err(e) => {
125 tracing::warn!("bg embed_regular: failed to embed chunks for msg {message_id}: {e:#}");
126 return;
127 }
128 };
129
130 let Some(first) = vectors.first() else {
131 return;
132 };
133 let vector_size = first.len() as u64;
134 if let Err(e) = qdrant.ensure_collection(vector_size).await {
135 let now = std::time::SystemTime::now()
136 .duration_since(std::time::UNIX_EPOCH)
137 .unwrap_or_default()
138 .as_secs();
139 let last = last_qdrant_warn.load(Ordering::Relaxed);
140 if now.saturating_sub(last) >= 10 {
141 last_qdrant_warn.store(now, Ordering::Relaxed);
142 tracing::warn!("bg embed_regular: failed to ensure Qdrant collection: {e:#}");
143 } else {
144 tracing::debug!(
145 "bg embed_regular: failed to ensure Qdrant collection (suppressed): {e:#}"
146 );
147 }
148 return;
149 }
150
151 for (chunk_index, vector) in vectors.into_iter().enumerate() {
152 let chunk_index_u32 = u32::try_from(chunk_index).unwrap_or(u32::MAX);
153 if let Err(e) = qdrant
154 .store(
155 message_id,
156 conversation_id,
157 &role,
158 vector,
159 MessageKind::Regular,
160 &embedding_model,
161 chunk_index_u32,
162 )
163 .await
164 {
165 tracing::warn!(
166 "bg embed_regular: failed to store chunk {chunk_index}/{chunk_count} \
167 for msg {message_id}: {e:#}"
168 );
169 }
170 }
171}
172
173async fn embed_chunks_with_tool_context_bg(args: EmbedBgArgs, embed_ctx: EmbedContext) {
177 let EmbedBgArgs {
178 qdrant,
179 embed_provider,
180 embedding_model,
181 message_id,
182 conversation_id,
183 role,
184 content,
185 last_qdrant_warn,
186 } = args;
187 let chunks = chunk_text(&content);
188 let chunk_count = chunks.len();
189
190 let vectors = match embed_provider.embed_batch(&chunks).await {
191 Ok(v) => v,
192 Err(e) => {
193 tracing::warn!(
194 "bg embed_tool: failed to embed tool-output chunks for msg {message_id}: {e:#}"
195 );
196 return;
197 }
198 };
199
200 if let Some(first) = vectors.first() {
201 let vector_size = first.len() as u64;
202 if let Err(e) = qdrant.ensure_collection(vector_size).await {
203 let now = std::time::SystemTime::now()
204 .duration_since(std::time::UNIX_EPOCH)
205 .unwrap_or_default()
206 .as_secs();
207 let last = last_qdrant_warn.load(Ordering::Relaxed);
208 if now.saturating_sub(last) >= 10 {
209 last_qdrant_warn.store(now, Ordering::Relaxed);
210 tracing::warn!("bg embed_tool: failed to ensure Qdrant collection: {e:#}");
211 } else {
212 tracing::debug!(
213 "bg embed_tool: failed to ensure Qdrant collection (suppressed): {e:#}"
214 );
215 }
216 return;
217 }
218 }
219
220 for (chunk_index, vector) in vectors.into_iter().enumerate() {
221 let chunk_index_u32 = u32::try_from(chunk_index).unwrap_or(u32::MAX);
222 let result = if let Some(ref tool_name) = embed_ctx.tool_name {
223 qdrant
224 .store_with_tool_context(
225 message_id,
226 conversation_id,
227 &role,
228 vector,
229 MessageKind::Regular,
230 &embedding_model,
231 chunk_index_u32,
232 tool_name,
233 embed_ctx.exit_code,
234 embed_ctx.timestamp.as_deref(),
235 )
236 .await
237 .map(|_| ())
238 } else {
239 qdrant
240 .store(
241 message_id,
242 conversation_id,
243 &role,
244 vector,
245 MessageKind::Regular,
246 &embedding_model,
247 chunk_index_u32,
248 )
249 .await
250 .map(|_| ())
251 };
252 if let Err(e) = result {
253 tracing::warn!(
254 "bg embed_tool: failed to store chunk {chunk_index}/{chunk_count} \
255 for msg {message_id}: {e:#}"
256 );
257 }
258 }
259}
260
261async fn embed_and_store_with_category_bg(args: EmbedBgArgs, category: Option<String>) {
265 let EmbedBgArgs {
266 qdrant,
267 embed_provider,
268 embedding_model,
269 message_id,
270 conversation_id,
271 role,
272 content,
273 last_qdrant_warn,
274 } = args;
275 let chunks = chunk_text(&content);
276 let chunk_count = chunks.len();
277
278 let vectors = match embed_provider.embed_batch(&chunks).await {
279 Ok(v) => v,
280 Err(e) => {
281 tracing::warn!(
282 "bg embed_category: failed to embed categorized chunks for msg {message_id}: {e:#}"
283 );
284 return;
285 }
286 };
287
288 let Some(first) = vectors.first() else {
289 return;
290 };
291 let vector_size = first.len() as u64;
292 if let Err(e) = qdrant.ensure_collection(vector_size).await {
293 let now = std::time::SystemTime::now()
294 .duration_since(std::time::UNIX_EPOCH)
295 .unwrap_or_default()
296 .as_secs();
297 let last = last_qdrant_warn.load(Ordering::Relaxed);
298 if now.saturating_sub(last) >= 10 {
299 last_qdrant_warn.store(now, Ordering::Relaxed);
300 tracing::warn!("bg embed_category: failed to ensure Qdrant collection: {e:#}");
301 } else {
302 tracing::debug!(
303 "bg embed_category: failed to ensure Qdrant collection (suppressed): {e:#}"
304 );
305 }
306 return;
307 }
308
309 for (chunk_index, vector) in vectors.into_iter().enumerate() {
310 let chunk_index_u32 = u32::try_from(chunk_index).unwrap_or(u32::MAX);
311 if let Err(e) = qdrant
312 .store_with_category(
313 message_id,
314 conversation_id,
315 &role,
316 vector,
317 MessageKind::Regular,
318 &embedding_model,
319 chunk_index_u32,
320 category.as_deref(),
321 )
322 .await
323 {
324 tracing::warn!(
325 "bg embed_category: failed to store chunk {chunk_index}/{chunk_count} \
326 for msg {message_id}: {e:#}"
327 );
328 }
329 }
330}
331
332impl SemanticMemory {
333 #[cfg_attr(
343 feature = "profiling",
344 tracing::instrument(name = "memory.remember", skip_all, fields(content_len = %content.len()))
345 )]
346 pub async fn remember(
347 &self,
348 conversation_id: ConversationId,
349 role: &str,
350 content: &str,
351 goal_text: Option<&str>,
352 ) -> Result<Option<MessageId>, MemoryError> {
353 if let Some(ref admission) = self.admission_control {
355 let decision = admission
356 .evaluate(
357 content,
358 role,
359 self.effective_embed_provider(),
360 self.qdrant.as_ref(),
361 goal_text,
362 )
363 .await;
364 let preview: String = content.chars().take(100).collect();
365 log_admission_decision(&decision, &preview, role, admission.threshold());
366 if !decision.admitted {
367 return Ok(None);
368 }
369 }
370
371 let message_id = self
372 .sqlite
373 .save_message(conversation_id, role, content)
374 .await?;
375
376 self.embed_and_store_regular(message_id, conversation_id, role, content);
377
378 Ok(Some(message_id))
379 }
380
381 #[cfg_attr(
390 feature = "profiling",
391 tracing::instrument(name = "memory.remember", skip_all, fields(content_len = %content.len()))
392 )]
393 pub async fn remember_with_parts(
394 &self,
395 conversation_id: ConversationId,
396 role: &str,
397 content: &str,
398 parts_json: &str,
399 goal_text: Option<&str>,
400 ) -> Result<(Option<MessageId>, bool), MemoryError> {
401 if let Some(ref admission) = self.admission_control {
403 let decision = admission
404 .evaluate(
405 content,
406 role,
407 self.effective_embed_provider(),
408 self.qdrant.as_ref(),
409 goal_text,
410 )
411 .await;
412 let preview: String = content.chars().take(100).collect();
413 log_admission_decision(&decision, &preview, role, admission.threshold());
414 if !decision.admitted {
415 return Ok((None, false));
416 }
417 }
418
419 let message_id = self
420 .sqlite
421 .save_message_with_parts(conversation_id, role, content, parts_json)
422 .await?;
423
424 let embedding_stored =
425 self.embed_and_store_regular(message_id, conversation_id, role, content);
426
427 Ok((Some(message_id), embedding_stored))
428 }
429
430 #[cfg_attr(
442 feature = "profiling",
443 tracing::instrument(name = "memory.remember", skip_all, fields(content_len = %content.len()))
444 )]
445 pub async fn remember_tool_output(
446 &self,
447 conversation_id: ConversationId,
448 role: &str,
449 content: &str,
450 parts_json: &str,
451 embed_ctx: EmbedContext,
452 ) -> Result<(Option<MessageId>, bool), MemoryError> {
453 if let Some(ref admission) = self.admission_control {
454 let decision = admission
455 .evaluate(
456 content,
457 role,
458 self.effective_embed_provider(),
459 self.qdrant.as_ref(),
460 None,
461 )
462 .await;
463 let preview: String = content.chars().take(100).collect();
464 log_admission_decision(&decision, &preview, role, admission.threshold());
465 if !decision.admitted {
466 return Ok((None, false));
467 }
468 }
469
470 let message_id = self
471 .sqlite
472 .save_message_with_parts(conversation_id, role, content, parts_json)
473 .await?;
474
475 let embedding_stored = self.embed_chunks_with_tool_context(
476 message_id,
477 conversation_id,
478 role,
479 content,
480 embed_ctx,
481 );
482
483 Ok((Some(message_id), embedding_stored))
484 }
485
486 #[cfg_attr(
497 feature = "profiling",
498 tracing::instrument(name = "memory.remember", skip_all, fields(content_len = %content.len()))
499 )]
500 pub async fn remember_categorized(
501 &self,
502 conversation_id: ConversationId,
503 role: &str,
504 content: &str,
505 category: Option<&str>,
506 goal_text: Option<&str>,
507 ) -> Result<Option<MessageId>, MemoryError> {
508 if let Some(ref admission) = self.admission_control {
509 let decision = admission
510 .evaluate(
511 content,
512 role,
513 self.effective_embed_provider(),
514 self.qdrant.as_ref(),
515 goal_text,
516 )
517 .await;
518 let preview: String = content.chars().take(100).collect();
519 log_admission_decision(&decision, &preview, role, admission.threshold());
520 if !decision.admitted {
521 return Ok(None);
522 }
523 }
524
525 let message_id = self
526 .sqlite
527 .save_message_with_category(conversation_id, role, content, category)
528 .await?;
529
530 self.embed_and_store_with_category(message_id, conversation_id, role, content, category);
531
532 Ok(Some(message_id))
533 }
534
535 pub async fn recall_with_category(
543 &self,
544 query: &str,
545 limit: usize,
546 filter: Option<SearchFilter>,
547 category: Option<&str>,
548 ) -> Result<Vec<RecalledMessage>, MemoryError> {
549 let filter_with_category = filter.map(|mut f| {
550 f.category = category.map(str::to_owned);
551 f
552 });
553 self.recall(query, limit, filter_with_category).await
554 }
555
556 pub fn reap_embed_tasks(&self) {
560 if let Ok(mut tasks) = self.embed_tasks.lock() {
561 while tasks.try_join_next().is_some() {}
562 }
563 }
564
565 fn spawn_embed_bg<F>(&self, fut: F) -> bool
569 where
570 F: std::future::Future<Output = ()> + Send + 'static,
571 {
572 let Ok(mut tasks) = self.embed_tasks.lock() else {
573 return false;
574 };
575 while tasks.try_join_next().is_some() {}
577 if tasks.len() >= MAX_EMBED_BG_TASKS {
578 tracing::debug!("background embed task limit reached, skipping");
579 return false;
580 }
581 tasks.spawn(fut);
582 true
583 }
584
585 fn embed_and_store_with_category(
589 &self,
590 message_id: MessageId,
591 conversation_id: ConversationId,
592 role: &str,
593 content: &str,
594 category: Option<&str>,
595 ) -> bool {
596 let Some(qdrant) = self.qdrant.clone() else {
597 return false;
598 };
599 let embed_provider = self.effective_embed_provider().clone();
600 if !embed_provider.supports_embeddings() {
601 return false;
602 }
603 self.spawn_embed_bg(embed_and_store_with_category_bg(
604 EmbedBgArgs {
605 qdrant,
606 embed_provider,
607 embedding_model: self.embedding_model.clone(),
608 message_id,
609 conversation_id,
610 role: role.to_owned(),
611 content: content.to_owned(),
612 last_qdrant_warn: Arc::clone(&self.last_qdrant_warn),
613 },
614 category.map(str::to_owned),
615 ))
616 }
617
618 fn embed_and_store_regular(
622 &self,
623 message_id: MessageId,
624 conversation_id: ConversationId,
625 role: &str,
626 content: &str,
627 ) -> bool {
628 let Some(qdrant) = self.qdrant.clone() else {
629 return false;
630 };
631 let embed_provider = self.effective_embed_provider().clone();
632 if !embed_provider.supports_embeddings() {
633 return false;
634 }
635 self.spawn_embed_bg(embed_and_store_regular_bg(EmbedBgArgs {
636 qdrant,
637 embed_provider,
638 embedding_model: self.embedding_model.clone(),
639 message_id,
640 conversation_id,
641 role: role.to_owned(),
642 content: content.to_owned(),
643 last_qdrant_warn: Arc::clone(&self.last_qdrant_warn),
644 }))
645 }
646
647 fn embed_chunks_with_tool_context(
651 &self,
652 message_id: MessageId,
653 conversation_id: ConversationId,
654 role: &str,
655 content: &str,
656 embed_ctx: EmbedContext,
657 ) -> bool {
658 let Some(qdrant) = self.qdrant.clone() else {
659 return false;
660 };
661 let embed_provider = self.effective_embed_provider().clone();
662 if !embed_provider.supports_embeddings() {
663 return false;
664 }
665 self.spawn_embed_bg(embed_chunks_with_tool_context_bg(
666 EmbedBgArgs {
667 qdrant,
668 embed_provider,
669 embedding_model: self.embedding_model.clone(),
670 message_id,
671 conversation_id,
672 role: role.to_owned(),
673 content: content.to_owned(),
674 last_qdrant_warn: Arc::clone(&self.last_qdrant_warn),
675 },
676 embed_ctx,
677 ))
678 }
679
680 pub async fn save_only(
688 &self,
689 conversation_id: ConversationId,
690 role: &str,
691 content: &str,
692 parts_json: &str,
693 ) -> Result<MessageId, MemoryError> {
694 self.sqlite
695 .save_message_with_parts(conversation_id, role, content, parts_json)
696 .await
697 }
698
699 #[cfg_attr(
709 feature = "profiling",
710 tracing::instrument(name = "memory.recall", skip_all, fields(query_len = %query.len(), result_count = tracing::field::Empty, top_score = tracing::field::Empty))
711 )]
712 pub async fn recall(
713 &self,
714 query: &str,
715 limit: usize,
716 filter: Option<SearchFilter>,
717 ) -> Result<Vec<RecalledMessage>, MemoryError> {
718 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
719
720 tracing::debug!(
721 query_len = query.len(),
722 limit,
723 has_filter = filter.is_some(),
724 conversation_id = conversation_id.map(|c| c.0),
725 has_qdrant = self.qdrant.is_some(),
726 "recall: starting hybrid search"
727 );
728
729 let keyword_results = match self
730 .sqlite
731 .keyword_search(query, limit * 2, conversation_id)
732 .await
733 {
734 Ok(results) => results,
735 Err(e) => {
736 tracing::warn!("FTS5 keyword search failed: {e:#}");
737 Vec::new()
738 }
739 };
740
741 let vector_results = if let Some(qdrant) = &self.qdrant
742 && self.effective_embed_provider().supports_embeddings()
743 {
744 let query_vector = self.effective_embed_provider().embed(query).await?;
745 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
746 qdrant.ensure_collection(vector_size).await?;
747 qdrant.search(&query_vector, limit * 2, filter).await?
748 } else {
749 Vec::new()
750 };
751
752 let results = self
753 .recall_merge_and_rank(keyword_results, vector_results, limit)
754 .await?;
755 #[cfg(feature = "profiling")]
756 {
757 let span = tracing::Span::current();
758 span.record("result_count", results.len());
759 if let Some(top) = results.first() {
760 span.record("top_score", top.score);
761 }
762 }
763 Ok(results)
764 }
765
766 pub(super) async fn recall_fts5_raw(
767 &self,
768 query: &str,
769 limit: usize,
770 conversation_id: Option<ConversationId>,
771 ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
772 self.sqlite
773 .keyword_search(query, limit * 2, conversation_id)
774 .await
775 }
776
777 pub(super) async fn recall_vectors_raw(
778 &self,
779 query: &str,
780 limit: usize,
781 filter: Option<SearchFilter>,
782 ) -> Result<Vec<crate::embedding_store::SearchResult>, MemoryError> {
783 let Some(qdrant) = &self.qdrant else {
784 return Ok(Vec::new());
785 };
786 if !self.effective_embed_provider().supports_embeddings() {
787 return Ok(Vec::new());
788 }
789 let query_vector = self.effective_embed_provider().embed(query).await?;
790 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
791 qdrant.ensure_collection(vector_size).await?;
792 qdrant.search(&query_vector, limit * 2, filter).await
793 }
794
795 #[allow(clippy::cast_possible_truncation, clippy::too_many_lines)]
804 pub(super) async fn recall_merge_and_rank(
805 &self,
806 keyword_results: Vec<(MessageId, f64)>,
807 vector_results: Vec<crate::embedding_store::SearchResult>,
808 limit: usize,
809 ) -> Result<Vec<RecalledMessage>, MemoryError> {
810 tracing::debug!(
811 vector_count = vector_results.len(),
812 keyword_count = keyword_results.len(),
813 limit,
814 "recall: merging search results"
815 );
816
817 let mut scores: std::collections::HashMap<MessageId, f64> =
818 std::collections::HashMap::new();
819
820 if !vector_results.is_empty() {
821 let max_vs = vector_results
822 .iter()
823 .map(|r| r.score)
824 .fold(f32::NEG_INFINITY, f32::max);
825 let norm = if max_vs > 0.0 { max_vs } else { 1.0 };
826 for r in &vector_results {
827 let normalized = f64::from(r.score / norm);
828 *scores.entry(r.message_id).or_default() += normalized * self.vector_weight;
829 }
830 }
831
832 if !keyword_results.is_empty() {
833 let max_ks = keyword_results
834 .iter()
835 .map(|r| r.1)
836 .fold(f64::NEG_INFINITY, f64::max);
837 let norm = if max_ks > 0.0 { max_ks } else { 1.0 };
838 for &(msg_id, score) in &keyword_results {
839 let normalized = score / norm;
840 *scores.entry(msg_id).or_default() += normalized * self.keyword_weight;
841 }
842 }
843
844 if scores.is_empty() {
845 tracing::debug!("recall: empty merge, no overlapping scores");
846 return Ok(Vec::new());
847 }
848
849 let mut ranked: Vec<(MessageId, f64)> = scores.into_iter().collect();
850 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
851
852 tracing::debug!(
853 merged = ranked.len(),
854 top_score = ranked.first().map(|r| r.1),
855 bottom_score = ranked.last().map(|r| r.1),
856 vector_weight = %self.vector_weight,
857 keyword_weight = %self.keyword_weight,
858 "recall: weighted merge complete"
859 );
860
861 if self.temporal_decay_enabled && self.temporal_decay_half_life_days > 0 {
862 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
863 match self.sqlite.message_timestamps(&ids).await {
864 Ok(timestamps) => {
865 apply_temporal_decay(
866 &mut ranked,
867 ×tamps,
868 self.temporal_decay_half_life_days,
869 );
870 ranked
871 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
872 tracing::debug!(
873 half_life_days = self.temporal_decay_half_life_days,
874 top_score_after = ranked.first().map(|r| r.1),
875 "recall: temporal decay applied"
876 );
877 }
878 Err(e) => {
879 tracing::warn!("temporal decay: failed to fetch timestamps: {e:#}");
880 }
881 }
882 }
883
884 if self.mmr_enabled && !vector_results.is_empty() {
885 if let Some(qdrant) = &self.qdrant {
886 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
887 match qdrant.get_vectors(&ids).await {
888 Ok(vec_map) if !vec_map.is_empty() => {
889 let ranked_len_before = ranked.len();
890 ranked = apply_mmr(&ranked, &vec_map, self.mmr_lambda, limit);
891 tracing::debug!(
892 before = ranked_len_before,
893 after = ranked.len(),
894 lambda = %self.mmr_lambda,
895 "recall: mmr re-ranked"
896 );
897 }
898 Ok(_) => {
899 ranked.truncate(limit);
900 }
901 Err(e) => {
902 tracing::warn!("MMR: failed to fetch vectors: {e:#}");
903 ranked.truncate(limit);
904 }
905 }
906 } else {
907 ranked.truncate(limit);
908 }
909 } else {
910 ranked.truncate(limit);
911 }
912
913 if self.importance_enabled && !ranked.is_empty() {
914 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
915 match self.sqlite.fetch_importance_scores(&ids).await {
916 Ok(scores) => {
917 for (msg_id, score) in &mut ranked {
918 if let Some(&imp) = scores.get(msg_id) {
919 *score += imp * self.importance_weight;
920 }
921 }
922 ranked
923 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
924 tracing::debug!(
925 importance_weight = %self.importance_weight,
926 "recall: importance scores blended"
927 );
928 }
929 Err(e) => {
930 tracing::warn!("importance scoring: failed to fetch scores: {e:#}");
931 }
932 }
933 }
934
935 if (self.tier_boost_semantic - 1.0).abs() > f64::EPSILON && !ranked.is_empty() {
939 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
940 match self.sqlite.fetch_tiers(&ids).await {
941 Ok(tiers) => {
942 let bonus = self.tier_boost_semantic - 1.0;
943 let mut boosted = false;
944 for (msg_id, score) in &mut ranked {
945 if tiers.get(msg_id).map(String::as_str) == Some("semantic") {
946 *score += bonus;
947 boosted = true;
948 }
949 }
950 if boosted {
951 ranked.sort_by(|a, b| {
952 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
953 });
954 tracing::debug!(
955 tier_boost = %self.tier_boost_semantic,
956 "recall: semantic tier boost applied"
957 );
958 }
959 }
960 Err(e) => {
961 tracing::warn!("tier boost: failed to fetch tiers: {e:#}");
962 }
963 }
964 }
965
966 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
967
968 if !ids.is_empty()
969 && let Err(e) = self.batch_increment_access_count(ids.clone()).await
970 {
971 tracing::warn!("recall: failed to increment access counts: {e:#}");
972 }
973
974 if let Err(e) = self.sqlite.mark_training_recalled(&ids).await {
976 tracing::debug!(
977 error = %e,
978 "recall: failed to mark training data as recalled (non-fatal)"
979 );
980 }
981
982 let messages = self.sqlite.messages_by_ids(&ids).await?;
983 let msg_map: std::collections::HashMap<MessageId, _> = messages.into_iter().collect();
984
985 let recalled: Vec<RecalledMessage> = ranked
986 .iter()
987 .filter_map(|(msg_id, score)| {
988 msg_map.get(msg_id).map(|msg| RecalledMessage {
989 message: msg.clone(),
990 #[expect(clippy::cast_possible_truncation)]
991 score: *score as f32,
992 })
993 })
994 .collect();
995
996 tracing::debug!(final_count = recalled.len(), "recall: final results");
997
998 Ok(recalled)
999 }
1000
1001 #[cfg_attr(
1010 feature = "profiling",
1011 tracing::instrument(name = "memory.recall", skip_all, fields(query_len = %query.len(), result_count = tracing::field::Empty))
1012 )]
1013 pub async fn recall_routed(
1014 &self,
1015 query: &str,
1016 limit: usize,
1017 filter: Option<SearchFilter>,
1018 router: &dyn crate::router::MemoryRouter,
1019 ) -> Result<Vec<RecalledMessage>, MemoryError> {
1020 use crate::router::MemoryRoute;
1021
1022 let route = router.route(query);
1023 tracing::debug!(?route, query_len = query.len(), "memory routing decision");
1024
1025 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
1026
1027 let (keyword_results, vector_results): (
1028 Vec<(MessageId, f64)>,
1029 Vec<crate::embedding_store::SearchResult>,
1030 ) = match route {
1031 MemoryRoute::Keyword => {
1032 let kw = self.recall_fts5_raw(query, limit, conversation_id).await?;
1033 (kw, Vec::new())
1034 }
1035 MemoryRoute::Semantic => {
1036 let vr = self.recall_vectors_raw(query, limit, filter).await?;
1037 (Vec::new(), vr)
1038 }
1039 MemoryRoute::Hybrid => {
1040 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
1041 Ok(r) => r,
1042 Err(e) => {
1043 tracing::warn!("FTS5 keyword search failed: {e:#}");
1044 Vec::new()
1045 }
1046 };
1047 let vr = self.recall_vectors_raw(query, limit, filter).await?;
1048 (kw, vr)
1049 }
1050 MemoryRoute::Episodic => {
1059 let range = crate::router::resolve_temporal_range(query, chrono::Utc::now());
1060 let cleaned = crate::router::strip_temporal_keywords(query);
1061 let search_query = if cleaned.is_empty() { query } else { &cleaned };
1062 let kw = if let Some(ref r) = range {
1063 self.sqlite
1064 .keyword_search_with_time_range(
1065 search_query,
1066 limit,
1067 conversation_id,
1068 r.after.as_deref(),
1069 r.before.as_deref(),
1070 )
1071 .await?
1072 } else {
1073 self.recall_fts5_raw(search_query, limit, conversation_id)
1074 .await?
1075 };
1076 tracing::debug!(
1077 has_range = range.is_some(),
1078 cleaned_query = %search_query,
1079 keyword_count = kw.len(),
1080 "recall: episodic path"
1081 );
1082 (kw, Vec::new())
1083 }
1084 MemoryRoute::Graph => {
1087 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
1088 Ok(r) => r,
1089 Err(e) => {
1090 tracing::warn!("FTS5 keyword search failed (graph→hybrid fallback): {e:#}");
1091 Vec::new()
1092 }
1093 };
1094 let vr = self.recall_vectors_raw(query, limit, filter).await?;
1095 (kw, vr)
1096 }
1097 };
1098
1099 tracing::debug!(
1100 keyword_count = keyword_results.len(),
1101 vector_count = vector_results.len(),
1102 "recall: routed search results"
1103 );
1104
1105 self.recall_merge_and_rank(keyword_results, vector_results, limit)
1106 .await
1107 }
1108
1109 #[cfg_attr(
1120 feature = "profiling",
1121 tracing::instrument(name = "memory.recall", skip_all, fields(query_len = %query.len(), result_count = tracing::field::Empty))
1122 )]
1123 pub async fn recall_routed_async(
1124 &self,
1125 query: &str,
1126 limit: usize,
1127 filter: Option<crate::embedding_store::SearchFilter>,
1128 router: &dyn crate::router::AsyncMemoryRouter,
1129 ) -> Result<Vec<RecalledMessage>, MemoryError> {
1130 use crate::router::MemoryRoute;
1131
1132 let decision = router.route_async(query).await;
1133 let route = decision.route;
1134 tracing::debug!(
1135 ?route,
1136 confidence = decision.confidence,
1137 query_len = query.len(),
1138 "memory routing decision (async)"
1139 );
1140
1141 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
1142
1143 let (keyword_results, vector_results): (
1144 Vec<(crate::types::MessageId, f64)>,
1145 Vec<crate::embedding_store::SearchResult>,
1146 ) = match route {
1147 MemoryRoute::Keyword => {
1148 let kw = self.recall_fts5_raw(query, limit, conversation_id).await?;
1149 (kw, Vec::new())
1150 }
1151 MemoryRoute::Semantic => {
1152 let vr = self.recall_vectors_raw(query, limit, filter).await?;
1153 (Vec::new(), vr)
1154 }
1155 MemoryRoute::Hybrid => {
1156 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
1157 Ok(r) => r,
1158 Err(e) => {
1159 tracing::warn!("FTS5 keyword search failed: {e:#}");
1160 Vec::new()
1161 }
1162 };
1163 let vr = self.recall_vectors_raw(query, limit, filter).await?;
1164 (kw, vr)
1165 }
1166 MemoryRoute::Episodic => {
1167 let range = crate::router::resolve_temporal_range(query, chrono::Utc::now());
1168 let cleaned = crate::router::strip_temporal_keywords(query);
1169 let search_query = if cleaned.is_empty() { query } else { &cleaned };
1170 let kw = if let Some(ref r) = range {
1171 self.sqlite
1172 .keyword_search_with_time_range(
1173 search_query,
1174 limit,
1175 conversation_id,
1176 r.after.as_deref(),
1177 r.before.as_deref(),
1178 )
1179 .await?
1180 } else {
1181 self.recall_fts5_raw(search_query, limit, conversation_id)
1182 .await?
1183 };
1184 (kw, Vec::new())
1185 }
1186 MemoryRoute::Graph => {
1187 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
1188 Ok(r) => r,
1189 Err(e) => {
1190 tracing::warn!("FTS5 keyword search failed (graph→hybrid fallback): {e:#}");
1191 Vec::new()
1192 }
1193 };
1194 let vr = self.recall_vectors_raw(query, limit, filter).await?;
1195 (kw, vr)
1196 }
1197 };
1198
1199 tracing::debug!(
1200 keyword_count = keyword_results.len(),
1201 vector_count = vector_results.len(),
1202 "recall: routed search results (async)"
1203 );
1204
1205 self.recall_merge_and_rank(keyword_results, vector_results, limit)
1206 .await
1207 }
1208
1209 #[cfg_attr(
1223 feature = "profiling",
1224 tracing::instrument(name = "memory.recall_graph", skip_all, fields(result_count = tracing::field::Empty))
1225 )]
1226 pub async fn recall_graph(
1227 &self,
1228 query: &str,
1229 limit: usize,
1230 max_hops: u32,
1231 at_timestamp: Option<&str>,
1232 temporal_decay_rate: f64,
1233 edge_types: &[crate::graph::EdgeType],
1234 ) -> Result<Vec<crate::graph::types::GraphFact>, MemoryError> {
1235 let Some(store) = &self.graph_store else {
1236 return Ok(Vec::new());
1237 };
1238
1239 tracing::debug!(
1240 query_len = query.len(),
1241 limit,
1242 max_hops,
1243 "graph: starting recall"
1244 );
1245
1246 let results = crate::graph::retrieval::graph_recall(
1247 store,
1248 self.qdrant.as_deref(),
1249 &self.provider,
1250 query,
1251 limit,
1252 max_hops,
1253 at_timestamp,
1254 temporal_decay_rate,
1255 edge_types,
1256 )
1257 .await?;
1258
1259 tracing::debug!(result_count = results.len(), "graph: recall complete");
1260 #[cfg(feature = "profiling")]
1261 tracing::Span::current().record("result_count", results.len());
1262
1263 Ok(results)
1264 }
1265
1266 #[cfg_attr(
1275 feature = "profiling",
1276 tracing::instrument(name = "memory.recall_graph", skip_all, fields(result_count = tracing::field::Empty))
1277 )]
1278 pub async fn recall_graph_activated(
1279 &self,
1280 query: &str,
1281 limit: usize,
1282 params: crate::graph::SpreadingActivationParams,
1283 edge_types: &[crate::graph::EdgeType],
1284 ) -> Result<Vec<crate::graph::activation::ActivatedFact>, MemoryError> {
1285 let Some(store) = &self.graph_store else {
1286 return Ok(Vec::new());
1287 };
1288
1289 tracing::debug!(
1290 query_len = query.len(),
1291 limit,
1292 "spreading activation: starting graph recall"
1293 );
1294
1295 let embeddings = self.qdrant.as_deref();
1296 let results = crate::graph::retrieval::graph_recall_activated(
1297 store,
1298 embeddings,
1299 &self.provider,
1300 query,
1301 limit,
1302 params,
1303 edge_types,
1304 )
1305 .await?;
1306
1307 tracing::debug!(
1308 result_count = results.len(),
1309 "spreading activation: graph recall complete"
1310 );
1311
1312 Ok(results)
1313 }
1314
1315 async fn batch_increment_access_count(
1323 &self,
1324 message_ids: Vec<MessageId>,
1325 ) -> Result<(), MemoryError> {
1326 if message_ids.is_empty() {
1327 return Ok(());
1328 }
1329 self.sqlite.increment_access_counts(&message_ids).await
1330 }
1331
1332 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
1338 match &self.qdrant {
1339 Some(qdrant) => qdrant.has_embedding(message_id).await,
1340 None => Ok(false),
1341 }
1342 }
1343
1344 pub async fn embed_missing(
1360 &self,
1361 progress_tx: Option<tokio::sync::watch::Sender<Option<super::BackfillProgress>>>,
1362 ) -> Result<usize, MemoryError> {
1363 if self.qdrant.is_none() || !self.effective_embed_provider().supports_embeddings() {
1364 return Ok(0);
1365 }
1366
1367 let total = self.sqlite.count_unembedded_messages().await?;
1368 if total == 0 {
1369 return Ok(0);
1370 }
1371
1372 if let Some(tx) = &progress_tx {
1373 let _ = tx.send(Some(super::BackfillProgress { done: 0, total }));
1374 }
1375
1376 let mut done = 0usize;
1377 let mut succeeded = 0usize;
1378
1379 loop {
1380 const BATCH_SIZE: usize = 32;
1381 const BATCH_SIZE_I64: i64 = 32;
1382 let rows: Vec<_> = self
1383 .sqlite
1384 .stream_unembedded_messages(BATCH_SIZE_I64)
1385 .try_collect()
1386 .await?;
1387
1388 if rows.is_empty() {
1389 break;
1390 }
1391
1392 let batch_len = rows.len();
1393
1394 let results: Vec<bool> = futures::stream::iter(rows)
1395 .map(|(msg_id, conv_id, role, content)| async move {
1396 self.embed_and_store_regular(msg_id, conv_id, &role, &content)
1397 })
1398 .buffer_unordered(4)
1399 .collect()
1400 .await;
1401
1402 for ok in &results {
1403 done += 1;
1404 if *ok {
1405 succeeded += 1;
1406 }
1407 if let Some(tx) = &progress_tx {
1408 let _ = tx.send(Some(super::BackfillProgress { done, total }));
1409 }
1410 }
1411
1412 let batch_succeeded = results.iter().filter(|&&b| b).count();
1413 if batch_succeeded > 0 {
1414 tracing::debug!("Backfill batch: {batch_succeeded}/{batch_len} embedded");
1415 }
1416
1417 if batch_len < BATCH_SIZE {
1418 break;
1419 }
1420 }
1421
1422 if let Some(tx) = &progress_tx {
1423 let _ = tx.send(None);
1424 }
1425
1426 if done > 0 {
1427 tracing::info!("Embedded {succeeded}/{total} missing messages");
1428 }
1429 Ok(succeeded)
1430 }
1431}
1432
1433#[cfg(test)]
1434mod tests {
1435 use super::*;
1436
1437 #[test]
1438 fn embed_context_default_all_none() {
1439 let ctx = EmbedContext::default();
1440 assert!(ctx.tool_name.is_none());
1441 assert!(ctx.exit_code.is_none());
1442 assert!(ctx.timestamp.is_none());
1443 }
1444
1445 #[test]
1446 fn embed_context_fields_set_correctly() {
1447 let ctx = EmbedContext {
1448 tool_name: Some("shell".to_string()),
1449 exit_code: Some(0),
1450 timestamp: Some("2026-04-04T00:00:00Z".to_string()),
1451 };
1452 assert_eq!(ctx.tool_name.as_deref(), Some("shell"));
1453 assert_eq!(ctx.exit_code, Some(0));
1454 assert_eq!(ctx.timestamp.as_deref(), Some("2026-04-04T00:00:00Z"));
1455 }
1456
1457 #[test]
1458 fn embed_context_non_zero_exit_code() {
1459 let ctx = EmbedContext {
1460 tool_name: Some("shell".to_string()),
1461 exit_code: Some(1),
1462 timestamp: None,
1463 };
1464 assert_eq!(ctx.exit_code, Some(1));
1465 assert!(ctx.timestamp.is_none());
1466 }
1467
1468 async fn make_semantic_memory() -> crate::semantic::SemanticMemory {
1469 use std::sync::Arc;
1470 use std::sync::atomic::AtomicU64;
1471 use zeph_llm::any::AnyProvider;
1472 use zeph_llm::mock::MockProvider;
1473
1474 let provider = AnyProvider::Mock(MockProvider::default());
1475 let sqlite = crate::store::SqliteStore::new(":memory:").await.unwrap();
1476 crate::semantic::SemanticMemory {
1477 sqlite,
1478 qdrant: None,
1479 provider,
1480 embed_provider: None,
1481 embedding_model: "test-model".into(),
1482 vector_weight: 0.7,
1483 keyword_weight: 0.3,
1484 temporal_decay_enabled: false,
1485 temporal_decay_half_life_days: 30,
1486 mmr_enabled: false,
1487 mmr_lambda: 0.7,
1488 importance_enabled: false,
1489 importance_weight: 0.15,
1490 token_counter: Arc::new(crate::token_counter::TokenCounter::new()),
1491 graph_store: None,
1492 community_detection_failures: Arc::new(AtomicU64::new(0)),
1493 graph_extraction_count: Arc::new(AtomicU64::new(0)),
1494 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
1495 last_qdrant_warn: Arc::new(AtomicU64::new(0)),
1496 tier_boost_semantic: 1.3,
1497 admission_control: None,
1498 key_facts_dedup_threshold: 0.95,
1499 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
1500 }
1501 }
1502
1503 #[tokio::test]
1504 async fn spawn_embed_bg_returns_true_when_capacity_available() {
1505 let memory = make_semantic_memory().await;
1506 let dispatched = memory.spawn_embed_bg(std::future::ready(()));
1507 assert!(
1508 dispatched,
1509 "spawn_embed_bg must return true when a task was successfully spawned"
1510 );
1511 }
1512
1513 #[tokio::test]
1514 async fn spawn_embed_bg_returns_false_at_capacity() {
1515 let memory = make_semantic_memory().await;
1516
1517 {
1519 let mut tasks = memory.embed_tasks.lock().unwrap();
1520 for _ in 0..MAX_EMBED_BG_TASKS {
1521 tasks.spawn(std::future::pending::<()>());
1522 }
1523 }
1524
1525 let dispatched = memory.spawn_embed_bg(std::future::ready(()));
1526 assert!(
1527 !dispatched,
1528 "spawn_embed_bg must return false when the task limit is reached"
1529 );
1530 }
1531
1532 #[test]
1533 fn qdrant_warn_rate_limit_suppresses_within_window() {
1534 use std::sync::Arc;
1535 use std::sync::atomic::{AtomicU64, Ordering};
1536
1537 let last_warn = Arc::new(AtomicU64::new(0));
1538 let window_secs = 10u64;
1539
1540 let now1 = 100u64;
1542 let last1 = last_warn.load(Ordering::Relaxed);
1543 let should_warn1 = now1.saturating_sub(last1) >= window_secs;
1544 assert!(should_warn1, "first call must not be suppressed");
1545 if should_warn1 {
1546 last_warn.store(now1, Ordering::Relaxed);
1547 }
1548
1549 let now2 = 105u64;
1551 let last2 = last_warn.load(Ordering::Relaxed);
1552 let should_warn2 = now2.saturating_sub(last2) >= window_secs;
1553 assert!(!should_warn2, "call within 10s window must be suppressed");
1554
1555 let now3 = 110u64;
1557 let last3 = last_warn.load(Ordering::Relaxed);
1558 let should_warn3 = now3.saturating_sub(last3) >= window_secs;
1559 assert!(
1560 should_warn3,
1561 "call after window expiry must not be suppressed"
1562 );
1563 }
1564
1565 #[test]
1566 fn qdrant_warn_rate_limit_shared_across_concurrent_sites() {
1567 use std::sync::Arc;
1568 use std::sync::atomic::{AtomicU64, Ordering};
1569
1570 let shared = Arc::new(AtomicU64::new(0));
1573 let window_secs = 10u64;
1574
1575 let site_a = Arc::clone(&shared);
1576 let site_b = Arc::clone(&shared);
1577
1578 let now_a = 100u64;
1579 let last_a = site_a.load(Ordering::Relaxed);
1580 if now_a.saturating_sub(last_a) >= window_secs {
1581 site_a.store(now_a, Ordering::Relaxed);
1582 }
1583
1584 let now_b = 105u64;
1585 let last_b = site_b.load(Ordering::Relaxed);
1586 let warn_b = now_b.saturating_sub(last_b) >= window_secs;
1587 assert!(
1588 !warn_b,
1589 "site B must be suppressed because site A already warned within the window"
1590 );
1591 }
1592}