sea_orm_rocket/database.rs
1use std::marker::PhantomData;
2use std::ops::DerefMut;
3
4use rocket::fairing::{self, Fairing, Info, Kind};
5use rocket::http::Status;
6use rocket::request::{FromRequest, Outcome, Request};
7use rocket::{error, info_, Build, Ignite, Phase, Rocket, Sentinel};
8
9use rocket::figment::providers::Serialized;
10use rocket::yansi::Paint;
11
12#[cfg(feature = "rocket_okapi")]
13use rocket_okapi::{
14 gen::OpenApiGenerator,
15 request::{OpenApiFromRequest, RequestHeaderInput},
16};
17
18use crate::Pool;
19
20/// Derivable trait which ties a database [`Pool`] with a configuration name.
21///
22/// This trait should rarely, if ever, be implemented manually. Instead, it
23/// should be derived:
24///
25/// ```ignore
26/// use sea_orm_rocket::{Database};
27/// # use sea_orm_rocket::MockPool as SeaOrmPool;
28///
29/// #[derive(Database, Debug)]
30/// #[database("sea_orm")]
31/// struct Db(SeaOrmPool);
32///
33/// #[launch]
34/// fn rocket() -> _ {
35/// rocket::build().attach(Db::init())
36/// }
37/// ```
38///
39/// See the [`Database` derive](derive@crate::Database) for details.
40pub trait Database:
41 From<Self::Pool> + DerefMut<Target = Self::Pool> + Send + Sync + 'static
42{
43 /// The [`Pool`] type of connections to this database.
44 ///
45 /// When `Database` is derived, this takes the value of the `Inner` type in
46 /// `struct Db(Inner)`.
47 type Pool: Pool;
48
49 /// The configuration name for this database.
50 ///
51 /// When `Database` is derived, this takes the value `"name"` in the
52 /// `#[database("name")]` attribute.
53 const NAME: &'static str;
54
55 /// Returns a fairing that initializes the database and its connection pool.
56 ///
57 /// # Example
58 ///
59 /// ```rust
60 /// # mod _inner {
61 /// # use rocket::launch;
62 /// use sea_orm_rocket::Database;
63 /// # use sea_orm_rocket::MockPool as SeaOrmPool;
64 ///
65 /// #[derive(Database)]
66 /// #[database("sea_orm")]
67 /// struct Db(SeaOrmPool);
68 ///
69 /// #[launch]
70 /// fn rocket() -> _ {
71 /// rocket::build().attach(Db::init())
72 /// }
73 /// # }
74 /// ```
75 fn init() -> Initializer<Self> {
76 Initializer::new()
77 }
78
79 /// Returns a reference to the initialized database in `rocket`. The
80 /// initializer fairing returned by `init()` must have already executed for
81 /// `Option` to be `Some`. This is guaranteed to be the case if the fairing
82 /// is attached and either:
83 ///
84 /// * Rocket is in the [`Orbit`](rocket::Orbit) phase. That is, the
85 /// application is running. This is always the case in request guards
86 /// and liftoff fairings,
87 /// * _or_ Rocket is in the [`Build`](rocket::Build) or
88 /// [`Ignite`](rocket::Ignite) phase and the `Initializer` fairing has
89 /// already been run. This is the case in all fairing callbacks
90 /// corresponding to fairings attached _after_ the `Initializer`
91 /// fairing.
92 ///
93 /// # Example
94 ///
95 /// Run database migrations in an ignite fairing. It is imperative that the
96 /// migration fairing be registered _after_ the `init()` fairing.
97 ///
98 /// ```rust
99 /// # mod _inner {
100 /// # use rocket::launch;
101 /// use rocket::fairing::{self, AdHoc};
102 /// use rocket::{Build, Rocket};
103 ///
104 /// use sea_orm_rocket::Database;
105 /// # use sea_orm_rocket::MockPool as SeaOrmPool;
106 ///
107 /// #[derive(Database)]
108 /// #[database("sea_orm")]
109 /// struct Db(SeaOrmPool);
110 ///
111 /// async fn run_migrations(rocket: Rocket<Build>) -> fairing::Result {
112 /// if let Some(db) = Db::fetch(&rocket) {
113 /// // run migrations using `db`. get the inner type with &db.0.
114 /// Ok(rocket)
115 /// } else {
116 /// Err(rocket)
117 /// }
118 /// }
119 ///
120 /// #[launch]
121 /// fn rocket() -> _ {
122 /// rocket::build()
123 /// .attach(Db::init())
124 /// .attach(AdHoc::try_on_ignite("DB Migrations", run_migrations))
125 /// }
126 /// # }
127 /// ```
128 fn fetch<P: Phase>(rocket: &Rocket<P>) -> Option<&Self> {
129 if let Some(db) = rocket.state() {
130 return Some(db);
131 }
132
133 let dbtype = std::any::type_name::<Self>();
134 let fairing = Paint::new(format!("{dbtype}::init()")).bold();
135 error!(
136 "Attempted to fetch unattached database `{}`.",
137 Paint::new(dbtype).bold()
138 );
139 info_!(
140 "`{}` fairing must be attached prior to using this database.",
141 fairing
142 );
143 None
144 }
145}
146
147/// A [`Fairing`] which initializes a [`Database`] and its connection pool.
148///
149/// A value of this type can be created for any type `D` that implements
150/// [`Database`] via the [`Database::init()`] method on the type. Normally, a
151/// value of this type _never_ needs to be constructed directly. This
152/// documentation exists purely as a reference.
153///
154/// This fairing initializes a database pool. Specifically, it:
155///
156/// 1. Reads the configuration at `database.db_name`, where `db_name` is
157/// [`Database::NAME`].
158///
159/// 2. Sets [`Config`](crate::Config) defaults on the configuration figment.
160///
161/// 3. Calls [`Pool::init()`].
162///
163/// 4. Stores the database instance in managed storage, retrievable via
164/// [`Database::fetch()`].
165///
166/// The name of the fairing itself is `Initializer<D>`, with `D` replaced with
167/// the type name `D` unless a name is explicitly provided via
168/// [`Self::with_name()`].
169pub struct Initializer<D: Database>(Option<&'static str>, PhantomData<fn() -> D>);
170
171/// A request guard which retrieves a single connection to a [`Database`].
172///
173/// For a database type of `Db`, a request guard of `Connection<Db>` retrieves a
174/// single connection to `Db`.
175///
176/// The request guard succeeds if the database was initialized by the
177/// [`Initializer`] fairing and a connection is available within
178/// [`connect_timeout`](crate::Config::connect_timeout) seconds.
179/// * If the `Initializer` fairing was _not_ attached, the guard _fails_ with
180/// status `InternalServerError`. A [`Sentinel`] guards this condition, and so
181/// this type of failure is unlikely to occur. A `None` error is returned.
182/// * If a connection is not available within `connect_timeout` seconds or
183/// another error occurs, the guard _fails_ with status `ServiceUnavailable`
184/// and the error is returned in `Some`.
185pub struct Connection<'a, D: Database>(&'a <D::Pool as Pool>::Connection);
186
187impl<D: Database> Initializer<D> {
188 /// Returns a database initializer fairing for `D`.
189 ///
190 /// This method should never need to be called manually. See the [crate
191 /// docs](crate) for usage information.
192 #[allow(clippy::new_without_default)]
193 pub fn new() -> Self {
194 Self(None, std::marker::PhantomData)
195 }
196
197 /// Returns a database initializer fairing for `D` with name `name`.
198 ///
199 /// This method should never need to be called manually. See the [crate
200 /// docs](crate) for usage information.
201 pub fn with_name(name: &'static str) -> Self {
202 Self(Some(name), std::marker::PhantomData)
203 }
204}
205
206impl<'a, D: Database> Connection<'a, D> {
207 /// Returns the internal connection value. See the [`Connection` Deref
208 /// column](crate#supported-drivers) for the expected type of this value.
209 pub fn into_inner(self) -> &'a <D::Pool as Pool>::Connection {
210 self.0
211 }
212}
213
214#[cfg(feature = "rocket_okapi")]
215impl<'r, D: Database> OpenApiFromRequest<'r> for Connection<'r, D> {
216 fn from_request_input(
217 _gen: &mut OpenApiGenerator,
218 _name: String,
219 _required: bool,
220 ) -> rocket_okapi::Result<RequestHeaderInput> {
221 Ok(RequestHeaderInput::None)
222 }
223}
224
225#[rocket::async_trait]
226impl<D: Database> Fairing for Initializer<D> {
227 fn info(&self) -> Info {
228 Info {
229 name: self.0.unwrap_or_else(std::any::type_name::<Self>),
230 kind: Kind::Ignite,
231 }
232 }
233
234 async fn on_ignite(&self, rocket: Rocket<Build>) -> fairing::Result {
235 let workers: usize = rocket
236 .figment()
237 .extract_inner(rocket::Config::WORKERS)
238 .unwrap_or_else(|_| rocket::Config::default().workers);
239
240 let figment = rocket
241 .figment()
242 .focus(&format!("databases.{}", D::NAME))
243 .merge(Serialized::default("max_connections", workers * 4))
244 .merge(Serialized::default("connect_timeout", 5))
245 .merge(Serialized::default("sqlx_logging", true));
246
247 match <D::Pool>::init(&figment).await {
248 Ok(pool) => Ok(rocket.manage(D::from(pool))),
249 Err(e) => {
250 error!("failed to initialize database: {}", e);
251 Err(rocket)
252 }
253 }
254 }
255}
256
257#[rocket::async_trait]
258impl<'r, D: Database> FromRequest<'r> for Connection<'r, D> {
259 type Error = Option<<D::Pool as Pool>::Error>;
260
261 async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
262 match D::fetch(req.rocket()) {
263 Some(pool) => Outcome::Success(Connection(pool.borrow())),
264 None => Outcome::Error((Status::InternalServerError, None)),
265 }
266 }
267}
268
269impl<D: Database> Sentinel for Connection<'_, D> {
270 fn abort(rocket: &Rocket<Ignite>) -> bool {
271 D::fetch(rocket).is_none()
272 }
273}