Skip to main content

rustbasic_core/sql/driver/sqlite/
mod.rs

1use std::ffi::{c_char, c_int, c_void, CStr, CString};
2use crate::sql::driver::error::SqlError;
3use crate::sql::driver::{SqlValue, SqlColumn, SqlRow, QueryResult};
4
5// Opaque SQLite database and statement handles
6#[repr(C)]
7pub struct sqlite3 {
8    _private: [u8; 0],
9}
10
11#[repr(C)]
12pub struct sqlite3_stmt {
13    _private: [u8; 0],
14}
15
16// SQLite Constants
17const SQLITE_OK: c_int = 0;
18const SQLITE_ROW: c_int = 100;
19const SQLITE_DONE: c_int = 101;
20
21const SQLITE_INTEGER: c_int = 1;
22const SQLITE_FLOAT: c_int = 2;
23const SQLITE_TEXT: c_int = 3;
24const SQLITE_BLOB: c_int = 4;
25const SQLITE_NULL: c_int = 5;
26
27#[link(name = "sqlite3")]
28unsafe extern "C" {
29    fn sqlite3_open(filename: *const c_char, ppDb: *mut *mut sqlite3) -> c_int;
30    fn sqlite3_close(db: *mut sqlite3) -> c_int;
31    fn sqlite3_errmsg(db: *mut sqlite3) -> *const c_char;
32    
33    fn sqlite3_prepare_v2(
34        db: *mut sqlite3,
35        zSql: *const c_char,
36        nByte: c_int,
37        ppStmt: *mut *mut sqlite3_stmt,
38        pzTail: *mut *const c_char,
39    ) -> c_int;
40    
41    fn sqlite3_finalize(pStmt: *mut sqlite3_stmt) -> c_int;
42    fn sqlite3_step(pStmt: *mut sqlite3_stmt) -> c_int;
43    
44    // Binding parameters
45    fn sqlite3_bind_null(pStmt: *mut sqlite3_stmt, index: c_int) -> c_int;
46    fn sqlite3_bind_int64(pStmt: *mut sqlite3_stmt, index: c_int, value: i64) -> c_int;
47    fn sqlite3_bind_double(pStmt: *mut sqlite3_stmt, index: c_int, value: f64) -> c_int;
48    fn sqlite3_bind_text(
49        pStmt: *mut sqlite3_stmt,
50        index: c_int,
51        value: *const c_char,
52        n: c_int,
53        destructor: Option<unsafe extern "C" fn(*mut c_void)>,
54    ) -> c_int;
55    fn sqlite3_bind_blob(
56        pStmt: *mut sqlite3_stmt,
57        index: c_int,
58        value: *const c_void,
59        n: c_int,
60        destructor: Option<unsafe extern "C" fn(*mut c_void)>,
61    ) -> c_int;
62
63    // Retrieving columns
64    fn sqlite3_column_count(pStmt: *mut sqlite3_stmt) -> c_int;
65    fn sqlite3_column_name(pStmt: *mut sqlite3_stmt, N: c_int) -> *const c_char;
66    fn sqlite3_column_type(pStmt: *mut sqlite3_stmt, iCol: c_int) -> c_int;
67    
68    fn sqlite3_column_int64(pStmt: *mut sqlite3_stmt, iCol: c_int) -> i64;
69    fn sqlite3_column_double(pStmt: *mut sqlite3_stmt, iCol: c_int) -> f64;
70    fn sqlite3_column_text(pStmt: *mut sqlite3_stmt, iCol: c_int) -> *const c_char;
71    fn sqlite3_column_blob(pStmt: *mut sqlite3_stmt, iCol: c_int) -> *const c_void;
72    fn sqlite3_column_bytes(pStmt: *mut sqlite3_stmt, iCol: c_int) -> c_int;
73    
74    fn sqlite3_changes(db: *mut sqlite3) -> c_int;
75    fn sqlite3_last_insert_rowid(db: *mut sqlite3) -> i64;
76}
77
78pub struct SqliteConnection {
79    db: *mut sqlite3,
80}
81
82unsafe impl Send for SqliteConnection {}
83
84impl Drop for SqliteConnection {
85    fn drop(&mut self) {
86        if !self.db.is_null() {
87            unsafe {
88                sqlite3_close(self.db);
89            }
90        }
91    }
92}
93
94impl SqliteConnection {
95    pub fn connect(path: &str) -> Result<Self, SqlError> {
96        let c_path = CString::new(path)
97            .map_err(|e| SqlError::Other(format!("Invalid path: {}", e)))?;
98        let mut db = std::ptr::null_mut();
99        let rc = unsafe { sqlite3_open(c_path.as_ptr(), &mut db) };
100        if rc != SQLITE_OK {
101            let err_msg = if db.is_null() {
102                "Failed to open database".to_string()
103            } else {
104                unsafe {
105                    let err = sqlite3_errmsg(db);
106                    CStr::from_ptr(err).to_string_lossy().into_owned()
107                }
108            };
109            if !db.is_null() {
110                unsafe { sqlite3_close(db); }
111            }
112            return Err(SqlError::Other(err_msg));
113        }
114        Ok(Self { db })
115    }
116
117    fn prepare_and_bind(&self, sql: &str, params: &[SqlValue]) -> Result<*mut sqlite3_stmt, SqlError> {
118        let c_sql = CString::new(sql)
119            .map_err(|e| SqlError::Other(format!("Invalid SQL string: {}", e)))?;
120        let mut stmt = std::ptr::null_mut();
121        let rc = unsafe {
122            sqlite3_prepare_v2(self.db, c_sql.as_ptr(), -1, &mut stmt, std::ptr::null_mut())
123        };
124        if rc != SQLITE_OK {
125            let err_msg = unsafe {
126                let err = sqlite3_errmsg(self.db);
127                CStr::from_ptr(err).to_string_lossy().into_owned()
128            };
129            return Err(SqlError::Other(err_msg));
130        }
131
132        let sqlite_transient: Option<unsafe extern "C" fn(*mut c_void)> = unsafe {
133            std::mem::transmute(-1isize)
134        };
135
136        for (i, param) in params.iter().enumerate() {
137            let idx = (i + 1) as c_int;
138            let rc = unsafe {
139                match param {
140                    SqlValue::Null => sqlite3_bind_null(stmt, idx),
141                    SqlValue::Integer(val) => sqlite3_bind_int64(stmt, idx, *val),
142                    SqlValue::Real(val) => sqlite3_bind_double(stmt, idx, *val),
143                    SqlValue::Text(val) => {
144                        let c_str = CString::new(val.as_str()).map_err(|e| SqlError::Other(e.to_string()))?;
145                        sqlite3_bind_text(stmt, idx, c_str.as_ptr(), -1, sqlite_transient)
146                    }
147                    SqlValue::Blob(val) => {
148                        sqlite3_bind_blob(stmt, idx, val.as_ptr() as *const c_void, val.len() as c_int, sqlite_transient)
149                    }
150                }
151            };
152            if rc != SQLITE_OK {
153                let err_msg = unsafe {
154                    let err = sqlite3_errmsg(self.db);
155                    CStr::from_ptr(err).to_string_lossy().into_owned()
156                };
157                unsafe { sqlite3_finalize(stmt); }
158                return Err(SqlError::Other(err_msg));
159            }
160        }
161
162        Ok(stmt)
163    }
164
165    pub fn execute(&mut self, sql: &str, params: &[SqlValue]) -> Result<QueryResult, SqlError> {
166        let stmt = self.prepare_and_bind(sql, params)?;
167        let mut rc = unsafe { sqlite3_step(stmt) };
168        
169        while rc == SQLITE_ROW {
170            rc = unsafe { sqlite3_step(stmt) };
171        }
172
173        if rc != SQLITE_DONE {
174            let err_msg = unsafe {
175                let err = sqlite3_errmsg(self.db);
176                CStr::from_ptr(err).to_string_lossy().into_owned()
177            };
178            unsafe { sqlite3_finalize(stmt); }
179            return Err(SqlError::Other(err_msg));
180        }
181
182        unsafe { sqlite3_finalize(stmt); }
183
184        let rows_affected = unsafe { sqlite3_changes(self.db) } as u64;
185        let last_insert_id = unsafe { sqlite3_last_insert_rowid(self.db) } as u64;
186
187        Ok(QueryResult {
188            rows_affected,
189            last_insert_id,
190        })
191    }
192
193    pub fn query(&mut self, sql: &str, params: &[SqlValue]) -> Result<Vec<SqlRow>, SqlError> {
194        let stmt = self.prepare_and_bind(sql, params)?;
195        
196        let col_count = unsafe { sqlite3_column_count(stmt) } as usize;
197        let mut columns = Vec::with_capacity(col_count);
198        for i in 0..col_count {
199            let name_ptr = unsafe { sqlite3_column_name(stmt, i as c_int) };
200            let name = if name_ptr.is_null() {
201                format!("column_{}", i)
202            } else {
203                unsafe { CStr::from_ptr(name_ptr).to_string_lossy().into_owned() }
204            };
205            columns.push(SqlColumn { name });
206        }
207
208        let mut rows = Vec::new();
209        loop {
210            let rc = unsafe { sqlite3_step(stmt) };
211            if rc == SQLITE_DONE {
212                break;
213            }
214            if rc != SQLITE_ROW {
215                let err_msg = unsafe {
216                    let err = sqlite3_errmsg(self.db);
217                    CStr::from_ptr(err).to_string_lossy().into_owned()
218                };
219                unsafe { sqlite3_finalize(stmt); }
220                return Err(SqlError::Other(err_msg));
221            }
222
223            let mut values = Vec::with_capacity(col_count);
224            for i in 0..col_count {
225                let col_idx = i as c_int;
226                let val_type = unsafe { sqlite3_column_type(stmt, col_idx) };
227                let val = match val_type {
228                    SQLITE_NULL => SqlValue::Null,
229                    SQLITE_INTEGER => {
230                        let v = unsafe { sqlite3_column_int64(stmt, col_idx) };
231                        SqlValue::Integer(v)
232                    }
233                    SQLITE_FLOAT => {
234                        let v = unsafe { sqlite3_column_double(stmt, col_idx) };
235                        SqlValue::Real(v)
236                    }
237                    SQLITE_TEXT => {
238                        let text_ptr = unsafe { sqlite3_column_text(stmt, col_idx) };
239                        let s = if text_ptr.is_null() {
240                            String::new()
241                        } else {
242                            unsafe { CStr::from_ptr(text_ptr as *const c_char).to_string_lossy().into_owned() }
243                        };
244                        SqlValue::Text(s)
245                    }
246                    SQLITE_BLOB => {
247                        let blob_ptr = unsafe { sqlite3_column_blob(stmt, col_idx) };
248                        let num_bytes = unsafe { sqlite3_column_bytes(stmt, col_idx) } as usize;
249                        let mut bytes = vec![0u8; num_bytes];
250                        if !blob_ptr.is_null() && num_bytes > 0 {
251                            unsafe {
252                                std::ptr::copy_nonoverlapping(blob_ptr as *const u8, bytes.as_mut_ptr(), num_bytes);
253                            }
254                        }
255                        SqlValue::Blob(bytes)
256                    }
257                    _ => SqlValue::Null,
258                };
259                values.push(val);
260            }
261            rows.push(SqlRow {
262                columns: columns.clone(),
263                values,
264            });
265        }
266
267        unsafe { sqlite3_finalize(stmt); }
268        Ok(rows)
269    }
270
271    pub fn begin(&mut self) -> Result<SqliteTransaction<'_>, SqlError> {
272        self.execute("BEGIN", &[])?;
273        Ok(SqliteTransaction {
274            conn: self,
275            committed: false,
276        })
277    }
278}
279
280pub struct SqliteTransaction<'a> {
281    conn: &'a mut SqliteConnection,
282    committed: bool,
283}
284
285impl<'a> SqliteTransaction<'a> {
286    pub fn execute(&mut self, sql: &str, params: &[SqlValue]) -> Result<QueryResult, SqlError> {
287        self.conn.execute(sql, params)
288    }
289
290    pub fn query(&mut self, sql: &str, params: &[SqlValue]) -> Result<Vec<SqlRow>, SqlError> {
291        self.conn.query(sql, params)
292    }
293
294    pub fn commit(mut self) -> Result<(), SqlError> {
295        self.conn.execute("COMMIT", &[])?;
296        self.committed = true;
297        Ok(())
298    }
299
300    pub fn rollback(mut self) -> Result<(), SqlError> {
301        self.conn.execute("ROLLBACK", &[])?;
302        self.committed = true;
303        Ok(())
304    }
305}
306
307impl<'a> Drop for SqliteTransaction<'a> {
308    fn drop(&mut self) {
309        if !self.committed {
310            let _ = self.conn.execute("ROLLBACK", &[]);
311        }
312    }
313}
314
315impl crate::sql::driver::SqlConnection for SqliteConnection {
316    fn execute(&mut self, sql: &str, params: &[SqlValue]) -> Result<QueryResult, SqlError> {
317        self.execute(sql, params)
318    }
319
320    fn query(&mut self, sql: &str, params: &[SqlValue]) -> Result<Vec<SqlRow>, SqlError> {
321        self.query(sql, params)
322    }
323}
324
325impl<'a> crate::sql::driver::SqlConnection for SqliteTransaction<'a> {
326    fn execute(&mut self, sql: &str, params: &[SqlValue]) -> Result<QueryResult, SqlError> {
327        self.execute(sql, params)
328    }
329
330    fn query(&mut self, sql: &str, params: &[SqlValue]) -> Result<Vec<SqlRow>, SqlError> {
331        self.query(sql, params)
332    }
333}