Skip to main content

zeph_memory/semantic/
recall.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use zeph_llm::provider::{LlmProvider as _, Message};
5
6use crate::embedding_store::{MessageKind, SearchFilter};
7use crate::error::MemoryError;
8use crate::types::{ConversationId, MessageId};
9
10use super::SemanticMemory;
11use super::algorithms::{apply_mmr, apply_temporal_decay};
12
13#[derive(Debug)]
14pub struct RecalledMessage {
15    pub message: Message,
16    pub score: f32,
17}
18
19impl SemanticMemory {
20    /// Save a message to `SQLite` and optionally embed and store in Qdrant.
21    ///
22    /// Returns the message ID assigned by `SQLite`.
23    ///
24    /// # Errors
25    ///
26    /// Returns an error if the `SQLite` save fails. Embedding failures are logged but not
27    /// propagated.
28    pub async fn remember(
29        &self,
30        conversation_id: ConversationId,
31        role: &str,
32        content: &str,
33    ) -> Result<MessageId, MemoryError> {
34        let message_id = self
35            .sqlite
36            .save_message(conversation_id, role, content)
37            .await?;
38
39        if let Some(qdrant) = &self.qdrant
40            && self.provider.supports_embeddings()
41        {
42            match self.provider.embed(content).await {
43                Ok(vector) => {
44                    let vector_size = u64::try_from(vector.len()).unwrap_or(896);
45                    if let Err(e) = qdrant.ensure_collection(vector_size).await {
46                        tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
47                    } else if let Err(e) = qdrant
48                        .store(
49                            message_id,
50                            conversation_id,
51                            role,
52                            vector,
53                            MessageKind::Regular,
54                            &self.embedding_model,
55                        )
56                        .await
57                    {
58                        tracing::warn!("Failed to store embedding: {e:#}");
59                    }
60                }
61                Err(e) => {
62                    tracing::warn!("Failed to generate embedding: {e:#}");
63                }
64            }
65        }
66
67        Ok(message_id)
68    }
69
70    /// Save a message with pre-serialized parts JSON to `SQLite` and optionally embed in Qdrant.
71    ///
72    /// Returns `(message_id, embedding_stored)` tuple where `embedding_stored` is `true` if
73    /// an embedding was successfully generated and stored in Qdrant.
74    ///
75    /// # Errors
76    ///
77    /// Returns an error if the `SQLite` save fails.
78    pub async fn remember_with_parts(
79        &self,
80        conversation_id: ConversationId,
81        role: &str,
82        content: &str,
83        parts_json: &str,
84    ) -> Result<(MessageId, bool), MemoryError> {
85        let message_id = self
86            .sqlite
87            .save_message_with_parts(conversation_id, role, content, parts_json)
88            .await?;
89
90        let mut embedding_stored = false;
91
92        if let Some(qdrant) = &self.qdrant
93            && self.provider.supports_embeddings()
94        {
95            match self.provider.embed(content).await {
96                Ok(vector) => {
97                    let vector_size = u64::try_from(vector.len()).unwrap_or(896);
98                    if let Err(e) = qdrant.ensure_collection(vector_size).await {
99                        tracing::warn!("Failed to ensure Qdrant collection: {e:#}");
100                    } else if let Err(e) = qdrant
101                        .store(
102                            message_id,
103                            conversation_id,
104                            role,
105                            vector,
106                            MessageKind::Regular,
107                            &self.embedding_model,
108                        )
109                        .await
110                    {
111                        tracing::warn!("Failed to store embedding: {e:#}");
112                    } else {
113                        embedding_stored = true;
114                    }
115                }
116                Err(e) => {
117                    tracing::warn!("Failed to generate embedding: {e:#}");
118                }
119            }
120        }
121
122        Ok((message_id, embedding_stored))
123    }
124
125    /// Save a message to `SQLite` without generating an embedding.
126    ///
127    /// Use this when embedding is intentionally skipped (e.g. autosave disabled for assistant).
128    ///
129    /// # Errors
130    ///
131    /// Returns an error if the `SQLite` save fails.
132    pub async fn save_only(
133        &self,
134        conversation_id: ConversationId,
135        role: &str,
136        content: &str,
137        parts_json: &str,
138    ) -> Result<MessageId, MemoryError> {
139        self.sqlite
140            .save_message_with_parts(conversation_id, role, content, parts_json)
141            .await
142    }
143
144    /// Recall relevant messages using hybrid search (vector + FTS5 keyword).
145    ///
146    /// When Qdrant is available, runs both vector and keyword searches, then merges
147    /// results using weighted scoring. When Qdrant is unavailable, falls back to
148    /// FTS5-only keyword search.
149    ///
150    /// # Errors
151    ///
152    /// Returns an error if embedding generation, Qdrant search, or FTS5 query fails.
153    pub async fn recall(
154        &self,
155        query: &str,
156        limit: usize,
157        filter: Option<SearchFilter>,
158    ) -> Result<Vec<RecalledMessage>, MemoryError> {
159        let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
160
161        tracing::debug!(
162            query_len = query.len(),
163            limit,
164            has_filter = filter.is_some(),
165            conversation_id = conversation_id.map(|c| c.0),
166            has_qdrant = self.qdrant.is_some(),
167            "recall: starting hybrid search"
168        );
169
170        let keyword_results = match self
171            .sqlite
172            .keyword_search(query, limit * 2, conversation_id)
173            .await
174        {
175            Ok(results) => results,
176            Err(e) => {
177                tracing::warn!("FTS5 keyword search failed: {e:#}");
178                Vec::new()
179            }
180        };
181
182        let vector_results = if let Some(qdrant) = &self.qdrant
183            && self.provider.supports_embeddings()
184        {
185            let query_vector = self.provider.embed(query).await?;
186            let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
187            qdrant.ensure_collection(vector_size).await?;
188            qdrant.search(&query_vector, limit * 2, filter).await?
189        } else {
190            Vec::new()
191        };
192
193        self.recall_merge_and_rank(keyword_results, vector_results, limit)
194            .await
195    }
196
197    pub(super) async fn recall_fts5_raw(
198        &self,
199        query: &str,
200        limit: usize,
201        conversation_id: Option<ConversationId>,
202    ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
203        self.sqlite
204            .keyword_search(query, limit * 2, conversation_id)
205            .await
206    }
207
208    pub(super) async fn recall_vectors_raw(
209        &self,
210        query: &str,
211        limit: usize,
212        filter: Option<SearchFilter>,
213    ) -> Result<Vec<crate::embedding_store::SearchResult>, MemoryError> {
214        let Some(qdrant) = &self.qdrant else {
215            return Ok(Vec::new());
216        };
217        if !self.provider.supports_embeddings() {
218            return Ok(Vec::new());
219        }
220        let query_vector = self.provider.embed(query).await?;
221        let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
222        qdrant.ensure_collection(vector_size).await?;
223        qdrant.search(&query_vector, limit * 2, filter).await
224    }
225
226    /// Merge raw keyword and vector results, apply weighted scoring, temporal decay, and MMR
227    /// re-ranking, then resolve to `RecalledMessage` objects.
228    ///
229    /// This is the shared post-processing step used by all recall paths.
230    ///
231    /// # Errors
232    ///
233    /// Returns an error if the `SQLite` `messages_by_ids` query fails.
234    #[allow(clippy::cast_possible_truncation, clippy::too_many_lines)]
235    pub(super) async fn recall_merge_and_rank(
236        &self,
237        keyword_results: Vec<(MessageId, f64)>,
238        vector_results: Vec<crate::embedding_store::SearchResult>,
239        limit: usize,
240    ) -> Result<Vec<RecalledMessage>, MemoryError> {
241        tracing::debug!(
242            vector_count = vector_results.len(),
243            keyword_count = keyword_results.len(),
244            limit,
245            "recall: merging search results"
246        );
247
248        let mut scores: std::collections::HashMap<MessageId, f64> =
249            std::collections::HashMap::new();
250
251        if !vector_results.is_empty() {
252            let max_vs = vector_results
253                .iter()
254                .map(|r| r.score)
255                .fold(f32::NEG_INFINITY, f32::max);
256            let norm = if max_vs > 0.0 { max_vs } else { 1.0 };
257            for r in &vector_results {
258                let normalized = f64::from(r.score / norm);
259                *scores.entry(r.message_id).or_default() += normalized * self.vector_weight;
260            }
261        }
262
263        if !keyword_results.is_empty() {
264            let max_ks = keyword_results
265                .iter()
266                .map(|r| r.1)
267                .fold(f64::NEG_INFINITY, f64::max);
268            let norm = if max_ks > 0.0 { max_ks } else { 1.0 };
269            for &(msg_id, score) in &keyword_results {
270                let normalized = score / norm;
271                *scores.entry(msg_id).or_default() += normalized * self.keyword_weight;
272            }
273        }
274
275        if scores.is_empty() {
276            tracing::debug!("recall: empty merge, no overlapping scores");
277            return Ok(Vec::new());
278        }
279
280        let mut ranked: Vec<(MessageId, f64)> = scores.into_iter().collect();
281        ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
282
283        tracing::debug!(
284            merged = ranked.len(),
285            top_score = ranked.first().map(|r| r.1),
286            bottom_score = ranked.last().map(|r| r.1),
287            vector_weight = %self.vector_weight,
288            keyword_weight = %self.keyword_weight,
289            "recall: weighted merge complete"
290        );
291
292        if self.temporal_decay_enabled && self.temporal_decay_half_life_days > 0 {
293            let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
294            match self.sqlite.message_timestamps(&ids).await {
295                Ok(timestamps) => {
296                    apply_temporal_decay(
297                        &mut ranked,
298                        &timestamps,
299                        self.temporal_decay_half_life_days,
300                    );
301                    ranked
302                        .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
303                    tracing::debug!(
304                        half_life_days = self.temporal_decay_half_life_days,
305                        top_score_after = ranked.first().map(|r| r.1),
306                        "recall: temporal decay applied"
307                    );
308                }
309                Err(e) => {
310                    tracing::warn!("temporal decay: failed to fetch timestamps: {e:#}");
311                }
312            }
313        }
314
315        if self.mmr_enabled && !vector_results.is_empty() {
316            if let Some(qdrant) = &self.qdrant {
317                let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
318                match qdrant.get_vectors(&ids).await {
319                    Ok(vec_map) if !vec_map.is_empty() => {
320                        let ranked_len_before = ranked.len();
321                        ranked = apply_mmr(&ranked, &vec_map, self.mmr_lambda, limit);
322                        tracing::debug!(
323                            before = ranked_len_before,
324                            after = ranked.len(),
325                            lambda = %self.mmr_lambda,
326                            "recall: mmr re-ranked"
327                        );
328                    }
329                    Ok(_) => {
330                        ranked.truncate(limit);
331                    }
332                    Err(e) => {
333                        tracing::warn!("MMR: failed to fetch vectors: {e:#}");
334                        ranked.truncate(limit);
335                    }
336                }
337            } else {
338                ranked.truncate(limit);
339            }
340        } else {
341            ranked.truncate(limit);
342        }
343
344        if self.importance_enabled && !ranked.is_empty() {
345            let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
346            match self.sqlite.fetch_importance_scores(&ids).await {
347                Ok(scores) => {
348                    for (msg_id, score) in &mut ranked {
349                        if let Some(&imp) = scores.get(msg_id) {
350                            *score += imp * self.importance_weight;
351                        }
352                    }
353                    ranked
354                        .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
355                    tracing::debug!(
356                        importance_weight = %self.importance_weight,
357                        "recall: importance scores blended"
358                    );
359                }
360                Err(e) => {
361                    tracing::warn!("importance scoring: failed to fetch scores: {e:#}");
362                }
363            }
364        }
365
366        let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
367
368        if !ids.is_empty()
369            && let Err(e) = self.batch_increment_access_count(ids.clone()).await
370        {
371            tracing::warn!("recall: failed to increment access counts: {e:#}");
372        }
373
374        let messages = self.sqlite.messages_by_ids(&ids).await?;
375        let msg_map: std::collections::HashMap<MessageId, _> = messages.into_iter().collect();
376
377        let recalled: Vec<RecalledMessage> = ranked
378            .iter()
379            .filter_map(|(msg_id, score)| {
380                msg_map.get(msg_id).map(|msg| RecalledMessage {
381                    message: msg.clone(),
382                    #[expect(clippy::cast_possible_truncation)]
383                    score: *score as f32,
384                })
385            })
386            .collect();
387
388        tracing::debug!(final_count = recalled.len(), "recall: final results");
389
390        Ok(recalled)
391    }
392
393    /// Recall messages using query-aware routing.
394    ///
395    /// Delegates to FTS5-only, vector-only, or hybrid search based on the router decision,
396    /// then runs the shared merge and ranking pipeline.
397    ///
398    /// # Errors
399    ///
400    /// Returns an error if any underlying search or database operation fails.
401    pub async fn recall_routed(
402        &self,
403        query: &str,
404        limit: usize,
405        filter: Option<SearchFilter>,
406        router: &dyn crate::router::MemoryRouter,
407    ) -> Result<Vec<RecalledMessage>, MemoryError> {
408        use crate::router::MemoryRoute;
409
410        let route = router.route(query);
411        tracing::debug!(?route, query_len = query.len(), "memory routing decision");
412
413        let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
414
415        let (keyword_results, vector_results): (
416            Vec<(MessageId, f64)>,
417            Vec<crate::embedding_store::SearchResult>,
418        ) = match route {
419            MemoryRoute::Keyword => {
420                let kw = self.recall_fts5_raw(query, limit, conversation_id).await?;
421                (kw, Vec::new())
422            }
423            MemoryRoute::Semantic => {
424                let vr = self.recall_vectors_raw(query, limit, filter).await?;
425                (Vec::new(), vr)
426            }
427            MemoryRoute::Hybrid => {
428                let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
429                    Ok(r) => r,
430                    Err(e) => {
431                        tracing::warn!("FTS5 keyword search failed: {e:#}");
432                        Vec::new()
433                    }
434                };
435                let vr = self.recall_vectors_raw(query, limit, filter).await?;
436                (kw, vr)
437            }
438            // Episodic: FTS5 keyword search with an optional timestamp-range filter.
439            // Temporal keywords are stripped from the query before passing to FTS5 to
440            // prevent BM25 score distortion (e.g. "yesterday" matching messages that
441            // literally contain the word "yesterday" regardless of actual relevance).
442            // Vector search is skipped for speed; temporal decay in recall_merge_and_rank
443            // provides recency boosting for the FTS5 results.
444            // Known trade-off (MVP): semantically similar but lexically different messages
445            // may be missed. See issue #1629 for a future hybrid_temporal mode.
446            MemoryRoute::Episodic => {
447                let range = crate::router::resolve_temporal_range(query, chrono::Utc::now());
448                let cleaned = crate::router::strip_temporal_keywords(query);
449                let search_query = if cleaned.is_empty() { query } else { &cleaned };
450                let kw = if let Some(ref r) = range {
451                    self.sqlite
452                        .keyword_search_with_time_range(
453                            search_query,
454                            limit,
455                            conversation_id,
456                            r.after.as_deref(),
457                            r.before.as_deref(),
458                        )
459                        .await?
460                } else {
461                    self.recall_fts5_raw(search_query, limit, conversation_id)
462                        .await?
463                };
464                tracing::debug!(
465                    has_range = range.is_some(),
466                    cleaned_query = %search_query,
467                    keyword_count = kw.len(),
468                    "recall: episodic path"
469                );
470                (kw, Vec::new())
471            }
472            // Graph routing triggers graph_recall separately in agent/context.rs.
473            // For the message-based recall, behave like Hybrid.
474            MemoryRoute::Graph => {
475                let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
476                    Ok(r) => r,
477                    Err(e) => {
478                        tracing::warn!("FTS5 keyword search failed (graph→hybrid fallback): {e:#}");
479                        Vec::new()
480                    }
481                };
482                let vr = self.recall_vectors_raw(query, limit, filter).await?;
483                (kw, vr)
484            }
485        };
486
487        tracing::debug!(
488            keyword_count = keyword_results.len(),
489            vector_count = vector_results.len(),
490            "recall: routed search results"
491        );
492
493        self.recall_merge_and_rank(keyword_results, vector_results, limit)
494            .await
495    }
496
497    /// Retrieve graph facts relevant to `query` via BFS traversal.
498    ///
499    /// Returns an empty `Vec` if no `graph_store` is configured.
500    ///
501    /// # Parameters
502    ///
503    /// - `at_timestamp`: when `Some`, only edges valid at that `SQLite` datetime string are returned.
504    ///   When `None`, only currently active edges are used.
505    /// - `temporal_decay_rate`: non-negative decay rate (1/day). `0.0` preserves original ordering.
506    ///
507    /// # Errors
508    ///
509    /// Returns an error if the underlying graph query fails.
510    pub async fn recall_graph(
511        &self,
512        query: &str,
513        limit: usize,
514        max_hops: u32,
515        at_timestamp: Option<&str>,
516        temporal_decay_rate: f64,
517        edge_types: &[crate::graph::EdgeType],
518    ) -> Result<Vec<crate::graph::types::GraphFact>, MemoryError> {
519        let Some(store) = &self.graph_store else {
520            return Ok(Vec::new());
521        };
522
523        tracing::debug!(
524            query_len = query.len(),
525            limit,
526            max_hops,
527            "graph: starting recall"
528        );
529
530        let results = crate::graph::retrieval::graph_recall(
531            store,
532            self.qdrant.as_deref(),
533            &self.provider,
534            query,
535            limit,
536            max_hops,
537            at_timestamp,
538            temporal_decay_rate,
539            edge_types,
540        )
541        .await?;
542
543        tracing::debug!(result_count = results.len(), "graph: recall complete");
544
545        Ok(results)
546    }
547
548    /// Retrieve graph facts via SYNAPSE spreading activation.
549    ///
550    /// Delegates to [`crate::graph::retrieval::graph_recall_activated`].
551    /// Used in place of [`recall_graph`] when `spreading_activation.enabled = true`.
552    ///
553    /// # Errors
554    ///
555    /// Returns an error if the underlying graph query fails.
556    pub async fn recall_graph_activated(
557        &self,
558        query: &str,
559        limit: usize,
560        params: crate::graph::SpreadingActivationParams,
561        edge_types: &[crate::graph::EdgeType],
562    ) -> Result<Vec<crate::graph::activation::ActivatedFact>, MemoryError> {
563        let Some(store) = &self.graph_store else {
564            return Ok(Vec::new());
565        };
566
567        tracing::debug!(
568            query_len = query.len(),
569            limit,
570            "spreading activation: starting graph recall"
571        );
572
573        let results = crate::graph::retrieval::graph_recall_activated(
574            store, query, limit, params, edge_types,
575        )
576        .await?;
577
578        tracing::debug!(
579            result_count = results.len(),
580            "spreading activation: graph recall complete"
581        );
582
583        Ok(results)
584    }
585
586    /// Increment access count and update `last_accessed` for a batch of message IDs.
587    ///
588    /// Skips the update if `message_ids` is empty to avoid an invalid `IN ()` clause.
589    ///
590    /// # Errors
591    ///
592    /// Returns an error if the `SQLite` update fails.
593    async fn batch_increment_access_count(
594        &self,
595        message_ids: Vec<MessageId>,
596    ) -> Result<(), MemoryError> {
597        if message_ids.is_empty() {
598            return Ok(());
599        }
600        self.sqlite.increment_access_counts(&message_ids).await
601    }
602
603    /// Check whether an embedding exists for a given message ID.
604    ///
605    /// # Errors
606    ///
607    /// Returns an error if the `SQLite` query fails.
608    pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
609        match &self.qdrant {
610            Some(qdrant) => qdrant.has_embedding(message_id).await,
611            None => Ok(false),
612        }
613    }
614
615    /// Embed all messages that do not yet have embeddings.
616    ///
617    /// Returns the count of successfully embedded messages.
618    ///
619    /// # Errors
620    ///
621    /// Returns an error if collection initialization or database query fails.
622    /// Individual embedding failures are logged but do not stop processing.
623    pub async fn embed_missing(&self) -> Result<usize, MemoryError> {
624        let Some(qdrant) = &self.qdrant else {
625            return Ok(0);
626        };
627        if !self.provider.supports_embeddings() {
628            return Ok(0);
629        }
630
631        let unembedded = self.sqlite.unembedded_message_ids(Some(1000)).await?;
632
633        if unembedded.is_empty() {
634            return Ok(0);
635        }
636
637        let probe = self.provider.embed("probe").await?;
638        let vector_size = u64::try_from(probe.len())?;
639        qdrant.ensure_collection(vector_size).await?;
640
641        let mut count = 0;
642        for (msg_id, conversation_id, role, content) in &unembedded {
643            match self.provider.embed(content).await {
644                Ok(vector) => {
645                    if let Err(e) = qdrant
646                        .store(
647                            *msg_id,
648                            *conversation_id,
649                            role,
650                            vector,
651                            MessageKind::Regular,
652                            &self.embedding_model,
653                        )
654                        .await
655                    {
656                        tracing::warn!("Failed to store embedding for msg {msg_id}: {e:#}");
657                        continue;
658                    }
659                    count += 1;
660                }
661                Err(e) => {
662                    tracing::warn!("Failed to embed msg {msg_id}: {e:#}");
663                }
664            }
665        }
666
667        tracing::info!("Embedded {count}/{} missing messages", unembedded.len());
668        Ok(count)
669    }
670}