teststack_macros/
lib.rs

1use custom::CustomContainer;
2use db::{DbContainer, DbName};
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::{parse::Parse, parse_macro_input, punctuated::Punctuated, ItemFn};
6
7mod custom;
8mod db;
9
10enum Container {
11    Db(DbContainer),
12    Custom(CustomContainer),
13}
14
15impl Parse for Container {
16    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
17        if db::peek(&input) {
18            Ok(Container::Db(DbContainer::parse(input)?))
19        } else if custom::peek(&input) {
20            Ok(Container::Custom(CustomContainer::parse(input)?))
21        } else {
22            Err(input.error("expected a container"))
23        }
24    }
25}
26
27/// Configure the stack of containers for a test.
28///
29/// Example:
30/// ```rust
31/// use teststack::stack;
32/// #[stack(postgres(db_name = "test"))]
33/// #[tokio::test]
34/// async fn test(pool: sqlx::PgPool) {
35///    let db_name: String = sqlx::query_scalar("SELECT current_database()")
36///        .fetch_one(&pool)
37///        .await
38///        .unwrap();
39///    assert_eq!(db_name, "test");
40/// }
41/// ```
42#[proc_macro_attribute]
43pub fn stack(attr: TokenStream, item: TokenStream) -> TokenStream {
44    let containers =
45        parse_macro_input!(attr with Punctuated::<Container, syn::Token![,]>::parse_terminated);
46    let input = parse_macro_input!(item as ItemFn);
47    let containers = containers.into_iter().collect();
48    expand_test(input, containers)
49}
50
51fn expand_test(input: ItemFn, containers: Vec<Container>) -> TokenStream {
52    let ret = &input.sig.output;
53    let name = &input.sig.ident;
54    let body = &input.block;
55    let attrs = &input.attrs;
56    let args = &input.sig.inputs;
57    let sqlx_test = attrs
58        .iter()
59        .any(|attr| attr.path().segments.iter().any(|s| s.ident == "sqlx"));
60
61    let container_vars = (0..containers.len())
62        .map(|i| format_ident!("container_{i}"))
63        .collect::<Vec<_>>();
64
65    let containers = containers
66        .iter()
67        .map(|container| match container {
68            Container::Db(db) => {
69                let db_name = match &db.conf.db_name {
70                    DbName::Random => quote! {::teststack::DbName::Random },
71                    DbName::Static(name) => {
72                        quote! {::teststack::DbName::Static(#name.to_string()) }
73                    }
74                    DbName::Default => quote! { ::teststack::DbName::Default },
75                };
76                match db.name {
77                    "postgres" => quote! { ::teststack::postgres(#db_name) },
78                    "mysql" => quote! { ::teststack::mysql(#db_name) },
79                    _ => panic!("Unknown container type: {}", name),
80                }
81            }
82            Container::Custom(custom) => {
83                let expr = &custom.expr;
84                quote! { ::teststack::custom(#expr) }
85            }
86        })
87        .collect::<Vec<_>>();
88
89    if sqlx_test {
90        quote! {
91            #[allow(unnameable_test_items)]
92            #[::core::prelude::v1::test]
93            fn #name() #ret {
94                let rt = ::tokio::runtime::Builder::new_current_thread()
95                    .enable_all()
96                    .build()
97                    .unwrap();
98                rt.block_on(async {
99                    ::tokio::join!(#(#containers),*);
100                });
101                #(#attrs)*
102                fn #name(#args) #ret {
103                    #body
104                }
105                #name()
106            }
107        }
108        .into()
109    } else {
110        let test_args = args.iter().enumerate().map(|(i, arg)| {
111            let arg = match arg {
112                syn::FnArg::Typed(arg) => arg,
113                _ => panic!("Expected a typed argument"),
114            };
115            let ty = &arg.ty;
116            let container_ident = format_ident!("container_{i}");
117            quote! { ::teststack::Init::<#ty>::init(#container_ident).await }
118        });
119        quote! {
120            #(#attrs)*
121            async fn #name() #ret {
122                use ::teststack::Init;
123               let (#(#container_vars),*,) = ::tokio::join!(#(#containers),*);
124                async fn #name(#args) #ret {
125                    #body
126                }
127                #name(#(#test_args),*).await
128            }
129        }
130        .into()
131    }
132}