sqlx_meta_macros/
lib.rs

1use inflector::Inflector;
2use proc_macro::{self, TokenStream};
3use proc_macro2::TokenStream as TokenStream2;
4use quote::{format_ident, quote};
5use syn::punctuated::Punctuated;
6use syn::token::Comma;
7use syn::{
8    parse_macro_input, Attribute, Data, DataStruct, DeriveInput, Field, Fields, FieldsNamed, Ident,
9    Lit, LitStr, Meta, MetaNameValue,
10};
11
12/// Derive metadata for the entity
13///
14/// # Panics
15///
16/// Panics if conversions fail
17#[proc_macro_derive(SqlxMeta, attributes(database, external_id, id))]
18pub fn sqlx_meta(input: TokenStream) -> TokenStream {
19    let DeriveInput {
20        ident, data, attrs, ..
21    } = parse_macro_input!(input);
22    match data {
23        Data::Struct(DataStruct {
24            fields: Fields::Named(FieldsNamed { named, .. }),
25            ..
26        }) => {
27            let config = Config::new(&attrs, &ident, &named);
28            let static_model_schema = build_static_model_schema(&config);
29            let sqlx_crud_impl = build_sqlx_crud_impl(&config);
30
31            quote! {
32                #static_model_schema
33                #sqlx_crud_impl
34            }
35            .into()
36        }
37        _ => panic!("this derive macro only works on structs with named fields"),
38    }
39}
40
41fn build_static_model_schema(config: &Config<'_>) -> TokenStream2 {
42    let crate_name = &config.crate_name;
43    let model_schema_ident = &config.model_schema_ident;
44    let table_name = &config.table_name;
45
46    let id_column = config.id_column_ident.to_string();
47    let columns_len = config.named.iter().count();
48    let columns = config
49        .named
50        .iter()
51        .flat_map(|f| &f.ident)
52        .map(|f| LitStr::new(format!("{f}").as_str(), f.span()));
53
54    quote! {
55        #[automatically_derived]
56        static #model_schema_ident: #crate_name::schema::Metadata<'static, #columns_len> = #crate_name::schema::Metadata {
57            table_name: #table_name,
58            id_column: #id_column,
59            columns: [#(#columns),*],
60        };
61    }
62}
63
64fn build_sqlx_crud_impl(config: &Config<'_>) -> TokenStream2 {
65    let crate_name = &config.crate_name;
66    let ident = &config.ident;
67    let model_schema_ident = &config.model_schema_ident;
68    let id_column_ident = &config.id_column_ident;
69    let id_ty = config
70        .named
71        .iter()
72        .find(|f| f.ident.as_ref() == Some(id_column_ident))
73        .map(|f| &f.ty)
74        .expect("the id type");
75
76    let insert_binds = config
77        .named
78        .iter()
79        .flat_map(|f| &f.ident)
80        .map(|i| quote! { .bind(&self.#i) });
81    let update_binds = config
82        .named
83        .iter()
84        .flat_map(|f| &f.ident)
85        .filter(|i| *i != id_column_ident)
86        .map(|i| quote! { .bind(&self.#i) });
87
88    let db_ty = config.db_ty.sqlx_db();
89
90    quote! {
91        #[automatically_derived]
92        impl #crate_name::traits::Schema for #ident {
93            type Id = #id_ty;
94
95            fn table_name() -> &'static str {
96                #model_schema_ident.table_name
97            }
98
99            fn id(&self) -> Self::Id {
100                self.#id_column_ident
101            }
102
103            fn id_column() -> &'static str {
104                #model_schema_ident.id_column
105            }
106
107            fn columns() -> &'static [&'static str] {
108                &#model_schema_ident.columns
109            }
110        }
111
112        #[automatically_derived]
113        impl<'e> #crate_name::traits::Binds<'e, &'e ::sqlx::pool::Pool<#db_ty>> for #ident {
114            fn insert_binds(
115                &'e self,
116                query: ::sqlx::query::QueryAs<'e, #db_ty, Self, <#db_ty as ::sqlx::database::HasArguments<'e>>::Arguments>
117            ) -> ::sqlx::query::QueryAs<'e, #db_ty, Self, <#db_ty as ::sqlx::database::HasArguments<'e>>::Arguments> {
118                query
119                    #(#insert_binds)*
120            }
121
122            fn update_binds(
123                &'e self,
124                query: ::sqlx::query::QueryAs<'e, #db_ty, Self, <#db_ty as ::sqlx::database::HasArguments<'e>>::Arguments>
125            ) -> ::sqlx::query::QueryAs<'e, #db_ty, Self, <#db_ty as ::sqlx::database::HasArguments<'e>>::Arguments> {
126                query
127                    #(#update_binds)*
128                    .bind(&self.#id_column_ident)
129            }
130        }
131    }
132}
133
134#[allow(dead_code)] // Usage in quote macros aren't flagged as used
135struct Config<'a> {
136    ident: &'a Ident,
137    named: &'a Punctuated<Field, Comma>,
138    crate_name: TokenStream2,
139    db_ty: DbType,
140    model_schema_ident: Ident,
141    table_name: String,
142    id_column_ident: Ident,
143    external_id: bool,
144}
145
146impl<'a> Config<'a> {
147    fn new(attrs: &[Attribute], ident: &'a Ident, named: &'a Punctuated<Field, Comma>) -> Self {
148        let crate_name = std::env::var("CARGO_PKG_NAME").unwrap();
149        let is_doctest = std::env::vars()
150            .any(|(k, _)| k == "UNSTABLE_RUSTDOC_TEST_LINE" || k == "UNSTABLE_RUSTDOC_TEST_PATH");
151        let crate_name = if !is_doctest && crate_name == "sqlx-meta" {
152            quote! { crate }
153        } else {
154            quote! { ::sqlx_meta }
155        };
156
157        let db_ty = DbType::new(attrs);
158
159        let model_schema_ident =
160            format_ident!("{}_SCHEMA", ident.to_string().to_screaming_snake_case());
161
162        let table_name = ident.to_string().to_table_case();
163
164        // Search for a field with the #[id] attribute
165        let id_attr = &named
166            .iter()
167            .find(|f| f.attrs.iter().any(|a| a.path.is_ident("id")))
168            .and_then(|f| f.ident.as_ref());
169        // Otherwise default to the first field as the "id" column
170        let id_column_ident = id_attr
171            .unwrap_or_else(|| {
172                named
173                    .iter()
174                    .flat_map(|f| &f.ident)
175                    .next()
176                    .expect("the first field")
177            })
178            .clone();
179
180        let external_id = attrs.iter().any(|a| a.path.is_ident("external_id"));
181
182        Self {
183            ident,
184            named,
185            crate_name,
186            db_ty,
187            model_schema_ident,
188            table_name,
189            id_column_ident,
190            external_id,
191        }
192    }
193}
194
195enum DbType {
196    Any,
197    Mssql,
198    MySql,
199    Postgres,
200    Sqlite,
201}
202
203#[allow(clippy::fallible_impl_from)]
204impl From<&str> for DbType {
205    fn from(db_type: &str) -> Self {
206        match db_type {
207            "Any" => Self::Any,
208            "Mssql" => Self::Mssql,
209            "MySql" => Self::MySql,
210            "Postgres" => Self::Postgres,
211            "Sqlite" => Self::Sqlite,
212            _ => panic!("unknown #[database] type {db_type}"),
213        }
214    }
215}
216
217impl DbType {
218    fn new(attrs: &[Attribute]) -> Self {
219        match attrs
220            .iter()
221            .find(|a| a.path.is_ident("database"))
222            .map(syn::Attribute::parse_meta)
223        {
224            Some(Ok(Meta::NameValue(MetaNameValue {
225                lit: Lit::Str(s), ..
226            }))) => Self::from(&*s.value()),
227            _ => Self::Sqlite,
228        }
229    }
230
231    fn sqlx_db(&self) -> TokenStream2 {
232        match self {
233            Self::Any => quote! { ::sqlx::Any },
234            Self::Mssql => quote! { ::sqlx::Mssql },
235            Self::MySql => quote! { ::sqlx::MySql },
236            Self::Postgres => quote! { ::sqlx::Postgres },
237            Self::Sqlite => quote! { ::sqlx::Sqlite },
238        }
239    }
240}