sqlx_pg_test_template_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse::Parser, MetaNameValue};
4
5type AttributeArgs = syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>;
6type Error = Box<dyn std::error::Error>;
7type Result<T> = std::result::Result<T, Error>;
8
9#[derive(Default)]
10struct Args {
11    template_name: Option<String>,
12    max_connections: Option<u32>,
13}
14
15/// Enables sqlx_db_test capabilities for a test
16#[proc_macro_attribute]
17pub fn test(args: TokenStream, input: TokenStream) -> TokenStream {
18    let input = syn::parse_macro_input!(input as syn::ItemFn);
19    let args = args;
20
21    match expand(args, input) {
22        Ok(ts) => ts,
23        Err(e) => {
24            if let Some(parse_err) = e.downcast_ref::<syn::Error>() {
25                parse_err.to_compile_error().into()
26            } else {
27                let msg = e.to_string();
28                quote!(::std::compile_error!(#msg)).into()
29            }
30        }
31    }
32}
33
34/// Runs actual expansion of the `#[test]` attribute
35fn expand(args: TokenStream, input: syn::ItemFn) -> Result<TokenStream> {
36    let parser = AttributeArgs::parse_terminated;
37    let args = parser.parse2(args.into())?;
38    let args = parse_args(args)?;
39
40    expand_with_args(input, args)
41}
42
43fn parse_args(attr_args: AttributeArgs) -> syn::Result<Args> {
44    let mut args = Args::default();
45
46    for arg in attr_args {
47        let path = arg.path().clone();
48
49        match arg {
50            syn::Meta::NameValue(MetaNameValue { value, .. }) if path.is_ident("template") => {
51                args.template_name = Some(parse_lit_str(&value)?);
52            }
53
54            syn::Meta::NameValue(MetaNameValue { value, .. })
55                if path.is_ident("max_connections") =>
56            {
57                let digits = parse_lit_int(&value)?;
58                let mc: u32 = digits
59                    .parse()
60                    .map_err(|_| syn::Error::new_spanned(value, "expected u32 number"))?;
61
62                args.max_connections = Some(mc);
63            }
64
65            arg => {
66                return Err(syn::Error::new_spanned(
67                    arg,
68                    r#"expected `template = "database_name"` and/or `max_connections = 5`"#,
69                ))
70            }
71        }
72    }
73
74    Ok(args)
75}
76
77fn expand_with_args(input: syn::ItemFn, args: Args) -> Result<TokenStream> {
78    let ret = &input.sig.output;
79    let name = &input.sig.ident;
80    let inputs = &input.sig.inputs;
81    let body = &input.block;
82    let attrs = &input.attrs;
83
84    let template_name = match args.template_name {
85        None => quote! { None },
86        Some(name) => quote! { Some(#name.to_string()) },
87    };
88
89    let max_connections = match args.max_connections {
90        None => quote! { None },
91        Some(mc) => quote! { Some(#mc) },
92    };
93
94    let name_str = name.to_string();
95
96    Ok(quote! {
97        #(#attrs)*
98        #[::core::prelude::v1::test]
99        fn #name() #ret {
100            async fn #name(#inputs) #ret {
101                #body
102            };
103
104            let test_args = ::sqlx_pg_test_template::TestArgs {
105                template_name: #template_name,
106                max_connections: #max_connections,
107                module_path: format!("{}::{}", module_path!().to_string(), #name_str),
108            };
109
110            sqlx_pg_test_template::run_test(#name, test_args)
111
112            // TODO: check timeout of pool going out of scope. main problem is that sqlx does
113            // not export core trait.
114            //
115            // let close_timed_out = sqlx::rt::timeout(Duration::from_secs(10), pool.close())
116            //     .await
117            //     .is_err();
118
119            // if close_timed_out {
120            //     eprintln!("test {test_path} held onto Pool after exiting");
121            // }
122
123        }
124    }
125    .into())
126}
127
128fn parse_lit_str(expr: &syn::Expr) -> syn::Result<String> {
129    match expr {
130        syn::Expr::Lit(syn::ExprLit {
131            lit: syn::Lit::Str(lit),
132            ..
133        }) => Ok(lit.value()),
134        _ => Err(syn::Error::new_spanned(expr, "expected string")),
135    }
136}
137
138fn parse_lit_int(expr: &syn::Expr) -> syn::Result<String> {
139    match expr {
140        syn::Expr::Lit(syn::ExprLit {
141            lit: syn::Lit::Int(lit),
142            ..
143        }) => Ok(lit.base10_digits().to_owned()),
144        _ => Err(syn::Error::new_spanned(expr, "expected integer")),
145    }
146}