1use zeph_llm::provider::{LlmProvider as _, Message};
5
6use crate::admission::log_admission_decision;
7use crate::embedding_store::{MessageKind, SearchFilter};
8use crate::error::MemoryError;
9use crate::types::{ConversationId, MessageId};
10
11use super::SemanticMemory;
12use super::algorithms::{apply_mmr, apply_temporal_decay};
13
14#[derive(Debug)]
15pub struct RecalledMessage {
16 pub message: Message,
17 pub score: f32,
18}
19
20impl SemanticMemory {
21 pub async fn remember(
31 &self,
32 conversation_id: ConversationId,
33 role: &str,
34 content: &str,
35 goal_text: Option<&str>,
36 ) -> Result<Option<MessageId>, MemoryError> {
37 if let Some(ref admission) = self.admission_control {
39 let decision = admission
40 .evaluate(
41 content,
42 role,
43 &self.provider,
44 self.qdrant.as_ref(),
45 goal_text,
46 )
47 .await;
48 let preview: String = content.chars().take(100).collect();
49 log_admission_decision(&decision, &preview, role, admission.threshold());
50 if !decision.admitted {
51 return Ok(None);
52 }
53 }
54
55 let message_id = self
56 .sqlite
57 .save_message(conversation_id, role, content)
58 .await?;
59
60 if let Some(qdrant) = &self.qdrant
61 && self.provider.supports_embeddings()
62 {
63 match self.provider.embed(content).await {
64 Ok(vector) => {
65 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
66 if let Err(e) = qdrant.ensure_collection(vector_size).await {
67 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
68 } else if let Err(e) = qdrant
69 .store(
70 message_id,
71 conversation_id,
72 role,
73 vector,
74 MessageKind::Regular,
75 &self.embedding_model,
76 )
77 .await
78 {
79 tracing::warn!("Failed to store embedding: {e:#}");
80 }
81 }
82 Err(e) => {
83 tracing::warn!("Failed to generate embedding: {e:#}");
84 }
85 }
86 }
87
88 Ok(Some(message_id))
89 }
90
91 pub async fn remember_with_parts(
100 &self,
101 conversation_id: ConversationId,
102 role: &str,
103 content: &str,
104 parts_json: &str,
105 goal_text: Option<&str>,
106 ) -> Result<(Option<MessageId>, bool), MemoryError> {
107 if let Some(ref admission) = self.admission_control {
109 let decision = admission
110 .evaluate(
111 content,
112 role,
113 &self.provider,
114 self.qdrant.as_ref(),
115 goal_text,
116 )
117 .await;
118 let preview: String = content.chars().take(100).collect();
119 log_admission_decision(&decision, &preview, role, admission.threshold());
120 if !decision.admitted {
121 return Ok((None, false));
122 }
123 }
124
125 let message_id = self
126 .sqlite
127 .save_message_with_parts(conversation_id, role, content, parts_json)
128 .await?;
129
130 let mut embedding_stored = false;
131
132 if let Some(qdrant) = &self.qdrant
133 && self.provider.supports_embeddings()
134 {
135 match self.provider.embed(content).await {
136 Ok(vector) => {
137 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
138 if let Err(e) = qdrant.ensure_collection(vector_size).await {
139 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
140 } else if let Err(e) = qdrant
141 .store(
142 message_id,
143 conversation_id,
144 role,
145 vector,
146 MessageKind::Regular,
147 &self.embedding_model,
148 )
149 .await
150 {
151 tracing::warn!("Failed to store embedding: {e:#}");
152 } else {
153 embedding_stored = true;
154 }
155 }
156 Err(e) => {
157 tracing::warn!("Failed to generate embedding: {e:#}");
158 }
159 }
160 }
161
162 Ok((Some(message_id), embedding_stored))
163 }
164
165 pub async fn save_only(
173 &self,
174 conversation_id: ConversationId,
175 role: &str,
176 content: &str,
177 parts_json: &str,
178 ) -> Result<MessageId, MemoryError> {
179 self.sqlite
180 .save_message_with_parts(conversation_id, role, content, parts_json)
181 .await
182 }
183
184 pub async fn recall(
194 &self,
195 query: &str,
196 limit: usize,
197 filter: Option<SearchFilter>,
198 ) -> Result<Vec<RecalledMessage>, MemoryError> {
199 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
200
201 tracing::debug!(
202 query_len = query.len(),
203 limit,
204 has_filter = filter.is_some(),
205 conversation_id = conversation_id.map(|c| c.0),
206 has_qdrant = self.qdrant.is_some(),
207 "recall: starting hybrid search"
208 );
209
210 let keyword_results = match self
211 .sqlite
212 .keyword_search(query, limit * 2, conversation_id)
213 .await
214 {
215 Ok(results) => results,
216 Err(e) => {
217 tracing::warn!("FTS5 keyword search failed: {e:#}");
218 Vec::new()
219 }
220 };
221
222 let vector_results = if let Some(qdrant) = &self.qdrant
223 && self.provider.supports_embeddings()
224 {
225 let query_vector = self.provider.embed(query).await?;
226 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
227 qdrant.ensure_collection(vector_size).await?;
228 qdrant.search(&query_vector, limit * 2, filter).await?
229 } else {
230 Vec::new()
231 };
232
233 self.recall_merge_and_rank(keyword_results, vector_results, limit)
234 .await
235 }
236
237 pub(super) async fn recall_fts5_raw(
238 &self,
239 query: &str,
240 limit: usize,
241 conversation_id: Option<ConversationId>,
242 ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
243 self.sqlite
244 .keyword_search(query, limit * 2, conversation_id)
245 .await
246 }
247
248 pub(super) async fn recall_vectors_raw(
249 &self,
250 query: &str,
251 limit: usize,
252 filter: Option<SearchFilter>,
253 ) -> Result<Vec<crate::embedding_store::SearchResult>, MemoryError> {
254 let Some(qdrant) = &self.qdrant else {
255 return Ok(Vec::new());
256 };
257 if !self.provider.supports_embeddings() {
258 return Ok(Vec::new());
259 }
260 let query_vector = self.provider.embed(query).await?;
261 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
262 qdrant.ensure_collection(vector_size).await?;
263 qdrant.search(&query_vector, limit * 2, filter).await
264 }
265
266 #[allow(clippy::cast_possible_truncation, clippy::too_many_lines)]
275 pub(super) async fn recall_merge_and_rank(
276 &self,
277 keyword_results: Vec<(MessageId, f64)>,
278 vector_results: Vec<crate::embedding_store::SearchResult>,
279 limit: usize,
280 ) -> Result<Vec<RecalledMessage>, MemoryError> {
281 tracing::debug!(
282 vector_count = vector_results.len(),
283 keyword_count = keyword_results.len(),
284 limit,
285 "recall: merging search results"
286 );
287
288 let mut scores: std::collections::HashMap<MessageId, f64> =
289 std::collections::HashMap::new();
290
291 if !vector_results.is_empty() {
292 let max_vs = vector_results
293 .iter()
294 .map(|r| r.score)
295 .fold(f32::NEG_INFINITY, f32::max);
296 let norm = if max_vs > 0.0 { max_vs } else { 1.0 };
297 for r in &vector_results {
298 let normalized = f64::from(r.score / norm);
299 *scores.entry(r.message_id).or_default() += normalized * self.vector_weight;
300 }
301 }
302
303 if !keyword_results.is_empty() {
304 let max_ks = keyword_results
305 .iter()
306 .map(|r| r.1)
307 .fold(f64::NEG_INFINITY, f64::max);
308 let norm = if max_ks > 0.0 { max_ks } else { 1.0 };
309 for &(msg_id, score) in &keyword_results {
310 let normalized = score / norm;
311 *scores.entry(msg_id).or_default() += normalized * self.keyword_weight;
312 }
313 }
314
315 if scores.is_empty() {
316 tracing::debug!("recall: empty merge, no overlapping scores");
317 return Ok(Vec::new());
318 }
319
320 let mut ranked: Vec<(MessageId, f64)> = scores.into_iter().collect();
321 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
322
323 tracing::debug!(
324 merged = ranked.len(),
325 top_score = ranked.first().map(|r| r.1),
326 bottom_score = ranked.last().map(|r| r.1),
327 vector_weight = %self.vector_weight,
328 keyword_weight = %self.keyword_weight,
329 "recall: weighted merge complete"
330 );
331
332 if self.temporal_decay_enabled && self.temporal_decay_half_life_days > 0 {
333 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
334 match self.sqlite.message_timestamps(&ids).await {
335 Ok(timestamps) => {
336 apply_temporal_decay(
337 &mut ranked,
338 ×tamps,
339 self.temporal_decay_half_life_days,
340 );
341 ranked
342 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
343 tracing::debug!(
344 half_life_days = self.temporal_decay_half_life_days,
345 top_score_after = ranked.first().map(|r| r.1),
346 "recall: temporal decay applied"
347 );
348 }
349 Err(e) => {
350 tracing::warn!("temporal decay: failed to fetch timestamps: {e:#}");
351 }
352 }
353 }
354
355 if self.mmr_enabled && !vector_results.is_empty() {
356 if let Some(qdrant) = &self.qdrant {
357 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
358 match qdrant.get_vectors(&ids).await {
359 Ok(vec_map) if !vec_map.is_empty() => {
360 let ranked_len_before = ranked.len();
361 ranked = apply_mmr(&ranked, &vec_map, self.mmr_lambda, limit);
362 tracing::debug!(
363 before = ranked_len_before,
364 after = ranked.len(),
365 lambda = %self.mmr_lambda,
366 "recall: mmr re-ranked"
367 );
368 }
369 Ok(_) => {
370 ranked.truncate(limit);
371 }
372 Err(e) => {
373 tracing::warn!("MMR: failed to fetch vectors: {e:#}");
374 ranked.truncate(limit);
375 }
376 }
377 } else {
378 ranked.truncate(limit);
379 }
380 } else {
381 ranked.truncate(limit);
382 }
383
384 if self.importance_enabled && !ranked.is_empty() {
385 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
386 match self.sqlite.fetch_importance_scores(&ids).await {
387 Ok(scores) => {
388 for (msg_id, score) in &mut ranked {
389 if let Some(&imp) = scores.get(msg_id) {
390 *score += imp * self.importance_weight;
391 }
392 }
393 ranked
394 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
395 tracing::debug!(
396 importance_weight = %self.importance_weight,
397 "recall: importance scores blended"
398 );
399 }
400 Err(e) => {
401 tracing::warn!("importance scoring: failed to fetch scores: {e:#}");
402 }
403 }
404 }
405
406 if (self.tier_boost_semantic - 1.0).abs() > f64::EPSILON && !ranked.is_empty() {
410 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
411 match self.sqlite.fetch_tiers(&ids).await {
412 Ok(tiers) => {
413 let bonus = self.tier_boost_semantic - 1.0;
414 let mut boosted = false;
415 for (msg_id, score) in &mut ranked {
416 if tiers.get(msg_id).map(String::as_str) == Some("semantic") {
417 *score += bonus;
418 boosted = true;
419 }
420 }
421 if boosted {
422 ranked.sort_by(|a, b| {
423 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
424 });
425 tracing::debug!(
426 tier_boost = %self.tier_boost_semantic,
427 "recall: semantic tier boost applied"
428 );
429 }
430 }
431 Err(e) => {
432 tracing::warn!("tier boost: failed to fetch tiers: {e:#}");
433 }
434 }
435 }
436
437 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
438
439 if !ids.is_empty()
440 && let Err(e) = self.batch_increment_access_count(ids.clone()).await
441 {
442 tracing::warn!("recall: failed to increment access counts: {e:#}");
443 }
444
445 if let Err(e) = self.sqlite.mark_training_recalled(&ids).await {
447 tracing::debug!(
448 error = %e,
449 "recall: failed to mark training data as recalled (non-fatal)"
450 );
451 }
452
453 let messages = self.sqlite.messages_by_ids(&ids).await?;
454 let msg_map: std::collections::HashMap<MessageId, _> = messages.into_iter().collect();
455
456 let recalled: Vec<RecalledMessage> = ranked
457 .iter()
458 .filter_map(|(msg_id, score)| {
459 msg_map.get(msg_id).map(|msg| RecalledMessage {
460 message: msg.clone(),
461 #[expect(clippy::cast_possible_truncation)]
462 score: *score as f32,
463 })
464 })
465 .collect();
466
467 tracing::debug!(final_count = recalled.len(), "recall: final results");
468
469 Ok(recalled)
470 }
471
472 pub async fn recall_routed(
481 &self,
482 query: &str,
483 limit: usize,
484 filter: Option<SearchFilter>,
485 router: &dyn crate::router::MemoryRouter,
486 ) -> Result<Vec<RecalledMessage>, MemoryError> {
487 use crate::router::MemoryRoute;
488
489 let route = router.route(query);
490 tracing::debug!(?route, query_len = query.len(), "memory routing decision");
491
492 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
493
494 let (keyword_results, vector_results): (
495 Vec<(MessageId, f64)>,
496 Vec<crate::embedding_store::SearchResult>,
497 ) = match route {
498 MemoryRoute::Keyword => {
499 let kw = self.recall_fts5_raw(query, limit, conversation_id).await?;
500 (kw, Vec::new())
501 }
502 MemoryRoute::Semantic => {
503 let vr = self.recall_vectors_raw(query, limit, filter).await?;
504 (Vec::new(), vr)
505 }
506 MemoryRoute::Hybrid => {
507 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
508 Ok(r) => r,
509 Err(e) => {
510 tracing::warn!("FTS5 keyword search failed: {e:#}");
511 Vec::new()
512 }
513 };
514 let vr = self.recall_vectors_raw(query, limit, filter).await?;
515 (kw, vr)
516 }
517 MemoryRoute::Episodic => {
526 let range = crate::router::resolve_temporal_range(query, chrono::Utc::now());
527 let cleaned = crate::router::strip_temporal_keywords(query);
528 let search_query = if cleaned.is_empty() { query } else { &cleaned };
529 let kw = if let Some(ref r) = range {
530 self.sqlite
531 .keyword_search_with_time_range(
532 search_query,
533 limit,
534 conversation_id,
535 r.after.as_deref(),
536 r.before.as_deref(),
537 )
538 .await?
539 } else {
540 self.recall_fts5_raw(search_query, limit, conversation_id)
541 .await?
542 };
543 tracing::debug!(
544 has_range = range.is_some(),
545 cleaned_query = %search_query,
546 keyword_count = kw.len(),
547 "recall: episodic path"
548 );
549 (kw, Vec::new())
550 }
551 MemoryRoute::Graph => {
554 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
555 Ok(r) => r,
556 Err(e) => {
557 tracing::warn!("FTS5 keyword search failed (graph→hybrid fallback): {e:#}");
558 Vec::new()
559 }
560 };
561 let vr = self.recall_vectors_raw(query, limit, filter).await?;
562 (kw, vr)
563 }
564 };
565
566 tracing::debug!(
567 keyword_count = keyword_results.len(),
568 vector_count = vector_results.len(),
569 "recall: routed search results"
570 );
571
572 self.recall_merge_and_rank(keyword_results, vector_results, limit)
573 .await
574 }
575
576 pub async fn recall_routed_async(
587 &self,
588 query: &str,
589 limit: usize,
590 filter: Option<crate::embedding_store::SearchFilter>,
591 router: &dyn crate::router::AsyncMemoryRouter,
592 ) -> Result<Vec<RecalledMessage>, MemoryError> {
593 use crate::router::MemoryRoute;
594
595 let decision = router.route_async(query).await;
596 let route = decision.route;
597 tracing::debug!(
598 ?route,
599 confidence = decision.confidence,
600 query_len = query.len(),
601 "memory routing decision (async)"
602 );
603
604 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
605
606 let (keyword_results, vector_results): (
607 Vec<(crate::types::MessageId, f64)>,
608 Vec<crate::embedding_store::SearchResult>,
609 ) = match route {
610 MemoryRoute::Keyword => {
611 let kw = self.recall_fts5_raw(query, limit, conversation_id).await?;
612 (kw, Vec::new())
613 }
614 MemoryRoute::Semantic => {
615 let vr = self.recall_vectors_raw(query, limit, filter).await?;
616 (Vec::new(), vr)
617 }
618 MemoryRoute::Hybrid => {
619 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
620 Ok(r) => r,
621 Err(e) => {
622 tracing::warn!("FTS5 keyword search failed: {e:#}");
623 Vec::new()
624 }
625 };
626 let vr = self.recall_vectors_raw(query, limit, filter).await?;
627 (kw, vr)
628 }
629 MemoryRoute::Episodic => {
630 let range = crate::router::resolve_temporal_range(query, chrono::Utc::now());
631 let cleaned = crate::router::strip_temporal_keywords(query);
632 let search_query = if cleaned.is_empty() { query } else { &cleaned };
633 let kw = if let Some(ref r) = range {
634 self.sqlite
635 .keyword_search_with_time_range(
636 search_query,
637 limit,
638 conversation_id,
639 r.after.as_deref(),
640 r.before.as_deref(),
641 )
642 .await?
643 } else {
644 self.recall_fts5_raw(search_query, limit, conversation_id)
645 .await?
646 };
647 (kw, Vec::new())
648 }
649 MemoryRoute::Graph => {
650 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
651 Ok(r) => r,
652 Err(e) => {
653 tracing::warn!("FTS5 keyword search failed (graph→hybrid fallback): {e:#}");
654 Vec::new()
655 }
656 };
657 let vr = self.recall_vectors_raw(query, limit, filter).await?;
658 (kw, vr)
659 }
660 };
661
662 tracing::debug!(
663 keyword_count = keyword_results.len(),
664 vector_count = vector_results.len(),
665 "recall: routed search results (async)"
666 );
667
668 self.recall_merge_and_rank(keyword_results, vector_results, limit)
669 .await
670 }
671
672 pub async fn recall_graph(
686 &self,
687 query: &str,
688 limit: usize,
689 max_hops: u32,
690 at_timestamp: Option<&str>,
691 temporal_decay_rate: f64,
692 edge_types: &[crate::graph::EdgeType],
693 ) -> Result<Vec<crate::graph::types::GraphFact>, MemoryError> {
694 let Some(store) = &self.graph_store else {
695 return Ok(Vec::new());
696 };
697
698 tracing::debug!(
699 query_len = query.len(),
700 limit,
701 max_hops,
702 "graph: starting recall"
703 );
704
705 let results = crate::graph::retrieval::graph_recall(
706 store,
707 self.qdrant.as_deref(),
708 &self.provider,
709 query,
710 limit,
711 max_hops,
712 at_timestamp,
713 temporal_decay_rate,
714 edge_types,
715 )
716 .await?;
717
718 tracing::debug!(result_count = results.len(), "graph: recall complete");
719
720 Ok(results)
721 }
722
723 pub async fn recall_graph_activated(
732 &self,
733 query: &str,
734 limit: usize,
735 params: crate::graph::SpreadingActivationParams,
736 edge_types: &[crate::graph::EdgeType],
737 ) -> Result<Vec<crate::graph::activation::ActivatedFact>, MemoryError> {
738 let Some(store) = &self.graph_store else {
739 return Ok(Vec::new());
740 };
741
742 tracing::debug!(
743 query_len = query.len(),
744 limit,
745 "spreading activation: starting graph recall"
746 );
747
748 let embeddings = self.qdrant.as_deref();
749 let results = crate::graph::retrieval::graph_recall_activated(
750 store,
751 embeddings,
752 &self.provider,
753 query,
754 limit,
755 params,
756 edge_types,
757 )
758 .await?;
759
760 tracing::debug!(
761 result_count = results.len(),
762 "spreading activation: graph recall complete"
763 );
764
765 Ok(results)
766 }
767
768 async fn batch_increment_access_count(
776 &self,
777 message_ids: Vec<MessageId>,
778 ) -> Result<(), MemoryError> {
779 if message_ids.is_empty() {
780 return Ok(());
781 }
782 self.sqlite.increment_access_counts(&message_ids).await
783 }
784
785 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
791 match &self.qdrant {
792 Some(qdrant) => qdrant.has_embedding(message_id).await,
793 None => Ok(false),
794 }
795 }
796
797 pub async fn embed_missing(&self) -> Result<usize, MemoryError> {
806 let Some(qdrant) = &self.qdrant else {
807 return Ok(0);
808 };
809 if !self.provider.supports_embeddings() {
810 return Ok(0);
811 }
812
813 let unembedded = self.sqlite.unembedded_message_ids(Some(1000)).await?;
814
815 if unembedded.is_empty() {
816 return Ok(0);
817 }
818
819 let probe = self.provider.embed("probe").await?;
820 let vector_size = u64::try_from(probe.len())?;
821 qdrant.ensure_collection(vector_size).await?;
822
823 let mut count = 0;
824 for (msg_id, conversation_id, role, content) in &unembedded {
825 match self.provider.embed(content).await {
826 Ok(vector) => {
827 if let Err(e) = qdrant
828 .store(
829 *msg_id,
830 *conversation_id,
831 role,
832 vector,
833 MessageKind::Regular,
834 &self.embedding_model,
835 )
836 .await
837 {
838 tracing::warn!("Failed to store embedding for msg {msg_id}: {e:#}");
839 continue;
840 }
841 count += 1;
842 }
843 Err(e) => {
844 tracing::warn!("Failed to embed msg {msg_id}: {e:#}");
845 }
846 }
847 }
848
849 tracing::info!("Embedded {count}/{} missing messages", unembedded.len());
850 Ok(count)
851 }
852}