sqlx_pg_test_template_runner/
lib.rs

1use std::hash::Hasher;
2use std::str::FromStr;
3
4use sqlx::{
5    postgres::{PgConnectOptions, PgPoolOptions},
6    Connection, PgConnection, Pool, Postgres,
7};
8
9#[derive(Debug, thiserror::Error)]
10pub enum Error {
11    #[error("DATABASE_URL is missing or invalid")]
12    InvalidDatabaseUrl,
13
14    #[error("database not found for an open connection pool")]
15    DatabaseNotFound,
16
17    #[error("sqlx error: '{0}'")]
18    Sqlx(#[from] sqlx::Error),
19}
20
21/// Individual test arguments
22pub struct TestArgs {
23    /// Template database name
24    pub template_name: Option<String>,
25
26    /// Max connections for this pool (1 by default)
27    pub max_connections: Option<u32>,
28
29    /// Test module path
30    pub module_path: String,
31}
32
33/// Creates a new database from a template
34pub async fn create_db_from_template(
35    mut conn: PgConnection,
36    template_db_name: &str,
37    module_path: &str,
38) -> Result<(String, PgConnection), Error> {
39    let mut hasher = std::hash::DefaultHasher::new();
40    hasher.write(module_path.as_bytes());
41    let id = hasher.finish();
42
43    let db_name = format!("_sqlx_{}", id);
44
45    sqlx::query(&format!("DROP DATABASE IF EXISTS {}", db_name))
46        .execute(&mut conn)
47        .await?;
48
49    sqlx::query(&format!(
50        "CREATE DATABASE {} WITH TEMPLATE {}",
51        db_name, template_db_name
52    ))
53    .execute(&mut conn)
54    .await?;
55
56    sqlx::query(&format!(
57        "COMMENT ON DATABASE {} IS '{}'",
58        db_name, module_path
59    ))
60    .execute(&mut conn)
61    .await?;
62
63    Ok((db_name, conn))
64}
65
66/// Spawns test pool with a new database
67pub async fn spawn_test_pool(
68    connect_options: &PgConnectOptions,
69    db_name: &str,
70    max_connections: Option<u32>,
71) -> Result<Pool<Postgres>, Error> {
72    let connect_options = connect_options.clone().database(db_name);
73    let pool = PgPoolOptions::new()
74        .max_connections(max_connections.unwrap_or(2))
75        .idle_timeout(Some(std::time::Duration::from_secs(1)))
76        .connect_with(connect_options)
77        .await?;
78
79    Ok(pool)
80}
81
82/// Returns the name of the database for the test pool or error
83pub fn db_name_of_test_pool(connect_opts: &PgConnectOptions) -> Result<String, Error> {
84    connect_opts
85        .get_database()
86        .map(|s| s.to_string())
87        .ok_or(Error::DatabaseNotFound)
88}
89
90/// Closes test pool and drops the test database
91pub async fn close_test_pool(
92    conn: &mut PgConnection,
93    pool: &sqlx::Pool<Postgres>,
94) -> Result<(), Error> {
95    let db_name = db_name_of_test_pool(&pool.connect_options())?;
96
97    pool.close().await;
98
99    sqlx::query(&format!("DROP DATABASE IF EXISTS {} WITH (FORCE)", db_name))
100        .execute(conn)
101        .await?;
102
103    Ok(())
104}
105
106/// Runs an individual test
107pub async fn wrap_run_test<F, Fut>(f: F, args: TestArgs) -> Result<(), Error>
108where
109    F: Fn(Pool<Postgres>) -> Fut,
110    Fut: std::future::Future<Output = ()>,
111{
112    // Get connection string
113    let database_url = std::env::var("DATABASE_URL").map_err(|_| Error::InvalidDatabaseUrl)?;
114
115    // Try to get template database name from args, defaulting to connection database name
116    let connect_opts = PgConnectOptions::from_str(&database_url)?;
117
118    // Get template name from args or use database name from connection options
119    let template_name = &args
120        .template_name
121        .map(Ok)
122        .unwrap_or_else(|| db_name_of_test_pool(&connect_opts))?;
123
124    // Service connection == default database (postgres, username, etc)
125    let service_connect_opts = connect_opts.clone().database("");
126
127    // Create a new database from the template
128    let conn = PgConnection::connect_with(&service_connect_opts)
129        .await
130        .unwrap();
131    let (db_name, conn) = create_db_from_template(conn, template_name, &args.module_path)
132        .await
133        .unwrap();
134    conn.close().await?;
135
136    // Run test
137    let pool = spawn_test_pool(&service_connect_opts, &db_name, args.max_connections).await?;
138
139    f(pool.clone()).await;
140
141    // Close the pool & drop the test database
142    let mut conn = PgConnection::connect_with(&service_connect_opts).await?;
143    close_test_pool(&mut conn, &pool).await.unwrap();
144    conn.close().await?;
145
146    Ok(())
147}
148
149/// Runs an individual test
150pub fn run_test<F, Fut>(f: F, args: TestArgs)
151where
152    F: Fn(Pool<Postgres>) -> Fut,
153    Fut: std::future::Future<Output = ()>,
154{
155    sqlx::test_block_on(async move {
156        match wrap_run_test(f, args).await {
157            Err(e) => panic!("test failed: {e}"),
158            Ok(v) => v,
159        }
160    })
161}