sqlx_pg_test_template_macros/
lib.rs1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse::Parser, MetaNameValue};
4
5type AttributeArgs = syn::punctuated::Punctuated<syn::Meta, syn::Token![,]>;
6type Error = Box<dyn std::error::Error>;
7type Result<T> = std::result::Result<T, Error>;
8
9#[derive(Default)]
10struct Args {
11 template_name: Option<String>,
12 max_connections: Option<u32>,
13}
14
15#[proc_macro_attribute]
17pub fn test(args: TokenStream, input: TokenStream) -> TokenStream {
18 let input = syn::parse_macro_input!(input as syn::ItemFn);
19 let args = args;
20
21 match expand(args, input) {
22 Ok(ts) => ts,
23 Err(e) => {
24 if let Some(parse_err) = e.downcast_ref::<syn::Error>() {
25 parse_err.to_compile_error().into()
26 } else {
27 let msg = e.to_string();
28 quote!(::std::compile_error!(#msg)).into()
29 }
30 }
31 }
32}
33
34fn expand(args: TokenStream, input: syn::ItemFn) -> Result<TokenStream> {
36 let parser = AttributeArgs::parse_terminated;
37 let args = parser.parse2(args.into())?;
38 let args = parse_args(args)?;
39
40 expand_with_args(input, args)
41}
42
43fn parse_args(attr_args: AttributeArgs) -> syn::Result<Args> {
44 let mut args = Args::default();
45
46 for arg in attr_args {
47 let path = arg.path().clone();
48
49 match arg {
50 syn::Meta::NameValue(MetaNameValue { value, .. }) if path.is_ident("template") => {
51 args.template_name = Some(parse_lit_str(&value)?);
52 }
53
54 syn::Meta::NameValue(MetaNameValue { value, .. })
55 if path.is_ident("max_connections") =>
56 {
57 let digits = parse_lit_int(&value)?;
58 let mc: u32 = digits
59 .parse()
60 .map_err(|_| syn::Error::new_spanned(value, "expected u32 number"))?;
61
62 args.max_connections = Some(mc);
63 }
64
65 arg => {
66 return Err(syn::Error::new_spanned(
67 arg,
68 r#"expected `template = "database_name"` and/or `max_connections = 5`"#,
69 ))
70 }
71 }
72 }
73
74 Ok(args)
75}
76
77fn expand_with_args(input: syn::ItemFn, args: Args) -> Result<TokenStream> {
78 let ret = &input.sig.output;
79 let name = &input.sig.ident;
80 let inputs = &input.sig.inputs;
81 let body = &input.block;
82 let attrs = &input.attrs;
83
84 let template_name = match args.template_name {
85 None => quote! { None },
86 Some(name) => quote! { Some(#name.to_string()) },
87 };
88
89 let max_connections = match args.max_connections {
90 None => quote! { None },
91 Some(mc) => quote! { Some(#mc) },
92 };
93
94 let name_str = name.to_string();
95
96 Ok(quote! {
97 #(#attrs)*
98 #[::core::prelude::v1::test]
99 fn #name() #ret {
100 async fn #name(#inputs) #ret {
101 #body
102 };
103
104 let test_args = ::sqlx_pg_test_template::TestArgs {
105 template_name: #template_name,
106 max_connections: #max_connections,
107 module_path: format!("{}::{}", module_path!().to_string(), #name_str),
108 };
109
110 sqlx_pg_test_template::run_test(#name, test_args)
111
112 }
124 }
125 .into())
126}
127
128fn parse_lit_str(expr: &syn::Expr) -> syn::Result<String> {
129 match expr {
130 syn::Expr::Lit(syn::ExprLit {
131 lit: syn::Lit::Str(lit),
132 ..
133 }) => Ok(lit.value()),
134 _ => Err(syn::Error::new_spanned(expr, "expected string")),
135 }
136}
137
138fn parse_lit_int(expr: &syn::Expr) -> syn::Result<String> {
139 match expr {
140 syn::Expr::Lit(syn::ExprLit {
141 lit: syn::Lit::Int(lit),
142 ..
143 }) => Ok(lit.base10_digits().to_owned()),
144 _ => Err(syn::Error::new_spanned(expr, "expected integer")),
145 }
146}