1use std::borrow::Cow;
2use std::future::Future;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use std::sync::Arc;
5use std::thread;
6
7use futures_channel::oneshot;
8use futures_intrusive::sync::{Mutex, MutexGuard};
9use tracing::span::Span;
10
11use sqlx_core::describe::Describe;
12use sqlx_core::error::Error;
13use sqlx_core::transaction::{
14 begin_ansi_transaction_sql, commit_ansi_transaction_sql, rollback_ansi_transaction_sql,
15};
16use sqlx_core::Either;
17
18use crate::connection::describe::describe;
19use crate::connection::establish::EstablishParams;
20use crate::connection::execute;
21use crate::connection::ConnectionState;
22use crate::{Sqlite, SqliteArguments, SqliteQueryResult, SqliteRow, SqliteStatement};
23
24use super::serialize::{deserialize, serialize, SchemaName, SqliteOwnedBuf};
25
26pub(crate) struct ConnectionWorker {
33 command_tx: flume::Sender<(Command, tracing::Span)>,
34 pub(crate) shared: Arc<WorkerSharedState>,
36}
37
38pub(crate) struct WorkerSharedState {
39 transaction_depth: AtomicUsize,
40 cached_statements_size: AtomicUsize,
41 pub(crate) conn: Mutex<ConnectionState>,
42}
43
44impl WorkerSharedState {
45 pub(crate) fn get_transaction_depth(&self) -> usize {
46 self.transaction_depth.load(Ordering::Acquire)
47 }
48
49 pub(crate) fn get_cached_statements_size(&self) -> usize {
50 self.cached_statements_size.load(Ordering::Acquire)
51 }
52}
53
54enum Command {
55 Prepare {
56 query: Box<str>,
57 tx: oneshot::Sender<Result<SqliteStatement<'static>, Error>>,
58 },
59 Describe {
60 query: Box<str>,
61 tx: oneshot::Sender<Result<Describe<Sqlite>, Error>>,
62 },
63 Execute {
64 query: Box<str>,
65 arguments: Option<SqliteArguments<'static>>,
66 persistent: bool,
67 tx: flume::Sender<Result<Either<SqliteQueryResult, SqliteRow>, Error>>,
68 limit: Option<usize>,
69 },
70 Serialize {
71 schema: Option<SchemaName>,
72 tx: oneshot::Sender<Result<SqliteOwnedBuf, Error>>,
73 },
74 Deserialize {
75 schema: Option<SchemaName>,
76 data: SqliteOwnedBuf,
77 read_only: bool,
78 tx: oneshot::Sender<Result<(), Error>>,
79 },
80 Begin {
81 tx: rendezvous_oneshot::Sender<Result<(), Error>>,
82 statement: Option<Cow<'static, str>>,
83 },
84 Commit {
85 tx: rendezvous_oneshot::Sender<Result<(), Error>>,
86 },
87 Rollback {
88 tx: Option<rendezvous_oneshot::Sender<Result<(), Error>>>,
89 },
90 UnlockDb,
91 ClearCache {
92 tx: oneshot::Sender<()>,
93 },
94 Ping {
95 tx: oneshot::Sender<()>,
96 },
97 Shutdown {
98 tx: oneshot::Sender<()>,
99 },
100}
101
102impl ConnectionWorker {
103 pub(crate) async fn establish(params: EstablishParams) -> Result<Self, Error> {
104 let (establish_tx, establish_rx) = oneshot::channel();
105
106 thread::Builder::new()
107 .name(params.thread_name.clone())
108 .spawn(move || {
109 let (command_tx, command_rx) = flume::bounded(params.command_channel_size);
110
111 let conn = match params.establish() {
112 Ok(conn) => conn,
113 Err(e) => {
114 establish_tx.send(Err(e)).ok();
115 return;
116 }
117 };
118
119 let shared = Arc::new(WorkerSharedState {
120 transaction_depth: AtomicUsize::new(0),
121 cached_statements_size: AtomicUsize::new(0),
122 conn: Mutex::new(conn, true),
126 });
127 let mut conn = shared.conn.try_lock().unwrap();
128
129 if establish_tx
130 .send(Ok(Self {
131 command_tx,
132 shared: Arc::clone(&shared),
133 }))
134 .is_err()
135 {
136 return;
137 }
138
139 let mut ignore_next_start_rollback = false;
143
144 for (cmd, span) in command_rx {
145 let _guard = span.enter();
146 match cmd {
147 Command::Prepare { query, tx } => {
148 tx.send(prepare(&mut conn, &query).map(|prepared| {
149 update_cached_statements_size(
150 &conn,
151 &shared.cached_statements_size,
152 );
153 prepared
154 }))
155 .ok();
156 }
157 Command::Describe { query, tx } => {
158 tx.send(describe(&mut conn, &query)).ok();
159 }
160 Command::Execute {
161 query,
162 arguments,
163 persistent,
164 tx,
165 limit
166 } => {
167 let iter = match execute::iter(&mut conn, &query, arguments, persistent)
168 {
169 Ok(iter) => iter,
170 Err(e) => {
171 tx.send(Err(e)).ok();
172 continue;
173 }
174 };
175
176 match limit {
177 None => {
178 for res in iter {
179 let has_error = res.is_err();
180 if tx.send(res).is_err() || has_error {
181 break;
182 }
183 }
184 },
185 Some(limit) => {
186 let mut iter = iter;
187 let mut rows_returned = 0;
188
189 while let Some(res) = iter.next() {
190 if let Ok(ok) = &res {
191 if ok.is_right() {
192 rows_returned += 1;
193 if rows_returned >= limit {
194 drop(iter);
195 let _ = tx.send(res);
196 break;
197 }
198 }
199 }
200 let has_error = res.is_err();
201 if tx.send(res).is_err() || has_error {
202 break;
203 }
204 }
205 },
206 }
207
208 update_cached_statements_size(&conn, &shared.cached_statements_size);
209 }
210 Command::Begin { tx, statement } => {
211 let depth = shared.transaction_depth.load(Ordering::Acquire);
212
213 let statement = match statement {
214 Some(_) if depth > 0 => {
218 if tx.blocking_send(Err(Error::InvalidSavePointStatement)).is_err() {
219 break;
220 }
221 continue;
222 },
223 Some(statement) => statement,
224 None => begin_ansi_transaction_sql(depth),
225 };
226 let res =
227 conn.handle
228 .exec(statement)
229 .map(|_| {
230 shared.transaction_depth.fetch_add(1, Ordering::Release);
231 });
232 let res_ok = res.is_ok();
233
234 if tx.blocking_send(res).is_err() && res_ok {
235 if let Err(error) = conn
240 .handle
241 .exec(rollback_ansi_transaction_sql(depth + 1))
242 .map(|_| {
243 shared.transaction_depth.fetch_sub(1, Ordering::Release);
244 })
245 {
246 tracing::error!(%error, "failed to rollback cancelled transaction");
250 break;
251 }
252 }
253 }
254 Command::Commit { tx } => {
255 let depth = shared.transaction_depth.load(Ordering::Acquire);
256
257 let res = if depth > 0 {
258 conn.handle
259 .exec(commit_ansi_transaction_sql(depth))
260 .map(|_| {
261 shared.transaction_depth.fetch_sub(1, Ordering::Release);
262 })
263 } else {
264 Ok(())
265 };
266 let res_ok = res.is_ok();
267
268 if tx.blocking_send(res).is_err() && res_ok {
269 ignore_next_start_rollback = true;
273 }
274 }
275 Command::Rollback { tx } => {
276 if ignore_next_start_rollback && tx.is_none() {
277 ignore_next_start_rollback = false;
278 continue;
279 }
280
281 let depth = shared.transaction_depth.load(Ordering::Acquire);
282
283 let res = if depth > 0 {
284 conn.handle
285 .exec(rollback_ansi_transaction_sql(depth))
286 .map(|_| {
287 shared.transaction_depth.fetch_sub(1, Ordering::Release);
288 })
289 } else {
290 Ok(())
291 };
292
293 let res_ok = res.is_ok();
294
295 if let Some(tx) = tx {
296 if tx.blocking_send(res).is_err() && res_ok {
297 ignore_next_start_rollback = true;
302 }
303 }
304 }
305 Command::Serialize { schema, tx } => {
306 tx.send(serialize(&mut conn, schema)).ok();
307 }
308 Command::Deserialize { schema, data, read_only, tx } => {
309 tx.send(deserialize(&mut conn, schema, data, read_only)).ok();
310 }
311 Command::ClearCache { tx } => {
312 conn.statements.clear();
313 update_cached_statements_size(&conn, &shared.cached_statements_size);
314 tx.send(()).ok();
315 }
316 Command::UnlockDb => {
317 drop(conn);
318 conn = futures_executor::block_on(shared.conn.lock());
319 }
320 Command::Ping { tx } => {
321 tx.send(()).ok();
322 }
323 Command::Shutdown { tx } => {
324 drop(conn);
327 drop(shared);
328 let _ = tx.send(());
329 return;
330 }
331 }
332 }
333 })?;
334
335 establish_rx.await.map_err(|_| Error::WorkerCrashed)?
336 }
337
338 pub(crate) async fn prepare(&mut self, query: &str) -> Result<SqliteStatement<'static>, Error> {
339 self.oneshot_cmd(|tx| Command::Prepare {
340 query: query.into(),
341 tx,
342 })
343 .await?
344 }
345
346 pub(crate) async fn describe(&mut self, query: &str) -> Result<Describe<Sqlite>, Error> {
347 self.oneshot_cmd(|tx| Command::Describe {
348 query: query.into(),
349 tx,
350 })
351 .await?
352 }
353
354 pub(crate) async fn execute(
355 &mut self,
356 query: &str,
357 args: Option<SqliteArguments<'_>>,
358 chan_size: usize,
359 persistent: bool,
360 limit: Option<usize>,
361 ) -> Result<flume::Receiver<Result<Either<SqliteQueryResult, SqliteRow>, Error>>, Error> {
362 let (tx, rx) = flume::bounded(chan_size);
363
364 self.command_tx
365 .send_async((
366 Command::Execute {
367 query: query.into(),
368 arguments: args.map(SqliteArguments::into_static),
369 persistent,
370 tx,
371 limit,
372 },
373 Span::current(),
374 ))
375 .await
376 .map_err(|_| Error::WorkerCrashed)?;
377
378 Ok(rx)
379 }
380
381 pub(crate) async fn begin(
382 &mut self,
383 statement: Option<Cow<'static, str>>,
384 ) -> Result<(), Error> {
385 self.oneshot_cmd_with_ack(|tx| Command::Begin { tx, statement })
386 .await?
387 }
388
389 pub(crate) async fn commit(&mut self) -> Result<(), Error> {
390 self.oneshot_cmd_with_ack(|tx| Command::Commit { tx })
391 .await?
392 }
393
394 pub(crate) async fn rollback(&mut self) -> Result<(), Error> {
395 self.oneshot_cmd_with_ack(|tx| Command::Rollback { tx: Some(tx) })
396 .await?
397 }
398
399 pub(crate) fn start_rollback(&mut self) -> Result<(), Error> {
400 self.command_tx
401 .send((Command::Rollback { tx: None }, Span::current()))
402 .map_err(|_| Error::WorkerCrashed)
403 }
404
405 pub(crate) async fn ping(&mut self) -> Result<(), Error> {
406 self.oneshot_cmd(|tx| Command::Ping { tx }).await
407 }
408
409 pub(crate) async fn deserialize(
410 &mut self,
411 schema: Option<SchemaName>,
412 data: SqliteOwnedBuf,
413 read_only: bool,
414 ) -> Result<(), Error> {
415 self.oneshot_cmd(|tx| Command::Deserialize {
416 schema,
417 data,
418 read_only,
419 tx,
420 })
421 .await?
422 }
423
424 pub(crate) async fn serialize(
425 &mut self,
426 schema: Option<SchemaName>,
427 ) -> Result<SqliteOwnedBuf, Error> {
428 self.oneshot_cmd(|tx| Command::Serialize { schema, tx })
429 .await?
430 }
431
432 async fn oneshot_cmd<F, T>(&mut self, command: F) -> Result<T, Error>
433 where
434 F: FnOnce(oneshot::Sender<T>) -> Command,
435 {
436 let (tx, rx) = oneshot::channel();
437
438 self.command_tx
439 .send_async((command(tx), Span::current()))
440 .await
441 .map_err(|_| Error::WorkerCrashed)?;
442
443 rx.await.map_err(|_| Error::WorkerCrashed)
444 }
445
446 async fn oneshot_cmd_with_ack<F, T>(&mut self, command: F) -> Result<T, Error>
447 where
448 F: FnOnce(rendezvous_oneshot::Sender<T>) -> Command,
449 {
450 let (tx, rx) = rendezvous_oneshot::channel();
451
452 self.command_tx
453 .send_async((command(tx), Span::current()))
454 .await
455 .map_err(|_| Error::WorkerCrashed)?;
456
457 rx.recv().await.map_err(|_| Error::WorkerCrashed)
458 }
459
460 pub(crate) async fn clear_cache(&mut self) -> Result<(), Error> {
461 self.oneshot_cmd(|tx| Command::ClearCache { tx }).await
462 }
463
464 pub(crate) async fn unlock_db(&mut self) -> Result<MutexGuard<'_, ConnectionState>, Error> {
465 let (guard, res) = futures_util::future::join(
466 self.shared.conn.lock(),
468 self.command_tx
469 .send_async((Command::UnlockDb, Span::current())),
470 )
471 .await;
472
473 res.map_err(|_| Error::WorkerCrashed)?;
474
475 Ok(guard)
476 }
477
478 pub(crate) fn shutdown(&mut self) -> impl Future<Output = Result<(), Error>> {
482 let (tx, rx) = oneshot::channel();
483
484 let send_res = self
485 .command_tx
486 .send((Command::Shutdown { tx }, Span::current()))
487 .map_err(|_| Error::WorkerCrashed);
488
489 async move {
490 send_res?;
491
492 rx.await.map_err(|_| Error::WorkerCrashed)
494 }
495 }
496}
497
498fn prepare(conn: &mut ConnectionState, query: &str) -> Result<SqliteStatement<'static>, Error> {
499 let statement = conn.statements.get(query, true)?;
501
502 let mut parameters = 0;
503 let mut columns = None;
504 let mut column_names = None;
505
506 while let Some(statement) = statement.prepare_next(&mut conn.handle)? {
507 parameters += statement.handle.bind_parameter_count();
508
509 if !statement.columns.is_empty() && columns.is_none() {
511 columns = Some(Arc::clone(statement.columns));
512 column_names = Some(Arc::clone(statement.column_names));
513 }
514 }
515
516 Ok(SqliteStatement {
517 sql: Cow::Owned(query.to_string()),
518 columns: columns.unwrap_or_default(),
519 column_names: column_names.unwrap_or_default(),
520 parameters,
521 })
522}
523
524fn update_cached_statements_size(conn: &ConnectionState, size: &AtomicUsize) {
525 size.store(conn.statements.len(), Ordering::Release);
526}
527
528mod rendezvous_oneshot {
530 use super::oneshot::{self, Canceled};
531
532 pub fn channel<T>() -> (Sender<T>, Receiver<T>) {
533 let (inner_tx, inner_rx) = oneshot::channel();
534 (Sender { inner: inner_tx }, Receiver { inner: inner_rx })
535 }
536
537 pub struct Sender<T> {
538 inner: oneshot::Sender<(T, oneshot::Sender<()>)>,
539 }
540
541 impl<T> Sender<T> {
542 pub async fn send(self, value: T) -> Result<(), Canceled> {
543 let (ack_tx, ack_rx) = oneshot::channel();
544 self.inner.send((value, ack_tx)).map_err(|_| Canceled)?;
545 ack_rx.await
546 }
547
548 pub fn blocking_send(self, value: T) -> Result<(), Canceled> {
549 futures_executor::block_on(self.send(value))
550 }
551 }
552
553 pub struct Receiver<T> {
554 inner: oneshot::Receiver<(T, oneshot::Sender<()>)>,
555 }
556
557 impl<T> Receiver<T> {
558 pub async fn recv(self) -> Result<T, Canceled> {
559 let (value, ack_tx) = self.inner.await?;
560 ack_tx.send(()).map_err(|_| Canceled)?;
561 Ok(value)
562 }
563 }
564}