1use crate::config::SearchConfig;
4use crate::error::MemoryError;
5use crate::types::{SearchResult, SearchSource, SearchSourceType};
6use rusqlite::types::Value as SqlValue;
7use rusqlite::Connection;
8use std::collections::{HashMap, HashSet};
9
10const VECTOR_SCAN_WARN_THRESHOLD: usize = 50_000;
12
13pub fn sanitize_fts_query(raw: &str) -> Option<String> {
17 let cleaned: String = raw
18 .chars()
19 .map(|c| {
20 if matches!(
21 c,
22 '"' | '*' | '+' | '-' | '(' | ')' | '^' | '{' | '}' | '~' | ':'
23 ) {
24 ' '
25 } else {
26 c
27 }
28 })
29 .collect();
30 let tokens: Vec<&str> = cleaned
32 .split_whitespace()
33 .filter(|t| !matches!(t.to_uppercase().as_str(), "AND" | "OR" | "NOT" | "NEAR"))
34 .collect();
35 if tokens.is_empty() {
36 None
37 } else {
38 Some(tokens.join(" "))
39 }
40}
41
42pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
44 debug_assert_eq!(a.len(), b.len(), "embedding dimension mismatch");
45 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
46 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
47 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
48 if norm_a == 0.0 || norm_b == 0.0 {
49 return 0.0;
50 }
51 dot / (norm_a * norm_b)
52}
53
54fn days_since(iso_timestamp: &str) -> Option<f64> {
56 let dt = chrono::NaiveDateTime::parse_from_str(iso_timestamp, "%Y-%m-%d %H:%M:%S").ok()?;
57 let now = chrono::Utc::now().naive_utc();
58 let duration = now - dt;
59 Some(duration.num_seconds() as f64 / 86400.0)
60}
61
62struct RrfCandidate {
64 content: String,
65 source: SearchSource,
66 bm25_rank: Option<usize>,
67 vector_rank: Option<usize>,
68 cosine_similarity: Option<f64>,
69 updated_at: Option<String>,
70}
71
72impl RrfCandidate {
73 fn score(&self, config: &SearchConfig) -> f64 {
74 let bm25_score = self
75 .bm25_rank
76 .map(|r| config.bm25_weight / (config.rrf_k + r as f64))
77 .unwrap_or(0.0);
78 let vector_score = self
79 .vector_rank
80 .map(|r| config.vector_weight / (config.rrf_k + r as f64))
81 .unwrap_or(0.0);
82
83 let recency_score = match (config.recency_half_life_days, &self.updated_at) {
84 (Some(half_life), Some(ts)) if half_life > 0.0 => {
85 let age_days = days_since(ts).unwrap_or(0.0).max(0.0);
86 let decay = 2.0_f64.powf(-age_days / half_life);
87 config.recency_weight * decay / (config.rrf_k + 1.0)
88 }
89 (Some(half_life), _) if half_life <= 0.0 => {
90 tracing::warn!("recency_half_life_days <= 0, ignoring recency boost");
91 0.0
92 }
93 _ => 0.0,
94 };
95
96 bm25_score + vector_score + recency_score
97 }
98}
99
100pub struct Bm25Hit {
102 pub id: String,
104 pub content: String,
106 pub source: SearchSource,
108 pub updated_at: Option<String>,
110}
111
112pub struct VectorHit {
114 pub id: String,
116 pub content: String,
118 pub source: SearchSource,
120 pub similarity: f64,
122 pub updated_at: Option<String>,
124}
125
126struct VectorRow {
128 id: String,
129 content: String,
130 blob: Vec<u8>,
131 updated_at: Option<String>,
132 source: SearchSource,
133}
134
135fn scan_vector_rows(
140 rows: impl Iterator<Item = Result<VectorRow, rusqlite::Error>>,
141 query_embedding: &[f32],
142 min_similarity: f64,
143 table_label: &str,
144) -> Result<(Vec<VectorHit>, usize), MemoryError> {
145 let expected_dims = query_embedding.len();
146 let mut hits = Vec::new();
147 let mut row_count = 0usize;
148
149 for row in rows {
150 let row = row?;
151 row_count += 1;
152
153 if row.blob.len() % 4 != 0 {
154 tracing::warn!(
155 "Skipping {} with invalid embedding length: {}",
156 table_label,
157 row.blob.len()
158 );
159 continue;
160 }
161 let stored_embedding: &[f32] =
162 bytemuck::try_cast_slice(&row.blob).map_err(|_| MemoryError::InvalidEmbedding {
163 expected_bytes: row.blob.len() - (row.blob.len() % 4),
164 actual_bytes: row.blob.len(),
165 })?;
166 if stored_embedding.len() != expected_dims {
167 tracing::warn!(
168 expected = expected_dims,
169 actual = stored_embedding.len(),
170 "Skipping {} with wrong embedding dimensions",
171 table_label
172 );
173 continue;
174 }
175
176 let sim = cosine_similarity(query_embedding, stored_embedding) as f64;
177 if sim >= min_similarity {
178 hits.push(VectorHit {
179 id: row.id,
180 content: row.content,
181 source: row.source,
182 similarity: sim,
183 updated_at: row.updated_at,
184 });
185 }
186 }
187
188 Ok((hits, row_count))
189}
190
191pub(crate) fn bm25_search(
193 conn: &Connection,
194 sanitized_query: &str,
195 pool_size: usize,
196 namespaces: Option<&[&str]>,
197 source_types: Option<&[SearchSourceType]>,
198 session_ids: Option<&[&str]>,
199) -> Result<Vec<Bm25Hit>, MemoryError> {
200 let mut hits = Vec::new();
201
202 let search_facts = source_types
203 .map(|st| st.contains(&SearchSourceType::Facts))
204 .unwrap_or(true);
205 let search_chunks = source_types
206 .map(|st| st.contains(&SearchSourceType::Chunks))
207 .unwrap_or(true);
208 let search_messages = source_types
210 .map(|st| st.contains(&SearchSourceType::Messages))
211 .unwrap_or(false);
212
213 if search_facts {
215 let (ns_clause, ns_params) = build_namespace_clause("f.namespace", namespaces, 3);
216 let sql = format!(
217 "SELECT fm.fact_id, f.content, f.namespace, bm25(facts_fts) AS score, f.updated_at
218 FROM facts_fts
219 JOIN facts_rowid_map fm ON facts_fts.rowid = fm.rowid
220 JOIN facts f ON f.id = fm.fact_id
221 WHERE facts_fts MATCH ?1 {}
222 ORDER BY bm25(facts_fts)
223 LIMIT ?2",
224 ns_clause
225 );
226
227 let mut all_params: Vec<SqlValue> = vec![
228 SqlValue::Text(sanitized_query.to_string()),
229 SqlValue::Integer(pool_size as i64),
230 ];
231 all_params.extend(ns_params.clone());
232
233 let mut stmt = conn.prepare(&sql)?;
234 let rows = stmt.query_map(rusqlite::params_from_iter(&all_params), |row| {
235 let fact_id: String = row.get(0)?;
236 let content: String = row.get(1)?;
237 let namespace: String = row.get(2)?;
238 let updated_at: Option<String> = row.get(4)?;
239 Ok(Bm25Hit {
240 id: fact_id.clone(),
241 content,
242 source: SearchSource::Fact { fact_id, namespace },
243 updated_at,
244 })
245 })?;
246
247 for row in rows {
248 hits.push(row?);
249 }
250 }
251
252 if search_chunks {
254 let (ns_clause, ns_params) = build_namespace_clause("d.namespace", namespaces, 3);
255 let sql = format!(
256 "SELECT cm.chunk_id, c.content, c.document_id, d.title, c.chunk_index, bm25(chunks_fts) AS score, c.created_at
257 FROM chunks_fts
258 JOIN chunks_rowid_map cm ON chunks_fts.rowid = cm.rowid
259 JOIN chunks c ON c.id = cm.chunk_id
260 JOIN documents d ON d.id = c.document_id
261 WHERE chunks_fts MATCH ?1 {}
262 ORDER BY bm25(chunks_fts)
263 LIMIT ?2",
264 ns_clause
265 );
266
267 let mut all_params: Vec<SqlValue> = vec![
268 SqlValue::Text(sanitized_query.to_string()),
269 SqlValue::Integer(pool_size as i64),
270 ];
271 all_params.extend(ns_params.clone());
272
273 let mut stmt = conn.prepare(&sql)?;
274 let rows = stmt.query_map(rusqlite::params_from_iter(&all_params), |row| {
275 let chunk_id: String = row.get(0)?;
276 let content: String = row.get(1)?;
277 let document_id: String = row.get(2)?;
278 let document_title: String = row.get(3)?;
279 let chunk_index: i64 = row.get(4)?;
280 let updated_at: Option<String> = row.get(6)?;
281 Ok(Bm25Hit {
282 id: chunk_id.clone(),
283 content,
284 source: SearchSource::Chunk {
285 chunk_id,
286 document_id,
287 document_title,
288 chunk_index: chunk_index as usize,
289 },
290 updated_at,
291 })
292 })?;
293
294 for row in rows {
295 hits.push(row?);
296 }
297 }
298
299 if search_messages {
301 let (sid_clause, sid_params) = build_namespace_clause("m.session_id", session_ids, 3);
302 let sql = format!(
303 "SELECT mm.message_id, m.content, m.session_id, m.role, bm25(messages_fts) AS score, m.created_at
304 FROM messages_fts
305 JOIN messages_rowid_map mm ON messages_fts.rowid = mm.rowid
306 JOIN messages m ON m.id = mm.message_id
307 WHERE messages_fts MATCH ?1 {}
308 ORDER BY bm25(messages_fts)
309 LIMIT ?2",
310 sid_clause
311 );
312
313 let mut all_params: Vec<SqlValue> = vec![
314 SqlValue::Text(sanitized_query.to_string()),
315 SqlValue::Integer(pool_size as i64),
316 ];
317 all_params.extend(sid_params.clone());
318
319 let mut stmt = conn.prepare(&sql)?;
320 let rows = stmt.query_map(rusqlite::params_from_iter(&all_params), |row| {
321 let message_id: i64 = row.get(0)?;
322 let content: String = row.get(1)?;
323 let session_id: String = row.get(2)?;
324 let role: String = row.get(3)?;
325 let updated_at: Option<String> = row.get(5)?;
326 Ok(Bm25Hit {
327 id: format!("msg:{}", message_id),
328 content,
329 source: SearchSource::Message {
330 message_id,
331 session_id,
332 role,
333 },
334 updated_at,
335 })
336 })?;
337
338 for row in rows {
339 hits.push(row?);
340 }
341 }
342
343 Ok(hits)
344}
345
346pub(crate) fn vector_search(
350 conn: &Connection,
351 query_embedding: &[f32],
352 pool_size: usize,
353 min_similarity: f64,
354 namespaces: Option<&[&str]>,
355 source_types: Option<&[SearchSourceType]>,
356 session_ids: Option<&[&str]>,
357) -> Result<Vec<VectorHit>, MemoryError> {
358 let mut hits = Vec::new();
359
360 let search_facts = source_types
361 .map(|st| st.contains(&SearchSourceType::Facts))
362 .unwrap_or(true);
363 let search_chunks = source_types
364 .map(|st| st.contains(&SearchSourceType::Chunks))
365 .unwrap_or(true);
366 let search_messages = source_types
367 .map(|st| st.contains(&SearchSourceType::Messages))
368 .unwrap_or(false);
369
370 if search_facts {
372 let (ns_clause, ns_params) = build_namespace_clause("namespace", namespaces, 1);
373 let sql = format!(
374 "SELECT id, content, namespace, embedding, updated_at FROM facts WHERE embedding IS NOT NULL {}",
375 ns_clause
376 );
377 let mut stmt = conn.prepare(&sql)?;
378 let rows = stmt.query_map(rusqlite::params_from_iter(&ns_params), |row| {
379 let id: String = row.get(0)?;
380 let content: String = row.get(1)?;
381 let namespace: String = row.get(2)?;
382 let blob: Vec<u8> = row.get(3)?;
383 let updated_at: Option<String> = row.get(4)?;
384 Ok(VectorRow {
385 id: id.clone(),
386 content,
387 blob,
388 updated_at,
389 source: SearchSource::Fact {
390 fact_id: id,
391 namespace,
392 },
393 })
394 })?;
395
396 let (fact_hits, fact_count) =
397 scan_vector_rows(rows, query_embedding, min_similarity, "fact")?;
398 hits.extend(fact_hits);
399
400 if fact_count > VECTOR_SCAN_WARN_THRESHOLD {
401 tracing::warn!(
402 count = fact_count,
403 "facts table exceeds vector scan threshold ({} rows). \
404 Consider namespace partitioning or pruning old data.",
405 fact_count
406 );
407 }
408 }
409
410 if search_chunks {
412 let (ns_clause, ns_params) = build_namespace_clause("d.namespace", namespaces, 1);
413 let sql = format!(
414 "SELECT c.id, c.content, c.document_id, d.title, c.chunk_index, c.embedding, c.created_at
415 FROM chunks c
416 JOIN documents d ON d.id = c.document_id
417 WHERE c.embedding IS NOT NULL {}",
418 ns_clause
419 );
420 let mut stmt = conn.prepare(&sql)?;
421 let rows = stmt.query_map(rusqlite::params_from_iter(&ns_params), |row| {
422 let id: String = row.get(0)?;
423 let content: String = row.get(1)?;
424 let document_id: String = row.get(2)?;
425 let document_title: String = row.get(3)?;
426 let chunk_index: i64 = row.get(4)?;
427 let blob: Vec<u8> = row.get(5)?;
428 let updated_at: Option<String> = row.get(6)?;
429 Ok(VectorRow {
430 id: id.clone(),
431 content,
432 blob,
433 updated_at,
434 source: SearchSource::Chunk {
435 chunk_id: id,
436 document_id,
437 document_title,
438 chunk_index: chunk_index as usize,
439 },
440 })
441 })?;
442
443 let (chunk_hits, chunk_count) =
444 scan_vector_rows(rows, query_embedding, min_similarity, "chunk")?;
445 hits.extend(chunk_hits);
446
447 if chunk_count > VECTOR_SCAN_WARN_THRESHOLD {
448 tracing::warn!(
449 count = chunk_count,
450 "chunks table exceeds vector scan threshold ({} rows). \
451 Consider namespace partitioning or pruning old data.",
452 chunk_count
453 );
454 }
455 }
456
457 if search_messages {
459 let (sid_clause, sid_params) = build_namespace_clause("m.session_id", session_ids, 1);
460 let sql = format!(
461 "SELECT m.id, m.content, m.session_id, m.role, m.embedding, m.created_at
462 FROM messages m
463 WHERE m.embedding IS NOT NULL {}",
464 sid_clause
465 );
466 let mut stmt = conn.prepare(&sql)?;
467 let rows = stmt.query_map(rusqlite::params_from_iter(&sid_params), |row| {
468 let message_id: i64 = row.get(0)?;
469 let content: String = row.get(1)?;
470 let session_id: String = row.get(2)?;
471 let role: String = row.get(3)?;
472 let blob: Vec<u8> = row.get(4)?;
473 let updated_at: Option<String> = row.get(5)?;
474 Ok(VectorRow {
475 id: format!("msg:{}", message_id),
476 content,
477 blob,
478 updated_at,
479 source: SearchSource::Message {
480 message_id,
481 session_id,
482 role,
483 },
484 })
485 })?;
486
487 let (msg_hits, msg_count) = scan_vector_rows(
488 rows,
489 query_embedding,
490 min_similarity,
491 "message",
492 )?;
493 hits.extend(msg_hits);
494
495 if msg_count > VECTOR_SCAN_WARN_THRESHOLD {
496 tracing::warn!(
497 count = msg_count,
498 "messages table exceeds vector scan threshold ({} rows). \
499 Consider pruning old sessions.",
500 msg_count
501 );
502 }
503 }
504
505 hits.sort_by(|a, b| {
507 b.similarity
508 .partial_cmp(&a.similarity)
509 .unwrap_or(std::cmp::Ordering::Equal)
510 });
511 hits.truncate(pool_size);
512
513 Ok(hits)
514}
515
516pub fn rrf_fuse(
518 bm25_hits: &[Bm25Hit],
519 vector_hits: &[VectorHit],
520 config: &SearchConfig,
521 top_k: usize,
522) -> Vec<SearchResult> {
523 let mut candidates: HashMap<String, RrfCandidate> = HashMap::new();
524
525 for (rank_0, hit) in bm25_hits.iter().enumerate() {
527 let rank = rank_0 + 1;
528 candidates
529 .entry(hit.id.clone())
530 .and_modify(|c| {
531 c.bm25_rank = Some(rank);
532 if c.updated_at.is_none() {
534 c.updated_at = hit.updated_at.clone();
535 }
536 })
537 .or_insert(RrfCandidate {
538 content: hit.content.clone(),
539 source: hit.source.clone(),
540 bm25_rank: Some(rank),
541 vector_rank: None,
542 cosine_similarity: None,
543 updated_at: hit.updated_at.clone(),
544 });
545 }
546
547 for (rank_0, hit) in vector_hits.iter().enumerate() {
549 let rank = rank_0 + 1;
550 candidates
551 .entry(hit.id.clone())
552 .and_modify(|c| {
553 c.vector_rank = Some(rank);
554 c.cosine_similarity = Some(hit.similarity);
555 if c.updated_at.is_none() {
556 c.updated_at = hit.updated_at.clone();
557 }
558 })
559 .or_insert(RrfCandidate {
560 content: hit.content.clone(),
561 source: hit.source.clone(),
562 bm25_rank: None,
563 vector_rank: Some(rank),
564 cosine_similarity: Some(hit.similarity),
565 updated_at: hit.updated_at.clone(),
566 });
567 }
568
569 let mut results: Vec<SearchResult> = candidates
571 .into_values()
572 .map(|c| {
573 let score = c.score(config);
574 SearchResult {
575 content: c.content,
576 source: c.source,
577 score,
578 bm25_rank: c.bm25_rank,
579 vector_rank: c.vector_rank,
580 cosine_similarity: c.cosine_similarity,
581 }
582 })
583 .collect();
584
585 results.sort_by(|a, b| {
586 b.score
587 .partial_cmp(&a.score)
588 .unwrap_or(std::cmp::Ordering::Equal)
589 });
590 results.truncate(top_k);
591 results
592}
593
594#[allow(clippy::too_many_arguments)]
599pub fn hybrid_search(
600 conn: &Connection,
601 query: &str,
602 query_embedding: &[f32],
603 config: &SearchConfig,
604 top_k: usize,
605 namespaces: Option<&[&str]>,
606 source_types: Option<&[SearchSourceType]>,
607 session_ids: Option<&[&str]>,
608) -> Result<Vec<SearchResult>, MemoryError> {
609 let bm25_hits = match sanitize_fts_query(query) {
611 Some(sanitized) => bm25_search(
612 conn,
613 &sanitized,
614 config.candidate_pool_size,
615 namespaces,
616 source_types,
617 session_ids,
618 )?,
619 None => Vec::new(),
620 };
621
622 let vector_hits = vector_search(
624 conn,
625 query_embedding,
626 config.candidate_pool_size,
627 config.min_similarity,
628 namespaces,
629 source_types,
630 session_ids,
631 )?;
632
633 let results = rrf_fuse(&bm25_hits, &vector_hits, config, top_k);
635 Ok(deduplicate_results(results))
636}
637
638#[cfg(feature = "hnsw")]
644#[allow(clippy::too_many_arguments)]
645pub fn hybrid_search_with_hnsw(
646 conn: &Connection,
647 query: &str,
648 _query_embedding: &[f32],
649 config: &SearchConfig,
650 top_k: usize,
651 namespaces: Option<&[&str]>,
652 source_types: Option<&[SearchSourceType]>,
653 session_ids: Option<&[&str]>,
654 hnsw_hits: &[crate::hnsw::HnswHit],
655) -> Result<Vec<SearchResult>, MemoryError> {
656 let bm25_hits = match sanitize_fts_query(query) {
658 Some(sanitized) => bm25_search(
659 conn,
660 &sanitized,
661 config.candidate_pool_size,
662 namespaces,
663 source_types,
664 session_ids,
665 )?,
666 None => Vec::new(),
667 };
668
669 let vector_hits = resolve_hnsw_hits_batched(
671 conn, config, namespaces, source_types, session_ids, hnsw_hits,
672 )?;
673
674 let results = rrf_fuse(&bm25_hits, &vector_hits, config, top_k);
676 Ok(deduplicate_results(results))
677}
678
679#[cfg(feature = "hnsw")]
683fn resolve_hnsw_hits_batched(
684 conn: &Connection,
685 config: &SearchConfig,
686 namespaces: Option<&[&str]>,
687 source_types: Option<&[SearchSourceType]>,
688 session_ids: Option<&[&str]>,
689 hnsw_hits: &[crate::hnsw::HnswHit],
690) -> Result<Vec<VectorHit>, MemoryError> {
691 let search_facts = source_types
692 .map(|st| st.contains(&SearchSourceType::Facts))
693 .unwrap_or(true);
694 let search_chunks = source_types
695 .map(|st| st.contains(&SearchSourceType::Chunks))
696 .unwrap_or(true);
697 let search_messages = source_types
698 .map(|st| st.contains(&SearchSourceType::Messages))
699 .unwrap_or(false);
700
701 let mut fact_entries: Vec<(String, f64)> = Vec::new();
703 let mut chunk_entries: Vec<(String, f64)> = Vec::new();
704 let mut msg_entries: Vec<(i64, f64)> = Vec::new();
705
706 for hit in hnsw_hits {
707 let similarity = hit.similarity() as f64;
708 if similarity < config.min_similarity {
709 continue;
710 }
711 match hit.key.split_once(':') {
712 Some(("fact", id)) if search_facts => fact_entries.push((id.to_string(), similarity)),
713 Some(("chunk", id)) if search_chunks => chunk_entries.push((id.to_string(), similarity)),
714 Some(("msg", id)) if search_messages => {
715 if let Ok(mid) = id.parse::<i64>() {
716 msg_entries.push((mid, similarity));
717 }
718 }
719 _ => continue,
720 }
721 }
722
723 let mut vector_hits = Vec::new();
724
725 if !fact_entries.is_empty() {
727 let sim_map: HashMap<String, f64> = fact_entries.iter().cloned().collect();
728 let placeholders: String = (1..=fact_entries.len())
729 .map(|i| format!("?{}", i))
730 .collect::<Vec<_>>()
731 .join(", ");
732 let sql = format!(
733 "SELECT id, content, namespace, updated_at FROM facts WHERE id IN ({})",
734 placeholders
735 );
736 let params: Vec<SqlValue> = fact_entries
737 .iter()
738 .map(|(id, _)| SqlValue::Text(id.clone()))
739 .collect();
740
741 let mut stmt = conn.prepare(&sql)?;
742 let rows = stmt.query_map(rusqlite::params_from_iter(¶ms), |row| {
743 Ok((
744 row.get::<_, String>(0)?,
745 row.get::<_, String>(1)?,
746 row.get::<_, String>(2)?,
747 row.get::<_, Option<String>>(3)?,
748 ))
749 })?;
750
751 for row in rows {
752 let (fact_id, content, namespace, updated_at) = row?;
753 if let Some(ns) = namespaces {
754 if !ns.contains(&namespace.as_str()) {
755 continue;
756 }
757 }
758 if let Some(&similarity) = sim_map.get(&fact_id) {
759 vector_hits.push(VectorHit {
760 id: fact_id.clone(),
761 content,
762 source: SearchSource::Fact { fact_id, namespace },
763 similarity,
764 updated_at,
765 });
766 }
767 }
768 }
769
770 if !chunk_entries.is_empty() {
772 let sim_map: HashMap<String, f64> = chunk_entries.iter().cloned().collect();
773 let placeholders: String = (1..=chunk_entries.len())
774 .map(|i| format!("?{}", i))
775 .collect::<Vec<_>>()
776 .join(", ");
777 let sql = format!(
778 "SELECT c.id, c.content, c.document_id, d.title, c.chunk_index, c.created_at, d.namespace
779 FROM chunks c JOIN documents d ON d.id = c.document_id
780 WHERE c.id IN ({})",
781 placeholders
782 );
783 let params: Vec<SqlValue> = chunk_entries
784 .iter()
785 .map(|(id, _)| SqlValue::Text(id.clone()))
786 .collect();
787
788 let mut stmt = conn.prepare(&sql)?;
789 let rows = stmt.query_map(rusqlite::params_from_iter(¶ms), |row| {
790 Ok((
791 row.get::<_, String>(0)?,
792 row.get::<_, String>(1)?,
793 row.get::<_, String>(2)?,
794 row.get::<_, String>(3)?,
795 row.get::<_, i64>(4)?,
796 row.get::<_, Option<String>>(5)?,
797 row.get::<_, String>(6)?,
798 ))
799 })?;
800
801 for row in rows {
802 let (chunk_id, content, document_id, document_title, chunk_index, updated_at, doc_ns) = row?;
803 if let Some(ns) = namespaces {
804 if !ns.contains(&doc_ns.as_str()) {
805 continue;
806 }
807 }
808 if let Some(&similarity) = sim_map.get(&chunk_id) {
809 vector_hits.push(VectorHit {
810 id: chunk_id.clone(),
811 content,
812 source: SearchSource::Chunk {
813 chunk_id,
814 document_id,
815 document_title,
816 chunk_index: chunk_index as usize,
817 },
818 similarity,
819 updated_at,
820 });
821 }
822 }
823 }
824
825 if !msg_entries.is_empty() {
827 let sim_map: HashMap<i64, f64> = msg_entries.iter().cloned().collect();
828 let placeholders: String = (1..=msg_entries.len())
829 .map(|i| format!("?{}", i))
830 .collect::<Vec<_>>()
831 .join(", ");
832 let sql = format!(
833 "SELECT id, content, session_id, role, created_at FROM messages WHERE id IN ({})",
834 placeholders
835 );
836 let params: Vec<SqlValue> = msg_entries
837 .iter()
838 .map(|(id, _)| SqlValue::Integer(*id))
839 .collect();
840
841 let mut stmt = conn.prepare(&sql)?;
842 let rows = stmt.query_map(rusqlite::params_from_iter(¶ms), |row| {
843 Ok((
844 row.get::<_, i64>(0)?,
845 row.get::<_, String>(1)?,
846 row.get::<_, String>(2)?,
847 row.get::<_, String>(3)?,
848 row.get::<_, Option<String>>(4)?,
849 ))
850 })?;
851
852 for row in rows {
853 let (message_id, content, session_id, role, updated_at) = row?;
854 if let Some(sids) = session_ids {
855 if !sids.contains(&session_id.as_str()) {
856 continue;
857 }
858 }
859 if let Some(&similarity) = sim_map.get(&message_id) {
860 vector_hits.push(VectorHit {
861 id: format!("msg:{}", message_id),
862 content,
863 source: SearchSource::Message {
864 message_id,
865 session_id,
866 role,
867 },
868 similarity,
869 updated_at,
870 });
871 }
872 }
873 }
874
875 vector_hits.sort_by(|a, b| {
877 b.similarity
878 .partial_cmp(&a.similarity)
879 .unwrap_or(std::cmp::Ordering::Equal)
880 });
881 vector_hits.truncate(config.candidate_pool_size);
882
883 Ok(vector_hits)
884}
885
886pub fn fts_only_search(
888 conn: &Connection,
889 query: &str,
890 config: &SearchConfig,
891 top_k: usize,
892 namespaces: Option<&[&str]>,
893 source_types: Option<&[SearchSourceType]>,
894 session_ids: Option<&[&str]>,
895) -> Result<Vec<SearchResult>, MemoryError> {
896 let sanitized = match sanitize_fts_query(query) {
897 Some(s) => s,
898 None => return Ok(Vec::new()),
899 };
900
901 let hits = bm25_search(
902 conn,
903 &sanitized,
904 top_k,
905 namespaces,
906 source_types,
907 session_ids,
908 )?;
909
910 let results: Vec<SearchResult> = hits
911 .into_iter()
912 .enumerate()
913 .map(|(rank_0, hit)| SearchResult {
914 content: hit.content,
915 source: hit.source,
916 score: config.bm25_weight / (config.rrf_k + (rank_0 + 1) as f64),
917 bm25_rank: Some(rank_0 + 1),
918 vector_rank: None,
919 cosine_similarity: None,
920 })
921 .collect();
922
923 Ok(deduplicate_results(results))
924}
925
926pub fn vector_only_search(
928 conn: &Connection,
929 query_embedding: &[f32],
930 config: &SearchConfig,
931 top_k: usize,
932 namespaces: Option<&[&str]>,
933 source_types: Option<&[SearchSourceType]>,
934 session_ids: Option<&[&str]>,
935) -> Result<Vec<SearchResult>, MemoryError> {
936 let hits = vector_search(
937 conn,
938 query_embedding,
939 top_k,
940 config.min_similarity,
941 namespaces,
942 source_types,
943 session_ids,
944 )?;
945
946 let results: Vec<SearchResult> = hits
947 .into_iter()
948 .enumerate()
949 .map(|(rank_0, hit)| SearchResult {
950 content: hit.content,
951 source: hit.source,
952 score: config.vector_weight / (config.rrf_k + (rank_0 + 1) as f64),
953 bm25_rank: None,
954 vector_rank: Some(rank_0 + 1),
955 cosine_similarity: Some(hit.similarity),
956 })
957 .collect();
958
959 Ok(deduplicate_results(results))
960}
961
962#[cfg(feature = "hnsw")]
966#[allow(clippy::too_many_arguments)]
967pub fn vector_only_search_with_hnsw(
968 conn: &Connection,
969 config: &SearchConfig,
970 top_k: usize,
971 namespaces: Option<&[&str]>,
972 source_types: Option<&[SearchSourceType]>,
973 session_ids: Option<&[&str]>,
974 hnsw_hits: &[crate::hnsw::HnswHit],
975) -> Result<Vec<SearchResult>, MemoryError> {
976 let mut vector_hits = resolve_hnsw_hits_batched(
977 conn, config, namespaces, source_types, session_ids, hnsw_hits,
978 )?;
979 vector_hits.truncate(top_k);
980
981 let results: Vec<SearchResult> = vector_hits
982 .into_iter()
983 .enumerate()
984 .map(|(rank_0, hit)| SearchResult {
985 content: hit.content,
986 source: hit.source,
987 score: config.vector_weight / (config.rrf_k + (rank_0 + 1) as f64),
988 bm25_rank: None,
989 vector_rank: Some(rank_0 + 1),
990 cosine_similarity: Some(hit.similarity),
991 })
992 .collect();
993
994 Ok(deduplicate_results(results))
995}
996
997fn source_dedup_key(source: &SearchSource) -> (u8, String) {
1002 match source {
1003 SearchSource::Fact { fact_id, .. } => (0, fact_id.clone()),
1004 SearchSource::Chunk { chunk_id, .. } => (1, chunk_id.clone()),
1005 SearchSource::Message { message_id, .. } => (2, message_id.to_string()),
1006 }
1007}
1008
1009fn deduplicate_results(results: Vec<SearchResult>) -> Vec<SearchResult> {
1011 let mut seen = HashSet::new();
1012 results
1013 .into_iter()
1014 .filter(|r| seen.insert(source_dedup_key(&r.source)))
1015 .collect()
1016}
1017
1018fn build_namespace_clause(
1024 column: &str,
1025 namespaces: Option<&[&str]>,
1026 param_offset: usize,
1027) -> (String, Vec<SqlValue>) {
1028 match namespaces {
1029 Some(ns) if !ns.is_empty() => {
1030 let placeholders: Vec<String> = (0..ns.len())
1031 .map(|i| format!("?{}", param_offset + i))
1032 .collect();
1033 let clause = format!("AND {} IN ({})", column, placeholders.join(", "));
1034 let values: Vec<SqlValue> = ns.iter().map(|n| SqlValue::Text(n.to_string())).collect();
1035 (clause, values)
1036 }
1037 _ => (String::new(), vec![]),
1038 }
1039}