tank_sqlite/
connection.rs

1use crate::{
2    CBox, SQLiteDriver, SQLitePrepared, SQLiteTransaction,
3    extract::{extract_name, extract_value},
4};
5use async_stream::try_stream;
6use flume::Sender;
7use libsqlite3_sys::*;
8use std::{
9    borrow::Cow,
10    ffi::{CStr, CString, c_char, c_int},
11    mem, ptr,
12    str::FromStr,
13    sync::{
14        Arc,
15        atomic::{AtomicPtr, Ordering},
16    },
17};
18use tank_core::{
19    AsQuery, Connection, Error, ErrorContext, Executor, Prepared, Query, QueryResult, Result,
20    RowLabeled, RowsAffected, error_message_from_ptr, send_value, stream::Stream, truncate_long,
21};
22use tokio::task::spawn_blocking;
23
24/// Wrapper for a SQLite `sqlite3` connection pointer used by the SQLite driver.
25///
26/// Provides helpers to prepare/execute statements and stream results into `tank_core` result types.
27pub struct SQLiteConnection {
28    pub(crate) connection: CBox<*mut sqlite3>,
29    pub(crate) _transaction: bool,
30}
31
32impl SQLiteConnection {
33    pub fn last_error(&self) -> String {
34        unsafe {
35            let errcode = sqlite3_errcode(*self.connection);
36            format!(
37                "Error ({errcode}): {}",
38                error_message_from_ptr(&sqlite3_errmsg(*self.connection)),
39            )
40        }
41    }
42
43    pub(crate) fn do_run_prepared(
44        connection: *mut sqlite3,
45        statement: *mut sqlite3_stmt,
46        tx: Sender<Result<QueryResult>>,
47    ) {
48        unsafe {
49            let count = sqlite3_column_count(statement);
50            let labels = match (0..count)
51                .map(|i| extract_name(statement, i))
52                .collect::<Result<Arc<[_]>>>()
53            {
54                Ok(labels) => labels,
55                Err(error) => {
56                    send_value!(tx, Err(error.into()));
57                    return;
58                }
59            };
60            loop {
61                match sqlite3_step(statement) {
62                    SQLITE_BUSY => {
63                        continue;
64                    }
65                    SQLITE_DONE => {
66                        if sqlite3_stmt_readonly(statement) == 0 {
67                            send_value!(
68                                tx,
69                                Ok(QueryResult::Affected(RowsAffected {
70                                    rows_affected: Some(sqlite3_changes64(connection) as _),
71                                    last_affected_id: Some(sqlite3_last_insert_rowid(connection)),
72                                }))
73                            );
74                        }
75                        break;
76                    }
77                    SQLITE_ROW => {
78                        let values = match (0..count)
79                            .map(|i| extract_value(statement, i))
80                            .collect::<Result<_>>()
81                        {
82                            Ok(value) => value,
83                            Err(error) => {
84                                send_value!(tx, Err(error));
85                                return;
86                            }
87                        };
88                        send_value!(
89                            tx,
90                            Ok(QueryResult::Row(RowLabeled {
91                                labels: labels.clone(),
92                                values: values,
93                            }))
94                        )
95                    }
96                    _ => {
97                        send_value!(
98                            tx,
99                            Err(Error::msg(
100                                error_message_from_ptr(&sqlite3_errmsg(sqlite3_db_handle(
101                                    statement,
102                                )))
103                                .to_string(),
104                            ))
105                        );
106                        break;
107                    }
108                }
109            }
110        }
111    }
112
113    pub(crate) fn do_run_unprepared(
114        connection: *mut sqlite3,
115        sql: &str,
116        tx: Sender<Result<QueryResult>>,
117    ) {
118        unsafe {
119            let sql = sql.trim();
120            let mut it = sql.as_ptr() as *const c_char;
121            let mut len = sql.len();
122            loop {
123                let (statement, tail) = {
124                    let mut statement = SQLitePrepared::new(CBox::new(ptr::null_mut(), |p| {
125                        sqlite3_finalize(p);
126                    }));
127                    let mut sql_tail = ptr::null();
128                    let rc = sqlite3_prepare_v2(
129                        connection,
130                        it,
131                        len as c_int,
132                        &mut *statement.statement,
133                        &mut sql_tail,
134                    );
135                    if rc != SQLITE_OK {
136                        send_value!(
137                            tx,
138                            Err(Error::msg(
139                                error_message_from_ptr(&sqlite3_errmsg(connection)).to_string(),
140                            ))
141                        );
142                        return;
143                    }
144                    (statement, sql_tail)
145                };
146                Self::do_run_prepared(connection, statement.statement(), tx.clone());
147                len = if tail != ptr::null() {
148                    len - tail.offset_from_unsigned(it)
149                } else {
150                    0
151                };
152                if len == 0 {
153                    break;
154                }
155                it = tail;
156            }
157        };
158    }
159}
160
161impl Executor for SQLiteConnection {
162    type Driver = SQLiteDriver;
163
164    async fn prepare(&mut self, sql: String) -> Result<Query<Self::Driver>> {
165        let connection = AtomicPtr::new(*self.connection);
166        let context = format!(
167            "While preparing the query:\n{}",
168            truncate_long!(sql.as_str())
169        );
170        let prepared = spawn_blocking(move || unsafe {
171            let connection = connection.load(Ordering::Relaxed);
172            let len = sql.len();
173            let sql = CString::new(sql.as_bytes())?;
174            let mut statement = CBox::new(ptr::null_mut(), |p| {
175                sqlite3_finalize(p);
176            });
177            let mut tail = ptr::null();
178            let rc = sqlite3_prepare_v2(
179                connection,
180                sql.as_ptr(),
181                len as c_int,
182                &mut *statement,
183                &mut tail,
184            );
185            if rc != SQLITE_OK {
186                let error =
187                    Error::msg(error_message_from_ptr(&sqlite3_errmsg(connection)).to_string())
188                        .context(context);
189                log::error!("{:#}", error);
190                return Err(error);
191            }
192            if tail != ptr::null() && *tail != '\0' as i8 {
193                let error = Error::msg(format!(
194                    "Cannot prepare more than one statement at a time (remaining: {})",
195                    CStr::from_ptr(tail).to_str().unwrap_or("unprintable")
196                ))
197                .context(context);
198                log::error!("{:#}", error);
199                return Err(error);
200            }
201            Ok(statement)
202        })
203        .await?;
204        Ok(SQLitePrepared::new(prepared?).into())
205    }
206
207    fn run<'s>(
208        &'s mut self,
209        query: impl AsQuery<Self::Driver> + 's,
210    ) -> impl Stream<Item = Result<QueryResult>> + Send {
211        let mut query = query.as_query();
212        let context = Arc::new(format!("While running the query:\n{}", query.as_mut()));
213        let (tx, rx) = flume::unbounded::<Result<QueryResult>>();
214        let connection = AtomicPtr::new(*self.connection);
215        let mut owned = mem::take(query.as_mut());
216        let join = spawn_blocking(move || {
217            match &mut owned {
218                Query::Raw(query) => {
219                    Self::do_run_unprepared(connection.load(Ordering::Relaxed), query, tx);
220                }
221                Query::Prepared(prepared) => {
222                    Self::do_run_prepared(
223                        connection.load(Ordering::Relaxed),
224                        prepared.statement(),
225                        tx,
226                    );
227                    let _ = prepared.clear_bindings();
228                }
229            }
230            owned
231        });
232        try_stream! {
233            while let Ok(result) = rx.recv_async().await {
234                yield result.map_err(|e| {
235                    let error = e.context(context.clone());
236                    log::error!("{:#}", error);
237                    error
238                })?;
239            }
240            *query.as_mut() = mem::take(&mut join.await?);
241        }
242    }
243}
244
245impl Connection for SQLiteConnection {
246    #[allow(refining_impl_trait)]
247    async fn connect(url: Cow<'static, str>) -> Result<SQLiteConnection> {
248        let context = format!("While trying to connect to `{}`", truncate_long!(url));
249        let url = Self::sanitize_url(url)?;
250        let url = CString::from_str(&url.as_str().replacen("sqlite://", "file:", 1))
251            .with_context(|| context.clone())?;
252        let mut connection;
253        unsafe {
254            connection = CBox::new(ptr::null_mut(), |p| {
255                if sqlite3_close(p) != SQLITE_OK {
256                    log::error!("Could not close sqlite connection")
257                }
258            });
259            let rc = sqlite3_open_v2(
260                url.as_ptr(),
261                &mut *connection,
262                SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_URI,
263                ptr::null(),
264            );
265            if rc != SQLITE_OK {
266                let error =
267                    Error::msg(error_message_from_ptr(&sqlite3_errmsg(*connection)).to_string())
268                        .context(context);
269                log::error!("{:#}", error);
270                return Err(error);
271            }
272        }
273        Ok(Self {
274            connection,
275            _transaction: false,
276        })
277    }
278
279    #[allow(refining_impl_trait)]
280    fn begin(&mut self) -> impl Future<Output = Result<SQLiteTransaction<'_>>> {
281        SQLiteTransaction::new(self)
282    }
283}