Skip to main content

zeph_memory/semantic/
recall.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use futures::{StreamExt as _, TryStreamExt as _};
5use zeph_llm::provider::{LlmProvider as _, Message};
6
7/// Approximate characters per token (conservative estimate for mixed content).
8const CHARS_PER_TOKEN: usize = 4;
9
10/// Target chunk size in characters (~400 tokens).
11const CHUNK_CHARS: usize = 400 * CHARS_PER_TOKEN;
12
13/// Overlap between adjacent chunks in characters (~80 tokens).
14const CHUNK_OVERLAP_CHARS: usize = 80 * CHARS_PER_TOKEN;
15
16/// Split `text` into overlapping chunks suitable for embedding.
17///
18/// For text shorter than `CHUNK_CHARS`, returns a single chunk.
19/// Splits at UTF-8 character boundaries on paragraph (`\n\n`), line (`\n`),
20/// space (` `), or raw character boundaries as a last resort.
21fn chunk_text(text: &str) -> Vec<&str> {
22    if text.len() <= CHUNK_CHARS {
23        return vec![text];
24    }
25
26    let mut chunks = Vec::new();
27    let mut start = 0;
28
29    while start < text.len() {
30        let end = if start + CHUNK_CHARS >= text.len() {
31            text.len()
32        } else {
33            // Find a clean UTF-8 char boundary at or before start + CHUNK_CHARS.
34            let boundary = text.floor_char_boundary(start + CHUNK_CHARS);
35            // Prefer to split at a paragraph or line break for cleaner chunks.
36            let slice = &text[start..boundary];
37            if let Some(pos) = slice.rfind("\n\n") {
38                start + pos + 2
39            } else if let Some(pos) = slice.rfind('\n') {
40                start + pos + 1
41            } else if let Some(pos) = slice.rfind(' ') {
42                start + pos + 1
43            } else {
44                boundary
45            }
46        };
47
48        chunks.push(&text[start..end]);
49        if end >= text.len() {
50            break;
51        }
52        // Next chunk starts with overlap.
53        let next = end.saturating_sub(CHUNK_OVERLAP_CHARS);
54        start = text.ceil_char_boundary(next);
55        if start >= end {
56            start = end; // safeguard against infinite loop
57        }
58    }
59
60    chunks
61}
62
63use crate::admission::log_admission_decision;
64use crate::embedding_store::{MessageKind, SearchFilter};
65use crate::error::MemoryError;
66use crate::types::{ConversationId, MessageId};
67
68use super::SemanticMemory;
69use super::algorithms::{apply_mmr, apply_temporal_decay};
70
71/// Tool execution metadata stored as Qdrant payload fields alongside embeddings.
72///
73/// Stored as payload — NOT prepended to content — to avoid corrupting embedding vectors.
74#[derive(Debug, Clone, Default)]
75pub struct EmbedContext {
76    pub tool_name: Option<String>,
77    pub exit_code: Option<i32>,
78    pub timestamp: Option<String>,
79}
80
81#[derive(Debug)]
82pub struct RecalledMessage {
83    pub message: Message,
84    pub score: f32,
85}
86
87/// Maximum number of concurrent background embed tasks per `SemanticMemory` instance.
88const MAX_EMBED_BG_TASKS: usize = 64;
89
90/// Shared arguments for background embed tasks.
91struct EmbedBgArgs {
92    qdrant: std::sync::Arc<crate::embedding_store::EmbeddingStore>,
93    embed_provider: zeph_llm::any::AnyProvider,
94    embedding_model: String,
95    message_id: MessageId,
96    conversation_id: ConversationId,
97    role: String,
98    content: String,
99}
100
101/// Background task: embed chunks and store as regular message vectors.
102///
103/// All errors are logged as warnings; the function never panics.
104async fn embed_and_store_regular_bg(args: EmbedBgArgs) {
105    let EmbedBgArgs {
106        qdrant,
107        embed_provider,
108        embedding_model,
109        message_id,
110        conversation_id,
111        role,
112        content,
113    } = args;
114    let chunks = chunk_text(&content);
115    let chunk_count = chunks.len();
116
117    let vectors = match embed_provider.embed_batch(&chunks).await {
118        Ok(v) => v,
119        Err(e) => {
120            tracing::warn!("bg embed_regular: failed to embed chunks for msg {message_id}: {e:#}");
121            return;
122        }
123    };
124
125    let Some(first) = vectors.first() else {
126        return;
127    };
128    let vector_size = first.len() as u64;
129    if let Err(e) = qdrant.ensure_collection(vector_size).await {
130        tracing::warn!("bg embed_regular: failed to ensure Qdrant collection: {e:#}");
131        return;
132    }
133
134    for (chunk_index, vector) in vectors.into_iter().enumerate() {
135        let chunk_index_u32 = u32::try_from(chunk_index).unwrap_or(u32::MAX);
136        if let Err(e) = qdrant
137            .store(
138                message_id,
139                conversation_id,
140                &role,
141                vector,
142                MessageKind::Regular,
143                &embedding_model,
144                chunk_index_u32,
145            )
146            .await
147        {
148            tracing::warn!(
149                "bg embed_regular: failed to store chunk {chunk_index}/{chunk_count} \
150                 for msg {message_id}: {e:#}"
151            );
152        }
153    }
154}
155
156/// Background task: embed chunks with tool context metadata and store in Qdrant.
157///
158/// All errors are logged as warnings; the function never panics.
159async fn embed_chunks_with_tool_context_bg(args: EmbedBgArgs, embed_ctx: EmbedContext) {
160    let EmbedBgArgs {
161        qdrant,
162        embed_provider,
163        embedding_model,
164        message_id,
165        conversation_id,
166        role,
167        content,
168    } = args;
169    let chunks = chunk_text(&content);
170    let chunk_count = chunks.len();
171
172    let vectors = match embed_provider.embed_batch(&chunks).await {
173        Ok(v) => v,
174        Err(e) => {
175            tracing::warn!(
176                "bg embed_tool: failed to embed tool-output chunks for msg {message_id}: {e:#}"
177            );
178            return;
179        }
180    };
181
182    if let Some(first) = vectors.first() {
183        let vector_size = first.len() as u64;
184        if let Err(e) = qdrant.ensure_collection(vector_size).await {
185            tracing::warn!("bg embed_tool: failed to ensure Qdrant collection: {e:#}");
186            return;
187        }
188    }
189
190    for (chunk_index, vector) in vectors.into_iter().enumerate() {
191        let chunk_index_u32 = u32::try_from(chunk_index).unwrap_or(u32::MAX);
192        let result = if let Some(ref tool_name) = embed_ctx.tool_name {
193            qdrant
194                .store_with_tool_context(
195                    message_id,
196                    conversation_id,
197                    &role,
198                    vector,
199                    MessageKind::Regular,
200                    &embedding_model,
201                    chunk_index_u32,
202                    tool_name,
203                    embed_ctx.exit_code,
204                    embed_ctx.timestamp.as_deref(),
205                )
206                .await
207                .map(|_| ())
208        } else {
209            qdrant
210                .store(
211                    message_id,
212                    conversation_id,
213                    &role,
214                    vector,
215                    MessageKind::Regular,
216                    &embedding_model,
217                    chunk_index_u32,
218                )
219                .await
220                .map(|_| ())
221        };
222        if let Err(e) = result {
223            tracing::warn!(
224                "bg embed_tool: failed to store chunk {chunk_index}/{chunk_count} \
225                 for msg {message_id}: {e:#}"
226            );
227        }
228    }
229}
230
231/// Background task: embed chunks with optional category and store in Qdrant.
232///
233/// All errors are logged as warnings; the function never panics.
234async fn embed_and_store_with_category_bg(args: EmbedBgArgs, category: Option<String>) {
235    let EmbedBgArgs {
236        qdrant,
237        embed_provider,
238        embedding_model,
239        message_id,
240        conversation_id,
241        role,
242        content,
243    } = args;
244    let chunks = chunk_text(&content);
245    let chunk_count = chunks.len();
246
247    let vectors = match embed_provider.embed_batch(&chunks).await {
248        Ok(v) => v,
249        Err(e) => {
250            tracing::warn!(
251                "bg embed_category: failed to embed categorized chunks for msg {message_id}: {e:#}"
252            );
253            return;
254        }
255    };
256
257    let Some(first) = vectors.first() else {
258        return;
259    };
260    let vector_size = first.len() as u64;
261    if let Err(e) = qdrant.ensure_collection(vector_size).await {
262        tracing::warn!("bg embed_category: failed to ensure Qdrant collection: {e:#}");
263        return;
264    }
265
266    for (chunk_index, vector) in vectors.into_iter().enumerate() {
267        let chunk_index_u32 = u32::try_from(chunk_index).unwrap_or(u32::MAX);
268        if let Err(e) = qdrant
269            .store_with_category(
270                message_id,
271                conversation_id,
272                &role,
273                vector,
274                MessageKind::Regular,
275                &embedding_model,
276                chunk_index_u32,
277                category.as_deref(),
278            )
279            .await
280        {
281            tracing::warn!(
282                "bg embed_category: failed to store chunk {chunk_index}/{chunk_count} \
283                 for msg {message_id}: {e:#}"
284            );
285        }
286    }
287}
288
289impl SemanticMemory {
290    /// Save a message to `SQLite` and optionally embed and store in Qdrant.
291    ///
292    /// Returns `Ok(Some(message_id))` when admitted and persisted.
293    /// Returns `Ok(None)` when A-MAC admission control rejects the message (not an error).
294    ///
295    /// # Errors
296    ///
297    /// Returns an error if the `SQLite` save fails. Embedding failures are logged but not
298    /// propagated.
299    pub async fn remember(
300        &self,
301        conversation_id: ConversationId,
302        role: &str,
303        content: &str,
304        goal_text: Option<&str>,
305    ) -> Result<Option<MessageId>, MemoryError> {
306        // A-MAC admission gate.
307        if let Some(ref admission) = self.admission_control {
308            let decision = admission
309                .evaluate(
310                    content,
311                    role,
312                    &self.provider,
313                    self.qdrant.as_ref(),
314                    goal_text,
315                )
316                .await;
317            let preview: String = content.chars().take(100).collect();
318            log_admission_decision(&decision, &preview, role, admission.threshold());
319            if !decision.admitted {
320                return Ok(None);
321            }
322        }
323
324        let message_id = self
325            .sqlite
326            .save_message(conversation_id, role, content)
327            .await?;
328
329        self.embed_and_store_regular(message_id, conversation_id, role, content);
330
331        Ok(Some(message_id))
332    }
333
334    /// Save a message with pre-serialized parts JSON to `SQLite` and optionally embed in Qdrant.
335    ///
336    /// Returns `Ok((Some(message_id), embedding_stored))` when admitted and persisted.
337    /// Returns `Ok((None, false))` when A-MAC admission control rejects the message.
338    ///
339    /// # Errors
340    ///
341    /// Returns an error if the `SQLite` save fails.
342    pub async fn remember_with_parts(
343        &self,
344        conversation_id: ConversationId,
345        role: &str,
346        content: &str,
347        parts_json: &str,
348        goal_text: Option<&str>,
349    ) -> Result<(Option<MessageId>, bool), MemoryError> {
350        // A-MAC admission gate.
351        if let Some(ref admission) = self.admission_control {
352            let decision = admission
353                .evaluate(
354                    content,
355                    role,
356                    &self.provider,
357                    self.qdrant.as_ref(),
358                    goal_text,
359                )
360                .await;
361            let preview: String = content.chars().take(100).collect();
362            log_admission_decision(&decision, &preview, role, admission.threshold());
363            if !decision.admitted {
364                return Ok((None, false));
365            }
366        }
367
368        let message_id = self
369            .sqlite
370            .save_message_with_parts(conversation_id, role, content, parts_json)
371            .await?;
372
373        let embedding_stored =
374            self.embed_and_store_regular(message_id, conversation_id, role, content);
375
376        Ok((Some(message_id), embedding_stored))
377    }
378
379    /// Save a tool output to `SQLite` and embed with tool metadata in Qdrant payload.
380    ///
381    /// Tool metadata (`tool_name`, `exit_code`, `timestamp`) is stored as Qdrant payload fields
382    /// so it is available for filtering without corrupting the embedding vector.
383    ///
384    /// Returns `Ok(Some(message_id))` when admitted and persisted.
385    /// Returns `Ok(None)` when A-MAC admission control rejects the message.
386    ///
387    /// # Errors
388    ///
389    /// Returns an error if the `SQLite` save fails.
390    pub async fn remember_tool_output(
391        &self,
392        conversation_id: ConversationId,
393        role: &str,
394        content: &str,
395        parts_json: &str,
396        embed_ctx: EmbedContext,
397    ) -> Result<(Option<MessageId>, bool), MemoryError> {
398        if let Some(ref admission) = self.admission_control {
399            let decision = admission
400                .evaluate(content, role, &self.provider, self.qdrant.as_ref(), None)
401                .await;
402            let preview: String = content.chars().take(100).collect();
403            log_admission_decision(&decision, &preview, role, admission.threshold());
404            if !decision.admitted {
405                return Ok((None, false));
406            }
407        }
408
409        let message_id = self
410            .sqlite
411            .save_message_with_parts(conversation_id, role, content, parts_json)
412            .await?;
413
414        let embedding_stored = self.embed_chunks_with_tool_context(
415            message_id,
416            conversation_id,
417            role,
418            content,
419            embed_ctx,
420        );
421
422        Ok((Some(message_id), embedding_stored))
423    }
424
425    /// Save a categorized message to `SQLite` and embed with category payload in Qdrant.
426    ///
427    /// The `category` is stored in both the `messages.category` column and as a Qdrant payload
428    /// field for recall filtering. Uses A-MAC admission gate.
429    ///
430    /// Returns `Ok(Some(message_id))` when admitted; `Ok(None)` when rejected.
431    ///
432    /// # Errors
433    ///
434    /// Returns an error if the `SQLite` save fails.
435    pub async fn remember_categorized(
436        &self,
437        conversation_id: ConversationId,
438        role: &str,
439        content: &str,
440        category: Option<&str>,
441        goal_text: Option<&str>,
442    ) -> Result<Option<MessageId>, MemoryError> {
443        if let Some(ref admission) = self.admission_control {
444            let decision = admission
445                .evaluate(
446                    content,
447                    role,
448                    &self.provider,
449                    self.qdrant.as_ref(),
450                    goal_text,
451                )
452                .await;
453            let preview: String = content.chars().take(100).collect();
454            log_admission_decision(&decision, &preview, role, admission.threshold());
455            if !decision.admitted {
456                return Ok(None);
457            }
458        }
459
460        let message_id = self
461            .sqlite
462            .save_message_with_category(conversation_id, role, content, category)
463            .await?;
464
465        self.embed_and_store_with_category(message_id, conversation_id, role, content, category);
466
467        Ok(Some(message_id))
468    }
469
470    /// Recall messages filtered by category.
471    ///
472    /// When `category` is `None`, behaves identically to [`recall`].
473    ///
474    /// # Errors
475    ///
476    /// Returns an error if the search fails.
477    pub async fn recall_with_category(
478        &self,
479        query: &str,
480        limit: usize,
481        filter: Option<SearchFilter>,
482        category: Option<&str>,
483    ) -> Result<Vec<RecalledMessage>, MemoryError> {
484        let filter_with_category = filter.map(|mut f| {
485            f.category = category.map(str::to_owned);
486            f
487        });
488        self.recall(query, limit, filter_with_category).await
489    }
490
491    /// Reap completed background embed tasks (non-blocking).
492    ///
493    /// Call at turn boundaries to release handles for finished tasks.
494    pub fn reap_embed_tasks(&self) {
495        if let Ok(mut tasks) = self.embed_tasks.lock() {
496            while tasks.try_join_next().is_some() {}
497        }
498    }
499
500    /// Spawn `fut` as a bounded background embed task.
501    ///
502    /// If the task limit is reached, the task is dropped and a debug message is logged.
503    fn spawn_embed_bg<F>(&self, fut: F) -> bool
504    where
505        F: std::future::Future<Output = ()> + Send + 'static,
506    {
507        let Ok(mut tasks) = self.embed_tasks.lock() else {
508            return false;
509        };
510        // Reap any finished tasks before checking capacity.
511        while tasks.try_join_next().is_some() {}
512        if tasks.len() >= MAX_EMBED_BG_TASKS {
513            tracing::debug!("background embed task limit reached, skipping");
514            return false;
515        }
516        tasks.spawn(fut);
517        // embedding dispatched to background; metric not incremented
518        false
519    }
520
521    /// Embed content chunks and store each with an optional category payload field.
522    ///
523    /// Spawns a bounded background task; returns immediately.
524    fn embed_and_store_with_category(
525        &self,
526        message_id: MessageId,
527        conversation_id: ConversationId,
528        role: &str,
529        content: &str,
530        category: Option<&str>,
531    ) -> bool {
532        let Some(qdrant) = self.qdrant.clone() else {
533            return false;
534        };
535        let embed_provider = self.effective_embed_provider().clone();
536        if !embed_provider.supports_embeddings() {
537            return false;
538        }
539        self.spawn_embed_bg(embed_and_store_with_category_bg(
540            EmbedBgArgs {
541                qdrant,
542                embed_provider,
543                embedding_model: self.embedding_model.clone(),
544                message_id,
545                conversation_id,
546                role: role.to_owned(),
547                content: content.to_owned(),
548            },
549            category.map(str::to_owned),
550        ))
551    }
552
553    /// Embed content chunks and store each as a regular (non-tool) message vector.
554    ///
555    /// Spawns a bounded background task; returns immediately.
556    fn embed_and_store_regular(
557        &self,
558        message_id: MessageId,
559        conversation_id: ConversationId,
560        role: &str,
561        content: &str,
562    ) -> bool {
563        let Some(qdrant) = self.qdrant.clone() else {
564            return false;
565        };
566        let embed_provider = self.effective_embed_provider().clone();
567        if !embed_provider.supports_embeddings() {
568            return false;
569        }
570        self.spawn_embed_bg(embed_and_store_regular_bg(EmbedBgArgs {
571            qdrant,
572            embed_provider,
573            embedding_model: self.embedding_model.clone(),
574            message_id,
575            conversation_id,
576            role: role.to_owned(),
577            content: content.to_owned(),
578        }))
579    }
580
581    /// Embed content chunks, enriching Qdrant payload with tool metadata when present.
582    ///
583    /// Spawns a bounded background task; returns immediately.
584    fn embed_chunks_with_tool_context(
585        &self,
586        message_id: MessageId,
587        conversation_id: ConversationId,
588        role: &str,
589        content: &str,
590        embed_ctx: EmbedContext,
591    ) -> bool {
592        let Some(qdrant) = self.qdrant.clone() else {
593            return false;
594        };
595        let embed_provider = self.effective_embed_provider().clone();
596        if !embed_provider.supports_embeddings() {
597            return false;
598        }
599        self.spawn_embed_bg(embed_chunks_with_tool_context_bg(
600            EmbedBgArgs {
601                qdrant,
602                embed_provider,
603                embedding_model: self.embedding_model.clone(),
604                message_id,
605                conversation_id,
606                role: role.to_owned(),
607                content: content.to_owned(),
608            },
609            embed_ctx,
610        ))
611    }
612
613    /// Save a message to `SQLite` without generating an embedding.
614    ///
615    /// Use this when embedding is intentionally skipped (e.g. autosave disabled for assistant).
616    ///
617    /// # Errors
618    ///
619    /// Returns an error if the `SQLite` save fails.
620    pub async fn save_only(
621        &self,
622        conversation_id: ConversationId,
623        role: &str,
624        content: &str,
625        parts_json: &str,
626    ) -> Result<MessageId, MemoryError> {
627        self.sqlite
628            .save_message_with_parts(conversation_id, role, content, parts_json)
629            .await
630    }
631
632    /// Recall relevant messages using hybrid search (vector + FTS5 keyword).
633    ///
634    /// When Qdrant is available, runs both vector and keyword searches, then merges
635    /// results using weighted scoring. When Qdrant is unavailable, falls back to
636    /// FTS5-only keyword search.
637    ///
638    /// # Errors
639    ///
640    /// Returns an error if embedding generation, Qdrant search, or FTS5 query fails.
641    pub async fn recall(
642        &self,
643        query: &str,
644        limit: usize,
645        filter: Option<SearchFilter>,
646    ) -> Result<Vec<RecalledMessage>, MemoryError> {
647        let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
648
649        tracing::debug!(
650            query_len = query.len(),
651            limit,
652            has_filter = filter.is_some(),
653            conversation_id = conversation_id.map(|c| c.0),
654            has_qdrant = self.qdrant.is_some(),
655            "recall: starting hybrid search"
656        );
657
658        let keyword_results = match self
659            .sqlite
660            .keyword_search(query, limit * 2, conversation_id)
661            .await
662        {
663            Ok(results) => results,
664            Err(e) => {
665                tracing::warn!("FTS5 keyword search failed: {e:#}");
666                Vec::new()
667            }
668        };
669
670        let vector_results = if let Some(qdrant) = &self.qdrant
671            && self.provider.supports_embeddings()
672        {
673            let query_vector = self.provider.embed(query).await?;
674            let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
675            qdrant.ensure_collection(vector_size).await?;
676            qdrant.search(&query_vector, limit * 2, filter).await?
677        } else {
678            Vec::new()
679        };
680
681        self.recall_merge_and_rank(keyword_results, vector_results, limit)
682            .await
683    }
684
685    pub(super) async fn recall_fts5_raw(
686        &self,
687        query: &str,
688        limit: usize,
689        conversation_id: Option<ConversationId>,
690    ) -> Result<Vec<(MessageId, f64)>, MemoryError> {
691        self.sqlite
692            .keyword_search(query, limit * 2, conversation_id)
693            .await
694    }
695
696    pub(super) async fn recall_vectors_raw(
697        &self,
698        query: &str,
699        limit: usize,
700        filter: Option<SearchFilter>,
701    ) -> Result<Vec<crate::embedding_store::SearchResult>, MemoryError> {
702        let Some(qdrant) = &self.qdrant else {
703            return Ok(Vec::new());
704        };
705        if !self.provider.supports_embeddings() {
706            return Ok(Vec::new());
707        }
708        let query_vector = self.provider.embed(query).await?;
709        let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
710        qdrant.ensure_collection(vector_size).await?;
711        qdrant.search(&query_vector, limit * 2, filter).await
712    }
713
714    /// Merge raw keyword and vector results, apply weighted scoring, temporal decay, and MMR
715    /// re-ranking, then resolve to `RecalledMessage` objects.
716    ///
717    /// This is the shared post-processing step used by all recall paths.
718    ///
719    /// # Errors
720    ///
721    /// Returns an error if the `SQLite` `messages_by_ids` query fails.
722    #[allow(clippy::cast_possible_truncation, clippy::too_many_lines)]
723    pub(super) async fn recall_merge_and_rank(
724        &self,
725        keyword_results: Vec<(MessageId, f64)>,
726        vector_results: Vec<crate::embedding_store::SearchResult>,
727        limit: usize,
728    ) -> Result<Vec<RecalledMessage>, MemoryError> {
729        tracing::debug!(
730            vector_count = vector_results.len(),
731            keyword_count = keyword_results.len(),
732            limit,
733            "recall: merging search results"
734        );
735
736        let mut scores: std::collections::HashMap<MessageId, f64> =
737            std::collections::HashMap::new();
738
739        if !vector_results.is_empty() {
740            let max_vs = vector_results
741                .iter()
742                .map(|r| r.score)
743                .fold(f32::NEG_INFINITY, f32::max);
744            let norm = if max_vs > 0.0 { max_vs } else { 1.0 };
745            for r in &vector_results {
746                let normalized = f64::from(r.score / norm);
747                *scores.entry(r.message_id).or_default() += normalized * self.vector_weight;
748            }
749        }
750
751        if !keyword_results.is_empty() {
752            let max_ks = keyword_results
753                .iter()
754                .map(|r| r.1)
755                .fold(f64::NEG_INFINITY, f64::max);
756            let norm = if max_ks > 0.0 { max_ks } else { 1.0 };
757            for &(msg_id, score) in &keyword_results {
758                let normalized = score / norm;
759                *scores.entry(msg_id).or_default() += normalized * self.keyword_weight;
760            }
761        }
762
763        if scores.is_empty() {
764            tracing::debug!("recall: empty merge, no overlapping scores");
765            return Ok(Vec::new());
766        }
767
768        let mut ranked: Vec<(MessageId, f64)> = scores.into_iter().collect();
769        ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
770
771        tracing::debug!(
772            merged = ranked.len(),
773            top_score = ranked.first().map(|r| r.1),
774            bottom_score = ranked.last().map(|r| r.1),
775            vector_weight = %self.vector_weight,
776            keyword_weight = %self.keyword_weight,
777            "recall: weighted merge complete"
778        );
779
780        if self.temporal_decay_enabled && self.temporal_decay_half_life_days > 0 {
781            let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
782            match self.sqlite.message_timestamps(&ids).await {
783                Ok(timestamps) => {
784                    apply_temporal_decay(
785                        &mut ranked,
786                        &timestamps,
787                        self.temporal_decay_half_life_days,
788                    );
789                    ranked
790                        .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
791                    tracing::debug!(
792                        half_life_days = self.temporal_decay_half_life_days,
793                        top_score_after = ranked.first().map(|r| r.1),
794                        "recall: temporal decay applied"
795                    );
796                }
797                Err(e) => {
798                    tracing::warn!("temporal decay: failed to fetch timestamps: {e:#}");
799                }
800            }
801        }
802
803        if self.mmr_enabled && !vector_results.is_empty() {
804            if let Some(qdrant) = &self.qdrant {
805                let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
806                match qdrant.get_vectors(&ids).await {
807                    Ok(vec_map) if !vec_map.is_empty() => {
808                        let ranked_len_before = ranked.len();
809                        ranked = apply_mmr(&ranked, &vec_map, self.mmr_lambda, limit);
810                        tracing::debug!(
811                            before = ranked_len_before,
812                            after = ranked.len(),
813                            lambda = %self.mmr_lambda,
814                            "recall: mmr re-ranked"
815                        );
816                    }
817                    Ok(_) => {
818                        ranked.truncate(limit);
819                    }
820                    Err(e) => {
821                        tracing::warn!("MMR: failed to fetch vectors: {e:#}");
822                        ranked.truncate(limit);
823                    }
824                }
825            } else {
826                ranked.truncate(limit);
827            }
828        } else {
829            ranked.truncate(limit);
830        }
831
832        if self.importance_enabled && !ranked.is_empty() {
833            let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
834            match self.sqlite.fetch_importance_scores(&ids).await {
835                Ok(scores) => {
836                    for (msg_id, score) in &mut ranked {
837                        if let Some(&imp) = scores.get(msg_id) {
838                            *score += imp * self.importance_weight;
839                        }
840                    }
841                    ranked
842                        .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
843                    tracing::debug!(
844                        importance_weight = %self.importance_weight,
845                        "recall: importance scores blended"
846                    );
847                }
848                Err(e) => {
849                    tracing::warn!("importance scoring: failed to fetch scores: {e:#}");
850                }
851            }
852        }
853
854        // Apply tier boost: semantic-tier messages receive an additive bonus so distilled facts
855        // rank above episodic messages with the same base score. Additive (not multiplicative)
856        // so the effect is consistent regardless of base score magnitude.
857        if (self.tier_boost_semantic - 1.0).abs() > f64::EPSILON && !ranked.is_empty() {
858            let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
859            match self.sqlite.fetch_tiers(&ids).await {
860                Ok(tiers) => {
861                    let bonus = self.tier_boost_semantic - 1.0;
862                    let mut boosted = false;
863                    for (msg_id, score) in &mut ranked {
864                        if tiers.get(msg_id).map(String::as_str) == Some("semantic") {
865                            *score += bonus;
866                            boosted = true;
867                        }
868                    }
869                    if boosted {
870                        ranked.sort_by(|a, b| {
871                            b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
872                        });
873                        tracing::debug!(
874                            tier_boost = %self.tier_boost_semantic,
875                            "recall: semantic tier boost applied"
876                        );
877                    }
878                }
879                Err(e) => {
880                    tracing::warn!("tier boost: failed to fetch tiers: {e:#}");
881                }
882            }
883        }
884
885        let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
886
887        if !ids.is_empty()
888            && let Err(e) = self.batch_increment_access_count(ids.clone()).await
889        {
890            tracing::warn!("recall: failed to increment access counts: {e:#}");
891        }
892
893        // Update RL admission training data: mark recalled messages as positive examples.
894        if let Err(e) = self.sqlite.mark_training_recalled(&ids).await {
895            tracing::debug!(
896                error = %e,
897                "recall: failed to mark training data as recalled (non-fatal)"
898            );
899        }
900
901        let messages = self.sqlite.messages_by_ids(&ids).await?;
902        let msg_map: std::collections::HashMap<MessageId, _> = messages.into_iter().collect();
903
904        let recalled: Vec<RecalledMessage> = ranked
905            .iter()
906            .filter_map(|(msg_id, score)| {
907                msg_map.get(msg_id).map(|msg| RecalledMessage {
908                    message: msg.clone(),
909                    #[expect(clippy::cast_possible_truncation)]
910                    score: *score as f32,
911                })
912            })
913            .collect();
914
915        tracing::debug!(final_count = recalled.len(), "recall: final results");
916
917        Ok(recalled)
918    }
919
920    /// Recall messages using query-aware routing.
921    ///
922    /// Delegates to FTS5-only, vector-only, or hybrid search based on the router decision,
923    /// then runs the shared merge and ranking pipeline.
924    ///
925    /// # Errors
926    ///
927    /// Returns an error if any underlying search or database operation fails.
928    pub async fn recall_routed(
929        &self,
930        query: &str,
931        limit: usize,
932        filter: Option<SearchFilter>,
933        router: &dyn crate::router::MemoryRouter,
934    ) -> Result<Vec<RecalledMessage>, MemoryError> {
935        use crate::router::MemoryRoute;
936
937        let route = router.route(query);
938        tracing::debug!(?route, query_len = query.len(), "memory routing decision");
939
940        let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
941
942        let (keyword_results, vector_results): (
943            Vec<(MessageId, f64)>,
944            Vec<crate::embedding_store::SearchResult>,
945        ) = match route {
946            MemoryRoute::Keyword => {
947                let kw = self.recall_fts5_raw(query, limit, conversation_id).await?;
948                (kw, Vec::new())
949            }
950            MemoryRoute::Semantic => {
951                let vr = self.recall_vectors_raw(query, limit, filter).await?;
952                (Vec::new(), vr)
953            }
954            MemoryRoute::Hybrid => {
955                let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
956                    Ok(r) => r,
957                    Err(e) => {
958                        tracing::warn!("FTS5 keyword search failed: {e:#}");
959                        Vec::new()
960                    }
961                };
962                let vr = self.recall_vectors_raw(query, limit, filter).await?;
963                (kw, vr)
964            }
965            // Episodic: FTS5 keyword search with an optional timestamp-range filter.
966            // Temporal keywords are stripped from the query before passing to FTS5 to
967            // prevent BM25 score distortion (e.g. "yesterday" matching messages that
968            // literally contain the word "yesterday" regardless of actual relevance).
969            // Vector search is skipped for speed; temporal decay in recall_merge_and_rank
970            // provides recency boosting for the FTS5 results.
971            // Known trade-off (MVP): semantically similar but lexically different messages
972            // may be missed. See issue #1629 for a future hybrid_temporal mode.
973            MemoryRoute::Episodic => {
974                let range = crate::router::resolve_temporal_range(query, chrono::Utc::now());
975                let cleaned = crate::router::strip_temporal_keywords(query);
976                let search_query = if cleaned.is_empty() { query } else { &cleaned };
977                let kw = if let Some(ref r) = range {
978                    self.sqlite
979                        .keyword_search_with_time_range(
980                            search_query,
981                            limit,
982                            conversation_id,
983                            r.after.as_deref(),
984                            r.before.as_deref(),
985                        )
986                        .await?
987                } else {
988                    self.recall_fts5_raw(search_query, limit, conversation_id)
989                        .await?
990                };
991                tracing::debug!(
992                    has_range = range.is_some(),
993                    cleaned_query = %search_query,
994                    keyword_count = kw.len(),
995                    "recall: episodic path"
996                );
997                (kw, Vec::new())
998            }
999            // Graph routing triggers graph_recall separately in agent/context.rs.
1000            // For the message-based recall, behave like Hybrid.
1001            MemoryRoute::Graph => {
1002                let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
1003                    Ok(r) => r,
1004                    Err(e) => {
1005                        tracing::warn!("FTS5 keyword search failed (graph→hybrid fallback): {e:#}");
1006                        Vec::new()
1007                    }
1008                };
1009                let vr = self.recall_vectors_raw(query, limit, filter).await?;
1010                (kw, vr)
1011            }
1012        };
1013
1014        tracing::debug!(
1015            keyword_count = keyword_results.len(),
1016            vector_count = vector_results.len(),
1017            "recall: routed search results"
1018        );
1019
1020        self.recall_merge_and_rank(keyword_results, vector_results, limit)
1021            .await
1022    }
1023
1024    /// Async variant of [`recall_routed`](Self::recall_routed) that uses
1025    /// [`AsyncMemoryRouter::route_async`](crate::router::AsyncMemoryRouter::route_async) when
1026    /// available, enabling LLM-based routing for `LlmRouter` and `HybridRouter`.
1027    ///
1028    /// Falls back to [`recall_routed`](Self::recall_routed) for routers that only implement
1029    /// the sync `MemoryRouter` trait (e.g. `HeuristicRouter`).
1030    ///
1031    /// # Errors
1032    ///
1033    /// Returns an error if any underlying search or database operation fails.
1034    pub async fn recall_routed_async(
1035        &self,
1036        query: &str,
1037        limit: usize,
1038        filter: Option<crate::embedding_store::SearchFilter>,
1039        router: &dyn crate::router::AsyncMemoryRouter,
1040    ) -> Result<Vec<RecalledMessage>, MemoryError> {
1041        use crate::router::MemoryRoute;
1042
1043        let decision = router.route_async(query).await;
1044        let route = decision.route;
1045        tracing::debug!(
1046            ?route,
1047            confidence = decision.confidence,
1048            query_len = query.len(),
1049            "memory routing decision (async)"
1050        );
1051
1052        let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
1053
1054        let (keyword_results, vector_results): (
1055            Vec<(crate::types::MessageId, f64)>,
1056            Vec<crate::embedding_store::SearchResult>,
1057        ) = match route {
1058            MemoryRoute::Keyword => {
1059                let kw = self.recall_fts5_raw(query, limit, conversation_id).await?;
1060                (kw, Vec::new())
1061            }
1062            MemoryRoute::Semantic => {
1063                let vr = self.recall_vectors_raw(query, limit, filter).await?;
1064                (Vec::new(), vr)
1065            }
1066            MemoryRoute::Hybrid => {
1067                let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
1068                    Ok(r) => r,
1069                    Err(e) => {
1070                        tracing::warn!("FTS5 keyword search failed: {e:#}");
1071                        Vec::new()
1072                    }
1073                };
1074                let vr = self.recall_vectors_raw(query, limit, filter).await?;
1075                (kw, vr)
1076            }
1077            MemoryRoute::Episodic => {
1078                let range = crate::router::resolve_temporal_range(query, chrono::Utc::now());
1079                let cleaned = crate::router::strip_temporal_keywords(query);
1080                let search_query = if cleaned.is_empty() { query } else { &cleaned };
1081                let kw = if let Some(ref r) = range {
1082                    self.sqlite
1083                        .keyword_search_with_time_range(
1084                            search_query,
1085                            limit,
1086                            conversation_id,
1087                            r.after.as_deref(),
1088                            r.before.as_deref(),
1089                        )
1090                        .await?
1091                } else {
1092                    self.recall_fts5_raw(search_query, limit, conversation_id)
1093                        .await?
1094                };
1095                (kw, Vec::new())
1096            }
1097            MemoryRoute::Graph => {
1098                let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
1099                    Ok(r) => r,
1100                    Err(e) => {
1101                        tracing::warn!("FTS5 keyword search failed (graph→hybrid fallback): {e:#}");
1102                        Vec::new()
1103                    }
1104                };
1105                let vr = self.recall_vectors_raw(query, limit, filter).await?;
1106                (kw, vr)
1107            }
1108        };
1109
1110        tracing::debug!(
1111            keyword_count = keyword_results.len(),
1112            vector_count = vector_results.len(),
1113            "recall: routed search results (async)"
1114        );
1115
1116        self.recall_merge_and_rank(keyword_results, vector_results, limit)
1117            .await
1118    }
1119
1120    /// Retrieve graph facts relevant to `query` via BFS traversal.
1121    ///
1122    /// Returns an empty `Vec` if no `graph_store` is configured.
1123    ///
1124    /// # Parameters
1125    ///
1126    /// - `at_timestamp`: when `Some`, only edges valid at that `SQLite` datetime string are returned.
1127    ///   When `None`, only currently active edges are used.
1128    /// - `temporal_decay_rate`: non-negative decay rate (1/day). `0.0` preserves original ordering.
1129    ///
1130    /// # Errors
1131    ///
1132    /// Returns an error if the underlying graph query fails.
1133    pub async fn recall_graph(
1134        &self,
1135        query: &str,
1136        limit: usize,
1137        max_hops: u32,
1138        at_timestamp: Option<&str>,
1139        temporal_decay_rate: f64,
1140        edge_types: &[crate::graph::EdgeType],
1141    ) -> Result<Vec<crate::graph::types::GraphFact>, MemoryError> {
1142        let Some(store) = &self.graph_store else {
1143            return Ok(Vec::new());
1144        };
1145
1146        tracing::debug!(
1147            query_len = query.len(),
1148            limit,
1149            max_hops,
1150            "graph: starting recall"
1151        );
1152
1153        let results = crate::graph::retrieval::graph_recall(
1154            store,
1155            self.qdrant.as_deref(),
1156            &self.provider,
1157            query,
1158            limit,
1159            max_hops,
1160            at_timestamp,
1161            temporal_decay_rate,
1162            edge_types,
1163        )
1164        .await?;
1165
1166        tracing::debug!(result_count = results.len(), "graph: recall complete");
1167
1168        Ok(results)
1169    }
1170
1171    /// Retrieve graph facts via SYNAPSE spreading activation.
1172    ///
1173    /// Delegates to [`crate::graph::retrieval::graph_recall_activated`].
1174    /// Used in place of [`recall_graph`] when `spreading_activation.enabled = true`.
1175    ///
1176    /// # Errors
1177    ///
1178    /// Returns an error if the underlying graph query fails.
1179    pub async fn recall_graph_activated(
1180        &self,
1181        query: &str,
1182        limit: usize,
1183        params: crate::graph::SpreadingActivationParams,
1184        edge_types: &[crate::graph::EdgeType],
1185    ) -> Result<Vec<crate::graph::activation::ActivatedFact>, MemoryError> {
1186        let Some(store) = &self.graph_store else {
1187            return Ok(Vec::new());
1188        };
1189
1190        tracing::debug!(
1191            query_len = query.len(),
1192            limit,
1193            "spreading activation: starting graph recall"
1194        );
1195
1196        let embeddings = self.qdrant.as_deref();
1197        let results = crate::graph::retrieval::graph_recall_activated(
1198            store,
1199            embeddings,
1200            &self.provider,
1201            query,
1202            limit,
1203            params,
1204            edge_types,
1205        )
1206        .await?;
1207
1208        tracing::debug!(
1209            result_count = results.len(),
1210            "spreading activation: graph recall complete"
1211        );
1212
1213        Ok(results)
1214    }
1215
1216    /// Increment access count and update `last_accessed` for a batch of message IDs.
1217    ///
1218    /// Skips the update if `message_ids` is empty to avoid an invalid `IN ()` clause.
1219    ///
1220    /// # Errors
1221    ///
1222    /// Returns an error if the `SQLite` update fails.
1223    async fn batch_increment_access_count(
1224        &self,
1225        message_ids: Vec<MessageId>,
1226    ) -> Result<(), MemoryError> {
1227        if message_ids.is_empty() {
1228            return Ok(());
1229        }
1230        self.sqlite.increment_access_counts(&message_ids).await
1231    }
1232
1233    /// Check whether an embedding exists for a given message ID.
1234    ///
1235    /// # Errors
1236    ///
1237    /// Returns an error if the `SQLite` query fails.
1238    pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
1239        match &self.qdrant {
1240            Some(qdrant) => qdrant.has_embedding(message_id).await,
1241            None => Ok(false),
1242        }
1243    }
1244
1245    /// Embed all messages that do not yet have embeddings.
1246    ///
1247    /// Processes unembedded messages in micro-batches of 32, using `buffer_unordered(4)` for
1248    /// concurrent embedding within each batch. Bounded peak memory: at most 32 messages of content
1249    /// plus their embedding vectors are live at any time.
1250    ///
1251    /// When `progress_tx` is `Some`, sends `Some(BackfillProgress)` after each message and
1252    /// `None` on completion (or on timeout/error in the caller).
1253    ///
1254    /// Returns the count of successfully embedded messages.
1255    ///
1256    /// # Errors
1257    ///
1258    /// Returns an error if collection initialization or the streaming query setup fails.
1259    /// Individual embedding failures are logged but do not stop processing.
1260    pub async fn embed_missing(
1261        &self,
1262        progress_tx: Option<tokio::sync::watch::Sender<Option<super::BackfillProgress>>>,
1263    ) -> Result<usize, MemoryError> {
1264        if self.qdrant.is_none() || !self.effective_embed_provider().supports_embeddings() {
1265            return Ok(0);
1266        }
1267
1268        let total = self.sqlite.count_unembedded_messages().await?;
1269        if total == 0 {
1270            return Ok(0);
1271        }
1272
1273        if let Some(tx) = &progress_tx {
1274            let _ = tx.send(Some(super::BackfillProgress { done: 0, total }));
1275        }
1276
1277        let mut done = 0usize;
1278        let mut succeeded = 0usize;
1279
1280        loop {
1281            const BATCH_SIZE: usize = 32;
1282            const BATCH_SIZE_I64: i64 = 32;
1283            let rows: Vec<_> = self
1284                .sqlite
1285                .stream_unembedded_messages(BATCH_SIZE_I64)
1286                .try_collect()
1287                .await?;
1288
1289            if rows.is_empty() {
1290                break;
1291            }
1292
1293            let batch_len = rows.len();
1294
1295            let results: Vec<bool> = futures::stream::iter(rows)
1296                .map(|(msg_id, conv_id, role, content)| async move {
1297                    self.embed_and_store_regular(msg_id, conv_id, &role, &content)
1298                })
1299                .buffer_unordered(4)
1300                .collect()
1301                .await;
1302
1303            for ok in &results {
1304                done += 1;
1305                if *ok {
1306                    succeeded += 1;
1307                }
1308                if let Some(tx) = &progress_tx {
1309                    let _ = tx.send(Some(super::BackfillProgress { done, total }));
1310                }
1311            }
1312
1313            let batch_succeeded = results.iter().filter(|&&b| b).count();
1314            if batch_succeeded > 0 {
1315                tracing::debug!("Backfill batch: {batch_succeeded}/{batch_len} embedded");
1316            }
1317
1318            if batch_len < BATCH_SIZE {
1319                break;
1320            }
1321        }
1322
1323        if let Some(tx) = &progress_tx {
1324            let _ = tx.send(None);
1325        }
1326
1327        if done > 0 {
1328            tracing::info!("Embedded {succeeded}/{total} missing messages");
1329        }
1330        Ok(succeeded)
1331    }
1332}
1333
1334#[cfg(test)]
1335mod tests {
1336    use super::*;
1337
1338    #[test]
1339    fn embed_context_default_all_none() {
1340        let ctx = EmbedContext::default();
1341        assert!(ctx.tool_name.is_none());
1342        assert!(ctx.exit_code.is_none());
1343        assert!(ctx.timestamp.is_none());
1344    }
1345
1346    #[test]
1347    fn embed_context_fields_set_correctly() {
1348        let ctx = EmbedContext {
1349            tool_name: Some("shell".to_string()),
1350            exit_code: Some(0),
1351            timestamp: Some("2026-04-04T00:00:00Z".to_string()),
1352        };
1353        assert_eq!(ctx.tool_name.as_deref(), Some("shell"));
1354        assert_eq!(ctx.exit_code, Some(0));
1355        assert_eq!(ctx.timestamp.as_deref(), Some("2026-04-04T00:00:00Z"));
1356    }
1357
1358    #[test]
1359    fn embed_context_non_zero_exit_code() {
1360        let ctx = EmbedContext {
1361            tool_name: Some("shell".to_string()),
1362            exit_code: Some(1),
1363            timestamp: None,
1364        };
1365        assert_eq!(ctx.exit_code, Some(1));
1366        assert!(ctx.timestamp.is_none());
1367    }
1368}