Skip to main content

sqlite_vector_rs/vtab/
mod.rs

1pub mod config;
2pub mod cursor;
3pub mod shadow;
4pub mod transaction;
5
6use std::cell::RefCell;
7use std::sync::Arc;
8
9use sqlite3_ext::query::ToParam;
10use sqlite3_ext::vtab::{
11    ChangeInfo, ChangeType, ConstraintOp, CreateVTab, DisconnectResult, FindFunctionVTab,
12    IndexInfo, TransactionVTab, UpdateVTab, VTab, VTabConnection, VTabFunctionList,
13};
14use sqlite3_ext::{Error, FromValue, Result, SQLITE_EMPTY, ValueRef, function::Context};
15
16use crate::index::HnswIndex;
17use crate::vtab::config::VectorTableConfig;
18use crate::vtab::cursor::{CursorMode, VectorCursor};
19use crate::vtab::shadow::ShadowOps;
20use crate::vtab::transaction::{IndexState, VectorTransaction};
21
22// Index numbers passed via best_index -> filter
23const INDEX_SCAN: i32 = 0;
24const INDEX_KNN: i32 = 1;
25
26/// The virtual table implementation for vector search.
27///
28/// `db` is a raw pointer to the VTabConnection that SQLite provides to connect/create.
29/// SQLite guarantees the connection outlives the virtual table, so this pointer is valid
30/// for the entire lifetime of VectorTable.
31pub struct VectorTable<'vtab> {
32    config: VectorTableConfig,
33    state: Arc<RefCell<IndexState>>,
34    /// Safety: valid for 'vtab lifetime — SQLite keeps the connection alive.
35    db: *const VTabConnection,
36    functions: VTabFunctionList<'vtab, Self>,
37}
38
39// Safety: VectorTable is only ever accessed from a single thread by SQLite's
40// virtual table machinery.
41unsafe impl Send for VectorTable<'_> {}
42unsafe impl Sync for VectorTable<'_> {}
43
44// ---------------------------------------------------------------------------
45// Shadow table I/O stubs — wired up in Task 13
46// ---------------------------------------------------------------------------
47
48/// Load the serialized HNSW index blob from the `_index` shadow table, if present.
49fn load_index_from_shadow(db: &VTabConnection, table_name: &str) -> Result<Option<Vec<u8>>> {
50    let sql = ShadowOps::select_index_sql(table_name);
51    match db.query_row(&sql, ["hnsw_graph"], |row| {
52        let blob = row[0].get_blob()?;
53        Ok(blob.to_vec())
54    }) {
55        Ok(buf) => Ok(Some(buf)),
56        Err(ref e) if *e == SQLITE_EMPTY => Ok(None),
57        Err(e) => Err(e),
58    }
59}
60
61/// Persist schema/config metadata to the `_index` shadow table.
62#[allow(dead_code)]
63fn save_meta_to_shadow(db: &VTabConnection, table_name: &str, meta_json: &str) -> Result<()> {
64    let sql = ShadowOps::upsert_index_sql(table_name);
65    db.execute(&sql, ["meta", meta_json])?;
66    Ok(())
67}
68
69/// Insert a new row into `_data` and return the auto-assigned rowid.
70fn insert_into_data_shadow(
71    db: &VTabConnection,
72    config: &VectorTableConfig,
73    vector_blob: &[u8],
74    metadata_args: &mut [&mut ValueRef],
75) -> Result<i64> {
76    use sqlite3_ext::query::Statement;
77    let sql = ShadowOps::insert_data_sql(config);
78    db.insert(&sql, |stmt: &mut Statement| {
79        vector_blob.bind_param(&mut *stmt, 1)?;
80        for (i, val) in metadata_args.iter_mut().enumerate() {
81            val.bind_param(&mut *stmt, (i + 2) as i32)?;
82        }
83        Ok(())
84    })
85}
86
87/// Delete a row from `_data` by rowid.
88fn delete_from_data_shadow(db: &VTabConnection, table_name: &str, rowid: i64) -> Result<()> {
89    let sql = ShadowOps::delete_data_sql(table_name);
90    db.execute(&sql, [rowid])?;
91    Ok(())
92}
93
94/// Update an existing row in `_data` by deleting and re-inserting.
95fn update_data_shadow(
96    db: &VTabConnection,
97    config: &VectorTableConfig,
98    rowid: i64,
99    vector_blob: &[u8],
100    metadata_args: &mut [&mut ValueRef],
101) -> Result<()> {
102    delete_from_data_shadow(db, &config.table_name, rowid)?;
103    insert_into_data_shadow(db, config, vector_blob, metadata_args)?;
104    Ok(())
105}
106
107// ---------------------------------------------------------------------------
108// Shared init logic used by both connect and create
109// ---------------------------------------------------------------------------
110
111#[allow(clippy::arc_with_non_send_sync)]
112fn init<'vtab>(db: &VTabConnection, args: &[&str]) -> Result<(String, VectorTable<'vtab>)> {
113    let config = VectorTableConfig::parse(args).map_err(|e| Error::Module(e.to_string()))?;
114
115    let schema = config.vtab_schema();
116
117    // Try to reload a previously persisted index; fall back to a fresh one.
118    let index = match load_index_from_shadow(db, &config.table_name) {
119        Ok(Some(buf)) => {
120            let idx = HnswIndex::new(
121                config.dim,
122                config.vtype,
123                config.metric,
124                Some(config.hnsw_params),
125            )
126            .map_err(|e| Error::Module(e.to_string()))?;
127            idx.load_from_buffer(&buf)
128                .map_err(|e| Error::Module(e.to_string()))?;
129            idx
130        }
131        _ => HnswIndex::new(
132            config.dim,
133            config.vtype,
134            config.metric,
135            Some(config.hnsw_params),
136        )
137        .map_err(|e| Error::Module(e.to_string()))?,
138    };
139
140    let state = Arc::new(RefCell::new(IndexState {
141        index,
142        dirty: false,
143        last_committed: None,
144    }));
145
146    let functions = VTabFunctionList::default();
147    // Register knn_match as a 2-arg overloaded function (col, param).
148    // ConstraintOp::Function(0) tells best_index this function can act as a constraint.
149    // The function body is a no-op returning 1 because set_omit(true) in best_index
150    // prevents SQLite from evaluating it; the real work happens in filter().
151    functions.add(
152        2,
153        "knn_match",
154        Some(ConstraintOp::Function(150)),
155        |ctx: &Context, _args: &mut [&mut ValueRef]| ctx.set_result(1i32),
156    );
157
158    let vtab = VectorTable {
159        config,
160        state,
161        db: db as *const VTabConnection,
162        functions,
163    };
164
165    Ok((schema, vtab))
166}
167
168// ---------------------------------------------------------------------------
169// VTab impl
170// ---------------------------------------------------------------------------
171
172impl<'vtab> VTab<'vtab> for VectorTable<'vtab> {
173    type Aux = ();
174    type Cursor = VectorCursor;
175
176    fn connect(
177        db: &'vtab VTabConnection,
178        _aux: &'vtab Self::Aux,
179        args: &[&str],
180    ) -> Result<(String, Self)> {
181        init(db, args)
182    }
183
184    fn best_index(&'vtab self, info: &mut IndexInfo) -> Result<()> {
185        // Distance column index = 2 + num_metadata_cols
186        let distance_col = (2 + self.config.metadata_columns.len()) as i32;
187
188        let mut found_knn = false;
189        let mut argv_next: u32 = 1;
190
191        for mut c in info.constraints() {
192            if !c.usable() {
193                continue;
194            }
195            if c.column() == distance_col
196                && let ConstraintOp::Function(_) = c.op()
197            {
198                // knn_match(distance_col, query_blob): query_blob passed to filter
199                c.set_argv_index(Some(argv_next - 1));
200                c.set_omit(true);
201                argv_next += 1;
202                found_knn = true;
203            }
204            // Capture LIMIT as the k parameter for KNN searches
205            if let ConstraintOp::Limit = c.op()
206                && found_knn
207            {
208                c.set_argv_index(Some(argv_next - 1));
209                c.set_omit(true);
210                argv_next += 1;
211            }
212        }
213
214        if found_knn {
215            info.set_index_num(INDEX_KNN);
216            info.set_estimated_cost(10.0);
217            info.set_estimated_rows(10);
218        } else {
219            info.set_index_num(INDEX_SCAN);
220            info.set_estimated_cost(1_000_000.0);
221            info.set_estimated_rows(1_000_000);
222        }
223
224        Ok(())
225    }
226
227    fn open(&'vtab self) -> Result<Self::Cursor> {
228        Ok(VectorCursor {
229            mode: CursorMode::Scan {
230                rows: Vec::new(),
231                pos: 0,
232            },
233            num_metadata_cols: self.config.metadata_columns.len(),
234            db: self.db,
235            config: &self.config as *const VectorTableConfig,
236            state: Arc::clone(&self.state),
237        })
238    }
239}
240
241// ---------------------------------------------------------------------------
242// CreateVTab impl
243// ---------------------------------------------------------------------------
244
245impl<'vtab> CreateVTab<'vtab> for VectorTable<'vtab> {
246    const SHADOW_NAMES: &'static [&'static str] = &["data", "index"];
247
248    fn create(
249        db: &'vtab VTabConnection,
250        aux: &'vtab Self::Aux,
251        args: &[&str],
252    ) -> Result<(String, Self)> {
253        let (schema, vtab) = init(db, args)?;
254
255        // Create the shadow tables
256        db.execute(&ShadowOps::create_data_table_sql(&vtab.config), ())?;
257        db.execute(&ShadowOps::create_index_table_sql(&vtab.config), ())?;
258
259        let _ = aux;
260        Ok((schema, vtab))
261    }
262
263    fn destroy(self) -> DisconnectResult<Self> {
264        // Safety: db pointer is valid for 'vtab; we're being destroyed now.
265        let db = unsafe { &*self.db };
266        for sql in ShadowOps::drop_shadow_tables_sql(&self.config.table_name) {
267            if let Err(e) = db.execute(&sql, ()) {
268                return Err((self, e));
269            }
270        }
271        Ok(())
272    }
273}
274
275// ---------------------------------------------------------------------------
276// UpdateVTab impl
277// ---------------------------------------------------------------------------
278
279impl<'vtab> UpdateVTab<'vtab> for VectorTable<'vtab> {
280    fn update(&'vtab self, info: &mut ChangeInfo) -> Result<i64> {
281        // Safety: db pointer is valid for 'vtab lifetime.
282        let db = unsafe { &*self.db };
283
284        match info.change_type() {
285            ChangeType::Delete => {
286                let rowid = info.rowid().get_i64();
287                delete_from_data_shadow(db, &self.config.table_name, rowid)?;
288                self.state
289                    .borrow()
290                    .index
291                    .remove(rowid as u64)
292                    .map_err(|e| Error::Module(e.to_string()))?;
293                self.state.borrow_mut().dirty = true;
294                Ok(0)
295            }
296            ChangeType::Insert => {
297                let args = info.args_mut();
298                // SQLite xUpdate argv layout (after argv[0] = old rowid):
299                //   args[0] = new rowid (NULL → auto-assign)
300                //   args[1] = col 0 (id)
301                //   args[2] = col 1 (vector)
302                //   args[3..3+N] = metadata cols
303                //   args[3+N] = distance (hidden, ignored on insert)
304                let vector_blob = args[2].get_blob()?.to_vec();
305                let num_meta = self.config.metadata_columns.len();
306                let meta_args = &mut args[3..3 + num_meta];
307
308                // Validate dimension and finiteness before inserting
309                self.config
310                    .vtype
311                    .validate_blob(&vector_blob, self.config.dim)
312                    .map_err(|e| Error::Module(e.to_string()))?;
313                self.config
314                    .vtype
315                    .validate_finite(&vector_blob, self.config.dim)
316                    .map_err(|e| Error::Module(e.to_string()))?;
317
318                let rowid = insert_into_data_shadow(db, &self.config, &vector_blob, meta_args)?;
319
320                let state = self.state.borrow();
321                state
322                    .index
323                    .add(rowid as u64, &vector_blob)
324                    .map_err(|e| Error::Module(e.to_string()))?;
325                drop(state);
326                self.state.borrow_mut().dirty = true;
327
328                Ok(rowid)
329            }
330            ChangeType::Update => {
331                let rowid = info.rowid().get_i64();
332                let args = info.args_mut();
333                // args[0] = new rowid, args[1] = id col, args[2] = vector, args[3+N] = distance
334                let vector_blob = args[2].get_blob()?.to_vec();
335                let num_meta = self.config.metadata_columns.len();
336                let meta_args = &mut args[3..3 + num_meta];
337
338                update_data_shadow(db, &self.config, rowid, &vector_blob, meta_args)?;
339
340                // Update index: remove old entry, add new one
341                let state = self.state.borrow();
342                state
343                    .index
344                    .remove(rowid as u64)
345                    .map_err(|e| Error::Module(e.to_string()))?;
346                state
347                    .index
348                    .add(rowid as u64, &vector_blob)
349                    .map_err(|e| Error::Module(e.to_string()))?;
350                drop(state);
351                self.state.borrow_mut().dirty = true;
352
353                Ok(rowid)
354            }
355        }
356    }
357}
358
359// ---------------------------------------------------------------------------
360// TransactionVTab impl
361// ---------------------------------------------------------------------------
362
363impl<'vtab> TransactionVTab<'vtab> for VectorTable<'vtab> {
364    type Transaction = VectorTransaction;
365
366    fn begin(&'vtab self) -> Result<Self::Transaction> {
367        Ok(VectorTransaction {
368            state: Arc::clone(&self.state),
369            table_name: self.config.table_name.clone(),
370            db: self.db,
371        })
372    }
373}
374
375// ---------------------------------------------------------------------------
376// FindFunctionVTab impl
377// ---------------------------------------------------------------------------
378
379impl<'vtab> FindFunctionVTab<'vtab> for VectorTable<'vtab> {
380    fn functions(&'vtab self) -> &'vtab VTabFunctionList<'vtab, Self> {
381        &self.functions
382    }
383}