sqlx_sqlite/connection/
worker.rs

1use std::borrow::Cow;
2use std::future::Future;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::Arc;
5use std::thread;
6
7use futures_channel::oneshot;
8use futures_intrusive::sync::{Mutex, MutexGuard};
9use tracing::span::Span;
10
11use sqlx_core::describe::Describe;
12use sqlx_core::error::Error;
13use sqlx_core::transaction::{
14    begin_ansi_transaction_sql, commit_ansi_transaction_sql, rollback_ansi_transaction_sql,
15};
16use sqlx_core::Either;
17
18use crate::connection::describe::describe;
19use crate::connection::establish::EstablishParams;
20use crate::connection::execute;
21use crate::connection::ConnectionState;
22use crate::{Sqlite, SqliteArguments, SqliteQueryResult, SqliteRow, SqliteStatement};
23
24use super::serialize::{deserialize, serialize, SchemaName, SqliteOwnedBuf};
25
26// Each SQLite connection has a dedicated thread.
27
28// TODO: Tweak this so that we can use a thread pool per pool of SQLite3 connections to reduce
29//       OS resource usage. Low priority because a high concurrent load for SQLite3 is very
30//       unlikely.
31
32pub(crate) struct ConnectionWorker {
33    command_tx: flume::Sender<(Command, tracing::Span)>,
34    /// Mutex for locking access to the database.
35    pub(crate) shared: Arc<WorkerSharedState>,
36}
37
38pub(crate) struct WorkerSharedState {
39    transaction_depth: AtomicUsize,
40    cached_statements_size: AtomicUsize,
41    pub(crate) conn: Mutex<ConnectionState>,
42}
43
44impl WorkerSharedState {
45    pub(crate) fn get_transaction_depth(&self) -> usize {
46        self.transaction_depth.load(Ordering::Acquire)
47    }
48
49    pub(crate) fn get_cached_statements_size(&self) -> usize {
50        self.cached_statements_size.load(Ordering::Acquire)
51    }
52}
53
54enum Command {
55    Prepare {
56        query: Box<str>,
57        tx: oneshot::Sender<Result<SqliteStatement<'static>, Error>>,
58    },
59    Describe {
60        query: Box<str>,
61        tx: oneshot::Sender<Result<Describe<Sqlite>, Error>>,
62    },
63    Execute {
64        query: Box<str>,
65        arguments: Option<SqliteArguments<'static>>,
66        persistent: bool,
67        tx: flume::Sender<Result<Either<SqliteQueryResult, SqliteRow>, Error>>,
68        limit: Option<usize>,
69    },
70    Serialize {
71        schema: Option<SchemaName>,
72        tx: oneshot::Sender<Result<SqliteOwnedBuf, Error>>,
73    },
74    Deserialize {
75        schema: Option<SchemaName>,
76        data: SqliteOwnedBuf,
77        read_only: bool,
78        tx: oneshot::Sender<Result<(), Error>>,
79    },
80    Begin {
81        tx: rendezvous_oneshot::Sender<Result<(), Error>>,
82        statement: Option<Cow<'static, str>>,
83    },
84    Commit {
85        tx: rendezvous_oneshot::Sender<Result<(), Error>>,
86    },
87    Rollback {
88        tx: Option<rendezvous_oneshot::Sender<Result<(), Error>>>,
89    },
90    UnlockDb,
91    ClearCache {
92        tx: oneshot::Sender<()>,
93    },
94    Ping {
95        tx: oneshot::Sender<()>,
96    },
97    Shutdown {
98        tx: oneshot::Sender<()>,
99    },
100}
101
102impl ConnectionWorker {
103    pub(crate) async fn establish(params: EstablishParams) -> Result<Self, Error> {
104        let (establish_tx, establish_rx) = oneshot::channel();
105
106        thread::Builder::new()
107            .name(params.thread_name.clone())
108            .spawn(move || {
109                let (command_tx, command_rx) = flume::bounded(params.command_channel_size);
110
111                let conn = match params.establish() {
112                    Ok(conn) => conn,
113                    Err(e) => {
114                        establish_tx.send(Err(e)).ok();
115                        return;
116                    }
117                };
118
119                let shared = Arc::new(WorkerSharedState {
120                    transaction_depth: AtomicUsize::new(0),
121                    cached_statements_size: AtomicUsize::new(0),
122                    // note: must be fair because in `Command::UnlockDb` we unlock the mutex
123                    // and then immediately try to relock it; an unfair mutex would immediately
124                    // grant us the lock even if another task is waiting.
125                    conn: Mutex::new(conn, true),
126                });
127                let mut conn = shared.conn.try_lock().unwrap();
128
129                if establish_tx
130                    .send(Ok(Self {
131                        command_tx,
132                        shared: Arc::clone(&shared),
133                    }))
134                    .is_err()
135                {
136                    return;
137                }
138
139                // If COMMIT or ROLLBACK is processed but not acknowledged, there would be another
140                // ROLLBACK sent when the `Transaction` drops. We need to ignore it otherwise we
141                // would rollback an already completed transaction.
142                let mut ignore_next_start_rollback = false;
143
144                for (cmd, span) in command_rx {
145                    let _guard = span.enter();
146                    match cmd {
147                        Command::Prepare { query, tx } => {
148                            tx.send(prepare(&mut conn, &query).map(|prepared| {
149                                update_cached_statements_size(
150                                    &conn,
151                                    &shared.cached_statements_size,
152                                );
153                                prepared
154                            }))
155                            .ok();
156                        }
157                        Command::Describe { query, tx } => {
158                            tx.send(describe(&mut conn, &query)).ok();
159                        }
160                        Command::Execute {
161                            query,
162                            arguments,
163                            persistent,
164                            tx,
165                            limit
166                        } => {
167                            let iter = match execute::iter(&mut conn, &query, arguments, persistent)
168                            {
169                                Ok(iter) => iter,
170                                Err(e) => {
171                                    tx.send(Err(e)).ok();
172                                    continue;
173                                }
174                            };
175
176                            match limit {
177                                None => {
178                                    for res in iter {
179                                        let has_error = res.is_err();
180                                        if tx.send(res).is_err() || has_error {
181                                            break;
182                                        }
183                                    }
184                                },
185                                Some(limit) => {
186                                    let mut iter = iter;
187                                    let mut rows_returned = 0;
188
189                                    while let Some(res) = iter.next() {
190                                        if let Ok(ok) = &res {
191                                            if ok.is_right() {
192                                                rows_returned += 1;
193                                                if rows_returned >= limit {
194                                                    drop(iter);
195                                                    let _ = tx.send(res);
196                                                    break;
197                                                }
198                                            }
199                                        }
200                                        let has_error = res.is_err();
201                                        if tx.send(res).is_err() || has_error {
202                                            break;
203                                        }
204                                    }
205                                },
206                            }
207
208                            update_cached_statements_size(&conn, &shared.cached_statements_size);
209                        }
210                        Command::Begin { tx, statement } => {
211                            let depth = shared.transaction_depth.load(Ordering::Acquire);
212
213                            let statement = match statement {
214                                // custom `BEGIN` statements are not allowed if
215                                // we're already in a transaction (we need to
216                                // issue a `SAVEPOINT` instead)
217                                Some(_) if depth > 0 => {
218                                    if tx.blocking_send(Err(Error::InvalidSavePointStatement)).is_err() {
219                                        break;
220                                    }
221                                    continue;
222                                },
223                                Some(statement) => statement,
224                                None => begin_ansi_transaction_sql(depth),
225                            };
226                            let res =
227                                conn.handle
228                                    .exec(statement)
229                                    .map(|_| {
230                                        shared.transaction_depth.fetch_add(1, Ordering::Release);
231                                    });
232                            let res_ok = res.is_ok();
233
234                            if tx.blocking_send(res).is_err() && res_ok {
235                                // The BEGIN was processed but not acknowledged. This means no
236                                // `Transaction` was created and so there is no way to commit /
237                                // rollback this transaction. We need to roll it back
238                                // immediately otherwise it would remain started forever.
239                                if let Err(error) = conn
240                                    .handle
241                                    .exec(rollback_ansi_transaction_sql(depth + 1))
242                                    .map(|_| {
243                                        shared.transaction_depth.fetch_sub(1, Ordering::Release);
244                                    })
245                                {
246                                    // The rollback failed. To prevent leaving the connection
247                                    // in an inconsistent state we shutdown this worker which
248                                    // causes any subsequent operation on the connection to fail.
249                                    tracing::error!(%error, "failed to rollback cancelled transaction");
250                                    break;
251                                }
252                            }
253                        }
254                        Command::Commit { tx } => {
255                            let depth = shared.transaction_depth.load(Ordering::Acquire);
256
257                            let res = if depth > 0 {
258                                conn.handle
259                                    .exec(commit_ansi_transaction_sql(depth))
260                                    .map(|_| {
261                                        shared.transaction_depth.fetch_sub(1, Ordering::Release);
262                                    })
263                            } else {
264                                Ok(())
265                            };
266                            let res_ok = res.is_ok();
267
268                            if tx.blocking_send(res).is_err() && res_ok {
269                                // The COMMIT was processed but not acknowledged. This means that
270                                // the `Transaction` doesn't know it was committed and will try to
271                                // rollback on drop. We need to ignore that rollback.
272                                ignore_next_start_rollback = true;
273                            }
274                        }
275                        Command::Rollback { tx } => {
276                            if ignore_next_start_rollback && tx.is_none() {
277                                ignore_next_start_rollback = false;
278                                continue;
279                            }
280
281                            let depth = shared.transaction_depth.load(Ordering::Acquire);
282
283                            let res = if depth > 0 {
284                                conn.handle
285                                    .exec(rollback_ansi_transaction_sql(depth))
286                                    .map(|_| {
287                                        shared.transaction_depth.fetch_sub(1, Ordering::Release);
288                                    })
289                            } else {
290                                Ok(())
291                            };
292
293                            let res_ok = res.is_ok();
294
295                            if let Some(tx) = tx {
296                                if tx.blocking_send(res).is_err() && res_ok {
297                                    // The ROLLBACK was processed but not acknowledged. This means
298                                    // that the `Transaction` doesn't know it was rolled back and
299                                    // will try to rollback again on drop. We need to ignore that
300                                    // rollback.
301                                    ignore_next_start_rollback = true;
302                                }
303                            }
304                        }
305                        Command::Serialize { schema, tx } => {
306                            tx.send(serialize(&mut conn, schema)).ok();
307                        }
308                        Command::Deserialize { schema, data, read_only, tx } => {
309                            tx.send(deserialize(&mut conn, schema, data, read_only)).ok();
310                        }
311                        Command::ClearCache { tx } => {
312                            conn.statements.clear();
313                            update_cached_statements_size(&conn, &shared.cached_statements_size);
314                            tx.send(()).ok();
315                        }
316                        Command::UnlockDb => {
317                            drop(conn);
318                            conn = futures_executor::block_on(shared.conn.lock());
319                        }
320                        Command::Ping { tx } => {
321                            tx.send(()).ok();
322                        }
323                        Command::Shutdown { tx } => {
324                            // drop the connection references before sending confirmation
325                            // and ending the command loop
326                            drop(conn);
327                            drop(shared);
328                            let _ = tx.send(());
329                            return;
330                        }
331                    }
332                }
333            })?;
334
335        establish_rx.await.map_err(|_| Error::WorkerCrashed)?
336    }
337
338    pub(crate) async fn prepare(&mut self, query: &str) -> Result<SqliteStatement<'static>, Error> {
339        self.oneshot_cmd(|tx| Command::Prepare {
340            query: query.into(),
341            tx,
342        })
343        .await?
344    }
345
346    pub(crate) async fn describe(&mut self, query: &str) -> Result<Describe<Sqlite>, Error> {
347        self.oneshot_cmd(|tx| Command::Describe {
348            query: query.into(),
349            tx,
350        })
351        .await?
352    }
353
354    pub(crate) async fn execute(
355        &mut self,
356        query: &str,
357        args: Option<SqliteArguments<'_>>,
358        chan_size: usize,
359        persistent: bool,
360        limit: Option<usize>,
361    ) -> Result<flume::Receiver<Result<Either<SqliteQueryResult, SqliteRow>, Error>>, Error> {
362        let (tx, rx) = flume::bounded(chan_size);
363
364        self.command_tx
365            .send_async((
366                Command::Execute {
367                    query: query.into(),
368                    arguments: args.map(SqliteArguments::into_static),
369                    persistent,
370                    tx,
371                    limit,
372                },
373                Span::current(),
374            ))
375            .await
376            .map_err(|_| Error::WorkerCrashed)?;
377
378        Ok(rx)
379    }
380
381    pub(crate) async fn begin(
382        &mut self,
383        statement: Option<Cow<'static, str>>,
384    ) -> Result<(), Error> {
385        self.oneshot_cmd_with_ack(|tx| Command::Begin { tx, statement })
386            .await?
387    }
388
389    pub(crate) async fn commit(&mut self) -> Result<(), Error> {
390        self.oneshot_cmd_with_ack(|tx| Command::Commit { tx })
391            .await?
392    }
393
394    pub(crate) async fn rollback(&mut self) -> Result<(), Error> {
395        self.oneshot_cmd_with_ack(|tx| Command::Rollback { tx: Some(tx) })
396            .await?
397    }
398
399    pub(crate) fn start_rollback(&mut self) -> Result<(), Error> {
400        self.command_tx
401            .send((Command::Rollback { tx: None }, Span::current()))
402            .map_err(|_| Error::WorkerCrashed)
403    }
404
405    pub(crate) async fn ping(&mut self) -> Result<(), Error> {
406        self.oneshot_cmd(|tx| Command::Ping { tx }).await
407    }
408
409    pub(crate) async fn deserialize(
410        &mut self,
411        schema: Option<SchemaName>,
412        data: SqliteOwnedBuf,
413        read_only: bool,
414    ) -> Result<(), Error> {
415        self.oneshot_cmd(|tx| Command::Deserialize {
416            schema,
417            data,
418            read_only,
419            tx,
420        })
421        .await?
422    }
423
424    pub(crate) async fn serialize(
425        &mut self,
426        schema: Option<SchemaName>,
427    ) -> Result<SqliteOwnedBuf, Error> {
428        self.oneshot_cmd(|tx| Command::Serialize { schema, tx })
429            .await?
430    }
431
432    async fn oneshot_cmd<F, T>(&mut self, command: F) -> Result<T, Error>
433    where
434        F: FnOnce(oneshot::Sender<T>) -> Command,
435    {
436        let (tx, rx) = oneshot::channel();
437
438        self.command_tx
439            .send_async((command(tx), Span::current()))
440            .await
441            .map_err(|_| Error::WorkerCrashed)?;
442
443        rx.await.map_err(|_| Error::WorkerCrashed)
444    }
445
446    async fn oneshot_cmd_with_ack<F, T>(&mut self, command: F) -> Result<T, Error>
447    where
448        F: FnOnce(rendezvous_oneshot::Sender<T>) -> Command,
449    {
450        let (tx, rx) = rendezvous_oneshot::channel();
451
452        self.command_tx
453            .send_async((command(tx), Span::current()))
454            .await
455            .map_err(|_| Error::WorkerCrashed)?;
456
457        rx.recv().await.map_err(|_| Error::WorkerCrashed)
458    }
459
460    pub(crate) async fn clear_cache(&mut self) -> Result<(), Error> {
461        self.oneshot_cmd(|tx| Command::ClearCache { tx }).await
462    }
463
464    pub(crate) async fn unlock_db(&mut self) -> Result<MutexGuard<'_, ConnectionState>, Error> {
465        let (guard, res) = futures_util::future::join(
466            // we need to join the wait queue for the lock before we send the message
467            self.shared.conn.lock(),
468            self.command_tx
469                .send_async((Command::UnlockDb, Span::current())),
470        )
471        .await;
472
473        res.map_err(|_| Error::WorkerCrashed)?;
474
475        Ok(guard)
476    }
477
478    /// Send a command to the worker to shut down the processing thread.
479    ///
480    /// A `WorkerCrashed` error may be returned if the thread has already stopped.
481    pub(crate) fn shutdown(&mut self) -> impl Future<Output = Result<(), Error>> {
482        let (tx, rx) = oneshot::channel();
483
484        let send_res = self
485            .command_tx
486            .send((Command::Shutdown { tx }, Span::current()))
487            .map_err(|_| Error::WorkerCrashed);
488
489        async move {
490            send_res?;
491
492            // wait for the response
493            rx.await.map_err(|_| Error::WorkerCrashed)
494        }
495    }
496}
497
498fn prepare(conn: &mut ConnectionState, query: &str) -> Result<SqliteStatement<'static>, Error> {
499    // prepare statement object (or checkout from cache)
500    let statement = conn.statements.get(query, true)?;
501
502    let mut parameters = 0;
503    let mut columns = None;
504    let mut column_names = None;
505
506    while let Some(statement) = statement.prepare_next(&mut conn.handle)? {
507        parameters += statement.handle.bind_parameter_count();
508
509        // the first non-empty statement is chosen as the statement we pull columns from
510        if !statement.columns.is_empty() && columns.is_none() {
511            columns = Some(Arc::clone(statement.columns));
512            column_names = Some(Arc::clone(statement.column_names));
513        }
514    }
515
516    Ok(SqliteStatement {
517        sql: Cow::Owned(query.to_string()),
518        columns: columns.unwrap_or_default(),
519        column_names: column_names.unwrap_or_default(),
520        parameters,
521    })
522}
523
524fn update_cached_statements_size(conn: &ConnectionState, size: &AtomicUsize) {
525    size.store(conn.statements.len(), Ordering::Release);
526}
527
528// A oneshot channel where send completes only after the receiver receives the value.
529mod rendezvous_oneshot {
530    use super::oneshot::{self, Canceled};
531
532    pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
533        let (inner_tx, inner_rx) = oneshot::channel();
534        (Sender { inner: inner_tx }, Receiver { inner: inner_rx })
535    }
536
537    pub struct Sender<T> {
538        inner: oneshot::Sender<(T, oneshot::Sender<()>)>,
539    }
540
541    impl<T> Sender<T> {
542        pub async fn send(self, value: T) -> Result<(), Canceled> {
543            let (ack_tx, ack_rx) = oneshot::channel();
544            self.inner.send((value, ack_tx)).map_err(|_| Canceled)?;
545            ack_rx.await
546        }
547
548        pub fn blocking_send(self, value: T) -> Result<(), Canceled> {
549            futures_executor::block_on(self.send(value))
550        }
551    }
552
553    pub struct Receiver<T> {
554        inner: oneshot::Receiver<(T, oneshot::Sender<()>)>,
555    }
556
557    impl<T> Receiver<T> {
558        pub async fn recv(self) -> Result<T, Canceled> {
559            let (value, ack_tx) = self.inner.await?;
560            ack_tx.send(()).map_err(|_| Canceled)?;
561            Ok(value)
562        }
563    }
564}