sqlx_testcontainers_macros/
lib.rs1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{
4 parse::{Parse, ParseStream},
5 parse_macro_input, Ident, ItemFn, LitStr, Token,
6};
7
8struct MacroArgs {
9 tag: Option<String>,
10 migrations: Option<String>,
11}
12
13impl Parse for MacroArgs {
14 fn parse(input: ParseStream) -> syn::Result<Self> {
15 let mut tag = None;
16 let mut migrations = None;
17
18 while !input.is_empty() {
19 let ident: Ident = input.parse()?;
20 input.parse::<Token![=]>()?;
21 let value: LitStr = input.parse()?;
22
23 match ident.to_string().as_str() {
24 "tag" => tag = Some(value.value()),
25 "migrations" => migrations = Some(value.value()),
26 _ => {
27 return Err(syn::Error::new(
28 ident.span(),
29 "expected `tag` or `migrations`",
30 ))
31 }
32 }
33
34 if !input.is_empty() {
35 input.parse::<Token![,]>()?;
36 }
37 }
38
39 Ok(MacroArgs { tag, migrations })
40 }
41}
42
43#[proc_macro_attribute]
44pub fn test(args: TokenStream, input: TokenStream) -> TokenStream {
45 let args = parse_macro_input!(args as MacroArgs);
46 let mut input = parse_macro_input!(input as ItemFn);
47
48 let test_name = input.sig.ident.clone();
49 let inner_name = quote::format_ident!("__inner_{}", test_name);
50
51 input.sig.ident = inner_name.clone();
53
54 let tag_expr = if let Some(t) = args.tag {
55 quote! { Some(#t.to_string()) }
56 } else {
57 quote! { None }
58 };
59
60 let migrate_expr = if let Some(m) = args.migrations {
61 quote! { ::sqlx::migrate!(#m).run(&pool).await.expect("Failed to run migrations"); }
62 } else {
63 quote! { ::sqlx::migrate!().run(&pool).await.expect("Failed to run migrations"); }
64 };
65
66 let expanded = quote! {
67 #[::tokio::test]
68 async fn #test_name() {
69 use ::testcontainers_modules::testcontainers::ImageExt;
70 use ::testcontainers_modules::testcontainers::runners::AsyncRunner;
71 use ::testcontainers_modules::postgres::Postgres;
72
73 let mut image = Postgres::default();
74 let tag: Option<String> = #tag_expr;
75 let container = if let Some(tag) = tag {
76 image.with_tag(tag).start().await.expect("Failed to start postgres container")
77 } else {
78 image.start().await.expect("Failed to start postgres container")
79 };
80
81 let host = container.get_host().await.expect("Failed to get host");
82 let host_port = container.get_host_port_ipv4(5432).await.expect("Failed to get port");
83
84 let conn_str = format!("postgres://postgres:postgres@{}:{}/postgres", host, host_port);
85
86 let pool = ::sqlx::postgres::PgPoolOptions::new()
87 .max_connections(1)
88 .connect(&conn_str)
89 .await
90 .expect("Failed to connect to postgres");
91
92 #migrate_expr
93
94 let mut conn = pool.acquire().await.expect("Failed to acquire connection").detach();
95
96 #input
97
98 #inner_name(conn).await;
99 }
100 };
101
102 TokenStream::from(expanded)
103}