1use proc_macro::TokenStream;
2use quote::quote;
3use syn;
4
5#[proc_macro_derive(Insertable, attributes(insertable))]
6pub fn insertable_derive(input: TokenStream) -> TokenStream {
7 let ast = syn::parse(input).unwrap();
8 impl_insertable(&ast)
9}
10
11fn impl_insertable(ast: &syn::DeriveInput) -> TokenStream {
12 let name = &ast.ident;
13
14 let (impl_generics, ty_generics, where_clause) = ast.generics.split_for_impl();
15 let fields = get_struct_fields(&ast);
16 let attr = get_insertable_attribute(&ast);
17 let InsertableAttr { db, table_name } = attr.parse_args().unwrap();
18
19 let gen = quote! {
20 impl #impl_generics sqlx_plus::Insertable for #name #ty_generics #where_clause {
21 type Database = #db;
22
23 fn table_name() -> &'static str {
24 #table_name
25 }
26
27 fn insert_columns() -> Vec<&'static str> {
28 vec![ #( stringify!(#fields) ),* ]
29 }
30
31 fn bind_fields<'q, Q>(&'q self, q: Q) -> Q
32 where
33 Q: sqlx_plus::QueryBindExt<'q, Self::Database>
34 {
35 q #( .bind(&self.#fields) )*
36 }
37 }
38 };
39
40 gen.into()
41}
42
43fn get_struct_fields(ast: &syn::DeriveInput) -> Vec<syn::Ident> {
44 match ast.data {
45 syn::Data::Struct(ref data_struct) => match data_struct.fields {
46 syn::Fields::Named(ref fields_named) => fields_named
47 .named
48 .iter()
49 .map(|field| field.ident.clone().unwrap())
50 .collect::<Vec<_>>(),
51 syn::Fields::Unnamed(_) => panic!("Can not tuple structs derive Insertable trait"),
52 syn::Fields::Unit => panic!("Can not unit structs derive Insertable trait"),
53 },
54 _ => panic!("Only structs can derive Insertable trait"),
55 }
56}
57
58fn get_insertable_attribute(ast: &syn::DeriveInput) -> &syn::Attribute {
59 ast.attrs
60 .iter()
61 .filter(|x| x.path.is_ident("insertable"))
62 .next()
63 .expect("The insertable attribute is required for specifying DB type and table name")
64}
65
66struct InsertableAttr {
67 db: syn::Path,
68 table_name: String,
69}
70
71impl syn::parse::Parse for InsertableAttr {
72 fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
73 let db: syn::Path = input.parse()?;
74 input.parse::<syn::Token![,]>()?;
75 let table: syn::LitStr = input.parse()?;
76
77 Ok(InsertableAttr {
78 db,
79 table_name: table.value(),
80 })
81 }
82}