tide_sqlx/lib.rs
1//! A [Tide][] middleware which holds a pool of SQLx database connections, and automatically hands
2//! each [tide::Request][] a connection, which may transparently be either a database transaction,
3//! or a direct pooled database connection.
4//!
5//! By default, transactions are used for all http methods other than `GET` and `HEAD`.
6//!
7//! When using this, use the `SQLxRequestExt` extenstion trait to get the connection.
8//!
9//! ## Examples
10//!
11//! ### Basic
12//! ```no_run
13//! # #[async_std::main]
14//! # async fn main() -> anyhow::Result<()> {
15//! use sqlx::Acquire; // Or sqlx::prelude::*;
16//! use sqlx::postgres::Postgres;
17//!
18//! use tide_sqlx::SQLxMiddleware;
19//! use tide_sqlx::SQLxRequestExt;
20//!
21//! let mut app = tide::new();
22//! app.with(SQLxMiddleware::<Postgres>::new("postgres://localhost/a_database").await?);
23//!
24//! app.at("/").post(|req: tide::Request<()>| async move {
25//! let mut pg_conn = req.sqlx_conn::<Postgres>().await;
26//!
27//! sqlx::query("SELECT * FROM users")
28//! .fetch_optional(pg_conn.acquire().await?)
29//! .await;
30//!
31//! Ok("")
32//! });
33//! # Ok(())
34//! # }
35//! ```
36//!
37//! ### From sqlx `PoolOptions` and with `ConnectOptions`
38//! ```no_run
39//! # #[async_std::main]
40//! # async fn main() -> anyhow::Result<()> {
41//! use log::LevelFilter;
42//! use sqlx::{Acquire, ConnectOptions}; // Or sqlx::prelude::*;
43//! use sqlx::postgres::{PgConnectOptions, PgPoolOptions, Postgres};
44//!
45//! use tide_sqlx::SQLxMiddleware;
46//! use tide_sqlx::SQLxRequestExt;
47//!
48//! let mut connect_opts = PgConnectOptions::new();
49//! connect_opts.log_statements(LevelFilter::Debug);
50//!
51//! let pg_pool = PgPoolOptions::new()
52//! .max_connections(5)
53//! .connect_with(connect_opts)
54//! .await?;
55//!
56//! let mut app = tide::new();
57//! app.with(SQLxMiddleware::from(pg_pool));
58//!
59//! app.at("/").post(|req: tide::Request<()>| async move {
60//! let mut pg_conn = req.sqlx_conn::<Postgres>().await;
61//!
62//! sqlx::query("SELECT * FROM users")
63//! .fetch_optional(pg_conn.acquire().await?)
64//! .await;
65//!
66//! Ok("")
67//! });
68//! # Ok(())
69//! # }
70//! ```
71//!
72//! ## Why you may want to use this
73//!
74//! Database transactions are very useful because they allow easy, assured rollback if something goes wrong.
75//! However, transactions incur extra runtime cost which is too expensive to justify for READ operations that _do not need_ this behavior.
76//!
77//! In order to allow transactions to be used seamlessly in endpoints, this middleware manages a transaction if one is deemed desirable.
78//!
79//! [tide::Request]: https://docs.rs/tide/0.15.0/tide/struct.Request.html
80//! [Tide]: https://docs.rs/tide/0.15.0/tide/
81
82#![allow(clippy::upper_case_acronyms)] // SQLxMiddleware
83#![cfg_attr(feature = "docs", feature(doc_cfg))]
84
85use std::fmt::{self, Debug};
86use std::ops::{Deref, DerefMut};
87use std::sync::Arc;
88
89use async_std::sync::{RwLock, RwLockWriteGuard};
90use sqlx::pool::{Pool, PoolConnection};
91use sqlx::{Database, Transaction};
92use tide::utils::async_trait;
93use tide::{http::Method, Middleware, Next, Request, Result};
94
95#[cfg(all(feature = "tracing", debug_assertions))]
96use tracing_crate::debug_span;
97#[cfg(feature = "tracing")]
98use tracing_crate::{info_span, Instrument};
99
100#[cfg(all(test, not(feature = "postgres")))]
101compile_error!("The tests must be run with --features=test");
102
103#[cfg(feature = "postgres")]
104#[cfg_attr(feature = "docs", doc(cfg(feature = "postgres")))]
105/// Helpers specific to Postgres
106pub mod postgres;
107
108#[doc(hidden)]
109pub enum ConnectionWrapInner<DB>
110where
111 DB: Database,
112 DB::Connection: Send + Sync + 'static,
113{
114 Transacting(Transaction<'static, DB>),
115 Plain(PoolConnection<DB>),
116}
117
118impl<DB> Debug for ConnectionWrapInner<DB>
119where
120 DB: Database,
121 DB::Connection: Send + Sync + 'static,
122{
123 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
124 match self {
125 Self::Transacting(_) => f.debug_struct("ConnectionWrapInner::Transacting").finish(),
126 Self::Plain(_) => f.debug_struct("ConnectionWrapInner::Plain").finish(),
127 }
128 }
129}
130
131impl<DB> Deref for ConnectionWrapInner<DB>
132where
133 DB: Database,
134 DB::Connection: Send + Sync + 'static,
135{
136 type Target = DB::Connection;
137
138 fn deref(&self) -> &Self::Target {
139 match self {
140 ConnectionWrapInner::Plain(c) => c,
141 ConnectionWrapInner::Transacting(c) => c,
142 }
143 }
144}
145
146impl<DB> DerefMut for ConnectionWrapInner<DB>
147where
148 DB: Database,
149 DB::Connection: Send + Sync + 'static,
150{
151 fn deref_mut(&mut self) -> &mut Self::Target {
152 match self {
153 ConnectionWrapInner::Plain(c) => c,
154 ConnectionWrapInner::Transacting(c) => c,
155 }
156 }
157}
158
159#[doc(hidden)]
160pub type ConnectionWrap<DB> = Arc<RwLock<ConnectionWrapInner<DB>>>;
161
162/// This middleware holds a pool of SQLx database connections, and automatically hands each
163/// [tide::Request][] a connection, which may transparently be either a database transaction,
164/// or a direct pooled database connection.
165///
166/// By default, transactions are used for all http methods other than `GET` and `HEAD`.
167///
168/// When using this, use the `SQLxRequestExt` extenstion trait to get the connection.
169///
170/// ## Example
171///
172/// ```no_run
173/// # #[async_std::main]
174/// # async fn main() -> anyhow::Result<()> {
175/// use sqlx::Acquire; // Or sqlx::prelude::*;
176/// use sqlx::postgres::Postgres;
177///
178/// use tide_sqlx::SQLxMiddleware;
179/// use tide_sqlx::SQLxRequestExt;
180///
181/// let mut app = tide::new();
182/// app.with(SQLxMiddleware::<Postgres>::new("postgres://localhost/a_database").await?);
183///
184/// app.at("/").post(|req: tide::Request<()>| async move {
185/// let mut pg_conn = req.sqlx_conn::<Postgres>().await;
186///
187/// sqlx::query("SELECT * FROM users")
188/// .fetch_optional(pg_conn.acquire().await?)
189/// .await;
190///
191/// Ok("")
192/// });
193/// # Ok(())
194/// # }
195/// ```
196///
197/// [tide::Request]: https://docs.rs/tide/0.15.0/tide/struct.Request.html
198#[derive(Debug, Clone)]
199pub struct SQLxMiddleware<DB>
200where
201 DB: Database,
202 DB::Connection: Send + Sync + 'static,
203{
204 pool: Pool<DB>,
205}
206
207impl<DB> SQLxMiddleware<DB>
208where
209 DB: Database,
210 DB::Connection: Send + Sync + 'static,
211{
212 /// Create a new instance of `SQLxMiddleware`.
213 pub async fn new(pgurl: &'_ str) -> std::result::Result<Self, sqlx::Error> {
214 let pool: Pool<DB> = Pool::connect(pgurl).await?;
215 Ok(Self { pool })
216 }
217}
218
219impl<DB> AsRef<Pool<DB>> for SQLxMiddleware<DB>
220where
221 DB: Database,
222 DB::Connection: Send + Sync + 'static,
223{
224 fn as_ref(&self) -> &Pool<DB> {
225 &self.pool
226 }
227}
228
229impl<DB> From<Pool<DB>> for SQLxMiddleware<DB>
230where
231 DB: Database,
232 DB::Connection: Send + Sync + 'static,
233{
234 /// Create a new instance of `SQLxMiddleware` from a `sqlx::Pool`.
235 fn from(pool: Pool<DB>) -> Self {
236 Self { pool }
237 }
238}
239
240// This is complicated because of sqlx's typing. We would like a dynamic `sqlx::Executor`, however the Executor trait
241// cannot be made into an object because it has generic methods.
242// Rust does not allow this due to exponential fat-pointer table size.
243// See https://doc.rust-lang.org/error-index.html#method-has-generic-type-parameters for more information.
244//
245// In order to get a concrete type for both which we can deref to a `Connection` on, we make an enum with multiple types.
246// The types must be concrete and non-generic because the outer type much be fetchable from `Request::ext`, which is a typemap.
247//
248// The type of the enum must be in an `Arc` because we want to be able to tell it to commit at the end of the middleware
249// once we've gotten a response back. This is because anything in `Request::ext` is lost in the endpoint without manual movement
250// to the `Response`. Tide may someday be able to do this automatically but not as of 0.15. An `Arc` is the correct choice to keep
251// something between mutltiple owned contexts over a threaded futures executor.
252//
253// However interior mutability (`RwLock`) is also required because `Acquire` requires mutable self reference,
254// requiring that we gain mutable lock from the `Arc`, which is not possible with an `Arc` alone.
255//
256// This makes using the extention of the request somewhat awkward, because it needs to be unwrapped into a `RwLockWriteGuard`,
257// and so the `SQLxRequestExt` extension trait exists to make that nicer.
258
259#[async_trait]
260impl<State, DB> Middleware<State> for SQLxMiddleware<DB>
261where
262 State: Clone + Send + Sync + 'static,
263 DB: Database,
264 DB::Connection: Send + Sync + 'static,
265{
266 async fn handle(&self, mut req: Request<State>, next: Next<'_, State>) -> Result {
267 // Dual-purpose: Avoid ever running twice, or pick up a test connection if one exists.
268 //
269 // TODO(Fishrock): implement recursive depth transactions.
270 // SQLx 0.4 Transactions which are recursive carry a Borrow to the containing Transaction.
271 // Blocked by language feature for Tide - Request extensions cannot hold Borrows.
272 if req.ext::<ConnectionWrap<DB>>().is_some() {
273 return Ok(next.run(req).await);
274 }
275
276 // TODO(Fishrock): Allow this to be overridden somehow. Maybe check part of the path.
277 let is_safe = matches!(req.method(), Method::Get | Method::Head);
278
279 let conn_wrap_inner = if is_safe {
280 let conn_fut = self.pool.acquire();
281 #[cfg(feature = "tracing")]
282 let conn_fut = conn_fut.instrument(info_span!("Acquiring database connection"));
283 ConnectionWrapInner::Plain(conn_fut.await?)
284 } else {
285 let conn_fut = self.pool.begin();
286 #[cfg(feature = "tracing")]
287 let conn_fut =
288 conn_fut.instrument(info_span!("Acquiring database transaction", "COMMIT"));
289 ConnectionWrapInner::Transacting(conn_fut.await?)
290 };
291 let conn_wrap = Arc::new(RwLock::new(conn_wrap_inner));
292 req.set_ext(conn_wrap.clone());
293
294 let res = next.run(req).await;
295
296 if res.error().is_none() {
297 if let Ok(conn_wrap_inner) = Arc::try_unwrap(conn_wrap) {
298 if let ConnectionWrapInner::Transacting(connection) = conn_wrap_inner.into_inner() {
299 // if we errored, sqlx::Transaction calls rollback on Drop.
300 let commit_fut = connection.commit();
301 #[cfg(feature = "tracing")]
302 let commit_fut = commit_fut
303 .instrument(info_span!("Commiting database transaction", "COMMIT"));
304 commit_fut.await?;
305 }
306 } else {
307 // If this is hit, it is likely that an http_types (surf::http / tide::http) Request has been kept alive and was not consumed.
308 // This would be a programmer error.
309 // Given the pool would slowly be resource-starved if we continue, there is no good way to continue.
310 //
311 // I'm bewildered, you're bewildered. Let's panic!
312 panic!("We have err'd egregiously! Could not unwrap refcounted SQLx connection for COMMIT; handler may be storing connection or request inappropiately?")
313 }
314 }
315
316 Ok(res)
317 }
318}
319
320/// An extension trait for [tide::Request][] which does proper unwrapping of the connection from [`req.ext()`][].
321///
322/// [`req.ext()`]: https://docs.rs/tide/0.15.0/tide/struct.Request.html#method.ext
323/// [tide::Request]: https://docs.rs/tide/0.15.0/tide/struct.Request.html
324#[async_trait]
325pub trait SQLxRequestExt {
326 /// Get the SQLx connection for the current Request.
327 ///
328 /// This will return a "write" guard from a read-write lock.
329 /// Under the hood this will transparently be either a postgres transaction or a direct pooled connection.
330 ///
331 /// This will panic with an expect message if the `SQLxMiddleware` has not been run.
332 ///
333 /// ## Example
334 ///
335 /// ```no_run
336 /// # #[async_std::main]
337 /// # async fn main() -> anyhow::Result<()> {
338 /// # use tide_sqlx::SQLxMiddleware;
339 /// # use sqlx::postgres::Postgres;
340 /// #
341 /// # let mut app = tide::new();
342 /// # app.with(SQLxMiddleware::<Postgres>::new("postgres://localhost/a_database").await?);
343 /// #
344 /// use sqlx::Acquire; // Or sqlx::prelude::*;
345 ///
346 /// use tide_sqlx::SQLxRequestExt;
347 ///
348 /// app.at("/").post(|req: tide::Request<()>| async move {
349 /// let mut pg_conn = req.sqlx_conn::<Postgres>().await;
350 ///
351 /// sqlx::query("SELECT * FROM users")
352 /// .fetch_optional(pg_conn.acquire().await?)
353 /// .await;
354 ///
355 /// Ok("")
356 /// });
357 /// # Ok(())
358 /// # }
359 /// ```
360 async fn sqlx_conn<'req, DB>(&'req self) -> RwLockWriteGuard<'req, ConnectionWrapInner<DB>>
361 where
362 DB: Database,
363 DB::Connection: Send + Sync + 'static;
364}
365
366#[async_trait]
367impl<T: Send + Sync + 'static> SQLxRequestExt for Request<T> {
368 async fn sqlx_conn<'req, DB>(&'req self) -> RwLockWriteGuard<'req, ConnectionWrapInner<DB>>
369 where
370 DB: Database,
371 DB::Connection: Send + Sync + 'static,
372 {
373 let sqlx_conn: &ConnectionWrap<DB> = self
374 .ext()
375 .expect("You must install SQLx middleware providing ConnectionWrap");
376 let rwlock_fut = sqlx_conn.write();
377 #[cfg(all(feature = "tracing", debug_assertions))]
378 let rwlock_fut =
379 rwlock_fut.instrument(debug_span!("Database connection RwLockWriteGuard acquire"));
380 rwlock_fut.await
381 }
382}