tide_sqlx/
lib.rs

1//! A [Tide][] middleware which holds a pool of SQLx database connections, and automatically hands
2//! each [tide::Request][] a connection, which may transparently be either a database transaction,
3//! or a direct pooled database connection.
4//!
5//! By default, transactions are used for all http methods other than `GET` and `HEAD`.
6//!
7//! When using this, use the `SQLxRequestExt` extenstion trait to get the connection.
8//!
9//! ## Examples
10//!
11//! ### Basic
12//! ```no_run
13//! # #[async_std::main]
14//! # async fn main() -> anyhow::Result<()> {
15//! use sqlx::Acquire; // Or sqlx::prelude::*;
16//! use sqlx::postgres::Postgres;
17//!
18//! use tide_sqlx::SQLxMiddleware;
19//! use tide_sqlx::SQLxRequestExt;
20//!
21//! let mut app = tide::new();
22//! app.with(SQLxMiddleware::<Postgres>::new("postgres://localhost/a_database").await?);
23//!
24//! app.at("/").post(|req: tide::Request<()>| async move {
25//!     let mut pg_conn = req.sqlx_conn::<Postgres>().await;
26//!
27//!     sqlx::query("SELECT * FROM users")
28//!         .fetch_optional(pg_conn.acquire().await?)
29//!         .await;
30//!
31//!     Ok("")
32//! });
33//! # Ok(())
34//! # }
35//! ```
36//!
37//! ### From sqlx `PoolOptions` and with `ConnectOptions`
38//! ```no_run
39//! # #[async_std::main]
40//! # async fn main() -> anyhow::Result<()> {
41//! use log::LevelFilter;
42//! use sqlx::{Acquire, ConnectOptions}; // Or sqlx::prelude::*;
43//! use sqlx::postgres::{PgConnectOptions, PgPoolOptions, Postgres};
44//!
45//! use tide_sqlx::SQLxMiddleware;
46//! use tide_sqlx::SQLxRequestExt;
47//!
48//! let mut connect_opts = PgConnectOptions::new();
49//! connect_opts.log_statements(LevelFilter::Debug);
50//!
51//! let pg_pool = PgPoolOptions::new()
52//!     .max_connections(5)
53//!     .connect_with(connect_opts)
54//!     .await?;
55//!
56//! let mut app = tide::new();
57//! app.with(SQLxMiddleware::from(pg_pool));
58//!
59//! app.at("/").post(|req: tide::Request<()>| async move {
60//!     let mut pg_conn = req.sqlx_conn::<Postgres>().await;
61//!
62//!     sqlx::query("SELECT * FROM users")
63//!         .fetch_optional(pg_conn.acquire().await?)
64//!         .await;
65//!
66//!     Ok("")
67//! });
68//! # Ok(())
69//! # }
70//! ```
71//!
72//! ## Why you may want to use this
73//!
74//! Database transactions are very useful because they allow easy, assured rollback if something goes wrong.
75//! However, transactions incur extra runtime cost which is too expensive to justify for READ operations that _do not need_ this behavior.
76//!
77//! In order to allow transactions to be used seamlessly in endpoints, this middleware manages a transaction if one is deemed desirable.
78//!
79//! [tide::Request]: https://docs.rs/tide/0.15.0/tide/struct.Request.html
80//! [Tide]: https://docs.rs/tide/0.15.0/tide/
81
82#![allow(clippy::upper_case_acronyms)] // SQLxMiddleware
83#![cfg_attr(feature = "docs", feature(doc_cfg))]
84
85use std::fmt::{self, Debug};
86use std::ops::{Deref, DerefMut};
87use std::sync::Arc;
88
89use async_std::sync::{RwLock, RwLockWriteGuard};
90use sqlx::pool::{Pool, PoolConnection};
91use sqlx::{Database, Transaction};
92use tide::utils::async_trait;
93use tide::{http::Method, Middleware, Next, Request, Result};
94
95#[cfg(all(feature = "tracing", debug_assertions))]
96use tracing_crate::debug_span;
97#[cfg(feature = "tracing")]
98use tracing_crate::{info_span, Instrument};
99
100#[cfg(all(test, not(feature = "postgres")))]
101compile_error!("The tests must be run with --features=test");
102
103#[cfg(feature = "postgres")]
104#[cfg_attr(feature = "docs", doc(cfg(feature = "postgres")))]
105/// Helpers specific to Postgres
106pub mod postgres;
107
108#[doc(hidden)]
109pub enum ConnectionWrapInner<DB>
110where
111    DB: Database,
112    DB::Connection: Send + Sync + 'static,
113{
114    Transacting(Transaction<'static, DB>),
115    Plain(PoolConnection<DB>),
116}
117
118impl<DB> Debug for ConnectionWrapInner<DB>
119where
120    DB: Database,
121    DB::Connection: Send + Sync + 'static,
122{
123    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124        match self {
125            Self::Transacting(_) => f.debug_struct("ConnectionWrapInner::Transacting").finish(),
126            Self::Plain(_) => f.debug_struct("ConnectionWrapInner::Plain").finish(),
127        }
128    }
129}
130
131impl<DB> Deref for ConnectionWrapInner<DB>
132where
133    DB: Database,
134    DB::Connection: Send + Sync + 'static,
135{
136    type Target = DB::Connection;
137
138    fn deref(&self) -> &Self::Target {
139        match self {
140            ConnectionWrapInner::Plain(c) => c,
141            ConnectionWrapInner::Transacting(c) => c,
142        }
143    }
144}
145
146impl<DB> DerefMut for ConnectionWrapInner<DB>
147where
148    DB: Database,
149    DB::Connection: Send + Sync + 'static,
150{
151    fn deref_mut(&mut self) -> &mut Self::Target {
152        match self {
153            ConnectionWrapInner::Plain(c) => c,
154            ConnectionWrapInner::Transacting(c) => c,
155        }
156    }
157}
158
159#[doc(hidden)]
160pub type ConnectionWrap<DB> = Arc<RwLock<ConnectionWrapInner<DB>>>;
161
162/// This middleware holds a pool of SQLx database connections, and automatically hands each
163/// [tide::Request][] a connection, which may transparently be either a database transaction,
164/// or a direct pooled database connection.
165///
166/// By default, transactions are used for all http methods other than `GET` and `HEAD`.
167///
168/// When using this, use the `SQLxRequestExt` extenstion trait to get the connection.
169///
170/// ## Example
171///
172/// ```no_run
173/// # #[async_std::main]
174/// # async fn main() -> anyhow::Result<()> {
175/// use sqlx::Acquire; // Or sqlx::prelude::*;
176/// use sqlx::postgres::Postgres;
177///
178/// use tide_sqlx::SQLxMiddleware;
179/// use tide_sqlx::SQLxRequestExt;
180///
181/// let mut app = tide::new();
182/// app.with(SQLxMiddleware::<Postgres>::new("postgres://localhost/a_database").await?);
183///
184/// app.at("/").post(|req: tide::Request<()>| async move {
185///     let mut pg_conn = req.sqlx_conn::<Postgres>().await;
186///
187///     sqlx::query("SELECT * FROM users")
188///         .fetch_optional(pg_conn.acquire().await?)
189///         .await;
190///
191///     Ok("")
192/// });
193/// # Ok(())
194/// # }
195/// ```
196///
197/// [tide::Request]: https://docs.rs/tide/0.15.0/tide/struct.Request.html
198#[derive(Debug, Clone)]
199pub struct SQLxMiddleware<DB>
200where
201    DB: Database,
202    DB::Connection: Send + Sync + 'static,
203{
204    pool: Pool<DB>,
205}
206
207impl<DB> SQLxMiddleware<DB>
208where
209    DB: Database,
210    DB::Connection: Send + Sync + 'static,
211{
212    /// Create a new instance of `SQLxMiddleware`.
213    pub async fn new(pgurl: &'_ str) -> std::result::Result<Self, sqlx::Error> {
214        let pool: Pool<DB> = Pool::connect(pgurl).await?;
215        Ok(Self { pool })
216    }
217}
218
219impl<DB> AsRef<Pool<DB>> for SQLxMiddleware<DB>
220where
221    DB: Database,
222    DB::Connection: Send + Sync + 'static,
223{
224    fn as_ref(&self) -> &Pool<DB> {
225        &self.pool
226    }
227}
228
229impl<DB> From<Pool<DB>> for SQLxMiddleware<DB>
230where
231    DB: Database,
232    DB::Connection: Send + Sync + 'static,
233{
234    /// Create a new instance of `SQLxMiddleware` from a `sqlx::Pool`.
235    fn from(pool: Pool<DB>) -> Self {
236        Self { pool }
237    }
238}
239
240// This is complicated because of sqlx's typing. We would like a dynamic `sqlx::Executor`, however the Executor trait
241// cannot be made into an object because it has generic methods.
242// Rust does not allow this due to exponential fat-pointer table size.
243// See https://doc.rust-lang.org/error-index.html#method-has-generic-type-parameters for more information.
244//
245// In order to get a concrete type for both which we can deref to a `Connection` on, we make an enum with multiple types.
246// The types must be concrete and non-generic because the outer type much be fetchable from `Request::ext`, which is a typemap.
247//
248// The type of the enum must be in an `Arc` because we want to be able to tell it to commit at the end of the middleware
249// once we've gotten a response back. This is because anything in `Request::ext` is lost in the endpoint without manual movement
250// to the `Response`. Tide may someday be able to do this automatically but not as of 0.15. An `Arc` is the correct choice to keep
251// something between mutltiple owned contexts over a threaded futures executor.
252//
253// However interior mutability (`RwLock`) is also required because `Acquire` requires mutable self reference,
254// requiring that we gain mutable lock from the `Arc`, which is not possible with an `Arc` alone.
255//
256// This makes using the extention of the request somewhat awkward, because it needs to be unwrapped into a `RwLockWriteGuard`,
257// and so the `SQLxRequestExt` extension trait exists to make that nicer.
258
259#[async_trait]
260impl<State, DB> Middleware<State> for SQLxMiddleware<DB>
261where
262    State: Clone + Send + Sync + 'static,
263    DB: Database,
264    DB::Connection: Send + Sync + 'static,
265{
266    async fn handle(&self, mut req: Request<State>, next: Next<'_, State>) -> Result {
267        // Dual-purpose: Avoid ever running twice, or pick up a test connection if one exists.
268        //
269        // TODO(Fishrock): implement recursive depth transactions.
270        //   SQLx 0.4 Transactions which are recursive carry a Borrow to the containing Transaction.
271        //   Blocked by language feature for Tide - Request extensions cannot hold Borrows.
272        if req.ext::<ConnectionWrap<DB>>().is_some() {
273            return Ok(next.run(req).await);
274        }
275
276        // TODO(Fishrock): Allow this to be overridden somehow. Maybe check part of the path.
277        let is_safe = matches!(req.method(), Method::Get | Method::Head);
278
279        let conn_wrap_inner = if is_safe {
280            let conn_fut = self.pool.acquire();
281            #[cfg(feature = "tracing")]
282            let conn_fut = conn_fut.instrument(info_span!("Acquiring database connection"));
283            ConnectionWrapInner::Plain(conn_fut.await?)
284        } else {
285            let conn_fut = self.pool.begin();
286            #[cfg(feature = "tracing")]
287            let conn_fut =
288                conn_fut.instrument(info_span!("Acquiring database transaction", "COMMIT"));
289            ConnectionWrapInner::Transacting(conn_fut.await?)
290        };
291        let conn_wrap = Arc::new(RwLock::new(conn_wrap_inner));
292        req.set_ext(conn_wrap.clone());
293
294        let res = next.run(req).await;
295
296        if res.error().is_none() {
297            if let Ok(conn_wrap_inner) = Arc::try_unwrap(conn_wrap) {
298                if let ConnectionWrapInner::Transacting(connection) = conn_wrap_inner.into_inner() {
299                    // if we errored, sqlx::Transaction calls rollback on Drop.
300                    let commit_fut = connection.commit();
301                    #[cfg(feature = "tracing")]
302                    let commit_fut = commit_fut
303                        .instrument(info_span!("Commiting database transaction", "COMMIT"));
304                    commit_fut.await?;
305                }
306            } else {
307                // If this is hit, it is likely that an http_types (surf::http / tide::http) Request has been kept alive and was not consumed.
308                // This would be a programmer error.
309                // Given the pool would slowly be resource-starved if we continue, there is no good way to continue.
310                //
311                // I'm bewildered, you're bewildered. Let's panic!
312                panic!("We have err'd egregiously! Could not unwrap refcounted SQLx connection for COMMIT; handler may be storing connection or request inappropiately?")
313            }
314        }
315
316        Ok(res)
317    }
318}
319
320/// An extension trait for [tide::Request][] which does proper unwrapping of the connection from [`req.ext()`][].
321///
322/// [`req.ext()`]: https://docs.rs/tide/0.15.0/tide/struct.Request.html#method.ext
323/// [tide::Request]: https://docs.rs/tide/0.15.0/tide/struct.Request.html
324#[async_trait]
325pub trait SQLxRequestExt {
326    /// Get the SQLx connection for the current Request.
327    ///
328    /// This will return a "write" guard from a read-write lock.
329    /// Under the hood this will transparently be either a postgres transaction or a direct pooled connection.
330    ///
331    /// This will panic with an expect message if the `SQLxMiddleware` has not been run.
332    ///
333    /// ## Example
334    ///
335    /// ```no_run
336    /// # #[async_std::main]
337    /// # async fn main() -> anyhow::Result<()> {
338    /// # use tide_sqlx::SQLxMiddleware;
339    /// # use sqlx::postgres::Postgres;
340    /// #
341    /// # let mut app = tide::new();
342    /// # app.with(SQLxMiddleware::<Postgres>::new("postgres://localhost/a_database").await?);
343    /// #
344    /// use sqlx::Acquire; // Or sqlx::prelude::*;
345    ///
346    /// use tide_sqlx::SQLxRequestExt;
347    ///
348    /// app.at("/").post(|req: tide::Request<()>| async move {
349    ///     let mut pg_conn = req.sqlx_conn::<Postgres>().await;
350    ///
351    ///     sqlx::query("SELECT * FROM users")
352    ///         .fetch_optional(pg_conn.acquire().await?)
353    ///         .await;
354    ///
355    ///     Ok("")
356    /// });
357    /// # Ok(())
358    /// # }
359    /// ```
360    async fn sqlx_conn<'req, DB>(&'req self) -> RwLockWriteGuard<'req, ConnectionWrapInner<DB>>
361    where
362        DB: Database,
363        DB::Connection: Send + Sync + 'static;
364}
365
366#[async_trait]
367impl<T: Send + Sync + 'static> SQLxRequestExt for Request<T> {
368    async fn sqlx_conn<'req, DB>(&'req self) -> RwLockWriteGuard<'req, ConnectionWrapInner<DB>>
369    where
370        DB: Database,
371        DB::Connection: Send + Sync + 'static,
372    {
373        let sqlx_conn: &ConnectionWrap<DB> = self
374            .ext()
375            .expect("You must install SQLx middleware providing ConnectionWrap");
376        let rwlock_fut = sqlx_conn.write();
377        #[cfg(all(feature = "tracing", debug_assertions))]
378        let rwlock_fut =
379            rwlock_fut.instrument(debug_span!("Database connection RwLockWriteGuard acquire"));
380        rwlock_fut.await
381    }
382}