1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, Attribute, Data, DeriveInput, Fields, Meta};
4
5fn get_attr(attrs: &[Attribute], name: &str) -> Option<String> {
7 attrs
8 .iter()
9 .flat_map(|elem| {
10 if let Meta::NameValue(nv) = &elem.meta {
11 if nv.path.is_ident(name) {
12 let value = &nv.value;
13 return Some(quote!(#value).to_string());
14 }
15 }
16 if let Meta::Path(p) = &elem.meta {
17 if p.get_ident().unwrap() == name {
18 return Some(String::new());
19 }
20 }
21 None
22 })
23 .last()
24}
25
26#[proc_macro_derive(BatchInserter, attributes(pgtable, colname, key))]
39pub fn derive_batch_inserter_from(input: TokenStream) -> TokenStream {
40 let input = parse_macro_input!(input as DeriveInput);
41
42 let table_name = get_attr(&input.attrs, "pgtable");
43
44 let struct_name = input.ident;
45 let new_struct_name = format_ident!("{}Inserter", struct_name);
46
47 let mut field_names = Vec::new();
48 let mut field_types = Vec::new();
49
50 let mut psql_names = Vec::new();
51 let mut psql_types = Vec::new();
52
53 let mut keyed_names = Vec::new();
54
55 match &input.data {
56 Data::Struct(data) => {
57 match &data.fields {
58 Fields::Named(fields) => {
59 for field in &fields.named {
60 let ident = &field.ident;
61 let ty = &field.ty;
62 let attrs = &field.attrs;
63
64 field_names.push(ident);
65 field_types.push(ty);
66
67 if let Some(name) = get_attr(attrs, "colname") {
69 psql_names.push(name.clone());
70 if get_attr(attrs, "key").is_some() {
71 keyed_names.push(name);
72 }
73 } else {
74 psql_names.push(quote!(#ident).to_string());
75 if get_attr(attrs, "key").is_some() {
76 keyed_names.push(quote!(#ident).to_string());
77 }
78 }
79
80 match quote!(#ty).to_string().as_ref() {
81 "String" | "&str" | "Option < &str >" | "Option < String >" => {
82 psql_types.push("text[]");
83 }
84 "bool" | "Option < bool >" => psql_types.push("bool[]"),
85 "f8" | "f16" | "f32" | "f64" | "Option < f8 >" | "Option < f16 >"
86 | "Option < f32 >" | "Option < f64 >" => psql_types.push("float[]"),
87 "i16" | "i32" | "i64" | "Option < i16 >" | "Option < i32 >"
88 | "Option < i64 >" => psql_types.push("integer[]"),
89 "NaiveDateTime" | "Option < NaiveDateTime >" => {
90 psql_types.push("timestamp[]")
91 }
92 other => panic!("Type {other} can not be directly converted to a Postgres array type!"),
93 }
94 }
95 }
96 _ => unimplemented!(),
97 }
98 }
99 _ => unimplemented!(),
100 }
101
102 let query_build_fn = if table_name.is_some() {
103 let mut content = format!("INSERT INTO {} (", table_name.as_deref().unwrap());
104 for name in psql_names.iter() {
105 content.push_str(name);
106 content.push(',');
107 }
108 content.pop();
109
110 content.push_str(") SELECT * FROM UNNEST (");
111
112 for (idx, ty) in psql_types.iter().enumerate() {
113 content.push_str(&format!("${}::{},", idx + 1, ty));
114 }
115 content.pop();
116 content.push(')');
117
118 if !keyed_names.is_empty() {
119 content.push_str(" ON CONFLICT (");
120 for name in &keyed_names {
121 content.push_str(name);
122 content.push(',');
123 }
124 content.pop();
125 content.push_str(") DO ");
126
127 if keyed_names.len() == psql_names.len() {
128 content.push_str("NOTHING");
129 } else {
130 content.push_str("UPDATE SET ");
131 for name in &psql_names {
132 if !keyed_names.contains(name) {
133 content.push_str(&format!("{name}=excluded.{name},"));
134 }
135 }
136 content.pop();
137 }
138 }
139
140 let cast_tokens = field_types.iter().map(|&ty| {
141 if quote!(#ty).to_string().starts_with("Option") {
142 quote!(as &[#ty])
143 } else {
144 quote!()
145 }
146 });
147
148 quote!(
149 fn build(
150 self,
151 ) -> sqlx::query::Query<'static, sqlx::Postgres, sqlx::postgres::PgArguments>
152 {
153 sqlx::query!(#content, #(&self.#field_names[..] #cast_tokens),*)
154 }
155 )
156 } else {
157 quote!()
158 };
159
160 let expanded = quote! {
161 #[derive(Default, Debug, PartialEq)]
162 struct #new_struct_name {
163 #(#field_names: Vec<#field_types>),*
164 }
165
166 impl #new_struct_name {
167 fn new() -> Self {
168 Self::default()
169 }
170
171 fn from(items: Vec<#struct_name>) -> Self {
172 items.into_iter().fold(Self::default(),|mut inserter, item| {
173 inserter.add(item);
174 inserter
175 })
176 }
177
178 fn add(&mut self, item: #struct_name) {
179 #(self.#field_names.push(item.#field_names));*
180 }
181
182 #query_build_fn
183 }
184 };
185
186 expanded.into()
187}