1use proc_macro::TokenStream;
2use quote::{format_ident, quote};
3use syn::{parse_macro_input, Attribute, Data, DeriveInput, Fields, Meta};
4
5fn get_attr(attrs: &[Attribute], name: &str) -> Option<String> {
7 attrs
8 .iter()
9 .flat_map(|elem| {
10 if let Meta::NameValue(nv) = &elem.meta {
11 if nv.path.is_ident(name) {
12 let value = &nv.value;
13 return Some(quote!(#value).to_string());
14 }
15 }
16 if let Meta::Path(p) = &elem.meta {
17 if p.get_ident().unwrap() == name {
18 return Some(String::new());
19 }
20 }
21 None
22 })
23 .last()
24}
25
26#[proc_macro_derive(BatchInserter, attributes(pgtable, colname, key))]
39pub fn derive_batch_inserter_from(input: TokenStream) -> TokenStream {
40 let input = parse_macro_input!(input as DeriveInput);
41
42 let table_name = get_attr(&input.attrs, "pgtable");
43
44 let struct_name = input.ident;
45 let new_struct_name = format_ident!("{}Inserter", struct_name);
46
47 let mut field_names = Vec::new();
48 let mut field_types = Vec::new();
49
50 let mut psql_names = Vec::new();
51 let mut psql_types = Vec::new();
52
53 let mut keyed_names = Vec::new();
54
55 match &input.data {
56 Data::Struct(data) => {
57 match &data.fields {
58 Fields::Named(fields) => {
59 for field in &fields.named {
60 let ident = &field.ident;
61 let ty = &field.ty;
62 let attrs = &field.attrs;
63
64 field_names.push(ident);
65 field_types.push(ty);
66
67 if let Some(name) = get_attr(attrs, "colname") {
69 psql_names.push(name.clone());
70 if get_attr(attrs, "key").is_some() {
71 keyed_names.push(name);
72 }
73 } else {
74 psql_names.push(quote!(#ident).to_string());
75 if get_attr(attrs, "key").is_some() {
76 keyed_names.push(quote!(#ident).to_string());
77 }
78 }
79
80 match quote!(#ty).to_string().as_ref() {
81 "String" | "&str" | "Option < &str >" | "Option < String >" => {
82 psql_types.push("text[]");
83 }
84 "bool" | "Option < bool >" => psql_types.push("bool[]"),
85 "f8" | "f16" | "f32" | "f64" | "Option < f8 >" | "Option < f16 >"
86 | "Option < f32 >" | "Option < f64 >" => psql_types.push("float[]"),
87 "i16" | "i32" | "i64" | "Option < i16 >" | "Option < i32 >"
88 | "Option < i64 >" => psql_types.push("integer[]"),
89 "NaiveDateTime" | "Option < NaiveDateTime >" => {
90 psql_types.push("timestamp[]")
91 }
92 other => panic!("Type {other} can not be directly converted to a Postgres array type!"),
93 }
94 }
95 }
96 _ => unimplemented!(),
97 }
98 }
99 _ => unimplemented!(),
100 }
101
102 let query_build_fn = if table_name.is_some() {
103 let mut content = format!("INSERT INTO {} (", table_name.as_deref().unwrap());
104 for name in psql_names.iter() {
105 content.push_str(name);
106 content.push(',');
107 }
108 content.pop();
109
110 content.push_str(") SELECT * FROM UNNEST (");
111
112 for (idx, ty) in psql_types.iter().enumerate() {
113 content.push_str(&format!("${}::{},", idx + 1, ty));
114 }
115 content.pop();
116 content.push(')');
117
118 if !keyed_names.is_empty() {
119 content.push_str(" ON CONFLICT (");
120 for name in &keyed_names {
121 content.push_str(name);
122 content.push(',');
123 }
124 content.pop();
125 content.push_str(") DO UPDATE SET ");
126
127 for name in &psql_names {
128 if !keyed_names.contains(name) {
129 content.push_str(&format!("{name}=excluded.{name},"));
130 }
131 }
132 content.pop();
133 }
134
135 let cast_tokens = field_types.iter().map(|&ty| {
136 if quote!(#ty).to_string().starts_with("Option") {
137 quote!(as &[#ty])
138 } else {
139 quote!()
140 }
141 });
142
143 quote!(
144 fn build(
145 self,
146 ) -> sqlx::query::Query<'static, sqlx::Postgres, sqlx::postgres::PgArguments>
147 {
148 sqlx::query!(#content, #(&self.#field_names[..] #cast_tokens),*)
149 }
150 )
151 } else {
152 quote!()
153 };
154
155 let expanded = quote! {
156 #[derive(Default, Debug, PartialEq)]
157 struct #new_struct_name {
158 #(#field_names: Vec<#field_types>),*
159 }
160
161 impl #new_struct_name {
162 fn new() -> Self {
163 Self::default()
164 }
165
166 fn from(items: Vec<#struct_name>) -> Self {
167 items.into_iter().fold(Self::default(),|mut inserter, item| {
168 inserter.add(item);
169 inserter
170 })
171 }
172
173 fn add(&mut self, item: #struct_name) {
174 #(self.#field_names.push(item.#field_names));*
175 }
176
177 #query_build_fn
178 }
179 };
180
181 expanded.into()
182}