1use crate::Result;
7use rusqlite::Connection;
8
9#[derive(Debug, Clone)]
11pub struct FtsResult {
12 pub chunk_id: String,
14 pub doc_id: String,
16 pub content: String,
18 pub score: f64,
21 pub position: i64,
23}
24
25fn escape_fts5_query(query: &str) -> String {
30 query
31 .split_whitespace()
32 .filter(|token| !token.is_empty())
33 .map(|token| {
34 let clean = token.replace('"', "");
36 if clean.is_empty() {
37 return String::new();
38 }
39 format!("\"{clean}\"")
40 })
41 .filter(|s| !s.is_empty())
42 .collect::<Vec<_>>()
43 .join(" ")
44}
45
46pub fn search(conn: &Connection, query: &str, k: usize) -> Result<Vec<FtsResult>> {
52 let escaped = escape_fts5_query(query);
53 if escaped.is_empty() {
54 return Ok(Vec::new());
55 }
56
57 let mut stmt = conn
58 .prepare_cached(
59 "SELECT c.id, c.doc_id, c.content, -bm25(chunks_fts) AS score, c.position
60 FROM chunks_fts
61 JOIN chunks c ON chunks_fts.rowid = c.rowid
62 WHERE chunks_fts MATCH ?1
63 ORDER BY score DESC
64 LIMIT ?2",
65 )
66 .map_err(|e| crate::Error::Query(format!("Failed to prepare FTS5 search: {e}")))?;
67
68 let results = stmt
69 .query_map(rusqlite::params![escaped, k as i64], |row| {
70 Ok(FtsResult {
71 chunk_id: row.get(0)?,
72 doc_id: row.get(1)?,
73 content: row.get(2)?,
74 score: row.get(3)?,
75 position: row.get(4)?,
76 })
77 })
78 .map_err(|e| crate::Error::Query(format!("FTS5 search failed: {e}")))?
79 .collect::<std::result::Result<Vec<_>, _>>()
80 .map_err(|e| crate::Error::Query(format!("FTS5 result mapping failed: {e}")))?;
81
82 Ok(results)
83}
84
85pub fn optimize(conn: &Connection) -> Result<()> {
89 conn.execute("INSERT INTO chunks_fts(chunks_fts) VALUES('optimize')", [])
90 .map_err(|e| crate::Error::Query(format!("FTS5 optimize failed: {e}")))?;
91 Ok(())
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97 use crate::sqlite::schema;
98
99 fn setup() -> Connection {
100 let conn = Connection::open_in_memory().unwrap();
101 schema::initialize(&conn).unwrap();
102 conn
103 }
104
105 fn insert_chunk(conn: &Connection, doc_id: &str, chunk_id: &str, content: &str, pos: i64) {
106 conn.execute("INSERT OR IGNORE INTO documents (id, content) VALUES (?1, '')", [doc_id])
108 .unwrap();
109 conn.execute(
110 "INSERT INTO chunks (id, doc_id, content, position) VALUES (?1, ?2, ?3, ?4)",
111 rusqlite::params![chunk_id, doc_id, content, pos],
112 )
113 .unwrap();
114 }
115
116 #[test]
117 fn test_search_returns_results() {
118 let conn = setup();
119 insert_chunk(&conn, "doc1", "c1", "SIMD vector operations for tensor math", 0);
120 insert_chunk(&conn, "doc1", "c2", "GPU kernel dispatch and scheduling", 1);
121
122 let results = search(&conn, "SIMD tensor", 10).unwrap();
123 assert!(!results.is_empty());
124 assert_eq!(results[0].chunk_id, "c1");
125 }
126
127 #[test]
128 fn test_search_bm25_ordering() {
129 let conn = setup();
130 insert_chunk(&conn, "d1", "c1", "machine learning algorithms", 0);
131 insert_chunk(
132 &conn,
133 "d2",
134 "c2",
135 "machine learning machine learning machine learning deep learning",
136 0,
137 );
138 insert_chunk(&conn, "d3", "c3", "cooking recipes for dinner", 0);
139
140 let results = search(&conn, "machine learning", 10).unwrap();
141 assert!(results.len() >= 2);
143 assert_eq!(results[0].chunk_id, "c2");
144 assert_eq!(results[1].chunk_id, "c1");
145 assert!(results.iter().all(|r| r.chunk_id != "c3"));
147 }
148
149 #[test]
150 fn test_search_empty_query() {
151 let conn = setup();
152 insert_chunk(&conn, "d1", "c1", "some content", 0);
153 let results = search(&conn, "", 10).unwrap();
154 assert!(results.is_empty());
155 }
156
157 #[test]
158 fn test_search_no_matches() {
159 let conn = setup();
160 insert_chunk(&conn, "d1", "c1", "SIMD vector operations", 0);
161 let results = search(&conn, "cryptocurrency blockchain", 10).unwrap();
162 assert!(results.is_empty());
163 }
164
165 #[test]
166 fn test_escape_fts5_query_special_chars() {
167 let escaped = escape_fts5_query("hello AND world");
169 assert!(escaped.contains("\"hello\""));
170 assert!(escaped.contains("\"AND\""));
171 assert!(escaped.contains("\"world\""));
172 }
173
174 #[test]
175 fn test_porter_stemming() {
176 let conn = setup();
177 insert_chunk(&conn, "d1", "c1", "tokenizer tokenization tokenizing", 0);
178
179 let results = search(&conn, "tokenize", 10).unwrap();
181 assert!(!results.is_empty(), "Porter stemmer should conflate 'tokenize' variants");
182 }
183
184 #[test]
185 fn test_optimize_does_not_error() {
186 let conn = setup();
187 insert_chunk(&conn, "d1", "c1", "some content", 0);
188 optimize(&conn).unwrap();
189 }
190
191 #[test]
192 fn test_scores_are_positive() {
193 let conn = setup();
194 insert_chunk(&conn, "d1", "c1", "machine learning algorithms", 0);
195 let results = search(&conn, "machine learning", 10).unwrap();
196 assert!(!results.is_empty());
197 for r in &results {
198 assert!(r.score > 0.0, "Negated BM25 scores should be positive");
199 }
200 }
201}