1use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6
7use futures::{StreamExt as _, TryStreamExt as _};
8use zeph_llm::provider::{LlmProvider as _, Message};
9
10const CHARS_PER_TOKEN: usize = 4;
12
13const CHUNK_CHARS: usize = 400 * CHARS_PER_TOKEN;
15
16const CHUNK_OVERLAP_CHARS: usize = 80 * CHARS_PER_TOKEN;
18
19fn chunk_text(text: &str) -> Vec<&str> {
25 if text.len() <= CHUNK_CHARS {
26 return vec![text];
27 }
28
29 let mut chunks = Vec::new();
30 let mut start = 0;
31
32 while start < text.len() {
33 let end = if start + CHUNK_CHARS >= text.len() {
34 text.len()
35 } else {
36 let boundary = text.floor_char_boundary(start + CHUNK_CHARS);
38 let slice = &text[start..boundary];
40 if let Some(pos) = slice.rfind("\n\n") {
41 start + pos + 2
42 } else if let Some(pos) = slice.rfind('\n') {
43 start + pos + 1
44 } else if let Some(pos) = slice.rfind(' ') {
45 start + pos + 1
46 } else {
47 boundary
48 }
49 };
50
51 chunks.push(&text[start..end]);
52 if end >= text.len() {
53 break;
54 }
55 let next = end.saturating_sub(CHUNK_OVERLAP_CHARS);
59 let new_start = text.ceil_char_boundary(next);
60 start = if new_start > start { new_start } else { end };
61 }
62
63 chunks
64}
65
66use crate::admission::log_admission_decision;
67use crate::embedding_store::{MessageKind, SearchFilter};
68use crate::error::MemoryError;
69use crate::types::{ConversationId, MessageId};
70
71use super::SemanticMemory;
72use super::algorithms::{apply_mmr, apply_temporal_decay};
73
74#[derive(Debug, Clone, Default)]
78pub struct EmbedContext {
79 pub tool_name: Option<String>,
80 pub exit_code: Option<i32>,
81 pub timestamp: Option<String>,
82}
83
84#[derive(Debug)]
85pub struct RecalledMessage {
86 pub message: Message,
87 pub score: f32,
88}
89
90const MAX_EMBED_BG_TASKS: usize = 64;
92
93struct EmbedBgArgs {
95 qdrant: Arc<crate::embedding_store::EmbeddingStore>,
96 embed_provider: zeph_llm::any::AnyProvider,
97 embedding_model: String,
98 message_id: MessageId,
99 conversation_id: ConversationId,
100 role: String,
101 content: String,
102 last_qdrant_warn: Arc<AtomicU64>,
103}
104
105async fn embed_and_store_regular_bg(args: EmbedBgArgs) {
109 let EmbedBgArgs {
110 qdrant,
111 embed_provider,
112 embedding_model,
113 message_id,
114 conversation_id,
115 role,
116 content,
117 last_qdrant_warn,
118 } = args;
119 let chunks = chunk_text(&content);
120 let chunk_count = chunks.len();
121
122 let vectors = match embed_provider.embed_batch(&chunks).await {
123 Ok(v) => v,
124 Err(e) => {
125 tracing::warn!("bg embed_regular: failed to embed chunks for msg {message_id}: {e:#}");
126 return;
127 }
128 };
129
130 let Some(first) = vectors.first() else {
131 return;
132 };
133 let vector_size = first.len() as u64;
134 if let Err(e) = qdrant.ensure_collection(vector_size).await {
135 let now = std::time::SystemTime::now()
136 .duration_since(std::time::UNIX_EPOCH)
137 .unwrap_or_default()
138 .as_secs();
139 let last = last_qdrant_warn.load(Ordering::Relaxed);
140 if now.saturating_sub(last) >= 10 {
141 last_qdrant_warn.store(now, Ordering::Relaxed);
142 tracing::warn!("bg embed_regular: failed to ensure Qdrant collection: {e:#}");
143 } else {
144 tracing::debug!(
145 "bg embed_regular: failed to ensure Qdrant collection (suppressed): {e:#}"
146 );
147 }
148 return;
149 }
150
151 for (chunk_index, vector) in vectors.into_iter().enumerate() {
152 let chunk_index_u32 = u32::try_from(chunk_index).unwrap_or(u32::MAX);
153 if let Err(e) = qdrant
154 .store(
155 message_id,
156 conversation_id,
157 &role,
158 vector,
159 MessageKind::Regular,
160 &embedding_model,
161 chunk_index_u32,
162 )
163 .await
164 {
165 tracing::warn!(
166 "bg embed_regular: failed to store chunk {chunk_index}/{chunk_count} \
167 for msg {message_id}: {e:#}"
168 );
169 }
170 }
171}
172
173async fn embed_chunks_with_tool_context_bg(args: EmbedBgArgs, embed_ctx: EmbedContext) {
177 let EmbedBgArgs {
178 qdrant,
179 embed_provider,
180 embedding_model,
181 message_id,
182 conversation_id,
183 role,
184 content,
185 last_qdrant_warn,
186 } = args;
187 let chunks = chunk_text(&content);
188 let chunk_count = chunks.len();
189
190 let vectors = match embed_provider.embed_batch(&chunks).await {
191 Ok(v) => v,
192 Err(e) => {
193 tracing::warn!(
194 "bg embed_tool: failed to embed tool-output chunks for msg {message_id}: {e:#}"
195 );
196 return;
197 }
198 };
199
200 if let Some(first) = vectors.first() {
201 let vector_size = first.len() as u64;
202 if let Err(e) = qdrant.ensure_collection(vector_size).await {
203 let now = std::time::SystemTime::now()
204 .duration_since(std::time::UNIX_EPOCH)
205 .unwrap_or_default()
206 .as_secs();
207 let last = last_qdrant_warn.load(Ordering::Relaxed);
208 if now.saturating_sub(last) >= 10 {
209 last_qdrant_warn.store(now, Ordering::Relaxed);
210 tracing::warn!("bg embed_tool: failed to ensure Qdrant collection: {e:#}");
211 } else {
212 tracing::debug!(
213 "bg embed_tool: failed to ensure Qdrant collection (suppressed): {e:#}"
214 );
215 }
216 return;
217 }
218 }
219
220 for (chunk_index, vector) in vectors.into_iter().enumerate() {
221 let chunk_index_u32 = u32::try_from(chunk_index).unwrap_or(u32::MAX);
222 let result = if let Some(ref tool_name) = embed_ctx.tool_name {
223 qdrant
224 .store_with_tool_context(
225 message_id,
226 conversation_id,
227 &role,
228 vector,
229 MessageKind::Regular,
230 &embedding_model,
231 chunk_index_u32,
232 tool_name,
233 embed_ctx.exit_code,
234 embed_ctx.timestamp.as_deref(),
235 )
236 .await
237 .map(|_| ())
238 } else {
239 qdrant
240 .store(
241 message_id,
242 conversation_id,
243 &role,
244 vector,
245 MessageKind::Regular,
246 &embedding_model,
247 chunk_index_u32,
248 )
249 .await
250 .map(|_| ())
251 };
252 if let Err(e) = result {
253 tracing::warn!(
254 "bg embed_tool: failed to store chunk {chunk_index}/{chunk_count} \
255 for msg {message_id}: {e:#}"
256 );
257 }
258 }
259}
260
261async fn embed_and_store_with_category_bg(args: EmbedBgArgs, category: Option<String>) {
265 let EmbedBgArgs {
266 qdrant,
267 embed_provider,
268 embedding_model,
269 message_id,
270 conversation_id,
271 role,
272 content,
273 last_qdrant_warn,
274 } = args;
275 let chunks = chunk_text(&content);
276 let chunk_count = chunks.len();
277
278 let vectors = match embed_provider.embed_batch(&chunks).await {
279 Ok(v) => v,
280 Err(e) => {
281 tracing::warn!(
282 "bg embed_category: failed to embed categorized chunks for msg {message_id}: {e:#}"
283 );
284 return;
285 }
286 };
287
288 let Some(first) = vectors.first() else {
289 return;
290 };
291 let vector_size = first.len() as u64;
292 if let Err(e) = qdrant.ensure_collection(vector_size).await {
293 let now = std::time::SystemTime::now()
294 .duration_since(std::time::UNIX_EPOCH)
295 .unwrap_or_default()
296 .as_secs();
297 let last = last_qdrant_warn.load(Ordering::Relaxed);
298 if now.saturating_sub(last) >= 10 {
299 last_qdrant_warn.store(now, Ordering::Relaxed);
300 tracing::warn!("bg embed_category: failed to ensure Qdrant collection: {e:#}");
301 } else {
302 tracing::debug!(
303 "bg embed_category: failed to ensure Qdrant collection (suppressed): {e:#}"
304 );
305 }
306 return;
307 }
308
309 for (chunk_index, vector) in vectors.into_iter().enumerate() {
310 let chunk_index_u32 = u32::try_from(chunk_index).unwrap_or(u32::MAX);
311 if let Err(e) = qdrant
312 .store_with_category(
313 message_id,
314 conversation_id,
315 &role,
316 vector,
317 MessageKind::Regular,
318 &embedding_model,
319 chunk_index_u32,
320 category.as_deref(),
321 )
322 .await
323 {
324 tracing::warn!(
325 "bg embed_category: failed to store chunk {chunk_index}/{chunk_count} \
326 for msg {message_id}: {e:#}"
327 );
328 }
329 }
330}
331
332impl SemanticMemory {
333 #[cfg_attr(
343 feature = "profiling",
344 tracing::instrument(name = "memory.remember", skip_all, fields(content_len = %content.len()))
345 )]
346 pub async fn remember(
347 &self,
348 conversation_id: ConversationId,
349 role: &str,
350 content: &str,
351 goal_text: Option<&str>,
352 ) -> Result<Option<MessageId>, MemoryError> {
353 if let Some(ref admission) = self.admission_control {
355 let decision = admission
356 .evaluate(
357 content,
358 role,
359 self.effective_embed_provider(),
360 self.qdrant.as_ref(),
361 goal_text,
362 )
363 .await;
364 let preview: String = content.chars().take(100).collect();
365 log_admission_decision(&decision, &preview, role, admission.threshold());
366 if !decision.admitted {
367 return Ok(None);
368 }
369 }
370
371 if let Some(gate) = &self.quality_gate
372 && gate
373 .evaluate(content, self.effective_embed_provider(), &[])
374 .await
375 .is_some()
376 {
377 return Ok(None);
378 }
379
380 let message_id = self
381 .sqlite
382 .save_message(conversation_id, role, content)
383 .await?;
384
385 self.embed_and_store_regular(message_id, conversation_id, role, content);
386
387 Ok(Some(message_id))
388 }
389
390 #[cfg_attr(
399 feature = "profiling",
400 tracing::instrument(name = "memory.remember", skip_all, fields(content_len = %content.len()))
401 )]
402 pub async fn remember_with_parts(
403 &self,
404 conversation_id: ConversationId,
405 role: &str,
406 content: &str,
407 parts_json: &str,
408 goal_text: Option<&str>,
409 ) -> Result<(Option<MessageId>, bool), MemoryError> {
410 if let Some(ref admission) = self.admission_control {
412 let decision = admission
413 .evaluate(
414 content,
415 role,
416 self.effective_embed_provider(),
417 self.qdrant.as_ref(),
418 goal_text,
419 )
420 .await;
421 let preview: String = content.chars().take(100).collect();
422 log_admission_decision(&decision, &preview, role, admission.threshold());
423 if !decision.admitted {
424 return Ok((None, false));
425 }
426 }
427
428 if let Some(gate) = &self.quality_gate
429 && gate
430 .evaluate(content, self.effective_embed_provider(), &[])
431 .await
432 .is_some()
433 {
434 return Ok((None, false));
435 }
436
437 let message_id = self
438 .sqlite
439 .save_message_with_parts(conversation_id, role, content, parts_json)
440 .await?;
441
442 let embedding_stored =
443 self.embed_and_store_regular(message_id, conversation_id, role, content);
444
445 Ok((Some(message_id), embedding_stored))
446 }
447
448 #[cfg_attr(
460 feature = "profiling",
461 tracing::instrument(name = "memory.remember", skip_all, fields(content_len = %content.len()))
462 )]
463 pub async fn remember_tool_output(
464 &self,
465 conversation_id: ConversationId,
466 role: &str,
467 content: &str,
468 parts_json: &str,
469 embed_ctx: EmbedContext,
470 ) -> Result<(Option<MessageId>, bool), MemoryError> {
471 if let Some(ref admission) = self.admission_control {
472 let decision = admission
473 .evaluate(
474 content,
475 role,
476 self.effective_embed_provider(),
477 self.qdrant.as_ref(),
478 None,
479 )
480 .await;
481 let preview: String = content.chars().take(100).collect();
482 log_admission_decision(&decision, &preview, role, admission.threshold());
483 if !decision.admitted {
484 return Ok((None, false));
485 }
486 }
487
488 let message_id = self
489 .sqlite
490 .save_message_with_parts(conversation_id, role, content, parts_json)
491 .await?;
492
493 let embedding_stored = self.embed_chunks_with_tool_context(
494 message_id,
495 conversation_id,
496 role,
497 content,
498 embed_ctx,
499 );
500
501 Ok((Some(message_id), embedding_stored))
502 }
503
504 #[cfg_attr(
515 feature = "profiling",
516 tracing::instrument(name = "memory.remember", skip_all, fields(content_len = %content.len()))
517 )]
518 pub async fn remember_categorized(
519 &self,
520 conversation_id: ConversationId,
521 role: &str,
522 content: &str,
523 category: Option<&str>,
524 goal_text: Option<&str>,
525 ) -> Result<Option<MessageId>, MemoryError> {
526 if let Some(ref admission) = self.admission_control {
527 let decision = admission
528 .evaluate(
529 content,
530 role,
531 self.effective_embed_provider(),
532 self.qdrant.as_ref(),
533 goal_text,
534 )
535 .await;
536 let preview: String = content.chars().take(100).collect();
537 log_admission_decision(&decision, &preview, role, admission.threshold());
538 if !decision.admitted {
539 return Ok(None);
540 }
541 }
542
543 let message_id = self
544 .sqlite
545 .save_message_with_category(conversation_id, role, content, category)
546 .await?;
547
548 self.embed_and_store_with_category(message_id, conversation_id, role, content, category);
549
550 Ok(Some(message_id))
551 }
552
553 pub async fn recall_with_category(
561 &self,
562 query: &str,
563 limit: usize,
564 filter: Option<SearchFilter>,
565 category: Option<&str>,
566 ) -> Result<Vec<RecalledMessage>, MemoryError> {
567 let filter_with_category = filter.map(|mut f| {
568 f.category = category.map(str::to_owned);
569 f
570 });
571 self.recall(query, limit, filter_with_category).await
572 }
573
574 pub fn reap_embed_tasks(&self) {
578 if let Ok(mut tasks) = self.embed_tasks.lock() {
579 while tasks.try_join_next().is_some() {}
580 }
581 }
582
583 fn spawn_embed_bg<F>(&self, fut: F) -> bool
587 where
588 F: std::future::Future<Output = ()> + Send + 'static,
589 {
590 let Ok(mut tasks) = self.embed_tasks.lock() else {
591 return false;
592 };
593 while tasks.try_join_next().is_some() {}
595 if tasks.len() >= MAX_EMBED_BG_TASKS {
596 tracing::debug!("background embed task limit reached, skipping");
597 return false;
598 }
599 tasks.spawn(fut);
600 true
601 }
602
603 fn embed_and_store_with_category(
607 &self,
608 message_id: MessageId,
609 conversation_id: ConversationId,
610 role: &str,
611 content: &str,
612 category: Option<&str>,
613 ) -> bool {
614 let Some(qdrant) = self.qdrant.clone() else {
615 return false;
616 };
617 let embed_provider = self.effective_embed_provider().clone();
618 if !embed_provider.supports_embeddings() {
619 return false;
620 }
621 self.spawn_embed_bg(embed_and_store_with_category_bg(
622 EmbedBgArgs {
623 qdrant,
624 embed_provider,
625 embedding_model: self.embedding_model.clone(),
626 message_id,
627 conversation_id,
628 role: role.to_owned(),
629 content: content.to_owned(),
630 last_qdrant_warn: Arc::clone(&self.last_qdrant_warn),
631 },
632 category.map(str::to_owned),
633 ))
634 }
635
636 fn embed_and_store_regular(
640 &self,
641 message_id: MessageId,
642 conversation_id: ConversationId,
643 role: &str,
644 content: &str,
645 ) -> bool {
646 let Some(qdrant) = self.qdrant.clone() else {
647 return false;
648 };
649 let embed_provider = self.effective_embed_provider().clone();
650 if !embed_provider.supports_embeddings() {
651 return false;
652 }
653 self.spawn_embed_bg(embed_and_store_regular_bg(EmbedBgArgs {
654 qdrant,
655 embed_provider,
656 embedding_model: self.embedding_model.clone(),
657 message_id,
658 conversation_id,
659 role: role.to_owned(),
660 content: content.to_owned(),
661 last_qdrant_warn: Arc::clone(&self.last_qdrant_warn),
662 }))
663 }
664
665 fn embed_chunks_with_tool_context(
669 &self,
670 message_id: MessageId,
671 conversation_id: ConversationId,
672 role: &str,
673 content: &str,
674 embed_ctx: EmbedContext,
675 ) -> bool {
676 let Some(qdrant) = self.qdrant.clone() else {
677 return false;
678 };
679 let embed_provider = self.effective_embed_provider().clone();
680 if !embed_provider.supports_embeddings() {
681 return false;
682 }
683 self.spawn_embed_bg(embed_chunks_with_tool_context_bg(
684 EmbedBgArgs {
685 qdrant,
686 embed_provider,
687 embedding_model: self.embedding_model.clone(),
688 message_id,
689 conversation_id,
690 role: role.to_owned(),
691 content: content.to_owned(),
692 last_qdrant_warn: Arc::clone(&self.last_qdrant_warn),
693 },
694 embed_ctx,
695 ))
696 }
697
698 pub async fn save_only(
706 &self,
707 conversation_id: ConversationId,
708 role: &str,
709 content: &str,
710 parts_json: &str,
711 ) -> Result<MessageId, MemoryError> {
712 self.sqlite
713 .save_message_with_parts(conversation_id, role, content, parts_json)
714 .await
715 }
716
717 #[cfg_attr(
727 feature = "profiling",
728 tracing::instrument(name = "memory.recall", skip_all, fields(query_len = %query.len(), result_count = tracing::field::Empty, top_score = tracing::field::Empty))
729 )]
730 pub async fn recall(
731 &self,
732 query: &str,
733 limit: usize,
734 filter: Option<SearchFilter>,
735 ) -> Result<Vec<RecalledMessage>, MemoryError> {
736 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
737
738 tracing::debug!(
739 query_len = query.len(),
740 limit,
741 has_filter = filter.is_some(),
742 conversation_id = conversation_id.map(|c| c.0),
743 has_qdrant = self.qdrant.is_some(),
744 "recall: starting hybrid search"
745 );
746
747 let keyword_results = match self
748 .sqlite
749 .keyword_search(query, self.effective_depth(limit), conversation_id)
750 .await
751 {
752 Ok(results) => results,
753 Err(e) => {
754 tracing::warn!("FTS5 keyword search failed: {e:#}");
755 Vec::new()
756 }
757 };
758
759 let vector_results = if let Some(qdrant) = &self.qdrant
760 && self.effective_embed_provider().supports_embeddings()
761 {
762 let embed_input = self.apply_search_prompt(query);
763 let query_vector = match tokio::time::timeout(
764 self.embed_timeout,
765 self.effective_embed_provider().embed(&embed_input),
766 )
767 .await
768 {
769 Ok(Ok(v)) => v,
770 Ok(Err(e)) => return Err(e.into()),
771 Err(_) => {
772 tracing::warn!("recall_semantic: embed timed out, returning empty results");
773 return Ok(Vec::new());
774 }
775 };
776 let query_vector = self.apply_query_bias(query, query_vector).await;
777 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
778 qdrant.ensure_collection(vector_size).await?;
779 qdrant
780 .search(&query_vector, self.effective_depth(limit), filter)
781 .await?
782 } else {
783 Vec::new()
784 };
785
786 let results = self
787 .recall_merge_and_rank(keyword_results, vector_results, limit, None)
788 .await?;
789 #[cfg(feature = "profiling")]
790 {
791 let span = tracing::Span::current();
792 span.record("result_count", results.len());
793 if let Some(top) = results.first() {
794 span.record("top_score", top.score);
795 }
796 }
797 Ok(results)
798 }
799
800 #[cfg_attr(
801 feature = "profiling",
802 tracing::instrument(name = "memory.recall.fts5", skip_all, fields(query_len = %query.len()))
803 )]
804 pub(super) async fn recall_fts5_raw(
805 &self,
806 query: &str,
807 limit: usize,
808 conversation_id: Option<ConversationId>,
809 ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
810 self.sqlite
811 .keyword_search(query, self.effective_depth(limit), conversation_id)
812 .await
813 }
814
815 #[cfg_attr(
816 feature = "profiling",
817 tracing::instrument(name = "memory.recall.vectors", skip_all, fields(query_len = %query.len()))
818 )]
819 pub(super) async fn recall_vectors_raw(
820 &self,
821 query: &str,
822 limit: usize,
823 filter: Option<SearchFilter>,
824 ) -> Result<Vec<crate::embedding_store::SearchResult>, MemoryError> {
825 let Some(qdrant) = &self.qdrant else {
826 return Ok(Vec::new());
827 };
828 if !self.effective_embed_provider().supports_embeddings() {
829 return Ok(Vec::new());
830 }
831 let embed_input = self.apply_search_prompt(query);
832 let query_vector = match tokio::time::timeout(
833 self.embed_timeout,
834 self.effective_embed_provider().embed(&embed_input),
835 )
836 .await
837 {
838 Ok(Ok(v)) => v,
839 Ok(Err(e)) => return Err(e.into()),
840 Err(_) => {
841 tracing::warn!("recall_vectors_raw: embed timed out, returning empty results");
842 return Ok(Vec::new());
843 }
844 };
845 let query_vector = self.apply_query_bias(query, query_vector).await;
846 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
847 qdrant.ensure_collection(vector_size).await?;
848 qdrant
849 .search(&query_vector, self.effective_depth(limit), filter)
850 .await
851 }
852
853 #[cfg_attr(
862 feature = "profiling",
863 tracing::instrument(name = "memory.recall.merge_and_rank", skip_all, fields(kw_count = keyword_results.len(), vec_count = vector_results.len()))
864 )]
865 #[allow(clippy::cast_possible_truncation, clippy::too_many_lines)]
866 pub(super) async fn recall_merge_and_rank(
867 &self,
868 keyword_results: Vec<(MessageId, f64)>,
869 vector_results: Vec<crate::embedding_store::SearchResult>,
870 limit: usize,
871 goal_entity_id: Option<i64>,
872 ) -> Result<Vec<RecalledMessage>, MemoryError> {
873 tracing::debug!(
874 vector_count = vector_results.len(),
875 keyword_count = keyword_results.len(),
876 limit,
877 "recall: merging search results"
878 );
879
880 let mut scores: std::collections::HashMap<MessageId, f64> =
881 std::collections::HashMap::new();
882
883 if !vector_results.is_empty() {
884 let max_vs = vector_results
885 .iter()
886 .map(|r| r.score)
887 .fold(f32::NEG_INFINITY, f32::max);
888 let norm = if max_vs > 0.0 { max_vs } else { 1.0 };
889 for r in &vector_results {
890 let normalized = f64::from(r.score / norm);
891 *scores.entry(r.message_id).or_default() += normalized * self.vector_weight;
892 }
893 }
894
895 if !keyword_results.is_empty() {
896 let max_ks = keyword_results
897 .iter()
898 .map(|r| r.1)
899 .fold(f64::NEG_INFINITY, f64::max);
900 let norm = if max_ks > 0.0 { max_ks } else { 1.0 };
901 for &(msg_id, score) in &keyword_results {
902 let normalized = score / norm;
903 *scores.entry(msg_id).or_default() += normalized * self.keyword_weight;
904 }
905 }
906
907 if scores.is_empty() {
908 tracing::debug!("recall: empty merge, no overlapping scores");
909 return Ok(Vec::new());
910 }
911
912 let mut ranked: Vec<(MessageId, f64)> = scores.into_iter().collect();
913 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
914
915 tracing::debug!(
916 merged = ranked.len(),
917 top_score = ranked.first().map(|r| r.1),
918 bottom_score = ranked.last().map(|r| r.1),
919 vector_weight = %self.vector_weight,
920 keyword_weight = %self.keyword_weight,
921 "recall: weighted merge complete"
922 );
923
924 if self.temporal_decay.is_enabled() && self.temporal_decay_half_life_days > 0 {
925 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
926 match self.sqlite.message_timestamps(&ids).await {
927 Ok(timestamps) => {
928 apply_temporal_decay(
929 &mut ranked,
930 ×tamps,
931 self.temporal_decay_half_life_days,
932 );
933 ranked
934 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
935 tracing::debug!(
936 half_life_days = self.temporal_decay_half_life_days,
937 top_score_after = ranked.first().map(|r| r.1),
938 "recall: temporal decay applied"
939 );
940 }
941 Err(e) => {
942 tracing::warn!("temporal decay: failed to fetch timestamps: {e:#}");
943 }
944 }
945 }
946
947 if self.mmr_reranking.is_enabled() && !vector_results.is_empty() {
948 if let Some(qdrant) = &self.qdrant {
949 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
950 match qdrant.get_vectors(&ids).await {
951 Ok(vec_map) if !vec_map.is_empty() => {
952 let ranked_len_before = ranked.len();
953 ranked = apply_mmr(&ranked, &vec_map, self.mmr_lambda, limit);
954 tracing::debug!(
955 before = ranked_len_before,
956 after = ranked.len(),
957 lambda = %self.mmr_lambda,
958 "recall: mmr re-ranked"
959 );
960 }
961 Ok(_) => {
962 ranked.truncate(limit);
963 }
964 Err(e) => {
965 tracing::warn!("MMR: failed to fetch vectors: {e:#}");
966 ranked.truncate(limit);
967 }
968 }
969 } else {
970 ranked.truncate(limit);
971 }
972 } else {
973 ranked.truncate(limit);
974 }
975
976 if self.importance_scoring.is_enabled() && !ranked.is_empty() {
977 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
978 match self.sqlite.fetch_importance_scores(&ids).await {
979 Ok(scores) => {
980 for (msg_id, score) in &mut ranked {
981 if let Some(&imp) = scores.get(msg_id) {
982 *score += imp * self.importance_weight;
983 }
984 }
985 ranked
986 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
987 tracing::debug!(
988 importance_weight = %self.importance_weight,
989 "recall: importance scores blended"
990 );
991 }
992 Err(e) => {
993 tracing::warn!("importance scoring: failed to fetch scores: {e:#}");
994 }
995 }
996 }
997
998 if (self.tier_boost_semantic - 1.0).abs() > f64::EPSILON && !ranked.is_empty() {
1002 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
1003 match self.sqlite.fetch_tiers(&ids).await {
1004 Ok(tiers) => {
1005 let bonus = self.tier_boost_semantic - 1.0;
1006 let mut boosted = false;
1007 for (msg_id, score) in &mut ranked {
1008 if tiers.get(msg_id).map(String::as_str) == Some("semantic") {
1009 *score += bonus;
1010 boosted = true;
1011 }
1012 }
1013 if boosted {
1014 ranked.sort_by(|a, b| {
1015 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
1016 });
1017 tracing::debug!(
1018 tier_boost = %self.tier_boost_semantic,
1019 "recall: semantic tier boost applied"
1020 );
1021 }
1022 }
1023 Err(e) => {
1024 tracing::warn!("tier boost: failed to fetch tiers: {e:#}");
1025 }
1026 }
1027 }
1028
1029 if let Some(fs) = &self.five_signal
1031 && !fs.weights.is_baseline()
1032 {
1033 self.apply_five_signal_scoring(&mut ranked, fs, goal_entity_id)
1034 .await;
1035 }
1036
1037 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
1038
1039 if let Some(fs) = &self.five_signal {
1041 for id in &ids {
1042 fs.access_cache
1043 .log_access(*id, "message", &fs.session_id)
1044 .await;
1045 }
1046 fs.metrics.inc_recall();
1047 }
1048
1049 if !ids.is_empty()
1050 && let Err(e) = self.batch_increment_access_count(ids.clone()).await
1051 {
1052 tracing::warn!("recall: failed to increment access counts: {e:#}");
1053 }
1054
1055 if let Err(e) = self.sqlite.mark_training_recalled(&ids).await {
1057 tracing::debug!(
1058 error = %e,
1059 "recall: failed to mark training data as recalled (non-fatal)"
1060 );
1061 }
1062
1063 let messages = self.sqlite.messages_by_ids(&ids).await?;
1064 let msg_map: std::collections::HashMap<MessageId, _> = messages.into_iter().collect();
1065
1066 let recalled: Vec<RecalledMessage> = ranked
1067 .iter()
1068 .filter_map(|(msg_id, score)| {
1069 msg_map.get(msg_id).map(|msg| RecalledMessage {
1070 message: msg.clone(),
1071 #[expect(clippy::cast_possible_truncation)]
1072 score: *score as f32,
1073 })
1074 })
1075 .collect();
1076
1077 tracing::debug!(final_count = recalled.len(), "recall: final results");
1078
1079 Ok(recalled)
1080 }
1081
1082 async fn apply_five_signal_scoring(
1088 &self,
1089 ranked: &mut [(MessageId, f64)],
1090 fs: &crate::five_signal::FiveSignalRuntime,
1091 goal_entity_id: Option<i64>,
1092 ) {
1093 use crate::five_signal::causal_distance::CausalDistanceComputer;
1094 use crate::five_signal::scoring::{CandidateSignals, apply_five_signal_scoring};
1095 use sqlx::Row as _;
1096
1097 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
1098
1099 let freq_map = match fs
1101 .access_cache
1102 .load_for_candidates(&fs.session_id, &ids)
1103 .await
1104 {
1105 Ok(m) => m,
1106 Err(e) => {
1107 tracing::warn!(error = %e, "five_signal: failed to load access frequencies (skipping)");
1108 return;
1109 }
1110 };
1111
1112 let created_at_map: std::collections::HashMap<MessageId, i64> = {
1114 let id_vals: Vec<i64> = ids.iter().map(|id| id.0).collect();
1115 let placeholders: String = id_vals
1116 .iter()
1117 .enumerate()
1118 .map(|(i, _)| format!("?{}", i + 1))
1119 .collect::<Vec<_>>()
1120 .join(", ");
1121 let sql = format!(
1122 "SELECT id, created_at FROM messages WHERE id IN ({placeholders}) AND deleted_at IS NULL"
1123 );
1124 let mut q = sqlx::query(&sql);
1125 for id in &id_vals {
1126 q = q.bind(id);
1127 }
1128 match q.fetch_all(&fs.pool).await {
1129 Ok(rows) => rows
1130 .iter()
1131 .map(|row| {
1132 (
1133 MessageId(row.get::<i64, _>("id")),
1134 row.get::<i64, _>("created_at"),
1135 )
1136 })
1137 .collect(),
1138 Err(e) => {
1139 tracing::warn!(error = %e, "five_signal: failed to fetch created_at (skipping novelty)");
1140 std::collections::HashMap::new()
1141 }
1142 }
1143 };
1144
1145 let causal_distance_map: std::collections::HashMap<i64, u32> = {
1149 let entity_ids: Vec<i64> = ids.iter().map(|id| id.0).collect();
1150 let mut computer = fs.causal_computer.lock().await;
1151 match computer.compute(goal_entity_id, &entity_ids).await {
1152 Ok(m) => m,
1153 Err(e) => {
1154 tracing::warn!(error = %e, "five_signal: causal BFS failed (using neutral)");
1155 std::collections::HashMap::new()
1156 }
1157 }
1158 };
1159 let neutral_causal_score =
1160 CausalDistanceComputer::distance_to_score(fs.config.neutral_causal_distance);
1161
1162 let mut signals_map = std::collections::HashMap::with_capacity(ids.len());
1163 for &(msg_id, base_score) in ranked.iter() {
1164 let frequency = freq_map.get(&msg_id).copied().unwrap_or(0.0);
1165 let half = base_score / 2.0;
1168 let fact_created_at = created_at_map
1169 .get(&msg_id)
1170 .copied()
1171 .unwrap_or(fs.session_start);
1172 let novelty = fs.novelty_computer.compute(fact_created_at);
1173 let causal = causal_distance_map
1174 .get(&msg_id.0)
1175 .map_or(neutral_causal_score, |&d| {
1176 CausalDistanceComputer::distance_to_score(d)
1177 });
1178 signals_map.insert(
1179 msg_id,
1180 CandidateSignals {
1181 recency: half,
1182 relevance: half,
1183 frequency,
1184 causal,
1185 novelty,
1186 },
1187 );
1188 }
1189
1190 apply_five_signal_scoring(ranked, &fs.weights, &signals_map);
1191
1192 tracing::debug!(
1193 candidate_count = ids.len(),
1194 "recall: five-signal scoring applied"
1195 );
1196 }
1197
1198 #[cfg_attr(
1210 feature = "profiling",
1211 tracing::instrument(name = "memory.recall", skip_all, fields(query_len = %query.len(), result_count = tracing::field::Empty))
1212 )]
1213 pub async fn recall_routed(
1214 &self,
1215 query: &str,
1216 limit: usize,
1217 filter: Option<SearchFilter>,
1218 router: &dyn crate::router::MemoryRouter,
1219 goal_entity_id: Option<i64>,
1220 ) -> Result<Vec<RecalledMessage>, MemoryError> {
1221 use crate::router::MemoryRoute;
1222
1223 let route = router.route(query);
1224 tracing::debug!(?route, query_len = query.len(), "memory routing decision");
1225
1226 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
1227
1228 let (keyword_results, vector_results): (
1229 Vec<(MessageId, f64)>,
1230 Vec<crate::embedding_store::SearchResult>,
1231 ) = match route {
1232 MemoryRoute::Keyword => {
1233 let kw = self.recall_fts5_raw(query, limit, conversation_id).await?;
1234 (kw, Vec::new())
1235 }
1236 MemoryRoute::Hybrid => {
1237 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
1238 Ok(r) => r,
1239 Err(e) => {
1240 tracing::warn!("FTS5 keyword search failed: {e:#}");
1241 Vec::new()
1242 }
1243 };
1244 let vr = self.recall_vectors_raw(query, limit, filter).await?;
1245 (kw, vr)
1246 }
1247 MemoryRoute::Episodic => {
1256 let range = crate::router::resolve_temporal_range(query, chrono::Utc::now());
1257 let cleaned = crate::router::strip_temporal_keywords(query);
1258 let search_query = if cleaned.is_empty() { query } else { &cleaned };
1259 let kw = if let Some(ref r) = range {
1260 self.sqlite
1261 .keyword_search_with_time_range(
1262 search_query,
1263 limit,
1264 conversation_id,
1265 r.after.as_deref(),
1266 r.before.as_deref(),
1267 )
1268 .await?
1269 } else {
1270 self.recall_fts5_raw(search_query, limit, conversation_id)
1271 .await?
1272 };
1273 tracing::debug!(
1274 has_range = range.is_some(),
1275 cleaned_query = %search_query,
1276 keyword_count = kw.len(),
1277 "recall: episodic path"
1278 );
1279 (kw, Vec::new())
1280 }
1281 MemoryRoute::Graph => {
1284 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
1285 Ok(r) => r,
1286 Err(e) => {
1287 tracing::warn!("FTS5 keyword search failed (graph→hybrid fallback): {e:#}");
1288 Vec::new()
1289 }
1290 };
1291 let vr = self.recall_vectors_raw(query, limit, filter).await?;
1292 (kw, vr)
1293 }
1294 _ => {
1295 let vr = self.recall_vectors_raw(query, limit, filter).await?;
1296 (Vec::new(), vr)
1297 }
1298 };
1299
1300 tracing::debug!(
1301 keyword_count = keyword_results.len(),
1302 vector_count = vector_results.len(),
1303 "recall: routed search results"
1304 );
1305
1306 self.recall_merge_and_rank(keyword_results, vector_results, limit, goal_entity_id)
1307 .await
1308 }
1309
1310 #[cfg_attr(
1324 feature = "profiling",
1325 tracing::instrument(name = "memory.recall", skip_all, fields(query_len = %query.len(), result_count = tracing::field::Empty))
1326 )]
1327 pub async fn recall_routed_async(
1328 &self,
1329 query: &str,
1330 limit: usize,
1331 filter: Option<crate::embedding_store::SearchFilter>,
1332 router: &dyn crate::router::AsyncMemoryRouter,
1333 goal_entity_id: Option<i64>,
1334 ) -> Result<Vec<RecalledMessage>, MemoryError> {
1335 use crate::router::MemoryRoute;
1336
1337 let decision = router.route_async(query).await;
1338 let route = decision.route;
1339 tracing::debug!(
1340 ?route,
1341 confidence = decision.confidence,
1342 query_len = query.len(),
1343 "memory routing decision (async)"
1344 );
1345
1346 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
1347
1348 let (keyword_results, vector_results): (
1349 Vec<(crate::types::MessageId, f64)>,
1350 Vec<crate::embedding_store::SearchResult>,
1351 ) = match route {
1352 MemoryRoute::Keyword => {
1353 let kw = self.recall_fts5_raw(query, limit, conversation_id).await?;
1354 (kw, Vec::new())
1355 }
1356 MemoryRoute::Hybrid => {
1357 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
1358 Ok(r) => r,
1359 Err(e) => {
1360 tracing::warn!("FTS5 keyword search failed: {e:#}");
1361 Vec::new()
1362 }
1363 };
1364 let vr = self.recall_vectors_raw(query, limit, filter).await?;
1365 (kw, vr)
1366 }
1367 MemoryRoute::Episodic => {
1368 let range = crate::router::resolve_temporal_range(query, chrono::Utc::now());
1369 let cleaned = crate::router::strip_temporal_keywords(query);
1370 let search_query = if cleaned.is_empty() { query } else { &cleaned };
1371 let kw = if let Some(ref r) = range {
1372 self.sqlite
1373 .keyword_search_with_time_range(
1374 search_query,
1375 limit,
1376 conversation_id,
1377 r.after.as_deref(),
1378 r.before.as_deref(),
1379 )
1380 .await?
1381 } else {
1382 self.recall_fts5_raw(search_query, limit, conversation_id)
1383 .await?
1384 };
1385 (kw, Vec::new())
1386 }
1387 MemoryRoute::Graph => {
1388 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
1389 Ok(r) => r,
1390 Err(e) => {
1391 tracing::warn!("FTS5 keyword search failed (graph→hybrid fallback): {e:#}");
1392 Vec::new()
1393 }
1394 };
1395 let vr = self.recall_vectors_raw(query, limit, filter).await?;
1396 (kw, vr)
1397 }
1398 _ => {
1399 let vr = self.recall_vectors_raw(query, limit, filter).await?;
1400 (Vec::new(), vr)
1401 }
1402 };
1403
1404 tracing::debug!(
1405 keyword_count = keyword_results.len(),
1406 vector_count = vector_results.len(),
1407 "recall: routed search results (async)"
1408 );
1409
1410 self.recall_merge_and_rank(keyword_results, vector_results, limit, goal_entity_id)
1411 .await
1412 }
1413
1414 #[cfg_attr(
1428 feature = "profiling",
1429 tracing::instrument(name = "memory.recall_graph", skip_all, fields(result_count = tracing::field::Empty))
1430 )]
1431 pub async fn recall_graph(
1432 &self,
1433 query: &str,
1434 limit: usize,
1435 max_hops: u32,
1436 at_timestamp: Option<&str>,
1437 temporal_decay_rate: f64,
1438 edge_types: &[crate::graph::EdgeType],
1439 ) -> Result<Vec<crate::graph::types::GraphFact>, MemoryError> {
1440 let Some(store) = &self.graph_store else {
1441 return Ok(Vec::new());
1442 };
1443
1444 tracing::debug!(
1445 query_len = query.len(),
1446 limit,
1447 max_hops,
1448 "graph: starting recall"
1449 );
1450
1451 let results = crate::graph::retrieval::graph_recall(
1452 store,
1453 self.qdrant.as_deref(),
1454 &self.provider,
1455 query,
1456 limit,
1457 max_hops,
1458 at_timestamp,
1459 temporal_decay_rate,
1460 edge_types,
1461 self.hebbian_reinforcement.is_enabled(),
1462 self.hebbian_lr,
1463 self.embed_timeout,
1464 )
1465 .await?;
1466
1467 tracing::debug!(result_count = results.len(), "graph: recall complete");
1468 #[cfg(feature = "profiling")]
1469 tracing::Span::current().record("result_count", results.len());
1470
1471 Ok(results)
1472 }
1473
1474 #[cfg_attr(
1483 feature = "profiling",
1484 tracing::instrument(name = "memory.recall_graph", skip_all, fields(result_count = tracing::field::Empty))
1485 )]
1486 pub async fn recall_graph_activated(
1487 &self,
1488 query: &str,
1489 limit: usize,
1490 params: crate::graph::SpreadingActivationParams,
1491 edge_types: &[crate::graph::EdgeType],
1492 ) -> Result<Vec<crate::graph::activation::ActivatedFact>, MemoryError> {
1493 let Some(store) = &self.graph_store else {
1494 return Ok(Vec::new());
1495 };
1496
1497 tracing::debug!(
1498 query_len = query.len(),
1499 limit,
1500 "spreading activation: starting graph recall"
1501 );
1502
1503 let embeddings = self.qdrant.as_deref();
1504 let results = crate::graph::retrieval::graph_recall_activated(
1505 store,
1506 embeddings,
1507 &self.provider,
1508 query,
1509 limit,
1510 params,
1511 edge_types,
1512 self.hebbian_reinforcement.is_enabled(),
1513 self.hebbian_lr,
1514 self.embed_timeout,
1515 )
1516 .await?;
1517
1518 tracing::debug!(
1519 result_count = results.len(),
1520 "spreading activation: graph recall complete"
1521 );
1522
1523 Ok(results)
1524 }
1525
1526 #[allow(clippy::too_many_arguments, clippy::too_many_lines)] #[cfg_attr(
1559 feature = "profiling",
1560 tracing::instrument(
1561 name = "memory.recall.graph_view",
1562 skip_all,
1563 fields(view = ?view, result_count = tracing::field::Empty)
1564 )
1565 )]
1566 pub async fn recall_graph_view(
1567 &self,
1568 query: &str,
1569 limit: usize,
1570 view: crate::recall_view::RecallView,
1571 neighbor_cap: usize,
1572 bfs_max_hops: u32,
1573 temporal_decay_rate: f64,
1574 edge_types: &[crate::graph::EdgeType],
1575 sa_params: Option<crate::graph::SpreadingActivationParams>,
1576 ) -> Result<Vec<crate::recall_view::RecalledFact>, MemoryError> {
1577 use crate::recall_view::{RecallView, RecalledFact};
1578
1579 let mut recalled: Vec<RecalledFact> = if let Some(params) = sa_params {
1581 let activated = self
1582 .recall_graph_activated(query, limit, params, edge_types)
1583 .await?;
1584 activated
1585 .into_iter()
1586 .map(|af| {
1587 let activation_score = af.activation_score;
1590 let edge = &af.edge;
1591 let fact = crate::graph::types::GraphFact {
1592 entity_name: String::new(), relation: edge.canonical_relation.clone(),
1594 target_name: String::new(),
1595 fact: edge.fact.clone(),
1596 entity_match_score: activation_score,
1597 hop_distance: 0,
1598 confidence: edge.confidence,
1599 valid_from: if edge.valid_from.is_empty() {
1600 None
1601 } else {
1602 Some(edge.valid_from.clone())
1603 },
1604 edge_type: edge.edge_type,
1605 retrieval_count: edge.retrieval_count,
1606 edge_id: Some(edge.id),
1607 };
1608 RecalledFact {
1609 fact,
1610 activation_score: Some(activation_score),
1611 provenance_message_id: edge.source_message_id,
1612 provenance_snippet: None,
1613 neighbors: Vec::new(),
1614 }
1615 })
1616 .collect()
1617 } else {
1618 let facts = self
1619 .recall_graph(
1620 query,
1621 limit,
1622 bfs_max_hops,
1623 None,
1624 temporal_decay_rate,
1625 edge_types,
1626 )
1627 .await?;
1628 facts
1629 .into_iter()
1630 .map(RecalledFact::from_graph_fact)
1631 .collect()
1632 };
1633
1634 if view == RecallView::Head {
1636 #[cfg(feature = "profiling")]
1637 tracing::Span::current().record("result_count", recalled.len());
1638 return Ok(recalled);
1639 }
1640
1641 if matches!(view, RecallView::ZoomIn | RecallView::ZoomOut) {
1643 let edge_ids: Vec<i64> = recalled.iter().filter_map(|r| r.fact.edge_id).collect();
1644
1645 if !edge_ids.is_empty()
1646 && let Some(ref store) = self.graph_store
1647 {
1648 const MAX_IDS: usize = 490;
1650 let mut edge_to_msg: std::collections::HashMap<i64, MessageId> =
1651 std::collections::HashMap::new();
1652 for chunk in edge_ids.chunks(MAX_IDS) {
1653 match store.source_message_ids_for_edges(chunk).await {
1654 Ok(pairs) => {
1655 for (eid, mid) in pairs {
1656 edge_to_msg.insert(eid, mid);
1657 }
1658 }
1659 Err(e) => {
1660 tracing::warn!(error = %e, "recall_graph_view: provenance fetch failed");
1661 }
1662 }
1663 }
1664
1665 for rf in &mut recalled {
1667 if rf.provenance_message_id.is_none()
1668 && let Some(eid) = rf.fact.edge_id
1669 {
1670 rf.provenance_message_id = edge_to_msg.get(&eid).copied();
1671 }
1672 }
1673
1674 let msg_ids: Vec<MessageId> = recalled
1676 .iter()
1677 .filter_map(|r| r.provenance_message_id)
1678 .collect::<std::collections::HashSet<_>>()
1679 .into_iter()
1680 .collect();
1681
1682 if !msg_ids.is_empty() {
1683 match self.sqlite.messages_by_ids(&msg_ids).await {
1684 Ok(messages) => {
1685 let mut mid_to_snippet: std::collections::HashMap<MessageId, String> =
1686 messages
1687 .into_iter()
1688 .map(|(id, msg)| {
1689 let raw = &msg.content;
1690 let scrubbed: String = raw
1691 .chars()
1692 .map(|c| match c {
1693 '\n' | '\r' | '<' | '>' => ' ',
1694 other => other,
1695 })
1696 .take(200)
1697 .collect();
1698 (id, scrubbed)
1699 })
1700 .collect();
1701 for rf in &mut recalled {
1702 if let Some(mid) = rf.provenance_message_id {
1703 rf.provenance_snippet = mid_to_snippet.remove(&mid);
1704 }
1705 }
1706 }
1707 Err(e) => {
1708 tracing::warn!(error = %e, "recall_graph_view: message snippet fetch failed");
1709 }
1710 }
1711 }
1712 }
1713 }
1714
1715 if view == RecallView::ZoomOut
1717 && let Some(ref store) = self.graph_store
1718 {
1719 type DedupeKey = (String, String, String, crate::graph::EdgeType);
1723 let make_key = |f: &crate::graph::types::GraphFact| -> DedupeKey {
1724 if f.entity_name.is_empty() || f.target_name.is_empty() {
1725 (
1726 f.fact.clone(),
1727 f.relation.clone(),
1728 String::new(),
1729 f.edge_type,
1730 )
1731 } else {
1732 (
1733 f.entity_name.clone(),
1734 f.relation.clone(),
1735 f.target_name.clone(),
1736 f.edge_type,
1737 )
1738 }
1739 };
1740 let mut seen: std::collections::HashSet<DedupeKey> =
1741 recalled.iter().map(|r| make_key(&r.fact)).collect();
1742
1743 let total_neighbor_cap = limit * neighbor_cap;
1744 let mut total_neighbors = 0usize;
1745
1746 for rf in &mut recalled {
1747 if total_neighbors >= total_neighbor_cap {
1748 break;
1749 }
1750 let source_entity_id = match rf.fact.edge_id {
1753 Some(eid) => match store.source_entity_id_for_edge(eid).await {
1754 Ok(Some(id)) => id,
1755 _ => continue,
1756 },
1757 None => continue,
1758 };
1759
1760 let neighbors = match store
1761 .bfs_edges_at_depth(source_entity_id, 1, edge_types)
1762 .await
1763 {
1764 Ok(edges) => edges,
1765 Err(e) => {
1766 tracing::warn!(error = %e, "recall_graph_view: zoom_out bfs failed");
1767 continue;
1768 }
1769 };
1770
1771 let mut added = 0usize;
1772 for n_edge in neighbors {
1773 if added >= neighbor_cap || total_neighbors >= total_neighbor_cap {
1774 break;
1775 }
1776 let key = make_key(&n_edge.fact);
1777 if seen.insert(key) {
1778 rf.neighbors.push(n_edge.fact);
1779 added += 1;
1780 total_neighbors += 1;
1781 }
1782 }
1783 }
1784 }
1785
1786 #[cfg(feature = "profiling")]
1787 tracing::Span::current().record("result_count", recalled.len());
1788 Ok(recalled)
1789 }
1790
1791 pub async fn recall_graph_astar(
1799 &self,
1800 query: &str,
1801 limit: usize,
1802 max_hops: u32,
1803 temporal_decay_rate: f64,
1804 edge_types: &[crate::graph::EdgeType],
1805 ) -> Result<Vec<crate::graph::types::GraphFact>, MemoryError> {
1806 let Some(store) = &self.graph_store else {
1807 return Ok(Vec::new());
1808 };
1809 crate::graph::retrieval_astar::graph_recall_astar(
1810 store,
1811 self.qdrant.as_deref(),
1812 &self.provider,
1813 query,
1814 limit,
1815 max_hops,
1816 edge_types,
1817 temporal_decay_rate,
1818 self.hebbian_reinforcement.is_enabled(),
1819 self.hebbian_lr,
1820 self.query_sensitive_cost,
1821 self.embed_timeout,
1822 )
1823 .await
1824 }
1825
1826 pub async fn recall_graph_watercircles(
1834 &self,
1835 query: &str,
1836 limit: usize,
1837 max_hops: u32,
1838 ring_limit: usize,
1839 temporal_decay_rate: f64,
1840 edge_types: &[crate::graph::EdgeType],
1841 ) -> Result<Vec<crate::graph::types::GraphFact>, MemoryError> {
1842 let Some(store) = &self.graph_store else {
1843 return Ok(Vec::new());
1844 };
1845 crate::graph::retrieval_watercircles::graph_recall_watercircles(
1846 store,
1847 self.qdrant.as_deref(),
1848 &self.provider,
1849 query,
1850 limit,
1851 max_hops,
1852 ring_limit,
1853 edge_types,
1854 temporal_decay_rate,
1855 self.hebbian_reinforcement.is_enabled(),
1856 self.hebbian_lr,
1857 self.embed_timeout,
1858 )
1859 .await
1860 }
1861
1862 pub async fn recall_graph_beam(
1870 &self,
1871 query: &str,
1872 limit: usize,
1873 beam_width: usize,
1874 max_hops: u32,
1875 temporal_decay_rate: f64,
1876 edge_types: &[crate::graph::EdgeType],
1877 ) -> Result<Vec<crate::graph::types::GraphFact>, MemoryError> {
1878 let Some(store) = &self.graph_store else {
1879 return Ok(Vec::new());
1880 };
1881 crate::graph::retrieval_beam::graph_recall_beam(
1882 store,
1883 self.qdrant.as_deref(),
1884 &self.provider,
1885 query,
1886 limit,
1887 beam_width,
1888 max_hops,
1889 edge_types,
1890 temporal_decay_rate,
1891 self.hebbian_reinforcement.is_enabled(),
1892 self.hebbian_lr,
1893 self.embed_timeout,
1894 )
1895 .await
1896 }
1897
1898 pub async fn classify_graph_strategy(&self, query: &str) -> String {
1903 crate::graph::strategy_classifier::classify_retrieval_strategy(&self.provider, query).await
1904 }
1905
1906 #[cfg_attr(
1916 feature = "profiling",
1917 tracing::instrument(
1918 name = "memory.recall_graph_hela",
1919 skip_all,
1920 fields(result_count = tracing::field::Empty)
1921 )
1922 )]
1923 pub async fn recall_graph_hela(
1924 &self,
1925 query: &str,
1926 limit: usize,
1927 params: crate::graph::HelaSpreadParams,
1928 ) -> Result<Vec<crate::graph::HelaFact>, MemoryError> {
1929 let Some(store) = &self.graph_store else {
1930 return Ok(Vec::new());
1931 };
1932 let Some(embeddings) = &self.qdrant else {
1933 return Ok(Vec::new());
1934 };
1935
1936 let store = Arc::clone(store);
1937 let embeddings = Arc::clone(embeddings);
1938 let provider = self.provider.clone();
1939 let hebbian_enabled = self.hebbian_reinforcement.is_enabled();
1940 let hebbian_lr = self.hebbian_lr;
1941
1942 let results = tokio::time::timeout(
1943 std::time::Duration::from_millis(200),
1944 crate::graph::hela_spreading_recall(
1945 &store,
1946 &embeddings,
1947 &provider,
1948 query,
1949 limit,
1950 ¶ms,
1951 hebbian_enabled,
1952 hebbian_lr,
1953 ),
1954 )
1955 .await
1956 .unwrap_or_else(|_| {
1957 tracing::warn!("memory.recall_graph_hela: outer 200ms timeout exceeded");
1958 Ok(Vec::new())
1959 })?;
1960
1961 #[cfg(feature = "profiling")]
1962 tracing::Span::current().record("result_count", results.len());
1963
1964 Ok(results)
1965 }
1966
1967 async fn batch_increment_access_count(
1975 &self,
1976 message_ids: Vec<MessageId>,
1977 ) -> Result<(), MemoryError> {
1978 if message_ids.is_empty() {
1979 return Ok(());
1980 }
1981 self.sqlite.increment_access_counts(&message_ids).await
1982 }
1983
1984 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
1990 match &self.qdrant {
1991 Some(qdrant) => qdrant.has_embedding(message_id).await,
1992 None => Ok(false),
1993 }
1994 }
1995
1996 pub async fn embed_missing(
2012 &self,
2013 progress_tx: Option<tokio::sync::watch::Sender<Option<super::BackfillProgress>>>,
2014 ) -> Result<usize, MemoryError> {
2015 if self.qdrant.is_none() || !self.effective_embed_provider().supports_embeddings() {
2016 return Ok(0);
2017 }
2018
2019 let total = self.sqlite.count_unembedded_messages().await?;
2020 if total == 0 {
2021 return Ok(0);
2022 }
2023
2024 if let Some(tx) = &progress_tx {
2025 let _ = tx.send(Some(super::BackfillProgress { done: 0, total }));
2026 }
2027
2028 let mut done = 0usize;
2029 let mut succeeded = 0usize;
2030
2031 loop {
2032 const BATCH_SIZE: usize = 32;
2033 const BATCH_SIZE_I64: i64 = 32;
2034 let rows: Vec<_> = self
2035 .sqlite
2036 .stream_unembedded_messages(BATCH_SIZE_I64)
2037 .try_collect()
2038 .await?;
2039
2040 if rows.is_empty() {
2041 break;
2042 }
2043
2044 let batch_len = rows.len();
2045
2046 let results: Vec<bool> = futures::stream::iter(rows)
2047 .map(|(msg_id, conv_id, role, content)| async move {
2048 self.embed_and_store_regular(msg_id, conv_id, &role, &content)
2049 })
2050 .buffer_unordered(4)
2051 .collect()
2052 .await;
2053
2054 for ok in &results {
2055 done += 1;
2056 if *ok {
2057 succeeded += 1;
2058 }
2059 if let Some(tx) = &progress_tx {
2060 let _ = tx.send(Some(super::BackfillProgress { done, total }));
2061 }
2062 }
2063
2064 let batch_succeeded = results.iter().filter(|&&b| b).count();
2065 if batch_succeeded > 0 {
2066 tracing::debug!("Backfill batch: {batch_succeeded}/{batch_len} embedded");
2067 }
2068
2069 if batch_len < BATCH_SIZE {
2070 break;
2071 }
2072 }
2073
2074 if let Some(tx) = &progress_tx {
2075 let _ = tx.send(None);
2076 }
2077
2078 if done > 0 {
2079 tracing::info!("Embedded {succeeded}/{total} missing messages");
2080 }
2081 Ok(succeeded)
2082 }
2083}
2084
2085#[cfg(test)]
2086mod tests {
2087 use super::*;
2088
2089 #[test]
2090 fn embed_context_default_all_none() {
2091 let ctx = EmbedContext::default();
2092 assert!(ctx.tool_name.is_none());
2093 assert!(ctx.exit_code.is_none());
2094 assert!(ctx.timestamp.is_none());
2095 }
2096
2097 #[test]
2098 fn embed_context_fields_set_correctly() {
2099 let ctx = EmbedContext {
2100 tool_name: Some("shell".to_string()),
2101 exit_code: Some(0),
2102 timestamp: Some("2026-04-04T00:00:00Z".to_string()),
2103 };
2104 assert_eq!(ctx.tool_name.as_deref(), Some("shell"));
2105 assert_eq!(ctx.exit_code, Some(0));
2106 assert_eq!(ctx.timestamp.as_deref(), Some("2026-04-04T00:00:00Z"));
2107 }
2108
2109 #[test]
2110 fn embed_context_non_zero_exit_code() {
2111 let ctx = EmbedContext {
2112 tool_name: Some("shell".to_string()),
2113 exit_code: Some(1),
2114 timestamp: None,
2115 };
2116 assert_eq!(ctx.exit_code, Some(1));
2117 assert!(ctx.timestamp.is_none());
2118 }
2119
2120 async fn make_semantic_memory() -> crate::semantic::SemanticMemory {
2121 use std::sync::Arc;
2122 use std::sync::atomic::AtomicU64;
2123 use zeph_llm::any::AnyProvider;
2124 use zeph_llm::mock::MockProvider;
2125
2126 let provider = AnyProvider::Mock(MockProvider::default());
2127 let sqlite = crate::store::SqliteStore::new(":memory:").await.unwrap();
2128 crate::semantic::SemanticMemory {
2129 sqlite,
2130 qdrant: None,
2131 provider,
2132 embed_provider: None,
2133 embedding_model: "test-model".into(),
2134 vector_weight: 0.7,
2135 keyword_weight: 0.3,
2136 temporal_decay: crate::semantic::TemporalDecay::Disabled,
2137 temporal_decay_half_life_days: 30,
2138 mmr_reranking: crate::semantic::MmrReranking::Disabled,
2139 mmr_lambda: 0.7,
2140 importance_scoring: crate::semantic::ImportanceScoring::Disabled,
2141 importance_weight: 0.15,
2142 token_counter: Arc::new(crate::token_counter::TokenCounter::new()),
2143 graph_store: None,
2144 experience: None,
2145 community_detection_failures: Arc::new(AtomicU64::new(0)),
2146 graph_extraction_count: Arc::new(AtomicU64::new(0)),
2147 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
2148 last_qdrant_warn: Arc::new(AtomicU64::new(0)),
2149 tier_boost_semantic: 1.3,
2150 admission_control: None,
2151 quality_gate: None,
2152 key_facts_dedup_threshold: 0.95,
2153 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
2154 retrieval_depth: 0,
2155 search_prompt_template: String::new(),
2156 depth_below_limit_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
2157 missing_placeholder_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
2158 reasoning: None,
2159 query_bias_correction: crate::semantic::QueryBiasCorrection::Disabled,
2160 query_bias_profile_weight: 0.25,
2161 profile_centroid: tokio::sync::RwLock::new(None),
2162 profile_centroid_ttl_secs: 300,
2163 hebbian_reinforcement: crate::semantic::HebbianReinforcement::Disabled,
2164 hebbian_lr: 0.1,
2165 hebbian_spread: crate::HelaSpreadRuntime::default(),
2166 retrieval_failure_logger: None,
2167 summarization_llm_timeout_secs: 60,
2168 query_sensitive_cost: false,
2169 five_signal: None,
2170 embed_timeout: std::time::Duration::from_secs(5),
2171 graph_cancel: std::sync::Mutex::new(None),
2172 }
2173 }
2174
2175 #[tokio::test]
2176 async fn spawn_embed_bg_returns_true_when_capacity_available() {
2177 let memory = make_semantic_memory().await;
2178 let dispatched = memory.spawn_embed_bg(std::future::ready(()));
2179 assert!(
2180 dispatched,
2181 "spawn_embed_bg must return true when a task was successfully spawned"
2182 );
2183 }
2184
2185 #[tokio::test]
2186 async fn spawn_embed_bg_returns_false_at_capacity() {
2187 let memory = make_semantic_memory().await;
2188
2189 {
2191 let mut tasks = memory.embed_tasks.lock().unwrap();
2192 for _ in 0..MAX_EMBED_BG_TASKS {
2193 tasks.spawn(std::future::pending::<()>());
2194 }
2195 }
2196
2197 let dispatched = memory.spawn_embed_bg(std::future::ready(()));
2198 assert!(
2199 !dispatched,
2200 "spawn_embed_bg must return false when the task limit is reached"
2201 );
2202 }
2203
2204 #[test]
2205 fn qdrant_warn_rate_limit_suppresses_within_window() {
2206 use std::sync::Arc;
2207 use std::sync::atomic::{AtomicU64, Ordering};
2208
2209 let last_warn = Arc::new(AtomicU64::new(0));
2210 let window_secs = 10u64;
2211
2212 let now1 = 100u64;
2214 let last1 = last_warn.load(Ordering::Relaxed);
2215 let should_warn1 = now1.saturating_sub(last1) >= window_secs;
2216 assert!(should_warn1, "first call must not be suppressed");
2217 if should_warn1 {
2218 last_warn.store(now1, Ordering::Relaxed);
2219 }
2220
2221 let now2 = 105u64;
2223 let last2 = last_warn.load(Ordering::Relaxed);
2224 let should_warn2 = now2.saturating_sub(last2) >= window_secs;
2225 assert!(!should_warn2, "call within 10s window must be suppressed");
2226
2227 let now3 = 110u64;
2229 let last3 = last_warn.load(Ordering::Relaxed);
2230 let should_warn3 = now3.saturating_sub(last3) >= window_secs;
2231 assert!(
2232 should_warn3,
2233 "call after window expiry must not be suppressed"
2234 );
2235 }
2236
2237 #[test]
2238 fn qdrant_warn_rate_limit_shared_across_concurrent_sites() {
2239 use std::sync::Arc;
2240 use std::sync::atomic::{AtomicU64, Ordering};
2241
2242 let shared = Arc::new(AtomicU64::new(0));
2245 let window_secs = 10u64;
2246
2247 let site_a = Arc::clone(&shared);
2248 let site_b = Arc::clone(&shared);
2249
2250 let now_a = 100u64;
2251 let last_a = site_a.load(Ordering::Relaxed);
2252 if now_a.saturating_sub(last_a) >= window_secs {
2253 site_a.store(now_a, Ordering::Relaxed);
2254 }
2255
2256 let now_b = 105u64;
2257 let last_b = site_b.load(Ordering::Relaxed);
2258 let warn_b = now_b.saturating_sub(last_b) >= window_secs;
2259 assert!(
2260 !warn_b,
2261 "site B must be suppressed because site A already warned within the window"
2262 );
2263 }
2264}