Skip to main content

souchy_sqlx_crud_macros/
lib.rs

1use inflector::Inflector;
2use proc_macro::{self, TokenStream};
3use proc_macro2::TokenStream as TokenStream2;
4use quote::{format_ident, quote};
5use syn::punctuated::Punctuated;
6use syn::token::Comma;
7use syn::{
8    parse_macro_input, Attribute, Data, DataStruct, DeriveInput, Field, Fields, FieldsNamed, Ident,
9    LitStr,
10};
11
12#[proc_macro_derive(SqlxCrud, attributes(database, table, external_id, id))]
13pub fn derive(input: TokenStream) -> TokenStream {
14    let DeriveInput {
15        ident, data, attrs, ..
16    } = parse_macro_input!(input);
17    match data {
18        Data::Struct(DataStruct {
19            fields: Fields::Named(FieldsNamed { named, .. }),
20            ..
21        }) => {
22            let config = Config::new(&attrs, &ident, &named);
23            let static_model_schema = build_static_model_schema(&config);
24            let sqlx_crud_impl = build_sqlx_crud_impl(&config);
25
26            quote! {
27                #static_model_schema
28                #sqlx_crud_impl
29            }
30            .into()
31        }
32        _ => panic!("this derive macro only works on structs with named fields"),
33    }
34}
35
36fn build_static_model_schema(config: &Config) -> TokenStream2 {
37    let crate_name = &config.crate_name;
38    let model_schema_ident = &config.model_schema_ident;
39    let table_name = &config.table_name;
40
41    let id_column = config.id_column_ident.to_string();
42    let columns_len = config.named.iter().count();
43    let columns = config
44        .named
45        .iter()
46        .flat_map(|f| &f.ident)
47        .map(|f| LitStr::new(format!("{}", f).as_str(), f.span()));
48
49    let sql_queries = build_sql_queries(config);
50
51    quote! {
52        #[automatically_derived]
53        static #model_schema_ident: #crate_name::schema::Metadata<'static, #columns_len> = #crate_name::schema::Metadata {
54            table_name: #table_name,
55            id_column: #id_column,
56            columns: [#(#columns),*],
57            #sql_queries
58        };
59    }
60}
61
62fn build_sql_queries(config: &Config) -> TokenStream2 {
63    let table_name = config.quote_ident(&config.table_name);
64    let id_column = format!(
65        "{}.{}",
66        &table_name,
67        config.quote_ident(&config.id_column_ident.to_string())
68    );
69
70    let insert_bind_cnt = if config.external_id {
71        config.named.iter().count()
72    } else {
73        config.named.iter().count() - 1
74    };
75    let insert_sql_binds = (0..insert_bind_cnt)
76        .map(|i| format!("${}", i + 1))
77        .collect::<Vec<_>>()
78        .join(", ");
79
80    let update_sql_binds = config
81        .named
82        .iter()
83        .flat_map(|f| &f.ident)
84        .filter(|i| *i != &config.id_column_ident)
85        .enumerate()
86        .map(|(i, ident)| format!("{} = ${}", config.quote_ident(&ident.to_string()), i + 1))
87        .collect::<Vec<_>>();
88
89    let insert_column_list = config
90        .named
91        .iter()
92        .flat_map(|f| &f.ident)
93        .filter(|i| config.external_id || *i != &config.id_column_ident)
94        .map(|i| config.quote_ident(&i.to_string()))
95        .collect::<Vec<_>>()
96        .join(", ");
97    let column_list = config
98        .named
99        .iter()
100        .flat_map(|f| &f.ident)
101        .map(|i| format!("{}.{}", &table_name, config.quote_ident(&i.to_string())))
102        .collect::<Vec<_>>()
103        .join(", ");
104
105    let select_sql = format!("SELECT {} FROM {}", column_list, table_name);
106    let paginated_sql = format!(
107        "SELECT {} FROM {} LIMIT $1 OFFSET $2",
108        column_list, table_name
109    );
110    let select_by_id_sql = format!(
111        "SELECT {} FROM {} WHERE {} = $1 LIMIT 1",
112        column_list, table_name, id_column
113    );
114    let insert_sql = format!(
115        "INSERT INTO {} ({}) VALUES ({}) RETURNING {}",
116        table_name, insert_column_list, insert_sql_binds, column_list
117    );
118    let update_by_id_sql = format!(
119        "UPDATE {} SET {} WHERE {} = ${} RETURNING {}",
120        table_name,
121        update_sql_binds.join(", "),
122        id_column,
123        update_sql_binds.len() + 1,
124        column_list
125    );
126    let delete_by_id_sql = format!("DELETE FROM {} WHERE {} = $1", table_name, id_column);
127
128    quote! {
129        select_sql: #select_sql,
130        select_by_id_sql: #select_by_id_sql,
131        insert_sql: #insert_sql,
132        update_by_id_sql: #update_by_id_sql,
133        delete_by_id_sql: #delete_by_id_sql,
134        paginated_sql: #paginated_sql,
135    }
136}
137
138fn build_sqlx_crud_impl(config: &Config) -> TokenStream2 {
139    let crate_name = &config.crate_name;
140    let ident = &config.ident;
141    let model_schema_ident = &config.model_schema_ident;
142    let db_ty = config.db_ty.sqlx_db();
143    let id_column_ident = &config.id_column_ident;
144
145    let id_ty = config
146        .named
147        .iter()
148        .find(|f| f.ident.as_ref() == Some(id_column_ident))
149        .map(|f| &f.ty)
150        .expect("the id type");
151
152    let insert_query_args = config
153        .named
154        .iter()
155        .flat_map(|f| &f.ident)
156        .filter(|i| config.external_id || *i != &config.id_column_ident)
157        .map(|i| quote! { args.add(self.#i).map_err(sqlx::Error::Encode)?; });
158
159    let insert_query_size = config
160        .named
161        .iter()
162        .flat_map(|f| &f.ident)
163        .filter(|i| config.external_id || *i != &config.id_column_ident)
164        .map(|i| quote! { ::sqlx::encode::Encode::<#db_ty>::size_hint(&self.#i) });
165
166    let update_query_args = config
167        .named
168        .iter()
169        .flat_map(|f| &f.ident)
170        .filter(|i| *i != &config.id_column_ident)
171        .map(|i| quote! { args.add(self.#i).map_err(sqlx::Error::Encode)?; });
172
173    let update_query_args_id = quote! { args.add(self.#id_column_ident).map_err(sqlx::Error::Encode)?; };
174
175    let update_query_size = config
176        .named
177        .iter()
178        .flat_map(|f| &f.ident)
179        .map(|i| quote! { ::sqlx::encode::Encode::<#db_ty>::size_hint(&self.#i) });
180
181    quote! {
182        #[automatically_derived]
183        impl #crate_name::traits::Schema for #ident {
184            type Id = #id_ty;
185
186            fn table_name() -> &'static str {
187                #model_schema_ident.table_name
188            }
189
190            fn id(&self) -> Self::Id {
191                self.#id_column_ident
192            }
193
194            fn id_column() -> &'static str {
195                #model_schema_ident.id_column
196            }
197
198            fn columns() -> &'static [&'static str] {
199                &#model_schema_ident.columns
200            }
201
202            fn select_sql() -> &'static str {
203                #model_schema_ident.select_sql
204            }
205
206            fn select_by_id_sql() -> &'static str {
207                #model_schema_ident.select_by_id_sql
208            }
209
210            fn insert_sql() -> &'static str {
211                #model_schema_ident.insert_sql
212            }
213
214            fn update_by_id_sql() -> &'static str {
215                #model_schema_ident.update_by_id_sql
216            }
217
218            fn delete_by_id_sql() -> &'static str {
219                #model_schema_ident.delete_by_id_sql
220            }
221
222            fn paginated_sql() -> &'static str {
223                #model_schema_ident.paginated_sql
224            }
225        }
226
227        #[automatically_derived]
228        impl<'e> #crate_name::traits::Crud<'e, &'e ::sqlx::pool::Pool<#db_ty>> for #ident {
229            fn insert_args(self) -> ::sqlx::Result<<#db_ty as ::sqlx::database::Database>::Arguments<'e>> {
230                use ::sqlx::Arguments as _;
231                let mut args = <#db_ty as ::sqlx::database::Database>::Arguments::default();
232                args.reserve(1usize, #(#insert_query_size)+*);
233                #(#insert_query_args)*
234                Ok(args)
235            }
236
237            fn update_args(self) -> ::sqlx::Result<<#db_ty as ::sqlx::database::Database>::Arguments<'e>> {
238                use ::sqlx::Arguments as _;
239                let mut args = <#db_ty as ::sqlx::database::Database>::Arguments::default();
240                args.reserve(1usize, #(#update_query_size)+*);
241                #(#update_query_args)*
242                #update_query_args_id
243                Ok(args)
244            }
245
246            fn paginated_args(limit: i64, offset: i64) -> <#db_ty as ::sqlx::database::Database>::Arguments<'e> {
247                use ::sqlx::Arguments as _;
248                let mut args = <#db_ty as ::sqlx::database::Database>::Arguments::default();
249                args.reserve(2usize,
250                    ::sqlx::encode::Encode::<#db_ty>::size_hint(&limit) +
251                    ::sqlx::encode::Encode::<#db_ty>::size_hint(&offset)
252                );
253                let _ = args.add(limit);
254                let _ = args.add(offset);
255                args
256            }
257        }
258    }
259}
260
261#[allow(dead_code)] // Usage in quote macros aren't flagged as used
262struct Config<'a> {
263    ident: &'a Ident,
264    named: &'a Punctuated<Field, Comma>,
265    crate_name: TokenStream2,
266    db_ty: DbType,
267    model_schema_ident: Ident,
268    table_name: String,
269    id_column_ident: Ident,
270    external_id: bool,
271}
272
273impl<'a> Config<'a> {
274    fn new(attrs: &[Attribute], ident: &'a Ident, named: &'a Punctuated<Field, Comma>) -> Self {
275        let crate_name = std::env::var("CARGO_PKG_NAME").unwrap();
276        let is_doctest = std::env::vars()
277            .any(|(k, _)| k == "UNSTABLE_RUSTDOC_TEST_LINE" || k == "UNSTABLE_RUSTDOC_TEST_PATH");
278        let crate_name = if !is_doctest && crate_name == "sqlx-crud" {
279            quote! { crate }
280        } else {
281            quote! { ::souchy_sqlx_crud }
282        };
283
284        let db_ty = DbType::new(attrs);
285
286        let model_schema_ident =
287            format_ident!("{}_SCHEMA", ident.to_string().to_screaming_snake_case());
288
289        // Search for a field with the #[id] attribute
290        let id_attr = &named
291            .iter()
292            .find(|f| f.attrs.iter().any(|a| a.path().is_ident("id")))
293            .and_then(|f| f.ident.as_ref());
294        // Otherwise default to the first field as the "id" column
295        let id_column_ident = id_attr
296            .unwrap_or_else(|| {
297                named
298                    .iter()
299                    .flat_map(|f| &f.ident)
300                    .next()
301                    .expect("the first field")
302            })
303            .clone();
304
305        let external_id = attrs.iter().any(|a| a.path().is_ident("external_id"));
306
307        let table_name = attrs
308            .iter()
309            .find(|a| a.path().is_ident("table"))
310            .and_then(|attr| {
311                let mut table = None;
312                attr.parse_nested_meta(|meta| {
313                    if let Some(ident) = meta.path.get_ident() {
314                        table = Some(ident.to_string());
315                    }
316                    Ok(())
317                })
318                .ok();
319                table
320            })
321            .unwrap_or_else(|| ident.to_string().to_table_case());
322
323        Self {
324            ident,
325            named,
326            crate_name,
327            db_ty,
328            model_schema_ident,
329            table_name,
330            id_column_ident,
331            external_id,
332        }
333    }
334
335    fn quote_ident(&self, ident: &str) -> String {
336        self.db_ty.quote_ident(ident)
337    }
338}
339
340enum DbType {
341    Any,
342    Mssql,
343    MySql,
344    Postgres,
345    Sqlite,
346}
347
348impl From<&str> for DbType {
349    fn from(db_type: &str) -> Self {
350        match db_type {
351            "Any" => Self::Any,
352            "Mssql" => Self::Mssql,
353            "MySql" => Self::MySql,
354            "Postgres" => Self::Postgres,
355            "Sqlite" => Self::Sqlite,
356            _ => panic!("unknown #[database] type {}", db_type),
357        }
358    }
359}
360
361impl DbType {
362    fn new(attrs: &[Attribute]) -> Self {
363        let mut db_type = DbType::Sqlite;
364        attrs
365            .iter()
366            .find(|a| a.path().is_ident("database"))
367            .map(|a| {
368                a.parse_nested_meta(|m| {
369                    if let Some(path) = m.path.get_ident() {
370                        db_type = DbType::from(path.to_string().as_str());
371                    }
372                    Ok(())
373                })
374            });
375
376        db_type
377    }
378
379    fn sqlx_db(&self) -> TokenStream2 {
380        match self {
381            Self::Any => quote! { ::sqlx::Any },
382            Self::Mssql => quote! { ::sqlx::Mssql },
383            Self::MySql => quote! { ::sqlx::MySql },
384            Self::Postgres => quote! { ::sqlx::Postgres },
385            Self::Sqlite => quote! { ::sqlx::Sqlite },
386        }
387    }
388
389    fn quote_ident(&self, ident: &str) -> String {
390        match self {
391            Self::Any => format!(r#""{}""#, &ident),
392            Self::Mssql => format!(r#""{}""#, &ident),
393            Self::MySql => format!("`{}`", &ident),
394            Self::Postgres => format!(r#""{}""#, &ident),
395            Self::Sqlite => format!(r#""{}""#, &ident),
396        }
397    }
398}