sqlx_pg_test_template_runner/
lib.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
use std::hash::Hasher;
use std::str::FromStr;

use sqlx::{
    postgres::{PgConnectOptions, PgPoolOptions},
    Connection, PgConnection, Pool, Postgres,
};

#[derive(Debug, thiserror::Error)]
pub enum Error {
    #[error("DATABASE_URL is missing or invalid")]
    InvalidDatabaseUrl,

    #[error("database not found for an open connection pool")]
    DatabaseNotFound,

    #[error("sqlx error: '{0}'")]
    Sqlx(#[from] sqlx::Error),
}

/// Individual test arguments
pub struct TestArgs {
    /// Template database name
    pub template_name: Option<String>,

    /// Max connections for this pool (1 by default)
    pub max_connections: Option<u32>,

    /// Test module path
    pub module_path: String,
}

/// Creates a new database from a template
pub async fn create_db_from_template(
    mut conn: PgConnection,
    template_db_name: &str,
    module_path: &str,
) -> Result<(String, PgConnection), Error> {
    let mut hasher = std::hash::DefaultHasher::new();
    hasher.write(module_path.as_bytes());
    let id = hasher.finish();

    let db_name = format!("_sqlx_{}", id);

    sqlx::query(&format!("DROP DATABASE IF EXISTS {}", db_name))
        .execute(&mut conn)
        .await?;

    sqlx::query(&format!(
        "CREATE DATABASE {} WITH TEMPLATE {}",
        db_name, template_db_name
    ))
    .execute(&mut conn)
    .await?;

    sqlx::query(&format!(
        "COMMENT ON DATABASE {} IS '{}'",
        db_name, module_path
    ))
    .execute(&mut conn)
    .await?;

    Ok((db_name, conn))
}

/// Spawns test pool with a new database
pub async fn spawn_test_pool(
    connect_options: &PgConnectOptions,
    db_name: &str,
    max_connections: Option<u32>,
) -> Result<Pool<Postgres>, Error> {
    let connect_options = connect_options.clone().database(db_name);
    let pool = PgPoolOptions::new()
        .max_connections(max_connections.unwrap_or(2))
        .idle_timeout(Some(std::time::Duration::from_secs(1)))
        .connect_with(connect_options)
        .await?;

    Ok(pool)
}

/// Returns the name of the database for the test pool or error
pub fn db_name_of_test_pool(connect_opts: &PgConnectOptions) -> Result<String, Error> {
    connect_opts
        .get_database()
        .map(|s| s.to_string())
        .ok_or(Error::DatabaseNotFound)
}

/// Closes test pool and drops the test database
pub async fn close_test_pool(
    conn: &mut PgConnection,
    pool: &sqlx::Pool<Postgres>,
) -> Result<(), Error> {
    let db_name = db_name_of_test_pool(&pool.connect_options())?;

    pool.close().await;

    sqlx::query(&format!("DROP DATABASE IF EXISTS {} WITH (FORCE)", db_name))
        .execute(conn)
        .await?;

    Ok(())
}

/// Runs an individual test
pub async fn wrap_run_test<F, Fut>(f: F, args: TestArgs) -> Result<(), Error>
where
    F: Fn(Pool<Postgres>) -> Fut,
    Fut: std::future::Future<Output = ()>,
{
    // Get connection string
    let database_url = std::env::var("DATABASE_URL").map_err(|_| Error::InvalidDatabaseUrl)?;

    // Try to get template database name from args, defaulting to connection database name
    let connect_opts = PgConnectOptions::from_str(&database_url)?;

    // Get template name from args or use database name from connection options
    let template_name = &args
        .template_name
        .map(Ok)
        .unwrap_or_else(|| db_name_of_test_pool(&connect_opts))?;

    // Service connection == default database (postgres, username, etc)
    let service_connect_opts = connect_opts.clone().database("");

    // Create a new database from the template
    let conn = PgConnection::connect_with(&service_connect_opts)
        .await
        .unwrap();
    let (db_name, conn) = create_db_from_template(conn, template_name, &args.module_path)
        .await
        .unwrap();
    conn.close().await?;

    // Run test
    let pool = spawn_test_pool(&service_connect_opts, &db_name, args.max_connections).await?;

    f(pool.clone()).await;

    // Close the pool & drop the test database
    let mut conn = PgConnection::connect_with(&service_connect_opts).await?;
    close_test_pool(&mut conn, &pool).await.unwrap();
    conn.close().await?;

    Ok(())
}

/// Runs an individual test
pub fn run_test<F, Fut>(f: F, args: TestArgs)
where
    F: Fn(Pool<Postgres>) -> Fut,
    Fut: std::future::Future<Output = ()>,
{
    sqlx::test_block_on(async move {
        match wrap_run_test(f, args).await {
            Err(e) => panic!("test failed: {e}"),
            Ok(v) => v,
        }
    })
}