Skip to main content

umbral_core/
db.rs

1//! Database pool registry and connection helpers.
2//!
3//! ## DbPool: the multi-backend seam
4//!
5//! [`DbPool`] is a small enum that wraps either a [`sqlx::SqlitePool`]
6//! or a [`sqlx::PgPool`]. It's the type [`connect`] returns and the
7//! type [`AppBuilder::database`](crate::app::AppBuilder::database)
8//! stores, so the framework remembers which backend each registered
9//! alias is connected to.
10//!
11//! ### Why an enum, not `sqlx::AnyPool`
12//!
13//! `sqlx::AnyPool` is the more "correct" abstraction at the type
14//! level: one pool type that dispatches to the right driver at
15//! runtime. But it has a real-world cost — sea-query-binder (the
16//! crate the QuerySet uses to bind parameters) doesn't have an
17//! `Any` backend; values must be bound through the per-driver
18//! query builder. Forcing every plugin and the queryset onto
19//! `AnyPool` therefore turns the simple multi-backend goal into a
20//! cascade through every binding site.
21//!
22//! The enum is the right shape for now. Every plugin still gets a
23//! typed `SqlitePool` from [`pool`] / [`pool_for`], and the
24//! ergonomics of `sqlx::query(...)` against that pool stay
25//! identical. Phase 2 of the Postgres rollout (per `FEATURES.md`)
26//! threads the variant choice through the migration engine and
27//! queryset; Phase 1 only needs the type seam.
28//!
29//! ### Postgres at boot, today
30//!
31//! [`connect`] accepts both `sqlite://...` and `postgres://...`
32//! URLs and returns a [`DbPool`] of the matching variant. The
33//! detection mirrors [`crate::backend::detect`], so the boot path
34//! has one URL parser and they can't drift.
35//!
36//! At Phase 1 the rest of the framework (queryset, migration
37//! engine, every plugin) still reads through [`pool`] / [`pool_for`]
38//! which hand back a `SqlitePool`. If the registered pool is
39//! actually a `PgPool`, those functions panic with a clear
40//! "Postgres support arrives in Phase 2" message. That's
41//! deliberate: the type seam exists, but callers that aren't
42//! ready for Postgres surface immediately at runtime rather than
43//! limping along and producing wrong results.
44
45use std::collections::HashMap;
46use std::pin::Pin;
47use std::sync::OnceLock;
48
49use sqlx::sqlite::{SqliteConnectOptions, SqlitePoolOptions, SqliteSynchronous};
50use sqlx::{ConnectOptions, PgPool, SqlitePool};
51use std::str::FromStr;
52use std::time::Duration;
53
54pub mod route_context;
55pub mod router;
56
57pub use route_context::{RouteContext, TenantKey, current as route_context};
58pub use router::{Alias, DatabaseRouter, DefaultRouter, RouteOp, Schema, router};
59
60/// A pool of database connections, typed by backend.
61///
62/// Cloning is cheap — both variants wrap an `Arc`-backed inner
63/// pool, so a `clone()` just bumps the refcount.
64#[derive(Debug, Clone)]
65pub enum DbPool {
66    /// SQLite-backed connection pool. The default through Phase 1
67    /// and the only variant the queryset / migration engine accepts
68    /// today.
69    Sqlite(SqlitePool),
70    /// Postgres-backed connection pool. Connectable at Phase 1, but
71    /// any code path that calls into the queryset or migration
72    /// engine against this variant panics with a clear "arrives in
73    /// Phase 2" message. The seam itself is the deliverable here.
74    Postgres(PgPool),
75}
76
77impl DbPool {
78    /// Borrow the inner `SqlitePool`. Returns `None` for a Postgres
79    /// pool. Phase 1 callers that haven't migrated to the dispatch
80    /// API yet typically reach for [`Self::sqlite_or_panic`]; the
81    /// returned-Option variant is for the (rare today) code that
82    /// wants to gracefully fall back.
83    pub fn as_sqlite(&self) -> Option<&SqlitePool> {
84        match self {
85            DbPool::Sqlite(p) => Some(p),
86            DbPool::Postgres(_) => None,
87        }
88    }
89
90    /// Borrow the inner `PgPool`. Returns `None` for a SQLite pool.
91    pub fn as_postgres(&self) -> Option<&PgPool> {
92        match self {
93            DbPool::Sqlite(_) => None,
94            DbPool::Postgres(p) => Some(p),
95        }
96    }
97
98    /// Borrow the inner `SqlitePool`, panicking with a clear "Postgres
99    /// support arrives in Phase 2" message on a Postgres variant. Used
100    /// by [`pool`] and [`pool_for`] so existing plugin code (that
101    /// expects a `SqlitePool`) doesn't quietly limp along when the
102    /// operator connects to Postgres.
103    pub fn sqlite_or_panic(&self) -> &SqlitePool {
104        self.as_sqlite().expect(
105            "umbral: a Postgres pool is registered but this code path \
106             still reads SqlitePool. Full Postgres support lands in \
107             Phase 2 of the rollout — see FEATURES.md and the \
108             `DbPool` rustdoc.",
109        )
110    }
111
112    /// The string identifier of the underlying backend. Matches
113    /// [`crate::backend::DatabaseBackend::name`] for the active
114    /// pool variant.
115    pub fn backend_name(&self) -> &'static str {
116        match self {
117            DbPool::Sqlite(_) => "sqlite",
118            DbPool::Postgres(_) => "postgres",
119        }
120    }
121}
122
123impl From<SqlitePool> for DbPool {
124    fn from(pool: SqlitePool) -> Self {
125        DbPool::Sqlite(pool)
126    }
127}
128
129impl From<PgPool> for DbPool {
130    fn from(pool: PgPool) -> Self {
131        DbPool::Postgres(pool)
132    }
133}
134
135/// Holds all registered database pools, keyed by alias.
136/// The "default" pool is always present after `App::build()` succeeds.
137static POOLS: OnceLock<HashMap<String, DbPool>> = OnceLock::new();
138
139/// Runtime tenant-pool registry for **database-per-tenant** multitenancy:
140/// pools registered AFTER `App::build()`, as tenants are onboarded (e.g. by a
141/// `DatabaseRouter` that maps a request's tenant to its own database). The
142/// static `POOLS` map above is set once at build; this `RwLock`-backed map
143/// grows at runtime via [`register_tenant_pool`]. Stored pools are leaked to
144/// `&'static` on insert — a tenant pool lives for the whole process (you never
145/// drop one mid-serve), so [`pool_for_dispatched`] keeps its zero-cost
146/// `&'static DbPool` return: the `&'static` is copied out before the read guard
147/// drops, so no lock guard ever escapes.
148static DYNAMIC_POOLS: OnceLock<std::sync::RwLock<HashMap<String, &'static DbPool>>> =
149    OnceLock::new();
150
151/// Global default for whether ORM write terminals should wrap in a
152/// transaction. Set by `AppBuilder::atomic_transactions(...)`; read by
153/// every terminal that supports `.atomic()` / `.non_atomic()`. Unset
154/// (the default) means "no wrapping" — preserves existing behaviour for
155/// apps that don't opt in.
156static ATOMIC_DEFAULT: OnceLock<bool> = OnceLock::new();
157
158/// Publish the app-wide atomic-transactions default. Called by
159/// `AppBuilder::build()` exactly when the user set the flag via
160/// `atomic_transactions(...)`. Idempotent across re-init attempts —
161/// the first set wins, matching the rest of the OnceLock-backed
162/// ambient state.
163pub(crate) fn init_atomic_default(enabled: bool) {
164    let _ = ATOMIC_DEFAULT.set(enabled);
165}
166
167/// Read the app-wide atomic-transactions default. Returns `false` when
168/// the builder didn't call `atomic_transactions(...)` (or when the
169/// ambient state hasn't been published yet, as in unit tests that
170/// drive the ORM with `.on(&pool)` and never call `App::build()`).
171pub fn atomic_default() -> bool {
172    *ATOMIC_DEFAULT.get().unwrap_or(&false)
173}
174
175/// Initialize the pool registry. Called by `AppBuilder::build()` only.
176pub(crate) fn init(pools: HashMap<String, DbPool>) {
177    POOLS
178        .set(pools)
179        .expect("umbral::db::init called more than once");
180}
181
182/// Return the default connection pool, typed as a [`SqlitePool`].
183///
184/// This is the function every plugin and the queryset call. The
185/// internal storage is a [`DbPool`]; this unwraps to the
186/// `SqlitePool` variant or panics with a Phase-2 hint, matching
187/// the documented Phase 1 contract.
188///
189/// # Panics
190///
191/// Panics if `App::build()` hasn't run or the registered default
192/// pool is Postgres.
193pub fn pool() -> SqlitePool {
194    pool_dispatched().sqlite_or_panic().clone()
195}
196
197/// Return the default connection pool as a typed [`DbPool`].
198///
199/// Use this from code that's ready to dispatch on backend (the
200/// migration engine and queryset will move to this surface in
201/// Phase 2). Plugin code can stay on [`pool`] until then.
202///
203/// # Panics
204///
205/// Panics if `App::build()` hasn't run.
206pub fn pool_dispatched() -> &'static DbPool {
207    POOLS
208        .get()
209        .expect("umbral: db pool not initialised — did you call App::build()?")
210        .get("default")
211        .expect("umbral: no default database registered")
212}
213
214/// Like [`pool_dispatched`] but returns `None` instead of panicking
215/// when no pool is registered yet (`App::build()` hasn't run, or this
216/// is a pure SQL-building call such as `QuerySet::to_sql` in a test with
217/// no app booted). Used by runtime advisory paths that must not crash a
218/// query-builder call — see the RIGHT-JOIN-on-old-SQLite warning.
219pub fn try_pool_dispatched() -> Option<&'static DbPool> {
220    POOLS.get().and_then(|pools| pools.get("default"))
221}
222
223/// Return a named connection pool, typed as a [`SqlitePool`].
224///
225/// # Panics
226///
227/// Panics if `App::build()` hasn't run, the alias isn't registered,
228/// or the registered pool is Postgres.
229pub fn pool_for(alias: &str) -> SqlitePool {
230    pool_for_dispatched(alias).sqlite_or_panic().clone()
231}
232
233/// Return a named connection pool as a typed [`DbPool`]. Phase 2
234/// surface; see [`pool_dispatched`].
235///
236/// Resolution order: the build-time `POOLS` map first, then the runtime
237/// [`register_tenant_pool`] registry (database-per-tenant). Panics only when
238/// the alias is in neither.
239pub fn pool_for_dispatched(alias: &str) -> &'static DbPool {
240    if let Some(p) = POOLS.get().and_then(|pools| pools.get(alias)) {
241        return p;
242    }
243    if let Some(p) = DYNAMIC_POOLS
244        .get()
245        .and_then(|reg| reg.read().ok().and_then(|m| m.get(alias).copied()))
246    {
247        return p;
248    }
249    if POOLS.get().is_none() {
250        panic!("umbral: db pool not initialised — did you call App::build()?");
251    }
252    panic!("umbral: no database registered under alias '{alias}'");
253}
254
255/// Register a database pool under `alias` at runtime — the database-per-tenant
256/// seam. Unlike the build-time `App::builder().database(alias, pool)` (which
257/// fills the static pool map), this may be called any time after `App::build()`
258/// as tenants are onboarded. First-write-wins: re-registering an existing alias
259/// is a no-op (a re-resolution of the same tenant won't churn its pool) and the
260/// surplus pool is dropped without leaking. The stored pool is leaked to
261/// `&'static` because tenant pools are process-lifetime.
262///
263/// A [`DatabaseRouter`](crate::db::router::DatabaseRouter) whose
264/// `db_for_read`/`db_for_write` returns `alias` for a tenant request then routes
265/// that tenant's queries to this pool.
266pub fn register_tenant_pool(alias: impl Into<String>, pool: DbPool) {
267    let alias = alias.into();
268    let mut guard = DYNAMIC_POOLS
269        .get_or_init(|| std::sync::RwLock::new(HashMap::new()))
270        .write()
271        .expect("umbral: dynamic pool registry poisoned");
272    if guard.contains_key(&alias) {
273        return; // first-write-wins; `pool` is dropped here, not leaked
274    }
275    let leaked: &'static DbPool = Box::leak(Box::new(pool));
276    guard.insert(alias, leaked);
277}
278
279/// True if `alias` resolves to a registered pool — build-time `POOLS` or the
280/// runtime tenant registry. A router can use this to fall back to the default
281/// pool for a tenant whose database hasn't been onboarded yet.
282pub fn pool_alias_registered(alias: &str) -> bool {
283    POOLS.get().is_some_and(|p| p.contains_key(alias))
284        || DYNAMIC_POOLS
285            .get()
286            .and_then(|reg| reg.read().ok().map(|m| m.contains_key(alias)))
287            .unwrap_or(false)
288}
289
290/// Ping the default database pool with a backend-appropriate liveness
291/// query (`SELECT 1`).
292///
293/// Resolves the ambient pool via [`pool_dispatched`] and dispatches:
294///
295/// - **SQLite** — `SELECT 1` via the sqlite driver.
296/// - **Postgres** — `SELECT 1` via the postgres driver.
297///
298/// Returns `Ok(())` when the pool is reachable. Returns
299/// `Err(sqlx::Error)` on any connection or query failure so callers
300/// can map it to a wire-friendly string without exposing the full sqlx
301/// error type.
302///
303/// # Panics
304///
305/// Panics if `App::build()` hasn't run (same contract as
306/// [`pool_dispatched`]).
307pub async fn ping() -> Result<(), sqlx::Error> {
308    match pool_dispatched() {
309        DbPool::Sqlite(p) => {
310            sqlx::query("SELECT 1").execute(p).await.map(|_| ())
311        }
312        DbPool::Postgres(p) => {
313            sqlx::query("SELECT 1").execute(p).await.map(|_| ())
314        }
315    }
316}
317
318/// List every registered pool alias, sorted alphabetically.
319///
320/// Used by the migration engine to walk each DB in deterministic
321/// order so per-DB tracking tables get created and per-DB diffs run
322/// against the right model subset. The `"default"` alias is always
323/// present after `App::build()` succeeds and lands wherever
324/// alphabetical sort puts it (typically first).
325///
326/// # Panics
327///
328/// Panics if `App::build()` hasn't run.
329pub fn registered_aliases() -> Vec<String> {
330    let mut aliases: Vec<String> = POOLS
331        .get()
332        .expect("umbral: db pool not initialised — did you call App::build()?")
333        .keys()
334        .cloned()
335        .collect();
336    aliases.sort();
337    aliases
338}
339
340/// Open a new connection pool for the given database URL.
341///
342/// Dispatches on the URL scheme:
343///
344/// - `sqlite://...` or `sqlite::memory:` returns a
345///   [`DbPool::Sqlite`].
346/// - `postgres://...` / `postgresql://...` returns a
347///   [`DbPool::Postgres`].
348///
349/// Any other scheme surfaces as an `sqlx::Error::Configuration`.
350/// For callers that already have a typed pool, [`From`] impls on
351/// [`DbPool`] convert directly: `let dp: DbPool = sqlite_pool.into();`.
352pub async fn connect(url: &str) -> Result<DbPool, sqlx::Error> {
353    let scheme = url
354        .split("://")
355        .next()
356        .and_then(|s| s.split(':').next())
357        .unwrap_or(url);
358    match scheme {
359        "sqlite" => Ok(DbPool::Sqlite(connect_sqlite(url).await?)),
360        "postgres" | "postgresql" => Ok(DbPool::Postgres(connect_postgres(url).await?)),
361        other => Err(sqlx::Error::Configuration(
362            format!(
363                "umbral::db::connect: unsupported URL scheme `{other}://`. \
364                 Phase 1 supports `sqlite://` and `postgres://`."
365            )
366            .into(),
367        )),
368    }
369}
370
371/// The effective pool configuration, resolved from [`crate::settings`]
372/// when installed and falling back to the documented production defaults
373/// otherwise (a pool can be opened before settings are installed). Shared
374/// by [`connect_postgres`] and [`connect_sqlite`] so both backends honour
375/// the same `UMBRAL_DB_*` knobs (gaps2 #91).
376struct PoolConfig {
377    max_connections: u32,
378    min_connections: u32,
379    acquire_timeout_secs: u64,
380    idle_timeout_secs: Option<u64>,
381    max_lifetime_secs: Option<u64>,
382    test_before_acquire: bool,
383}
384
385impl PoolConfig {
386    fn resolve() -> Self {
387        match crate::settings::get_opt() {
388            Some(s) => PoolConfig {
389                max_connections: s.db_max_connections,
390                min_connections: s.db_min_connections,
391                acquire_timeout_secs: s.db_acquire_timeout_secs,
392                idle_timeout_secs: s.db_idle_timeout_secs,
393                max_lifetime_secs: s.db_max_lifetime_secs,
394                test_before_acquire: s.db_test_before_acquire,
395            },
396            // Defaults mirror the `default_db_*` fns in `settings`.
397            None => PoolConfig {
398                max_connections: 10,
399                min_connections: 0,
400                acquire_timeout_secs: 30,
401                idle_timeout_secs: Some(600),
402                max_lifetime_secs: Some(1800),
403                test_before_acquire: true,
404            },
405        }
406    }
407
408    /// Emit one operator-facing line describing the pool that's about to
409    /// be built, so the effective config is visible in the boot log.
410    fn log(&self, backend: &str) {
411        tracing::info!(
412            backend,
413            max_connections = self.max_connections.max(1),
414            min_connections = self.min_connections,
415            acquire_timeout_secs = self.acquire_timeout_secs,
416            idle_timeout_secs = ?self.idle_timeout_secs,
417            max_lifetime_secs = ?self.max_lifetime_secs,
418            test_before_acquire = self.test_before_acquire,
419            "umbral: opening database pool"
420        );
421    }
422}
423
424/// Open a Postgres pool from a URL with umbral's pool configuration.
425///
426/// PERF-5 / gaps2 #91: bare `PgPool::connect` uses sqlx's defaults with
427/// **no acquire timeout**, so a saturated pool blocks request tasks
428/// forever. We always apply the full set of pool knobs — `max_connections`,
429/// `min_connections`, a bounded `acquire_timeout` (fail fast),
430/// `idle_timeout`, `max_lifetime`, and `test_before_acquire` — read from
431/// [`crate::settings`] when available (falling back to the documented
432/// production defaults if the pool is opened before settings are
433/// installed). `idle_timeout`/`max_lifetime` are only applied when `Some`;
434/// a `None` (env `0`/empty) leaves that recycling disabled.
435pub async fn connect_postgres(url: &str) -> Result<PgPool, sqlx::Error> {
436    use std::time::Duration;
437    let cfg = PoolConfig::resolve();
438    cfg.log("postgres");
439
440    let mut opts = sqlx::postgres::PgPoolOptions::new()
441        .max_connections(cfg.max_connections.max(1))
442        .min_connections(cfg.min_connections)
443        .acquire_timeout(Duration::from_secs(cfg.acquire_timeout_secs))
444        .test_before_acquire(cfg.test_before_acquire);
445    if let Some(secs) = cfg.idle_timeout_secs {
446        opts = opts.idle_timeout(Duration::from_secs(secs));
447    }
448    if let Some(secs) = cfg.max_lifetime_secs {
449        opts = opts.max_lifetime(Duration::from_secs(secs));
450    }
451    opts.connect(url).await
452}
453
454/// Open a SQLite-backed pool from a URL.
455///
456/// Applies the standard production PRAGMAs to every connection in the
457/// pool: WAL journal, NORMAL synchronous, a 5-second busy-timeout, and
458/// foreign-key enforcement on. Without these, a fresh `SqlitePool` ends
459/// up in `journal_mode = DELETE` + `synchronous = FULL` — the safe
460/// SQLite defaults that cost ~1-4 seconds per concurrent INSERT once
461/// any other connection touches the file (the rollback-journal lock
462/// serialises writers).
463///
464/// | PRAGMA | Value | Why |
465/// |---|---|---|
466/// | `journal_mode` | `WAL` | Readers don't block writers; a single writer at a time but no full-file lock. Order-of-magnitude faster for any concurrent workload — typically the session/auth/audit tables fanning out. |
467/// | `synchronous` | `NORMAL` | Skips the per-commit fsync of the rollback journal; safe with WAL since the WAL log is fsynced on checkpoint. The official SQLite docs call this the right pairing with WAL for "most applications". |
468/// | `busy_timeout` | `5000ms` | Wait up to 5 s for a contended writer to release the lock before raising `SQLITE_BUSY`. Without this, two concurrent writers immediately race to error. |
469/// | `foreign_keys` | `ON` | sqlite turns FK enforcement off by default. The ORM emits `REFERENCES` clauses assuming they're respected — turning it on per connection makes the FK contract real. |
470///
471/// **In-memory URLs are backed by a process-unique temp file.** A bare
472/// `sqlite::memory:` gives every connection in the pool its OWN private,
473/// empty database, so a table created on one connection is invisible to a
474/// query that lands on another — and a shared in-memory database doesn't
475/// survive the connection (or the tokio runtime) that created it being
476/// dropped. Both surface as a flaky "no such table" whenever a pool is
477/// reused across queries or test cases. Routing in-memory URLs through a
478/// small temp file (which every connection sees and which persists for the
479/// process) sidesteps both — the same approach `umbral-testing::TempPool`
480/// already documents. File-backed (`sqlite://app.db`) and Postgres URLs are
481/// untouched.
482pub async fn connect_sqlite(url: &str) -> Result<SqlitePool, sqlx::Error> {
483    use std::sync::atomic::{AtomicU64, Ordering};
484    static MEM_SEQ: AtomicU64 = AtomicU64::new(0);
485
486    let lower = url.to_ascii_lowercase();
487    let in_memory = lower.contains(":memory:") || lower.contains("mode=memory");
488
489    let opts = if in_memory {
490        let n = MEM_SEQ.fetch_add(1, Ordering::Relaxed);
491        let path =
492            std::env::temp_dir().join(format!("umbral_mem_{}_{n}.sqlite", std::process::id()));
493        // Best-effort: remove a stale file from a previous run with this
494        // exact (pid, seq) — pids recycle. WAL/SHM siblings are recreated.
495        let _ = std::fs::remove_file(&path);
496        SqliteConnectOptions::new()
497            .filename(&path)
498            .create_if_missing(true)
499    } else {
500        SqliteConnectOptions::from_str(url)?
501    };
502    let opts = opts
503        .journal_mode(sqlx::sqlite::SqliteJournalMode::Wal)
504        .synchronous(SqliteSynchronous::Normal)
505        .busy_timeout(Duration::from_secs(5))
506        .foreign_keys(true)
507        // Disable per-statement logging — sqlx's default INFO-level
508        // logger reads every statement before execution, which adds a
509        // measurable per-query overhead under load. The `slow statement`
510        // WARN at the 1-second threshold stays on, since it goes via a
511        // separate log target.
512        .log_statements(tracing::log::LevelFilter::Off);
513
514    // gaps2 #91: apply the same settings-driven pool knobs as Postgres so
515    // a single `UMBRAL_DB_*` configuration governs every backend. SQLite is
516    // effectively single-writer (WAL serialises writers behind one lock),
517    // so a large `max_connections` mainly buys concurrent *readers*; the
518    // knob is still honoured rather than hardcoding a divergent SQLite path.
519    let cfg = PoolConfig::resolve();
520    cfg.log("sqlite");
521    let mut pool_opts = SqlitePoolOptions::new()
522        .max_connections(cfg.max_connections.max(1))
523        .min_connections(cfg.min_connections)
524        .acquire_timeout(Duration::from_secs(cfg.acquire_timeout_secs))
525        .test_before_acquire(cfg.test_before_acquire);
526    if let Some(secs) = cfg.idle_timeout_secs {
527        pool_opts = pool_opts.idle_timeout(Duration::from_secs(secs));
528    }
529    if let Some(secs) = cfg.max_lifetime_secs {
530        pool_opts = pool_opts.max_lifetime(Duration::from_secs(secs));
531    }
532    pool_opts.connect_with(opts).await
533}
534
535/// Gracefully close the ambient default database pool (gaps2 #91).
536///
537/// Call this once during shutdown — after the HTTP server has stopped
538/// accepting connections — to let sqlx flush in-flight work and close
539/// every pooled connection cleanly rather than having them dropped
540/// abruptly when the process exits. For SQLite this also lets WAL
541/// checkpoint; for Postgres it sends a clean `Terminate` so the server
542/// doesn't log the connections as unexpectedly lost.
543///
544/// Closing is terminal: the ambient [`OnceLock`] is left in place (it
545/// can't be unset), so the pool object remains registered but is closed.
546/// Acquiring from a closed pool errors, which is the intended post-
547/// shutdown behaviour. A no-op if no pool was ever registered.
548///
549/// ```rust,ignore
550/// // in your shutdown handler, after the server stops:
551/// umbral::db::close().await;
552/// ```
553pub async fn close() {
554    if let Some(pools) = POOLS.get() {
555        for db in pools.values() {
556            match db {
557                DbPool::Sqlite(p) => p.close().await,
558                DbPool::Postgres(p) => p.close().await,
559            }
560        }
561    }
562}
563
564// =============================================================================
565// Transaction support
566// =============================================================================
567
568/// An active database transaction, typed by backend.
569///
570/// `Transaction` wraps either a `sqlx::Transaction<'static, sqlx::Sqlite>` or
571/// a `sqlx::Transaction<'static, sqlx::Postgres>` and provides the executor
572/// surface needed by the ORM's query terminals.
573///
574/// ## How to obtain one
575///
576/// The typical path is through the top-level closure helpers:
577///
578/// ```rust,ignore
579/// use umbral::db::transaction;
580///
581/// let order = transaction(|tx| async move {
582///     let o = Order::objects().on_tx(tx).create(new_order).await?;
583///     Inventory::objects().on_tx(tx).filter(...).update_values(...).await?;
584///     Ok::<_, MyError>(o)
585/// }).await?;
586/// ```
587///
588/// For manual control (committing or rolling back yourself) call
589/// [`begin`] / [`begin_sqlite`] / [`begin_pg`] directly.
590///
591/// ## Executor contract
592///
593/// The `as_sqlite_mut` / `as_pg_mut` accessors return a mutable reference to
594/// the underlying sqlx transaction so ORM internals can call
595/// `sqlx::query(...).execute(&mut *inner)`. Both the `QuerySet::on_tx` and
596/// `Manager::create_in_tx` methods receive `&mut Transaction` and dispatch
597/// through these accessors.
598pub struct Transaction {
599    inner: TransactionInner,
600}
601
602enum TransactionInner {
603    Sqlite(sqlx::Transaction<'static, sqlx::Sqlite>),
604    Postgres(sqlx::Transaction<'static, sqlx::Postgres>),
605}
606
607impl Transaction {
608    /// Return a mutable reference to the inner SQLite transaction, or `None`
609    /// when this is a Postgres transaction.
610    pub fn as_sqlite_mut(&mut self) -> Option<&mut sqlx::Transaction<'static, sqlx::Sqlite>> {
611        match &mut self.inner {
612            TransactionInner::Sqlite(tx) => Some(tx),
613            TransactionInner::Postgres(_) => None,
614        }
615    }
616
617    /// Return a mutable reference to the inner Postgres transaction, or `None`
618    /// when this is a SQLite transaction.
619    pub fn as_pg_mut(&mut self) -> Option<&mut sqlx::Transaction<'static, sqlx::Postgres>> {
620        match &mut self.inner {
621            TransactionInner::Sqlite(_) => None,
622            TransactionInner::Postgres(tx) => Some(tx),
623        }
624    }
625
626    /// The backend name — `"sqlite"` or `"postgres"`. Mirrors
627    /// [`DbPool::backend_name`] so shared dispatch helpers can use the same
628    /// match arm.
629    pub fn backend_name(&self) -> &'static str {
630        match &self.inner {
631            TransactionInner::Sqlite(_) => "sqlite",
632            TransactionInner::Postgres(_) => "postgres",
633        }
634    }
635
636    /// Commit the transaction explicitly.
637    ///
638    /// The closure-based helpers ([`transaction`] / [`transaction_sqlite`] /
639    /// [`transaction_pg`]) call this automatically on `Ok`. Use this only
640    /// when you obtained the transaction via [`begin`] / [`begin_sqlite`] /
641    /// [`begin_pg`] and are driving the lifecycle yourself.
642    pub async fn commit(self) -> Result<(), sqlx::Error> {
643        match self.inner {
644            TransactionInner::Sqlite(tx) => tx.commit().await,
645            TransactionInner::Postgres(tx) => tx.commit().await,
646        }
647    }
648
649    /// Roll back the transaction explicitly.
650    ///
651    /// The closure-based helpers call this automatically on `Err`. Use this
652    /// only in the manual-control pattern.
653    pub async fn rollback(self) -> Result<(), sqlx::Error> {
654        match self.inner {
655            TransactionInner::Sqlite(tx) => tx.rollback().await,
656            TransactionInner::Postgres(tx) => tx.rollback().await,
657        }
658    }
659}
660
661/// Begin a transaction against the ambient pool.
662///
663/// The `Transaction` is dropped-and-rolled-back if neither `commit` nor
664/// `rollback` is called before it goes out of scope (sqlx's drop impl).
665/// Most callers use the higher-level [`transaction`] / [`transaction_sqlite`]
666/// / [`transaction_pg`] closures instead.
667///
668/// # Panics
669///
670/// Panics if `App::build()` hasn't run.
671pub async fn begin() -> Result<Transaction, sqlx::Error> {
672    match pool_dispatched() {
673        DbPool::Sqlite(pool) => {
674            let tx = pool.begin().await?;
675            Ok(Transaction {
676                inner: TransactionInner::Sqlite(tx),
677            })
678        }
679        DbPool::Postgres(pool) => {
680            let tx = pool.begin().await?;
681            Ok(Transaction {
682                inner: TransactionInner::Postgres(tx),
683            })
684        }
685    }
686}
687
688/// Begin a transaction against an explicit SQLite pool.
689pub async fn begin_sqlite(pool: &sqlx::SqlitePool) -> Result<Transaction, sqlx::Error> {
690    let tx = pool.begin().await?;
691    Ok(Transaction {
692        inner: TransactionInner::Sqlite(tx),
693    })
694}
695
696/// Begin a transaction against an explicit Postgres pool.
697pub async fn begin_pg(pool: &sqlx::PgPool) -> Result<Transaction, sqlx::Error> {
698    let tx = pool.begin().await?;
699    Ok(Transaction {
700        inner: TransactionInner::Postgres(tx),
701    })
702}
703
704
705/// Pinned, boxed `Future` with a lifetime parameter.
706///
707/// This is the required shape for the closure argument to
708/// [`transaction`] / [`transaction_sqlite`] / [`transaction_pg`].
709/// The lifetime `'a` ties the future to the `&'a mut Transaction`
710/// reference so the borrow checker can verify that the transaction
711/// outlives the async work being done inside it.
712///
713/// Call sites construct this by calling `.boxed()` or wrapping the
714/// `async move` block:
715///
716/// ```rust,ignore
717/// use futures::FutureExt;
718/// use umbral::db::{transaction, TxFuture};
719///
720/// transaction(|tx| {
721///     Box::pin(async move {
722///         Post::objects().on_tx(tx).create(new_post).await?;
723///         Ok::<_, MyError>(())
724///     })
725/// }).await?;
726/// ```
727///
728/// The `async move { ... }` block captures the `&mut Transaction` by
729/// move and the `Box::pin(...)` wrapper satisfies the HRTB bound.
730pub type TxFuture<'a, T, E> = Pin<Box<dyn std::future::Future<Output = Result<T, E>> + Send + 'a>>;
731
732/// Run an async closure inside a database transaction against the ambient pool.
733///
734/// The closure receives `&mut Transaction`. On `Ok` the transaction is
735/// committed; on `Err` it is rolled back. Returns the closure's `Ok` value
736/// on success.
737///
738/// The closure must return a `TxFuture` (a `Pin<Box<dyn Future>>`).
739/// Use `Box::pin(async move { ... })`:
740///
741/// ```rust,ignore
742/// use umbral::db::transaction;
743///
744/// let order = transaction(|tx| Box::pin(async move {
745///     let o = Order::objects().on_tx(tx).create(new_order).await?;
746///     Inventory::objects()
747///         .on_tx(tx)
748///         .filter(inv::PRODUCT_ID.eq(sku))
749///         .update_values(delta)
750///         .await?;
751///     Ok::<_, MyError>(o)
752/// })).await?;
753/// ```
754///
755/// # Panics
756///
757/// Panics if `App::build()` hasn't run.
758pub async fn transaction<F, T, E>(f: F) -> Result<T, E>
759where
760    for<'a> F: FnOnce(&'a mut Transaction) -> TxFuture<'a, T, E>,
761    E: From<sqlx::Error>,
762{
763    let mut tx = begin().await.map_err(E::from)?;
764    match f(&mut tx).await {
765        Ok(val) => {
766            tx.commit().await.map_err(E::from)?;
767            Ok(val)
768        }
769        Err(e) => {
770            // Best-effort rollback — if it fails we surface the original error.
771            let _ = tx.rollback().await;
772            Err(e)
773        }
774    }
775}
776
777/// Run an async closure inside a SQLite transaction against an explicit pool.
778///
779/// The SQLite-specific variant of [`transaction`] for callers that want to
780/// pin to SQLite regardless of what the ambient pool is, or that are running
781/// outside of `App::build()` (e.g. tests).
782///
783/// See [`transaction`] for the closure shape.
784pub async fn transaction_sqlite<F, T, E>(pool: &sqlx::SqlitePool, f: F) -> Result<T, E>
785where
786    for<'a> F: FnOnce(&'a mut Transaction) -> TxFuture<'a, T, E>,
787    E: From<sqlx::Error>,
788{
789    let mut tx = begin_sqlite(pool).await.map_err(E::from)?;
790    match f(&mut tx).await {
791        Ok(val) => {
792            tx.commit().await.map_err(E::from)?;
793            Ok(val)
794        }
795        Err(e) => {
796            let _ = tx.rollback().await;
797            Err(e)
798        }
799    }
800}
801
802/// Run an async closure inside a Postgres transaction against an explicit pool.
803///
804/// The Postgres-specific variant of [`transaction`] for callers that want to
805/// pin to Postgres or run outside `App::build()`.
806///
807/// See [`transaction`] for the closure shape.
808pub async fn transaction_pg<F, T, E>(pool: &sqlx::PgPool, f: F) -> Result<T, E>
809where
810    for<'a> F: FnOnce(&'a mut Transaction) -> TxFuture<'a, T, E>,
811    E: From<sqlx::Error>,
812{
813    let mut tx = begin_pg(pool).await.map_err(E::from)?;
814    match f(&mut tx).await {
815        Ok(val) => {
816            tx.commit().await.map_err(E::from)?;
817            Ok(val)
818        }
819        Err(e) => {
820            let _ = tx.rollback().await;
821            Err(e)
822        }
823    }
824}
825
826#[cfg(test)]
827mod tests {
828    use super::*;
829
830    // `pool` and `pool_for` read the process-wide `POOLS` `OnceLock`, which
831    // can only be set once per process. Under cargo test's parallel runner
832    // that makes them unreliable to cover directly without `serial_test` or
833    // a refactor, so they're intentionally out of scope here. Same reason
834    // the "pool() panics before init" path isn't exercised: another test in
835    // the same process may have already populated the lock.
836    //
837    // Mirrors the settings module's stance on its own `init`/`get` pair.
838
839    /// `connect` hands back a SQLite pool wrapped in `DbPool::Sqlite` we
840    /// can actually run queries through.
841    #[tokio::test]
842    async fn connect_returns_a_working_pool_against_in_memory_sqlite() {
843        let pool = connect("sqlite::memory:")
844            .await
845            .expect("in-memory sqlite should always connect");
846
847        let sqlite = pool.as_sqlite().expect("should be Sqlite variant");
848        let (one,): (i64,) = sqlx::query_as("SELECT 1")
849            .fetch_one(sqlite)
850            .await
851            .expect("SELECT 1 should succeed on a fresh pool");
852
853        assert_eq!(one, 1);
854    }
855
856    /// A URL sqlx can't parse surfaces as a plain `sqlx::Error`. We don't
857    /// pin the variant — the family is the contract.
858    #[tokio::test]
859    async fn connect_errors_on_malformed_url() {
860        let result = connect("not-a-real-url").await;
861        assert!(
862            result.is_err(),
863            "expected sqlx to reject a malformed url, got Ok"
864        );
865    }
866
867    /// MySQL and similar schemes that umbral hasn't shipped yet
868    /// surface as a clear configuration error rather than a
869    /// driver-internal one.
870    #[tokio::test]
871    async fn connect_rejects_unsupported_scheme() {
872        let result = connect("mysql://user:pass@host/db").await;
873        match result {
874            Err(sqlx::Error::Configuration(msg)) => {
875                assert!(msg.to_string().contains("mysql"));
876            }
877            other => panic!("expected Configuration error, got {other:?}"),
878        }
879    }
880
881    /// `From<SqlitePool>` and the variant accessors round-trip.
882    #[tokio::test]
883    async fn sqlite_pool_round_trips_through_dbpool() {
884        let sp = SqlitePool::connect("sqlite::memory:").await.unwrap();
885        let dp: DbPool = sp.clone().into();
886        assert_eq!(dp.backend_name(), "sqlite");
887        assert!(dp.as_sqlite().is_some());
888        assert!(dp.as_postgres().is_none());
889    }
890}