1use custom::CustomContainer;
2use db::{DbContainer, DbName};
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::{parse::Parse, parse_macro_input, punctuated::Punctuated, ItemFn};
6
7mod custom;
8mod db;
9
10enum Container {
11 Db(DbContainer),
12 Custom(CustomContainer),
13}
14
15impl Parse for Container {
16 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
17 if db::peek(&input) {
18 Ok(Container::Db(DbContainer::parse(input)?))
19 } else if custom::peek(&input) {
20 Ok(Container::Custom(CustomContainer::parse(input)?))
21 } else {
22 Err(input.error("expected a container"))
23 }
24 }
25}
26
27#[proc_macro_attribute]
43pub fn stack(attr: TokenStream, item: TokenStream) -> TokenStream {
44 let containers =
45 parse_macro_input!(attr with Punctuated::<Container, syn::Token![,]>::parse_terminated);
46 let input = parse_macro_input!(item as ItemFn);
47 let containers = containers.into_iter().collect();
48 expand_test(input, containers)
49}
50
51fn expand_test(input: ItemFn, containers: Vec<Container>) -> TokenStream {
52 let ret = &input.sig.output;
53 let name = &input.sig.ident;
54 let body = &input.block;
55 let attrs = &input.attrs;
56 let args = &input.sig.inputs;
57 let sqlx_test = attrs
58 .iter()
59 .any(|attr| attr.path().segments.iter().any(|s| s.ident == "sqlx"));
60
61 let container_vars = (0..containers.len())
62 .map(|i| format_ident!("container_{i}"))
63 .collect::<Vec<_>>();
64
65 let containers = containers
66 .iter()
67 .map(|container| match container {
68 Container::Db(db) => {
69 let db_name = match &db.conf.db_name {
70 DbName::Random => quote! {::teststack::DbName::Random },
71 DbName::Static(name) => {
72 quote! {::teststack::DbName::Static(#name.to_string()) }
73 }
74 DbName::Default => quote! { ::teststack::DbName::Default },
75 };
76 match db.name {
77 "postgres" => quote! { ::teststack::postgres(#db_name) },
78 "mysql" => quote! { ::teststack::mysql(#db_name) },
79 _ => panic!("Unknown container type: {}", name),
80 }
81 }
82 Container::Custom(custom) => {
83 let expr = &custom.expr;
84 quote! { ::teststack::custom(#expr) }
85 }
86 })
87 .collect::<Vec<_>>();
88
89 if sqlx_test {
90 quote! {
91 #[allow(unnameable_test_items)]
92 #[::core::prelude::v1::test]
93 fn #name() #ret {
94 let rt = ::tokio::runtime::Builder::new_current_thread()
95 .enable_all()
96 .build()
97 .unwrap();
98 rt.block_on(async {
99 ::tokio::join!(#(#containers),*);
100 });
101 #(#attrs)*
102 fn #name(#args) #ret {
103 #body
104 }
105 #name()
106 }
107 }
108 .into()
109 } else {
110 let test_args = args.iter().enumerate().map(|(i, arg)| {
111 let arg = match arg {
112 syn::FnArg::Typed(arg) => arg,
113 _ => panic!("Expected a typed argument"),
114 };
115 let ty = &arg.ty;
116 let container_ident = format_ident!("container_{i}");
117 quote! { ::teststack::Init::<#ty>::init(#container_ident).await }
118 });
119 quote! {
120 #(#attrs)*
121 async fn #name() #ret {
122 use ::teststack::Init;
123 let (#(#container_vars),*,) = ::tokio::join!(#(#containers),*);
124 async fn #name(#args) #ret {
125 #body
126 }
127 #name(#(#test_args),*).await
128 }
129 }
130 .into()
131 }
132}