1use std::sync::Arc;
5use std::sync::atomic::{AtomicU64, Ordering};
6
7use futures::{StreamExt as _, TryStreamExt as _};
8use zeph_llm::provider::{LlmProvider as _, Message};
9
10const CHARS_PER_TOKEN: usize = 4;
12
13const CHUNK_CHARS: usize = 400 * CHARS_PER_TOKEN;
15
16const CHUNK_OVERLAP_CHARS: usize = 80 * CHARS_PER_TOKEN;
18
19fn chunk_text(text: &str) -> Vec<&str> {
25 if text.len() <= CHUNK_CHARS {
26 return vec![text];
27 }
28
29 let mut chunks = Vec::new();
30 let mut start = 0;
31
32 while start < text.len() {
33 let end = if start + CHUNK_CHARS >= text.len() {
34 text.len()
35 } else {
36 let boundary = text.floor_char_boundary(start + CHUNK_CHARS);
38 let slice = &text[start..boundary];
40 if let Some(pos) = slice.rfind("\n\n") {
41 start + pos + 2
42 } else if let Some(pos) = slice.rfind('\n') {
43 start + pos + 1
44 } else if let Some(pos) = slice.rfind(' ') {
45 start + pos + 1
46 } else {
47 boundary
48 }
49 };
50
51 chunks.push(&text[start..end]);
52 if end >= text.len() {
53 break;
54 }
55 let next = end.saturating_sub(CHUNK_OVERLAP_CHARS);
59 let new_start = text.ceil_char_boundary(next);
60 start = if new_start > start { new_start } else { end };
61 }
62
63 chunks
64}
65
66use crate::admission::log_admission_decision;
67use crate::embedding_store::{MessageKind, SearchFilter};
68use crate::error::MemoryError;
69use crate::types::{ConversationId, MessageId};
70
71use super::SemanticMemory;
72use super::algorithms::{apply_mmr, apply_temporal_decay};
73
74#[derive(Debug, Clone, Default)]
78pub struct EmbedContext {
79 pub tool_name: Option<String>,
80 pub exit_code: Option<i32>,
81 pub timestamp: Option<String>,
82}
83
84#[derive(Debug)]
85pub struct RecalledMessage {
86 pub message: Message,
87 pub score: f32,
88}
89
90const MAX_EMBED_BG_TASKS: usize = 64;
92
93struct EmbedBgArgs {
95 qdrant: Arc<crate::embedding_store::EmbeddingStore>,
96 embed_provider: zeph_llm::any::AnyProvider,
97 embedding_model: String,
98 message_id: MessageId,
99 conversation_id: ConversationId,
100 role: String,
101 content: String,
102 last_qdrant_warn: Arc<AtomicU64>,
103}
104
105async fn embed_and_store_regular_bg(args: EmbedBgArgs) {
109 let EmbedBgArgs {
110 qdrant,
111 embed_provider,
112 embedding_model,
113 message_id,
114 conversation_id,
115 role,
116 content,
117 last_qdrant_warn,
118 } = args;
119 let chunks = chunk_text(&content);
120 let chunk_count = chunks.len();
121
122 let vectors = match embed_provider.embed_batch(&chunks).await {
123 Ok(v) => v,
124 Err(e) => {
125 tracing::warn!("bg embed_regular: failed to embed chunks for msg {message_id}: {e:#}");
126 return;
127 }
128 };
129
130 let Some(first) = vectors.first() else {
131 return;
132 };
133 let vector_size = first.len() as u64;
134 if let Err(e) = qdrant.ensure_collection(vector_size).await {
135 let now = std::time::SystemTime::now()
136 .duration_since(std::time::UNIX_EPOCH)
137 .unwrap_or_default()
138 .as_secs();
139 let last = last_qdrant_warn.load(Ordering::Relaxed);
140 if now.saturating_sub(last) >= 10 {
141 last_qdrant_warn.store(now, Ordering::Relaxed);
142 tracing::warn!("bg embed_regular: failed to ensure Qdrant collection: {e:#}");
143 } else {
144 tracing::debug!(
145 "bg embed_regular: failed to ensure Qdrant collection (suppressed): {e:#}"
146 );
147 }
148 return;
149 }
150
151 for (chunk_index, vector) in vectors.into_iter().enumerate() {
152 let chunk_index_u32 = u32::try_from(chunk_index).unwrap_or(u32::MAX);
153 if let Err(e) = qdrant
154 .store(
155 message_id,
156 conversation_id,
157 &role,
158 vector,
159 MessageKind::Regular,
160 &embedding_model,
161 chunk_index_u32,
162 )
163 .await
164 {
165 tracing::warn!(
166 "bg embed_regular: failed to store chunk {chunk_index}/{chunk_count} \
167 for msg {message_id}: {e:#}"
168 );
169 }
170 }
171}
172
173async fn embed_chunks_with_tool_context_bg(args: EmbedBgArgs, embed_ctx: EmbedContext) {
177 let EmbedBgArgs {
178 qdrant,
179 embed_provider,
180 embedding_model,
181 message_id,
182 conversation_id,
183 role,
184 content,
185 last_qdrant_warn,
186 } = args;
187 let chunks = chunk_text(&content);
188 let chunk_count = chunks.len();
189
190 let vectors = match embed_provider.embed_batch(&chunks).await {
191 Ok(v) => v,
192 Err(e) => {
193 tracing::warn!(
194 "bg embed_tool: failed to embed tool-output chunks for msg {message_id}: {e:#}"
195 );
196 return;
197 }
198 };
199
200 if let Some(first) = vectors.first() {
201 let vector_size = first.len() as u64;
202 if let Err(e) = qdrant.ensure_collection(vector_size).await {
203 let now = std::time::SystemTime::now()
204 .duration_since(std::time::UNIX_EPOCH)
205 .unwrap_or_default()
206 .as_secs();
207 let last = last_qdrant_warn.load(Ordering::Relaxed);
208 if now.saturating_sub(last) >= 10 {
209 last_qdrant_warn.store(now, Ordering::Relaxed);
210 tracing::warn!("bg embed_tool: failed to ensure Qdrant collection: {e:#}");
211 } else {
212 tracing::debug!(
213 "bg embed_tool: failed to ensure Qdrant collection (suppressed): {e:#}"
214 );
215 }
216 return;
217 }
218 }
219
220 for (chunk_index, vector) in vectors.into_iter().enumerate() {
221 let chunk_index_u32 = u32::try_from(chunk_index).unwrap_or(u32::MAX);
222 let result = if let Some(ref tool_name) = embed_ctx.tool_name {
223 qdrant
224 .store_with_tool_context(
225 message_id,
226 conversation_id,
227 &role,
228 vector,
229 MessageKind::Regular,
230 &embedding_model,
231 chunk_index_u32,
232 tool_name,
233 embed_ctx.exit_code,
234 embed_ctx.timestamp.as_deref(),
235 )
236 .await
237 .map(|_| ())
238 } else {
239 qdrant
240 .store(
241 message_id,
242 conversation_id,
243 &role,
244 vector,
245 MessageKind::Regular,
246 &embedding_model,
247 chunk_index_u32,
248 )
249 .await
250 .map(|_| ())
251 };
252 if let Err(e) = result {
253 tracing::warn!(
254 "bg embed_tool: failed to store chunk {chunk_index}/{chunk_count} \
255 for msg {message_id}: {e:#}"
256 );
257 }
258 }
259}
260
261async fn embed_and_store_with_category_bg(args: EmbedBgArgs, category: Option<String>) {
265 let EmbedBgArgs {
266 qdrant,
267 embed_provider,
268 embedding_model,
269 message_id,
270 conversation_id,
271 role,
272 content,
273 last_qdrant_warn,
274 } = args;
275 let chunks = chunk_text(&content);
276 let chunk_count = chunks.len();
277
278 let vectors = match embed_provider.embed_batch(&chunks).await {
279 Ok(v) => v,
280 Err(e) => {
281 tracing::warn!(
282 "bg embed_category: failed to embed categorized chunks for msg {message_id}: {e:#}"
283 );
284 return;
285 }
286 };
287
288 let Some(first) = vectors.first() else {
289 return;
290 };
291 let vector_size = first.len() as u64;
292 if let Err(e) = qdrant.ensure_collection(vector_size).await {
293 let now = std::time::SystemTime::now()
294 .duration_since(std::time::UNIX_EPOCH)
295 .unwrap_or_default()
296 .as_secs();
297 let last = last_qdrant_warn.load(Ordering::Relaxed);
298 if now.saturating_sub(last) >= 10 {
299 last_qdrant_warn.store(now, Ordering::Relaxed);
300 tracing::warn!("bg embed_category: failed to ensure Qdrant collection: {e:#}");
301 } else {
302 tracing::debug!(
303 "bg embed_category: failed to ensure Qdrant collection (suppressed): {e:#}"
304 );
305 }
306 return;
307 }
308
309 for (chunk_index, vector) in vectors.into_iter().enumerate() {
310 let chunk_index_u32 = u32::try_from(chunk_index).unwrap_or(u32::MAX);
311 if let Err(e) = qdrant
312 .store_with_category(
313 message_id,
314 conversation_id,
315 &role,
316 vector,
317 MessageKind::Regular,
318 &embedding_model,
319 chunk_index_u32,
320 category.as_deref(),
321 )
322 .await
323 {
324 tracing::warn!(
325 "bg embed_category: failed to store chunk {chunk_index}/{chunk_count} \
326 for msg {message_id}: {e:#}"
327 );
328 }
329 }
330}
331
332impl SemanticMemory {
333 #[cfg_attr(
343 feature = "profiling",
344 tracing::instrument(name = "memory.remember", skip_all, fields(content_len = %content.len()))
345 )]
346 pub async fn remember(
347 &self,
348 conversation_id: ConversationId,
349 role: &str,
350 content: &str,
351 goal_text: Option<&str>,
352 ) -> Result<Option<MessageId>, MemoryError> {
353 if let Some(ref admission) = self.admission_control {
355 let decision = admission
356 .evaluate(
357 content,
358 role,
359 self.effective_embed_provider(),
360 self.qdrant.as_ref(),
361 goal_text,
362 )
363 .await;
364 let preview: String = content.chars().take(100).collect();
365 log_admission_decision(&decision, &preview, role, admission.threshold());
366 if !decision.admitted {
367 return Ok(None);
368 }
369 }
370
371 if let Some(gate) = &self.quality_gate
372 && gate
373 .evaluate(content, self.effective_embed_provider(), &[])
374 .await
375 .is_some()
376 {
377 return Ok(None);
378 }
379
380 let message_id = self
381 .sqlite
382 .save_message(conversation_id, role, content)
383 .await?;
384
385 self.embed_and_store_regular(message_id, conversation_id, role, content);
386
387 Ok(Some(message_id))
388 }
389
390 #[cfg_attr(
399 feature = "profiling",
400 tracing::instrument(name = "memory.remember", skip_all, fields(content_len = %content.len()))
401 )]
402 pub async fn remember_with_parts(
403 &self,
404 conversation_id: ConversationId,
405 role: &str,
406 content: &str,
407 parts_json: &str,
408 goal_text: Option<&str>,
409 ) -> Result<(Option<MessageId>, bool), MemoryError> {
410 if let Some(ref admission) = self.admission_control {
412 let decision = admission
413 .evaluate(
414 content,
415 role,
416 self.effective_embed_provider(),
417 self.qdrant.as_ref(),
418 goal_text,
419 )
420 .await;
421 let preview: String = content.chars().take(100).collect();
422 log_admission_decision(&decision, &preview, role, admission.threshold());
423 if !decision.admitted {
424 return Ok((None, false));
425 }
426 }
427
428 if let Some(gate) = &self.quality_gate
429 && gate
430 .evaluate(content, self.effective_embed_provider(), &[])
431 .await
432 .is_some()
433 {
434 return Ok((None, false));
435 }
436
437 let message_id = self
438 .sqlite
439 .save_message_with_parts(conversation_id, role, content, parts_json)
440 .await?;
441
442 let embedding_stored =
443 self.embed_and_store_regular(message_id, conversation_id, role, content);
444
445 Ok((Some(message_id), embedding_stored))
446 }
447
448 #[cfg_attr(
460 feature = "profiling",
461 tracing::instrument(name = "memory.remember", skip_all, fields(content_len = %content.len()))
462 )]
463 pub async fn remember_tool_output(
464 &self,
465 conversation_id: ConversationId,
466 role: &str,
467 content: &str,
468 parts_json: &str,
469 embed_ctx: EmbedContext,
470 ) -> Result<(Option<MessageId>, bool), MemoryError> {
471 if let Some(ref admission) = self.admission_control {
472 let decision = admission
473 .evaluate(
474 content,
475 role,
476 self.effective_embed_provider(),
477 self.qdrant.as_ref(),
478 None,
479 )
480 .await;
481 let preview: String = content.chars().take(100).collect();
482 log_admission_decision(&decision, &preview, role, admission.threshold());
483 if !decision.admitted {
484 return Ok((None, false));
485 }
486 }
487
488 let message_id = self
489 .sqlite
490 .save_message_with_parts(conversation_id, role, content, parts_json)
491 .await?;
492
493 let embedding_stored = self.embed_chunks_with_tool_context(
494 message_id,
495 conversation_id,
496 role,
497 content,
498 embed_ctx,
499 );
500
501 Ok((Some(message_id), embedding_stored))
502 }
503
504 #[cfg_attr(
515 feature = "profiling",
516 tracing::instrument(name = "memory.remember", skip_all, fields(content_len = %content.len()))
517 )]
518 pub async fn remember_categorized(
519 &self,
520 conversation_id: ConversationId,
521 role: &str,
522 content: &str,
523 category: Option<&str>,
524 goal_text: Option<&str>,
525 ) -> Result<Option<MessageId>, MemoryError> {
526 if let Some(ref admission) = self.admission_control {
527 let decision = admission
528 .evaluate(
529 content,
530 role,
531 self.effective_embed_provider(),
532 self.qdrant.as_ref(),
533 goal_text,
534 )
535 .await;
536 let preview: String = content.chars().take(100).collect();
537 log_admission_decision(&decision, &preview, role, admission.threshold());
538 if !decision.admitted {
539 return Ok(None);
540 }
541 }
542
543 let message_id = self
544 .sqlite
545 .save_message_with_category(conversation_id, role, content, category)
546 .await?;
547
548 self.embed_and_store_with_category(message_id, conversation_id, role, content, category);
549
550 Ok(Some(message_id))
551 }
552
553 pub async fn recall_with_category(
561 &self,
562 query: &str,
563 limit: usize,
564 filter: Option<SearchFilter>,
565 category: Option<&str>,
566 ) -> Result<Vec<RecalledMessage>, MemoryError> {
567 let filter_with_category = filter.map(|mut f| {
568 f.category = category.map(str::to_owned);
569 f
570 });
571 self.recall(query, limit, filter_with_category).await
572 }
573
574 pub fn reap_embed_tasks(&self) {
578 if let Ok(mut tasks) = self.embed_tasks.lock() {
579 while tasks.try_join_next().is_some() {}
580 }
581 }
582
583 fn spawn_embed_bg<F>(&self, fut: F) -> bool
587 where
588 F: std::future::Future<Output = ()> + Send + 'static,
589 {
590 let Ok(mut tasks) = self.embed_tasks.lock() else {
591 return false;
592 };
593 while tasks.try_join_next().is_some() {}
595 if tasks.len() >= MAX_EMBED_BG_TASKS {
596 tracing::debug!("background embed task limit reached, skipping");
597 return false;
598 }
599 tasks.spawn(fut);
600 true
601 }
602
603 fn embed_and_store_with_category(
607 &self,
608 message_id: MessageId,
609 conversation_id: ConversationId,
610 role: &str,
611 content: &str,
612 category: Option<&str>,
613 ) -> bool {
614 let Some(qdrant) = self.qdrant.clone() else {
615 return false;
616 };
617 let embed_provider = self.effective_embed_provider().clone();
618 if !embed_provider.supports_embeddings() {
619 return false;
620 }
621 self.spawn_embed_bg(embed_and_store_with_category_bg(
622 EmbedBgArgs {
623 qdrant,
624 embed_provider,
625 embedding_model: self.embedding_model.clone(),
626 message_id,
627 conversation_id,
628 role: role.to_owned(),
629 content: content.to_owned(),
630 last_qdrant_warn: Arc::clone(&self.last_qdrant_warn),
631 },
632 category.map(str::to_owned),
633 ))
634 }
635
636 fn embed_and_store_regular(
640 &self,
641 message_id: MessageId,
642 conversation_id: ConversationId,
643 role: &str,
644 content: &str,
645 ) -> bool {
646 let Some(qdrant) = self.qdrant.clone() else {
647 return false;
648 };
649 let embed_provider = self.effective_embed_provider().clone();
650 if !embed_provider.supports_embeddings() {
651 return false;
652 }
653 self.spawn_embed_bg(embed_and_store_regular_bg(EmbedBgArgs {
654 qdrant,
655 embed_provider,
656 embedding_model: self.embedding_model.clone(),
657 message_id,
658 conversation_id,
659 role: role.to_owned(),
660 content: content.to_owned(),
661 last_qdrant_warn: Arc::clone(&self.last_qdrant_warn),
662 }))
663 }
664
665 fn embed_chunks_with_tool_context(
669 &self,
670 message_id: MessageId,
671 conversation_id: ConversationId,
672 role: &str,
673 content: &str,
674 embed_ctx: EmbedContext,
675 ) -> bool {
676 let Some(qdrant) = self.qdrant.clone() else {
677 return false;
678 };
679 let embed_provider = self.effective_embed_provider().clone();
680 if !embed_provider.supports_embeddings() {
681 return false;
682 }
683 self.spawn_embed_bg(embed_chunks_with_tool_context_bg(
684 EmbedBgArgs {
685 qdrant,
686 embed_provider,
687 embedding_model: self.embedding_model.clone(),
688 message_id,
689 conversation_id,
690 role: role.to_owned(),
691 content: content.to_owned(),
692 last_qdrant_warn: Arc::clone(&self.last_qdrant_warn),
693 },
694 embed_ctx,
695 ))
696 }
697
698 pub async fn save_only(
706 &self,
707 conversation_id: ConversationId,
708 role: &str,
709 content: &str,
710 parts_json: &str,
711 ) -> Result<MessageId, MemoryError> {
712 self.sqlite
713 .save_message_with_parts(conversation_id, role, content, parts_json)
714 .await
715 }
716
717 #[cfg_attr(
727 feature = "profiling",
728 tracing::instrument(name = "memory.recall", skip_all, fields(query_len = %query.len(), result_count = tracing::field::Empty, top_score = tracing::field::Empty))
729 )]
730 pub async fn recall(
731 &self,
732 query: &str,
733 limit: usize,
734 filter: Option<SearchFilter>,
735 ) -> Result<Vec<RecalledMessage>, MemoryError> {
736 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
737
738 tracing::debug!(
739 query_len = query.len(),
740 limit,
741 has_filter = filter.is_some(),
742 conversation_id = conversation_id.map(|c| c.0),
743 has_qdrant = self.qdrant.is_some(),
744 "recall: starting hybrid search"
745 );
746
747 let keyword_results = match self
748 .sqlite
749 .keyword_search(query, self.effective_depth(limit), conversation_id)
750 .await
751 {
752 Ok(results) => results,
753 Err(e) => {
754 tracing::warn!("FTS5 keyword search failed: {e:#}");
755 Vec::new()
756 }
757 };
758
759 let vector_results = if let Some(qdrant) = &self.qdrant
760 && self.effective_embed_provider().supports_embeddings()
761 {
762 let embed_input = self.apply_search_prompt(query);
763 let query_vector = self.effective_embed_provider().embed(&embed_input).await?;
764 let query_vector = self.apply_query_bias(query, query_vector).await;
765 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
766 qdrant.ensure_collection(vector_size).await?;
767 qdrant
768 .search(&query_vector, self.effective_depth(limit), filter)
769 .await?
770 } else {
771 Vec::new()
772 };
773
774 let results = self
775 .recall_merge_and_rank(keyword_results, vector_results, limit)
776 .await?;
777 #[cfg(feature = "profiling")]
778 {
779 let span = tracing::Span::current();
780 span.record("result_count", results.len());
781 if let Some(top) = results.first() {
782 span.record("top_score", top.score);
783 }
784 }
785 Ok(results)
786 }
787
788 pub(super) async fn recall_fts5_raw(
789 &self,
790 query: &str,
791 limit: usize,
792 conversation_id: Option<ConversationId>,
793 ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
794 self.sqlite
795 .keyword_search(query, self.effective_depth(limit), conversation_id)
796 .await
797 }
798
799 pub(super) async fn recall_vectors_raw(
800 &self,
801 query: &str,
802 limit: usize,
803 filter: Option<SearchFilter>,
804 ) -> Result<Vec<crate::embedding_store::SearchResult>, MemoryError> {
805 let Some(qdrant) = &self.qdrant else {
806 return Ok(Vec::new());
807 };
808 if !self.effective_embed_provider().supports_embeddings() {
809 return Ok(Vec::new());
810 }
811 let embed_input = self.apply_search_prompt(query);
812 let query_vector = self.effective_embed_provider().embed(&embed_input).await?;
813 let query_vector = self.apply_query_bias(query, query_vector).await;
814 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
815 qdrant.ensure_collection(vector_size).await?;
816 qdrant
817 .search(&query_vector, self.effective_depth(limit), filter)
818 .await
819 }
820
821 #[allow(clippy::cast_possible_truncation, clippy::too_many_lines)]
830 pub(super) async fn recall_merge_and_rank(
831 &self,
832 keyword_results: Vec<(MessageId, f64)>,
833 vector_results: Vec<crate::embedding_store::SearchResult>,
834 limit: usize,
835 ) -> Result<Vec<RecalledMessage>, MemoryError> {
836 tracing::debug!(
837 vector_count = vector_results.len(),
838 keyword_count = keyword_results.len(),
839 limit,
840 "recall: merging search results"
841 );
842
843 let mut scores: std::collections::HashMap<MessageId, f64> =
844 std::collections::HashMap::new();
845
846 if !vector_results.is_empty() {
847 let max_vs = vector_results
848 .iter()
849 .map(|r| r.score)
850 .fold(f32::NEG_INFINITY, f32::max);
851 let norm = if max_vs > 0.0 { max_vs } else { 1.0 };
852 for r in &vector_results {
853 let normalized = f64::from(r.score / norm);
854 *scores.entry(r.message_id).or_default() += normalized * self.vector_weight;
855 }
856 }
857
858 if !keyword_results.is_empty() {
859 let max_ks = keyword_results
860 .iter()
861 .map(|r| r.1)
862 .fold(f64::NEG_INFINITY, f64::max);
863 let norm = if max_ks > 0.0 { max_ks } else { 1.0 };
864 for &(msg_id, score) in &keyword_results {
865 let normalized = score / norm;
866 *scores.entry(msg_id).or_default() += normalized * self.keyword_weight;
867 }
868 }
869
870 if scores.is_empty() {
871 tracing::debug!("recall: empty merge, no overlapping scores");
872 return Ok(Vec::new());
873 }
874
875 let mut ranked: Vec<(MessageId, f64)> = scores.into_iter().collect();
876 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
877
878 tracing::debug!(
879 merged = ranked.len(),
880 top_score = ranked.first().map(|r| r.1),
881 bottom_score = ranked.last().map(|r| r.1),
882 vector_weight = %self.vector_weight,
883 keyword_weight = %self.keyword_weight,
884 "recall: weighted merge complete"
885 );
886
887 if self.temporal_decay_enabled && self.temporal_decay_half_life_days > 0 {
888 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
889 match self.sqlite.message_timestamps(&ids).await {
890 Ok(timestamps) => {
891 apply_temporal_decay(
892 &mut ranked,
893 ×tamps,
894 self.temporal_decay_half_life_days,
895 );
896 ranked
897 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
898 tracing::debug!(
899 half_life_days = self.temporal_decay_half_life_days,
900 top_score_after = ranked.first().map(|r| r.1),
901 "recall: temporal decay applied"
902 );
903 }
904 Err(e) => {
905 tracing::warn!("temporal decay: failed to fetch timestamps: {e:#}");
906 }
907 }
908 }
909
910 if self.mmr_enabled && !vector_results.is_empty() {
911 if let Some(qdrant) = &self.qdrant {
912 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
913 match qdrant.get_vectors(&ids).await {
914 Ok(vec_map) if !vec_map.is_empty() => {
915 let ranked_len_before = ranked.len();
916 ranked = apply_mmr(&ranked, &vec_map, self.mmr_lambda, limit);
917 tracing::debug!(
918 before = ranked_len_before,
919 after = ranked.len(),
920 lambda = %self.mmr_lambda,
921 "recall: mmr re-ranked"
922 );
923 }
924 Ok(_) => {
925 ranked.truncate(limit);
926 }
927 Err(e) => {
928 tracing::warn!("MMR: failed to fetch vectors: {e:#}");
929 ranked.truncate(limit);
930 }
931 }
932 } else {
933 ranked.truncate(limit);
934 }
935 } else {
936 ranked.truncate(limit);
937 }
938
939 if self.importance_enabled && !ranked.is_empty() {
940 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
941 match self.sqlite.fetch_importance_scores(&ids).await {
942 Ok(scores) => {
943 for (msg_id, score) in &mut ranked {
944 if let Some(&imp) = scores.get(msg_id) {
945 *score += imp * self.importance_weight;
946 }
947 }
948 ranked
949 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
950 tracing::debug!(
951 importance_weight = %self.importance_weight,
952 "recall: importance scores blended"
953 );
954 }
955 Err(e) => {
956 tracing::warn!("importance scoring: failed to fetch scores: {e:#}");
957 }
958 }
959 }
960
961 if (self.tier_boost_semantic - 1.0).abs() > f64::EPSILON && !ranked.is_empty() {
965 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
966 match self.sqlite.fetch_tiers(&ids).await {
967 Ok(tiers) => {
968 let bonus = self.tier_boost_semantic - 1.0;
969 let mut boosted = false;
970 for (msg_id, score) in &mut ranked {
971 if tiers.get(msg_id).map(String::as_str) == Some("semantic") {
972 *score += bonus;
973 boosted = true;
974 }
975 }
976 if boosted {
977 ranked.sort_by(|a, b| {
978 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
979 });
980 tracing::debug!(
981 tier_boost = %self.tier_boost_semantic,
982 "recall: semantic tier boost applied"
983 );
984 }
985 }
986 Err(e) => {
987 tracing::warn!("tier boost: failed to fetch tiers: {e:#}");
988 }
989 }
990 }
991
992 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
993
994 if !ids.is_empty()
995 && let Err(e) = self.batch_increment_access_count(ids.clone()).await
996 {
997 tracing::warn!("recall: failed to increment access counts: {e:#}");
998 }
999
1000 if let Err(e) = self.sqlite.mark_training_recalled(&ids).await {
1002 tracing::debug!(
1003 error = %e,
1004 "recall: failed to mark training data as recalled (non-fatal)"
1005 );
1006 }
1007
1008 let messages = self.sqlite.messages_by_ids(&ids).await?;
1009 let msg_map: std::collections::HashMap<MessageId, _> = messages.into_iter().collect();
1010
1011 let recalled: Vec<RecalledMessage> = ranked
1012 .iter()
1013 .filter_map(|(msg_id, score)| {
1014 msg_map.get(msg_id).map(|msg| RecalledMessage {
1015 message: msg.clone(),
1016 #[expect(clippy::cast_possible_truncation)]
1017 score: *score as f32,
1018 })
1019 })
1020 .collect();
1021
1022 tracing::debug!(final_count = recalled.len(), "recall: final results");
1023
1024 Ok(recalled)
1025 }
1026
1027 #[cfg_attr(
1036 feature = "profiling",
1037 tracing::instrument(name = "memory.recall", skip_all, fields(query_len = %query.len(), result_count = tracing::field::Empty))
1038 )]
1039 pub async fn recall_routed(
1040 &self,
1041 query: &str,
1042 limit: usize,
1043 filter: Option<SearchFilter>,
1044 router: &dyn crate::router::MemoryRouter,
1045 ) -> Result<Vec<RecalledMessage>, MemoryError> {
1046 use crate::router::MemoryRoute;
1047
1048 let route = router.route(query);
1049 tracing::debug!(?route, query_len = query.len(), "memory routing decision");
1050
1051 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
1052
1053 let (keyword_results, vector_results): (
1054 Vec<(MessageId, f64)>,
1055 Vec<crate::embedding_store::SearchResult>,
1056 ) = match route {
1057 MemoryRoute::Keyword => {
1058 let kw = self.recall_fts5_raw(query, limit, conversation_id).await?;
1059 (kw, Vec::new())
1060 }
1061 MemoryRoute::Semantic => {
1062 let vr = self.recall_vectors_raw(query, limit, filter).await?;
1063 (Vec::new(), vr)
1064 }
1065 MemoryRoute::Hybrid => {
1066 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
1067 Ok(r) => r,
1068 Err(e) => {
1069 tracing::warn!("FTS5 keyword search failed: {e:#}");
1070 Vec::new()
1071 }
1072 };
1073 let vr = self.recall_vectors_raw(query, limit, filter).await?;
1074 (kw, vr)
1075 }
1076 MemoryRoute::Episodic => {
1085 let range = crate::router::resolve_temporal_range(query, chrono::Utc::now());
1086 let cleaned = crate::router::strip_temporal_keywords(query);
1087 let search_query = if cleaned.is_empty() { query } else { &cleaned };
1088 let kw = if let Some(ref r) = range {
1089 self.sqlite
1090 .keyword_search_with_time_range(
1091 search_query,
1092 limit,
1093 conversation_id,
1094 r.after.as_deref(),
1095 r.before.as_deref(),
1096 )
1097 .await?
1098 } else {
1099 self.recall_fts5_raw(search_query, limit, conversation_id)
1100 .await?
1101 };
1102 tracing::debug!(
1103 has_range = range.is_some(),
1104 cleaned_query = %search_query,
1105 keyword_count = kw.len(),
1106 "recall: episodic path"
1107 );
1108 (kw, Vec::new())
1109 }
1110 MemoryRoute::Graph => {
1113 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
1114 Ok(r) => r,
1115 Err(e) => {
1116 tracing::warn!("FTS5 keyword search failed (graph→hybrid fallback): {e:#}");
1117 Vec::new()
1118 }
1119 };
1120 let vr = self.recall_vectors_raw(query, limit, filter).await?;
1121 (kw, vr)
1122 }
1123 };
1124
1125 tracing::debug!(
1126 keyword_count = keyword_results.len(),
1127 vector_count = vector_results.len(),
1128 "recall: routed search results"
1129 );
1130
1131 self.recall_merge_and_rank(keyword_results, vector_results, limit)
1132 .await
1133 }
1134
1135 #[cfg_attr(
1146 feature = "profiling",
1147 tracing::instrument(name = "memory.recall", skip_all, fields(query_len = %query.len(), result_count = tracing::field::Empty))
1148 )]
1149 pub async fn recall_routed_async(
1150 &self,
1151 query: &str,
1152 limit: usize,
1153 filter: Option<crate::embedding_store::SearchFilter>,
1154 router: &dyn crate::router::AsyncMemoryRouter,
1155 ) -> Result<Vec<RecalledMessage>, MemoryError> {
1156 use crate::router::MemoryRoute;
1157
1158 let decision = router.route_async(query).await;
1159 let route = decision.route;
1160 tracing::debug!(
1161 ?route,
1162 confidence = decision.confidence,
1163 query_len = query.len(),
1164 "memory routing decision (async)"
1165 );
1166
1167 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
1168
1169 let (keyword_results, vector_results): (
1170 Vec<(crate::types::MessageId, f64)>,
1171 Vec<crate::embedding_store::SearchResult>,
1172 ) = match route {
1173 MemoryRoute::Keyword => {
1174 let kw = self.recall_fts5_raw(query, limit, conversation_id).await?;
1175 (kw, Vec::new())
1176 }
1177 MemoryRoute::Semantic => {
1178 let vr = self.recall_vectors_raw(query, limit, filter).await?;
1179 (Vec::new(), vr)
1180 }
1181 MemoryRoute::Hybrid => {
1182 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
1183 Ok(r) => r,
1184 Err(e) => {
1185 tracing::warn!("FTS5 keyword search failed: {e:#}");
1186 Vec::new()
1187 }
1188 };
1189 let vr = self.recall_vectors_raw(query, limit, filter).await?;
1190 (kw, vr)
1191 }
1192 MemoryRoute::Episodic => {
1193 let range = crate::router::resolve_temporal_range(query, chrono::Utc::now());
1194 let cleaned = crate::router::strip_temporal_keywords(query);
1195 let search_query = if cleaned.is_empty() { query } else { &cleaned };
1196 let kw = if let Some(ref r) = range {
1197 self.sqlite
1198 .keyword_search_with_time_range(
1199 search_query,
1200 limit,
1201 conversation_id,
1202 r.after.as_deref(),
1203 r.before.as_deref(),
1204 )
1205 .await?
1206 } else {
1207 self.recall_fts5_raw(search_query, limit, conversation_id)
1208 .await?
1209 };
1210 (kw, Vec::new())
1211 }
1212 MemoryRoute::Graph => {
1213 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
1214 Ok(r) => r,
1215 Err(e) => {
1216 tracing::warn!("FTS5 keyword search failed (graph→hybrid fallback): {e:#}");
1217 Vec::new()
1218 }
1219 };
1220 let vr = self.recall_vectors_raw(query, limit, filter).await?;
1221 (kw, vr)
1222 }
1223 };
1224
1225 tracing::debug!(
1226 keyword_count = keyword_results.len(),
1227 vector_count = vector_results.len(),
1228 "recall: routed search results (async)"
1229 );
1230
1231 self.recall_merge_and_rank(keyword_results, vector_results, limit)
1232 .await
1233 }
1234
1235 #[cfg_attr(
1249 feature = "profiling",
1250 tracing::instrument(name = "memory.recall_graph", skip_all, fields(result_count = tracing::field::Empty))
1251 )]
1252 pub async fn recall_graph(
1253 &self,
1254 query: &str,
1255 limit: usize,
1256 max_hops: u32,
1257 at_timestamp: Option<&str>,
1258 temporal_decay_rate: f64,
1259 edge_types: &[crate::graph::EdgeType],
1260 ) -> Result<Vec<crate::graph::types::GraphFact>, MemoryError> {
1261 let Some(store) = &self.graph_store else {
1262 return Ok(Vec::new());
1263 };
1264
1265 tracing::debug!(
1266 query_len = query.len(),
1267 limit,
1268 max_hops,
1269 "graph: starting recall"
1270 );
1271
1272 let results = crate::graph::retrieval::graph_recall(
1273 store,
1274 self.qdrant.as_deref(),
1275 &self.provider,
1276 query,
1277 limit,
1278 max_hops,
1279 at_timestamp,
1280 temporal_decay_rate,
1281 edge_types,
1282 self.hebbian_enabled,
1283 self.hebbian_lr,
1284 )
1285 .await?;
1286
1287 tracing::debug!(result_count = results.len(), "graph: recall complete");
1288 #[cfg(feature = "profiling")]
1289 tracing::Span::current().record("result_count", results.len());
1290
1291 Ok(results)
1292 }
1293
1294 #[cfg_attr(
1303 feature = "profiling",
1304 tracing::instrument(name = "memory.recall_graph", skip_all, fields(result_count = tracing::field::Empty))
1305 )]
1306 pub async fn recall_graph_activated(
1307 &self,
1308 query: &str,
1309 limit: usize,
1310 params: crate::graph::SpreadingActivationParams,
1311 edge_types: &[crate::graph::EdgeType],
1312 ) -> Result<Vec<crate::graph::activation::ActivatedFact>, MemoryError> {
1313 let Some(store) = &self.graph_store else {
1314 return Ok(Vec::new());
1315 };
1316
1317 tracing::debug!(
1318 query_len = query.len(),
1319 limit,
1320 "spreading activation: starting graph recall"
1321 );
1322
1323 let embeddings = self.qdrant.as_deref();
1324 let results = crate::graph::retrieval::graph_recall_activated(
1325 store,
1326 embeddings,
1327 &self.provider,
1328 query,
1329 limit,
1330 params,
1331 edge_types,
1332 self.hebbian_enabled,
1333 self.hebbian_lr,
1334 )
1335 .await?;
1336
1337 tracing::debug!(
1338 result_count = results.len(),
1339 "spreading activation: graph recall complete"
1340 );
1341
1342 Ok(results)
1343 }
1344
1345 pub async fn recall_graph_astar(
1353 &self,
1354 query: &str,
1355 limit: usize,
1356 max_hops: u32,
1357 temporal_decay_rate: f64,
1358 edge_types: &[crate::graph::EdgeType],
1359 ) -> Result<Vec<crate::graph::types::GraphFact>, MemoryError> {
1360 let Some(store) = &self.graph_store else {
1361 return Ok(Vec::new());
1362 };
1363 crate::graph::retrieval_astar::graph_recall_astar(
1364 store,
1365 self.qdrant.as_deref(),
1366 &self.provider,
1367 query,
1368 limit,
1369 max_hops,
1370 edge_types,
1371 temporal_decay_rate,
1372 self.hebbian_enabled,
1373 self.hebbian_lr,
1374 )
1375 .await
1376 }
1377
1378 pub async fn recall_graph_watercircles(
1386 &self,
1387 query: &str,
1388 limit: usize,
1389 max_hops: u32,
1390 ring_limit: usize,
1391 temporal_decay_rate: f64,
1392 edge_types: &[crate::graph::EdgeType],
1393 ) -> Result<Vec<crate::graph::types::GraphFact>, MemoryError> {
1394 let Some(store) = &self.graph_store else {
1395 return Ok(Vec::new());
1396 };
1397 crate::graph::retrieval_watercircles::graph_recall_watercircles(
1398 store,
1399 self.qdrant.as_deref(),
1400 &self.provider,
1401 query,
1402 limit,
1403 max_hops,
1404 ring_limit,
1405 edge_types,
1406 temporal_decay_rate,
1407 self.hebbian_enabled,
1408 self.hebbian_lr,
1409 )
1410 .await
1411 }
1412
1413 pub async fn recall_graph_beam(
1421 &self,
1422 query: &str,
1423 limit: usize,
1424 beam_width: usize,
1425 max_hops: u32,
1426 temporal_decay_rate: f64,
1427 edge_types: &[crate::graph::EdgeType],
1428 ) -> Result<Vec<crate::graph::types::GraphFact>, MemoryError> {
1429 let Some(store) = &self.graph_store else {
1430 return Ok(Vec::new());
1431 };
1432 crate::graph::retrieval_beam::graph_recall_beam(
1433 store,
1434 self.qdrant.as_deref(),
1435 &self.provider,
1436 query,
1437 limit,
1438 beam_width,
1439 max_hops,
1440 edge_types,
1441 temporal_decay_rate,
1442 self.hebbian_enabled,
1443 self.hebbian_lr,
1444 )
1445 .await
1446 }
1447
1448 pub async fn classify_graph_strategy(&self, query: &str) -> String {
1453 crate::graph::strategy_classifier::classify_retrieval_strategy(&self.provider, query).await
1454 }
1455
1456 #[cfg_attr(
1466 feature = "profiling",
1467 tracing::instrument(
1468 name = "memory.recall_graph_hela",
1469 skip_all,
1470 fields(result_count = tracing::field::Empty)
1471 )
1472 )]
1473 pub async fn recall_graph_hela(
1474 &self,
1475 query: &str,
1476 limit: usize,
1477 params: crate::graph::HelaSpreadParams,
1478 ) -> Result<Vec<crate::graph::HelaFact>, MemoryError> {
1479 let Some(store) = &self.graph_store else {
1480 return Ok(Vec::new());
1481 };
1482 let Some(embeddings) = &self.qdrant else {
1483 return Ok(Vec::new());
1484 };
1485
1486 let store = Arc::clone(store);
1487 let embeddings = Arc::clone(embeddings);
1488 let provider = self.provider.clone();
1489 let hebbian_enabled = self.hebbian_enabled;
1490 let hebbian_lr = self.hebbian_lr;
1491
1492 let results = tokio::time::timeout(
1493 std::time::Duration::from_millis(200),
1494 crate::graph::hela_spreading_recall(
1495 &store,
1496 &embeddings,
1497 &provider,
1498 query,
1499 limit,
1500 ¶ms,
1501 hebbian_enabled,
1502 hebbian_lr,
1503 ),
1504 )
1505 .await
1506 .unwrap_or_else(|_| {
1507 tracing::warn!("memory.recall_graph_hela: outer 200ms timeout exceeded");
1508 Ok(Vec::new())
1509 })?;
1510
1511 #[cfg(feature = "profiling")]
1512 tracing::Span::current().record("result_count", results.len());
1513
1514 Ok(results)
1515 }
1516
1517 async fn batch_increment_access_count(
1525 &self,
1526 message_ids: Vec<MessageId>,
1527 ) -> Result<(), MemoryError> {
1528 if message_ids.is_empty() {
1529 return Ok(());
1530 }
1531 self.sqlite.increment_access_counts(&message_ids).await
1532 }
1533
1534 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
1540 match &self.qdrant {
1541 Some(qdrant) => qdrant.has_embedding(message_id).await,
1542 None => Ok(false),
1543 }
1544 }
1545
1546 pub async fn embed_missing(
1562 &self,
1563 progress_tx: Option<tokio::sync::watch::Sender<Option<super::BackfillProgress>>>,
1564 ) -> Result<usize, MemoryError> {
1565 if self.qdrant.is_none() || !self.effective_embed_provider().supports_embeddings() {
1566 return Ok(0);
1567 }
1568
1569 let total = self.sqlite.count_unembedded_messages().await?;
1570 if total == 0 {
1571 return Ok(0);
1572 }
1573
1574 if let Some(tx) = &progress_tx {
1575 let _ = tx.send(Some(super::BackfillProgress { done: 0, total }));
1576 }
1577
1578 let mut done = 0usize;
1579 let mut succeeded = 0usize;
1580
1581 loop {
1582 const BATCH_SIZE: usize = 32;
1583 const BATCH_SIZE_I64: i64 = 32;
1584 let rows: Vec<_> = self
1585 .sqlite
1586 .stream_unembedded_messages(BATCH_SIZE_I64)
1587 .try_collect()
1588 .await?;
1589
1590 if rows.is_empty() {
1591 break;
1592 }
1593
1594 let batch_len = rows.len();
1595
1596 let results: Vec<bool> = futures::stream::iter(rows)
1597 .map(|(msg_id, conv_id, role, content)| async move {
1598 self.embed_and_store_regular(msg_id, conv_id, &role, &content)
1599 })
1600 .buffer_unordered(4)
1601 .collect()
1602 .await;
1603
1604 for ok in &results {
1605 done += 1;
1606 if *ok {
1607 succeeded += 1;
1608 }
1609 if let Some(tx) = &progress_tx {
1610 let _ = tx.send(Some(super::BackfillProgress { done, total }));
1611 }
1612 }
1613
1614 let batch_succeeded = results.iter().filter(|&&b| b).count();
1615 if batch_succeeded > 0 {
1616 tracing::debug!("Backfill batch: {batch_succeeded}/{batch_len} embedded");
1617 }
1618
1619 if batch_len < BATCH_SIZE {
1620 break;
1621 }
1622 }
1623
1624 if let Some(tx) = &progress_tx {
1625 let _ = tx.send(None);
1626 }
1627
1628 if done > 0 {
1629 tracing::info!("Embedded {succeeded}/{total} missing messages");
1630 }
1631 Ok(succeeded)
1632 }
1633}
1634
1635#[cfg(test)]
1636mod tests {
1637 use super::*;
1638
1639 #[test]
1640 fn embed_context_default_all_none() {
1641 let ctx = EmbedContext::default();
1642 assert!(ctx.tool_name.is_none());
1643 assert!(ctx.exit_code.is_none());
1644 assert!(ctx.timestamp.is_none());
1645 }
1646
1647 #[test]
1648 fn embed_context_fields_set_correctly() {
1649 let ctx = EmbedContext {
1650 tool_name: Some("shell".to_string()),
1651 exit_code: Some(0),
1652 timestamp: Some("2026-04-04T00:00:00Z".to_string()),
1653 };
1654 assert_eq!(ctx.tool_name.as_deref(), Some("shell"));
1655 assert_eq!(ctx.exit_code, Some(0));
1656 assert_eq!(ctx.timestamp.as_deref(), Some("2026-04-04T00:00:00Z"));
1657 }
1658
1659 #[test]
1660 fn embed_context_non_zero_exit_code() {
1661 let ctx = EmbedContext {
1662 tool_name: Some("shell".to_string()),
1663 exit_code: Some(1),
1664 timestamp: None,
1665 };
1666 assert_eq!(ctx.exit_code, Some(1));
1667 assert!(ctx.timestamp.is_none());
1668 }
1669
1670 async fn make_semantic_memory() -> crate::semantic::SemanticMemory {
1671 use std::sync::Arc;
1672 use std::sync::atomic::AtomicU64;
1673 use zeph_llm::any::AnyProvider;
1674 use zeph_llm::mock::MockProvider;
1675
1676 let provider = AnyProvider::Mock(MockProvider::default());
1677 let sqlite = crate::store::SqliteStore::new(":memory:").await.unwrap();
1678 crate::semantic::SemanticMemory {
1679 sqlite,
1680 qdrant: None,
1681 provider,
1682 embed_provider: None,
1683 embedding_model: "test-model".into(),
1684 vector_weight: 0.7,
1685 keyword_weight: 0.3,
1686 temporal_decay_enabled: false,
1687 temporal_decay_half_life_days: 30,
1688 mmr_enabled: false,
1689 mmr_lambda: 0.7,
1690 importance_enabled: false,
1691 importance_weight: 0.15,
1692 token_counter: Arc::new(crate::token_counter::TokenCounter::new()),
1693 graph_store: None,
1694 experience: None,
1695 community_detection_failures: Arc::new(AtomicU64::new(0)),
1696 graph_extraction_count: Arc::new(AtomicU64::new(0)),
1697 graph_extraction_failures: Arc::new(AtomicU64::new(0)),
1698 last_qdrant_warn: Arc::new(AtomicU64::new(0)),
1699 tier_boost_semantic: 1.3,
1700 admission_control: None,
1701 quality_gate: None,
1702 key_facts_dedup_threshold: 0.95,
1703 embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
1704 retrieval_depth: 0,
1705 search_prompt_template: String::new(),
1706 depth_below_limit_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
1707 missing_placeholder_warned: Arc::new(std::sync::atomic::AtomicBool::new(false)),
1708 reasoning: None,
1709 query_bias_correction: false,
1710 query_bias_profile_weight: 0.25,
1711 profile_centroid: tokio::sync::RwLock::new(None),
1712 profile_centroid_ttl_secs: 300,
1713 hebbian_enabled: false,
1714 hebbian_lr: 0.1,
1715 hebbian_spread: crate::HelaSpreadRuntime::default(),
1716 }
1717 }
1718
1719 #[tokio::test]
1720 async fn spawn_embed_bg_returns_true_when_capacity_available() {
1721 let memory = make_semantic_memory().await;
1722 let dispatched = memory.spawn_embed_bg(std::future::ready(()));
1723 assert!(
1724 dispatched,
1725 "spawn_embed_bg must return true when a task was successfully spawned"
1726 );
1727 }
1728
1729 #[tokio::test]
1730 async fn spawn_embed_bg_returns_false_at_capacity() {
1731 let memory = make_semantic_memory().await;
1732
1733 {
1735 let mut tasks = memory.embed_tasks.lock().unwrap();
1736 for _ in 0..MAX_EMBED_BG_TASKS {
1737 tasks.spawn(std::future::pending::<()>());
1738 }
1739 }
1740
1741 let dispatched = memory.spawn_embed_bg(std::future::ready(()));
1742 assert!(
1743 !dispatched,
1744 "spawn_embed_bg must return false when the task limit is reached"
1745 );
1746 }
1747
1748 #[test]
1749 fn qdrant_warn_rate_limit_suppresses_within_window() {
1750 use std::sync::Arc;
1751 use std::sync::atomic::{AtomicU64, Ordering};
1752
1753 let last_warn = Arc::new(AtomicU64::new(0));
1754 let window_secs = 10u64;
1755
1756 let now1 = 100u64;
1758 let last1 = last_warn.load(Ordering::Relaxed);
1759 let should_warn1 = now1.saturating_sub(last1) >= window_secs;
1760 assert!(should_warn1, "first call must not be suppressed");
1761 if should_warn1 {
1762 last_warn.store(now1, Ordering::Relaxed);
1763 }
1764
1765 let now2 = 105u64;
1767 let last2 = last_warn.load(Ordering::Relaxed);
1768 let should_warn2 = now2.saturating_sub(last2) >= window_secs;
1769 assert!(!should_warn2, "call within 10s window must be suppressed");
1770
1771 let now3 = 110u64;
1773 let last3 = last_warn.load(Ordering::Relaxed);
1774 let should_warn3 = now3.saturating_sub(last3) >= window_secs;
1775 assert!(
1776 should_warn3,
1777 "call after window expiry must not be suppressed"
1778 );
1779 }
1780
1781 #[test]
1782 fn qdrant_warn_rate_limit_shared_across_concurrent_sites() {
1783 use std::sync::Arc;
1784 use std::sync::atomic::{AtomicU64, Ordering};
1785
1786 let shared = Arc::new(AtomicU64::new(0));
1789 let window_secs = 10u64;
1790
1791 let site_a = Arc::clone(&shared);
1792 let site_b = Arc::clone(&shared);
1793
1794 let now_a = 100u64;
1795 let last_a = site_a.load(Ordering::Relaxed);
1796 if now_a.saturating_sub(last_a) >= window_secs {
1797 site_a.store(now_a, Ordering::Relaxed);
1798 }
1799
1800 let now_b = 105u64;
1801 let last_b = site_b.load(Ordering::Relaxed);
1802 let warn_b = now_b.saturating_sub(last_b) >= window_secs;
1803 assert!(
1804 !warn_b,
1805 "site B must be suppressed because site A already warned within the window"
1806 );
1807 }
1808}