wired_handler_hyper/data/db/
mod.rs

1use std::future::Future;
2
3use diesel_async::{
4    async_connection_wrapper::AsyncConnectionWrapper,
5    pooled_connection::{
6        deadpool::{BuildError, Pool, PoolError},
7        AsyncDieselConnectionManager,
8    },
9    AsyncPgConnection,
10};
11use diesel_migrations::EmbeddedMigrations;
12use thiserror::Error;
13use wired_handler::Handler;
14
15use crate::{
16    prelude::*,
17    state::{
18        context::{HttpRequestContext, SessionlessRequestContext},
19        global_state::GlobalState,
20    },
21};
22
23pub use db_connection::*;
24pub use db_pool::*;
25
26use super::config::DbConfig;
27
28mod db_connection;
29mod db_pool;
30
31/// For getting a database connection from the pool
32pub trait ContextGetDbExt {
33    /// Gets a database connection from the pool
34    fn db(&self) -> impl Future<Output = Result<DbConnection, PoolError>>;
35}
36
37impl ContextGetDbExt for HttpRequestContext {
38    async fn db(&self) -> Result<DbConnection, PoolError> {
39        let db_pool = GlobalState::get_from_ctx(self)
40            .get_cloned::<DbPool>()
41            .await
42            .expect("DbPool must be inserted");
43
44        let db = db_pool.get().await?;
45        Ok(db)
46    }
47}
48
49#[non_exhaustive]
50#[derive(Debug, Error)]
51pub enum LoadDbError {
52    #[error("{0}")]
53    DbPool(#[from] BuildError),
54    #[error("{0}")]
55    MigrationDbPool(#[from] PoolError),
56    #[error("{0}")]
57    MigrationError(#[from] MigrationError),
58}
59
60pub trait LoadDbExt {
61    /// Loads the database and applies migrations
62    ///
63    /// In debug mode, only generates a warning if there are pending migrations
64    fn load_db(
65        &self,
66        db_config: DbConfig,
67        migrations: impl Into<Option<EmbeddedMigrations>>,
68    ) -> impl Future<Output = Result<(), LoadDbError>>;
69}
70
71impl<F: Future<Output = HttpRequestContext> + 'static + Send> LoadDbExt
72    for Handler<SessionlessRequestContext, HttpRequestContext, GlobalState, F>
73{
74    async fn load_db(
75        &self,
76        db_config: DbConfig,
77        migrations: impl Into<Option<EmbeddedMigrations>>,
78    ) -> Result<(), LoadDbError> {
79        let db_addr = &db_config.addr;
80        let db_pool: DbPool = {
81            let config: AsyncDieselConnectionManager<AsyncPgConnection> =
82                AsyncDieselConnectionManager::<AsyncPgConnection>::new(db_addr);
83            Pool::builder(config).build()?
84        };
85
86        if let Some(migrations) = migrations.into() {
87            let conn = db_pool.get().await?;
88            AsyncConnectionWrapper::from(conn)
89                .run_migrations(migrations)
90                .await?;
91        }
92
93        {
94            let global_state = self.state();
95
96            global_state.insert(db_config).await;
97            global_state.insert(db_pool).await;
98        }
99
100        Ok(())
101    }
102}