Skip to main content

sqlx_testcontainers_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse::{Parse, ParseStream}, parse_macro_input, ItemFn, LitStr, Token, Ident};
4
5struct MacroArgs {
6    tag: Option<String>,
7    migrations: Option<String>,
8}
9
10impl Parse for MacroArgs {
11    fn parse(input: ParseStream) -> syn::Result<Self> {
12        let mut tag = None;
13        let mut migrations = None;
14
15        while !input.is_empty() {
16            let ident: Ident = input.parse()?;
17            input.parse::<Token![=]>()?;
18            let value: LitStr = input.parse()?;
19
20            match ident.to_string().as_str() {
21                "tag" => tag = Some(value.value()),
22                "migrations" => migrations = Some(value.value()),
23                _ => return Err(syn::Error::new(ident.span(), "expected `tag` or `migrations`")),
24            }
25
26            if !input.is_empty() {
27                input.parse::<Token![,]>()?;
28            }
29        }
30
31        Ok(MacroArgs { tag, migrations })
32    }
33}
34
35#[proc_macro_attribute]
36pub fn test(args: TokenStream, input: TokenStream) -> TokenStream {
37    let args = parse_macro_input!(args as MacroArgs);
38    let mut input = parse_macro_input!(input as ItemFn);
39
40    let test_name = input.sig.ident.clone();
41    let inner_name = quote::format_ident!("__inner_{}", test_name);
42    
43    // Rename the original function to __inner_...
44    input.sig.ident = inner_name.clone();
45    
46    let tag_expr = if let Some(t) = args.tag {
47        quote! { Some(#t.to_string()) }
48    } else {
49        quote! { None }
50    };
51
52    let migrate_expr = if let Some(m) = args.migrations {
53        quote! { ::sqlx::migrate!(#m).run(&pool).await.expect("Failed to run migrations"); }
54    } else {
55        quote! { ::sqlx::migrate!().run(&pool).await.expect("Failed to run migrations"); }
56    };
57
58    let expanded = quote! {
59        #[::tokio::test]
60        async fn #test_name() {
61            use ::testcontainers_modules::testcontainers::ImageExt;
62            use ::testcontainers_modules::testcontainers::runners::AsyncRunner;
63            use ::testcontainers_modules::postgres::Postgres;
64            
65            let mut image = Postgres::default();
66            let tag: Option<String> = #tag_expr;
67            let container = if let Some(tag) = tag {
68                image.with_tag(tag).start().await.expect("Failed to start postgres container")
69            } else {
70                image.start().await.expect("Failed to start postgres container")
71            };
72            
73            let host = container.get_host().await.expect("Failed to get host");
74            let host_port = container.get_host_port_ipv4(5432).await.expect("Failed to get port");
75            
76            let conn_str = format!("postgres://postgres:postgres@{}:{}/postgres", host, host_port);
77            
78            let pool = ::sqlx::postgres::PgPoolOptions::new()
79                .max_connections(1)
80                .connect(&conn_str)
81                .await
82                .expect("Failed to connect to postgres");
83            
84            #migrate_expr
85            
86            let mut conn = pool.acquire().await.expect("Failed to acquire connection").detach();
87            
88            #input
89
90            #inner_name(conn).await;
91        }
92    };
93
94    TokenStream::from(expanded)
95}