Skip to main content

sqlite_vector_rs/vtab/
transaction.rs

1use std::cell::RefCell;
2use std::sync::Arc;
3
4use sqlite3_ext::query::ToParam;
5use sqlite3_ext::vtab::VTabConnection;
6use sqlite3_ext::{Error, Result};
7
8use crate::index::HnswIndex;
9use crate::vtab::shadow::ShadowOps;
10
11pub struct IndexState {
12    pub index: HnswIndex,
13    pub dirty: bool,
14    pub last_committed: Option<Vec<u8>>,
15}
16
17pub struct VectorTransaction {
18    pub state: Arc<RefCell<IndexState>>,
19    pub table_name: String,
20    /// Safety: valid for the vtab lifetime — SQLite keeps the connection alive.
21    pub db: *const VTabConnection,
22}
23
24// Safety: VectorTransaction is only ever accessed from a single thread by SQLite.
25unsafe impl Send for VectorTransaction {}
26unsafe impl Sync for VectorTransaction {}
27
28impl sqlite3_ext::vtab::VTabTransaction for VectorTransaction {
29    fn sync(&mut self) -> Result<()> {
30        let mut s = self.state.borrow_mut();
31        if s.dirty {
32            let buf = s
33                .index
34                .save_to_buffer()
35                .map_err(|e| Error::Module(e.to_string()))?;
36
37            // Persist serialized HNSW graph to the _index shadow table.
38            use sqlite3_ext::query::Statement;
39            let db = unsafe { &*self.db };
40            let sql = ShadowOps::upsert_index_sql(&self.table_name);
41            db.insert(&sql, |stmt: &mut Statement| {
42                "hnsw_graph".bind_param(&mut *stmt, 1)?;
43                buf.as_slice().bind_param(&mut *stmt, 2)?;
44                Ok(())
45            })?;
46
47            s.last_committed = Some(buf);
48            s.dirty = false;
49        }
50        Ok(())
51    }
52
53    fn commit(self) -> Result<()> {
54        // sync() has already serialized and persisted; nothing more to do.
55        Ok(())
56    }
57
58    fn rollback(self) -> Result<()> {
59        let mut s = self.state.borrow_mut();
60        if let Some(ref buf) = s.last_committed.clone() {
61            s.index
62                .load_from_buffer(buf)
63                .map_err(|e| Error::Module(e.to_string()))?;
64        }
65        s.dirty = false;
66        Ok(())
67    }
68
69    fn savepoint(&mut self, _n: i32) -> Result<()> {
70        Ok(())
71    }
72
73    fn release(&mut self, _n: i32) -> Result<()> {
74        Ok(())
75    }
76
77    fn rollback_to(&mut self, _n: i32) -> Result<()> {
78        let mut s = self.state.borrow_mut();
79        if let Some(ref buf) = s.last_committed.clone() {
80            s.index
81                .load_from_buffer(buf)
82                .map_err(|e| Error::Module(e.to_string()))?;
83        }
84        s.dirty = false;
85        Ok(())
86    }
87}