1use zeph_llm::provider::{LlmProvider as _, Message};
5
6use crate::embedding_store::{MessageKind, SearchFilter};
7use crate::error::MemoryError;
8use crate::types::{ConversationId, MessageId};
9
10use super::SemanticMemory;
11use super::algorithms::{apply_mmr, apply_temporal_decay};
12
13#[derive(Debug)]
14pub struct RecalledMessage {
15 pub message: Message,
16 pub score: f32,
17}
18
19impl SemanticMemory {
20 pub async fn remember(
29 &self,
30 conversation_id: ConversationId,
31 role: &str,
32 content: &str,
33 ) -> Result<MessageId, MemoryError> {
34 let message_id = self
35 .sqlite
36 .save_message(conversation_id, role, content)
37 .await?;
38
39 if let Some(qdrant) = &self.qdrant
40 && self.provider.supports_embeddings()
41 {
42 match self.provider.embed(content).await {
43 Ok(vector) => {
44 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
45 if let Err(e) = qdrant.ensure_collection(vector_size).await {
46 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
47 } else if let Err(e) = qdrant
48 .store(
49 message_id,
50 conversation_id,
51 role,
52 vector,
53 MessageKind::Regular,
54 &self.embedding_model,
55 )
56 .await
57 {
58 tracing::warn!("Failed to store embedding: {e:#}");
59 }
60 }
61 Err(e) => {
62 tracing::warn!("Failed to generate embedding: {e:#}");
63 }
64 }
65 }
66
67 Ok(message_id)
68 }
69
70 pub async fn remember_with_parts(
79 &self,
80 conversation_id: ConversationId,
81 role: &str,
82 content: &str,
83 parts_json: &str,
84 ) -> Result<(MessageId, bool), MemoryError> {
85 let message_id = self
86 .sqlite
87 .save_message_with_parts(conversation_id, role, content, parts_json)
88 .await?;
89
90 let mut embedding_stored = false;
91
92 if let Some(qdrant) = &self.qdrant
93 && self.provider.supports_embeddings()
94 {
95 match self.provider.embed(content).await {
96 Ok(vector) => {
97 let vector_size = u64::try_from(vector.len()).unwrap_or(896);
98 if let Err(e) = qdrant.ensure_collection(vector_size).await {
99 tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
100 } else if let Err(e) = qdrant
101 .store(
102 message_id,
103 conversation_id,
104 role,
105 vector,
106 MessageKind::Regular,
107 &self.embedding_model,
108 )
109 .await
110 {
111 tracing::warn!("Failed to store embedding: {e:#}");
112 } else {
113 embedding_stored = true;
114 }
115 }
116 Err(e) => {
117 tracing::warn!("Failed to generate embedding: {e:#}");
118 }
119 }
120 }
121
122 Ok((message_id, embedding_stored))
123 }
124
125 pub async fn save_only(
133 &self,
134 conversation_id: ConversationId,
135 role: &str,
136 content: &str,
137 parts_json: &str,
138 ) -> Result<MessageId, MemoryError> {
139 self.sqlite
140 .save_message_with_parts(conversation_id, role, content, parts_json)
141 .await
142 }
143
144 pub async fn recall(
154 &self,
155 query: &str,
156 limit: usize,
157 filter: Option<SearchFilter>,
158 ) -> Result<Vec<RecalledMessage>, MemoryError> {
159 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
160
161 tracing::debug!(
162 query_len = query.len(),
163 limit,
164 has_filter = filter.is_some(),
165 conversation_id = conversation_id.map(|c| c.0),
166 has_qdrant = self.qdrant.is_some(),
167 "recall: starting hybrid search"
168 );
169
170 let keyword_results = match self
171 .sqlite
172 .keyword_search(query, limit * 2, conversation_id)
173 .await
174 {
175 Ok(results) => results,
176 Err(e) => {
177 tracing::warn!("FTS5 keyword search failed: {e:#}");
178 Vec::new()
179 }
180 };
181
182 let vector_results = if let Some(qdrant) = &self.qdrant
183 && self.provider.supports_embeddings()
184 {
185 let query_vector = self.provider.embed(query).await?;
186 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
187 qdrant.ensure_collection(vector_size).await?;
188 qdrant.search(&query_vector, limit * 2, filter).await?
189 } else {
190 Vec::new()
191 };
192
193 self.recall_merge_and_rank(keyword_results, vector_results, limit)
194 .await
195 }
196
197 pub(super) async fn recall_fts5_raw(
198 &self,
199 query: &str,
200 limit: usize,
201 conversation_id: Option<ConversationId>,
202 ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
203 self.sqlite
204 .keyword_search(query, limit * 2, conversation_id)
205 .await
206 }
207
208 pub(super) async fn recall_vectors_raw(
209 &self,
210 query: &str,
211 limit: usize,
212 filter: Option<SearchFilter>,
213 ) -> Result<Vec<crate::embedding_store::SearchResult>, MemoryError> {
214 let Some(qdrant) = &self.qdrant else {
215 return Ok(Vec::new());
216 };
217 if !self.provider.supports_embeddings() {
218 return Ok(Vec::new());
219 }
220 let query_vector = self.provider.embed(query).await?;
221 let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
222 qdrant.ensure_collection(vector_size).await?;
223 qdrant.search(&query_vector, limit * 2, filter).await
224 }
225
226 #[allow(clippy::cast_possible_truncation, clippy::too_many_lines)]
235 pub(super) async fn recall_merge_and_rank(
236 &self,
237 keyword_results: Vec<(MessageId, f64)>,
238 vector_results: Vec<crate::embedding_store::SearchResult>,
239 limit: usize,
240 ) -> Result<Vec<RecalledMessage>, MemoryError> {
241 tracing::debug!(
242 vector_count = vector_results.len(),
243 keyword_count = keyword_results.len(),
244 limit,
245 "recall: merging search results"
246 );
247
248 let mut scores: std::collections::HashMap<MessageId, f64> =
249 std::collections::HashMap::new();
250
251 if !vector_results.is_empty() {
252 let max_vs = vector_results
253 .iter()
254 .map(|r| r.score)
255 .fold(f32::NEG_INFINITY, f32::max);
256 let norm = if max_vs > 0.0 { max_vs } else { 1.0 };
257 for r in &vector_results {
258 let normalized = f64::from(r.score / norm);
259 *scores.entry(r.message_id).or_default() += normalized * self.vector_weight;
260 }
261 }
262
263 if !keyword_results.is_empty() {
264 let max_ks = keyword_results
265 .iter()
266 .map(|r| r.1)
267 .fold(f64::NEG_INFINITY, f64::max);
268 let norm = if max_ks > 0.0 { max_ks } else { 1.0 };
269 for &(msg_id, score) in &keyword_results {
270 let normalized = score / norm;
271 *scores.entry(msg_id).or_default() += normalized * self.keyword_weight;
272 }
273 }
274
275 if scores.is_empty() {
276 tracing::debug!("recall: empty merge, no overlapping scores");
277 return Ok(Vec::new());
278 }
279
280 let mut ranked: Vec<(MessageId, f64)> = scores.into_iter().collect();
281 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
282
283 tracing::debug!(
284 merged = ranked.len(),
285 top_score = ranked.first().map(|r| r.1),
286 bottom_score = ranked.last().map(|r| r.1),
287 vector_weight = %self.vector_weight,
288 keyword_weight = %self.keyword_weight,
289 "recall: weighted merge complete"
290 );
291
292 if self.temporal_decay_enabled && self.temporal_decay_half_life_days > 0 {
293 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
294 match self.sqlite.message_timestamps(&ids).await {
295 Ok(timestamps) => {
296 apply_temporal_decay(
297 &mut ranked,
298 ×tamps,
299 self.temporal_decay_half_life_days,
300 );
301 ranked
302 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
303 tracing::debug!(
304 half_life_days = self.temporal_decay_half_life_days,
305 top_score_after = ranked.first().map(|r| r.1),
306 "recall: temporal decay applied"
307 );
308 }
309 Err(e) => {
310 tracing::warn!("temporal decay: failed to fetch timestamps: {e:#}");
311 }
312 }
313 }
314
315 if self.mmr_enabled && !vector_results.is_empty() {
316 if let Some(qdrant) = &self.qdrant {
317 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
318 match qdrant.get_vectors(&ids).await {
319 Ok(vec_map) if !vec_map.is_empty() => {
320 let ranked_len_before = ranked.len();
321 ranked = apply_mmr(&ranked, &vec_map, self.mmr_lambda, limit);
322 tracing::debug!(
323 before = ranked_len_before,
324 after = ranked.len(),
325 lambda = %self.mmr_lambda,
326 "recall: mmr re-ranked"
327 );
328 }
329 Ok(_) => {
330 ranked.truncate(limit);
331 }
332 Err(e) => {
333 tracing::warn!("MMR: failed to fetch vectors: {e:#}");
334 ranked.truncate(limit);
335 }
336 }
337 } else {
338 ranked.truncate(limit);
339 }
340 } else {
341 ranked.truncate(limit);
342 }
343
344 if self.importance_enabled && !ranked.is_empty() {
345 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
346 match self.sqlite.fetch_importance_scores(&ids).await {
347 Ok(scores) => {
348 for (msg_id, score) in &mut ranked {
349 if let Some(&imp) = scores.get(msg_id) {
350 *score += imp * self.importance_weight;
351 }
352 }
353 ranked
354 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
355 tracing::debug!(
356 importance_weight = %self.importance_weight,
357 "recall: importance scores blended"
358 );
359 }
360 Err(e) => {
361 tracing::warn!("importance scoring: failed to fetch scores: {e:#}");
362 }
363 }
364 }
365
366 let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
367
368 if !ids.is_empty()
369 && let Err(e) = self.batch_increment_access_count(ids.clone()).await
370 {
371 tracing::warn!("recall: failed to increment access counts: {e:#}");
372 }
373
374 let messages = self.sqlite.messages_by_ids(&ids).await?;
375 let msg_map: std::collections::HashMap<MessageId, _> = messages.into_iter().collect();
376
377 let recalled: Vec<RecalledMessage> = ranked
378 .iter()
379 .filter_map(|(msg_id, score)| {
380 msg_map.get(msg_id).map(|msg| RecalledMessage {
381 message: msg.clone(),
382 #[expect(clippy::cast_possible_truncation)]
383 score: *score as f32,
384 })
385 })
386 .collect();
387
388 tracing::debug!(final_count = recalled.len(), "recall: final results");
389
390 Ok(recalled)
391 }
392
393 pub async fn recall_routed(
402 &self,
403 query: &str,
404 limit: usize,
405 filter: Option<SearchFilter>,
406 router: &dyn crate::router::MemoryRouter,
407 ) -> Result<Vec<RecalledMessage>, MemoryError> {
408 use crate::router::MemoryRoute;
409
410 let route = router.route(query);
411 tracing::debug!(?route, query_len = query.len(), "memory routing decision");
412
413 let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
414
415 let (keyword_results, vector_results): (
416 Vec<(MessageId, f64)>,
417 Vec<crate::embedding_store::SearchResult>,
418 ) = match route {
419 MemoryRoute::Keyword => {
420 let kw = self.recall_fts5_raw(query, limit, conversation_id).await?;
421 (kw, Vec::new())
422 }
423 MemoryRoute::Semantic => {
424 let vr = self.recall_vectors_raw(query, limit, filter).await?;
425 (Vec::new(), vr)
426 }
427 MemoryRoute::Hybrid => {
428 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
429 Ok(r) => r,
430 Err(e) => {
431 tracing::warn!("FTS5 keyword search failed: {e:#}");
432 Vec::new()
433 }
434 };
435 let vr = self.recall_vectors_raw(query, limit, filter).await?;
436 (kw, vr)
437 }
438 MemoryRoute::Episodic => {
447 let range = crate::router::resolve_temporal_range(query, chrono::Utc::now());
448 let cleaned = crate::router::strip_temporal_keywords(query);
449 let search_query = if cleaned.is_empty() { query } else { &cleaned };
450 let kw = if let Some(ref r) = range {
451 self.sqlite
452 .keyword_search_with_time_range(
453 search_query,
454 limit,
455 conversation_id,
456 r.after.as_deref(),
457 r.before.as_deref(),
458 )
459 .await?
460 } else {
461 self.recall_fts5_raw(search_query, limit, conversation_id)
462 .await?
463 };
464 tracing::debug!(
465 has_range = range.is_some(),
466 cleaned_query = %search_query,
467 keyword_count = kw.len(),
468 "recall: episodic path"
469 );
470 (kw, Vec::new())
471 }
472 MemoryRoute::Graph => {
475 let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
476 Ok(r) => r,
477 Err(e) => {
478 tracing::warn!("FTS5 keyword search failed (graphâhybrid fallback): {e:#}");
479 Vec::new()
480 }
481 };
482 let vr = self.recall_vectors_raw(query, limit, filter).await?;
483 (kw, vr)
484 }
485 };
486
487 tracing::debug!(
488 keyword_count = keyword_results.len(),
489 vector_count = vector_results.len(),
490 "recall: routed search results"
491 );
492
493 self.recall_merge_and_rank(keyword_results, vector_results, limit)
494 .await
495 }
496
497 pub async fn recall_graph(
511 &self,
512 query: &str,
513 limit: usize,
514 max_hops: u32,
515 at_timestamp: Option<&str>,
516 temporal_decay_rate: f64,
517 edge_types: &[crate::graph::EdgeType],
518 ) -> Result<Vec<crate::graph::types::GraphFact>, MemoryError> {
519 let Some(store) = &self.graph_store else {
520 return Ok(Vec::new());
521 };
522
523 tracing::debug!(
524 query_len = query.len(),
525 limit,
526 max_hops,
527 "graph: starting recall"
528 );
529
530 let results = crate::graph::retrieval::graph_recall(
531 store,
532 self.qdrant.as_deref(),
533 &self.provider,
534 query,
535 limit,
536 max_hops,
537 at_timestamp,
538 temporal_decay_rate,
539 edge_types,
540 )
541 .await?;
542
543 tracing::debug!(result_count = results.len(), "graph: recall complete");
544
545 Ok(results)
546 }
547
548 pub async fn recall_graph_activated(
557 &self,
558 query: &str,
559 limit: usize,
560 params: crate::graph::SpreadingActivationParams,
561 edge_types: &[crate::graph::EdgeType],
562 ) -> Result<Vec<crate::graph::activation::ActivatedFact>, MemoryError> {
563 let Some(store) = &self.graph_store else {
564 return Ok(Vec::new());
565 };
566
567 tracing::debug!(
568 query_len = query.len(),
569 limit,
570 "spreading activation: starting graph recall"
571 );
572
573 let results = crate::graph::retrieval::graph_recall_activated(
574 store, query, limit, params, edge_types,
575 )
576 .await?;
577
578 tracing::debug!(
579 result_count = results.len(),
580 "spreading activation: graph recall complete"
581 );
582
583 Ok(results)
584 }
585
586 async fn batch_increment_access_count(
594 &self,
595 message_ids: Vec<MessageId>,
596 ) -> Result<(), MemoryError> {
597 if message_ids.is_empty() {
598 return Ok(());
599 }
600 self.sqlite.increment_access_counts(&message_ids).await
601 }
602
603 pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
609 match &self.qdrant {
610 Some(qdrant) => qdrant.has_embedding(message_id).await,
611 None => Ok(false),
612 }
613 }
614
615 pub async fn embed_missing(&self) -> Result<usize, MemoryError> {
624 let Some(qdrant) = &self.qdrant else {
625 return Ok(0);
626 };
627 if !self.provider.supports_embeddings() {
628 return Ok(0);
629 }
630
631 let unembedded = self.sqlite.unembedded_message_ids(Some(1000)).await?;
632
633 if unembedded.is_empty() {
634 return Ok(0);
635 }
636
637 let probe = self.provider.embed("probe").await?;
638 let vector_size = u64::try_from(probe.len())?;
639 qdrant.ensure_collection(vector_size).await?;
640
641 let mut count = 0;
642 for (msg_id, conversation_id, role, content) in &unembedded {
643 match self.provider.embed(content).await {
644 Ok(vector) => {
645 if let Err(e) = qdrant
646 .store(
647 *msg_id,
648 *conversation_id,
649 role,
650 vector,
651 MessageKind::Regular,
652 &self.embedding_model,
653 )
654 .await
655 {
656 tracing::warn!("Failed to store embedding for msg {msg_id}: {e:#}");
657 continue;
658 }
659 count += 1;
660 }
661 Err(e) => {
662 tracing::warn!("Failed to embed msg {msg_id}: {e:#}");
663 }
664 }
665 }
666
667 tracing::info!("Embedded {count}/{} missing messages", unembedded.len());
668 Ok(count)
669 }
670}