sqlx_plus_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn;
4
5#[proc_macro_derive(Insertable, attributes(insertable))]
6pub fn insertable_derive(input: TokenStream) -> TokenStream {
7    let ast = syn::parse(input).unwrap();
8    impl_insertable(&ast)
9}
10
11fn impl_insertable(ast: &syn::DeriveInput) -> TokenStream {
12    let name = &ast.ident;
13
14    let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
15    let fields = get_struct_fields(&ast);
16    let attr = get_insertable_attribute(&ast);
17    let InsertableAttr { db, table_name } = attr.parse_args().unwrap();
18
19    let gen = quote! {
20        impl #impl_generics sqlx_plus::Insertable for #name #ty_generics #where_clause {
21            type Database = #db;
22
23            fn table_name() -> &'static str {
24                #table_name
25            }
26
27            fn insert_columns() -> Vec<&'static str> {
28                vec![ #( stringify!(#fields) ),* ]
29            }
30
31            fn bind_fields<'q, Q>(&'q self, q: Q) -> Q
32            where
33                Q: sqlx_plus::QueryBindExt<'q, Self::Database>
34            {
35                q #( .bind(&self.#fields) )*
36            }
37        }
38    };
39
40    gen.into()
41}
42
43fn get_struct_fields(ast: &syn::DeriveInput) -> Vec<syn::Ident> {
44    match ast.data {
45        syn::Data::Struct(ref data_struct) => match data_struct.fields {
46            syn::Fields::Named(ref fields_named) => fields_named
47                .named
48                .iter()
49                .map(|field| field.ident.clone().unwrap())
50                .collect::<Vec<_>>(),
51            syn::Fields::Unnamed(_) => panic!("Can not tuple structs derive Insertable trait"),
52            syn::Fields::Unit => panic!("Can not unit structs derive Insertable trait"),
53        },
54        _ => panic!("Only structs can derive Insertable trait"),
55    }
56}
57
58fn get_insertable_attribute(ast: &syn::DeriveInput) -> &syn::Attribute {
59    ast.attrs
60        .iter()
61        .filter(|x| x.path.is_ident("insertable"))
62        .next()
63        .expect("The insertable attribute is required for specifying DB type and table name")
64}
65
66struct InsertableAttr {
67    db: syn::Path,
68    table_name: String,
69}
70
71impl syn::parse::Parse for InsertableAttr {
72    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
73        let db: syn::Path = input.parse()?;
74        input.parse::<syn::Token![,]>()?;
75        let table: syn::LitStr = input.parse()?;
76
77        Ok(InsertableAttr {
78            db,
79            table_name: table.value(),
80        })
81    }
82}