Skip to main content

semantic_memory/
search.rs

1//! Hybrid search engine: BM25 + vector similarity + Reciprocal Rank Fusion.
2
3use crate::config::SearchConfig;
4use crate::error::MemoryError;
5use crate::types::{SearchResult, SearchSource, SearchSourceType};
6use rusqlite::types::Value as SqlValue;
7use rusqlite::Connection;
8use std::collections::{HashMap, HashSet};
9
10/// Per-table row count above which vector search emits a warning.
11const VECTOR_SCAN_WARN_THRESHOLD: usize = 50_000;
12
13/// Sanitize a raw query string for FTS5 MATCH syntax.
14///
15/// Strips FTS5 operators, splits on whitespace, and returns `None` if nothing remains.
16pub fn sanitize_fts_query(raw: &str) -> Option<String> {
17    let cleaned: String = raw
18        .chars()
19        .map(|c| {
20            if matches!(
21                c,
22                '"' | '*' | '+' | '-' | '(' | ')' | '^' | '{' | '}' | '~' | ':'
23            ) {
24                ' '
25            } else {
26                c
27            }
28        })
29        .collect();
30    // Filter out bare FTS5 boolean operators that would cause query errors
31    let tokens: Vec<&str> = cleaned
32        .split_whitespace()
33        .filter(|t| !matches!(t.to_uppercase().as_str(), "AND" | "OR" | "NOT" | "NEAR"))
34        .collect();
35    if tokens.is_empty() {
36        None
37    } else {
38        Some(tokens.join(" "))
39    }
40}
41
42/// Compute cosine similarity between two vectors.
43pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
44    debug_assert_eq!(a.len(), b.len(), "embedding dimension mismatch");
45    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
46    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
47    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
48    if norm_a == 0.0 || norm_b == 0.0 {
49        return 0.0;
50    }
51    dot / (norm_a * norm_b)
52}
53
54/// Compute the number of days since an ISO 8601 timestamp (SQLite format).
55fn days_since(iso_timestamp: &str) -> Option<f64> {
56    let dt = chrono::NaiveDateTime::parse_from_str(iso_timestamp, "%Y-%m-%d %H:%M:%S").ok()?;
57    let now = chrono::Utc::now().naive_utc();
58    let duration = now - dt;
59    Some(duration.num_seconds() as f64 / 86400.0)
60}
61
62/// An RRF candidate accumulating scores from BM25 and vector search.
63struct RrfCandidate {
64    content: String,
65    source: SearchSource,
66    bm25_rank: Option<usize>,
67    vector_rank: Option<usize>,
68    cosine_similarity: Option<f64>,
69    updated_at: Option<String>,
70}
71
72impl RrfCandidate {
73    fn score(&self, config: &SearchConfig) -> f64 {
74        let bm25_score = self
75            .bm25_rank
76            .map(|r| config.bm25_weight / (config.rrf_k + r as f64))
77            .unwrap_or(0.0);
78        let vector_score = self
79            .vector_rank
80            .map(|r| config.vector_weight / (config.rrf_k + r as f64))
81            .unwrap_or(0.0);
82
83        let recency_score = match (config.recency_half_life_days, &self.updated_at) {
84            (Some(half_life), Some(ts)) if half_life > 0.0 => {
85                let age_days = days_since(ts).unwrap_or(0.0).max(0.0);
86                let decay = 2.0_f64.powf(-age_days / half_life);
87                config.recency_weight * decay / (config.rrf_k + 1.0)
88            }
89            (Some(half_life), _) if half_life <= 0.0 => {
90                tracing::warn!("recency_half_life_days <= 0, ignoring recency boost");
91                0.0
92            }
93            _ => 0.0,
94        };
95
96        bm25_score + vector_score + recency_score
97    }
98}
99
100/// A BM25 search hit from FTS5.
101pub struct Bm25Hit {
102    /// Item ID (fact_id or chunk_id).
103    pub id: String,
104    /// Text content.
105    pub content: String,
106    /// Source info.
107    pub source: SearchSource,
108    /// Timestamp for recency scoring.
109    pub updated_at: Option<String>,
110}
111
112/// A vector search hit.
113pub struct VectorHit {
114    /// Item ID (fact_id or chunk_id).
115    pub id: String,
116    /// Text content.
117    pub content: String,
118    /// Source info.
119    pub source: SearchSource,
120    /// Cosine similarity score.
121    pub similarity: f64,
122    /// Timestamp for recency scoring.
123    pub updated_at: Option<String>,
124}
125
126/// Row data extracted from a SQLite query for vector similarity scoring.
127struct VectorRow {
128    id: String,
129    content: String,
130    blob: Vec<u8>,
131    updated_at: Option<String>,
132    source: SearchSource,
133}
134
135/// Decode embedding BLOBs and compute cosine similarity for a set of rows.
136///
137/// Shared logic for the facts, chunks, and messages vector search loops.
138/// Returns the matching hits and the total row count scanned.
139fn scan_vector_rows(
140    rows: impl Iterator<Item = Result<VectorRow, rusqlite::Error>>,
141    query_embedding: &[f32],
142    min_similarity: f64,
143    table_label: &str,
144) -> Result<(Vec<VectorHit>, usize), MemoryError> {
145    let expected_dims = query_embedding.len();
146    let mut hits = Vec::new();
147    let mut row_count = 0usize;
148
149    for row in rows {
150        let row = row?;
151        row_count += 1;
152
153        if row.blob.len() % 4 != 0 {
154            tracing::warn!(
155                "Skipping {} with invalid embedding length: {}",
156                table_label,
157                row.blob.len()
158            );
159            continue;
160        }
161        let stored_embedding: &[f32] =
162            bytemuck::try_cast_slice(&row.blob).map_err(|_| MemoryError::InvalidEmbedding {
163                expected_bytes: row.blob.len() - (row.blob.len() % 4),
164                actual_bytes: row.blob.len(),
165            })?;
166        if stored_embedding.len() != expected_dims {
167            tracing::warn!(
168                expected = expected_dims,
169                actual = stored_embedding.len(),
170                "Skipping {} with wrong embedding dimensions",
171                table_label
172            );
173            continue;
174        }
175
176        let sim = cosine_similarity(query_embedding, stored_embedding) as f64;
177        if sim >= min_similarity {
178            hits.push(VectorHit {
179                id: row.id,
180                content: row.content,
181                source: row.source,
182                similarity: sim,
183                updated_at: row.updated_at,
184            });
185        }
186    }
187
188    Ok((hits, row_count))
189}
190
191/// Run BM25 search over facts_fts, chunks_fts, and optionally messages_fts.
192pub(crate) fn bm25_search(
193    conn: &Connection,
194    sanitized_query: &str,
195    pool_size: usize,
196    namespaces: Option<&[&str]>,
197    source_types: Option<&[SearchSourceType]>,
198    session_ids: Option<&[&str]>,
199) -> Result<Vec<Bm25Hit>, MemoryError> {
200    let mut hits = Vec::new();
201
202    let search_facts = source_types
203        .map(|st| st.contains(&SearchSourceType::Facts))
204        .unwrap_or(true);
205    let search_chunks = source_types
206        .map(|st| st.contains(&SearchSourceType::Chunks))
207        .unwrap_or(true);
208    // Messages are NOT included by default — only when explicitly requested
209    let search_messages = source_types
210        .map(|st| st.contains(&SearchSourceType::Messages))
211        .unwrap_or(false);
212
213    // Search facts
214    if search_facts {
215        let (ns_clause, ns_params) = build_namespace_clause("f.namespace", namespaces, 3);
216        let sql = format!(
217            "SELECT fm.fact_id, f.content, f.namespace, bm25(facts_fts) AS score, f.updated_at
218             FROM facts_fts
219             JOIN facts_rowid_map fm ON facts_fts.rowid = fm.rowid
220             JOIN facts f ON f.id = fm.fact_id
221             WHERE facts_fts MATCH ?1 {}
222             ORDER BY bm25(facts_fts)
223             LIMIT ?2",
224            ns_clause
225        );
226
227        let mut all_params: Vec<SqlValue> = vec![
228            SqlValue::Text(sanitized_query.to_string()),
229            SqlValue::Integer(pool_size as i64),
230        ];
231        all_params.extend(ns_params.clone());
232
233        let mut stmt = conn.prepare(&sql)?;
234        let rows = stmt.query_map(rusqlite::params_from_iter(&all_params), |row| {
235            let fact_id: String = row.get(0)?;
236            let content: String = row.get(1)?;
237            let namespace: String = row.get(2)?;
238            let updated_at: Option<String> = row.get(4)?;
239            Ok(Bm25Hit {
240                id: fact_id.clone(),
241                content,
242                source: SearchSource::Fact { fact_id, namespace },
243                updated_at,
244            })
245        })?;
246
247        for row in rows {
248            hits.push(row?);
249        }
250    }
251
252    // Search chunks
253    if search_chunks {
254        let (ns_clause, ns_params) = build_namespace_clause("d.namespace", namespaces, 3);
255        let sql = format!(
256            "SELECT cm.chunk_id, c.content, c.document_id, d.title, c.chunk_index, bm25(chunks_fts) AS score, c.created_at
257             FROM chunks_fts
258             JOIN chunks_rowid_map cm ON chunks_fts.rowid = cm.rowid
259             JOIN chunks c ON c.id = cm.chunk_id
260             JOIN documents d ON d.id = c.document_id
261             WHERE chunks_fts MATCH ?1 {}
262             ORDER BY bm25(chunks_fts)
263             LIMIT ?2",
264            ns_clause
265        );
266
267        let mut all_params: Vec<SqlValue> = vec![
268            SqlValue::Text(sanitized_query.to_string()),
269            SqlValue::Integer(pool_size as i64),
270        ];
271        all_params.extend(ns_params.clone());
272
273        let mut stmt = conn.prepare(&sql)?;
274        let rows = stmt.query_map(rusqlite::params_from_iter(&all_params), |row| {
275            let chunk_id: String = row.get(0)?;
276            let content: String = row.get(1)?;
277            let document_id: String = row.get(2)?;
278            let document_title: String = row.get(3)?;
279            let chunk_index: i64 = row.get(4)?;
280            let updated_at: Option<String> = row.get(6)?;
281            Ok(Bm25Hit {
282                id: chunk_id.clone(),
283                content,
284                source: SearchSource::Chunk {
285                    chunk_id,
286                    document_id,
287                    document_title,
288                    chunk_index: chunk_index as usize,
289                },
290                updated_at,
291            })
292        })?;
293
294        for row in rows {
295            hits.push(row?);
296        }
297    }
298
299    // Search messages (only when explicitly requested)
300    if search_messages {
301        let (sid_clause, sid_params) = build_namespace_clause("m.session_id", session_ids, 3);
302        let sql = format!(
303            "SELECT mm.message_id, m.content, m.session_id, m.role, bm25(messages_fts) AS score, m.created_at
304             FROM messages_fts
305             JOIN messages_rowid_map mm ON messages_fts.rowid = mm.rowid
306             JOIN messages m ON m.id = mm.message_id
307             WHERE messages_fts MATCH ?1 {}
308             ORDER BY bm25(messages_fts)
309             LIMIT ?2",
310            sid_clause
311        );
312
313        let mut all_params: Vec<SqlValue> = vec![
314            SqlValue::Text(sanitized_query.to_string()),
315            SqlValue::Integer(pool_size as i64),
316        ];
317        all_params.extend(sid_params.clone());
318
319        let mut stmt = conn.prepare(&sql)?;
320        let rows = stmt.query_map(rusqlite::params_from_iter(&all_params), |row| {
321            let message_id: i64 = row.get(0)?;
322            let content: String = row.get(1)?;
323            let session_id: String = row.get(2)?;
324            let role: String = row.get(3)?;
325            let updated_at: Option<String> = row.get(5)?;
326            Ok(Bm25Hit {
327                id: format!("msg:{}", message_id),
328                content,
329                source: SearchSource::Message {
330                    message_id,
331                    session_id,
332                    role,
333                },
334                updated_at,
335            })
336        })?;
337
338        for row in rows {
339            hits.push(row?);
340        }
341    }
342
343    Ok(hits)
344}
345
346/// Run vector similarity search over facts, chunks, and optionally messages.
347///
348/// Uses buffer reuse to avoid per-row allocations during the brute-force scan.
349pub(crate) fn vector_search(
350    conn: &Connection,
351    query_embedding: &[f32],
352    pool_size: usize,
353    min_similarity: f64,
354    namespaces: Option<&[&str]>,
355    source_types: Option<&[SearchSourceType]>,
356    session_ids: Option<&[&str]>,
357) -> Result<Vec<VectorHit>, MemoryError> {
358    let mut hits = Vec::new();
359
360    let search_facts = source_types
361        .map(|st| st.contains(&SearchSourceType::Facts))
362        .unwrap_or(true);
363    let search_chunks = source_types
364        .map(|st| st.contains(&SearchSourceType::Chunks))
365        .unwrap_or(true);
366    let search_messages = source_types
367        .map(|st| st.contains(&SearchSourceType::Messages))
368        .unwrap_or(false);
369
370    // Vector search over facts
371    if search_facts {
372        let (ns_clause, ns_params) = build_namespace_clause("namespace", namespaces, 1);
373        let sql = format!(
374            "SELECT id, content, namespace, embedding, updated_at FROM facts WHERE embedding IS NOT NULL {}",
375            ns_clause
376        );
377        let mut stmt = conn.prepare(&sql)?;
378        let rows = stmt.query_map(rusqlite::params_from_iter(&ns_params), |row| {
379            let id: String = row.get(0)?;
380            let content: String = row.get(1)?;
381            let namespace: String = row.get(2)?;
382            let blob: Vec<u8> = row.get(3)?;
383            let updated_at: Option<String> = row.get(4)?;
384            Ok(VectorRow {
385                id: id.clone(),
386                content,
387                blob,
388                updated_at,
389                source: SearchSource::Fact {
390                    fact_id: id,
391                    namespace,
392                },
393            })
394        })?;
395
396        let (fact_hits, fact_count) =
397            scan_vector_rows(rows, query_embedding, min_similarity, "fact")?;
398        hits.extend(fact_hits);
399
400        if fact_count > VECTOR_SCAN_WARN_THRESHOLD {
401            tracing::warn!(
402                count = fact_count,
403                "facts table exceeds vector scan threshold ({} rows). \
404                 Consider namespace partitioning or pruning old data.",
405                fact_count
406            );
407        }
408    }
409
410    // Vector search over chunks
411    if search_chunks {
412        let (ns_clause, ns_params) = build_namespace_clause("d.namespace", namespaces, 1);
413        let sql = format!(
414            "SELECT c.id, c.content, c.document_id, d.title, c.chunk_index, c.embedding, c.created_at
415             FROM chunks c
416             JOIN documents d ON d.id = c.document_id
417             WHERE c.embedding IS NOT NULL {}",
418            ns_clause
419        );
420        let mut stmt = conn.prepare(&sql)?;
421        let rows = stmt.query_map(rusqlite::params_from_iter(&ns_params), |row| {
422            let id: String = row.get(0)?;
423            let content: String = row.get(1)?;
424            let document_id: String = row.get(2)?;
425            let document_title: String = row.get(3)?;
426            let chunk_index: i64 = row.get(4)?;
427            let blob: Vec<u8> = row.get(5)?;
428            let updated_at: Option<String> = row.get(6)?;
429            Ok(VectorRow {
430                id: id.clone(),
431                content,
432                blob,
433                updated_at,
434                source: SearchSource::Chunk {
435                    chunk_id: id,
436                    document_id,
437                    document_title,
438                    chunk_index: chunk_index as usize,
439                },
440            })
441        })?;
442
443        let (chunk_hits, chunk_count) =
444            scan_vector_rows(rows, query_embedding, min_similarity, "chunk")?;
445        hits.extend(chunk_hits);
446
447        if chunk_count > VECTOR_SCAN_WARN_THRESHOLD {
448            tracing::warn!(
449                count = chunk_count,
450                "chunks table exceeds vector scan threshold ({} rows). \
451                 Consider namespace partitioning or pruning old data.",
452                chunk_count
453            );
454        }
455    }
456
457    // Vector search over messages (only when explicitly requested)
458    if search_messages {
459        let (sid_clause, sid_params) = build_namespace_clause("m.session_id", session_ids, 1);
460        let sql = format!(
461            "SELECT m.id, m.content, m.session_id, m.role, m.embedding, m.created_at
462             FROM messages m
463             WHERE m.embedding IS NOT NULL {}",
464            sid_clause
465        );
466        let mut stmt = conn.prepare(&sql)?;
467        let rows = stmt.query_map(rusqlite::params_from_iter(&sid_params), |row| {
468            let message_id: i64 = row.get(0)?;
469            let content: String = row.get(1)?;
470            let session_id: String = row.get(2)?;
471            let role: String = row.get(3)?;
472            let blob: Vec<u8> = row.get(4)?;
473            let updated_at: Option<String> = row.get(5)?;
474            Ok(VectorRow {
475                id: format!("msg:{}", message_id),
476                content,
477                blob,
478                updated_at,
479                source: SearchSource::Message {
480                    message_id,
481                    session_id,
482                    role,
483                },
484            })
485        })?;
486
487        let (msg_hits, msg_count) = scan_vector_rows(
488            rows,
489            query_embedding,
490            min_similarity,
491            "message",
492        )?;
493        hits.extend(msg_hits);
494
495        if msg_count > VECTOR_SCAN_WARN_THRESHOLD {
496            tracing::warn!(
497                count = msg_count,
498                "messages table exceeds vector scan threshold ({} rows). \
499                 Consider pruning old sessions.",
500                msg_count
501            );
502        }
503    }
504
505    // Sort by similarity descending, take top pool_size
506    hits.sort_by(|a, b| {
507        b.similarity
508            .partial_cmp(&a.similarity)
509            .unwrap_or(std::cmp::Ordering::Equal)
510    });
511    hits.truncate(pool_size);
512
513    Ok(hits)
514}
515
516/// Fuse BM25 and vector results via Reciprocal Rank Fusion.
517pub fn rrf_fuse(
518    bm25_hits: &[Bm25Hit],
519    vector_hits: &[VectorHit],
520    config: &SearchConfig,
521    top_k: usize,
522) -> Vec<SearchResult> {
523    let mut candidates: HashMap<String, RrfCandidate> = HashMap::new();
524
525    // Walk BM25 results (ranks are 1-based)
526    for (rank_0, hit) in bm25_hits.iter().enumerate() {
527        let rank = rank_0 + 1;
528        candidates
529            .entry(hit.id.clone())
530            .and_modify(|c| {
531                c.bm25_rank = Some(rank);
532                // Prefer the most recent timestamp if both sources provide one
533                if c.updated_at.is_none() {
534                    c.updated_at = hit.updated_at.clone();
535                }
536            })
537            .or_insert(RrfCandidate {
538                content: hit.content.clone(),
539                source: hit.source.clone(),
540                bm25_rank: Some(rank),
541                vector_rank: None,
542                cosine_similarity: None,
543                updated_at: hit.updated_at.clone(),
544            });
545    }
546
547    // Walk vector results (ranks are 1-based)
548    for (rank_0, hit) in vector_hits.iter().enumerate() {
549        let rank = rank_0 + 1;
550        candidates
551            .entry(hit.id.clone())
552            .and_modify(|c| {
553                c.vector_rank = Some(rank);
554                c.cosine_similarity = Some(hit.similarity);
555                if c.updated_at.is_none() {
556                    c.updated_at = hit.updated_at.clone();
557                }
558            })
559            .or_insert(RrfCandidate {
560                content: hit.content.clone(),
561                source: hit.source.clone(),
562                bm25_rank: None,
563                vector_rank: Some(rank),
564                cosine_similarity: Some(hit.similarity),
565                updated_at: hit.updated_at.clone(),
566            });
567    }
568
569    // Score, sort, truncate
570    let mut results: Vec<SearchResult> = candidates
571        .into_values()
572        .map(|c| {
573            let score = c.score(config);
574            SearchResult {
575                content: c.content,
576                source: c.source,
577                score,
578                bm25_rank: c.bm25_rank,
579                vector_rank: c.vector_rank,
580                cosine_similarity: c.cosine_similarity,
581            }
582        })
583        .collect();
584
585    results.sort_by(|a, b| {
586        b.score
587            .partial_cmp(&a.score)
588            .unwrap_or(std::cmp::Ordering::Equal)
589    });
590    results.truncate(top_k);
591    results
592}
593
594/// Perform a hybrid search (BM25 + vector + RRF).
595///
596/// This is the main search entry point. Embed query first (async outside this fn),
597/// then call this with the connection locked.
598#[allow(clippy::too_many_arguments)]
599pub fn hybrid_search(
600    conn: &Connection,
601    query: &str,
602    query_embedding: &[f32],
603    config: &SearchConfig,
604    top_k: usize,
605    namespaces: Option<&[&str]>,
606    source_types: Option<&[SearchSourceType]>,
607    session_ids: Option<&[&str]>,
608) -> Result<Vec<SearchResult>, MemoryError> {
609    // BM25 search
610    let bm25_hits = match sanitize_fts_query(query) {
611        Some(sanitized) => bm25_search(
612            conn,
613            &sanitized,
614            config.candidate_pool_size,
615            namespaces,
616            source_types,
617            session_ids,
618        )?,
619        None => Vec::new(),
620    };
621
622    // Vector search
623    let vector_hits = vector_search(
624        conn,
625        query_embedding,
626        config.candidate_pool_size,
627        config.min_similarity,
628        namespaces,
629        source_types,
630        session_ids,
631    )?;
632
633    // RRF fusion + dedup
634    let results = rrf_fuse(&bm25_hits, &vector_hits, config, top_k);
635    Ok(deduplicate_results(results))
636}
637
638/// Perform a hybrid search using pre-computed HNSW hits for the vector component.
639///
640/// Instead of brute-force scanning all rows, this takes the HNSW nearest neighbor
641/// results and looks up their content from SQLite. The rest (BM25 + RRF fusion)
642/// is identical to `hybrid_search`.
643#[cfg(feature = "hnsw")]
644#[allow(clippy::too_many_arguments)]
645pub fn hybrid_search_with_hnsw(
646    conn: &Connection,
647    query: &str,
648    _query_embedding: &[f32],
649    config: &SearchConfig,
650    top_k: usize,
651    namespaces: Option<&[&str]>,
652    source_types: Option<&[SearchSourceType]>,
653    session_ids: Option<&[&str]>,
654    hnsw_hits: &[crate::hnsw::HnswHit],
655) -> Result<Vec<SearchResult>, MemoryError> {
656    // BM25 search (same as hybrid_search)
657    let bm25_hits = match sanitize_fts_query(query) {
658        Some(sanitized) => bm25_search(
659            conn,
660            &sanitized,
661            config.candidate_pool_size,
662            namespaces,
663            source_types,
664            session_ids,
665        )?,
666        None => Vec::new(),
667    };
668
669    // Convert HNSW hits to VectorHits via batched SQLite lookups
670    let vector_hits = resolve_hnsw_hits_batched(
671        conn, config, namespaces, source_types, session_ids, hnsw_hits,
672    )?;
673
674    // RRF fusion + dedup
675    let results = rrf_fuse(&bm25_hits, &vector_hits, config, top_k);
676    Ok(deduplicate_results(results))
677}
678
679/// Resolve HNSW hits to VectorHits using batched SQL queries (one per domain).
680///
681/// Replaces the N+1 query pattern with at most 3 batch queries.
682#[cfg(feature = "hnsw")]
683fn resolve_hnsw_hits_batched(
684    conn: &Connection,
685    config: &SearchConfig,
686    namespaces: Option<&[&str]>,
687    source_types: Option<&[SearchSourceType]>,
688    session_ids: Option<&[&str]>,
689    hnsw_hits: &[crate::hnsw::HnswHit],
690) -> Result<Vec<VectorHit>, MemoryError> {
691    let search_facts = source_types
692        .map(|st| st.contains(&SearchSourceType::Facts))
693        .unwrap_or(true);
694    let search_chunks = source_types
695        .map(|st| st.contains(&SearchSourceType::Chunks))
696        .unwrap_or(true);
697    let search_messages = source_types
698        .map(|st| st.contains(&SearchSourceType::Messages))
699        .unwrap_or(false);
700
701    // Partition HNSW hits by domain
702    let mut fact_entries: Vec<(String, f64)> = Vec::new();
703    let mut chunk_entries: Vec<(String, f64)> = Vec::new();
704    let mut msg_entries: Vec<(i64, f64)> = Vec::new();
705
706    for hit in hnsw_hits {
707        let similarity = hit.similarity() as f64;
708        if similarity < config.min_similarity {
709            continue;
710        }
711        match hit.key.split_once(':') {
712            Some(("fact", id)) if search_facts => fact_entries.push((id.to_string(), similarity)),
713            Some(("chunk", id)) if search_chunks => chunk_entries.push((id.to_string(), similarity)),
714            Some(("msg", id)) if search_messages => {
715                if let Ok(mid) = id.parse::<i64>() {
716                    msg_entries.push((mid, similarity));
717                }
718            }
719            _ => continue,
720        }
721    }
722
723    let mut vector_hits = Vec::new();
724
725    // Batch load facts
726    if !fact_entries.is_empty() {
727        let sim_map: HashMap<String, f64> = fact_entries.iter().cloned().collect();
728        let placeholders: String = (1..=fact_entries.len())
729            .map(|i| format!("?{}", i))
730            .collect::<Vec<_>>()
731            .join(", ");
732        let sql = format!(
733            "SELECT id, content, namespace, updated_at FROM facts WHERE id IN ({})",
734            placeholders
735        );
736        let params: Vec<SqlValue> = fact_entries
737            .iter()
738            .map(|(id, _)| SqlValue::Text(id.clone()))
739            .collect();
740
741        let mut stmt = conn.prepare(&sql)?;
742        let rows = stmt.query_map(rusqlite::params_from_iter(&params), |row| {
743            Ok((
744                row.get::<_, String>(0)?,
745                row.get::<_, String>(1)?,
746                row.get::<_, String>(2)?,
747                row.get::<_, Option<String>>(3)?,
748            ))
749        })?;
750
751        for row in rows {
752            let (fact_id, content, namespace, updated_at) = row?;
753            if let Some(ns) = namespaces {
754                if !ns.contains(&namespace.as_str()) {
755                    continue;
756                }
757            }
758            if let Some(&similarity) = sim_map.get(&fact_id) {
759                vector_hits.push(VectorHit {
760                    id: fact_id.clone(),
761                    content,
762                    source: SearchSource::Fact { fact_id, namespace },
763                    similarity,
764                    updated_at,
765                });
766            }
767        }
768    }
769
770    // Batch load chunks
771    if !chunk_entries.is_empty() {
772        let sim_map: HashMap<String, f64> = chunk_entries.iter().cloned().collect();
773        let placeholders: String = (1..=chunk_entries.len())
774            .map(|i| format!("?{}", i))
775            .collect::<Vec<_>>()
776            .join(", ");
777        let sql = format!(
778            "SELECT c.id, c.content, c.document_id, d.title, c.chunk_index, c.created_at, d.namespace
779             FROM chunks c JOIN documents d ON d.id = c.document_id
780             WHERE c.id IN ({})",
781            placeholders
782        );
783        let params: Vec<SqlValue> = chunk_entries
784            .iter()
785            .map(|(id, _)| SqlValue::Text(id.clone()))
786            .collect();
787
788        let mut stmt = conn.prepare(&sql)?;
789        let rows = stmt.query_map(rusqlite::params_from_iter(&params), |row| {
790            Ok((
791                row.get::<_, String>(0)?,
792                row.get::<_, String>(1)?,
793                row.get::<_, String>(2)?,
794                row.get::<_, String>(3)?,
795                row.get::<_, i64>(4)?,
796                row.get::<_, Option<String>>(5)?,
797                row.get::<_, String>(6)?,
798            ))
799        })?;
800
801        for row in rows {
802            let (chunk_id, content, document_id, document_title, chunk_index, updated_at, doc_ns) = row?;
803            if let Some(ns) = namespaces {
804                if !ns.contains(&doc_ns.as_str()) {
805                    continue;
806                }
807            }
808            if let Some(&similarity) = sim_map.get(&chunk_id) {
809                vector_hits.push(VectorHit {
810                    id: chunk_id.clone(),
811                    content,
812                    source: SearchSource::Chunk {
813                        chunk_id,
814                        document_id,
815                        document_title,
816                        chunk_index: chunk_index as usize,
817                    },
818                    similarity,
819                    updated_at,
820                });
821            }
822        }
823    }
824
825    // Batch load messages
826    if !msg_entries.is_empty() {
827        let sim_map: HashMap<i64, f64> = msg_entries.iter().cloned().collect();
828        let placeholders: String = (1..=msg_entries.len())
829            .map(|i| format!("?{}", i))
830            .collect::<Vec<_>>()
831            .join(", ");
832        let sql = format!(
833            "SELECT id, content, session_id, role, created_at FROM messages WHERE id IN ({})",
834            placeholders
835        );
836        let params: Vec<SqlValue> = msg_entries
837            .iter()
838            .map(|(id, _)| SqlValue::Integer(*id))
839            .collect();
840
841        let mut stmt = conn.prepare(&sql)?;
842        let rows = stmt.query_map(rusqlite::params_from_iter(&params), |row| {
843            Ok((
844                row.get::<_, i64>(0)?,
845                row.get::<_, String>(1)?,
846                row.get::<_, String>(2)?,
847                row.get::<_, String>(3)?,
848                row.get::<_, Option<String>>(4)?,
849            ))
850        })?;
851
852        for row in rows {
853            let (message_id, content, session_id, role, updated_at) = row?;
854            if let Some(sids) = session_ids {
855                if !sids.contains(&session_id.as_str()) {
856                    continue;
857                }
858            }
859            if let Some(&similarity) = sim_map.get(&message_id) {
860                vector_hits.push(VectorHit {
861                    id: format!("msg:{}", message_id),
862                    content,
863                    source: SearchSource::Message {
864                        message_id,
865                        session_id,
866                        role,
867                    },
868                    similarity,
869                    updated_at,
870                });
871            }
872        }
873    }
874
875    // Sort by similarity descending
876    vector_hits.sort_by(|a, b| {
877        b.similarity
878            .partial_cmp(&a.similarity)
879            .unwrap_or(std::cmp::Ordering::Equal)
880    });
881    vector_hits.truncate(config.candidate_pool_size);
882
883    Ok(vector_hits)
884}
885
886/// Full-text search only (no embeddings needed). Synchronous.
887pub fn fts_only_search(
888    conn: &Connection,
889    query: &str,
890    config: &SearchConfig,
891    top_k: usize,
892    namespaces: Option<&[&str]>,
893    source_types: Option<&[SearchSourceType]>,
894    session_ids: Option<&[&str]>,
895) -> Result<Vec<SearchResult>, MemoryError> {
896    let sanitized = match sanitize_fts_query(query) {
897        Some(s) => s,
898        None => return Ok(Vec::new()),
899    };
900
901    let hits = bm25_search(
902        conn,
903        &sanitized,
904        top_k,
905        namespaces,
906        source_types,
907        session_ids,
908    )?;
909
910    let results: Vec<SearchResult> = hits
911        .into_iter()
912        .enumerate()
913        .map(|(rank_0, hit)| SearchResult {
914            content: hit.content,
915            source: hit.source,
916            score: config.bm25_weight / (config.rrf_k + (rank_0 + 1) as f64),
917            bm25_rank: Some(rank_0 + 1),
918            vector_rank: None,
919            cosine_similarity: None,
920        })
921        .collect();
922
923    Ok(deduplicate_results(results))
924}
925
926/// Vector-only search. Called after embedding the query.
927pub fn vector_only_search(
928    conn: &Connection,
929    query_embedding: &[f32],
930    config: &SearchConfig,
931    top_k: usize,
932    namespaces: Option<&[&str]>,
933    source_types: Option<&[SearchSourceType]>,
934    session_ids: Option<&[&str]>,
935) -> Result<Vec<SearchResult>, MemoryError> {
936    let hits = vector_search(
937        conn,
938        query_embedding,
939        top_k,
940        config.min_similarity,
941        namespaces,
942        source_types,
943        session_ids,
944    )?;
945
946    let results: Vec<SearchResult> = hits
947        .into_iter()
948        .enumerate()
949        .map(|(rank_0, hit)| SearchResult {
950            content: hit.content,
951            source: hit.source,
952            score: config.vector_weight / (config.rrf_k + (rank_0 + 1) as f64),
953            bm25_rank: None,
954            vector_rank: Some(rank_0 + 1),
955            cosine_similarity: Some(hit.similarity),
956        })
957        .collect();
958
959    Ok(deduplicate_results(results))
960}
961
962/// Vector-only search using pre-computed HNSW hits.
963///
964/// Skips BM25 entirely. Uses batched SQL lookups via `resolve_hnsw_hits_batched`.
965#[cfg(feature = "hnsw")]
966#[allow(clippy::too_many_arguments)]
967pub fn vector_only_search_with_hnsw(
968    conn: &Connection,
969    config: &SearchConfig,
970    top_k: usize,
971    namespaces: Option<&[&str]>,
972    source_types: Option<&[SearchSourceType]>,
973    session_ids: Option<&[&str]>,
974    hnsw_hits: &[crate::hnsw::HnswHit],
975) -> Result<Vec<SearchResult>, MemoryError> {
976    let mut vector_hits = resolve_hnsw_hits_batched(
977        conn, config, namespaces, source_types, session_ids, hnsw_hits,
978    )?;
979    vector_hits.truncate(top_k);
980
981    let results: Vec<SearchResult> = vector_hits
982        .into_iter()
983        .enumerate()
984        .map(|(rank_0, hit)| SearchResult {
985            content: hit.content,
986            source: hit.source,
987            score: config.vector_weight / (config.rrf_k + (rank_0 + 1) as f64),
988            bm25_rank: None,
989            vector_rank: Some(rank_0 + 1),
990            cosine_similarity: Some(hit.similarity),
991        })
992        .collect();
993
994    Ok(deduplicate_results(results))
995}
996
997/// Extract a dedupe key from a search source: (source type discriminant, primary ID).
998///
999/// This preserves provenance — the same text in a fact and a chunk are kept,
1000/// but the same source appearing via both BM25 and vector paths is deduplicated.
1001fn source_dedup_key(source: &SearchSource) -> (u8, String) {
1002    match source {
1003        SearchSource::Fact { fact_id, .. } => (0, fact_id.clone()),
1004        SearchSource::Chunk { chunk_id, .. } => (1, chunk_id.clone()),
1005        SearchSource::Message { message_id, .. } => (2, message_id.to_string()),
1006    }
1007}
1008
1009/// Deduplicate results by (source_type, source_id), keeping the first (highest-scored) occurrence.
1010fn deduplicate_results(results: Vec<SearchResult>) -> Vec<SearchResult> {
1011    let mut seen = HashSet::new();
1012    results
1013        .into_iter()
1014        .filter(|r| seen.insert(source_dedup_key(&r.source)))
1015        .collect()
1016}
1017
1018/// Build a parameterized namespace filter SQL fragment.
1019///
1020/// Returns a tuple of (SQL clause, parameter values). The `param_offset` sets
1021/// the starting numbered placeholder (e.g., if existing query uses ?1 and ?2,
1022/// pass `param_offset = 3`).
1023fn build_namespace_clause(
1024    column: &str,
1025    namespaces: Option<&[&str]>,
1026    param_offset: usize,
1027) -> (String, Vec<SqlValue>) {
1028    match namespaces {
1029        Some(ns) if !ns.is_empty() => {
1030            let placeholders: Vec<String> = (0..ns.len())
1031                .map(|i| format!("?{}", param_offset + i))
1032                .collect();
1033            let clause = format!("AND {} IN ({})", column, placeholders.join(", "));
1034            let values: Vec<SqlValue> = ns.iter().map(|n| SqlValue::Text(n.to_string())).collect();
1035            (clause, values)
1036        }
1037        _ => (String::new(), vec![]),
1038    }
1039}