1use crate::{
2 CBox, SQLiteDriver, SQLitePrepared, SQLiteTransaction, error_message_from_ptr,
3 extract::{extract_name, extract_value},
4};
5use async_stream::try_stream;
6use flume::Sender;
7use libsqlite3_sys::*;
8use std::{
9 borrow::Cow,
10 ffi::{CStr, CString, c_char, c_int},
11 mem, ptr,
12 sync::{
13 Arc,
14 atomic::{AtomicPtr, Ordering},
15 },
16};
17use tank_core::{
18 AsQuery, Connection, Driver, Error, ErrorContext, Executor, Query, QueryResult, Result,
19 RowLabeled, RowsAffected, send_value, stream::Stream, truncate_long,
20};
21use tokio::task::spawn_blocking;
22
23pub struct SQLiteConnection {
24 pub(crate) connection: CBox<*mut sqlite3>,
25 pub(crate) _transaction: bool,
26}
27
28impl SQLiteConnection {
29 pub fn last_error(&self) -> String {
30 unsafe {
31 let errcode = sqlite3_errcode(*self.connection);
32 format!(
33 "Error ({errcode}): {}",
34 error_message_from_ptr(&sqlite3_errmsg(*self.connection)).to_string(),
35 )
36 }
37 }
38
39 pub(crate) fn do_run_prepared(
40 connection: *mut sqlite3,
41 statement: *mut sqlite3_stmt,
42 tx: Sender<Result<QueryResult>>,
43 ) {
44 unsafe {
45 let count = sqlite3_column_count(statement);
46 let labels = match (0..count)
47 .map(|i| extract_name(statement, i))
48 .collect::<Result<Arc<[_]>>>()
49 {
50 Ok(labels) => labels,
51 Err(error) => {
52 send_value!(tx, Err(error.into()));
53 return;
54 }
55 };
56 loop {
57 match sqlite3_step(statement) {
58 SQLITE_BUSY => {
59 continue;
60 }
61 SQLITE_DONE => {
62 if sqlite3_stmt_readonly(statement) == 0 {
63 send_value!(
64 tx,
65 Ok(QueryResult::Affected(RowsAffected {
66 rows_affected: sqlite3_changes64(connection) as u64,
67 last_affected_id: Some(sqlite3_last_insert_rowid(connection)),
68 }))
69 );
70 }
71 break;
72 }
73 SQLITE_ROW => {
74 let values = match (0..count)
75 .map(|i| extract_value(statement, i))
76 .collect::<Result<_>>()
77 {
78 Ok(value) => value,
79 Err(error) => {
80 send_value!(tx, Err(error));
81 return;
82 }
83 };
84 send_value!(
85 tx,
86 Ok(QueryResult::Row(RowLabeled {
87 labels: labels.clone(),
88 values: values,
89 }))
90 )
91 }
92 _ => {
93 send_value!(
94 tx,
95 Err(Error::msg(
96 error_message_from_ptr(&sqlite3_errmsg(sqlite3_db_handle(
97 statement,
98 )))
99 .to_string(),
100 ))
101 );
102 return;
103 }
104 }
105 }
106 }
107 }
108
109 pub(crate) fn do_run_unprepared(
110 connection: *mut sqlite3,
111 sql: &str,
112 tx: Sender<Result<QueryResult>>,
113 ) {
114 unsafe {
115 let sql = sql.trim();
116 let mut it = sql.as_ptr() as *const c_char;
117 let mut len = sql.len();
118 loop {
119 let (statement, tail) = {
120 let mut statement = SQLitePrepared::new(CBox::new(ptr::null_mut(), |p| {
121 sqlite3_finalize(p);
122 }));
123 let mut sql_tail = ptr::null();
124 let rc = sqlite3_prepare_v2(
125 connection,
126 it,
127 len as c_int,
128 &mut *statement.statement,
129 &mut sql_tail,
130 );
131 if rc != SQLITE_OK {
132 send_value!(
133 tx,
134 Err(Error::msg(
135 error_message_from_ptr(&sqlite3_errmsg(connection)).to_string(),
136 ))
137 );
138 return;
139 }
140 (statement, sql_tail)
141 };
142 Self::do_run_prepared(connection, statement.statement(), tx.clone());
143 len = if tail != ptr::null() {
144 len - tail.offset_from_unsigned(it)
145 } else {
146 0
147 };
148 if len == 0 {
149 break;
150 }
151 it = tail;
152 }
153 };
154 }
155}
156
157impl Executor for SQLiteConnection {
158 type Driver = SQLiteDriver;
159
160 fn driver(&self) -> &Self::Driver {
161 &SQLiteDriver {}
162 }
163
164 async fn prepare(&mut self, sql: String) -> Result<Query<Self::Driver>> {
165 let connection = AtomicPtr::new(*self.connection);
166 let context = format!(
167 "While preparing the query:\n{}",
168 truncate_long!(sql.as_str())
169 );
170 let prepared = spawn_blocking(move || unsafe {
171 let connection = connection.load(Ordering::Relaxed);
172 let len = sql.len();
173 let sql = CString::new(sql.as_bytes())?;
174 let mut statement = CBox::new(ptr::null_mut(), |p| {
175 sqlite3_finalize(p);
176 });
177 let mut tail = ptr::null();
178 let rc = sqlite3_prepare_v2(
179 connection,
180 sql.as_ptr(),
181 len as c_int,
182 &mut *statement,
183 &mut tail,
184 );
185 if rc != SQLITE_OK {
186 let error =
187 Error::msg(error_message_from_ptr(&sqlite3_errmsg(connection)).to_string())
188 .context(context);
189 log::error!("{:#}", error);
190 return Err(error);
191 }
192 if tail != ptr::null() && *tail != '\0' as i8 {
193 let error = Error::msg(format!(
194 "Cannot prepare more than one statement at a time (remaining: {})",
195 CStr::from_ptr(tail).to_str().unwrap_or("unprintable")
196 ))
197 .context(context);
198 log::error!("{:#}", error);
199 return Err(error);
200 }
201 Ok(statement)
202 })
203 .await?;
204 Ok(SQLitePrepared::new(prepared?).into())
205 }
206
207 fn run<'s>(
208 &'s mut self,
209 query: impl AsQuery<Self::Driver> + 's,
210 ) -> impl Stream<Item = Result<QueryResult>> + Send {
211 let mut query = query.as_query();
212 let context = Arc::new(format!("While executing the query:\n{}", query.as_mut()));
213 let (tx, rx) = flume::unbounded::<Result<QueryResult>>();
214 let connection = AtomicPtr::new(*self.connection);
215 let mut owned = mem::take(query.as_mut());
216 let join = spawn_blocking(move || {
217 match &mut owned {
218 Query::Raw(query) => {
219 Self::do_run_unprepared(connection.load(Ordering::Relaxed), query, tx);
220 }
221 Query::Prepared(prepared) => Self::do_run_prepared(
222 connection.load(Ordering::Relaxed),
223 prepared.statement(),
224 tx,
225 ),
226 }
227 owned
228 });
229 try_stream! {
230 while let Ok(result) = rx.recv_async().await {
231 yield result.map_err(|e| {
232 let error = e.context(context.clone());
233 log::error!("{:#}", error);
234 error
235 })?;
236 }
237 *query.as_mut() = mem::take(&mut join.await?);
238 }
239 }
240}
241
242impl Connection for SQLiteConnection {
243 #[allow(refining_impl_trait)]
244 async fn connect(url: Cow<'static, str>) -> Result<SQLiteConnection> {
245 let context = || format!("While trying to connect to `{}`", truncate_long!(url));
246 let prefix = format!("{}://", <Self::Driver as Driver>::NAME);
247 if !url.starts_with(&prefix) {
248 let error = Error::msg(format!(
249 "SQLite connection url must start with `{}`",
250 &prefix
251 ))
252 .context(context());
253 log::error!("{:#}", error);
254 return Err(error);
255 }
256 let url = CString::new(format!("file:{}", url.trim_start_matches(&prefix)))
257 .with_context(context)?;
258 let mut connection;
259 unsafe {
260 connection = CBox::new(ptr::null_mut(), |p| {
261 if sqlite3_close(p) != SQLITE_OK {
262 log::error!("Could not close sqlite connection")
263 }
264 });
265 let rc = sqlite3_open_v2(
266 url.as_ptr(),
267 &mut *connection,
268 SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_URI,
269 ptr::null(),
270 );
271 if rc != SQLITE_OK {
272 let error =
273 Error::msg(error_message_from_ptr(&sqlite3_errmsg(*connection)).to_string())
274 .context(context());
275 log::error!("{:#}", error);
276 return Err(error);
277 }
278 }
279 Ok(Self {
280 connection,
281 _transaction: false,
282 })
283 }
284
285 #[allow(refining_impl_trait)]
286 fn begin(&mut self) -> impl Future<Output = Result<SQLiteTransaction<'_>>> {
287 SQLiteTransaction::new(self)
288 }
289}