tank_sqlite/
connection.rs

1use crate::{
2    CBox, SqliteDriver, SqlitePrepared, SqliteTransaction, error_message_from_ptr,
3    extract::{extract_name, extract_value},
4};
5use async_stream::{stream, try_stream};
6use libsqlite3_sys::{
7    SQLITE_BUSY, SQLITE_DONE, SQLITE_OK, SQLITE_OPEN_CREATE, SQLITE_OPEN_READWRITE,
8    SQLITE_OPEN_URI, SQLITE_ROW, sqlite3, sqlite3_changes64, sqlite3_close, sqlite3_column_count,
9    sqlite3_db_handle, sqlite3_errmsg, sqlite3_finalize, sqlite3_last_insert_rowid,
10    sqlite3_open_v2, sqlite3_prepare_v2, sqlite3_step, sqlite3_stmt, sqlite3_stmt_readonly,
11};
12use std::{
13    borrow::Cow,
14    ffi::{CStr, CString, c_char, c_int},
15    pin::pin,
16    ptr,
17    sync::{
18        Arc,
19        atomic::{AtomicPtr, Ordering},
20    },
21};
22use tank_core::{
23    Connection, Driver, Error, ErrorContext, Executor, Query, QueryResult, Result, RowLabeled,
24    RowsAffected,
25    future::Either,
26    printable_query,
27    stream::{Stream, StreamExt},
28};
29use tokio::task::spawn_blocking;
30
31pub struct SqliteConnection {
32    pub(crate) connection: CBox<*mut sqlite3>,
33    pub(crate) _transaction: bool,
34}
35
36impl SqliteConnection {
37    pub(crate) fn run_prepared(
38        &mut self,
39        statement: CBox<*mut sqlite3_stmt>,
40    ) -> impl Stream<Item = Result<QueryResult>> {
41        unsafe {
42            stream! {
43                let count = sqlite3_column_count(*statement);
44                let labels = (0..count)
45                    .map(|i| extract_name(*statement, i))
46                    .collect::<Result<Arc<[_]>>>()?;
47                loop {
48                    match sqlite3_step(*statement) {
49                        SQLITE_BUSY => {
50                            continue;
51                        }
52                        SQLITE_DONE => {
53                            if sqlite3_stmt_readonly(*statement) == 0 {
54                                yield Ok(QueryResult::Affected(RowsAffected {
55                                    rows_affected: sqlite3_changes64(*self.connection) as u64,
56                                    last_affected_id: Some(sqlite3_last_insert_rowid(*self.connection)),
57                                }))
58                            }
59                            break;
60                        }
61                        SQLITE_ROW => {
62                            yield Ok(QueryResult::RowLabeled(RowLabeled {
63                                labels: labels.clone(),
64                                values: (0..count).map(|i| extract_value(*statement, i)).collect()?,
65                            }))
66                        }
67                        _ => {
68                            let error = Error::msg(
69                                error_message_from_ptr(&sqlite3_errmsg(sqlite3_db_handle(*statement)))
70                                    .to_string(),
71                            );
72                            log::error!("{:#}", error);
73                            yield Err(error);
74                        }
75                    }
76                }
77            }
78        }
79    }
80
81    pub(crate) fn run_unprepared(
82        &mut self,
83        sql: String,
84    ) -> impl Stream<Item = Result<QueryResult>> {
85        try_stream! {
86            let mut len = sql.trim_end().len();
87            let buff = sql.into_bytes();
88            let mut it = CBox::new(buff.as_ptr() as *const c_char, |_| {});
89            loop {
90                let connection = CBox::new(*self.connection, |_| {});
91                let sql = CBox::new(*it, |_| {});
92                let (statement, tail) = spawn_blocking(move || unsafe {
93                    let mut statement = CBox::new(ptr::null_mut(), |p| {
94                        sqlite3_finalize(p);
95                    });
96                    let mut sql_tail = CBox::new(ptr::null(), |_| {});
97                    let rc = sqlite3_prepare_v2(
98                        *connection,
99                        *sql,
100                        len as c_int,
101                        &mut *statement,
102                        &mut *sql_tail,
103                    );
104                    if rc != SQLITE_OK {
105                        return Err(Error::msg(
106                            error_message_from_ptr(&sqlite3_errmsg(*connection)).to_string(),
107                        ));
108                    }
109                    Ok((statement, sql_tail))
110                })
111                .await??;
112                let mut stream = pin!(self.run_prepared(statement));
113                while let Some(value) = stream.next().await {
114                    yield value?
115                }
116                unsafe {
117                    len = if *tail != ptr::null() {
118                        len - tail.offset_from_unsigned(*it)
119                    } else {
120                        0
121                    };
122                    if len == 0 {
123                        break;
124                    }
125                }
126                *it = *tail;
127            }
128        }
129    }
130}
131
132impl Executor for SqliteConnection {
133    type Driver = SqliteDriver;
134
135    fn driver(&self) -> &Self::Driver {
136        &SqliteDriver {}
137    }
138
139    async fn prepare(&mut self, sql: String) -> Result<Query<Self::Driver>> {
140        let connection = AtomicPtr::new(*self.connection);
141        let context = format!(
142            "Failed to prepare the query:\n{}",
143            printable_query!(sql.as_str())
144        );
145        let prepared = spawn_blocking(move || unsafe {
146            let connection = connection.load(Ordering::Relaxed);
147            let len = sql.len();
148            let sql = CString::new(sql.as_bytes())?;
149            let mut statement = CBox::new(ptr::null_mut(), |p| {
150                sqlite3_finalize(p);
151            });
152            let mut tail = ptr::null();
153            let rc = sqlite3_prepare_v2(
154                connection,
155                sql.as_ptr(),
156                len as c_int,
157                &mut *statement,
158                &mut tail,
159            );
160            if rc != SQLITE_OK {
161                let error =
162                    Error::msg(error_message_from_ptr(&sqlite3_errmsg(connection)).to_string())
163                        .context(context);
164                log::error!("{:#}", error);
165                return Err(error);
166            }
167            if tail != ptr::null() && *tail != '\0' as i8 {
168                let error = Error::msg(format!(
169                    "Cannot prepare more than one statement at a time (remaining: {})",
170                    CStr::from_ptr(tail).to_str().unwrap_or("unprintable")
171                ))
172                .context(context);
173                log::error!("{:#}", error);
174                return Err(error);
175            }
176            Ok(statement)
177        })
178        .await?;
179        Ok(SqlitePrepared::new(prepared?).into())
180    }
181
182    fn run(
183        &mut self,
184        query: Query<Self::Driver>,
185    ) -> impl Stream<Item = Result<QueryResult>> + Send {
186        match query {
187            Query::Raw(sql) => Either::Left(self.run_unprepared(sql)),
188            Query::Prepared(prepared) => Either::Right(self.run_prepared(prepared.statement)),
189        }
190    }
191}
192
193impl Connection for SqliteConnection {
194    #[allow(refining_impl_trait)]
195    async fn connect(url: Cow<'static, str>) -> Result<SqliteConnection> {
196        let prefix = format!("{}://", <Self::Driver as Driver>::NAME);
197        if !url.starts_with(&prefix) {
198            let error = Error::msg(format!(
199                "Sqlite connection url must start with `{}`",
200                &prefix
201            ));
202            log::error!("{:#}", error);
203            return Err(error);
204        }
205        let url = CString::new(format!("file:{}", url.trim_start_matches(&prefix)))
206            .with_context(|| format!("Invalid database url `{}`", url))?;
207        let mut connection;
208        unsafe {
209            connection = CBox::new(ptr::null_mut(), |p| {
210                if sqlite3_close(p) != SQLITE_OK {
211                    log::error!("Could not close sqlite connection")
212                }
213            });
214            let rc = sqlite3_open_v2(
215                url.as_ptr(),
216                &mut *connection,
217                SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_URI,
218                ptr::null(),
219            );
220            if rc != SQLITE_OK {
221                let error =
222                    Error::msg(error_message_from_ptr(&sqlite3_errmsg(*connection)).to_string())
223                        .context(format!(
224                            "Failed to connect to database url `{}`",
225                            url.to_str().unwrap_or("unprintable value")
226                        ));
227                log::error!("{:#}", error);
228                return Err(error);
229            }
230        }
231        Ok(Self {
232            connection,
233            _transaction: false,
234        })
235    }
236
237    #[allow(refining_impl_trait)]
238    fn begin(&mut self) -> impl Future<Output = Result<SqliteTransaction>> {
239        SqliteTransaction::new(self)
240    }
241}