Skip to main content

sqlite_vector_rs/vtab/
cursor.rs

1use std::cell::RefCell;
2use std::sync::Arc;
3
4use sqlite3_ext::{
5    Error, FallibleIteratorMut, FromValue, Result, ValueRef,
6    vtab::{ColumnContext, VTabConnection, VTabCursor},
7};
8
9use crate::vtab::config::VectorTableConfig;
10use crate::vtab::shadow::ShadowOps;
11use crate::vtab::transaction::IndexState;
12
13// Index number must match INDEX_KNN in mod.rs
14const INDEX_KNN: i32 = 1;
15
16pub enum CursorMode {
17    Scan { rows: Vec<ScanRow>, pos: usize },
18    Knn { results: Vec<KnnRow>, pos: usize },
19}
20
21pub struct ScanRow {
22    pub id: i64,
23    pub vector: Vec<u8>,
24    pub metadata: Vec<Option<Vec<u8>>>,
25}
26
27pub struct KnnRow {
28    pub id: i64,
29    pub vector: Vec<u8>,
30    pub metadata: Vec<Option<Vec<u8>>>,
31    pub distance: f64,
32}
33
34pub struct VectorCursor {
35    pub mode: CursorMode,
36    pub num_metadata_cols: usize,
37    /// Safety: valid for the vtab lifetime — SQLite keeps the connection alive.
38    pub db: *const VTabConnection,
39    /// Safety: valid for the vtab lifetime — VectorTable owns the config.
40    pub config: *const VectorTableConfig,
41    pub state: Arc<RefCell<IndexState>>,
42}
43
44// Safety: VectorCursor is only ever accessed from a single thread by SQLite.
45unsafe impl Send for VectorCursor {}
46unsafe impl Sync for VectorCursor {}
47
48impl VectorCursor {
49    fn current_id(&self) -> i64 {
50        match &self.mode {
51            CursorMode::Scan { rows, pos } => rows[*pos].id,
52            CursorMode::Knn { results, pos } => results[*pos].id,
53        }
54    }
55
56    fn current_vector(&self) -> &[u8] {
57        match &self.mode {
58            CursorMode::Scan { rows, pos } => &rows[*pos].vector,
59            CursorMode::Knn { results, pos } => &results[*pos].vector,
60        }
61    }
62
63    fn current_metadata(&self) -> &[Option<Vec<u8>>] {
64        match &self.mode {
65            CursorMode::Scan { rows, pos } => &rows[*pos].metadata,
66            CursorMode::Knn { results, pos } => &results[*pos].metadata,
67        }
68    }
69
70    fn current_distance(&self) -> Option<f64> {
71        match &self.mode {
72            CursorMode::Scan { .. } => None,
73            CursorMode::Knn { results, pos } => Some(results[*pos].distance),
74        }
75    }
76
77    fn len(&self) -> usize {
78        match &self.mode {
79            CursorMode::Scan { rows, .. } => rows.len(),
80            CursorMode::Knn { results, .. } => results.len(),
81        }
82    }
83
84    fn pos(&self) -> usize {
85        match &self.mode {
86            CursorMode::Scan { pos, .. } => *pos,
87            CursorMode::Knn { pos, .. } => *pos,
88        }
89    }
90
91    fn set_pos(&mut self, new_pos: usize) {
92        match &mut self.mode {
93            CursorMode::Scan { pos, .. } => *pos = new_pos,
94            CursorMode::Knn { pos, .. } => *pos = new_pos,
95        }
96    }
97}
98
99impl VTabCursor for VectorCursor {
100    fn filter(
101        &mut self,
102        index_num: i32,
103        _index_str: Option<&str>,
104        args: &mut [&mut ValueRef],
105    ) -> Result<()> {
106        // Safety: db and config pointers are valid for the vtab lifetime.
107        let db = unsafe { &*self.db };
108        let config = unsafe { &*self.config };
109
110        match index_num {
111            INDEX_KNN => {
112                // args[0] = query vector blob (from knn_match function constraint)
113                // args[1] = k (from LIMIT clause, if present)
114                if args.is_empty() {
115                    return Err(Error::Module(
116                        "knn_match requires a query vector argument".into(),
117                    ));
118                }
119                let query_blob = args[0].get_blob()?.to_vec();
120                let k = if args.len() > 1 {
121                    args[1].get_i64() as usize
122                } else {
123                    // Default k when no LIMIT is specified
124                    100
125                };
126
127                let state = self.state.borrow();
128                let hits = state
129                    .index
130                    .search(&query_blob, k)
131                    .map_err(|e| Error::Module(e.to_string()))?;
132
133                let mut results = Vec::with_capacity(hits.len());
134                for (key, dist) in hits {
135                    if let Some(row) = fetch_row_by_id(db, config, key as i64)? {
136                        results.push(KnnRow {
137                            id: row.id,
138                            vector: row.vector,
139                            metadata: row.metadata,
140                            distance: dist as f64,
141                        });
142                    }
143                }
144                self.mode = CursorMode::Knn { results, pos: 0 };
145            }
146            _ => {
147                let rows = scan_all_rows(db, config)?;
148                self.mode = CursorMode::Scan { rows, pos: 0 };
149            }
150        }
151
152        Ok(())
153    }
154
155    fn next(&mut self) -> Result<()> {
156        let new_pos = self.pos() + 1;
157        self.set_pos(new_pos);
158        Ok(())
159    }
160
161    fn eof(&mut self) -> bool {
162        self.pos() >= self.len()
163    }
164
165    fn column(&mut self, idx: usize, ctx: &ColumnContext) -> Result<()> {
166        // Column layout: 0=id, 1=vector, 2..2+N=metadata[0..N], last=distance
167        match idx {
168            0 => {
169                ctx.set_result(self.current_id())?;
170            }
171            1 => {
172                ctx.set_result(self.current_vector())?;
173            }
174            i if i >= 2 && i < 2 + self.num_metadata_cols => {
175                let meta_idx = i - 2;
176                match &self.current_metadata()[meta_idx] {
177                    Some(blob) => ctx.set_result(blob.as_slice())?,
178                    None => ctx.set_result(())?,
179                }
180            }
181            _ => {
182                // distance column (last)
183                match self.current_distance() {
184                    Some(d) => ctx.set_result(d)?,
185                    None => ctx.set_result(())?,
186                }
187            }
188        }
189        Ok(())
190    }
191
192    fn rowid(&mut self) -> Result<i64> {
193        Ok(self.current_id())
194    }
195}
196
197// ---------------------------------------------------------------------------
198// Helpers duplicated here to avoid circular imports (mirror mod.rs helpers)
199// ---------------------------------------------------------------------------
200
201fn scan_all_rows(db: &VTabConnection, config: &VectorTableConfig) -> Result<Vec<ScanRow>> {
202    let sql = ShadowOps::select_all_data_sql(&config.table_name);
203    let num_meta = config.metadata_columns.len();
204    let mut stmt = db.prepare(&sql)?;
205    stmt.query(())?;
206    let mut rows = Vec::new();
207    while let Some(row) = stmt.next()? {
208        let id = row[0].get_i64();
209        let vector = row[1].get_blob()?.to_vec();
210        let mut metadata = Vec::with_capacity(num_meta);
211        for i in 0..num_meta {
212            if row[2 + i].is_null() {
213                metadata.push(None);
214            } else {
215                metadata.push(Some(row[2 + i].get_blob()?.to_vec()));
216            }
217        }
218        rows.push(ScanRow {
219            id,
220            vector,
221            metadata,
222        });
223    }
224    Ok(rows)
225}
226
227fn fetch_row_by_id(
228    db: &VTabConnection,
229    config: &VectorTableConfig,
230    id: i64,
231) -> Result<Option<ScanRow>> {
232    use sqlite3_ext::SQLITE_EMPTY;
233    let sql = ShadowOps::select_data_sql(&config.table_name);
234    let num_meta = config.metadata_columns.len();
235    match db.query_row(&sql, [id], |row| {
236        let id = row[0].get_i64();
237        let vector = row[1].get_blob()?.to_vec();
238        let mut metadata = Vec::with_capacity(num_meta);
239        for i in 0..num_meta {
240            if row[2 + i].is_null() {
241                metadata.push(None);
242            } else {
243                metadata.push(Some(row[2 + i].get_blob()?.to_vec()));
244            }
245        }
246        Ok(ScanRow {
247            id,
248            vector,
249            metadata,
250        })
251    }) {
252        Ok(row) => Ok(Some(row)),
253        Err(ref e) if *e == SQLITE_EMPTY => Ok(None),
254        Err(e) => Err(e),
255    }
256}