1#![doc = include_str!("../README.md")]
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use shuttle_service::{
6 resource::{ProvisionResourceRequest, ResourceType},
7 DatabaseResource, DbInput, Environment, Error, IntoResource, ResourceFactory,
8 ResourceInputBuilder,
9};
10
11#[cfg(any(feature = "diesel-async-bb8", feature = "diesel-async-deadpool"))]
12use diesel_async::pooled_connection::AsyncDieselConnectionManager;
13
14#[cfg(feature = "diesel-async-bb8")]
15use diesel_async::pooled_connection::bb8 as diesel_bb8;
16
17#[cfg(feature = "diesel-async-deadpool")]
18use diesel_async::pooled_connection::deadpool as diesel_deadpool;
19
20#[allow(dead_code)]
21const MIN_CONNECTIONS: u32 = 1;
22#[allow(dead_code)]
23const MAX_CONNECTIONS: u32 = 5;
24
25#[derive(Serialize, Deserialize)]
27#[serde(untagged)]
28pub enum MaybeRequest {
29 Request(ProvisionResourceRequest),
30 NotRequest(DatabaseResource),
31}
32
33macro_rules! aws_engine {
34 ($feature:expr, $struct_ident:ident, $res_type:ident) => {
35 #[cfg(feature = $feature)]
36 #[derive(Default)]
37 pub struct $struct_ident(DbInput);
39
40 #[cfg(feature = $feature)]
41 impl $struct_ident {
42 pub fn local_uri(mut self, local_uri: &str) -> Self {
44 self.0.local_uri = Some(local_uri.to_string());
45
46 self
47 }
48
49 pub fn database_name(mut self, database_name: &str) -> Self {
51 self.0.db_name = Some(database_name.to_string());
52
53 self
54 }
55 }
56
57 #[cfg(feature = $feature)]
58 #[async_trait::async_trait]
59 impl ResourceInputBuilder for $struct_ident {
60 type Input = MaybeRequest;
61 type Output = OutputWrapper;
62
63 async fn build(self, factory: &ResourceFactory) -> Result<Self::Input, Error> {
64 let md = factory.get_metadata();
65 Ok(match md.env {
66 Environment::Deployment => MaybeRequest::Request(ProvisionResourceRequest {
67 r#type: ResourceType::$res_type,
68 config: serde_json::to_value(self.0).unwrap(),
69 }),
70 Environment::Local => match self.0.local_uri {
71 Some(local_uri) => {
72 MaybeRequest::NotRequest(DatabaseResource::ConnectionString(local_uri))
73 }
74 None => MaybeRequest::Request(ProvisionResourceRequest {
75 r#type: ResourceType::$res_type,
76 config: serde_json::to_value(self.0).unwrap(),
77 }),
78 },
79 })
80 }
81 }
82 };
83}
84
85aws_engine!("postgres", Postgres, DatabaseAwsRdsPostgres);
86aws_engine!("mysql", MySql, DatabaseAwsRdsMySql);
87aws_engine!("mariadb", MariaDB, DatabaseAwsRdsMariaDB);
88
89#[derive(Serialize, Deserialize)]
90#[serde(transparent)]
91pub struct OutputWrapper(DatabaseResource);
92
93#[async_trait]
94impl IntoResource<String> for OutputWrapper {
95 async fn into_resource(self) -> Result<String, Error> {
96 Ok(match self.0 {
97 DatabaseResource::ConnectionString(s) => s,
98 DatabaseResource::Info(info) => info.connection_string(true),
99 })
100 }
101}
102
103#[cfg(feature = "diesel-async")]
106mod _diesel_async {
107 use super::*;
108
109 #[cfg(feature = "postgres")]
110 #[async_trait]
111 impl IntoResource<diesel_async::AsyncPgConnection> for OutputWrapper {
112 async fn into_resource(self) -> Result<diesel_async::AsyncPgConnection, Error> {
113 use diesel_async::{AsyncConnection, AsyncPgConnection};
114
115 let connection_string: String = self.into_resource().await.unwrap();
116 Ok(AsyncPgConnection::establish(&connection_string)
117 .await
118 .map_err(shuttle_service::error::CustomError::new)?)
119 }
120 }
121
122 #[cfg(any(feature = "mysql", feature = "mariadb"))]
123 #[async_trait]
124 impl IntoResource<diesel_async::AsyncMysqlConnection> for OutputWrapper {
125 async fn into_resource(self) -> Result<diesel_async::AsyncMysqlConnection, Error> {
126 use diesel_async::{AsyncConnection, AsyncMysqlConnection};
127
128 let connection_string: String = self.into_resource().await.unwrap();
129 Ok(AsyncMysqlConnection::establish(&connection_string)
130 .await
131 .map_err(shuttle_service::error::CustomError::new)?)
132 }
133 }
134}
135
136#[cfg(feature = "diesel-async-bb8")]
137mod _diesel_async_bb8 {
138 use super::*;
139
140 #[cfg(feature = "postgres")]
141 #[async_trait]
142 impl IntoResource<diesel_bb8::Pool<diesel_async::AsyncPgConnection>> for OutputWrapper {
143 async fn into_resource(
144 self,
145 ) -> Result<diesel_bb8::Pool<diesel_async::AsyncPgConnection>, Error> {
146 let connection_string: String = self.into_resource().await.unwrap();
147
148 Ok(diesel_bb8::Pool::builder()
149 .min_idle(Some(MIN_CONNECTIONS))
150 .max_size(MAX_CONNECTIONS)
151 .build(AsyncDieselConnectionManager::new(connection_string))
152 .await
153 .map_err(shuttle_service::error::CustomError::new)?)
154 }
155 }
156
157 #[cfg(any(feature = "mysql", feature = "mariadb"))]
158 #[async_trait]
159 impl IntoResource<diesel_bb8::Pool<diesel_async::AsyncMysqlConnection>> for OutputWrapper {
160 async fn into_resource(
161 self,
162 ) -> Result<diesel_bb8::Pool<diesel_async::AsyncMysqlConnection>, Error> {
163 let connection_string: String = self.into_resource().await.unwrap();
164
165 Ok(diesel_bb8::Pool::builder()
166 .min_idle(Some(MIN_CONNECTIONS))
167 .max_size(MAX_CONNECTIONS)
168 .build(AsyncDieselConnectionManager::new(connection_string))
169 .await
170 .map_err(shuttle_service::error::CustomError::new)?)
171 }
172 }
173}
174
175#[cfg(feature = "diesel-async-deadpool")]
176mod _diesel_async_deadpool {
177 use super::*;
178
179 #[cfg(feature = "postgres")]
180 #[async_trait]
181 impl IntoResource<diesel_deadpool::Pool<diesel_async::AsyncPgConnection>> for OutputWrapper {
182 async fn into_resource(
183 self,
184 ) -> Result<diesel_deadpool::Pool<diesel_async::AsyncPgConnection>, Error> {
185 let connection_string: String = self.into_resource().await.unwrap();
186
187 Ok(
188 diesel_deadpool::Pool::builder(AsyncDieselConnectionManager::new(
189 connection_string,
190 ))
191 .max_size(MAX_CONNECTIONS as usize)
192 .build()
193 .map_err(shuttle_service::error::CustomError::new)?,
194 )
195 }
196 }
197
198 #[cfg(any(feature = "mysql", feature = "mariadb"))]
199 #[async_trait]
200 impl IntoResource<diesel_deadpool::Pool<diesel_async::AsyncMysqlConnection>> for OutputWrapper {
201 async fn into_resource(
202 self,
203 ) -> Result<diesel_deadpool::Pool<diesel_async::AsyncMysqlConnection>, Error> {
204 let connection_string: String = self.into_resource().await.unwrap();
205
206 Ok(
207 diesel_deadpool::Pool::builder(AsyncDieselConnectionManager::new(
208 connection_string,
209 ))
210 .max_size(MAX_CONNECTIONS as usize)
211 .build()
212 .map_err(shuttle_service::error::CustomError::new)?,
213 )
214 }
215 }
216}
217
218#[cfg(feature = "sqlx")]
219mod _sqlx {
220 use super::*;
221
222 #[cfg(feature = "postgres")]
223 #[async_trait]
224 impl IntoResource<sqlx::PgPool> for OutputWrapper {
225 async fn into_resource(self) -> Result<sqlx::PgPool, Error> {
226 let connection_string: String = self.into_resource().await.unwrap();
227
228 Ok(sqlx::postgres::PgPoolOptions::new()
229 .min_connections(MIN_CONNECTIONS)
230 .max_connections(MAX_CONNECTIONS)
231 .connect(&connection_string)
232 .await
233 .map_err(shuttle_service::error::CustomError::new)?)
234 }
235 }
236
237 #[cfg(any(feature = "mysql", feature = "mariadb"))]
238 #[async_trait]
239 impl IntoResource<sqlx::MySqlPool> for OutputWrapper {
240 async fn into_resource(self) -> Result<sqlx::MySqlPool, Error> {
241 let connection_string: String = self.into_resource().await.unwrap();
242
243 Ok(sqlx::mysql::MySqlPoolOptions::new()
244 .min_connections(MIN_CONNECTIONS)
245 .max_connections(MAX_CONNECTIONS)
246 .connect(&connection_string)
247 .await
248 .map_err(shuttle_service::error::CustomError::new)?)
249 }
250 }
251}