sqlx_testcontainers_macros/
lib.rs1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse::{Parse, ParseStream}, parse_macro_input, ItemFn, LitStr, Token, Ident};
4
5struct MacroArgs {
6 tag: Option<String>,
7 migrations: Option<String>,
8}
9
10impl Parse for MacroArgs {
11 fn parse(input: ParseStream) -> syn::Result<Self> {
12 let mut tag = None;
13 let mut migrations = None;
14
15 while !input.is_empty() {
16 let ident: Ident = input.parse()?;
17 input.parse::<Token![=]>()?;
18 let value: LitStr = input.parse()?;
19
20 match ident.to_string().as_str() {
21 "tag" => tag = Some(value.value()),
22 "migrations" => migrations = Some(value.value()),
23 _ => return Err(syn::Error::new(ident.span(), "expected `tag` or `migrations`")),
24 }
25
26 if !input.is_empty() {
27 input.parse::<Token![,]>()?;
28 }
29 }
30
31 Ok(MacroArgs { tag, migrations })
32 }
33}
34
35#[proc_macro_attribute]
36pub fn test(args: TokenStream, input: TokenStream) -> TokenStream {
37 let args = parse_macro_input!(args as MacroArgs);
38 let mut input = parse_macro_input!(input as ItemFn);
39
40 let test_name = input.sig.ident.clone();
41 let inner_name = quote::format_ident!("__inner_{}", test_name);
42
43 input.sig.ident = inner_name.clone();
45
46 let tag_expr = if let Some(t) = args.tag {
47 quote! { Some(#t.to_string()) }
48 } else {
49 quote! { None }
50 };
51
52 let migrate_expr = if let Some(m) = args.migrations {
53 quote! { ::sqlx::migrate!(#m).run(&pool).await.expect("Failed to run migrations"); }
54 } else {
55 quote! { ::sqlx::migrate!().run(&pool).await.expect("Failed to run migrations"); }
56 };
57
58 let expanded = quote! {
59 #[::tokio::test]
60 async fn #test_name() {
61 use ::testcontainers_modules::testcontainers::ImageExt;
62 use ::testcontainers_modules::testcontainers::runners::AsyncRunner;
63 use ::testcontainers_modules::postgres::Postgres;
64
65 let mut image = Postgres::default();
66 let tag: Option<String> = #tag_expr;
67 let container = if let Some(tag) = tag {
68 image.with_tag(tag).start().await.expect("Failed to start postgres container")
69 } else {
70 image.start().await.expect("Failed to start postgres container")
71 };
72
73 let host = container.get_host().await.expect("Failed to get host");
74 let host_port = container.get_host_port_ipv4(5432).await.expect("Failed to get port");
75
76 let conn_str = format!("postgres://postgres:postgres@{}:{}/postgres", host, host_port);
77
78 let pool = ::sqlx::postgres::PgPoolOptions::new()
79 .max_connections(1)
80 .connect(&conn_str)
81 .await
82 .expect("Failed to connect to postgres");
83
84 #migrate_expr
85
86 let mut conn = pool.acquire().await.expect("Failed to acquire connection").detach();
87
88 #input
89
90 #inner_name(conn).await;
91 }
92 };
93
94 TokenStream::from(expanded)
95}