sqlx_scylladb_core/
testing.rs1use std::fmt::Write;
2use std::time::SystemTime;
3use std::{ops::Deref, str::FromStr, sync::OnceLock, time::Duration};
4
5use futures_core::future::BoxFuture;
6use scylla::value::CqlTimestamp;
7use sha2::{Digest, Sha512};
8use sqlx::{Connection as _, Error, Executor, Pool, pool::PoolOptions};
9use sqlx_core::testing::{FixtureSnapshot, TestArgs, TestContext, TestSupport};
10
11use crate::{ScyllaDB, ScyllaDBConnectOptions, ScyllaDBConnection};
12
13static MASTER_POOL: OnceLock<Pool<ScyllaDB>> = OnceLock::new();
15
16impl TestSupport for ScyllaDB {
17 fn test_context(args: &TestArgs) -> BoxFuture<'_, Result<TestContext<Self>, Error>> {
18 Box::pin(async move { test_context(args).await })
19 }
20
21 fn cleanup_test(db_name: &str) -> BoxFuture<'_, Result<(), Error>> {
22 Box::pin(async move {
23 let mut conn = MASTER_POOL
24 .get()
25 .expect("cleanup_test() invoked outside `#[sqlx::test]`")
26 .acquire()
27 .await?;
28
29 do_cleanup(&mut conn, db_name).await
30 })
31 }
32
33 fn cleanup_test_dbs() -> BoxFuture<'static, Result<Option<usize>, Error>> {
34 Box::pin(async move { cleanup_test_dbs().await })
35 }
36
37 fn snapshot(
38 _conn: &mut Self::Connection,
39 ) -> BoxFuture<'_, Result<FixtureSnapshot<Self>, Error>> {
40 todo!()
41 }
42
43 fn db_name(args: &TestArgs) -> String {
55 let mut hasher = Sha512::new();
56 hasher.update(args.test_path.as_bytes());
57 let hash = hasher.finalize();
58 let hash = base32::encode(
59 base32::Alphabet::Rfc4648Lower { padding: false },
60 &hash[0..23],
61 );
62 let db_name = format!("sqlx_test_{}", hash);
64 debug_assert!(db_name.len() <= 48);
65 db_name
66 }
67}
68
69async fn test_context(args: &TestArgs) -> Result<TestContext<ScyllaDB>, Error> {
70 let url = get_database_url();
71
72 let master_opts = ScyllaDBConnectOptions::from_str(&url).expect("failed to parse DATABASE_URL");
73
74 let pool = PoolOptions::new()
75 .max_connections(20)
76 .after_release(|_conn, _| Box::pin(async move { Ok(false) }))
77 .connect_lazy_with(master_opts);
78
79 let master_pool = match once_lock_try_insert_polyfill(&MASTER_POOL, pool) {
80 Ok(inserted) => inserted,
81 Err((existing, pool)) => {
82 assert_eq!(
83 existing.connect_options().get_connect_nodes(),
84 pool.connect_options().get_connect_nodes(),
85 "DATABASE_URL changed at runtime, host differs"
86 );
87
88 assert_eq!(
89 existing.connect_options().keyspace,
90 pool.connect_options().keyspace,
91 "DATABASE_URL changed at runtime, database differs"
92 );
93
94 existing
95 }
96 };
97
98 let mut conn = master_pool.acquire().await?;
99
100 conn.execute(
101 r#"
102 CREATE TABLE IF NOT EXISTS sqlx_test_databases (
103 db_name TEXT PRIMARY KEY,
104 test_path TEXT,
105 created_at TIMESTAMP
106 )
107 "#,
108 )
109 .await?;
110
111 let db_name = ScyllaDB::db_name(args);
112 do_cleanup(&mut conn, &db_name).await?;
113
114 let timestamp = SystemTime::now()
115 .duration_since(SystemTime::UNIX_EPOCH)
116 .expect("System clock is before unix epoch.")
117 .as_millis() as i64;
118 let timestamp = CqlTimestamp(timestamp);
119
120 sqlx::query("INSERT INTO sqlx_test_databases(db_name, test_path, created_at) values (?, ?, ?)")
121 .bind(&db_name)
122 .bind(args.test_path)
123 .bind(timestamp)
124 .execute(&mut *conn)
125 .await?;
126
127 conn.execute(format!("CREATE KEYSPACE IF NOT EXISTS {db_name} WITH replication = {{'class': 'SimpleStrategy', 'replication_factor': 1}}").as_str()).await?;
128
129 eprintln!("CREATED KEYSPACE {db_name}");
130
131 Ok(TestContext {
132 pool_opts: PoolOptions::new()
133 .max_connections(5)
134 .idle_timeout(Some(Duration::from_secs(1)))
135 .parent(master_pool.clone()),
136 connect_opts: master_pool
137 .connect_options()
138 .deref()
139 .clone()
140 .keyspace(&db_name),
141 db_name,
142 })
143}
144
145async fn do_cleanup(conn: &mut ScyllaDBConnection, db_name: &str) -> Result<(), Error> {
146 let delete_db_command = format!("DROP KEYSPACE IF EXISTS {db_name};");
147 conn.execute(delete_db_command.as_str()).await?;
148 sqlx::query("DELETE FROM sqlx_test_databases WHERE db_name = ?")
149 .bind(db_name)
150 .execute(&mut *conn)
151 .await?;
152
153 Ok(())
154}
155
156async fn cleanup_test_dbs() -> Result<Option<usize>, Error> {
157 let url = get_database_url();
158
159 let mut conn = ScyllaDBConnection::connect(&url).await?;
160
161 let delete_db_names: Vec<String> =
162 sqlx::query_scalar("SELECT db_name from sqlx_test_databases")
163 .fetch_all(&mut conn)
164 .await?;
165
166 if delete_db_names.is_empty() {
167 return Ok(None);
168 }
169
170 let mut deleted_db_names = Vec::with_capacity(delete_db_names.len());
171
172 let mut command = String::new();
173
174 for db_name in &delete_db_names {
175 command.clear();
176
177 writeln!(command, "drop database if exists {db_name};").ok();
178 match conn.execute(&*command).await {
179 Ok(_deleted) => {
180 deleted_db_names.push(db_name);
181 }
182 Err(Error::Database(dbe)) => {
184 eprintln!("could not clean test database {db_name:?}: {dbe}")
185 }
186 Err(e) => return Err(e),
188 }
189 }
190
191 if deleted_db_names.is_empty() {
192 return Ok(None);
193 }
194
195 sqlx::query("DELETE FROM sqlx_test_databases WHERE db_name IN(?)")
196 .bind(delete_db_names.as_slice())
197 .execute(&mut conn)
198 .await?;
199
200 let _ = conn.close().await;
201
202 Ok(Some(delete_db_names.len()))
203}
204
205fn get_database_url() -> String {
206 dotenvy::var("SCYLLADB_URL")
207 .or_else(|_| dotenvy::var("DATABSE_URL"))
208 .expect("SCYLLADB_URL or DATABASE_URL must be set")
209}
210
211fn once_lock_try_insert_polyfill<T>(this: &OnceLock<T>, value: T) -> Result<&T, (&T, T)> {
212 let mut value = Some(value);
213 let res = this.get_or_init(|| value.take().unwrap());
214 match value {
215 None => Ok(res),
216 Some(value) => Err((res, value)),
217 }
218}