sqlx_database_tester_macros/
lib.rs

1//! macros for sqlx-database-tester
2
3use darling::FromMeta;
4use proc_macro::TokenStream;
5use proc_macro2::Span;
6use quote::{format_ident, quote};
7use syn::{parse_macro_input, AttributeArgs, Ident};
8mod generators;
9
10/// Pool configuration
11#[derive(Debug, FromMeta)]
12pub(crate) struct Pool {
13	/// The variable with pool that will be exposed to the test function
14	variable: Ident,
15	/// The optional transaction variable
16	#[darling(default)]
17	transaction_variable: Option<Ident>,
18	/// The migration directory path
19	#[darling(default)]
20	migrations: Option<String>,
21	/// Should the migration be skipped
22	#[darling(default)]
23	skip_migrations: bool,
24}
25
26impl Pool {
27	/// Return identifier for variable that will contain the database name
28	fn database_name_var(&self) -> Ident {
29		format_ident!("__{}_db_name", &self.variable)
30	}
31}
32
33/// Test case configuration
34#[derive(Debug, FromMeta)]
35pub(crate) struct MacroArgs {
36	/// Sqlx log level
37	#[darling(default)]
38	level: String,
39	/// The variable the database pool will be exposed in
40	#[darling(multiple)]
41	pool: Vec<Pool>,
42}
43
44/// Marks async test function that exposes database pool to its scope
45///
46/// ## Macro attributes:
47///
48/// - `variable`: Variable of the PgPool to be exposed to the function scope
49///   (mandatory)
50/// - `other_dir_migrations`: Path to SQLX other_dir_migrations directory for
51///   the specified pool (falls back to default ./migrations directory if left
52///   out)
53/// - `skip_migrations`: If present, doesn't run any other_dir_migrations
54/// ```
55/// #[sqlx_database_tester::test(
56///     pool(variable = "default_migrated_pool"),
57///     pool(variable = "migrated_pool", migrations = "./other_dir_migrations"),
58///     pool(variable = "empty_db_pool",
59///          transaction_variable = "empty_db_transaction",
60///          skip_migrations),
61/// )]
62/// async fn test_server_sta_rt() {
63///     let migrated_pool_tables = sqlx::query!("SELECT * FROM pg_catalog.pg_tables")
64///         .fetch_all(&migrated_pool)
65///         .await
66///         .unwrap();
67///     let empty_pool_tables = sqlx::query!("SELECT * FROM pg_catalog.pg_tables")
68///         .fetch_all(&migrated_pool)
69///         .await
70///         .unwrap();
71///     println!("Migrated pool tables: \n {:#?}", migrated_pool_tables);
72///     println!("Empty pool tables: \n {:#?}", empty_pool_tables);
73/// }
74/// ```
75#[proc_macro_attribute]
76pub fn test(test_attr: TokenStream, item: TokenStream) -> TokenStream {
77	let mut input = syn::parse_macro_input!(item as syn::ItemFn);
78	let test_attr_args = parse_macro_input!(test_attr as AttributeArgs);
79	let test_attr: MacroArgs = match MacroArgs::from_list(&test_attr_args) {
80		Ok(v) => v,
81		Err(e) => {
82			return TokenStream::from(e.write_errors());
83		}
84	};
85
86	let level = test_attr.level.as_str();
87	let attrs = &input.attrs;
88	let vis = &input.vis;
89	let sig = &mut input.sig;
90	let body = &input.block;
91
92	let runtime = if let Some(runtime) = generators::runtime() {
93		runtime
94	} else {
95		return syn::Error::new(
96			Span::call_site(),
97			"One of 'runtime-actix' and 'runtime-tokio' features needs to be enabled",
98		)
99		.into_compile_error()
100		.into();
101	};
102
103	if sig.asyncness.is_none() {
104		return syn::Error::new_spanned(
105			input.sig.fn_token,
106			"the async keyword is missing from the function declaration",
107		)
108		.to_compile_error()
109		.into();
110	}
111
112	sig.asyncness = None;
113
114	let database_name_vars = generators::database_name_vars(&test_attr);
115	let database_creators = generators::database_creators(&test_attr);
116	let database_migrations_exposures = generators::database_migrations_exposures(&test_attr);
117	let database_closers = generators::database_closers(&test_attr);
118	let database_destructors = generators::database_destructors(&test_attr);
119	let sleep = generators::sleep();
120
121	(quote! {
122		#[::core::prelude::v1::test]
123		#(#attrs)*
124		#vis #sig {
125			/// Maximum number of tries to attempt database connection
126			const MAX_RETRIES: u8 = 30;
127			/// Time between retries, in seconds
128			const TIME_BETWEEN_RETRIES: u64 = 10;
129
130			#[allow(clippy::expect_used)]
131			async fn connect_with_retry() -> Result<sqlx::PgPool, sqlx::Error> {
132				let mut i = 0;
133				loop {
134					let db_pool = sqlx::PgPool::connect_with(sqlx_database_tester::connect_options(
135						sqlx_database_tester::derive_db_prefix(&sqlx_database_tester::get_database_uri())
136							.expect("Getting database name")
137							.as_deref()
138							.unwrap_or_default(),
139						#level,
140					))
141					.await;
142					match db_pool {
143						Ok(pool) => break Ok(pool),
144						Err(e) => {
145							if i >= MAX_RETRIES {
146								break Err(e);
147							}
148						}
149					}
150					#sleep(std::time::Duration::from_secs(TIME_BETWEEN_RETRIES)).await;
151					i += 1;
152				}
153			}
154
155			sqlx_database_tester::dotenv::dotenv().ok();
156			#(#database_name_vars)*
157			#runtime.block_on(async {
158				#[allow(clippy::expect_used)]
159				let db_pool = connect_with_retry().await.expect("connecting to db for creation");
160				#(#database_creators)*
161			});
162
163			let result = std::panic::catch_unwind(|| {
164				#runtime.block_on(async {
165					#(#database_migrations_exposures)*
166					let res = #body;
167					#(#database_closers)*
168					res
169				})
170			});
171
172			#runtime.block_on(async {
173				#[allow(clippy::expect_used)]
174				let db_pool = connect_with_retry().await.expect("connecting to db for deletion");
175				#(#database_destructors)*
176			});
177
178			match result {
179				std::result::Result::Err(_) => std::panic!("The main test function crashed, the test database got cleaned"),
180				std::result::Result::Ok(o) => o
181			}
182		}
183	}).into()
184}