sqlx_batch/
lib.rs

1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, Attribute, Data, DeriveInput, Fields, Meta};
4
5/// Get an optionally defined key-value pair from annotation
6fn get_attr(attrs: &[Attribute], name: &str) -> Option<String> {
7    attrs
8        .iter()
9        .flat_map(|elem| {
10            if let Meta::NameValue(nv) = &elem.meta {
11                if nv.path.is_ident(name) {
12                    let value = &nv.value;
13                    return Some(quote!(#value).to_string());
14                }
15            }
16            if let Meta::Path(p) = &elem.meta {
17                if p.get_ident().unwrap() == name {
18                    return Some(String::new());
19                }
20            }
21            None
22        })
23        .last()
24}
25
26/// Generates a {stuct_name}Inserter struct. This struct allows you to easily
27/// add many of this struct to a given Postgres table.
28///
29/// Use #\[pgtable = "{table_name}"\] on the main struct to define the postgres
30/// table name.
31///
32/// Use #\[colname = "{renamed_elem}"\] on a field to specify that this field
33/// will be renamed when added to the database.
34///
35/// Use #\[key\] on a field to specify that you want to upsert when inserting
36/// this struct in the database, and that this field is a primary key / unique
37/// constraint which is used for checking if the row already exists.
38#[proc_macro_derive(BatchInserter, attributes(pgtable, colname, key))]
39pub fn derive_batch_inserter_from(input: TokenStream) -> TokenStream {
40    let input = parse_macro_input!(input as DeriveInput);
41
42    let table_name = get_attr(&input.attrs, "pgtable");
43
44    let struct_name = input.ident;
45    let new_struct_name = format_ident!("{}Inserter", struct_name);
46
47    let mut field_names = Vec::new();
48    let mut field_types = Vec::new();
49
50    let mut psql_names = Vec::new();
51    let mut psql_types = Vec::new();
52
53    let mut keyed_names = Vec::new();
54
55    match &input.data {
56        Data::Struct(data) => {
57            match &data.fields {
58                Fields::Named(fields) => {
59                    for field in &fields.named {
60                        let ident = &field.ident;
61                        let ty = &field.ty;
62                        let attrs = &field.attrs;
63
64                        field_names.push(ident);
65                        field_types.push(ty);
66
67                        // See if we have specified a custom postgres column name for this field
68                        if let Some(name) = get_attr(attrs, "colname") {
69                            psql_names.push(name.clone());
70                            if get_attr(attrs, "key").is_some() {
71                                keyed_names.push(name);
72                            }
73                        } else {
74                            psql_names.push(quote!(#ident).to_string());
75                            if get_attr(attrs, "key").is_some() {
76                                keyed_names.push(quote!(#ident).to_string());
77                            }
78                        }
79
80                        match quote!(#ty).to_string().as_ref() {
81                            "String" | "&str" | "Option < &str >" | "Option < String >" => {
82                                psql_types.push("text[]");
83                            }
84                            "bool" | "Option < bool >" => psql_types.push("bool[]"),
85                            "f8" | "f16" | "f32" | "f64" | "Option < f8 >" | "Option < f16 >"
86                            | "Option < f32 >" | "Option < f64 >" => psql_types.push("float[]"),
87                            "i16" | "i32" | "i64" | "Option < i16 >" | "Option < i32 >"
88                            | "Option < i64 >" => psql_types.push("integer[]"),
89                            "NaiveDateTime" | "Option < NaiveDateTime >" => {
90                                psql_types.push("timestamp[]")
91                            }
92                            other => panic!("Type {other} can not be directly converted to a Postgres array type!"),
93                        }
94                    }
95                }
96                _ => unimplemented!(),
97            }
98        }
99        _ => unimplemented!(),
100    }
101
102    let query_build_fn = if table_name.is_some() {
103        let mut content = format!("INSERT INTO {} (", table_name.as_deref().unwrap());
104        for name in psql_names.iter() {
105            content.push_str(name);
106            content.push(',');
107        }
108        content.pop();
109
110        content.push_str(") SELECT * FROM UNNEST (");
111
112        for (idx, ty) in psql_types.iter().enumerate() {
113            content.push_str(&format!("${}::{},", idx + 1, ty));
114        }
115        content.pop();
116        content.push(')');
117
118        if !keyed_names.is_empty() {
119            content.push_str(" ON CONFLICT (");
120            for name in &keyed_names {
121                content.push_str(name);
122                content.push(',');
123            }
124            content.pop();
125            content.push_str(") DO ");
126
127            if keyed_names.len() == psql_names.len() {
128                content.push_str("NOTHING");
129            } else {
130                content.push_str("UPDATE SET ");
131                for name in &psql_names {
132                    if !keyed_names.contains(name) {
133                        content.push_str(&format!("{name}=excluded.{name},"));
134                    }
135                }
136                content.pop();
137            }
138        }
139
140        let cast_tokens = field_types.iter().map(|&ty| {
141            if quote!(#ty).to_string().starts_with("Option") {
142                quote!(as &[#ty])
143            } else {
144                quote!()
145            }
146        });
147
148        quote!(
149            fn build(
150                self,
151            ) -> sqlx::query::Query<'static, sqlx::Postgres, sqlx::postgres::PgArguments>
152            {
153                sqlx::query!(#content, #(&self.#field_names[..] #cast_tokens),*)
154            }
155        )
156    } else {
157        quote!()
158    };
159
160    let expanded = quote! {
161        #[derive(Default, Debug, PartialEq)]
162        struct #new_struct_name {
163            #(#field_names: Vec<#field_types>),*
164        }
165
166        impl #new_struct_name {
167            fn new() -> Self {
168                Self::default()
169            }
170
171            fn from(items: Vec<#struct_name>) -> Self {
172                items.into_iter().fold(Self::default(),|mut inserter, item| {
173                    inserter.add(item);
174                    inserter
175                })
176            }
177
178            fn add(&mut self, item: #struct_name) {
179                #(self.#field_names.push(item.#field_names));*
180            }
181
182            #query_build_fn
183        }
184    };
185
186    expanded.into()
187}