sea_orm_rocket/
database.rs

1use std::marker::PhantomData;
2use std::ops::DerefMut;
3
4use rocket::fairing::{self, Fairing, Info, Kind};
5use rocket::http::Status;
6use rocket::request::{FromRequest, Outcome, Request};
7use rocket::{error, info_, Build, Ignite, Phase, Rocket, Sentinel};
8
9use rocket::figment::providers::Serialized;
10use rocket::yansi::Paint;
11
12#[cfg(feature = "rocket_okapi")]
13use rocket_okapi::{
14    gen::OpenApiGenerator,
15    request::{OpenApiFromRequest, RequestHeaderInput},
16};
17
18use crate::Pool;
19
20/// Derivable trait which ties a database [`Pool`] with a configuration name.
21///
22/// This trait should rarely, if ever, be implemented manually. Instead, it
23/// should be derived:
24///
25/// ```ignore
26/// use sea_orm_rocket::{Database};
27/// # use sea_orm_rocket::MockPool as SeaOrmPool;
28///
29/// #[derive(Database, Debug)]
30/// #[database("sea_orm")]
31/// struct Db(SeaOrmPool);
32///
33/// #[launch]
34/// fn rocket() -> _ {
35///     rocket::build().attach(Db::init())
36/// }
37/// ```
38///
39/// See the [`Database` derive](derive@crate::Database) for details.
40pub trait Database:
41    From<Self::Pool> + DerefMut<Target = Self::Pool> + Send + Sync + 'static
42{
43    /// The [`Pool`] type of connections to this database.
44    ///
45    /// When `Database` is derived, this takes the value of the `Inner` type in
46    /// `struct Db(Inner)`.
47    type Pool: Pool;
48
49    /// The configuration name for this database.
50    ///
51    /// When `Database` is derived, this takes the value `"name"` in the
52    /// `#[database("name")]` attribute.
53    const NAME: &'static str;
54
55    /// Returns a fairing that initializes the database and its connection pool.
56    ///
57    /// # Example
58    ///
59    /// ```rust
60    /// # mod _inner {
61    /// # use rocket::launch;
62    /// use sea_orm_rocket::Database;
63    /// # use sea_orm_rocket::MockPool as SeaOrmPool;
64    ///
65    /// #[derive(Database)]
66    /// #[database("sea_orm")]
67    /// struct Db(SeaOrmPool);
68    ///
69    /// #[launch]
70    /// fn rocket() -> _ {
71    ///     rocket::build().attach(Db::init())
72    /// }
73    /// # }
74    /// ```
75    fn init() -> Initializer<Self> {
76        Initializer::new()
77    }
78
79    /// Returns a reference to the initialized database in `rocket`. The
80    /// initializer fairing returned by `init()` must have already executed for
81    /// `Option` to be `Some`. This is guaranteed to be the case if the fairing
82    /// is attached and either:
83    ///
84    ///   * Rocket is in the [`Orbit`](rocket::Orbit) phase. That is, the
85    ///     application is running. This is always the case in request guards
86    ///     and liftoff fairings,
87    ///   * _or_ Rocket is in the [`Build`](rocket::Build) or
88    ///     [`Ignite`](rocket::Ignite) phase and the `Initializer` fairing has
89    ///     already been run. This is the case in all fairing callbacks
90    ///     corresponding to fairings attached _after_ the `Initializer`
91    ///     fairing.
92    ///
93    /// # Example
94    ///
95    /// Run database migrations in an ignite fairing. It is imperative that the
96    /// migration fairing be registered _after_ the `init()` fairing.
97    ///
98    /// ```rust
99    /// # mod _inner {
100    /// # use rocket::launch;
101    /// use rocket::fairing::{self, AdHoc};
102    /// use rocket::{Build, Rocket};
103    ///
104    /// use sea_orm_rocket::Database;
105    /// # use sea_orm_rocket::MockPool as SeaOrmPool;
106    ///
107    /// #[derive(Database)]
108    /// #[database("sea_orm")]
109    /// struct Db(SeaOrmPool);
110    ///
111    /// async fn run_migrations(rocket: Rocket<Build>) -> fairing::Result {
112    ///     if let Some(db) = Db::fetch(&rocket) {
113    ///         // run migrations using `db`. get the inner type with &db.0.
114    ///         Ok(rocket)
115    ///     } else {
116    ///         Err(rocket)
117    ///     }
118    /// }
119    ///
120    /// #[launch]
121    /// fn rocket() -> _ {
122    ///     rocket::build()
123    ///         .attach(Db::init())
124    ///         .attach(AdHoc::try_on_ignite("DB Migrations", run_migrations))
125    /// }
126    /// # }
127    /// ```
128    fn fetch<P: Phase>(rocket: &Rocket<P>) -> Option<&Self> {
129        if let Some(db) = rocket.state() {
130            return Some(db);
131        }
132
133        let dbtype = std::any::type_name::<Self>();
134        let fairing = Paint::new(format!("{dbtype}::init()")).bold();
135        error!(
136            "Attempted to fetch unattached database `{}`.",
137            Paint::new(dbtype).bold()
138        );
139        info_!(
140            "`{}` fairing must be attached prior to using this database.",
141            fairing
142        );
143        None
144    }
145}
146
147/// A [`Fairing`] which initializes a [`Database`] and its connection pool.
148///
149/// A value of this type can be created for any type `D` that implements
150/// [`Database`] via the [`Database::init()`] method on the type. Normally, a
151/// value of this type _never_ needs to be constructed directly. This
152/// documentation exists purely as a reference.
153///
154/// This fairing initializes a database pool. Specifically, it:
155///
156///   1. Reads the configuration at `database.db_name`, where `db_name` is
157///      [`Database::NAME`].
158///
159///   2. Sets [`Config`](crate::Config) defaults on the configuration figment.
160///
161///   3. Calls [`Pool::init()`].
162///
163///   4. Stores the database instance in managed storage, retrievable via
164///      [`Database::fetch()`].
165///
166/// The name of the fairing itself is `Initializer<D>`, with `D` replaced with
167/// the type name `D` unless a name is explicitly provided via
168/// [`Self::with_name()`].
169pub struct Initializer<D: Database>(Option<&'static str>, PhantomData<fn() -> D>);
170
171/// A request guard which retrieves a single connection to a [`Database`].
172///
173/// For a database type of `Db`, a request guard of `Connection<Db>` retrieves a
174/// single connection to `Db`.
175///
176/// The request guard succeeds if the database was initialized by the
177/// [`Initializer`] fairing and a connection is available within
178/// [`connect_timeout`](crate::Config::connect_timeout) seconds.
179///   * If the `Initializer` fairing was _not_ attached, the guard _fails_ with
180///   status `InternalServerError`. A [`Sentinel`] guards this condition, and so
181///   this type of failure is unlikely to occur. A `None` error is returned.
182///   * If a connection is not available within `connect_timeout` seconds or
183///   another error occurs, the guard _fails_ with status `ServiceUnavailable`
184///   and the error is returned in `Some`.
185pub struct Connection<'a, D: Database>(&'a <D::Pool as Pool>::Connection);
186
187impl<D: Database> Initializer<D> {
188    /// Returns a database initializer fairing for `D`.
189    ///
190    /// This method should never need to be called manually. See the [crate
191    /// docs](crate) for usage information.
192    #[allow(clippy::new_without_default)]
193    pub fn new() -> Self {
194        Self(None, std::marker::PhantomData)
195    }
196
197    /// Returns a database initializer fairing for `D` with name `name`.
198    ///
199    /// This method should never need to be called manually. See the [crate
200    /// docs](crate) for usage information.
201    pub fn with_name(name: &'static str) -> Self {
202        Self(Some(name), std::marker::PhantomData)
203    }
204}
205
206impl<'a, D: Database> Connection<'a, D> {
207    /// Returns the internal connection value. See the [`Connection` Deref
208    /// column](crate#supported-drivers) for the expected type of this value.
209    pub fn into_inner(self) -> &'a <D::Pool as Pool>::Connection {
210        self.0
211    }
212}
213
214#[cfg(feature = "rocket_okapi")]
215impl<'r, D: Database> OpenApiFromRequest<'r> for Connection<'r, D> {
216    fn from_request_input(
217        _gen: &mut OpenApiGenerator,
218        _name: String,
219        _required: bool,
220    ) -> rocket_okapi::Result<RequestHeaderInput> {
221        Ok(RequestHeaderInput::None)
222    }
223}
224
225#[rocket::async_trait]
226impl<D: Database> Fairing for Initializer<D> {
227    fn info(&self) -> Info {
228        Info {
229            name: self.0.unwrap_or_else(std::any::type_name::<Self>),
230            kind: Kind::Ignite,
231        }
232    }
233
234    async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result {
235        let workers: usize = rocket
236            .figment()
237            .extract_inner(rocket::Config::WORKERS)
238            .unwrap_or_else(|_| rocket::Config::default().workers);
239
240        let figment = rocket
241            .figment()
242            .focus(&format!("databases.{}", D::NAME))
243            .merge(Serialized::default("max_connections", workers * 4))
244            .merge(Serialized::default("connect_timeout", 5))
245            .merge(Serialized::default("sqlx_logging", true));
246
247        match <D::Pool>::init(&figment).await {
248            Ok(pool) => Ok(rocket.manage(D::from(pool))),
249            Err(e) => {
250                error!("failed to initialize database: {}", e);
251                Err(rocket)
252            }
253        }
254    }
255}
256
257#[rocket::async_trait]
258impl<'r, D: Database> FromRequest<'r> for Connection<'r, D> {
259    type Error = Option<<D::Pool as Pool>::Error>;
260
261    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
262        match D::fetch(req.rocket()) {
263            Some(pool) => Outcome::Success(Connection(pool.borrow())),
264            None => Outcome::Error((Status::InternalServerError, None)),
265        }
266    }
267}
268
269impl<D: Database> Sentinel for Connection<'_, D> {
270    fn abort(rocket: &Rocket<Ignite>) -> bool {
271        D::fetch(rocket).is_none()
272    }
273}