1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
#![deny(
missing_docs,
trivial_casts,
trivial_numeric_casts,
unused_extern_crates,
unused_import_braces,
unused_qualifications
)]
#![warn(missing_debug_implementations, dead_code, clippy::unwrap_used, clippy::expect_used)]
use darling::FromMeta;
use proc_macro::TokenStream;
use proc_macro2::Span;
use quote::{format_ident, quote};
use syn::{parse_macro_input, AttributeArgs, Ident};
mod generators;
#[derive(Debug, FromMeta)]
pub(crate) struct Pool {
variable: Ident,
#[darling(default)]
transaction_variable: Option<Ident>,
#[darling(default)]
migrations: Option<String>,
#[darling(default)]
skip_migrations: bool,
}
impl Pool {
fn database_name_var(&self) -> Ident {
format_ident!("__{}_db_name", &self.variable)
}
}
#[derive(Debug, FromMeta)]
pub(crate) struct MacroArgs {
#[darling(multiple)]
pool: Vec<Pool>,
}
#[proc_macro_attribute]
pub fn test(test_attr: TokenStream, item: TokenStream) -> TokenStream {
let mut input = syn::parse_macro_input!(item as syn::ItemFn);
let test_attr_args = parse_macro_input!(test_attr as AttributeArgs);
let test_attr: MacroArgs = match MacroArgs::from_list(&test_attr_args) {
Ok(v) => v,
Err(e) => {
return TokenStream::from(e.write_errors());
}
};
let attrs = &input.attrs;
let vis = &input.vis;
let sig = &mut input.sig;
let body = &input.block;
let runtime = if let Some(runtime) = generators::runtime() {
runtime
} else {
return syn::Error::new(
Span::call_site(),
"One of 'runtime-actix' and 'runtime-tokio' features needs to be enabled",
)
.into_compile_error()
.into();
};
if sig.asyncness.is_none() {
return syn::Error::new_spanned(
input.sig.fn_token,
"the async keyword is missing from the function declaration",
)
.to_compile_error()
.into();
}
sig.asyncness = None;
let database_name_vars = generators::database_name_vars(&test_attr);
let database_creators = generators::database_creators(&test_attr);
let database_migrations_exposures = generators::database_migrations_exposures(&test_attr);
let database_closers = generators::database_closers(&test_attr);
let database_destructors = generators::database_destructors(&test_attr);
(quote! {
#[::core::prelude::v1::test]
#(#attrs)*
#vis #sig {
sqlx_database_tester::dotenv::dotenv().ok();
#(#database_name_vars)*
#runtime.block_on(async {
#[allow(clippy::expect_used)]
let db_pool = sqlx::PgPool::connect(
&sqlx_database_tester::get_target_database_uri(
&sqlx_database_tester::get_database_uri(), sqlx_database_tester::derive_db_prefix(
&sqlx_database_tester::get_database_uri()).expect("Getting database name").as_deref().unwrap_or_default()).expect("URI parsing")
).await.expect("connecting to db for creation");
#(#database_creators)*
});
let result = std::panic::catch_unwind(|| {
#runtime.block_on(async {
#(#database_migrations_exposures)*
let res = #body;
#(#database_closers)*
res
})
});
#runtime.block_on(async {
#[allow(clippy::expect_used)]
let db_pool = sqlx::PgPool::connect(
&sqlx_database_tester::get_target_database_uri(
&sqlx_database_tester::get_database_uri(), sqlx_database_tester::derive_db_prefix(
&sqlx_database_tester::get_database_uri()).expect("Getting database name").as_deref().unwrap_or_default()).expect("URI parsing")
).await.expect("connecting to db for deletion");
#(#database_destructors)*
});
match result {
std::result::Result::Err(_) => std::panic!("The main test function crashed, the test database got cleaned"),
std::result::Result::Ok(o) => o
}
}
}).into()
}