Skip to main content

tank_sqlite/
connection.rs

1use crate::{
2    CBox, SQLiteDriver, SQLitePrepared, SQLiteTransaction,
3    extract::{extract_name, extract_value},
4};
5use async_stream::try_stream;
6use flume::Sender;
7use libsqlite3_sys::*;
8use std::{
9    borrow::Cow,
10    ffi::{CStr, CString, c_char, c_int},
11    mem, ptr,
12    str::FromStr,
13    sync::{
14        Arc,
15        atomic::{AtomicPtr, Ordering},
16    },
17};
18use tank_core::{
19    AsQuery, Connection, Error, ErrorContext, Executor, Prepared, Query, QueryResult, RawQuery,
20    Result, RowLabeled, RowsAffected, error_message_from_ptr, send_value, stream::Stream,
21    truncate_long,
22};
23use tokio::task::spawn_blocking;
24
25/// Wrapper for a SQLite `sqlite3` connection pointer used by the SQLite driver.
26///
27/// Provides helpers to prepare/execute statements and stream results into `tank_core` result types.
28pub struct SQLiteConnection {
29    pub(crate) connection: CBox<*mut sqlite3>,
30}
31
32impl SQLiteConnection {
33    pub fn last_error(&self) -> String {
34        unsafe {
35            let errcode = sqlite3_errcode(*self.connection);
36            format!(
37                "Error ({errcode}): {}",
38                error_message_from_ptr(&sqlite3_errmsg(*self.connection)),
39            )
40        }
41    }
42
43    pub(crate) fn do_run_prepared(
44        connection: *mut sqlite3,
45        statement: *mut sqlite3_stmt,
46        tx: Sender<Result<QueryResult>>,
47    ) {
48        unsafe {
49            let count = sqlite3_column_count(statement);
50            let labels = match (0..count)
51                .map(|i| extract_name(statement, i))
52                .collect::<Result<Arc<[_]>>>()
53            {
54                Ok(labels) => labels,
55                Err(error) => {
56                    send_value!(tx, Err(error.into()));
57                    return;
58                }
59            };
60            loop {
61                match sqlite3_step(statement) {
62                    SQLITE_BUSY => {
63                        continue;
64                    }
65                    SQLITE_DONE => {
66                        if sqlite3_stmt_readonly(statement) == 0 {
67                            send_value!(
68                                tx,
69                                Ok(QueryResult::Affected(RowsAffected {
70                                    rows_affected: Some(sqlite3_changes64(connection) as _),
71                                    last_affected_id: Some(sqlite3_last_insert_rowid(connection)),
72                                }))
73                            );
74                        }
75                        break;
76                    }
77                    SQLITE_ROW => {
78                        let values = match (0..count)
79                            .map(|i| extract_value(statement, i))
80                            .collect::<Result<_>>()
81                        {
82                            Ok(value) => value,
83                            Err(error) => {
84                                send_value!(tx, Err(error));
85                                return;
86                            }
87                        };
88                        send_value!(
89                            tx,
90                            Ok(QueryResult::Row(RowLabeled {
91                                labels: labels.clone(),
92                                values: values,
93                            }))
94                        )
95                    }
96                    _ => {
97                        send_value!(
98                            tx,
99                            Err(Error::msg(
100                                error_message_from_ptr(&sqlite3_errmsg(sqlite3_db_handle(
101                                    statement,
102                                )))
103                                .to_string(),
104                            ))
105                        );
106                        break;
107                    }
108                }
109            }
110        }
111    }
112
113    pub(crate) fn do_run_unprepared(
114        connection: *mut sqlite3,
115        sql: &str,
116        tx: Sender<Result<QueryResult>>,
117    ) {
118        unsafe {
119            let sql = sql.trim();
120            let mut it = sql.as_ptr() as *const c_char;
121            let mut len = sql.len();
122            loop {
123                let (statement, tail) = {
124                    let mut statement = SQLitePrepared::new(CBox::new(ptr::null_mut(), |p| {
125                        sqlite3_finalize(p);
126                    }));
127                    let mut sql_tail = ptr::null();
128                    let rc = sqlite3_prepare_v2(
129                        connection,
130                        it,
131                        len as c_int,
132                        &mut *statement.statement,
133                        &mut sql_tail,
134                    );
135                    if rc != SQLITE_OK {
136                        send_value!(
137                            tx,
138                            Err(Error::msg(
139                                error_message_from_ptr(&sqlite3_errmsg(connection)).to_string(),
140                            ))
141                        );
142                        return;
143                    }
144                    (statement, sql_tail)
145                };
146                Self::do_run_prepared(connection, statement.statement(), tx.clone());
147                len = if tail != ptr::null() {
148                    len - tail.offset_from_unsigned(it)
149                } else {
150                    0
151                };
152                if len == 0 {
153                    break;
154                }
155                it = tail;
156            }
157        };
158    }
159}
160
161impl Executor for SQLiteConnection {
162    type Driver = SQLiteDriver;
163
164    async fn do_prepare(&mut self, sql: String) -> Result<Query<Self::Driver>> {
165        let connection = AtomicPtr::new(*self.connection);
166        let context = format!("While preparing the query:\n{}", truncate_long!(sql));
167        let prepared = spawn_blocking(move || unsafe {
168            let connection = connection.load(Ordering::Relaxed);
169            let len = sql.len();
170            let sql = CString::new(sql.into_bytes())?;
171            let mut statement = CBox::new(ptr::null_mut(), |p| {
172                sqlite3_finalize(p);
173            });
174            let mut tail = ptr::null();
175            let rc = sqlite3_prepare_v2(
176                connection,
177                sql.as_ptr(),
178                len as c_int,
179                &mut *statement,
180                &mut tail,
181            );
182            if rc != SQLITE_OK {
183                let error =
184                    Error::msg(error_message_from_ptr(&sqlite3_errmsg(connection)).to_string())
185                        .context(context);
186                log::error!("{:#}", error);
187                return Err(error);
188            }
189            if tail != ptr::null() && *tail != '\0' as i8 {
190                let error = Error::msg(format!(
191                    "Cannot prepare more than one statement at a time (remaining: {})",
192                    CStr::from_ptr(tail).to_str().unwrap_or("unprintable")
193                ))
194                .context(context);
195                log::error!("{:#}", error);
196                return Err(error);
197            }
198            Ok(statement)
199        })
200        .await?;
201        Ok(SQLitePrepared::new(prepared?).into())
202    }
203
204    fn run<'s>(
205        &'s mut self,
206        query: impl AsQuery<Self::Driver> + 's,
207    ) -> impl Stream<Item = Result<QueryResult>> + Send {
208        let mut query = query.as_query();
209        let context = Arc::new(format!("While running the query:\n{}", query.as_mut()));
210        let (tx, rx) = flume::unbounded::<Result<QueryResult>>();
211        let connection = AtomicPtr::new(*self.connection);
212        let mut owned = mem::take(query.as_mut());
213        let join = spawn_blocking(move || {
214            match &mut owned {
215                Query::Raw(RawQuery(sql)) => {
216                    Self::do_run_unprepared(connection.load(Ordering::Relaxed), sql, tx);
217                }
218                Query::Prepared(prepared) => {
219                    Self::do_run_prepared(
220                        connection.load(Ordering::Relaxed),
221                        prepared.statement(),
222                        tx,
223                    );
224                    let _ = prepared.clear_bindings();
225                }
226            }
227            owned
228        });
229        try_stream! {
230            while let Ok(result) = rx.recv_async().await {
231                yield result.map_err(|e| {
232                    let error = e.context(context.clone());
233                    log::error!("{:#}", error);
234                    error
235                })?;
236            }
237            *query.as_mut() = mem::take(&mut join.await?);
238        }
239    }
240}
241
242impl Connection for SQLiteConnection {
243    async fn connect(url: Cow<'static, str>) -> Result<SQLiteConnection> {
244        let context = format!("While trying to connect to `{}`", truncate_long!(url));
245        let url = Self::sanitize_url(url)?;
246        let url = CString::from_str(&url.as_str().replacen("sqlite://", "file:", 1))
247            .with_context(|| context.clone())?;
248        let mut connection;
249        unsafe {
250            connection = CBox::new(ptr::null_mut(), |p| {
251                if sqlite3_close(p) != SQLITE_OK {
252                    log::error!("Could not close sqlite connection")
253                }
254            });
255            let rc = sqlite3_open_v2(
256                url.as_ptr(),
257                &mut *connection,
258                SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_URI,
259                ptr::null(),
260            );
261            if rc != SQLITE_OK {
262                let error =
263                    Error::msg(error_message_from_ptr(&sqlite3_errmsg(*connection)).to_string())
264                        .context(context);
265                log::error!("{:#}", error);
266                return Err(error);
267            }
268        }
269        Ok(Self { connection })
270    }
271
272    fn begin(&mut self) -> impl Future<Output = Result<SQLiteTransaction<'_>>> {
273        SQLiteTransaction::new(self)
274    }
275}