sqlx_scylladb_core/
testing.rs

1use std::fmt::Write;
2use std::time::SystemTime;
3use std::{ops::Deref, str::FromStr, sync::OnceLock, time::Duration};
4
5use futures_core::future::BoxFuture;
6use scylla::value::CqlTimestamp;
7use sha2::{Digest, Sha512};
8use sqlx::{Connection as _, Error, Executor, Pool, pool::PoolOptions};
9use sqlx_core::testing::{FixtureSnapshot, TestArgs, TestContext, TestSupport};
10
11use crate::{ScyllaDB, ScyllaDBConnectOptions, ScyllaDBConnection};
12
13// Using a blocking `OnceLock` here because the critical sections are short.
14static MASTER_POOL: OnceLock<Pool<ScyllaDB>> = OnceLock::new();
15
16impl TestSupport for ScyllaDB {
17    fn test_context(args: &TestArgs) -> BoxFuture<'_, Result<TestContext<Self>, Error>> {
18        Box::pin(async move { test_context(args).await })
19    }
20
21    fn cleanup_test(db_name: &str) -> BoxFuture<'_, Result<(), Error>> {
22        Box::pin(async move {
23            let mut conn = MASTER_POOL
24                .get()
25                .expect("cleanup_test() invoked outside `#[sqlx::test]`")
26                .acquire()
27                .await?;
28
29            do_cleanup(&mut conn, db_name).await
30        })
31    }
32
33    fn cleanup_test_dbs() -> BoxFuture<'static, Result<Option<usize>, Error>> {
34        Box::pin(async move { cleanup_test_dbs().await })
35    }
36
37    fn snapshot(
38        _conn: &mut Self::Connection,
39    ) -> BoxFuture<'_, Result<FixtureSnapshot<Self>, Error>> {
40        todo!()
41    }
42
43    ///
44    /// ```
45    /// use sqlx_core::testing::TestSupport;
46    /// let args = ::sqlx_core::testing::TestArgs{
47    ///     test_path: "my_test_function",
48    ///     migrator: None,
49    ///     fixtures: &[],
50    /// };
51    /// let name = ::sqlx_scylladb_core::ScyllaDB::db_name(&args);
52    /// assert_eq!("sqlx_test_ai4drkqtg4lnnkdlk7fjtixcrmcc7cnwamqrm", name);
53    /// ```
54    fn db_name(args: &TestArgs) -> String {
55        let mut hasher = Sha512::new();
56        hasher.update(args.test_path.as_bytes());
57        let hash = hasher.finalize();
58        let hash = base32::encode(
59            base32::Alphabet::Rfc4648Lower { padding: false },
60            &hash[0..23],
61        );
62        // Keyspace name is supported lower and less than 48 characters.
63        let db_name = format!("sqlx_test_{}", hash);
64        debug_assert!(db_name.len() <= 48);
65        db_name
66    }
67}
68
69async fn test_context(args: &TestArgs) -> Result<TestContext<ScyllaDB>, Error> {
70    let url = get_database_url();
71
72    let master_opts = ScyllaDBConnectOptions::from_str(&url).expect("failed to parse DATABASE_URL");
73
74    let pool = PoolOptions::new()
75        .max_connections(20)
76        .after_release(|_conn, _| Box::pin(async move { Ok(false) }))
77        .connect_lazy_with(master_opts);
78
79    let master_pool = match once_lock_try_insert_polyfill(&MASTER_POOL, pool) {
80        Ok(inserted) => inserted,
81        Err((existing, pool)) => {
82            assert_eq!(
83                existing.connect_options().get_connect_nodes(),
84                pool.connect_options().get_connect_nodes(),
85                "DATABASE_URL changed at runtime, host differs"
86            );
87
88            assert_eq!(
89                existing.connect_options().keyspace,
90                pool.connect_options().keyspace,
91                "DATABASE_URL changed at runtime, database differs"
92            );
93
94            existing
95        }
96    };
97
98    let mut conn = master_pool.acquire().await?;
99
100    conn.execute(
101        r#"
102        CREATE TABLE IF NOT EXISTS sqlx_test_databases (
103            db_name TEXT PRIMARY KEY,
104            test_path TEXT,
105            created_at TIMESTAMP
106        )
107    "#,
108    )
109    .await?;
110
111    let db_name = ScyllaDB::db_name(args);
112    do_cleanup(&mut conn, &db_name).await?;
113
114    let timestamp = SystemTime::now()
115        .duration_since(SystemTime::UNIX_EPOCH)
116        .expect("System clock is before unix epoch.")
117        .as_millis() as i64;
118    let timestamp = CqlTimestamp(timestamp);
119
120    sqlx::query("INSERT INTO sqlx_test_databases(db_name, test_path, created_at) values (?, ?, ?)")
121        .bind(&db_name)
122        .bind(args.test_path)
123        .bind(timestamp)
124        .execute(&mut *conn)
125        .await?;
126
127    conn.execute(format!("CREATE KEYSPACE IF NOT EXISTS {db_name} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}").as_str()).await?;
128
129    eprintln!("CREATED KEYSPACE {db_name}");
130
131    Ok(TestContext {
132        pool_opts: PoolOptions::new()
133            .max_connections(5)
134            .idle_timeout(Some(Duration::from_secs(1)))
135            .parent(master_pool.clone()),
136        connect_opts: master_pool
137            .connect_options()
138            .deref()
139            .clone()
140            .keyspace(&db_name),
141        db_name,
142    })
143}
144
145async fn do_cleanup(conn: &mut ScyllaDBConnection, db_name: &str) -> Result<(), Error> {
146    let delete_db_command = format!("DROP KEYSPACE IF EXISTS {db_name};");
147    conn.execute(delete_db_command.as_str()).await?;
148    sqlx::query("DELETE FROM sqlx_test_databases WHERE db_name = ?")
149        .bind(db_name)
150        .execute(&mut *conn)
151        .await?;
152
153    Ok(())
154}
155
156async fn cleanup_test_dbs() -> Result<Option<usize>, Error> {
157    let url = get_database_url();
158
159    let mut conn = ScyllaDBConnection::connect(&url).await?;
160
161    let delete_db_names: Vec<String> =
162        sqlx::query_scalar("SELECT db_name from sqlx_test_databases")
163            .fetch_all(&mut conn)
164            .await?;
165
166    if delete_db_names.is_empty() {
167        return Ok(None);
168    }
169
170    let mut deleted_db_names = Vec::with_capacity(delete_db_names.len());
171
172    let mut command = String::new();
173
174    for db_name in &delete_db_names {
175        command.clear();
176
177        writeln!(command, "drop database if exists {db_name};").ok();
178        match conn.execute(&*command).await {
179            Ok(_deleted) => {
180                deleted_db_names.push(db_name);
181            }
182            // Assume a database error just means the DB is still in use.
183            Err(Error::Database(dbe)) => {
184                eprintln!("could not clean test database {db_name:?}: {dbe}")
185            }
186            // Bubble up other errors
187            Err(e) => return Err(e),
188        }
189    }
190
191    if deleted_db_names.is_empty() {
192        return Ok(None);
193    }
194
195    sqlx::query("DELETE FROM sqlx_test_databases WHERE db_name IN(?)")
196        .bind(delete_db_names.as_slice())
197        .execute(&mut conn)
198        .await?;
199
200    let _ = conn.close().await;
201
202    Ok(Some(delete_db_names.len()))
203}
204
205fn get_database_url() -> String {
206    dotenvy::var("SCYLLADB_URL")
207        .or_else(|_| dotenvy::var("DATABSE_URL"))
208        .expect("SCYLLADB_URL or DATABASE_URL must be set")
209}
210
211fn once_lock_try_insert_polyfill<T>(this: &OnceLock<T>, value: T) -> Result<&T, (&T, T)> {
212    let mut value = Some(value);
213    let res = this.get_or_init(|| value.take().unwrap());
214    match value {
215        None => Ok(res),
216        Some(value) => Err((res, value)),
217    }
218}