1use zeph_llm::provider::{LlmProvider as _, Message};
5
6use crate::admission::log_admission_decision;
7use crate::embedding_store::{MessageKind, SearchFilter};
8use crate::error::MemoryError;
9use crate::types::{ConversationId, MessageId};
10
11use super::SemanticMemory;
12use super::algorithms::{apply_mmr, apply_temporal_decay};
13
14#[derive(Debug)]
15pub struct RecalledMessage {
16 pub message: Message,
17 pub score: f32,
18}
19
20impl SemanticMemory {
21 pub async fn remember(
31 &self,
32 conversation_id: ConversationId,
33 role: &str,
34 content: &str,
35 ) -> Result<Option<MessageId>, MemoryError> {
36 if let Some(ref admission) = self.admission_control {
38 let decision = admission
39 .evaluate(content, role, &self.provider, self.qdrant.as_ref(), None)
40 .await;
41 let preview: String = content.chars().take(100).collect();
42 log_admission_decision(&decision, &preview, role, admission.threshold());
43 if !decision.admitted {
44 return Ok(None);
45 }
46 }
47
48 let message_id = self
49 .sqlite
50 .save_message(conversation_id, role, content)
51 .await?;
52
53 if let Some(qdrant) = &self.qdrant
54 && self.provider.supports_embeddings()
55 {
56 match self.provider.embed(content).await {
57 Ok(vector) => {
58 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
59 if let Err(e) = qdrant.ensure_collection(vector_size).await {
60 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
61 } else if let Err(e) = qdrant
62 .store(
63 message_id,
64 conversation_id,
65 role,
66 vector,
67 MessageKind::Regular,
68 &self.embedding_model,
69 )
70 .await
71 {
72 tracing::warn!("Failed to store embedding: {e:#}");
73 }
74 }
75 Err(e) => {
76 tracing::warn!("Failed to generate embedding: {e:#}");
77 }
78 }
79 }
80
81 Ok(Some(message_id))
82 }
83
84 pub async fn remember_with_parts(
93 &self,
94 conversation_id: ConversationId,
95 role: &str,
96 content: &str,
97 parts_json: &str,
98 ) -> Result<(Option<MessageId>, bool), MemoryError> {
99 if let Some(ref admission) = self.admission_control {
101 let decision = admission
102 .evaluate(content, role, &self.provider, self.qdrant.as_ref(), None)
103 .await;
104 let preview: String = content.chars().take(100).collect();
105 log_admission_decision(&decision, &preview, role, admission.threshold());
106 if !decision.admitted {
107 return Ok((None, false));
108 }
109 }
110
111 let message_id = self
112 .sqlite
113 .save_message_with_parts(conversation_id, role, content, parts_json)
114 .await?;
115
116 let mut embedding_stored = false;
117
118 if let Some(qdrant) = &self.qdrant
119 && self.provider.supports_embeddings()
120 {
121 match self.provider.embed(content).await {
122 Ok(vector) => {
123 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
124 if let Err(e) = qdrant.ensure_collection(vector_size).await {
125 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
126 } else if let Err(e) = qdrant
127 .store(
128 message_id,
129 conversation_id,
130 role,
131 vector,
132 MessageKind::Regular,
133 &self.embedding_model,
134 )
135 .await
136 {
137 tracing::warn!("Failed to store embedding: {e:#}");
138 } else {
139 embedding_stored = true;
140 }
141 }
142 Err(e) => {
143 tracing::warn!("Failed to generate embedding: {e:#}");
144 }
145 }
146 }
147
148 Ok((Some(message_id), embedding_stored))
149 }
150
151 pub async fn save_only(
159 &self,
160 conversation_id: ConversationId,
161 role: &str,
162 content: &str,
163 parts_json: &str,
164 ) -> Result<MessageId, MemoryError> {
165 self.sqlite
166 .save_message_with_parts(conversation_id, role, content, parts_json)
167 .await
168 }
169
170 pub async fn recall(
180 &self,
181 query: &str,
182 limit: usize,
183 filter: Option<SearchFilter>,
184 ) -> Result<Vec<RecalledMessage>, MemoryError> {
185 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
186
187 tracing::debug!(
188 query_len = query.len(),
189 limit,
190 has_filter = filter.is_some(),
191 conversation_id = conversation_id.map(|c| c.0),
192 has_qdrant = self.qdrant.is_some(),
193 "recall: starting hybrid search"
194 );
195
196 let keyword_results = match self
197 .sqlite
198 .keyword_search(query, limit * 2, conversation_id)
199 .await
200 {
201 Ok(results) => results,
202 Err(e) => {
203 tracing::warn!("FTS5 keyword search failed: {e:#}");
204 Vec::new()
205 }
206 };
207
208 let vector_results = if let Some(qdrant) = &self.qdrant
209 && self.provider.supports_embeddings()
210 {
211 let query_vector = self.provider.embed(query).await?;
212 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
213 qdrant.ensure_collection(vector_size).await?;
214 qdrant.search(&query_vector, limit * 2, filter).await?
215 } else {
216 Vec::new()
217 };
218
219 self.recall_merge_and_rank(keyword_results, vector_results, limit)
220 .await
221 }
222
223 pub(super) async fn recall_fts5_raw(
224 &self,
225 query: &str,
226 limit: usize,
227 conversation_id: Option<ConversationId>,
228 ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
229 self.sqlite
230 .keyword_search(query, limit * 2, conversation_id)
231 .await
232 }
233
234 pub(super) async fn recall_vectors_raw(
235 &self,
236 query: &str,
237 limit: usize,
238 filter: Option<SearchFilter>,
239 ) -> Result<Vec<crate::embedding_store::SearchResult>, MemoryError> {
240 let Some(qdrant) = &self.qdrant else {
241 return Ok(Vec::new());
242 };
243 if !self.provider.supports_embeddings() {
244 return Ok(Vec::new());
245 }
246 let query_vector = self.provider.embed(query).await?;
247 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
248 qdrant.ensure_collection(vector_size).await?;
249 qdrant.search(&query_vector, limit * 2, filter).await
250 }
251
252 #[allow(clippy::cast_possible_truncation, clippy::too_many_lines)]
261 pub(super) async fn recall_merge_and_rank(
262 &self,
263 keyword_results: Vec<(MessageId, f64)>,
264 vector_results: Vec<crate::embedding_store::SearchResult>,
265 limit: usize,
266 ) -> Result<Vec<RecalledMessage>, MemoryError> {
267 tracing::debug!(
268 vector_count = vector_results.len(),
269 keyword_count = keyword_results.len(),
270 limit,
271 "recall: merging search results"
272 );
273
274 let mut scores: std::collections::HashMap<MessageId, f64> =
275 std::collections::HashMap::new();
276
277 if !vector_results.is_empty() {
278 let max_vs = vector_results
279 .iter()
280 .map(|r| r.score)
281 .fold(f32::NEG_INFINITY, f32::max);
282 let norm = if max_vs > 0.0 { max_vs } else { 1.0 };
283 for r in &vector_results {
284 let normalized = f64::from(r.score / norm);
285 *scores.entry(r.message_id).or_default() += normalized * self.vector_weight;
286 }
287 }
288
289 if !keyword_results.is_empty() {
290 let max_ks = keyword_results
291 .iter()
292 .map(|r| r.1)
293 .fold(f64::NEG_INFINITY, f64::max);
294 let norm = if max_ks > 0.0 { max_ks } else { 1.0 };
295 for &(msg_id, score) in &keyword_results {
296 let normalized = score / norm;
297 *scores.entry(msg_id).or_default() += normalized * self.keyword_weight;
298 }
299 }
300
301 if scores.is_empty() {
302 tracing::debug!("recall: empty merge, no overlapping scores");
303 return Ok(Vec::new());
304 }
305
306 let mut ranked: Vec<(MessageId, f64)> = scores.into_iter().collect();
307 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
308
309 tracing::debug!(
310 merged = ranked.len(),
311 top_score = ranked.first().map(|r| r.1),
312 bottom_score = ranked.last().map(|r| r.1),
313 vector_weight = %self.vector_weight,
314 keyword_weight = %self.keyword_weight,
315 "recall: weighted merge complete"
316 );
317
318 if self.temporal_decay_enabled && self.temporal_decay_half_life_days > 0 {
319 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
320 match self.sqlite.message_timestamps(&ids).await {
321 Ok(timestamps) => {
322 apply_temporal_decay(
323 &mut ranked,
324 ×tamps,
325 self.temporal_decay_half_life_days,
326 );
327 ranked
328 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
329 tracing::debug!(
330 half_life_days = self.temporal_decay_half_life_days,
331 top_score_after = ranked.first().map(|r| r.1),
332 "recall: temporal decay applied"
333 );
334 }
335 Err(e) => {
336 tracing::warn!("temporal decay: failed to fetch timestamps: {e:#}");
337 }
338 }
339 }
340
341 if self.mmr_enabled && !vector_results.is_empty() {
342 if let Some(qdrant) = &self.qdrant {
343 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
344 match qdrant.get_vectors(&ids).await {
345 Ok(vec_map) if !vec_map.is_empty() => {
346 let ranked_len_before = ranked.len();
347 ranked = apply_mmr(&ranked, &vec_map, self.mmr_lambda, limit);
348 tracing::debug!(
349 before = ranked_len_before,
350 after = ranked.len(),
351 lambda = %self.mmr_lambda,
352 "recall: mmr re-ranked"
353 );
354 }
355 Ok(_) => {
356 ranked.truncate(limit);
357 }
358 Err(e) => {
359 tracing::warn!("MMR: failed to fetch vectors: {e:#}");
360 ranked.truncate(limit);
361 }
362 }
363 } else {
364 ranked.truncate(limit);
365 }
366 } else {
367 ranked.truncate(limit);
368 }
369
370 if self.importance_enabled && !ranked.is_empty() {
371 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
372 match self.sqlite.fetch_importance_scores(&ids).await {
373 Ok(scores) => {
374 for (msg_id, score) in &mut ranked {
375 if let Some(&imp) = scores.get(msg_id) {
376 *score += imp * self.importance_weight;
377 }
378 }
379 ranked
380 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
381 tracing::debug!(
382 importance_weight = %self.importance_weight,
383 "recall: importance scores blended"
384 );
385 }
386 Err(e) => {
387 tracing::warn!("importance scoring: failed to fetch scores: {e:#}");
388 }
389 }
390 }
391
392 if (self.tier_boost_semantic - 1.0).abs() > f64::EPSILON && !ranked.is_empty() {
396 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
397 match self.sqlite.fetch_tiers(&ids).await {
398 Ok(tiers) => {
399 let bonus = self.tier_boost_semantic - 1.0;
400 let mut boosted = false;
401 for (msg_id, score) in &mut ranked {
402 if tiers.get(msg_id).map(String::as_str) == Some("semantic") {
403 *score += bonus;
404 boosted = true;
405 }
406 }
407 if boosted {
408 ranked.sort_by(|a, b| {
409 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
410 });
411 tracing::debug!(
412 tier_boost = %self.tier_boost_semantic,
413 "recall: semantic tier boost applied"
414 );
415 }
416 }
417 Err(e) => {
418 tracing::warn!("tier boost: failed to fetch tiers: {e:#}");
419 }
420 }
421 }
422
423 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
424
425 if !ids.is_empty()
426 && let Err(e) = self.batch_increment_access_count(ids.clone()).await
427 {
428 tracing::warn!("recall: failed to increment access counts: {e:#}");
429 }
430
431 if let Err(e) = self.sqlite.mark_training_recalled(&ids).await {
433 tracing::debug!(
434 error = %e,
435 "recall: failed to mark training data as recalled (non-fatal)"
436 );
437 }
438
439 let messages = self.sqlite.messages_by_ids(&ids).await?;
440 let msg_map: std::collections::HashMap<MessageId, _> = messages.into_iter().collect();
441
442 let recalled: Vec<RecalledMessage> = ranked
443 .iter()
444 .filter_map(|(msg_id, score)| {
445 msg_map.get(msg_id).map(|msg| RecalledMessage {
446 message: msg.clone(),
447 #[expect(clippy::cast_possible_truncation)]
448 score: *score as f32,
449 })
450 })
451 .collect();
452
453 tracing::debug!(final_count = recalled.len(), "recall: final results");
454
455 Ok(recalled)
456 }
457
458 pub async fn recall_routed(
467 &self,
468 query: &str,
469 limit: usize,
470 filter: Option<SearchFilter>,
471 router: &dyn crate::router::MemoryRouter,
472 ) -> Result<Vec<RecalledMessage>, MemoryError> {
473 use crate::router::MemoryRoute;
474
475 let route = router.route(query);
476 tracing::debug!(?route, query_len = query.len(), "memory routing decision");
477
478 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
479
480 let (keyword_results, vector_results): (
481 Vec<(MessageId, f64)>,
482 Vec<crate::embedding_store::SearchResult>,
483 ) = match route {
484 MemoryRoute::Keyword => {
485 let kw = self.recall_fts5_raw(query, limit, conversation_id).await?;
486 (kw, Vec::new())
487 }
488 MemoryRoute::Semantic => {
489 let vr = self.recall_vectors_raw(query, limit, filter).await?;
490 (Vec::new(), vr)
491 }
492 MemoryRoute::Hybrid => {
493 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
494 Ok(r) => r,
495 Err(e) => {
496 tracing::warn!("FTS5 keyword search failed: {e:#}");
497 Vec::new()
498 }
499 };
500 let vr = self.recall_vectors_raw(query, limit, filter).await?;
501 (kw, vr)
502 }
503 MemoryRoute::Episodic => {
512 let range = crate::router::resolve_temporal_range(query, chrono::Utc::now());
513 let cleaned = crate::router::strip_temporal_keywords(query);
514 let search_query = if cleaned.is_empty() { query } else { &cleaned };
515 let kw = if let Some(ref r) = range {
516 self.sqlite
517 .keyword_search_with_time_range(
518 search_query,
519 limit,
520 conversation_id,
521 r.after.as_deref(),
522 r.before.as_deref(),
523 )
524 .await?
525 } else {
526 self.recall_fts5_raw(search_query, limit, conversation_id)
527 .await?
528 };
529 tracing::debug!(
530 has_range = range.is_some(),
531 cleaned_query = %search_query,
532 keyword_count = kw.len(),
533 "recall: episodic path"
534 );
535 (kw, Vec::new())
536 }
537 MemoryRoute::Graph => {
540 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
541 Ok(r) => r,
542 Err(e) => {
543 tracing::warn!("FTS5 keyword search failed (graphâhybrid fallback): {e:#}");
544 Vec::new()
545 }
546 };
547 let vr = self.recall_vectors_raw(query, limit, filter).await?;
548 (kw, vr)
549 }
550 };
551
552 tracing::debug!(
553 keyword_count = keyword_results.len(),
554 vector_count = vector_results.len(),
555 "recall: routed search results"
556 );
557
558 self.recall_merge_and_rank(keyword_results, vector_results, limit)
559 .await
560 }
561
562 pub async fn recall_graph(
576 &self,
577 query: &str,
578 limit: usize,
579 max_hops: u32,
580 at_timestamp: Option<&str>,
581 temporal_decay_rate: f64,
582 edge_types: &[crate::graph::EdgeType],
583 ) -> Result<Vec<crate::graph::types::GraphFact>, MemoryError> {
584 let Some(store) = &self.graph_store else {
585 return Ok(Vec::new());
586 };
587
588 tracing::debug!(
589 query_len = query.len(),
590 limit,
591 max_hops,
592 "graph: starting recall"
593 );
594
595 let results = crate::graph::retrieval::graph_recall(
596 store,
597 self.qdrant.as_deref(),
598 &self.provider,
599 query,
600 limit,
601 max_hops,
602 at_timestamp,
603 temporal_decay_rate,
604 edge_types,
605 )
606 .await?;
607
608 tracing::debug!(result_count = results.len(), "graph: recall complete");
609
610 Ok(results)
611 }
612
613 pub async fn recall_graph_activated(
622 &self,
623 query: &str,
624 limit: usize,
625 params: crate::graph::SpreadingActivationParams,
626 edge_types: &[crate::graph::EdgeType],
627 ) -> Result<Vec<crate::graph::activation::ActivatedFact>, MemoryError> {
628 let Some(store) = &self.graph_store else {
629 return Ok(Vec::new());
630 };
631
632 tracing::debug!(
633 query_len = query.len(),
634 limit,
635 "spreading activation: starting graph recall"
636 );
637
638 let embeddings = self.qdrant.as_deref();
639 let results = crate::graph::retrieval::graph_recall_activated(
640 store,
641 embeddings,
642 &self.provider,
643 query,
644 limit,
645 params,
646 edge_types,
647 )
648 .await?;
649
650 tracing::debug!(
651 result_count = results.len(),
652 "spreading activation: graph recall complete"
653 );
654
655 Ok(results)
656 }
657
658 async fn batch_increment_access_count(
666 &self,
667 message_ids: Vec<MessageId>,
668 ) -> Result<(), MemoryError> {
669 if message_ids.is_empty() {
670 return Ok(());
671 }
672 self.sqlite.increment_access_counts(&message_ids).await
673 }
674
675 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
681 match &self.qdrant {
682 Some(qdrant) => qdrant.has_embedding(message_id).await,
683 None => Ok(false),
684 }
685 }
686
687 pub async fn embed_missing(&self) -> Result<usize, MemoryError> {
696 let Some(qdrant) = &self.qdrant else {
697 return Ok(0);
698 };
699 if !self.provider.supports_embeddings() {
700 return Ok(0);
701 }
702
703 let unembedded = self.sqlite.unembedded_message_ids(Some(1000)).await?;
704
705 if unembedded.is_empty() {
706 return Ok(0);
707 }
708
709 let probe = self.provider.embed("probe").await?;
710 let vector_size = u64::try_from(probe.len())?;
711 qdrant.ensure_collection(vector_size).await?;
712
713 let mut count = 0;
714 for (msg_id, conversation_id, role, content) in &unembedded {
715 match self.provider.embed(content).await {
716 Ok(vector) => {
717 if let Err(e) = qdrant
718 .store(
719 *msg_id,
720 *conversation_id,
721 role,
722 vector,
723 MessageKind::Regular,
724 &self.embedding_model,
725 )
726 .await
727 {
728 tracing::warn!("Failed to store embedding for msg {msg_id}: {e:#}");
729 continue;
730 }
731 count += 1;
732 }
733 Err(e) => {
734 tracing::warn!("Failed to embed msg {msg_id}: {e:#}");
735 }
736 }
737 }
738
739 tracing::info!("Embedded {count}/{} missing messages", unembedded.len());
740 Ok(count)
741 }
742}