1use zeph_llm::provider::{LlmProvider as _, Message};
5
6use crate::admission::log_admission_decision;
7use crate::embedding_store::{MessageKind, SearchFilter};
8use crate::error::MemoryError;
9use crate::types::{ConversationId, MessageId};
10
11use super::SemanticMemory;
12use super::algorithms::{apply_mmr, apply_temporal_decay};
13
14#[derive(Debug)]
15pub struct RecalledMessage {
16 pub message: Message,
17 pub score: f32,
18}
19
20impl SemanticMemory {
21 pub async fn remember(
31 &self,
32 conversation_id: ConversationId,
33 role: &str,
34 content: &str,
35 ) -> Result<Option<MessageId>, MemoryError> {
36 if let Some(ref admission) = self.admission_control {
38 let decision = admission
39 .evaluate(content, role, &self.provider, self.qdrant.as_ref())
40 .await;
41 let preview: String = content.chars().take(100).collect();
42 log_admission_decision(&decision, &preview, role, admission.threshold());
43 if !decision.admitted {
44 return Ok(None);
45 }
46 }
47
48 let message_id = self
49 .sqlite
50 .save_message(conversation_id, role, content)
51 .await?;
52
53 if let Some(qdrant) = &self.qdrant
54 && self.provider.supports_embeddings()
55 {
56 match self.provider.embed(content).await {
57 Ok(vector) => {
58 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
59 if let Err(e) = qdrant.ensure_collection(vector_size).await {
60 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
61 } else if let Err(e) = qdrant
62 .store(
63 message_id,
64 conversation_id,
65 role,
66 vector,
67 MessageKind::Regular,
68 &self.embedding_model,
69 )
70 .await
71 {
72 tracing::warn!("Failed to store embedding: {e:#}");
73 }
74 }
75 Err(e) => {
76 tracing::warn!("Failed to generate embedding: {e:#}");
77 }
78 }
79 }
80
81 Ok(Some(message_id))
82 }
83
84 pub async fn remember_with_parts(
93 &self,
94 conversation_id: ConversationId,
95 role: &str,
96 content: &str,
97 parts_json: &str,
98 ) -> Result<(Option<MessageId>, bool), MemoryError> {
99 if let Some(ref admission) = self.admission_control {
101 let decision = admission
102 .evaluate(content, role, &self.provider, self.qdrant.as_ref())
103 .await;
104 let preview: String = content.chars().take(100).collect();
105 log_admission_decision(&decision, &preview, role, admission.threshold());
106 if !decision.admitted {
107 return Ok((None, false));
108 }
109 }
110
111 let message_id = self
112 .sqlite
113 .save_message_with_parts(conversation_id, role, content, parts_json)
114 .await?;
115
116 let mut embedding_stored = false;
117
118 if let Some(qdrant) = &self.qdrant
119 && self.provider.supports_embeddings()
120 {
121 match self.provider.embed(content).await {
122 Ok(vector) => {
123 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
124 if let Err(e) = qdrant.ensure_collection(vector_size).await {
125 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
126 } else if let Err(e) = qdrant
127 .store(
128 message_id,
129 conversation_id,
130 role,
131 vector,
132 MessageKind::Regular,
133 &self.embedding_model,
134 )
135 .await
136 {
137 tracing::warn!("Failed to store embedding: {e:#}");
138 } else {
139 embedding_stored = true;
140 }
141 }
142 Err(e) => {
143 tracing::warn!("Failed to generate embedding: {e:#}");
144 }
145 }
146 }
147
148 Ok((Some(message_id), embedding_stored))
149 }
150
151 pub async fn save_only(
159 &self,
160 conversation_id: ConversationId,
161 role: &str,
162 content: &str,
163 parts_json: &str,
164 ) -> Result<MessageId, MemoryError> {
165 self.sqlite
166 .save_message_with_parts(conversation_id, role, content, parts_json)
167 .await
168 }
169
170 pub async fn recall(
180 &self,
181 query: &str,
182 limit: usize,
183 filter: Option<SearchFilter>,
184 ) -> Result<Vec<RecalledMessage>, MemoryError> {
185 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
186
187 tracing::debug!(
188 query_len = query.len(),
189 limit,
190 has_filter = filter.is_some(),
191 conversation_id = conversation_id.map(|c| c.0),
192 has_qdrant = self.qdrant.is_some(),
193 "recall: starting hybrid search"
194 );
195
196 let keyword_results = match self
197 .sqlite
198 .keyword_search(query, limit * 2, conversation_id)
199 .await
200 {
201 Ok(results) => results,
202 Err(e) => {
203 tracing::warn!("FTS5 keyword search failed: {e:#}");
204 Vec::new()
205 }
206 };
207
208 let vector_results = if let Some(qdrant) = &self.qdrant
209 && self.provider.supports_embeddings()
210 {
211 let query_vector = self.provider.embed(query).await?;
212 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
213 qdrant.ensure_collection(vector_size).await?;
214 qdrant.search(&query_vector, limit * 2, filter).await?
215 } else {
216 Vec::new()
217 };
218
219 self.recall_merge_and_rank(keyword_results, vector_results, limit)
220 .await
221 }
222
223 pub(super) async fn recall_fts5_raw(
224 &self,
225 query: &str,
226 limit: usize,
227 conversation_id: Option<ConversationId>,
228 ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
229 self.sqlite
230 .keyword_search(query, limit * 2, conversation_id)
231 .await
232 }
233
234 pub(super) async fn recall_vectors_raw(
235 &self,
236 query: &str,
237 limit: usize,
238 filter: Option<SearchFilter>,
239 ) -> Result<Vec<crate::embedding_store::SearchResult>, MemoryError> {
240 let Some(qdrant) = &self.qdrant else {
241 return Ok(Vec::new());
242 };
243 if !self.provider.supports_embeddings() {
244 return Ok(Vec::new());
245 }
246 let query_vector = self.provider.embed(query).await?;
247 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
248 qdrant.ensure_collection(vector_size).await?;
249 qdrant.search(&query_vector, limit * 2, filter).await
250 }
251
252 #[allow(clippy::cast_possible_truncation, clippy::too_many_lines)]
261 pub(super) async fn recall_merge_and_rank(
262 &self,
263 keyword_results: Vec<(MessageId, f64)>,
264 vector_results: Vec<crate::embedding_store::SearchResult>,
265 limit: usize,
266 ) -> Result<Vec<RecalledMessage>, MemoryError> {
267 tracing::debug!(
268 vector_count = vector_results.len(),
269 keyword_count = keyword_results.len(),
270 limit,
271 "recall: merging search results"
272 );
273
274 let mut scores: std::collections::HashMap<MessageId, f64> =
275 std::collections::HashMap::new();
276
277 if !vector_results.is_empty() {
278 let max_vs = vector_results
279 .iter()
280 .map(|r| r.score)
281 .fold(f32::NEG_INFINITY, f32::max);
282 let norm = if max_vs > 0.0 { max_vs } else { 1.0 };
283 for r in &vector_results {
284 let normalized = f64::from(r.score / norm);
285 *scores.entry(r.message_id).or_default() += normalized * self.vector_weight;
286 }
287 }
288
289 if !keyword_results.is_empty() {
290 let max_ks = keyword_results
291 .iter()
292 .map(|r| r.1)
293 .fold(f64::NEG_INFINITY, f64::max);
294 let norm = if max_ks > 0.0 { max_ks } else { 1.0 };
295 for &(msg_id, score) in &keyword_results {
296 let normalized = score / norm;
297 *scores.entry(msg_id).or_default() += normalized * self.keyword_weight;
298 }
299 }
300
301 if scores.is_empty() {
302 tracing::debug!("recall: empty merge, no overlapping scores");
303 return Ok(Vec::new());
304 }
305
306 let mut ranked: Vec<(MessageId, f64)> = scores.into_iter().collect();
307 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
308
309 tracing::debug!(
310 merged = ranked.len(),
311 top_score = ranked.first().map(|r| r.1),
312 bottom_score = ranked.last().map(|r| r.1),
313 vector_weight = %self.vector_weight,
314 keyword_weight = %self.keyword_weight,
315 "recall: weighted merge complete"
316 );
317
318 if self.temporal_decay_enabled && self.temporal_decay_half_life_days > 0 {
319 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
320 match self.sqlite.message_timestamps(&ids).await {
321 Ok(timestamps) => {
322 apply_temporal_decay(
323 &mut ranked,
324 ×tamps,
325 self.temporal_decay_half_life_days,
326 );
327 ranked
328 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
329 tracing::debug!(
330 half_life_days = self.temporal_decay_half_life_days,
331 top_score_after = ranked.first().map(|r| r.1),
332 "recall: temporal decay applied"
333 );
334 }
335 Err(e) => {
336 tracing::warn!("temporal decay: failed to fetch timestamps: {e:#}");
337 }
338 }
339 }
340
341 if self.mmr_enabled && !vector_results.is_empty() {
342 if let Some(qdrant) = &self.qdrant {
343 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
344 match qdrant.get_vectors(&ids).await {
345 Ok(vec_map) if !vec_map.is_empty() => {
346 let ranked_len_before = ranked.len();
347 ranked = apply_mmr(&ranked, &vec_map, self.mmr_lambda, limit);
348 tracing::debug!(
349 before = ranked_len_before,
350 after = ranked.len(),
351 lambda = %self.mmr_lambda,
352 "recall: mmr re-ranked"
353 );
354 }
355 Ok(_) => {
356 ranked.truncate(limit);
357 }
358 Err(e) => {
359 tracing::warn!("MMR: failed to fetch vectors: {e:#}");
360 ranked.truncate(limit);
361 }
362 }
363 } else {
364 ranked.truncate(limit);
365 }
366 } else {
367 ranked.truncate(limit);
368 }
369
370 if self.importance_enabled && !ranked.is_empty() {
371 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
372 match self.sqlite.fetch_importance_scores(&ids).await {
373 Ok(scores) => {
374 for (msg_id, score) in &mut ranked {
375 if let Some(&imp) = scores.get(msg_id) {
376 *score += imp * self.importance_weight;
377 }
378 }
379 ranked
380 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
381 tracing::debug!(
382 importance_weight = %self.importance_weight,
383 "recall: importance scores blended"
384 );
385 }
386 Err(e) => {
387 tracing::warn!("importance scoring: failed to fetch scores: {e:#}");
388 }
389 }
390 }
391
392 if (self.tier_boost_semantic - 1.0).abs() > f64::EPSILON && !ranked.is_empty() {
396 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
397 match self.sqlite.fetch_tiers(&ids).await {
398 Ok(tiers) => {
399 let bonus = self.tier_boost_semantic - 1.0;
400 let mut boosted = false;
401 for (msg_id, score) in &mut ranked {
402 if tiers.get(msg_id).map(String::as_str) == Some("semantic") {
403 *score += bonus;
404 boosted = true;
405 }
406 }
407 if boosted {
408 ranked.sort_by(|a, b| {
409 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
410 });
411 tracing::debug!(
412 tier_boost = %self.tier_boost_semantic,
413 "recall: semantic tier boost applied"
414 );
415 }
416 }
417 Err(e) => {
418 tracing::warn!("tier boost: failed to fetch tiers: {e:#}");
419 }
420 }
421 }
422
423 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
424
425 if !ids.is_empty()
426 && let Err(e) = self.batch_increment_access_count(ids.clone()).await
427 {
428 tracing::warn!("recall: failed to increment access counts: {e:#}");
429 }
430
431 let messages = self.sqlite.messages_by_ids(&ids).await?;
432 let msg_map: std::collections::HashMap<MessageId, _> = messages.into_iter().collect();
433
434 let recalled: Vec<RecalledMessage> = ranked
435 .iter()
436 .filter_map(|(msg_id, score)| {
437 msg_map.get(msg_id).map(|msg| RecalledMessage {
438 message: msg.clone(),
439 #[expect(clippy::cast_possible_truncation)]
440 score: *score as f32,
441 })
442 })
443 .collect();
444
445 tracing::debug!(final_count = recalled.len(), "recall: final results");
446
447 Ok(recalled)
448 }
449
450 pub async fn recall_routed(
459 &self,
460 query: &str,
461 limit: usize,
462 filter: Option<SearchFilter>,
463 router: &dyn crate::router::MemoryRouter,
464 ) -> Result<Vec<RecalledMessage>, MemoryError> {
465 use crate::router::MemoryRoute;
466
467 let route = router.route(query);
468 tracing::debug!(?route, query_len = query.len(), "memory routing decision");
469
470 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
471
472 let (keyword_results, vector_results): (
473 Vec<(MessageId, f64)>,
474 Vec<crate::embedding_store::SearchResult>,
475 ) = match route {
476 MemoryRoute::Keyword => {
477 let kw = self.recall_fts5_raw(query, limit, conversation_id).await?;
478 (kw, Vec::new())
479 }
480 MemoryRoute::Semantic => {
481 let vr = self.recall_vectors_raw(query, limit, filter).await?;
482 (Vec::new(), vr)
483 }
484 MemoryRoute::Hybrid => {
485 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
486 Ok(r) => r,
487 Err(e) => {
488 tracing::warn!("FTS5 keyword search failed: {e:#}");
489 Vec::new()
490 }
491 };
492 let vr = self.recall_vectors_raw(query, limit, filter).await?;
493 (kw, vr)
494 }
495 MemoryRoute::Episodic => {
504 let range = crate::router::resolve_temporal_range(query, chrono::Utc::now());
505 let cleaned = crate::router::strip_temporal_keywords(query);
506 let search_query = if cleaned.is_empty() { query } else { &cleaned };
507 let kw = if let Some(ref r) = range {
508 self.sqlite
509 .keyword_search_with_time_range(
510 search_query,
511 limit,
512 conversation_id,
513 r.after.as_deref(),
514 r.before.as_deref(),
515 )
516 .await?
517 } else {
518 self.recall_fts5_raw(search_query, limit, conversation_id)
519 .await?
520 };
521 tracing::debug!(
522 has_range = range.is_some(),
523 cleaned_query = %search_query,
524 keyword_count = kw.len(),
525 "recall: episodic path"
526 );
527 (kw, Vec::new())
528 }
529 MemoryRoute::Graph => {
532 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
533 Ok(r) => r,
534 Err(e) => {
535 tracing::warn!("FTS5 keyword search failed (graphâhybrid fallback): {e:#}");
536 Vec::new()
537 }
538 };
539 let vr = self.recall_vectors_raw(query, limit, filter).await?;
540 (kw, vr)
541 }
542 };
543
544 tracing::debug!(
545 keyword_count = keyword_results.len(),
546 vector_count = vector_results.len(),
547 "recall: routed search results"
548 );
549
550 self.recall_merge_and_rank(keyword_results, vector_results, limit)
551 .await
552 }
553
554 pub async fn recall_graph(
568 &self,
569 query: &str,
570 limit: usize,
571 max_hops: u32,
572 at_timestamp: Option<&str>,
573 temporal_decay_rate: f64,
574 edge_types: &[crate::graph::EdgeType],
575 ) -> Result<Vec<crate::graph::types::GraphFact>, MemoryError> {
576 let Some(store) = &self.graph_store else {
577 return Ok(Vec::new());
578 };
579
580 tracing::debug!(
581 query_len = query.len(),
582 limit,
583 max_hops,
584 "graph: starting recall"
585 );
586
587 let results = crate::graph::retrieval::graph_recall(
588 store,
589 self.qdrant.as_deref(),
590 &self.provider,
591 query,
592 limit,
593 max_hops,
594 at_timestamp,
595 temporal_decay_rate,
596 edge_types,
597 )
598 .await?;
599
600 tracing::debug!(result_count = results.len(), "graph: recall complete");
601
602 Ok(results)
603 }
604
605 pub async fn recall_graph_activated(
614 &self,
615 query: &str,
616 limit: usize,
617 params: crate::graph::SpreadingActivationParams,
618 edge_types: &[crate::graph::EdgeType],
619 ) -> Result<Vec<crate::graph::activation::ActivatedFact>, MemoryError> {
620 let Some(store) = &self.graph_store else {
621 return Ok(Vec::new());
622 };
623
624 tracing::debug!(
625 query_len = query.len(),
626 limit,
627 "spreading activation: starting graph recall"
628 );
629
630 let embeddings = self.qdrant.as_deref();
631 let results = crate::graph::retrieval::graph_recall_activated(
632 store,
633 embeddings,
634 &self.provider,
635 query,
636 limit,
637 params,
638 edge_types,
639 )
640 .await?;
641
642 tracing::debug!(
643 result_count = results.len(),
644 "spreading activation: graph recall complete"
645 );
646
647 Ok(results)
648 }
649
650 async fn batch_increment_access_count(
658 &self,
659 message_ids: Vec<MessageId>,
660 ) -> Result<(), MemoryError> {
661 if message_ids.is_empty() {
662 return Ok(());
663 }
664 self.sqlite.increment_access_counts(&message_ids).await
665 }
666
667 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
673 match &self.qdrant {
674 Some(qdrant) => qdrant.has_embedding(message_id).await,
675 None => Ok(false),
676 }
677 }
678
679 pub async fn embed_missing(&self) -> Result<usize, MemoryError> {
688 let Some(qdrant) = &self.qdrant else {
689 return Ok(0);
690 };
691 if !self.provider.supports_embeddings() {
692 return Ok(0);
693 }
694
695 let unembedded = self.sqlite.unembedded_message_ids(Some(1000)).await?;
696
697 if unembedded.is_empty() {
698 return Ok(0);
699 }
700
701 let probe = self.provider.embed("probe").await?;
702 let vector_size = u64::try_from(probe.len())?;
703 qdrant.ensure_collection(vector_size).await?;
704
705 let mut count = 0;
706 for (msg_id, conversation_id, role, content) in &unembedded {
707 match self.provider.embed(content).await {
708 Ok(vector) => {
709 if let Err(e) = qdrant
710 .store(
711 *msg_id,
712 *conversation_id,
713 role,
714 vector,
715 MessageKind::Regular,
716 &self.embedding_model,
717 )
718 .await
719 {
720 tracing::warn!("Failed to store embedding for msg {msg_id}: {e:#}");
721 continue;
722 }
723 count += 1;
724 }
725 Err(e) => {
726 tracing::warn!("Failed to embed msg {msg_id}: {e:#}");
727 }
728 }
729 }
730
731 tracing::info!("Embedded {count}/{} missing messages", unembedded.len());
732 Ok(count)
733 }
734}