sqlx_pg_test_template_runner/
lib.rs1use std::hash::Hasher;
2use std::str::FromStr;
3
4use sqlx::{
5 postgres::{PgConnectOptions, PgPoolOptions},
6 Connection, PgConnection, Pool, Postgres,
7};
8
9#[derive(Debug, thiserror::Error)]
10pub enum Error {
11 #[error("DATABASE_URL is missing or invalid")]
12 InvalidDatabaseUrl,
13
14 #[error("database not found for an open connection pool")]
15 DatabaseNotFound,
16
17 #[error("sqlx error: '{0}'")]
18 Sqlx(#[from] sqlx::Error),
19}
20
21pub struct TestArgs {
23 pub template_name: Option<String>,
25
26 pub max_connections: Option<u32>,
28
29 pub module_path: String,
31}
32
33pub async fn create_db_from_template(
35 mut conn: PgConnection,
36 template_db_name: &str,
37 module_path: &str,
38) -> Result<(String, PgConnection), Error> {
39 let mut hasher = std::hash::DefaultHasher::new();
40 hasher.write(module_path.as_bytes());
41 let id = hasher.finish();
42
43 let db_name = format!("_sqlx_{}", id);
44
45 sqlx::query(&format!("DROP DATABASE IF EXISTS {}", db_name))
46 .execute(&mut conn)
47 .await?;
48
49 sqlx::query(&format!(
50 "CREATE DATABASE {} WITH TEMPLATE {}",
51 db_name, template_db_name
52 ))
53 .execute(&mut conn)
54 .await?;
55
56 sqlx::query(&format!(
57 "COMMENT ON DATABASE {} IS '{}'",
58 db_name, module_path
59 ))
60 .execute(&mut conn)
61 .await?;
62
63 Ok((db_name, conn))
64}
65
66pub async fn spawn_test_pool(
68 connect_options: &PgConnectOptions,
69 db_name: &str,
70 max_connections: Option<u32>,
71) -> Result<Pool<Postgres>, Error> {
72 let connect_options = connect_options.clone().database(db_name);
73 let pool = PgPoolOptions::new()
74 .max_connections(max_connections.unwrap_or(2))
75 .idle_timeout(Some(std::time::Duration::from_secs(1)))
76 .connect_with(connect_options)
77 .await?;
78
79 Ok(pool)
80}
81
82pub fn db_name_of_test_pool(connect_opts: &PgConnectOptions) -> Result<String, Error> {
84 connect_opts
85 .get_database()
86 .map(|s| s.to_string())
87 .ok_or(Error::DatabaseNotFound)
88}
89
90pub async fn close_test_pool(
92 conn: &mut PgConnection,
93 pool: &sqlx::Pool<Postgres>,
94) -> Result<(), Error> {
95 let db_name = db_name_of_test_pool(&pool.connect_options())?;
96
97 pool.close().await;
98
99 sqlx::query(&format!("DROP DATABASE IF EXISTS {} WITH (FORCE)", db_name))
100 .execute(conn)
101 .await?;
102
103 Ok(())
104}
105
106pub async fn wrap_run_test<F, Fut>(f: F, args: TestArgs) -> Result<(), Error>
108where
109 F: Fn(Pool<Postgres>) -> Fut,
110 Fut: std::future::Future<Output = ()>,
111{
112 let database_url = std::env::var("DATABASE_URL").map_err(|_| Error::InvalidDatabaseUrl)?;
114
115 let connect_opts = PgConnectOptions::from_str(&database_url)?;
117
118 let template_name = &args
120 .template_name
121 .map(Ok)
122 .unwrap_or_else(|| db_name_of_test_pool(&connect_opts))?;
123
124 let service_connect_opts = connect_opts.clone().database("");
126
127 let conn = PgConnection::connect_with(&service_connect_opts)
129 .await
130 .unwrap();
131 let (db_name, conn) = create_db_from_template(conn, template_name, &args.module_path)
132 .await
133 .unwrap();
134 conn.close().await?;
135
136 let pool = spawn_test_pool(&service_connect_opts, &db_name, args.max_connections).await?;
138
139 f(pool.clone()).await;
140
141 let mut conn = PgConnection::connect_with(&service_connect_opts).await?;
143 close_test_pool(&mut conn, &pool).await.unwrap();
144 conn.close().await?;
145
146 Ok(())
147}
148
149pub fn run_test<F, Fut>(f: F, args: TestArgs)
151where
152 F: Fn(Pool<Postgres>) -> Fut,
153 Fut: std::future::Future<Output = ()>,
154{
155 sqlx::test_block_on(async move {
156 match wrap_run_test(f, args).await {
157 Err(e) => panic!("test failed: {e}"),
158 Ok(v) => v,
159 }
160 })
161}