1use std::cell::RefCell;
2use std::sync::Arc;
3
4use sqlite3_ext::{
5 Error, FallibleIteratorMut, FromValue, Result, ValueRef,
6 vtab::{ColumnContext, VTabConnection, VTabCursor},
7};
8
9use crate::vtab::config::VectorTableConfig;
10use crate::vtab::shadow::ShadowOps;
11use crate::vtab::transaction::IndexState;
12
13const INDEX_KNN: i32 = 1;
15
16pub enum CursorMode {
17 Scan { rows: Vec<ScanRow>, pos: usize },
18 Knn { results: Vec<KnnRow>, pos: usize },
19}
20
21pub struct ScanRow {
22 pub id: i64,
23 pub vector: Vec<u8>,
24 pub metadata: Vec<Option<Vec<u8>>>,
25}
26
27pub struct KnnRow {
28 pub id: i64,
29 pub vector: Vec<u8>,
30 pub metadata: Vec<Option<Vec<u8>>>,
31 pub distance: f64,
32}
33
34pub struct VectorCursor {
35 pub mode: CursorMode,
36 pub num_metadata_cols: usize,
37 pub db: *const VTabConnection,
39 pub config: *const VectorTableConfig,
41 pub state: Arc<RefCell<IndexState>>,
42}
43
44unsafe impl Send for VectorCursor {}
46unsafe impl Sync for VectorCursor {}
47
48impl VectorCursor {
49 fn current_id(&self) -> i64 {
50 match &self.mode {
51 CursorMode::Scan { rows, pos } => rows[*pos].id,
52 CursorMode::Knn { results, pos } => results[*pos].id,
53 }
54 }
55
56 fn current_vector(&self) -> &[u8] {
57 match &self.mode {
58 CursorMode::Scan { rows, pos } => &rows[*pos].vector,
59 CursorMode::Knn { results, pos } => &results[*pos].vector,
60 }
61 }
62
63 fn current_metadata(&self) -> &[Option<Vec<u8>>] {
64 match &self.mode {
65 CursorMode::Scan { rows, pos } => &rows[*pos].metadata,
66 CursorMode::Knn { results, pos } => &results[*pos].metadata,
67 }
68 }
69
70 fn current_distance(&self) -> Option<f64> {
71 match &self.mode {
72 CursorMode::Scan { .. } => None,
73 CursorMode::Knn { results, pos } => Some(results[*pos].distance),
74 }
75 }
76
77 fn len(&self) -> usize {
78 match &self.mode {
79 CursorMode::Scan { rows, .. } => rows.len(),
80 CursorMode::Knn { results, .. } => results.len(),
81 }
82 }
83
84 fn pos(&self) -> usize {
85 match &self.mode {
86 CursorMode::Scan { pos, .. } => *pos,
87 CursorMode::Knn { pos, .. } => *pos,
88 }
89 }
90
91 fn set_pos(&mut self, new_pos: usize) {
92 match &mut self.mode {
93 CursorMode::Scan { pos, .. } => *pos = new_pos,
94 CursorMode::Knn { pos, .. } => *pos = new_pos,
95 }
96 }
97}
98
99impl VTabCursor for VectorCursor {
100 fn filter(
101 &mut self,
102 index_num: i32,
103 _index_str: Option<&str>,
104 args: &mut [&mut ValueRef],
105 ) -> Result<()> {
106 let db = unsafe { &*self.db };
108 let config = unsafe { &*self.config };
109
110 match index_num {
111 INDEX_KNN => {
112 if args.is_empty() {
115 return Err(Error::Module(
116 "knn_match requires a query vector argument".into(),
117 ));
118 }
119 let query_blob = args[0].get_blob()?.to_vec();
120 let k = if args.len() > 1 {
121 args[1].get_i64() as usize
122 } else {
123 100
125 };
126
127 let state = self.state.borrow();
128 let hits = state
129 .index
130 .search(&query_blob, k)
131 .map_err(|e| Error::Module(e.to_string()))?;
132
133 let mut results = Vec::with_capacity(hits.len());
134 for (key, dist) in hits {
135 if let Some(row) = fetch_row_by_id(db, config, key as i64)? {
136 results.push(KnnRow {
137 id: row.id,
138 vector: row.vector,
139 metadata: row.metadata,
140 distance: dist as f64,
141 });
142 }
143 }
144 self.mode = CursorMode::Knn { results, pos: 0 };
145 }
146 _ => {
147 let rows = scan_all_rows(db, config)?;
148 self.mode = CursorMode::Scan { rows, pos: 0 };
149 }
150 }
151
152 Ok(())
153 }
154
155 fn next(&mut self) -> Result<()> {
156 let new_pos = self.pos() + 1;
157 self.set_pos(new_pos);
158 Ok(())
159 }
160
161 fn eof(&mut self) -> bool {
162 self.pos() >= self.len()
163 }
164
165 fn column(&mut self, idx: usize, ctx: &ColumnContext) -> Result<()> {
166 match idx {
168 0 => {
169 ctx.set_result(self.current_id())?;
170 }
171 1 => {
172 ctx.set_result(self.current_vector())?;
173 }
174 i if i >= 2 && i < 2 + self.num_metadata_cols => {
175 let meta_idx = i - 2;
176 match &self.current_metadata()[meta_idx] {
177 Some(blob) => ctx.set_result(blob.as_slice())?,
178 None => ctx.set_result(())?,
179 }
180 }
181 _ => {
182 match self.current_distance() {
184 Some(d) => ctx.set_result(d)?,
185 None => ctx.set_result(())?,
186 }
187 }
188 }
189 Ok(())
190 }
191
192 fn rowid(&mut self) -> Result<i64> {
193 Ok(self.current_id())
194 }
195}
196
197fn scan_all_rows(db: &VTabConnection, config: &VectorTableConfig) -> Result<Vec<ScanRow>> {
202 let sql = ShadowOps::select_all_data_sql(&config.table_name);
203 let num_meta = config.metadata_columns.len();
204 let mut stmt = db.prepare(&sql)?;
205 stmt.query(())?;
206 let mut rows = Vec::new();
207 while let Some(row) = stmt.next()? {
208 let id = row[0].get_i64();
209 let vector = row[1].get_blob()?.to_vec();
210 let mut metadata = Vec::with_capacity(num_meta);
211 for i in 0..num_meta {
212 if row[2 + i].is_null() {
213 metadata.push(None);
214 } else {
215 metadata.push(Some(row[2 + i].get_blob()?.to_vec()));
216 }
217 }
218 rows.push(ScanRow {
219 id,
220 vector,
221 metadata,
222 });
223 }
224 Ok(rows)
225}
226
227fn fetch_row_by_id(
228 db: &VTabConnection,
229 config: &VectorTableConfig,
230 id: i64,
231) -> Result<Option<ScanRow>> {
232 use sqlite3_ext::SQLITE_EMPTY;
233 let sql = ShadowOps::select_data_sql(&config.table_name);
234 let num_meta = config.metadata_columns.len();
235 match db.query_row(&sql, [id], |row| {
236 let id = row[0].get_i64();
237 let vector = row[1].get_blob()?.to_vec();
238 let mut metadata = Vec::with_capacity(num_meta);
239 for i in 0..num_meta {
240 if row[2 + i].is_null() {
241 metadata.push(None);
242 } else {
243 metadata.push(Some(row[2 + i].get_blob()?.to_vec()));
244 }
245 }
246 Ok(ScanRow {
247 id,
248 vector,
249 metadata,
250 })
251 }) {
252 Ok(row) => Ok(Some(row)),
253 Err(ref e) if *e == SQLITE_EMPTY => Ok(None),
254 Err(e) => Err(e),
255 }
256}