1use inflector::Inflector;
2use proc_macro::{self, TokenStream};
3use proc_macro2::TokenStream as TokenStream2;
4use quote::{format_ident, quote};
5use syn::punctuated::Punctuated;
6use syn::token::Comma;
7use syn::{
8 parse_macro_input, Attribute, Data, DataStruct, DeriveInput, Field, Fields, FieldsNamed, Ident,
9 LitStr,
10};
11
12#[proc_macro_derive(SqlxCrud, attributes(database, table, external_id, id))]
13pub fn derive(input: TokenStream) -> TokenStream {
14 let DeriveInput {
15 ident, data, attrs, ..
16 } = parse_macro_input!(input);
17 match data {
18 Data::Struct(DataStruct {
19 fields: Fields::Named(FieldsNamed { named, .. }),
20 ..
21 }) => {
22 let config = Config::new(&attrs, &ident, &named);
23 let static_model_schema = build_static_model_schema(&config);
24 let sqlx_crud_impl = build_sqlx_crud_impl(&config);
25
26 quote! {
27 #static_model_schema
28 #sqlx_crud_impl
29 }
30 .into()
31 }
32 _ => panic!("this derive macro only works on structs with named fields"),
33 }
34}
35
36fn build_static_model_schema(config: &Config) -> TokenStream2 {
37 let crate_name = &config.crate_name;
38 let model_schema_ident = &config.model_schema_ident;
39 let table_name = &config.table_name;
40
41 let id_column = config.id_column_ident.to_string();
42 let columns_len = config.named.iter().count();
43 let columns = config
44 .named
45 .iter()
46 .flat_map(|f| &f.ident)
47 .map(|f| LitStr::new(format!("{}", f).as_str(), f.span()));
48
49 let sql_queries = build_sql_queries(config);
50
51 quote! {
52 #[automatically_derived]
53 static #model_schema_ident: #crate_name::schema::Metadata<'static, #columns_len> = #crate_name::schema::Metadata {
54 table_name: #table_name,
55 id_column: #id_column,
56 columns: [#(#columns),*],
57 #sql_queries
58 };
59 }
60}
61
62fn build_sql_queries(config: &Config) -> TokenStream2 {
63 let table_name = config.quote_ident(&config.table_name);
64 let id_column = format!(
65 "{}.{}",
66 &table_name,
67 config.quote_ident(&config.id_column_ident.to_string())
68 );
69
70 let insert_bind_cnt = if config.external_id {
71 config.named.iter().count()
72 } else {
73 config.named.iter().count() - 1
74 };
75 let insert_sql_binds = (0..insert_bind_cnt)
76 .map(|i| format!("${}", i + 1))
77 .collect::<Vec<_>>()
78 .join(", ");
79
80 let update_sql_binds = config
81 .named
82 .iter()
83 .flat_map(|f| &f.ident)
84 .filter(|i| *i != &config.id_column_ident)
85 .enumerate()
86 .map(|(i, ident)| format!("{} = ${}", config.quote_ident(&ident.to_string()), i + 1))
87 .collect::<Vec<_>>();
88
89 let insert_column_list = config
90 .named
91 .iter()
92 .flat_map(|f| &f.ident)
93 .filter(|i| config.external_id || *i != &config.id_column_ident)
94 .map(|i| config.quote_ident(&i.to_string()))
95 .collect::<Vec<_>>()
96 .join(", ");
97 let column_list = config
98 .named
99 .iter()
100 .flat_map(|f| &f.ident)
101 .map(|i| format!("{}.{}", &table_name, config.quote_ident(&i.to_string())))
102 .collect::<Vec<_>>()
103 .join(", ");
104
105 let select_sql = format!("SELECT {} FROM {}", column_list, table_name);
106 let paginated_sql = format!(
107 "SELECT {} FROM {} LIMIT $1 OFFSET $2",
108 column_list, table_name
109 );
110 let select_by_id_sql = format!(
111 "SELECT {} FROM {} WHERE {} = $1 LIMIT 1",
112 column_list, table_name, id_column
113 );
114 let insert_sql = format!(
115 "INSERT INTO {} ({}) VALUES ({}) RETURNING {}",
116 table_name, insert_column_list, insert_sql_binds, column_list
117 );
118 let update_by_id_sql = format!(
119 "UPDATE {} SET {} WHERE {} = ${} RETURNING {}",
120 table_name,
121 update_sql_binds.join(", "),
122 id_column,
123 update_sql_binds.len() + 1,
124 column_list
125 );
126 let delete_by_id_sql = format!("DELETE FROM {} WHERE {} = $1", table_name, id_column);
127
128 quote! {
129 select_sql: #select_sql,
130 select_by_id_sql: #select_by_id_sql,
131 insert_sql: #insert_sql,
132 update_by_id_sql: #update_by_id_sql,
133 delete_by_id_sql: #delete_by_id_sql,
134 paginated_sql: #paginated_sql,
135 }
136}
137
138fn build_sqlx_crud_impl(config: &Config) -> TokenStream2 {
139 let crate_name = &config.crate_name;
140 let ident = &config.ident;
141 let model_schema_ident = &config.model_schema_ident;
142 let db_ty = config.db_ty.sqlx_db();
143 let id_column_ident = &config.id_column_ident;
144
145 let id_ty = config
146 .named
147 .iter()
148 .find(|f| f.ident.as_ref() == Some(id_column_ident))
149 .map(|f| &f.ty)
150 .expect("the id type");
151
152 let insert_query_args = config
153 .named
154 .iter()
155 .flat_map(|f| &f.ident)
156 .filter(|i| config.external_id || *i != &config.id_column_ident)
157 .map(|i| quote! { args.add(self.#i).map_err(sqlx::Error::Encode)?; });
158
159 let insert_query_size = config
160 .named
161 .iter()
162 .flat_map(|f| &f.ident)
163 .filter(|i| config.external_id || *i != &config.id_column_ident)
164 .map(|i| quote! { ::sqlx::encode::Encode::<#db_ty>::size_hint(&self.#i) });
165
166 let update_query_args = config
167 .named
168 .iter()
169 .flat_map(|f| &f.ident)
170 .filter(|i| *i != &config.id_column_ident)
171 .map(|i| quote! { args.add(self.#i).map_err(sqlx::Error::Encode)?; });
172
173 let update_query_args_id = quote! { args.add(self.#id_column_ident).map_err(sqlx::Error::Encode)?; };
174
175 let update_query_size = config
176 .named
177 .iter()
178 .flat_map(|f| &f.ident)
179 .map(|i| quote! { ::sqlx::encode::Encode::<#db_ty>::size_hint(&self.#i) });
180
181 quote! {
182 #[automatically_derived]
183 impl #crate_name::traits::Schema for #ident {
184 type Id = #id_ty;
185
186 fn table_name() -> &'static str {
187 #model_schema_ident.table_name
188 }
189
190 fn id(&self) -> Self::Id {
191 self.#id_column_ident
192 }
193
194 fn id_column() -> &'static str {
195 #model_schema_ident.id_column
196 }
197
198 fn columns() -> &'static [&'static str] {
199 &#model_schema_ident.columns
200 }
201
202 fn select_sql() -> &'static str {
203 #model_schema_ident.select_sql
204 }
205
206 fn select_by_id_sql() -> &'static str {
207 #model_schema_ident.select_by_id_sql
208 }
209
210 fn insert_sql() -> &'static str {
211 #model_schema_ident.insert_sql
212 }
213
214 fn update_by_id_sql() -> &'static str {
215 #model_schema_ident.update_by_id_sql
216 }
217
218 fn delete_by_id_sql() -> &'static str {
219 #model_schema_ident.delete_by_id_sql
220 }
221
222 fn paginated_sql() -> &'static str {
223 #model_schema_ident.paginated_sql
224 }
225 }
226
227 #[automatically_derived]
228 impl<'e> #crate_name::traits::Crud<'e, &'e ::sqlx::pool::Pool<#db_ty>> for #ident {
229 fn insert_args(self) -> ::sqlx::Result<<#db_ty as ::sqlx::database::Database>::Arguments<'e>> {
230 use ::sqlx::Arguments as _;
231 let mut args = <#db_ty as ::sqlx::database::Database>::Arguments::default();
232 args.reserve(1usize, #(#insert_query_size)+*);
233 #(#insert_query_args)*
234 Ok(args)
235 }
236
237 fn update_args(self) -> ::sqlx::Result<<#db_ty as ::sqlx::database::Database>::Arguments<'e>> {
238 use ::sqlx::Arguments as _;
239 let mut args = <#db_ty as ::sqlx::database::Database>::Arguments::default();
240 args.reserve(1usize, #(#update_query_size)+*);
241 #(#update_query_args)*
242 #update_query_args_id
243 Ok(args)
244 }
245
246 fn paginated_args(limit: i64, offset: i64) -> <#db_ty as ::sqlx::database::Database>::Arguments<'e> {
247 use ::sqlx::Arguments as _;
248 let mut args = <#db_ty as ::sqlx::database::Database>::Arguments::default();
249 args.reserve(2usize,
250 ::sqlx::encode::Encode::<#db_ty>::size_hint(&limit) +
251 ::sqlx::encode::Encode::<#db_ty>::size_hint(&offset)
252 );
253 let _ = args.add(limit);
254 let _ = args.add(offset);
255 args
256 }
257 }
258 }
259}
260
261#[allow(dead_code)] struct Config<'a> {
263 ident: &'a Ident,
264 named: &'a Punctuated<Field, Comma>,
265 crate_name: TokenStream2,
266 db_ty: DbType,
267 model_schema_ident: Ident,
268 table_name: String,
269 id_column_ident: Ident,
270 external_id: bool,
271}
272
273impl<'a> Config<'a> {
274 fn new(attrs: &[Attribute], ident: &'a Ident, named: &'a Punctuated<Field, Comma>) -> Self {
275 let crate_name = std::env::var("CARGO_PKG_NAME").unwrap();
276 let is_doctest = std::env::vars()
277 .any(|(k, _)| k == "UNSTABLE_RUSTDOC_TEST_LINE" || k == "UNSTABLE_RUSTDOC_TEST_PATH");
278 let crate_name = if !is_doctest && crate_name == "sqlx-crud" {
279 quote! { crate }
280 } else {
281 quote! { ::souchy_sqlx_crud }
282 };
283
284 let db_ty = DbType::new(attrs);
285
286 let model_schema_ident =
287 format_ident!("{}_SCHEMA", ident.to_string().to_screaming_snake_case());
288
289 let id_attr = &named
291 .iter()
292 .find(|f| f.attrs.iter().any(|a| a.path().is_ident("id")))
293 .and_then(|f| f.ident.as_ref());
294 let id_column_ident = id_attr
296 .unwrap_or_else(|| {
297 named
298 .iter()
299 .flat_map(|f| &f.ident)
300 .next()
301 .expect("the first field")
302 })
303 .clone();
304
305 let external_id = attrs.iter().any(|a| a.path().is_ident("external_id"));
306
307 let table_name = attrs
308 .iter()
309 .find(|a| a.path().is_ident("table"))
310 .and_then(|attr| {
311 let mut table = None;
312 attr.parse_nested_meta(|meta| {
313 if let Some(ident) = meta.path.get_ident() {
314 table = Some(ident.to_string());
315 }
316 Ok(())
317 })
318 .ok();
319 table
320 })
321 .unwrap_or_else(|| ident.to_string().to_table_case());
322
323 Self {
324 ident,
325 named,
326 crate_name,
327 db_ty,
328 model_schema_ident,
329 table_name,
330 id_column_ident,
331 external_id,
332 }
333 }
334
335 fn quote_ident(&self, ident: &str) -> String {
336 self.db_ty.quote_ident(ident)
337 }
338}
339
340enum DbType {
341 Any,
342 Mssql,
343 MySql,
344 Postgres,
345 Sqlite,
346}
347
348impl From<&str> for DbType {
349 fn from(db_type: &str) -> Self {
350 match db_type {
351 "Any" => Self::Any,
352 "Mssql" => Self::Mssql,
353 "MySql" => Self::MySql,
354 "Postgres" => Self::Postgres,
355 "Sqlite" => Self::Sqlite,
356 _ => panic!("unknown #[database] type {}", db_type),
357 }
358 }
359}
360
361impl DbType {
362 fn new(attrs: &[Attribute]) -> Self {
363 let mut db_type = DbType::Sqlite;
364 attrs
365 .iter()
366 .find(|a| a.path().is_ident("database"))
367 .map(|a| {
368 a.parse_nested_meta(|m| {
369 if let Some(path) = m.path.get_ident() {
370 db_type = DbType::from(path.to_string().as_str());
371 }
372 Ok(())
373 })
374 });
375
376 db_type
377 }
378
379 fn sqlx_db(&self) -> TokenStream2 {
380 match self {
381 Self::Any => quote! { ::sqlx::Any },
382 Self::Mssql => quote! { ::sqlx::Mssql },
383 Self::MySql => quote! { ::sqlx::MySql },
384 Self::Postgres => quote! { ::sqlx::Postgres },
385 Self::Sqlite => quote! { ::sqlx::Sqlite },
386 }
387 }
388
389 fn quote_ident(&self, ident: &str) -> String {
390 match self {
391 Self::Any => format!(r#""{}""#, &ident),
392 Self::Mssql => format!(r#""{}""#, &ident),
393 Self::MySql => format!("`{}`", &ident),
394 Self::Postgres => format!(r#""{}""#, &ident),
395 Self::Sqlite => format!(r#""{}""#, &ident),
396 }
397 }
398}