umbral_core/db.rs
1//! Database pool registry and connection helpers.
2//!
3//! ## DbPool: the multi-backend seam
4//!
5//! [`DbPool`] is a small enum that wraps either a [`sqlx::SqlitePool`]
6//! or a [`sqlx::PgPool`]. It's the type [`connect`] returns and the
7//! type [`AppBuilder::database`](crate::app::AppBuilder::database)
8//! stores, so the framework remembers which backend each registered
9//! alias is connected to.
10//!
11//! ### Why an enum, not `sqlx::AnyPool`
12//!
13//! `sqlx::AnyPool` is the more "correct" abstraction at the type
14//! level: one pool type that dispatches to the right driver at
15//! runtime. But it has a real-world cost — sea-query-binder (the
16//! crate the QuerySet uses to bind parameters) doesn't have an
17//! `Any` backend; values must be bound through the per-driver
18//! query builder. Forcing every plugin and the queryset onto
19//! `AnyPool` therefore turns the simple multi-backend goal into a
20//! cascade through every binding site.
21//!
22//! The enum is the right shape for now. Every plugin still gets a
23//! typed `SqlitePool` from [`pool`] / [`pool_for`], and the
24//! ergonomics of `sqlx::query(...)` against that pool stay
25//! identical. Phase 2 of the Postgres rollout (per `FEATURES.md`)
26//! threads the variant choice through the migration engine and
27//! queryset; Phase 1 only needs the type seam.
28//!
29//! ### Postgres at boot, today
30//!
31//! [`connect`] accepts both `sqlite://...` and `postgres://...`
32//! URLs and returns a [`DbPool`] of the matching variant. The
33//! detection mirrors [`crate::backend::detect`], so the boot path
34//! has one URL parser and they can't drift.
35//!
36//! At Phase 1 the rest of the framework (queryset, migration
37//! engine, every plugin) still reads through [`pool`] / [`pool_for`]
38//! which hand back a `SqlitePool`. If the registered pool is
39//! actually a `PgPool`, those functions panic with a clear
40//! "Postgres support arrives in Phase 2" message. That's
41//! deliberate: the type seam exists, but callers that aren't
42//! ready for Postgres surface immediately at runtime rather than
43//! limping along and producing wrong results.
44
45use std::collections::HashMap;
46use std::pin::Pin;
47use std::sync::OnceLock;
48
49use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions, SqliteSynchronous};
50use sqlx::{ConnectOptions, PgPool, SqlitePool};
51use std::str::FromStr;
52use std::time::Duration;
53
54pub mod route_context;
55pub mod router;
56
57pub use route_context::{RouteContext, TenantKey, current as route_context};
58pub use router::{Alias, DatabaseRouter, DefaultRouter, RouteOp, Schema, router};
59
60/// A pool of database connections, typed by backend.
61///
62/// Cloning is cheap — both variants wrap an `Arc`-backed inner
63/// pool, so a `clone()` just bumps the refcount.
64#[derive(Debug, Clone)]
65pub enum DbPool {
66 /// SQLite-backed connection pool. The default through Phase 1
67 /// and the only variant the queryset / migration engine accepts
68 /// today.
69 Sqlite(SqlitePool),
70 /// Postgres-backed connection pool. Connectable at Phase 1, but
71 /// any code path that calls into the queryset or migration
72 /// engine against this variant panics with a clear "arrives in
73 /// Phase 2" message. The seam itself is the deliverable here.
74 Postgres(PgPool),
75}
76
77impl DbPool {
78 /// Borrow the inner `SqlitePool`. Returns `None` for a Postgres
79 /// pool. Phase 1 callers that haven't migrated to the dispatch
80 /// API yet typically reach for [`Self::sqlite_or_panic`]; the
81 /// returned-Option variant is for the (rare today) code that
82 /// wants to gracefully fall back.
83 pub fn as_sqlite(&self) -> Option<&SqlitePool> {
84 match self {
85 DbPool::Sqlite(p) => Some(p),
86 DbPool::Postgres(_) => None,
87 }
88 }
89
90 /// Borrow the inner `PgPool`. Returns `None` for a SQLite pool.
91 pub fn as_postgres(&self) -> Option<&PgPool> {
92 match self {
93 DbPool::Sqlite(_) => None,
94 DbPool::Postgres(p) => Some(p),
95 }
96 }
97
98 /// Borrow the inner `SqlitePool`, panicking with a clear "Postgres
99 /// support arrives in Phase 2" message on a Postgres variant. Used
100 /// by [`pool`] and [`pool_for`] so existing plugin code (that
101 /// expects a `SqlitePool`) doesn't quietly limp along when the
102 /// operator connects to Postgres.
103 pub fn sqlite_or_panic(&self) -> &SqlitePool {
104 self.as_sqlite().expect(
105 "umbral: a Postgres pool is registered but this code path \
106 still reads SqlitePool. Full Postgres support lands in \
107 Phase 2 of the rollout — see FEATURES.md and the \
108 `DbPool` rustdoc.",
109 )
110 }
111
112 /// The string identifier of the underlying backend. Matches
113 /// [`crate::backend::DatabaseBackend::name`] for the active
114 /// pool variant.
115 pub fn backend_name(&self) -> &'static str {
116 match self {
117 DbPool::Sqlite(_) => "sqlite",
118 DbPool::Postgres(_) => "postgres",
119 }
120 }
121}
122
123impl From<SqlitePool> for DbPool {
124 fn from(pool: SqlitePool) -> Self {
125 DbPool::Sqlite(pool)
126 }
127}
128
129impl From<PgPool> for DbPool {
130 fn from(pool: PgPool) -> Self {
131 DbPool::Postgres(pool)
132 }
133}
134
135/// Holds all registered database pools, keyed by alias.
136/// The "default" pool is always present after `App::build()` succeeds.
137static POOLS: OnceLock<HashMap<String, DbPool>> = OnceLock::new();
138
139/// Runtime tenant-pool registry for **database-per-tenant** multitenancy:
140/// pools registered AFTER `App::build()`, as tenants are onboarded (e.g. by a
141/// `DatabaseRouter` that maps a request's tenant to its own database). The
142/// static `POOLS` map above is set once at build; this `RwLock`-backed map
143/// grows at runtime via [`register_tenant_pool`]. Stored pools are leaked to
144/// `&'static` on insert — a tenant pool lives for the whole process (you never
145/// drop one mid-serve), so [`pool_for_dispatched`] keeps its zero-cost
146/// `&'static DbPool` return: the `&'static` is copied out before the read guard
147/// drops, so no lock guard ever escapes.
148static DYNAMIC_POOLS: OnceLock<std::sync::RwLock<HashMap<String, &'static DbPool>>> =
149 OnceLock::new();
150
151/// Global default for whether ORM write terminals should wrap in a
152/// transaction. Set by `AppBuilder::atomic_transactions(...)`; read by
153/// every terminal that supports `.atomic()` / `.non_atomic()`. Unset
154/// (the default) means "no wrapping" — preserves existing behaviour for
155/// apps that don't opt in.
156static ATOMIC_DEFAULT: OnceLock<bool> = OnceLock::new();
157
158/// Publish the app-wide atomic-transactions default. Called by
159/// `AppBuilder::build()` exactly when the user set the flag via
160/// `atomic_transactions(...)`. Idempotent across re-init attempts —
161/// the first set wins, matching the rest of the OnceLock-backed
162/// ambient state.
163pub(crate) fn init_atomic_default(enabled: bool) {
164 let _ = ATOMIC_DEFAULT.set(enabled);
165}
166
167/// Read the app-wide atomic-transactions default. Returns `false` when
168/// the builder didn't call `atomic_transactions(...)` (or when the
169/// ambient state hasn't been published yet, as in unit tests that
170/// drive the ORM with `.on(&pool)` and never call `App::build()`).
171pub fn atomic_default() -> bool {
172 *ATOMIC_DEFAULT.get().unwrap_or(&false)
173}
174
175/// Initialize the pool registry. Called by `AppBuilder::build()` only.
176pub(crate) fn init(pools: HashMap<String, DbPool>) {
177 POOLS
178 .set(pools)
179 .expect("umbral::db::init called more than once");
180}
181
182/// Return the default connection pool, typed as a [`SqlitePool`].
183///
184/// This is the function every plugin and the queryset call. The
185/// internal storage is a [`DbPool`]; this unwraps to the
186/// `SqlitePool` variant or panics with a Phase-2 hint, matching
187/// the documented Phase 1 contract.
188///
189/// # Panics
190///
191/// Panics if `App::build()` hasn't run or the registered default
192/// pool is Postgres.
193pub fn pool() -> SqlitePool {
194 pool_dispatched().sqlite_or_panic().clone()
195}
196
197/// Return the default connection pool as a typed [`DbPool`].
198///
199/// Use this from code that's ready to dispatch on backend (the
200/// migration engine and queryset will move to this surface in
201/// Phase 2). Plugin code can stay on [`pool`] until then.
202///
203/// # Panics
204///
205/// Panics if `App::build()` hasn't run.
206pub fn pool_dispatched() -> &'static DbPool {
207 POOLS
208 .get()
209 .expect("umbral: db pool not initialised — did you call App::build()?")
210 .get("default")
211 .expect("umbral: no default database registered")
212}
213
214/// Like [`pool_dispatched`] but returns `None` instead of panicking
215/// when no pool is registered yet (`App::build()` hasn't run, or this
216/// is a pure SQL-building call such as `QuerySet::to_sql` in a test with
217/// no app booted). Used by runtime advisory paths that must not crash a
218/// query-builder call — see the RIGHT-JOIN-on-old-SQLite warning.
219pub fn try_pool_dispatched() -> Option<&'static DbPool> {
220 POOLS.get().and_then(|pools| pools.get("default"))
221}
222
223/// Return a named connection pool, typed as a [`SqlitePool`].
224///
225/// # Panics
226///
227/// Panics if `App::build()` hasn't run, the alias isn't registered,
228/// or the registered pool is Postgres.
229pub fn pool_for(alias: &str) -> SqlitePool {
230 pool_for_dispatched(alias).sqlite_or_panic().clone()
231}
232
233/// Return a named connection pool as a typed [`DbPool`]. Phase 2
234/// surface; see [`pool_dispatched`].
235///
236/// Resolution order: the build-time `POOLS` map first, then the runtime
237/// [`register_tenant_pool`] registry (database-per-tenant). Panics only when
238/// the alias is in neither.
239pub fn pool_for_dispatched(alias: &str) -> &'static DbPool {
240 if let Some(p) = POOLS.get().and_then(|pools| pools.get(alias)) {
241 return p;
242 }
243 if let Some(p) = DYNAMIC_POOLS
244 .get()
245 .and_then(|reg| reg.read().ok().and_then(|m| m.get(alias).copied()))
246 {
247 return p;
248 }
249 if POOLS.get().is_none() {
250 panic!("umbral: db pool not initialised — did you call App::build()?");
251 }
252 panic!("umbral: no database registered under alias '{alias}'");
253}
254
255/// Register a database pool under `alias` at runtime — the database-per-tenant
256/// seam. Unlike the build-time `App::builder().database(alias, pool)` (which
257/// fills the static pool map), this may be called any time after `App::build()`
258/// as tenants are onboarded. First-write-wins: re-registering an existing alias
259/// is a no-op (a re-resolution of the same tenant won't churn its pool) and the
260/// surplus pool is dropped without leaking. The stored pool is leaked to
261/// `&'static` because tenant pools are process-lifetime.
262///
263/// A [`DatabaseRouter`](crate::db::router::DatabaseRouter) whose
264/// `db_for_read`/`db_for_write` returns `alias` for a tenant request then routes
265/// that tenant's queries to this pool.
266pub fn register_tenant_pool(alias: impl Into<String>, pool: DbPool) {
267 let alias = alias.into();
268 let mut guard = DYNAMIC_POOLS
269 .get_or_init(|| std::sync::RwLock::new(HashMap::new()))
270 .write()
271 .expect("umbral: dynamic pool registry poisoned");
272 if guard.contains_key(&alias) {
273 return; // first-write-wins; `pool` is dropped here, not leaked
274 }
275 let leaked: &'static DbPool = Box::leak(Box::new(pool));
276 guard.insert(alias, leaked);
277}
278
279/// True if `alias` resolves to a registered pool — build-time `POOLS` or the
280/// runtime tenant registry. A router can use this to fall back to the default
281/// pool for a tenant whose database hasn't been onboarded yet.
282pub fn pool_alias_registered(alias: &str) -> bool {
283 POOLS.get().is_some_and(|p| p.contains_key(alias))
284 || DYNAMIC_POOLS
285 .get()
286 .and_then(|reg| reg.read().ok().map(|m| m.contains_key(alias)))
287 .unwrap_or(false)
288}
289
290/// Ping the default database pool with a backend-appropriate liveness
291/// query (`SELECT 1`).
292///
293/// Resolves the ambient pool via [`pool_dispatched`] and dispatches:
294///
295/// - **SQLite** — `SELECT 1` via the sqlite driver.
296/// - **Postgres** — `SELECT 1` via the postgres driver.
297///
298/// Returns `Ok(())` when the pool is reachable. Returns
299/// `Err(sqlx::Error)` on any connection or query failure so callers
300/// can map it to a wire-friendly string without exposing the full sqlx
301/// error type.
302///
303/// # Panics
304///
305/// Panics if `App::build()` hasn't run (same contract as
306/// [`pool_dispatched`]).
307pub async fn ping() -> Result<(), sqlx::Error> {
308 match pool_dispatched() {
309 DbPool::Sqlite(p) => {
310 sqlx::query("SELECT 1").execute(p).await.map(|_| ())
311 }
312 DbPool::Postgres(p) => {
313 sqlx::query("SELECT 1").execute(p).await.map(|_| ())
314 }
315 }
316}
317
318/// List every registered pool alias, sorted alphabetically.
319///
320/// Used by the migration engine to walk each DB in deterministic
321/// order so per-DB tracking tables get created and per-DB diffs run
322/// against the right model subset. The `"default"` alias is always
323/// present after `App::build()` succeeds and lands wherever
324/// alphabetical sort puts it (typically first).
325///
326/// # Panics
327///
328/// Panics if `App::build()` hasn't run.
329pub fn registered_aliases() -> Vec<String> {
330 let mut aliases: Vec<String> = POOLS
331 .get()
332 .expect("umbral: db pool not initialised — did you call App::build()?")
333 .keys()
334 .cloned()
335 .collect();
336 aliases.sort();
337 aliases
338}
339
340/// Open a new connection pool for the given database URL.
341///
342/// Dispatches on the URL scheme:
343///
344/// - `sqlite://...` or `sqlite::memory:` returns a
345/// [`DbPool::Sqlite`].
346/// - `postgres://...` / `postgresql://...` returns a
347/// [`DbPool::Postgres`].
348///
349/// Any other scheme surfaces as an `sqlx::Error::Configuration`.
350/// For callers that already have a typed pool, [`From`] impls on
351/// [`DbPool`] convert directly: `let dp: DbPool = sqlite_pool.into();`.
352pub async fn connect(url: &str) -> Result<DbPool, sqlx::Error> {
353 let scheme = url
354 .split("://")
355 .next()
356 .and_then(|s| s.split(':').next())
357 .unwrap_or(url);
358 match scheme {
359 "sqlite" => Ok(DbPool::Sqlite(connect_sqlite(url).await?)),
360 "postgres" | "postgresql" => Ok(DbPool::Postgres(connect_postgres(url).await?)),
361 other => Err(sqlx::Error::Configuration(
362 format!(
363 "umbral::db::connect: unsupported URL scheme `{other}://`. \
364 Phase 1 supports `sqlite://` and `postgres://`."
365 )
366 .into(),
367 )),
368 }
369}
370
371/// The effective pool configuration, resolved from [`crate::settings`]
372/// when installed and falling back to the documented production defaults
373/// otherwise (a pool can be opened before settings are installed). Shared
374/// by [`connect_postgres`] and [`connect_sqlite`] so both backends honour
375/// the same `UMBRAL_DB_*` knobs (gaps2 #91).
376struct PoolConfig {
377 max_connections: u32,
378 min_connections: u32,
379 acquire_timeout_secs: u64,
380 idle_timeout_secs: Option<u64>,
381 max_lifetime_secs: Option<u64>,
382 test_before_acquire: bool,
383}
384
385impl PoolConfig {
386 fn resolve() -> Self {
387 match crate::settings::get_opt() {
388 Some(s) => PoolConfig {
389 max_connections: s.db_max_connections,
390 min_connections: s.db_min_connections,
391 acquire_timeout_secs: s.db_acquire_timeout_secs,
392 idle_timeout_secs: s.db_idle_timeout_secs,
393 max_lifetime_secs: s.db_max_lifetime_secs,
394 test_before_acquire: s.db_test_before_acquire,
395 },
396 // Defaults mirror the `default_db_*` fns in `settings`.
397 None => PoolConfig {
398 max_connections: 10,
399 min_connections: 0,
400 acquire_timeout_secs: 30,
401 idle_timeout_secs: Some(600),
402 max_lifetime_secs: Some(1800),
403 test_before_acquire: true,
404 },
405 }
406 }
407
408 /// Emit one operator-facing line describing the pool that's about to
409 /// be built, so the effective config is visible in the boot log.
410 fn log(&self, backend: &str) {
411 tracing::info!(
412 backend,
413 max_connections = self.max_connections.max(1),
414 min_connections = self.min_connections,
415 acquire_timeout_secs = self.acquire_timeout_secs,
416 idle_timeout_secs = ?self.idle_timeout_secs,
417 max_lifetime_secs = ?self.max_lifetime_secs,
418 test_before_acquire = self.test_before_acquire,
419 "umbral: opening database pool"
420 );
421 }
422}
423
424/// Open a Postgres pool from a URL with umbral's pool configuration.
425///
426/// PERF-5 / gaps2 #91: bare `PgPool::connect` uses sqlx's defaults with
427/// **no acquire timeout**, so a saturated pool blocks request tasks
428/// forever. We always apply the full set of pool knobs — `max_connections`,
429/// `min_connections`, a bounded `acquire_timeout` (fail fast),
430/// `idle_timeout`, `max_lifetime`, and `test_before_acquire` — read from
431/// [`crate::settings`] when available (falling back to the documented
432/// production defaults if the pool is opened before settings are
433/// installed). `idle_timeout`/`max_lifetime` are only applied when `Some`;
434/// a `None` (env `0`/empty) leaves that recycling disabled.
435pub async fn connect_postgres(url: &str) -> Result<PgPool, sqlx::Error> {
436 use std::time::Duration;
437 let cfg = PoolConfig::resolve();
438 cfg.log("postgres");
439
440 let mut opts = sqlx::postgres::PgPoolOptions::new()
441 .max_connections(cfg.max_connections.max(1))
442 .min_connections(cfg.min_connections)
443 .acquire_timeout(Duration::from_secs(cfg.acquire_timeout_secs))
444 .test_before_acquire(cfg.test_before_acquire);
445 if let Some(secs) = cfg.idle_timeout_secs {
446 opts = opts.idle_timeout(Duration::from_secs(secs));
447 }
448 if let Some(secs) = cfg.max_lifetime_secs {
449 opts = opts.max_lifetime(Duration::from_secs(secs));
450 }
451 opts.connect(url).await
452}
453
454/// Open a SQLite-backed pool from a URL.
455///
456/// Applies the standard production PRAGMAs to every connection in the
457/// pool: WAL journal, NORMAL synchronous, a 5-second busy-timeout, and
458/// foreign-key enforcement on. Without these, a fresh `SqlitePool` ends
459/// up in `journal_mode = DELETE` + `synchronous = FULL` — the safe
460/// SQLite defaults that cost ~1-4 seconds per concurrent INSERT once
461/// any other connection touches the file (the rollback-journal lock
462/// serialises writers).
463///
464/// | PRAGMA | Value | Why |
465/// |---|---|---|
466/// | `journal_mode` | `WAL` | Readers don't block writers; a single writer at a time but no full-file lock. Order-of-magnitude faster for any concurrent workload — typically the session/auth/audit tables fanning out. |
467/// | `synchronous` | `NORMAL` | Skips the per-commit fsync of the rollback journal; safe with WAL since the WAL log is fsynced on checkpoint. The official SQLite docs call this the right pairing with WAL for "most applications". |
468/// | `busy_timeout` | `5000ms` | Wait up to 5 s for a contended writer to release the lock before raising `SQLITE_BUSY`. Without this, two concurrent writers immediately race to error. |
469/// | `foreign_keys` | `ON` | sqlite turns FK enforcement off by default. The ORM emits `REFERENCES` clauses assuming they're respected — turning it on per connection makes the FK contract real. |
470///
471/// **In-memory URLs are backed by a process-unique temp file.** A bare
472/// `sqlite::memory:` gives every connection in the pool its OWN private,
473/// empty database, so a table created on one connection is invisible to a
474/// query that lands on another — and a shared in-memory database doesn't
475/// survive the connection (or the tokio runtime) that created it being
476/// dropped. Both surface as a flaky "no such table" whenever a pool is
477/// reused across queries or test cases. Routing in-memory URLs through a
478/// small temp file (which every connection sees and which persists for the
479/// process) sidesteps both — the same approach `umbral-testing::TempPool`
480/// already documents. File-backed (`sqlite://app.db`) and Postgres URLs are
481/// untouched.
482pub async fn connect_sqlite(url: &str) -> Result<SqlitePool, sqlx::Error> {
483 use std::sync::atomic::{AtomicU64, Ordering};
484 static MEM_SEQ: AtomicU64 = AtomicU64::new(0);
485
486 let lower = url.to_ascii_lowercase();
487 let in_memory = lower.contains(":memory:") || lower.contains("mode=memory");
488
489 let opts = if in_memory {
490 let n = MEM_SEQ.fetch_add(1, Ordering::Relaxed);
491 let path =
492 std::env::temp_dir().join(format!("umbral_mem_{}_{n}.sqlite", std::process::id()));
493 // Best-effort: remove a stale file from a previous run with this
494 // exact (pid, seq) — pids recycle. WAL/SHM siblings are recreated.
495 let _ = std::fs::remove_file(&path);
496 SqliteConnectOptions::new()
497 .filename(&path)
498 .create_if_missing(true)
499 } else {
500 SqliteConnectOptions::from_str(url)?
501 };
502 let opts = opts
503 .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
504 .synchronous(SqliteSynchronous::Normal)
505 .busy_timeout(Duration::from_secs(5))
506 .foreign_keys(true)
507 // Disable per-statement logging — sqlx's default INFO-level
508 // logger reads every statement before execution, which adds a
509 // measurable per-query overhead under load. The `slow statement`
510 // WARN at the 1-second threshold stays on, since it goes via a
511 // separate log target.
512 .log_statements(tracing::log::LevelFilter::Off);
513
514 // gaps2 #91: apply the same settings-driven pool knobs as Postgres so
515 // a single `UMBRAL_DB_*` configuration governs every backend. SQLite is
516 // effectively single-writer (WAL serialises writers behind one lock),
517 // so a large `max_connections` mainly buys concurrent *readers*; the
518 // knob is still honoured rather than hardcoding a divergent SQLite path.
519 let cfg = PoolConfig::resolve();
520 cfg.log("sqlite");
521 let mut pool_opts = SqlitePoolOptions::new()
522 .max_connections(cfg.max_connections.max(1))
523 .min_connections(cfg.min_connections)
524 .acquire_timeout(Duration::from_secs(cfg.acquire_timeout_secs))
525 .test_before_acquire(cfg.test_before_acquire);
526 if let Some(secs) = cfg.idle_timeout_secs {
527 pool_opts = pool_opts.idle_timeout(Duration::from_secs(secs));
528 }
529 if let Some(secs) = cfg.max_lifetime_secs {
530 pool_opts = pool_opts.max_lifetime(Duration::from_secs(secs));
531 }
532 pool_opts.connect_with(opts).await
533}
534
535/// Gracefully close the ambient default database pool (gaps2 #91).
536///
537/// Call this once during shutdown — after the HTTP server has stopped
538/// accepting connections — to let sqlx flush in-flight work and close
539/// every pooled connection cleanly rather than having them dropped
540/// abruptly when the process exits. For SQLite this also lets WAL
541/// checkpoint; for Postgres it sends a clean `Terminate` so the server
542/// doesn't log the connections as unexpectedly lost.
543///
544/// Closing is terminal: the ambient [`OnceLock`] is left in place (it
545/// can't be unset), so the pool object remains registered but is closed.
546/// Acquiring from a closed pool errors, which is the intended post-
547/// shutdown behaviour. A no-op if no pool was ever registered.
548///
549/// ```rust,ignore
550/// // in your shutdown handler, after the server stops:
551/// umbral::db::close().await;
552/// ```
553pub async fn close() {
554 if let Some(pools) = POOLS.get() {
555 for db in pools.values() {
556 match db {
557 DbPool::Sqlite(p) => p.close().await,
558 DbPool::Postgres(p) => p.close().await,
559 }
560 }
561 }
562}
563
564// =============================================================================
565// Transaction support
566// =============================================================================
567
568/// An active database transaction, typed by backend.
569///
570/// `Transaction` wraps either a `sqlx::Transaction<'static, sqlx::Sqlite>` or
571/// a `sqlx::Transaction<'static, sqlx::Postgres>` and provides the executor
572/// surface needed by the ORM's query terminals.
573///
574/// ## How to obtain one
575///
576/// The typical path is through the top-level closure helpers:
577///
578/// ```rust,ignore
579/// use umbral::db::transaction;
580///
581/// let order = transaction(|tx| async move {
582/// let o = Order::objects().on_tx(tx).create(new_order).await?;
583/// Inventory::objects().on_tx(tx).filter(...).update_values(...).await?;
584/// Ok::<_, MyError>(o)
585/// }).await?;
586/// ```
587///
588/// For manual control (committing or rolling back yourself) call
589/// [`begin`] / [`begin_sqlite`] / [`begin_pg`] directly.
590///
591/// ## Executor contract
592///
593/// The `as_sqlite_mut` / `as_pg_mut` accessors return a mutable reference to
594/// the underlying sqlx transaction so ORM internals can call
595/// `sqlx::query(...).execute(&mut *inner)`. Both the `QuerySet::on_tx` and
596/// `Manager::create_in_tx` methods receive `&mut Transaction` and dispatch
597/// through these accessors.
598pub struct Transaction {
599 inner: TransactionInner,
600}
601
602enum TransactionInner {
603 Sqlite(sqlx::Transaction<'static, sqlx::Sqlite>),
604 Postgres(sqlx::Transaction<'static, sqlx::Postgres>),
605}
606
607impl Transaction {
608 /// Return a mutable reference to the inner SQLite transaction, or `None`
609 /// when this is a Postgres transaction.
610 pub fn as_sqlite_mut(&mut self) -> Option<&mut sqlx::Transaction<'static, sqlx::Sqlite>> {
611 match &mut self.inner {
612 TransactionInner::Sqlite(tx) => Some(tx),
613 TransactionInner::Postgres(_) => None,
614 }
615 }
616
617 /// Return a mutable reference to the inner Postgres transaction, or `None`
618 /// when this is a SQLite transaction.
619 pub fn as_pg_mut(&mut self) -> Option<&mut sqlx::Transaction<'static, sqlx::Postgres>> {
620 match &mut self.inner {
621 TransactionInner::Sqlite(_) => None,
622 TransactionInner::Postgres(tx) => Some(tx),
623 }
624 }
625
626 /// The backend name — `"sqlite"` or `"postgres"`. Mirrors
627 /// [`DbPool::backend_name`] so shared dispatch helpers can use the same
628 /// match arm.
629 pub fn backend_name(&self) -> &'static str {
630 match &self.inner {
631 TransactionInner::Sqlite(_) => "sqlite",
632 TransactionInner::Postgres(_) => "postgres",
633 }
634 }
635
636 /// Commit the transaction explicitly.
637 ///
638 /// The closure-based helpers ([`transaction`] / [`transaction_sqlite`] /
639 /// [`transaction_pg`]) call this automatically on `Ok`. Use this only
640 /// when you obtained the transaction via [`begin`] / [`begin_sqlite`] /
641 /// [`begin_pg`] and are driving the lifecycle yourself.
642 pub async fn commit(self) -> Result<(), sqlx::Error> {
643 match self.inner {
644 TransactionInner::Sqlite(tx) => tx.commit().await,
645 TransactionInner::Postgres(tx) => tx.commit().await,
646 }
647 }
648
649 /// Roll back the transaction explicitly.
650 ///
651 /// The closure-based helpers call this automatically on `Err`. Use this
652 /// only in the manual-control pattern.
653 pub async fn rollback(self) -> Result<(), sqlx::Error> {
654 match self.inner {
655 TransactionInner::Sqlite(tx) => tx.rollback().await,
656 TransactionInner::Postgres(tx) => tx.rollback().await,
657 }
658 }
659}
660
661/// Begin a transaction against the ambient pool.
662///
663/// The `Transaction` is dropped-and-rolled-back if neither `commit` nor
664/// `rollback` is called before it goes out of scope (sqlx's drop impl).
665/// Most callers use the higher-level [`transaction`] / [`transaction_sqlite`]
666/// / [`transaction_pg`] closures instead.
667///
668/// # Panics
669///
670/// Panics if `App::build()` hasn't run.
671pub async fn begin() -> Result<Transaction, sqlx::Error> {
672 match pool_dispatched() {
673 DbPool::Sqlite(pool) => {
674 let tx = pool.begin().await?;
675 Ok(Transaction {
676 inner: TransactionInner::Sqlite(tx),
677 })
678 }
679 DbPool::Postgres(pool) => {
680 let tx = pool.begin().await?;
681 Ok(Transaction {
682 inner: TransactionInner::Postgres(tx),
683 })
684 }
685 }
686}
687
688/// Begin a transaction against an explicit SQLite pool.
689pub async fn begin_sqlite(pool: &sqlx::SqlitePool) -> Result<Transaction, sqlx::Error> {
690 let tx = pool.begin().await?;
691 Ok(Transaction {
692 inner: TransactionInner::Sqlite(tx),
693 })
694}
695
696/// Begin a transaction against an explicit Postgres pool.
697pub async fn begin_pg(pool: &sqlx::PgPool) -> Result<Transaction, sqlx::Error> {
698 let tx = pool.begin().await?;
699 Ok(Transaction {
700 inner: TransactionInner::Postgres(tx),
701 })
702}
703
704
705/// Pinned, boxed `Future` with a lifetime parameter.
706///
707/// This is the required shape for the closure argument to
708/// [`transaction`] / [`transaction_sqlite`] / [`transaction_pg`].
709/// The lifetime `'a` ties the future to the `&'a mut Transaction`
710/// reference so the borrow checker can verify that the transaction
711/// outlives the async work being done inside it.
712///
713/// Call sites construct this by calling `.boxed()` or wrapping the
714/// `async move` block:
715///
716/// ```rust,ignore
717/// use futures::FutureExt;
718/// use umbral::db::{transaction, TxFuture};
719///
720/// transaction(|tx| {
721/// Box::pin(async move {
722/// Post::objects().on_tx(tx).create(new_post).await?;
723/// Ok::<_, MyError>(())
724/// })
725/// }).await?;
726/// ```
727///
728/// The `async move { ... }` block captures the `&mut Transaction` by
729/// move and the `Box::pin(...)` wrapper satisfies the HRTB bound.
730pub type TxFuture<'a, T, E> = Pin<Box<dyn std::future::Future<Output = Result<T, E>> + Send + 'a>>;
731
732/// Run an async closure inside a database transaction against the ambient pool.
733///
734/// The closure receives `&mut Transaction`. On `Ok` the transaction is
735/// committed; on `Err` it is rolled back. Returns the closure's `Ok` value
736/// on success.
737///
738/// The closure must return a `TxFuture` (a `Pin<Box<dyn Future>>`).
739/// Use `Box::pin(async move { ... })`:
740///
741/// ```rust,ignore
742/// use umbral::db::transaction;
743///
744/// let order = transaction(|tx| Box::pin(async move {
745/// let o = Order::objects().on_tx(tx).create(new_order).await?;
746/// Inventory::objects()
747/// .on_tx(tx)
748/// .filter(inv::PRODUCT_ID.eq(sku))
749/// .update_values(delta)
750/// .await?;
751/// Ok::<_, MyError>(o)
752/// })).await?;
753/// ```
754///
755/// # Panics
756///
757/// Panics if `App::build()` hasn't run.
758pub async fn transaction<F, T, E>(f: F) -> Result<T, E>
759where
760 for<'a> F: FnOnce(&'a mut Transaction) -> TxFuture<'a, T, E>,
761 E: From<sqlx::Error>,
762{
763 let mut tx = begin().await.map_err(E::from)?;
764 match f(&mut tx).await {
765 Ok(val) => {
766 tx.commit().await.map_err(E::from)?;
767 Ok(val)
768 }
769 Err(e) => {
770 // Best-effort rollback — if it fails we surface the original error.
771 let _ = tx.rollback().await;
772 Err(e)
773 }
774 }
775}
776
777/// Run an async closure inside a SQLite transaction against an explicit pool.
778///
779/// The SQLite-specific variant of [`transaction`] for callers that want to
780/// pin to SQLite regardless of what the ambient pool is, or that are running
781/// outside of `App::build()` (e.g. tests).
782///
783/// See [`transaction`] for the closure shape.
784pub async fn transaction_sqlite<F, T, E>(pool: &sqlx::SqlitePool, f: F) -> Result<T, E>
785where
786 for<'a> F: FnOnce(&'a mut Transaction) -> TxFuture<'a, T, E>,
787 E: From<sqlx::Error>,
788{
789 let mut tx = begin_sqlite(pool).await.map_err(E::from)?;
790 match f(&mut tx).await {
791 Ok(val) => {
792 tx.commit().await.map_err(E::from)?;
793 Ok(val)
794 }
795 Err(e) => {
796 let _ = tx.rollback().await;
797 Err(e)
798 }
799 }
800}
801
802/// Run an async closure inside a Postgres transaction against an explicit pool.
803///
804/// The Postgres-specific variant of [`transaction`] for callers that want to
805/// pin to Postgres or run outside `App::build()`.
806///
807/// See [`transaction`] for the closure shape.
808pub async fn transaction_pg<F, T, E>(pool: &sqlx::PgPool, f: F) -> Result<T, E>
809where
810 for<'a> F: FnOnce(&'a mut Transaction) -> TxFuture<'a, T, E>,
811 E: From<sqlx::Error>,
812{
813 let mut tx = begin_pg(pool).await.map_err(E::from)?;
814 match f(&mut tx).await {
815 Ok(val) => {
816 tx.commit().await.map_err(E::from)?;
817 Ok(val)
818 }
819 Err(e) => {
820 let _ = tx.rollback().await;
821 Err(e)
822 }
823 }
824}
825
826#[cfg(test)]
827mod tests {
828 use super::*;
829
830 // `pool` and `pool_for` read the process-wide `POOLS` `OnceLock`, which
831 // can only be set once per process. Under cargo test's parallel runner
832 // that makes them unreliable to cover directly without `serial_test` or
833 // a refactor, so they're intentionally out of scope here. Same reason
834 // the "pool() panics before init" path isn't exercised: another test in
835 // the same process may have already populated the lock.
836 //
837 // Mirrors the settings module's stance on its own `init`/`get` pair.
838
839 /// `connect` hands back a SQLite pool wrapped in `DbPool::Sqlite` we
840 /// can actually run queries through.
841 #[tokio::test]
842 async fn connect_returns_a_working_pool_against_in_memory_sqlite() {
843 let pool = connect("sqlite::memory:")
844 .await
845 .expect("in-memory sqlite should always connect");
846
847 let sqlite = pool.as_sqlite().expect("should be Sqlite variant");
848 let (one,): (i64,) = sqlx::query_as("SELECT 1")
849 .fetch_one(sqlite)
850 .await
851 .expect("SELECT 1 should succeed on a fresh pool");
852
853 assert_eq!(one, 1);
854 }
855
856 /// A URL sqlx can't parse surfaces as a plain `sqlx::Error`. We don't
857 /// pin the variant — the family is the contract.
858 #[tokio::test]
859 async fn connect_errors_on_malformed_url() {
860 let result = connect("not-a-real-url").await;
861 assert!(
862 result.is_err(),
863 "expected sqlx to reject a malformed url, got Ok"
864 );
865 }
866
867 /// MySQL and similar schemes that umbral hasn't shipped yet
868 /// surface as a clear configuration error rather than a
869 /// driver-internal one.
870 #[tokio::test]
871 async fn connect_rejects_unsupported_scheme() {
872 let result = connect("mysql://user:pass@host/db").await;
873 match result {
874 Err(sqlx::Error::Configuration(msg)) => {
875 assert!(msg.to_string().contains("mysql"));
876 }
877 other => panic!("expected Configuration error, got {other:?}"),
878 }
879 }
880
881 /// `From<SqlitePool>` and the variant accessors round-trip.
882 #[tokio::test]
883 async fn sqlite_pool_round_trips_through_dbpool() {
884 let sp = SqlitePool::connect("sqlite::memory:").await.unwrap();
885 let dp: DbPool = sp.clone().into();
886 assert_eq!(dp.backend_name(), "sqlite");
887 assert!(dp.as_sqlite().is_some());
888 assert!(dp.as_postgres().is_none());
889 }
890}