shuttle_aws_rds/
lib.rs

1#![doc = include_str!("../README.md")]
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use shuttle_service::{
6    resource::{ProvisionResourceRequest, ResourceType},
7    DatabaseResource, DbInput, Environment, Error, IntoResource, ResourceFactory,
8    ResourceInputBuilder,
9};
10
11#[cfg(any(feature = "diesel-async-bb8", feature = "diesel-async-deadpool"))]
12use diesel_async::pooled_connection::AsyncDieselConnectionManager;
13
14#[cfg(feature = "diesel-async-bb8")]
15use diesel_async::pooled_connection::bb8 as diesel_bb8;
16
17#[cfg(feature = "diesel-async-deadpool")]
18use diesel_async::pooled_connection::deadpool as diesel_deadpool;
19
20#[allow(dead_code)]
21const MIN_CONNECTIONS: u32 = 1;
22#[allow(dead_code)]
23const MAX_CONNECTIONS: u32 = 5;
24
25/// Conditionally request a Shuttle resource
26#[derive(Serialize, Deserialize)]
27#[serde(untagged)]
28pub enum MaybeRequest {
29    Request(ProvisionResourceRequest),
30    NotRequest(DatabaseResource),
31}
32
33macro_rules! aws_engine {
34    ($feature:expr, $struct_ident:ident, $res_type:ident) => {
35        #[cfg(feature = $feature)]
36        #[derive(Default)]
37        /// Shuttle managed AWS RDS instance
38        pub struct $struct_ident(DbInput);
39
40        #[cfg(feature = $feature)]
41        impl $struct_ident {
42            /// Use a custom connection string for local runs
43            pub fn local_uri(mut self, local_uri: &str) -> Self {
44                self.0.local_uri = Some(local_uri.to_string());
45
46                self
47            }
48
49            /// Use something other than the project name as the DB name
50            pub fn database_name(mut self, database_name: &str) -> Self {
51                self.0.db_name = Some(database_name.to_string());
52
53                self
54            }
55        }
56
57        #[cfg(feature = $feature)]
58        #[async_trait::async_trait]
59        impl ResourceInputBuilder for $struct_ident {
60            type Input = MaybeRequest;
61            type Output = OutputWrapper;
62
63            async fn build(self, factory: &ResourceFactory) -> Result<Self::Input, Error> {
64                let md = factory.get_metadata();
65                Ok(match md.env {
66                    Environment::Deployment => MaybeRequest::Request(ProvisionResourceRequest {
67                        r#type: ResourceType::$res_type,
68                        config: serde_json::to_value(self.0).unwrap(),
69                    }),
70                    Environment::Local => match self.0.local_uri {
71                        Some(local_uri) => {
72                            MaybeRequest::NotRequest(DatabaseResource::ConnectionString(local_uri))
73                        }
74                        None => MaybeRequest::Request(ProvisionResourceRequest {
75                            r#type: ResourceType::$res_type,
76                            config: serde_json::to_value(self.0).unwrap(),
77                        }),
78                    },
79                })
80            }
81        }
82    };
83}
84
85aws_engine!("postgres", Postgres, DatabaseAwsRdsPostgres);
86aws_engine!("mysql", MySql, DatabaseAwsRdsMySql);
87aws_engine!("mariadb", MariaDB, DatabaseAwsRdsMariaDB);
88
89#[derive(Serialize, Deserialize)]
90#[serde(transparent)]
91pub struct OutputWrapper(DatabaseResource);
92
93#[async_trait]
94impl IntoResource<String> for OutputWrapper {
95    async fn into_resource(self) -> Result<String, Error> {
96        Ok(match self.0 {
97            DatabaseResource::ConnectionString(s) => s,
98            DatabaseResource::Info(info) => info.connection_string(true),
99        })
100    }
101}
102
103// If these were done in the main macro above, this would produce two conflicting `impl IntoResource<sqlx::MySqlPool>`
104
105#[cfg(feature = "diesel-async")]
106mod _diesel_async {
107    use super::*;
108
109    #[cfg(feature = "postgres")]
110    #[async_trait]
111    impl IntoResource<diesel_async::AsyncPgConnection> for OutputWrapper {
112        async fn into_resource(self) -> Result<diesel_async::AsyncPgConnection, Error> {
113            use diesel_async::{AsyncConnection, AsyncPgConnection};
114
115            let connection_string: String = self.into_resource().await.unwrap();
116            Ok(AsyncPgConnection::establish(&connection_string)
117                .await
118                .map_err(shuttle_service::error::CustomError::new)?)
119        }
120    }
121
122    #[cfg(any(feature = "mysql", feature = "mariadb"))]
123    #[async_trait]
124    impl IntoResource<diesel_async::AsyncMysqlConnection> for OutputWrapper {
125        async fn into_resource(self) -> Result<diesel_async::AsyncMysqlConnection, Error> {
126            use diesel_async::{AsyncConnection, AsyncMysqlConnection};
127
128            let connection_string: String = self.into_resource().await.unwrap();
129            Ok(AsyncMysqlConnection::establish(&connection_string)
130                .await
131                .map_err(shuttle_service::error::CustomError::new)?)
132        }
133    }
134}
135
136#[cfg(feature = "diesel-async-bb8")]
137mod _diesel_async_bb8 {
138    use super::*;
139
140    #[cfg(feature = "postgres")]
141    #[async_trait]
142    impl IntoResource<diesel_bb8::Pool<diesel_async::AsyncPgConnection>> for OutputWrapper {
143        async fn into_resource(
144            self,
145        ) -> Result<diesel_bb8::Pool<diesel_async::AsyncPgConnection>, Error> {
146            let connection_string: String = self.into_resource().await.unwrap();
147
148            Ok(diesel_bb8::Pool::builder()
149                .min_idle(Some(MIN_CONNECTIONS))
150                .max_size(MAX_CONNECTIONS)
151                .build(AsyncDieselConnectionManager::new(connection_string))
152                .await
153                .map_err(shuttle_service::error::CustomError::new)?)
154        }
155    }
156
157    #[cfg(any(feature = "mysql", feature = "mariadb"))]
158    #[async_trait]
159    impl IntoResource<diesel_bb8::Pool<diesel_async::AsyncMysqlConnection>> for OutputWrapper {
160        async fn into_resource(
161            self,
162        ) -> Result<diesel_bb8::Pool<diesel_async::AsyncMysqlConnection>, Error> {
163            let connection_string: String = self.into_resource().await.unwrap();
164
165            Ok(diesel_bb8::Pool::builder()
166                .min_idle(Some(MIN_CONNECTIONS))
167                .max_size(MAX_CONNECTIONS)
168                .build(AsyncDieselConnectionManager::new(connection_string))
169                .await
170                .map_err(shuttle_service::error::CustomError::new)?)
171        }
172    }
173}
174
175#[cfg(feature = "diesel-async-deadpool")]
176mod _diesel_async_deadpool {
177    use super::*;
178
179    #[cfg(feature = "postgres")]
180    #[async_trait]
181    impl IntoResource<diesel_deadpool::Pool<diesel_async::AsyncPgConnection>> for OutputWrapper {
182        async fn into_resource(
183            self,
184        ) -> Result<diesel_deadpool::Pool<diesel_async::AsyncPgConnection>, Error> {
185            let connection_string: String = self.into_resource().await.unwrap();
186
187            Ok(
188                diesel_deadpool::Pool::builder(AsyncDieselConnectionManager::new(
189                    connection_string,
190                ))
191                .max_size(MAX_CONNECTIONS as usize)
192                .build()
193                .map_err(shuttle_service::error::CustomError::new)?,
194            )
195        }
196    }
197
198    #[cfg(any(feature = "mysql", feature = "mariadb"))]
199    #[async_trait]
200    impl IntoResource<diesel_deadpool::Pool<diesel_async::AsyncMysqlConnection>> for OutputWrapper {
201        async fn into_resource(
202            self,
203        ) -> Result<diesel_deadpool::Pool<diesel_async::AsyncMysqlConnection>, Error> {
204            let connection_string: String = self.into_resource().await.unwrap();
205
206            Ok(
207                diesel_deadpool::Pool::builder(AsyncDieselConnectionManager::new(
208                    connection_string,
209                ))
210                .max_size(MAX_CONNECTIONS as usize)
211                .build()
212                .map_err(shuttle_service::error::CustomError::new)?,
213            )
214        }
215    }
216}
217
218#[cfg(feature = "sqlx")]
219mod _sqlx {
220    use super::*;
221
222    #[cfg(feature = "postgres")]
223    #[async_trait]
224    impl IntoResource<sqlx::PgPool> for OutputWrapper {
225        async fn into_resource(self) -> Result<sqlx::PgPool, Error> {
226            let connection_string: String = self.into_resource().await.unwrap();
227
228            Ok(sqlx::postgres::PgPoolOptions::new()
229                .min_connections(MIN_CONNECTIONS)
230                .max_connections(MAX_CONNECTIONS)
231                .connect(&connection_string)
232                .await
233                .map_err(shuttle_service::error::CustomError::new)?)
234        }
235    }
236
237    #[cfg(any(feature = "mysql", feature = "mariadb"))]
238    #[async_trait]
239    impl IntoResource<sqlx::MySqlPool> for OutputWrapper {
240        async fn into_resource(self) -> Result<sqlx::MySqlPool, Error> {
241            let connection_string: String = self.into_resource().await.unwrap();
242
243            Ok(sqlx::mysql::MySqlPoolOptions::new()
244                .min_connections(MIN_CONNECTIONS)
245                .max_connections(MAX_CONNECTIONS)
246                .connect(&connection_string)
247                .await
248                .map_err(shuttle_service::error::CustomError::new)?)
249        }
250    }
251}