1use crate::config::SearchConfig;
4use crate::episodes;
5use crate::error::MemoryError;
6use crate::types::{ExplainedResult, ScoreBreakdown, SearchResult, SearchSource, SearchSourceType};
7use rusqlite::types::Value as SqlValue;
8use rusqlite::Connection;
9use std::collections::{HashMap, HashSet};
10
11const VECTOR_SCAN_WARN_THRESHOLD: usize = 50_000;
13
14pub fn sanitize_fts_query(raw: &str) -> Option<String> {
24 let cleaned: String = raw
25 .chars()
26 .map(|c| {
27 if c.is_alphanumeric() || c.is_whitespace() || c == '_' {
28 c
29 } else {
30 ' '
31 }
32 })
33 .collect();
34
35 let tokens: Vec<&str> = cleaned
36 .split_whitespace()
37 .filter(|t| !matches!(t.to_uppercase().as_str(), "AND" | "OR" | "NOT" | "NEAR"))
38 .collect();
39
40 if tokens.is_empty() {
41 None
42 } else {
43 Some(
44 tokens
45 .into_iter()
46 .map(|token| format!("\"{}\"", token.replace('"', "\"\"")))
47 .collect::<Vec<_>>()
48 .join(" OR "),
49 )
50 }
51}
52
53pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
55 debug_assert_eq!(a.len(), b.len(), "embedding dimension mismatch");
56 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
57 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
58 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
59 if norm_a == 0.0 || norm_b == 0.0 {
60 return 0.0;
61 }
62 dot / (norm_a * norm_b)
63}
64
65fn days_since(iso_timestamp: &str) -> Option<f64> {
66 let dt = chrono::NaiveDateTime::parse_from_str(iso_timestamp, "%Y-%m-%d %H:%M:%S").ok()?;
67 let now = chrono::Utc::now().naive_utc();
68 let duration = now - dt;
69 Some(duration.num_seconds() as f64 / 86_400.0)
70}
71
72fn recency_contribution(
73 config: &SearchConfig,
74 updated_at: Option<&str>,
75 best_rank: Option<usize>,
76) -> Option<f64> {
77 match (config.recency_half_life_days, updated_at) {
78 (Some(half_life), Some(ts)) if half_life > 0.0 => {
79 let age_days = days_since(ts).unwrap_or(0.0).max(0.0);
80 let decay = 2.0_f64.powf(-age_days / half_life);
81 let rank = best_rank.unwrap_or(1).max(1) as f64;
82 Some(config.recency_weight * decay / (config.rrf_k + rank))
83 }
84 _ => None,
85 }
86}
87
88pub fn source_dedup_key(source: &SearchSource) -> (u8, String) {
89 match source {
90 SearchSource::Fact { fact_id, .. } => (0, fact_id.clone()),
91 SearchSource::Chunk { chunk_id, .. } => (1, chunk_id.clone()),
92 SearchSource::Message {
93 message_id,
94 session_id,
95 ..
96 } => (2, format!("{session_id}:{message_id}")),
97 SearchSource::Episode { episode_id, .. } => (3, episode_id.clone()),
98 SearchSource::Projection { projection_id, .. } => (4, projection_id.clone()),
99 }
100}
101
102#[derive(Debug, Clone)]
104pub struct Bm25Hit {
105 pub id: String,
107 pub content: String,
109 pub source: SearchSource,
111 pub raw_score: f64,
113 pub updated_at: Option<String>,
115}
116
117#[derive(Debug, Clone)]
119pub struct VectorHit {
120 pub id: String,
122 pub content: String,
124 pub source: SearchSource,
126 pub similarity: f64,
128 pub updated_at: Option<String>,
130 pub source_rank: Option<usize>,
132 pub source_similarity: Option<f64>,
134 pub reranked_from_f32: bool,
136}
137
138struct VectorRow {
139 id: String,
140 content: String,
141 blob: Vec<u8>,
142 updated_at: Option<String>,
143 source: SearchSource,
144}
145
146struct RrfCandidate {
147 content: String,
148 source: SearchSource,
149 updated_at: Option<String>,
150 bm25_score: Option<f64>,
151 bm25_rank: Option<usize>,
152 vector_score: Option<f64>,
153 vector_rank: Option<usize>,
154 vector_source_rank: Option<usize>,
155 vector_source_score: Option<f64>,
156 vector_reranked_from_f32: bool,
157}
158
159impl RrfCandidate {
160 fn explained(self, config: &SearchConfig) -> ExplainedResult {
161 let bm25_contribution = self
162 .bm25_rank
163 .map(|rank| config.bm25_weight / (config.rrf_k + rank as f64));
164 let vector_contribution = self
165 .vector_rank
166 .map(|rank| config.vector_weight / (config.rrf_k + rank as f64));
167 let best_rank = match (self.bm25_rank, self.vector_rank) {
168 (Some(a), Some(b)) => Some(a.min(b)),
169 (Some(a), None) | (None, Some(a)) => Some(a),
170 (None, None) => None,
171 };
172 let recency_score = recency_contribution(config, self.updated_at.as_deref(), best_rank);
173 let rrf_score = bm25_contribution.unwrap_or(0.0)
174 + vector_contribution.unwrap_or(0.0)
175 + recency_score.unwrap_or(0.0);
176
177 let breakdown = ScoreBreakdown {
178 rrf_score,
179 bm25_score: self.bm25_score,
180 vector_score: self.vector_score,
181 recency_score,
182 bm25_rank: self.bm25_rank,
183 vector_rank: self.vector_rank,
184 vector_source_rank: self.vector_source_rank,
185 vector_source_score: self.vector_source_score,
186 bm25_contribution,
187 vector_contribution,
188 vector_reranked_from_f32: self.vector_reranked_from_f32,
189 bm25_weight: config.bm25_weight,
190 vector_weight: config.vector_weight,
191 recency_weight: config.recency_half_life_days.map(|_| config.recency_weight),
192 rrf_k: config.rrf_k,
193 };
194
195 ExplainedResult {
196 result: SearchResult {
197 content: self.content,
198 source: self.source,
199 score: rrf_score,
200 bm25_rank: breakdown.bm25_rank,
201 vector_rank: breakdown.vector_rank,
202 cosine_similarity: breakdown.vector_score,
203 },
204 breakdown,
205 }
206 }
207}
208
209fn scan_vector_rows(
210 rows: impl Iterator<Item = Result<VectorRow, rusqlite::Error>>,
211 query_embedding: &[f32],
212 min_similarity: f64,
213 table_label: &str,
214) -> Result<(Vec<VectorHit>, usize), MemoryError> {
215 let expected_dims = query_embedding.len();
216 let mut hits = Vec::new();
217 let mut row_count = 0usize;
218
219 for row in rows {
220 let row = row?;
221 row_count += 1;
222
223 if row.blob.len() % 4 != 0 {
224 tracing::warn!(
225 "Skipping {} with invalid embedding length: {}",
226 table_label,
227 row.blob.len()
228 );
229 continue;
230 }
231
232 let stored_embedding = bytemuck::try_cast_slice::<u8, f32>(&row.blob).map_err(|e| {
233 tracing::warn!(error = %e, blob_len = row.blob.len(), "embedding cast failed");
234 MemoryError::InvalidEmbedding {
235 expected_bytes: row.blob.len() - (row.blob.len() % 4),
236 actual_bytes: row.blob.len(),
237 }
238 })?;
239
240 if stored_embedding.len() != expected_dims {
241 tracing::warn!(
242 expected = expected_dims,
243 actual = stored_embedding.len(),
244 "Skipping {} with wrong embedding dimensions",
245 table_label
246 );
247 continue;
248 }
249
250 let similarity = cosine_similarity(query_embedding, stored_embedding) as f64;
251 if similarity >= min_similarity {
252 hits.push(VectorHit {
253 id: row.id,
254 content: row.content,
255 source: row.source,
256 similarity,
257 updated_at: row.updated_at,
258 source_rank: None,
259 source_similarity: None,
260 reranked_from_f32: false,
261 });
262 }
263 }
264
265 Ok((hits, row_count))
266}
267
268fn rank_vector_hits(mut hits: Vec<VectorHit>, pool_size: usize) -> Vec<VectorHit> {
269 hits.sort_by(|a, b| {
270 b.similarity.partial_cmp(&a.similarity).unwrap_or_else(|| {
271 if a.similarity.is_nan() {
272 std::cmp::Ordering::Greater
273 } else {
274 std::cmp::Ordering::Less
275 }
276 })
277 });
278
279 for (idx, hit) in hits.iter_mut().enumerate() {
280 hit.source_rank = Some(idx + 1);
281 hit.source_similarity = Some(hit.similarity);
282 }
283
284 hits.truncate(pool_size);
285 hits
286}
287
288pub(crate) fn bm25_search(
290 conn: &Connection,
291 sanitized_query: &str,
292 pool_size: usize,
293 namespaces: Option<&[&str]>,
294 source_types: Option<&[SearchSourceType]>,
295 session_ids: Option<&[&str]>,
296) -> Result<Vec<Bm25Hit>, MemoryError> {
297 let mut hits = Vec::new();
298
299 let search_facts = source_types
300 .map(|st| st.contains(&SearchSourceType::Facts))
301 .unwrap_or(true);
302 let search_chunks = source_types
303 .map(|st| st.contains(&SearchSourceType::Chunks))
304 .unwrap_or(true);
305 let search_messages = source_types
306 .map(|st| st.contains(&SearchSourceType::Messages))
307 .unwrap_or(false);
308 let search_episodes = source_types
309 .map(|st| st.contains(&SearchSourceType::Episodes))
310 .unwrap_or(true);
311
312 if search_facts {
313 let (ns_clause, ns_params) = build_filter_clause("f.namespace", namespaces, 3);
314 let sql = format!(
315 "SELECT fm.fact_id, f.content, f.namespace, bm25(facts_fts) AS score, f.updated_at
316 FROM facts_fts
317 JOIN facts_rowid_map fm ON facts_fts.rowid = fm.rowid
318 JOIN facts f ON f.id = fm.fact_id
319 WHERE facts_fts MATCH ?1 {}
320 ORDER BY score ASC
321 LIMIT ?2",
322 ns_clause
323 );
324
325 let mut params = vec![
326 SqlValue::Text(sanitized_query.to_string()),
327 SqlValue::Integer(pool_size as i64),
328 ];
329 params.extend(ns_params);
330
331 let mut stmt = conn.prepare(&sql)?;
332 let rows = stmt.query_map(rusqlite::params_from_iter(¶ms), |row| {
333 let fact_id: String = row.get(0)?;
334 let content: String = row.get(1)?;
335 let namespace: String = row.get(2)?;
336 let raw_score: f64 = row.get(3)?;
337 let updated_at: Option<String> = row.get(4)?;
338 Ok(Bm25Hit {
339 id: format!("fact:{fact_id}"),
340 content,
341 source: SearchSource::Fact { fact_id, namespace },
342 raw_score,
343 updated_at,
344 })
345 })?;
346
347 for row in rows {
348 hits.push(row?);
349 }
350 }
351
352 if search_chunks {
353 let (ns_clause, ns_params) = build_filter_clause("d.namespace", namespaces, 3);
354 let sql = format!(
355 "SELECT cm.chunk_id, c.content, c.document_id, d.title, c.chunk_index,
356 bm25(chunks_fts) AS score, c.created_at
357 FROM chunks_fts
358 JOIN chunks_rowid_map cm ON chunks_fts.rowid = cm.rowid
359 JOIN chunks c ON c.id = cm.chunk_id
360 JOIN documents d ON d.id = c.document_id
361 WHERE chunks_fts MATCH ?1 {}
362 ORDER BY score ASC
363 LIMIT ?2",
364 ns_clause
365 );
366
367 let mut params = vec![
368 SqlValue::Text(sanitized_query.to_string()),
369 SqlValue::Integer(pool_size as i64),
370 ];
371 params.extend(ns_params);
372
373 let mut stmt = conn.prepare(&sql)?;
374 let rows = stmt.query_map(rusqlite::params_from_iter(¶ms), |row| {
375 let chunk_id: String = row.get(0)?;
376 let content: String = row.get(1)?;
377 let document_id: String = row.get(2)?;
378 let document_title: String = row.get(3)?;
379 let chunk_index: i64 = row.get(4)?;
380 let raw_score: f64 = row.get(5)?;
381 let updated_at: Option<String> = row.get(6)?;
382 Ok(Bm25Hit {
383 id: format!("chunk:{chunk_id}"),
384 content,
385 source: SearchSource::Chunk {
386 chunk_id,
387 document_id,
388 document_title,
389 chunk_index: chunk_index as usize,
390 },
391 raw_score,
392 updated_at,
393 })
394 })?;
395
396 for row in rows {
397 hits.push(row?);
398 }
399 }
400
401 if search_messages {
402 let (sid_clause, sid_params) = build_filter_clause("m.session_id", session_ids, 3);
403 let sql = format!(
404 "SELECT mm.message_id, m.content, m.session_id, m.role,
405 bm25(messages_fts) AS score, m.created_at
406 FROM messages_fts
407 JOIN messages_rowid_map mm ON messages_fts.rowid = mm.rowid
408 JOIN messages m ON m.id = mm.message_id
409 WHERE messages_fts MATCH ?1 {}
410 ORDER BY score ASC
411 LIMIT ?2",
412 sid_clause
413 );
414
415 let mut params = vec![
416 SqlValue::Text(sanitized_query.to_string()),
417 SqlValue::Integer(pool_size as i64),
418 ];
419 params.extend(sid_params);
420
421 let mut stmt = conn.prepare(&sql)?;
422 let rows = stmt.query_map(rusqlite::params_from_iter(¶ms), |row| {
423 let message_id: i64 = row.get(0)?;
424 let content: String = row.get(1)?;
425 let session_id: String = row.get(2)?;
426 let role: String = row.get(3)?;
427 let raw_score: f64 = row.get(4)?;
428 let updated_at: Option<String> = row.get(5)?;
429 Ok(Bm25Hit {
430 id: format!("msg:{message_id}"),
431 content,
432 source: SearchSource::Message {
433 message_id,
434 session_id,
435 role,
436 },
437 raw_score,
438 updated_at,
439 })
440 })?;
441
442 for row in rows {
443 hits.push(row?);
444 }
445 }
446
447 if search_episodes {
448 let (ns_clause, ns_params) = build_filter_clause("d.namespace", namespaces, 3);
449 let sql = format!(
450 "SELECT e.episode_id, e.document_id, e.search_text, e.effect_type, e.outcome,
451 bm25(episodes_fts) AS score, e.updated_at
452 FROM episodes_fts
453 JOIN episodes_rowid_map rm ON episodes_fts.rowid = rm.rowid
454 JOIN episodes e ON e.episode_id = rm.episode_id
455 JOIN documents d ON d.id = e.document_id
456 WHERE episodes_fts MATCH ?1 {}
457 ORDER BY score ASC
458 LIMIT ?2",
459 ns_clause
460 );
461
462 let mut params = vec![
463 SqlValue::Text(sanitized_query.to_string()),
464 SqlValue::Integer(pool_size as i64),
465 ];
466 params.extend(ns_params);
467
468 let mut stmt = conn.prepare(&sql)?;
469 let rows = stmt.query_map(rusqlite::params_from_iter(¶ms), |row| {
470 let episode_id: String = row.get(0)?;
471 let document_id: String = row.get(1)?;
472 let content: String = row.get(2)?;
473 let effect_type: String = row.get(3)?;
474 let outcome: String = row.get(4)?;
475 let raw_score: f64 = row.get(5)?;
476 let updated_at: Option<String> = row.get(6)?;
477 Ok(Bm25Hit {
478 id: episodes::episode_item_key(&episode_id),
479 content,
480 source: SearchSource::Episode {
481 episode_id,
482 document_id,
483 effect_type,
484 outcome,
485 },
486 raw_score,
487 updated_at,
488 })
489 })?;
490
491 for row in rows {
492 hits.push(row?);
493 }
494 }
495
496 Ok(hits)
497}
498
499pub(crate) fn vector_search(
501 conn: &Connection,
502 query_embedding: &[f32],
503 pool_size: usize,
504 min_similarity: f64,
505 namespaces: Option<&[&str]>,
506 source_types: Option<&[SearchSourceType]>,
507 session_ids: Option<&[&str]>,
508) -> Result<Vec<VectorHit>, MemoryError> {
509 let mut hits = Vec::new();
510
511 let search_facts = source_types
512 .map(|st| st.contains(&SearchSourceType::Facts))
513 .unwrap_or(true);
514 let search_chunks = source_types
515 .map(|st| st.contains(&SearchSourceType::Chunks))
516 .unwrap_or(true);
517 let search_messages = source_types
518 .map(|st| st.contains(&SearchSourceType::Messages))
519 .unwrap_or(false);
520 let search_episodes = source_types
521 .map(|st| st.contains(&SearchSourceType::Episodes))
522 .unwrap_or(true);
523
524 if search_facts {
525 let (ns_clause, ns_params) = build_filter_clause("namespace", namespaces, 1);
526 let sql = format!(
527 "SELECT id, content, namespace, embedding, updated_at
528 FROM facts
529 WHERE embedding IS NOT NULL {}",
530 ns_clause
531 );
532
533 let mut stmt = conn.prepare(&sql)?;
534 let rows = stmt.query_map(rusqlite::params_from_iter(&ns_params), |row| {
535 let id: String = row.get(0)?;
536 let content: String = row.get(1)?;
537 let namespace: String = row.get(2)?;
538 let blob: Vec<u8> = row.get(3)?;
539 let updated_at: Option<String> = row.get(4)?;
540 Ok(VectorRow {
541 id: format!("fact:{id}"),
542 content,
543 blob,
544 updated_at,
545 source: SearchSource::Fact {
546 fact_id: id,
547 namespace,
548 },
549 })
550 })?;
551
552 let (fact_hits, fact_count) =
553 scan_vector_rows(rows, query_embedding, min_similarity, "fact")?;
554 hits.extend(fact_hits);
555
556 if fact_count > VECTOR_SCAN_WARN_THRESHOLD {
557 tracing::warn!(
558 count = fact_count,
559 "facts table exceeds vector scan threshold ({} rows)",
560 fact_count
561 );
562 }
563 }
564
565 if search_chunks {
566 let (ns_clause, ns_params) = build_filter_clause("d.namespace", namespaces, 1);
567 let sql = format!(
568 "SELECT c.id, c.content, c.document_id, d.title, c.chunk_index, c.embedding, c.created_at
569 FROM chunks c
570 JOIN documents d ON d.id = c.document_id
571 WHERE c.embedding IS NOT NULL {}",
572 ns_clause
573 );
574
575 let mut stmt = conn.prepare(&sql)?;
576 let rows = stmt.query_map(rusqlite::params_from_iter(&ns_params), |row| {
577 let id: String = row.get(0)?;
578 let content: String = row.get(1)?;
579 let document_id: String = row.get(2)?;
580 let document_title: String = row.get(3)?;
581 let chunk_index: i64 = row.get(4)?;
582 let blob: Vec<u8> = row.get(5)?;
583 let updated_at: Option<String> = row.get(6)?;
584 Ok(VectorRow {
585 id: format!("chunk:{id}"),
586 content,
587 blob,
588 updated_at,
589 source: SearchSource::Chunk {
590 chunk_id: id,
591 document_id,
592 document_title,
593 chunk_index: chunk_index as usize,
594 },
595 })
596 })?;
597
598 let (chunk_hits, chunk_count) =
599 scan_vector_rows(rows, query_embedding, min_similarity, "chunk")?;
600 hits.extend(chunk_hits);
601
602 if chunk_count > VECTOR_SCAN_WARN_THRESHOLD {
603 tracing::warn!(
604 count = chunk_count,
605 "chunks table exceeds vector scan threshold ({} rows)",
606 chunk_count
607 );
608 }
609 }
610
611 if search_messages {
612 let (sid_clause, sid_params) = build_filter_clause("m.session_id", session_ids, 1);
613 let sql = format!(
614 "SELECT m.id, m.content, m.session_id, m.role, m.embedding, m.created_at
615 FROM messages m
616 WHERE m.embedding IS NOT NULL {}",
617 sid_clause
618 );
619
620 let mut stmt = conn.prepare(&sql)?;
621 let rows = stmt.query_map(rusqlite::params_from_iter(&sid_params), |row| {
622 let message_id: i64 = row.get(0)?;
623 let content: String = row.get(1)?;
624 let session_id: String = row.get(2)?;
625 let role: String = row.get(3)?;
626 let blob: Vec<u8> = row.get(4)?;
627 let updated_at: Option<String> = row.get(5)?;
628 Ok(VectorRow {
629 id: format!("msg:{message_id}"),
630 content,
631 blob,
632 updated_at,
633 source: SearchSource::Message {
634 message_id,
635 session_id,
636 role,
637 },
638 })
639 })?;
640
641 let (message_hits, message_count) =
642 scan_vector_rows(rows, query_embedding, min_similarity, "message")?;
643 hits.extend(message_hits);
644
645 if message_count > VECTOR_SCAN_WARN_THRESHOLD {
646 tracing::warn!(
647 count = message_count,
648 "messages table exceeds vector scan threshold ({} rows)",
649 message_count
650 );
651 }
652 }
653
654 if search_episodes {
655 let (ns_clause, ns_params) = build_filter_clause("d.namespace", namespaces, 1);
656 let sql = format!(
657 "SELECT e.episode_id, e.document_id, e.search_text, e.effect_type, e.outcome, e.embedding, e.updated_at
658 FROM episodes e
659 JOIN documents d ON d.id = e.document_id
660 WHERE e.embedding IS NOT NULL {}",
661 ns_clause
662 );
663
664 let mut stmt = conn.prepare(&sql)?;
665 let rows = stmt.query_map(rusqlite::params_from_iter(&ns_params), |row| {
666 let episode_id: String = row.get(0)?;
667 let document_id: String = row.get(1)?;
668 let content: String = row.get(2)?;
669 let effect_type: String = row.get(3)?;
670 let outcome: String = row.get(4)?;
671 let blob: Vec<u8> = row.get(5)?;
672 let updated_at: Option<String> = row.get(6)?;
673 Ok(VectorRow {
674 id: episodes::episode_item_key(&episode_id),
675 content,
676 blob,
677 updated_at,
678 source: SearchSource::Episode {
679 episode_id,
680 document_id,
681 effect_type,
682 outcome,
683 },
684 })
685 })?;
686
687 let (episode_hits, episode_count) =
688 scan_vector_rows(rows, query_embedding, min_similarity, "episode")?;
689 hits.extend(episode_hits);
690
691 if episode_count > VECTOR_SCAN_WARN_THRESHOLD {
692 tracing::warn!(
693 count = episode_count,
694 "episodes table exceeds vector scan threshold ({} rows)",
695 episode_count
696 );
697 }
698 }
699
700 Ok(rank_vector_hits(hits, pool_size))
701}
702
703fn rrf_fuse_detailed(
704 bm25_hits: &[Bm25Hit],
705 vector_hits: &[VectorHit],
706 config: &SearchConfig,
707 top_k: usize,
708) -> Vec<ExplainedResult> {
709 let mut candidates: HashMap<(u8, String), RrfCandidate> = HashMap::new();
711
712 for (rank_0, hit) in bm25_hits.iter().enumerate() {
713 let key = source_dedup_key(&hit.source);
714 let rank = rank_0 + 1;
715 candidates
716 .entry(key)
717 .and_modify(|candidate| {
718 candidate.bm25_rank = Some(rank);
719 candidate.bm25_score = Some(hit.raw_score);
720 if candidate.updated_at.is_none() {
721 candidate.updated_at = hit.updated_at.clone();
722 }
723 })
724 .or_insert_with(|| RrfCandidate {
725 content: hit.content.clone(),
726 source: hit.source.clone(),
727 updated_at: hit.updated_at.clone(),
728 bm25_score: Some(hit.raw_score),
729 bm25_rank: Some(rank),
730 vector_score: None,
731 vector_rank: None,
732 vector_source_rank: None,
733 vector_source_score: None,
734 vector_reranked_from_f32: false,
735 });
736 }
737
738 for (rank_0, hit) in vector_hits.iter().enumerate() {
739 let key = source_dedup_key(&hit.source);
740 let rank = rank_0 + 1;
741 candidates
742 .entry(key)
743 .and_modify(|candidate| {
744 candidate.vector_rank = Some(rank);
745 candidate.vector_score = Some(hit.similarity);
746 candidate.vector_source_rank = hit.source_rank.or(Some(rank));
747 candidate.vector_source_score = hit.source_similarity.or(Some(hit.similarity));
748 candidate.vector_reranked_from_f32 = hit.reranked_from_f32;
749 if candidate.updated_at.is_none() {
750 candidate.updated_at = hit.updated_at.clone();
751 }
752 })
753 .or_insert_with(|| RrfCandidate {
754 content: hit.content.clone(),
755 source: hit.source.clone(),
756 updated_at: hit.updated_at.clone(),
757 bm25_score: None,
758 bm25_rank: None,
759 vector_score: Some(hit.similarity),
760 vector_rank: Some(rank),
761 vector_source_rank: hit.source_rank.or(Some(rank)),
762 vector_source_score: hit.source_similarity.or(Some(hit.similarity)),
763 vector_reranked_from_f32: hit.reranked_from_f32,
764 });
765 }
766
767 let mut explained: Vec<ExplainedResult> = candidates
768 .into_values()
769 .map(|candidate| candidate.explained(config))
770 .collect();
771
772 explained.sort_by(|a, b| {
773 b.result
774 .score
775 .partial_cmp(&a.result.score)
776 .unwrap_or(std::cmp::Ordering::Equal)
777 .then_with(|| {
778 source_dedup_key(&a.result.source).cmp(&source_dedup_key(&b.result.source))
779 })
780 });
781 explained.truncate(top_k);
782 explained
783}
784
785pub fn rrf_fuse(
787 bm25_hits: &[Bm25Hit],
788 vector_hits: &[VectorHit],
789 config: &SearchConfig,
790 top_k: usize,
791) -> Vec<SearchResult> {
792 rrf_fuse_detailed(bm25_hits, vector_hits, config, top_k)
793 .into_iter()
794 .map(|result| result.result)
795 .collect()
796}
797
798#[allow(clippy::too_many_arguments)]
799pub(crate) fn hybrid_search_detailed(
800 conn: &Connection,
801 query: &str,
802 query_embedding: &[f32],
803 config: &SearchConfig,
804 top_k: usize,
805 namespaces: Option<&[&str]>,
806 source_types: Option<&[SearchSourceType]>,
807 session_ids: Option<&[&str]>,
808) -> Result<Vec<ExplainedResult>, MemoryError> {
809 let bm25_hits = match sanitize_fts_query(query) {
810 Some(sanitized) => bm25_search(
811 conn,
812 &sanitized,
813 config.candidate_pool_size,
814 namespaces,
815 source_types,
816 session_ids,
817 )?,
818 None => Vec::new(),
819 };
820
821 let vector_hits = vector_search(
822 conn,
823 query_embedding,
824 config.candidate_pool_size,
825 config.min_similarity,
826 namespaces,
827 source_types,
828 session_ids,
829 )?;
830
831 Ok(rrf_fuse_detailed(&bm25_hits, &vector_hits, config, top_k))
832}
833
834#[allow(clippy::too_many_arguments)]
836pub fn hybrid_search_explained(
837 conn: &Connection,
838 query: &str,
839 query_embedding: &[f32],
840 config: &SearchConfig,
841 top_k: usize,
842 namespaces: Option<&[&str]>,
843 source_types: Option<&[SearchSourceType]>,
844 session_ids: Option<&[&str]>,
845) -> Result<Vec<ExplainedResult>, MemoryError> {
846 hybrid_search_detailed(
847 conn,
848 query,
849 query_embedding,
850 config,
851 top_k,
852 namespaces,
853 source_types,
854 session_ids,
855 )
856}
857
858#[allow(clippy::too_many_arguments)]
860pub fn hybrid_search(
861 conn: &Connection,
862 query: &str,
863 query_embedding: &[f32],
864 config: &SearchConfig,
865 top_k: usize,
866 namespaces: Option<&[&str]>,
867 source_types: Option<&[SearchSourceType]>,
868 session_ids: Option<&[&str]>,
869) -> Result<Vec<SearchResult>, MemoryError> {
870 Ok(hybrid_search_detailed(
871 conn,
872 query,
873 query_embedding,
874 config,
875 top_k,
876 namespaces,
877 source_types,
878 session_ids,
879 )?
880 .into_iter()
881 .map(|result| result.result)
882 .collect())
883}
884
885#[cfg(feature = "hnsw")]
886#[derive(Clone)]
887struct HnswCandidateSeed {
888 source_rank: usize,
889 source_similarity: f64,
890}
891
892#[cfg(feature = "hnsw")]
893#[allow(clippy::type_complexity)]
894fn resolve_hnsw_hits_batched(
895 conn: &Connection,
896 query_embedding: &[f32],
897 config: &SearchConfig,
898 namespaces: Option<&[&str]>,
899 source_types: Option<&[SearchSourceType]>,
900 session_ids: Option<&[&str]>,
901 hnsw_hits: &[crate::hnsw::HnswHit],
902) -> Result<Vec<VectorHit>, MemoryError> {
903 let search_facts = source_types
904 .map(|st| st.contains(&SearchSourceType::Facts))
905 .unwrap_or(true);
906 let search_chunks = source_types
907 .map(|st| st.contains(&SearchSourceType::Chunks))
908 .unwrap_or(true);
909 let search_messages = source_types
910 .map(|st| st.contains(&SearchSourceType::Messages))
911 .unwrap_or(false);
912 let search_episodes = source_types
913 .map(|st| st.contains(&SearchSourceType::Episodes))
914 .unwrap_or(true);
915
916 let mut fact_entries: HashMap<String, HnswCandidateSeed> = HashMap::new();
918 let mut chunk_entries: HashMap<String, HnswCandidateSeed> = HashMap::new();
920 let mut message_entries: HashMap<i64, HnswCandidateSeed> = HashMap::new();
922 let mut episode_entries: HashMap<String, HnswCandidateSeed> = HashMap::new();
924
925 for (rank_0, hit) in hnsw_hits.iter().enumerate() {
926 let similarity = hit.similarity() as f64;
927 if similarity < config.min_similarity {
928 continue;
929 }
930
931 let (domain, raw_id) = hit.parse_key()?;
932 let seed = HnswCandidateSeed {
933 source_rank: rank_0 + 1,
934 source_similarity: similarity,
935 };
936
937 match domain {
938 "fact" if search_facts => {
939 fact_entries.entry(raw_id.to_string()).or_insert(seed);
940 }
941 "chunk" if search_chunks => {
942 chunk_entries.entry(raw_id.to_string()).or_insert(seed);
943 }
944 "msg" if search_messages => {
945 if let Ok(message_id) = raw_id.parse::<i64>() {
946 message_entries.entry(message_id).or_insert(seed);
947 }
948 }
949 "episode" if search_episodes => {
950 episode_entries.entry(raw_id.to_string()).or_insert(seed);
951 }
952 _ => {}
953 }
954 }
955
956 let mut hits = Vec::new();
957 batch_load_fact_hits(
958 conn,
959 query_embedding,
960 config,
961 namespaces,
962 &fact_entries,
963 &mut hits,
964 )?;
965 batch_load_chunk_hits(
966 conn,
967 query_embedding,
968 config,
969 namespaces,
970 &chunk_entries,
971 &mut hits,
972 )?;
973 batch_load_message_hits(
974 conn,
975 query_embedding,
976 config,
977 session_ids,
978 &message_entries,
979 &mut hits,
980 )?;
981 batch_load_episode_hits(
982 conn,
983 query_embedding,
984 config,
985 namespaces,
986 &episode_entries,
987 &mut hits,
988 )?;
989
990 hits.sort_by(|a, b| {
991 b.similarity
992 .partial_cmp(&a.similarity)
993 .unwrap_or(std::cmp::Ordering::Equal)
994 .then_with(|| {
995 a.source_rank
996 .unwrap_or(usize::MAX)
997 .cmp(&b.source_rank.unwrap_or(usize::MAX))
998 })
999 });
1000 hits.truncate(config.candidate_pool_size);
1001 Ok(hits)
1002}
1003
1004#[cfg(feature = "hnsw")]
1005fn exact_similarity_from_blob(
1006 query_embedding: &[f32],
1007 blob: &[u8],
1008) -> Result<Option<f64>, MemoryError> {
1009 if blob.is_empty() {
1010 return Ok(None);
1011 }
1012 let stored = crate::db::bytes_to_embedding(blob)?;
1013 if stored.len() != query_embedding.len() {
1014 return Ok(None);
1015 }
1016 Ok(Some(cosine_similarity(query_embedding, &stored) as f64))
1017}
1018
1019#[cfg(feature = "hnsw")]
1020#[allow(clippy::too_many_arguments)]
1021fn build_ranked_vector_hit(
1022 id: String,
1023 content: String,
1024 source: SearchSource,
1025 updated_at: Option<String>,
1026 embedding_blob: Option<Vec<u8>>,
1027 seed: &HnswCandidateSeed,
1028 query_embedding: &[f32],
1029 config: &SearchConfig,
1030) -> Result<Option<VectorHit>, MemoryError> {
1031 let similarity = if config.rerank_from_f32 {
1032 match embedding_blob {
1033 Some(blob) => exact_similarity_from_blob(query_embedding, &blob)?,
1034 None => None,
1035 }
1036 .unwrap_or(seed.source_similarity)
1037 } else {
1038 seed.source_similarity
1039 };
1040
1041 if similarity < config.min_similarity {
1042 return Ok(None);
1043 }
1044
1045 Ok(Some(VectorHit {
1046 id,
1047 content,
1048 source,
1049 similarity,
1050 updated_at,
1051 source_rank: Some(seed.source_rank),
1052 source_similarity: Some(seed.source_similarity),
1053 reranked_from_f32: config.rerank_from_f32,
1054 }))
1055}
1056
1057#[cfg(feature = "hnsw")]
1058fn batch_load_fact_hits(
1059 conn: &Connection,
1060 query_embedding: &[f32],
1061 config: &SearchConfig,
1062 namespaces: Option<&[&str]>,
1063 entries: &HashMap<String, HnswCandidateSeed>,
1065 output: &mut Vec<VectorHit>,
1066) -> Result<(), MemoryError> {
1067 if entries.is_empty() {
1068 return Ok(());
1069 }
1070
1071 let placeholders = (1..=entries.len())
1072 .map(|idx| format!("?{idx}"))
1073 .collect::<Vec<_>>()
1074 .join(", ");
1075 let sql = format!(
1076 "SELECT id, content, namespace, updated_at, embedding
1077 FROM facts
1078 WHERE id IN ({placeholders})"
1079 );
1080 let params: Vec<SqlValue> = entries
1081 .keys()
1082 .map(|id| SqlValue::Text(id.clone()))
1083 .collect();
1084 let mut stmt = conn.prepare(&sql)?;
1085 let rows = stmt.query_map(rusqlite::params_from_iter(¶ms), |row| {
1086 Ok((
1087 row.get::<_, String>(0)?,
1088 row.get::<_, String>(1)?,
1089 row.get::<_, String>(2)?,
1090 row.get::<_, Option<String>>(3)?,
1091 row.get::<_, Option<Vec<u8>>>(4)?,
1092 ))
1093 })?;
1094
1095 for row in rows {
1096 let (fact_id, content, namespace, updated_at, embedding_blob) = row?;
1097 if let Some(filter) = namespaces {
1098 if !filter.contains(&namespace.as_str()) {
1099 continue;
1100 }
1101 }
1102 if let Some(seed) = entries.get(&fact_id) {
1103 if let Some(hit) = build_ranked_vector_hit(
1104 format!("fact:{fact_id}"),
1105 content,
1106 SearchSource::Fact { fact_id, namespace },
1107 updated_at,
1108 embedding_blob,
1109 seed,
1110 query_embedding,
1111 config,
1112 )? {
1113 output.push(hit);
1114 }
1115 }
1116 }
1117
1118 Ok(())
1119}
1120
1121#[cfg(feature = "hnsw")]
1122fn batch_load_chunk_hits(
1123 conn: &Connection,
1124 query_embedding: &[f32],
1125 config: &SearchConfig,
1126 namespaces: Option<&[&str]>,
1127 entries: &HashMap<String, HnswCandidateSeed>,
1129 output: &mut Vec<VectorHit>,
1130) -> Result<(), MemoryError> {
1131 if entries.is_empty() {
1132 return Ok(());
1133 }
1134
1135 let placeholders = (1..=entries.len())
1136 .map(|idx| format!("?{idx}"))
1137 .collect::<Vec<_>>()
1138 .join(", ");
1139 let sql = format!(
1140 "SELECT c.id, c.content, c.document_id, d.title, c.chunk_index, c.created_at, d.namespace, c.embedding
1141 FROM chunks c
1142 JOIN documents d ON d.id = c.document_id
1143 WHERE c.id IN ({placeholders})"
1144 );
1145 let params: Vec<SqlValue> = entries
1146 .keys()
1147 .map(|id| SqlValue::Text(id.clone()))
1148 .collect();
1149 let mut stmt = conn.prepare(&sql)?;
1150 let rows = stmt.query_map(rusqlite::params_from_iter(¶ms), |row| {
1151 Ok((
1152 row.get::<_, String>(0)?,
1153 row.get::<_, String>(1)?,
1154 row.get::<_, String>(2)?,
1155 row.get::<_, String>(3)?,
1156 row.get::<_, i64>(4)?,
1157 row.get::<_, Option<String>>(5)?,
1158 row.get::<_, String>(6)?,
1159 row.get::<_, Option<Vec<u8>>>(7)?,
1160 ))
1161 })?;
1162
1163 for row in rows {
1164 let (
1165 chunk_id,
1166 content,
1167 document_id,
1168 document_title,
1169 chunk_index,
1170 updated_at,
1171 namespace,
1172 embedding_blob,
1173 ) = row?;
1174 if let Some(filter) = namespaces {
1175 if !filter.contains(&namespace.as_str()) {
1176 continue;
1177 }
1178 }
1179 if let Some(seed) = entries.get(&chunk_id) {
1180 if let Some(hit) = build_ranked_vector_hit(
1181 format!("chunk:{chunk_id}"),
1182 content,
1183 SearchSource::Chunk {
1184 chunk_id,
1185 document_id,
1186 document_title,
1187 chunk_index: chunk_index as usize,
1188 },
1189 updated_at,
1190 embedding_blob,
1191 seed,
1192 query_embedding,
1193 config,
1194 )? {
1195 output.push(hit);
1196 }
1197 }
1198 }
1199
1200 Ok(())
1201}
1202
1203#[cfg(feature = "hnsw")]
1204fn batch_load_message_hits(
1205 conn: &Connection,
1206 query_embedding: &[f32],
1207 config: &SearchConfig,
1208 session_ids: Option<&[&str]>,
1209 entries: &HashMap<i64, HnswCandidateSeed>,
1211 output: &mut Vec<VectorHit>,
1212) -> Result<(), MemoryError> {
1213 if entries.is_empty() {
1214 return Ok(());
1215 }
1216
1217 let placeholders = (1..=entries.len())
1218 .map(|idx| format!("?{idx}"))
1219 .collect::<Vec<_>>()
1220 .join(", ");
1221 let sql = format!(
1222 "SELECT id, content, session_id, role, created_at, embedding
1223 FROM messages
1224 WHERE id IN ({placeholders})"
1225 );
1226 let params: Vec<SqlValue> = entries.keys().map(|id| SqlValue::Integer(*id)).collect();
1227 let mut stmt = conn.prepare(&sql)?;
1228 let rows = stmt.query_map(rusqlite::params_from_iter(¶ms), |row| {
1229 Ok((
1230 row.get::<_, i64>(0)?,
1231 row.get::<_, String>(1)?,
1232 row.get::<_, String>(2)?,
1233 row.get::<_, String>(3)?,
1234 row.get::<_, Option<String>>(4)?,
1235 row.get::<_, Option<Vec<u8>>>(5)?,
1236 ))
1237 })?;
1238
1239 for row in rows {
1240 let (message_id, content, session_id, role, updated_at, embedding_blob) = row?;
1241 if let Some(filter) = session_ids {
1242 if !filter.contains(&session_id.as_str()) {
1243 continue;
1244 }
1245 }
1246 if let Some(seed) = entries.get(&message_id) {
1247 if let Some(hit) = build_ranked_vector_hit(
1248 format!("msg:{message_id}"),
1249 content,
1250 SearchSource::Message {
1251 message_id,
1252 session_id,
1253 role,
1254 },
1255 updated_at,
1256 embedding_blob,
1257 seed,
1258 query_embedding,
1259 config,
1260 )? {
1261 output.push(hit);
1262 }
1263 }
1264 }
1265
1266 Ok(())
1267}
1268
1269#[cfg(feature = "hnsw")]
1270fn batch_load_episode_hits(
1271 conn: &Connection,
1272 query_embedding: &[f32],
1273 config: &SearchConfig,
1274 namespaces: Option<&[&str]>,
1275 entries: &HashMap<String, HnswCandidateSeed>,
1277 output: &mut Vec<VectorHit>,
1278) -> Result<(), MemoryError> {
1279 if entries.is_empty() {
1280 return Ok(());
1281 }
1282
1283 let placeholders = (1..=entries.len())
1284 .map(|idx| format!("?{idx}"))
1285 .collect::<Vec<_>>()
1286 .join(", ");
1287 let sql = format!(
1288 "SELECT e.episode_id, e.document_id, e.search_text, e.effect_type, e.outcome, e.updated_at, d.namespace, e.embedding
1289 FROM episodes e
1290 JOIN documents d ON d.id = e.document_id
1291 WHERE e.episode_id IN ({placeholders})"
1292 );
1293 let params: Vec<SqlValue> = entries
1294 .keys()
1295 .map(|id| SqlValue::Text(id.clone()))
1296 .collect();
1297 let mut stmt = conn.prepare(&sql)?;
1298 let rows = stmt.query_map(rusqlite::params_from_iter(¶ms), |row| {
1299 Ok((
1300 row.get::<_, String>(0)?,
1301 row.get::<_, String>(1)?,
1302 row.get::<_, String>(2)?,
1303 row.get::<_, String>(3)?,
1304 row.get::<_, String>(4)?,
1305 row.get::<_, Option<String>>(5)?,
1306 row.get::<_, String>(6)?,
1307 row.get::<_, Option<Vec<u8>>>(7)?,
1308 ))
1309 })?;
1310
1311 for row in rows {
1312 let (
1313 episode_id,
1314 document_id,
1315 content,
1316 effect_type,
1317 outcome,
1318 updated_at,
1319 namespace,
1320 embedding_blob,
1321 ) = row?;
1322 if let Some(filter) = namespaces {
1323 if !filter.contains(&namespace.as_str()) {
1324 continue;
1325 }
1326 }
1327 if let Some(seed) = entries.get(&episode_id) {
1328 if let Some(hit) = build_ranked_vector_hit(
1329 episodes::episode_item_key(&episode_id),
1330 content,
1331 SearchSource::Episode {
1332 episode_id,
1333 document_id,
1334 effect_type,
1335 outcome,
1336 },
1337 updated_at,
1338 embedding_blob,
1339 seed,
1340 query_embedding,
1341 config,
1342 )? {
1343 output.push(hit);
1344 }
1345 }
1346 }
1347
1348 Ok(())
1349}
1350
1351#[cfg(feature = "hnsw")]
1353#[allow(clippy::too_many_arguments)]
1354pub fn hybrid_search_with_hnsw(
1355 conn: &Connection,
1356 query: &str,
1357 query_embedding: &[f32],
1358 config: &SearchConfig,
1359 top_k: usize,
1360 namespaces: Option<&[&str]>,
1361 source_types: Option<&[SearchSourceType]>,
1362 session_ids: Option<&[&str]>,
1363 hnsw_hits: &[crate::hnsw::HnswHit],
1364) -> Result<Vec<SearchResult>, MemoryError> {
1365 Ok(hybrid_search_with_hnsw_detailed(
1366 conn,
1367 query,
1368 query_embedding,
1369 config,
1370 top_k,
1371 namespaces,
1372 source_types,
1373 session_ids,
1374 hnsw_hits,
1375 )?
1376 .into_iter()
1377 .map(|result| result.result)
1378 .collect())
1379}
1380
1381#[cfg(feature = "hnsw")]
1382#[allow(clippy::too_many_arguments)]
1383pub(crate) fn hybrid_search_with_hnsw_detailed(
1384 conn: &Connection,
1385 query: &str,
1386 query_embedding: &[f32],
1387 config: &SearchConfig,
1388 top_k: usize,
1389 namespaces: Option<&[&str]>,
1390 source_types: Option<&[SearchSourceType]>,
1391 session_ids: Option<&[&str]>,
1392 hnsw_hits: &[crate::hnsw::HnswHit],
1393) -> Result<Vec<ExplainedResult>, MemoryError> {
1394 let bm25_hits = match sanitize_fts_query(query) {
1395 Some(sanitized) => bm25_search(
1396 conn,
1397 &sanitized,
1398 config.candidate_pool_size,
1399 namespaces,
1400 source_types,
1401 session_ids,
1402 )?,
1403 None => Vec::new(),
1404 };
1405
1406 let vector_hits = resolve_hnsw_hits_batched(
1407 conn,
1408 query_embedding,
1409 config,
1410 namespaces,
1411 source_types,
1412 session_ids,
1413 hnsw_hits,
1414 )?;
1415
1416 Ok(rrf_fuse_detailed(&bm25_hits, &vector_hits, config, top_k))
1417}
1418
1419#[cfg(feature = "hnsw")]
1421#[allow(clippy::too_many_arguments)]
1422pub fn hybrid_search_explained_with_hnsw(
1423 conn: &Connection,
1424 query: &str,
1425 query_embedding: &[f32],
1426 config: &SearchConfig,
1427 top_k: usize,
1428 namespaces: Option<&[&str]>,
1429 source_types: Option<&[SearchSourceType]>,
1430 session_ids: Option<&[&str]>,
1431 hnsw_hits: &[crate::hnsw::HnswHit],
1432) -> Result<Vec<ExplainedResult>, MemoryError> {
1433 hybrid_search_with_hnsw_detailed(
1434 conn,
1435 query,
1436 query_embedding,
1437 config,
1438 top_k,
1439 namespaces,
1440 source_types,
1441 session_ids,
1442 hnsw_hits,
1443 )
1444}
1445
1446pub(crate) fn fts_only_search_detailed(
1447 conn: &Connection,
1448 query: &str,
1449 config: &SearchConfig,
1450 top_k: usize,
1451 namespaces: Option<&[&str]>,
1452 source_types: Option<&[SearchSourceType]>,
1453 session_ids: Option<&[&str]>,
1454) -> Result<Vec<ExplainedResult>, MemoryError> {
1455 let sanitized = match sanitize_fts_query(query) {
1456 Some(value) => value,
1457 None => return Ok(Vec::new()),
1458 };
1459 let bm25_hits = bm25_search(
1460 conn,
1461 &sanitized,
1462 top_k,
1463 namespaces,
1464 source_types,
1465 session_ids,
1466 )?;
1467 Ok(rrf_fuse_detailed(&bm25_hits, &[], config, top_k))
1468}
1469
1470pub fn fts_only_search(
1472 conn: &Connection,
1473 query: &str,
1474 config: &SearchConfig,
1475 top_k: usize,
1476 namespaces: Option<&[&str]>,
1477 source_types: Option<&[SearchSourceType]>,
1478 session_ids: Option<&[&str]>,
1479) -> Result<Vec<SearchResult>, MemoryError> {
1480 Ok(fts_only_search_detailed(
1481 conn,
1482 query,
1483 config,
1484 top_k,
1485 namespaces,
1486 source_types,
1487 session_ids,
1488 )?
1489 .into_iter()
1490 .map(|result| result.result)
1491 .collect())
1492}
1493
1494pub(crate) fn vector_only_search_detailed(
1495 conn: &Connection,
1496 query_embedding: &[f32],
1497 config: &SearchConfig,
1498 top_k: usize,
1499 namespaces: Option<&[&str]>,
1500 source_types: Option<&[SearchSourceType]>,
1501 session_ids: Option<&[&str]>,
1502) -> Result<Vec<ExplainedResult>, MemoryError> {
1503 let vector_hits = vector_search(
1504 conn,
1505 query_embedding,
1506 top_k,
1507 config.min_similarity,
1508 namespaces,
1509 source_types,
1510 session_ids,
1511 )?;
1512 Ok(rrf_fuse_detailed(&[], &vector_hits, config, top_k))
1513}
1514
1515pub fn vector_only_search(
1517 conn: &Connection,
1518 query_embedding: &[f32],
1519 config: &SearchConfig,
1520 top_k: usize,
1521 namespaces: Option<&[&str]>,
1522 source_types: Option<&[SearchSourceType]>,
1523 session_ids: Option<&[&str]>,
1524) -> Result<Vec<SearchResult>, MemoryError> {
1525 Ok(vector_only_search_detailed(
1526 conn,
1527 query_embedding,
1528 config,
1529 top_k,
1530 namespaces,
1531 source_types,
1532 session_ids,
1533 )?
1534 .into_iter()
1535 .map(|result| result.result)
1536 .collect())
1537}
1538
1539#[cfg(feature = "hnsw")]
1541#[allow(clippy::too_many_arguments)]
1542pub fn vector_only_search_with_hnsw(
1543 conn: &Connection,
1544 query_embedding: &[f32],
1545 config: &SearchConfig,
1546 top_k: usize,
1547 namespaces: Option<&[&str]>,
1548 source_types: Option<&[SearchSourceType]>,
1549 session_ids: Option<&[&str]>,
1550 hnsw_hits: &[crate::hnsw::HnswHit],
1551) -> Result<Vec<SearchResult>, MemoryError> {
1552 Ok(vector_only_search_with_hnsw_detailed(
1553 conn,
1554 query_embedding,
1555 config,
1556 top_k,
1557 namespaces,
1558 source_types,
1559 session_ids,
1560 hnsw_hits,
1561 )?
1562 .into_iter()
1563 .map(|result| result.result)
1564 .collect())
1565}
1566
1567#[cfg(feature = "hnsw")]
1568#[allow(clippy::too_many_arguments)]
1569pub(crate) fn vector_only_search_with_hnsw_detailed(
1570 conn: &Connection,
1571 query_embedding: &[f32],
1572 config: &SearchConfig,
1573 top_k: usize,
1574 namespaces: Option<&[&str]>,
1575 source_types: Option<&[SearchSourceType]>,
1576 session_ids: Option<&[&str]>,
1577 hnsw_hits: &[crate::hnsw::HnswHit],
1578) -> Result<Vec<ExplainedResult>, MemoryError> {
1579 let vector_hits = resolve_hnsw_hits_batched(
1580 conn,
1581 query_embedding,
1582 config,
1583 namespaces,
1584 source_types,
1585 session_ids,
1586 hnsw_hits,
1587 )?;
1588 Ok(rrf_fuse_detailed(&[], &vector_hits, config, top_k))
1589}
1590
1591fn build_filter_clause(
1592 column: &str,
1593 values: Option<&[&str]>,
1594 param_offset: usize,
1595) -> (String, Vec<SqlValue>) {
1596 match values {
1597 Some(values) if !values.is_empty() => {
1598 let placeholders = (0..values.len())
1599 .map(|idx| format!("?{}", param_offset + idx))
1600 .collect::<Vec<_>>();
1601 let clause = format!(" AND {} IN ({})", column, placeholders.join(", "));
1602 let params = values
1603 .iter()
1604 .map(|value| SqlValue::Text((*value).to_string()))
1605 .collect();
1606 (clause, params)
1607 }
1608 _ => (String::new(), Vec::new()),
1609 }
1610}
1611
1612pub fn deduplicate_results(results: Vec<SearchResult>) -> Vec<SearchResult> {
1614 let mut seen = HashSet::new();
1615 results
1616 .into_iter()
1617 .filter(|result| seen.insert(source_dedup_key(&result.source)))
1618 .collect()
1619}