sqlite_vector_rs/vtab/
transaction.rs1use std::cell::RefCell;
2use std::sync::Arc;
3
4use sqlite3_ext::query::ToParam;
5use sqlite3_ext::vtab::VTabConnection;
6use sqlite3_ext::{Error, Result};
7
8use crate::index::HnswIndex;
9use crate::vtab::shadow::ShadowOps;
10
11pub struct IndexState {
12 pub index: HnswIndex,
13 pub dirty: bool,
14 pub last_committed: Option<Vec<u8>>,
15}
16
17pub struct VectorTransaction {
18 pub state: Arc<RefCell<IndexState>>,
19 pub table_name: String,
20 pub db: *const VTabConnection,
22}
23
24unsafe impl Send for VectorTransaction {}
26unsafe impl Sync for VectorTransaction {}
27
28impl sqlite3_ext::vtab::VTabTransaction for VectorTransaction {
29 fn sync(&mut self) -> Result<()> {
30 let mut s = self.state.borrow_mut();
31 if s.dirty {
32 let buf = s
33 .index
34 .save_to_buffer()
35 .map_err(|e| Error::Module(e.to_string()))?;
36
37 use sqlite3_ext::query::Statement;
39 let db = unsafe { &*self.db };
40 let sql = ShadowOps::upsert_index_sql(&self.table_name);
41 db.insert(&sql, |stmt: &mut Statement| {
42 "hnsw_graph".bind_param(&mut *stmt, 1)?;
43 buf.as_slice().bind_param(&mut *stmt, 2)?;
44 Ok(())
45 })?;
46
47 s.last_committed = Some(buf);
48 s.dirty = false;
49 }
50 Ok(())
51 }
52
53 fn commit(self) -> Result<()> {
54 Ok(())
56 }
57
58 fn rollback(self) -> Result<()> {
59 let mut s = self.state.borrow_mut();
60 if let Some(ref buf) = s.last_committed.clone() {
61 s.index
62 .load_from_buffer(buf)
63 .map_err(|e| Error::Module(e.to_string()))?;
64 }
65 s.dirty = false;
66 Ok(())
67 }
68
69 fn savepoint(&mut self, _n: i32) -> Result<()> {
70 Ok(())
71 }
72
73 fn release(&mut self, _n: i32) -> Result<()> {
74 Ok(())
75 }
76
77 fn rollback_to(&mut self, _n: i32) -> Result<()> {
78 let mut s = self.state.borrow_mut();
79 if let Some(ref buf) = s.last_committed.clone() {
80 s.index
81 .load_from_buffer(buf)
82 .map_err(|e| Error::Module(e.to_string()))?;
83 }
84 s.dirty = false;
85 Ok(())
86 }
87}