sqlx_database_tester_macros/
lib.rs1use darling::FromMeta;
4use proc_macro::TokenStream;
5use proc_macro2::Span;
6use quote::{format_ident, quote};
7use syn::{parse_macro_input, AttributeArgs, Ident};
8mod generators;
9
10#[derive(Debug, FromMeta)]
12pub(crate) struct Pool {
13 variable: Ident,
15 #[darling(default)]
17 transaction_variable: Option<Ident>,
18 #[darling(default)]
20 migrations: Option<String>,
21 #[darling(default)]
23 skip_migrations: bool,
24}
25
26impl Pool {
27 fn database_name_var(&self) -> Ident {
29 format_ident!("__{}_db_name", &self.variable)
30 }
31}
32
33#[derive(Debug, FromMeta)]
35pub(crate) struct MacroArgs {
36 #[darling(default)]
38 level: String,
39 #[darling(multiple)]
41 pool: Vec<Pool>,
42}
43
44#[proc_macro_attribute]
76pub fn test(test_attr: TokenStream, item: TokenStream) -> TokenStream {
77 let mut input = syn::parse_macro_input!(item as syn::ItemFn);
78 let test_attr_args = parse_macro_input!(test_attr as AttributeArgs);
79 let test_attr: MacroArgs = match MacroArgs::from_list(&test_attr_args) {
80 Ok(v) => v,
81 Err(e) => {
82 return TokenStream::from(e.write_errors());
83 }
84 };
85
86 let level = test_attr.level.as_str();
87 let attrs = &input.attrs;
88 let vis = &input.vis;
89 let sig = &mut input.sig;
90 let body = &input.block;
91
92 let runtime = if let Some(runtime) = generators::runtime() {
93 runtime
94 } else {
95 return syn::Error::new(
96 Span::call_site(),
97 "One of 'runtime-actix' and 'runtime-tokio' features needs to be enabled",
98 )
99 .into_compile_error()
100 .into();
101 };
102
103 if sig.asyncness.is_none() {
104 return syn::Error::new_spanned(
105 input.sig.fn_token,
106 "the async keyword is missing from the function declaration",
107 )
108 .to_compile_error()
109 .into();
110 }
111
112 sig.asyncness = None;
113
114 let database_name_vars = generators::database_name_vars(&test_attr);
115 let database_creators = generators::database_creators(&test_attr);
116 let database_migrations_exposures = generators::database_migrations_exposures(&test_attr);
117 let database_closers = generators::database_closers(&test_attr);
118 let database_destructors = generators::database_destructors(&test_attr);
119 let sleep = generators::sleep();
120
121 (quote! {
122 #[::core::prelude::v1::test]
123 #(#attrs)*
124 #vis #sig {
125 const MAX_RETRIES: u8 = 30;
127 const TIME_BETWEEN_RETRIES: u64 = 10;
129
130 #[allow(clippy::expect_used)]
131 async fn connect_with_retry() -> Result<sqlx::PgPool, sqlx::Error> {
132 let mut i = 0;
133 loop {
134 let db_pool = sqlx::PgPool::connect_with(sqlx_database_tester::connect_options(
135 sqlx_database_tester::derive_db_prefix(&sqlx_database_tester::get_database_uri())
136 .expect("Getting database name")
137 .as_deref()
138 .unwrap_or_default(),
139 #level,
140 ))
141 .await;
142 match db_pool {
143 Ok(pool) => break Ok(pool),
144 Err(e) => {
145 if i >= MAX_RETRIES {
146 break Err(e);
147 }
148 }
149 }
150 #sleep(std::time::Duration::from_secs(TIME_BETWEEN_RETRIES)).await;
151 i += 1;
152 }
153 }
154
155 sqlx_database_tester::dotenv::dotenv().ok();
156 #(#database_name_vars)*
157 #runtime.block_on(async {
158 #[allow(clippy::expect_used)]
159 let db_pool = connect_with_retry().await.expect("connecting to db for creation");
160 #(#database_creators)*
161 });
162
163 let result = std::panic::catch_unwind(|| {
164 #runtime.block_on(async {
165 #(#database_migrations_exposures)*
166 let res = #body;
167 #(#database_closers)*
168 res
169 })
170 });
171
172 #runtime.block_on(async {
173 #[allow(clippy::expect_used)]
174 let db_pool = connect_with_retry().await.expect("connecting to db for deletion");
175 #(#database_destructors)*
176 });
177
178 match result {
179 std::result::Result::Err(_) => std::panic!("The main test function crashed, the test database got cleaned"),
180 std::result::Result::Ok(o) => o
181 }
182 }
183 }).into()
184}