1use crate::{
2 CBox, SqliteDriver, SqlitePrepared, SqliteTransaction, error_message_from_ptr,
3 extract::{extract_name, extract_value},
4};
5use async_stream::{stream, try_stream};
6use libsqlite3_sys::{
7 SQLITE_BUSY, SQLITE_DONE, SQLITE_OK, SQLITE_OPEN_CREATE, SQLITE_OPEN_READWRITE,
8 SQLITE_OPEN_URI, SQLITE_ROW, sqlite3, sqlite3_changes64, sqlite3_close, sqlite3_column_count,
9 sqlite3_db_handle, sqlite3_errmsg, sqlite3_finalize, sqlite3_last_insert_rowid,
10 sqlite3_open_v2, sqlite3_prepare_v2, sqlite3_step, sqlite3_stmt, sqlite3_stmt_readonly,
11};
12use std::{
13 borrow::Cow,
14 ffi::{CStr, CString, c_char, c_int},
15 pin::pin,
16 ptr,
17 sync::{
18 Arc,
19 atomic::{AtomicPtr, Ordering},
20 },
21};
22use tank_core::{
23 Connection, Driver, Error, ErrorContext, Executor, Query, QueryResult, Result, RowLabeled,
24 RowsAffected,
25 future::Either,
26 printable_query,
27 stream::{Stream, StreamExt},
28};
29use tokio::task::spawn_blocking;
30
31pub struct SqliteConnection {
32 pub(crate) connection: CBox<*mut sqlite3>,
33 pub(crate) _transaction: bool,
34}
35
36impl SqliteConnection {
37 pub(crate) fn run_prepared(
38 &mut self,
39 statement: CBox<*mut sqlite3_stmt>,
40 ) -> impl Stream<Item = Result<QueryResult>> {
41 unsafe {
42 stream! {
43 let count = sqlite3_column_count(*statement);
44 let labels = (0..count)
45 .map(|i| extract_name(*statement, i))
46 .collect::<Result<Arc<[_]>>>()?;
47 loop {
48 match sqlite3_step(*statement) {
49 SQLITE_BUSY => {
50 continue;
51 }
52 SQLITE_DONE => {
53 if sqlite3_stmt_readonly(*statement) == 0 {
54 yield Ok(QueryResult::Affected(RowsAffected {
55 rows_affected: sqlite3_changes64(*self.connection) as u64,
56 last_affected_id: Some(sqlite3_last_insert_rowid(*self.connection)),
57 }))
58 }
59 break;
60 }
61 SQLITE_ROW => {
62 yield Ok(QueryResult::RowLabeled(RowLabeled {
63 labels: labels.clone(),
64 values: (0..count).map(|i| extract_value(*statement, i)).collect()?,
65 }))
66 }
67 _ => {
68 let error = Error::msg(
69 error_message_from_ptr(&sqlite3_errmsg(sqlite3_db_handle(*statement)))
70 .to_string(),
71 );
72 log::error!("{:#}", error);
73 yield Err(error);
74 }
75 }
76 }
77 }
78 }
79 }
80
81 pub(crate) fn run_unprepared(
82 &mut self,
83 sql: String,
84 ) -> impl Stream<Item = Result<QueryResult>> {
85 try_stream! {
86 let mut len = sql.trim_end().len();
87 let buff = sql.into_bytes();
88 let mut it = CBox::new(buff.as_ptr() as *const c_char, |_| {});
89 loop {
90 let connection = CBox::new(*self.connection, |_| {});
91 let sql = CBox::new(*it, |_| {});
92 let (statement, tail) = spawn_blocking(move || unsafe {
93 let mut statement = CBox::new(ptr::null_mut(), |p| {
94 sqlite3_finalize(p);
95 });
96 let mut sql_tail = CBox::new(ptr::null(), |_| {});
97 let rc = sqlite3_prepare_v2(
98 *connection,
99 *sql,
100 len as c_int,
101 &mut *statement,
102 &mut *sql_tail,
103 );
104 if rc != SQLITE_OK {
105 return Err(Error::msg(
106 error_message_from_ptr(&sqlite3_errmsg(*connection)).to_string(),
107 ));
108 }
109 Ok((statement, sql_tail))
110 })
111 .await??;
112 let mut stream = pin!(self.run_prepared(statement));
113 while let Some(value) = stream.next().await {
114 yield value?
115 }
116 unsafe {
117 len = if *tail != ptr::null() {
118 len - tail.offset_from_unsigned(*it)
119 } else {
120 0
121 };
122 if len == 0 {
123 break;
124 }
125 }
126 *it = *tail;
127 }
128 }
129 }
130}
131
132impl Executor for SqliteConnection {
133 type Driver = SqliteDriver;
134
135 fn driver(&self) -> &Self::Driver {
136 &SqliteDriver {}
137 }
138
139 async fn prepare(&mut self, sql: String) -> Result<Query<Self::Driver>> {
140 let connection = AtomicPtr::new(*self.connection);
141 let context = format!(
142 "Failed to prepare the query:\n{}",
143 printable_query!(sql.as_str())
144 );
145 let prepared = spawn_blocking(move || unsafe {
146 let connection = connection.load(Ordering::Relaxed);
147 let len = sql.len();
148 let sql = CString::new(sql.as_bytes())?;
149 let mut statement = CBox::new(ptr::null_mut(), |p| {
150 sqlite3_finalize(p);
151 });
152 let mut tail = ptr::null();
153 let rc = sqlite3_prepare_v2(
154 connection,
155 sql.as_ptr(),
156 len as c_int,
157 &mut *statement,
158 &mut tail,
159 );
160 if rc != SQLITE_OK {
161 let error =
162 Error::msg(error_message_from_ptr(&sqlite3_errmsg(connection)).to_string())
163 .context(context);
164 log::error!("{:#}", error);
165 return Err(error);
166 }
167 if tail != ptr::null() && *tail != '\0' as i8 {
168 let error = Error::msg(format!(
169 "Cannot prepare more than one statement at a time (remaining: {})",
170 CStr::from_ptr(tail).to_str().unwrap_or("unprintable")
171 ))
172 .context(context);
173 log::error!("{:#}", error);
174 return Err(error);
175 }
176 Ok(statement)
177 })
178 .await?;
179 Ok(SqlitePrepared::new(prepared?).into())
180 }
181
182 fn run(
183 &mut self,
184 query: Query<Self::Driver>,
185 ) -> impl Stream<Item = Result<QueryResult>> + Send {
186 match query {
187 Query::Raw(sql) => Either::Left(self.run_unprepared(sql)),
188 Query::Prepared(prepared) => Either::Right(self.run_prepared(prepared.statement)),
189 }
190 }
191}
192
193impl Connection for SqliteConnection {
194 #[allow(refining_impl_trait)]
195 async fn connect(url: Cow<'static, str>) -> Result<SqliteConnection> {
196 let prefix = format!("{}://", <Self::Driver as Driver>::NAME);
197 if !url.starts_with(&prefix) {
198 let error = Error::msg(format!(
199 "Sqlite connection url must start with `{}`",
200 &prefix
201 ));
202 log::error!("{:#}", error);
203 return Err(error);
204 }
205 let url = CString::new(format!("file:{}", url.trim_start_matches(&prefix)))
206 .with_context(|| format!("Invalid database url `{}`", url))?;
207 let mut connection;
208 unsafe {
209 connection = CBox::new(ptr::null_mut(), |p| {
210 if sqlite3_close(p) != SQLITE_OK {
211 log::error!("Could not close sqlite connection")
212 }
213 });
214 let rc = sqlite3_open_v2(
215 url.as_ptr(),
216 &mut *connection,
217 SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_URI,
218 ptr::null(),
219 );
220 if rc != SQLITE_OK {
221 let error =
222 Error::msg(error_message_from_ptr(&sqlite3_errmsg(*connection)).to_string())
223 .context(format!(
224 "Failed to connect to database url `{}`",
225 url.to_str().unwrap_or("unprintable value")
226 ));
227 log::error!("{:#}", error);
228 return Err(error);
229 }
230 }
231 Ok(Self {
232 connection,
233 _transaction: false,
234 })
235 }
236
237 #[allow(refining_impl_trait)]
238 fn begin(&mut self) -> impl Future<Output = Result<SqliteTransaction>> {
239 SqliteTransaction::new(self)
240 }
241}