sqlx_crud_macros/
lib.rs

1use inflector::Inflector;
2use proc_macro::{self, TokenStream};
3use proc_macro2::TokenStream as TokenStream2;
4use quote::{format_ident, quote};
5use syn::punctuated::Punctuated;
6use syn::token::Comma;
7use syn::{
8    parse_macro_input, Attribute, Data, DataStruct, DeriveInput, Expr, Field, Fields, FieldsNamed, Ident,
9    LitStr, Meta, MetaNameValue, Lit, ExprLit,
10};
11
12#[proc_macro_derive(SqlxCrud, attributes(database, external_id, id))]
13pub fn derive(input: TokenStream) -> TokenStream {
14    let DeriveInput {
15        ident, data, attrs, ..
16    } = parse_macro_input!(input);
17    match data {
18        Data::Struct(DataStruct {
19            fields: Fields::Named(FieldsNamed { named, .. }),
20            ..
21        }) => {
22            let config = Config::new(&attrs, &ident, &named);
23            let static_model_schema = build_static_model_schema(&config);
24            let sqlx_crud_impl = build_sqlx_crud_impl(&config);
25
26            quote! {
27                #static_model_schema
28                #sqlx_crud_impl
29            }
30            .into()
31        }
32        _ => panic!("this derive macro only works on structs with named fields"),
33    }
34}
35
36fn build_static_model_schema(config: &Config) -> TokenStream2 {
37    let crate_name = &config.crate_name;
38    let model_schema_ident = &config.model_schema_ident;
39    let table_name = &config.table_name;
40
41    let id_column = config.id_column_ident.to_string();
42    let columns_len = config.named.iter().count();
43    let columns = config
44        .named
45        .iter()
46        .flat_map(|f| &f.ident)
47        .map(|f| LitStr::new(format!("{}", f).as_str(), f.span()));
48
49    let sql_queries = build_sql_queries(config);
50
51    quote! {
52        #[automatically_derived]
53        static #model_schema_ident: #crate_name::schema::Metadata<'static, #columns_len> = #crate_name::schema::Metadata {
54            table_name: #table_name,
55            id_column: #id_column,
56            columns: [#(#columns),*],
57            #sql_queries
58        };
59    }
60}
61
62fn build_sql_queries(config: &Config) -> TokenStream2 {
63    let table_name = config.quote_ident(&config.table_name);
64    let id_column = format!(
65        "{}.{}",
66        &table_name,
67        config.quote_ident(&config.id_column_ident.to_string())
68    );
69
70    let insert_bind_cnt = if config.external_id {
71        config.named.iter().count()
72    } else {
73        config.named.iter().count() - 1
74    };
75    let insert_sql_binds = (0..insert_bind_cnt)
76        .map(|_| "?")
77        .collect::<Vec<_>>()
78        .join(", ");
79
80    let update_sql_binds = config
81        .named
82        .iter()
83        .flat_map(|f| &f.ident)
84        .filter(|i| *i != &config.id_column_ident)
85        .map(|i| format!("{} = ?", config.quote_ident(&i.to_string())))
86        .collect::<Vec<_>>()
87        .join(", ");
88
89    let insert_column_list = config
90        .named
91        .iter()
92        .flat_map(|f| &f.ident)
93        .filter(|i| config.external_id || *i != &config.id_column_ident)
94        .map(|i| config.quote_ident(&i.to_string()))
95        .collect::<Vec<_>>()
96        .join(", ");
97    let column_list = config
98        .named
99        .iter()
100        .flat_map(|f| &f.ident)
101        .map(|i| format!("{}.{}", &table_name, config.quote_ident(&i.to_string())))
102        .collect::<Vec<_>>()
103        .join(", ");
104
105    let select_sql = format!("SELECT {} FROM {}", column_list, table_name);
106    let select_by_id_sql = format!(
107        "SELECT {} FROM {} WHERE {} = ? LIMIT 1",
108        column_list, table_name, id_column
109    );
110    let insert_sql = format!(
111        "INSERT INTO {} ({}) VALUES ({}) RETURNING {}",
112        table_name, insert_column_list, insert_sql_binds, column_list
113    );
114    let update_by_id_sql = format!(
115        "UPDATE {} SET {} WHERE {} = ? RETURNING {}",
116        table_name, update_sql_binds, id_column, column_list
117    );
118    let delete_by_id_sql = format!("DELETE FROM {} WHERE {} = ?", table_name, id_column);
119
120    quote! {
121        select_sql: #select_sql,
122        select_by_id_sql: #select_by_id_sql,
123        insert_sql: #insert_sql,
124        update_by_id_sql: #update_by_id_sql,
125        delete_by_id_sql: #delete_by_id_sql,
126    }
127}
128
129fn build_sqlx_crud_impl(config: &Config) -> TokenStream2 {
130    let crate_name = &config.crate_name;
131    let ident = &config.ident;
132    let model_schema_ident = &config.model_schema_ident;
133    let db_ty = config.db_ty.sqlx_db();
134    let id_column_ident = &config.id_column_ident;
135
136    let id_ty = config
137        .named
138        .iter()
139        .find(|f| f.ident.as_ref() == Some(id_column_ident))
140        .map(|f| &f.ty)
141        .expect("the id type");
142
143    let insert_query_args = config
144        .named
145        .iter()
146        .flat_map(|f| &f.ident)
147        .filter(|i| config.external_id || *i != &config.id_column_ident)
148        .map(|i| quote! { args.add(self.#i); });
149
150    let insert_query_size = config
151        .named
152        .iter()
153        .flat_map(|f| &f.ident)
154        .filter(|i| config.external_id || *i != &config.id_column_ident)
155        .map(|i| quote! { ::sqlx::encode::Encode::<#db_ty>::size_hint(&self.#i) });
156
157    let update_query_args = config
158        .named
159        .iter()
160        .flat_map(|f| &f.ident)
161        .filter(|i| *i != &config.id_column_ident)
162        .map(|i| quote! { args.add(self.#i); });
163
164    let update_query_args_id = quote! { args.add(self.#id_column_ident); };
165
166    let update_query_size = config
167        .named
168        .iter()
169        .flat_map(|f| &f.ident)
170        .map(|i| quote! { ::sqlx::encode::Encode::<#db_ty>::size_hint(&self.#i) });
171
172    quote! {
173        #[automatically_derived]
174        impl #crate_name::traits::Schema for #ident {
175            type Id = #id_ty;
176
177            fn table_name() -> &'static str {
178                #model_schema_ident.table_name
179            }
180
181            fn id(&self) -> Self::Id {
182                self.#id_column_ident
183            }
184
185            fn id_column() -> &'static str {
186                #model_schema_ident.id_column
187            }
188
189            fn columns() -> &'static [&'static str] {
190                &#model_schema_ident.columns
191            }
192
193            fn select_sql() -> &'static str {
194                #model_schema_ident.select_sql
195            }
196
197            fn select_by_id_sql() -> &'static str {
198                #model_schema_ident.select_by_id_sql
199            }
200
201            fn insert_sql() -> &'static str {
202                #model_schema_ident.insert_sql
203            }
204
205            fn update_by_id_sql() -> &'static str {
206                #model_schema_ident.update_by_id_sql
207            }
208
209            fn delete_by_id_sql() -> &'static str {
210                #model_schema_ident.delete_by_id_sql
211            }
212        }
213
214        #[automatically_derived]
215        impl<'e> #crate_name::traits::Crud<'e, &'e ::sqlx::pool::Pool<#db_ty>> for #ident {
216            fn insert_args(self) -> <#db_ty as ::sqlx::database::HasArguments<'e>>::Arguments {
217                use ::sqlx::Arguments as _;
218                let mut args = <#db_ty as ::sqlx::database::HasArguments<'e>>::Arguments::default();
219                args.reserve(1usize, #(#insert_query_size)+*);
220                #(#insert_query_args)*
221                args
222            }
223
224            fn update_args(self) -> <#db_ty as ::sqlx::database::HasArguments<'e>>::Arguments {
225                use ::sqlx::Arguments as _;
226                let mut args = <#db_ty as ::sqlx::database::HasArguments<'e>>::Arguments::default();
227                args.reserve(1usize, #(#update_query_size)+*);
228                #(#update_query_args)*
229                #update_query_args_id
230                args
231            }
232        }
233    }
234}
235
236#[allow(dead_code)] // Usage in quote macros aren't flagged as used
237struct Config<'a> {
238    ident: &'a Ident,
239    named: &'a Punctuated<Field, Comma>,
240    crate_name: TokenStream2,
241    db_ty: DbType,
242    model_schema_ident: Ident,
243    table_name: String,
244    id_column_ident: Ident,
245    external_id: bool,
246}
247
248impl<'a> Config<'a> {
249    fn new(attrs: &[Attribute], ident: &'a Ident, named: &'a Punctuated<Field, Comma>) -> Self {
250        let crate_name = std::env::var("CARGO_PKG_NAME").unwrap();
251        let is_doctest = std::env::vars()
252            .any(|(k, _)| k == "UNSTABLE_RUSTDOC_TEST_LINE" || k == "UNSTABLE_RUSTDOC_TEST_PATH");
253        let crate_name = if !is_doctest && crate_name == "sqlx-crud" {
254            quote! { crate }
255        } else {
256            quote! { ::sqlx_crud }
257        };
258
259        let db_ty = DbType::new(attrs);
260
261        let model_schema_ident =
262            format_ident!("{}_SCHEMA", ident.to_string().to_screaming_snake_case());
263
264        let table_name = ident.to_string().to_table_case();
265
266        // Search for a field with the #[id] attribute
267        let id_attr = &named
268            .iter()
269            .find(|f| f.attrs.iter().any(|a| a.path().is_ident("id")))
270            .and_then(|f| f.ident.as_ref());
271        // Otherwise default to the first field as the "id" column
272        let id_column_ident = id_attr
273            .unwrap_or_else(|| {
274                named
275                    .iter()
276                    .flat_map(|f| &f.ident)
277                    .next()
278                    .expect("the first field")
279            })
280            .clone();
281
282        let external_id = attrs.iter().any(|a| a.path().is_ident("external_id"));
283
284        Self {
285            ident,
286            named,
287            crate_name,
288            db_ty,
289            model_schema_ident,
290            table_name,
291            id_column_ident,
292            external_id,
293        }
294    }
295
296    fn quote_ident(&self, ident: &str) -> String {
297        self.db_ty.quote_ident(ident)
298    }
299}
300
301enum DbType {
302    Any,
303    Mssql,
304    MySql,
305    Postgres,
306    Sqlite,
307}
308
309impl From<&str> for DbType {
310    fn from(db_type: &str) -> Self {
311        match db_type {
312            "Any" => Self::Any,
313            "Mssql" => Self::Mssql,
314            "MySql" => Self::MySql,
315            "Postgres" => Self::Postgres,
316            "Sqlite" => Self::Sqlite,
317            _ => panic!("unknown #[database] type {}", db_type),
318        }
319    }
320}
321
322impl DbType {
323    fn new(attrs: &[Attribute]) -> Self {
324        let mut db_type = DbType::Sqlite;
325        attrs.iter()
326            .find(|a| a.path().is_ident("database"))
327            .map(|a| a.parse_nested_meta(|m| {
328                if let Some(path) = m.path.get_ident() {
329                    db_type = DbType::from(path.to_string().as_str());
330                }
331                Ok(())
332            }));
333
334        db_type
335    }
336
337    fn sqlx_db(&self) -> TokenStream2 {
338        match self {
339            Self::Any => quote! { ::sqlx::Any },
340            Self::Mssql => quote! { ::sqlx::Mssql },
341            Self::MySql => quote! { ::sqlx::MySql },
342            Self::Postgres => quote! { ::sqlx::Postgres },
343            Self::Sqlite => quote! { ::sqlx::Sqlite },
344        }
345    }
346
347    fn quote_ident(&self, ident: &str) -> String {
348        match self {
349            Self::Any => format!(r#""{}""#, &ident),
350            Self::Mssql => format!(r#""{}""#, &ident),
351            Self::MySql => format!("`{}`", &ident),
352            Self::Postgres => format!(r#""{}""#, &ident),
353            Self::Sqlite => format!(r#""{}""#, &ident),
354        }
355    }
356}