1use inflector::Inflector;
2use proc_macro::{self, TokenStream};
3use proc_macro2::TokenStream as TokenStream2;
4use quote::{format_ident, quote};
5use syn::punctuated::Punctuated;
6use syn::token::Comma;
7use syn::{
8 parse_macro_input, Attribute, Data, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed, Ident,
9 LitStr, Meta, MetaNameValue, Lit, ExprLit,
10};
11
12#[proc_macro_derive(SqlxCrud, attributes(database, external_id, id))]
13pub fn derive(input: TokenStream) -> TokenStream {
14 let DeriveInput {
15 ident, data, attrs, ..
16 } = parse_macro_input!(input);
17 match data {
18 Data::Struct(DataStruct {
19 fields: Fields::Named(FieldsNamed { named, .. }),
20 ..
21 }) => {
22 let config = Config::new(&attrs, &ident, &named);
23 let static_model_schema = build_static_model_schema(&config);
24 let sqlx_crud_impl = build_sqlx_crud_impl(&config);
25
26 quote! {
27 #static_model_schema
28 #sqlx_crud_impl
29 }
30 .into()
31 }
32 _ => panic!("this derive macro only works on structs with named fields"),
33 }
34}
35
36fn build_static_model_schema(config: &Config) -> TokenStream2 {
37 let crate_name = &config.crate_name;
38 let model_schema_ident = &config.model_schema_ident;
39 let table_name = &config.table_name;
40
41 let id_column = config.id_column_ident.to_string();
42 let columns_len = config.named.iter().count();
43 let columns = config
44 .named
45 .iter()
46 .flat_map(|f| &f.ident)
47 .map(|f| LitStr::new(format!("{}", f).as_str(), f.span()));
48
49 let sql_queries = build_sql_queries(config);
50
51 quote! {
52 #[automatically_derived]
53 static #model_schema_ident: #crate_name::schema::Metadata<'static, #columns_len> = #crate_name::schema::Metadata {
54 table_name: #table_name,
55 id_column: #id_column,
56 columns: [#(#columns),*],
57 #sql_queries
58 };
59 }
60}
61
62fn build_sql_queries(config: &Config) -> TokenStream2 {
63 let table_name = config.quote_ident(&config.table_name);
64 let id_column = format!(
65 "{}.{}",
66 &table_name,
67 config.quote_ident(&config.id_column_ident.to_string())
68 );
69
70 let insert_bind_cnt = if config.external_id {
71 config.named.iter().count()
72 } else {
73 config.named.iter().count() - 1
74 };
75 let insert_sql_binds = (0..insert_bind_cnt)
76 .map(|_| "?")
77 .collect::<Vec<_>>()
78 .join(", ");
79
80 let update_sql_binds = config
81 .named
82 .iter()
83 .flat_map(|f| &f.ident)
84 .filter(|i| *i != &config.id_column_ident)
85 .map(|i| format!("{} = ?", config.quote_ident(&i.to_string())))
86 .collect::<Vec<_>>()
87 .join(", ");
88
89 let insert_column_list = config
90 .named
91 .iter()
92 .flat_map(|f| &f.ident)
93 .filter(|i| config.external_id || *i != &config.id_column_ident)
94 .map(|i| config.quote_ident(&i.to_string()))
95 .collect::<Vec<_>>()
96 .join(", ");
97 let column_list = config
98 .named
99 .iter()
100 .flat_map(|f| &f.ident)
101 .map(|i| format!("{}.{}", &table_name, config.quote_ident(&i.to_string())))
102 .collect::<Vec<_>>()
103 .join(", ");
104
105 let select_sql = format!("SELECT {} FROM {}", column_list, table_name);
106 let select_by_id_sql = format!(
107 "SELECT {} FROM {} WHERE {} = ? LIMIT 1",
108 column_list, table_name, id_column
109 );
110 let insert_sql = format!(
111 "INSERT INTO {} ({}) VALUES ({}) RETURNING {}",
112 table_name, insert_column_list, insert_sql_binds, column_list
113 );
114 let update_by_id_sql = format!(
115 "UPDATE {} SET {} WHERE {} = ? RETURNING {}",
116 table_name, update_sql_binds, id_column, column_list
117 );
118 let delete_by_id_sql = format!("DELETE FROM {} WHERE {} = ?", table_name, id_column);
119
120 quote! {
121 select_sql: #select_sql,
122 select_by_id_sql: #select_by_id_sql,
123 insert_sql: #insert_sql,
124 update_by_id_sql: #update_by_id_sql,
125 delete_by_id_sql: #delete_by_id_sql,
126 }
127}
128
129fn build_sqlx_crud_impl(config: &Config) -> TokenStream2 {
130 let crate_name = &config.crate_name;
131 let ident = &config.ident;
132 let model_schema_ident = &config.model_schema_ident;
133 let db_ty = config.db_ty.sqlx_db();
134 let id_column_ident = &config.id_column_ident;
135
136 let id_ty = config
137 .named
138 .iter()
139 .find(|f| f.ident.as_ref() == Some(id_column_ident))
140 .map(|f| &f.ty)
141 .expect("the id type");
142
143 let insert_query_args = config
144 .named
145 .iter()
146 .flat_map(|f| &f.ident)
147 .filter(|i| config.external_id || *i != &config.id_column_ident)
148 .map(|i| quote! { args.add(self.#i); });
149
150 let insert_query_size = config
151 .named
152 .iter()
153 .flat_map(|f| &f.ident)
154 .filter(|i| config.external_id || *i != &config.id_column_ident)
155 .map(|i| quote! { ::sqlx::encode::Encode::<#db_ty>::size_hint(&self.#i) });
156
157 let update_query_args = config
158 .named
159 .iter()
160 .flat_map(|f| &f.ident)
161 .filter(|i| *i != &config.id_column_ident)
162 .map(|i| quote! { args.add(self.#i); });
163
164 let update_query_args_id = quote! { args.add(self.#id_column_ident); };
165
166 let update_query_size = config
167 .named
168 .iter()
169 .flat_map(|f| &f.ident)
170 .map(|i| quote! { ::sqlx::encode::Encode::<#db_ty>::size_hint(&self.#i) });
171
172 quote! {
173 #[automatically_derived]
174 impl #crate_name::traits::Schema for #ident {
175 type Id = #id_ty;
176
177 fn table_name() -> &'static str {
178 #model_schema_ident.table_name
179 }
180
181 fn id(&self) -> Self::Id {
182 self.#id_column_ident
183 }
184
185 fn id_column() -> &'static str {
186 #model_schema_ident.id_column
187 }
188
189 fn columns() -> &'static [&'static str] {
190 &#model_schema_ident.columns
191 }
192
193 fn select_sql() -> &'static str {
194 #model_schema_ident.select_sql
195 }
196
197 fn select_by_id_sql() -> &'static str {
198 #model_schema_ident.select_by_id_sql
199 }
200
201 fn insert_sql() -> &'static str {
202 #model_schema_ident.insert_sql
203 }
204
205 fn update_by_id_sql() -> &'static str {
206 #model_schema_ident.update_by_id_sql
207 }
208
209 fn delete_by_id_sql() -> &'static str {
210 #model_schema_ident.delete_by_id_sql
211 }
212 }
213
214 #[automatically_derived]
215 impl<'e> #crate_name::traits::Crud<'e, &'e ::sqlx::pool::Pool<#db_ty>> for #ident {
216 fn insert_args(self) -> <#db_ty as ::sqlx::database::HasArguments<'e>>::Arguments {
217 use ::sqlx::Arguments as _;
218 let mut args = <#db_ty as ::sqlx::database::HasArguments<'e>>::Arguments::default();
219 args.reserve(1usize, #(#insert_query_size)+*);
220 #(#insert_query_args)*
221 args
222 }
223
224 fn update_args(self) -> <#db_ty as ::sqlx::database::HasArguments<'e>>::Arguments {
225 use ::sqlx::Arguments as _;
226 let mut args = <#db_ty as ::sqlx::database::HasArguments<'e>>::Arguments::default();
227 args.reserve(1usize, #(#update_query_size)+*);
228 #(#update_query_args)*
229 #update_query_args_id
230 args
231 }
232 }
233 }
234}
235
236#[allow(dead_code)] struct Config<'a> {
238 ident: &'a Ident,
239 named: &'a Punctuated<Field, Comma>,
240 crate_name: TokenStream2,
241 db_ty: DbType,
242 model_schema_ident: Ident,
243 table_name: String,
244 id_column_ident: Ident,
245 external_id: bool,
246}
247
248impl<'a> Config<'a> {
249 fn new(attrs: &[Attribute], ident: &'a Ident, named: &'a Punctuated<Field, Comma>) -> Self {
250 let crate_name = std::env::var("CARGO_PKG_NAME").unwrap();
251 let is_doctest = std::env::vars()
252 .any(|(k, _)| k == "UNSTABLE_RUSTDOC_TEST_LINE" || k == "UNSTABLE_RUSTDOC_TEST_PATH");
253 let crate_name = if !is_doctest && crate_name == "sqlx-crud" {
254 quote! { crate }
255 } else {
256 quote! { ::sqlx_crud }
257 };
258
259 let db_ty = DbType::new(attrs);
260
261 let model_schema_ident =
262 format_ident!("{}_SCHEMA", ident.to_string().to_screaming_snake_case());
263
264 let table_name = ident.to_string().to_table_case();
265
266 let id_attr = &named
268 .iter()
269 .find(|f| f.attrs.iter().any(|a| a.path().is_ident("id")))
270 .and_then(|f| f.ident.as_ref());
271 let id_column_ident = id_attr
273 .unwrap_or_else(|| {
274 named
275 .iter()
276 .flat_map(|f| &f.ident)
277 .next()
278 .expect("the first field")
279 })
280 .clone();
281
282 let external_id = attrs.iter().any(|a| a.path().is_ident("external_id"));
283
284 Self {
285 ident,
286 named,
287 crate_name,
288 db_ty,
289 model_schema_ident,
290 table_name,
291 id_column_ident,
292 external_id,
293 }
294 }
295
296 fn quote_ident(&self, ident: &str) -> String {
297 self.db_ty.quote_ident(ident)
298 }
299}
300
301enum DbType {
302 Any,
303 Mssql,
304 MySql,
305 Postgres,
306 Sqlite,
307}
308
309impl From<&str> for DbType {
310 fn from(db_type: &str) -> Self {
311 match db_type {
312 "Any" => Self::Any,
313 "Mssql" => Self::Mssql,
314 "MySql" => Self::MySql,
315 "Postgres" => Self::Postgres,
316 "Sqlite" => Self::Sqlite,
317 _ => panic!("unknown #[database] type {}", db_type),
318 }
319 }
320}
321
322impl DbType {
323 fn new(attrs: &[Attribute]) -> Self {
324 let mut db_type = DbType::Sqlite;
325 attrs.iter()
326 .find(|a| a.path().is_ident("database"))
327 .map(|a| a.parse_nested_meta(|m| {
328 if let Some(path) = m.path.get_ident() {
329 db_type = DbType::from(path.to_string().as_str());
330 }
331 Ok(())
332 }));
333
334 db_type
335 }
336
337 fn sqlx_db(&self) -> TokenStream2 {
338 match self {
339 Self::Any => quote! { ::sqlx::Any },
340 Self::Mssql => quote! { ::sqlx::Mssql },
341 Self::MySql => quote! { ::sqlx::MySql },
342 Self::Postgres => quote! { ::sqlx::Postgres },
343 Self::Sqlite => quote! { ::sqlx::Sqlite },
344 }
345 }
346
347 fn quote_ident(&self, ident: &str) -> String {
348 match self {
349 Self::Any => format!(r#""{}""#, &ident),
350 Self::Mssql => format!(r#""{}""#, &ident),
351 Self::MySql => format!("`{}`", &ident),
352 Self::Postgres => format!(r#""{}""#, &ident),
353 Self::Sqlite => format!(r#""{}""#, &ident),
354 }
355 }
356}