1use inflector::Inflector;
2use proc_macro::{self, TokenStream};
3use proc_macro2::TokenStream as TokenStream2;
4use quote::{format_ident, quote};
5use syn::punctuated::Punctuated;
6use syn::token::Comma;
7use syn::{
8 parse_macro_input, Attribute, Data, DataStruct, DeriveInput, Field, Fields, FieldsNamed, Ident,
9 Lit, LitStr, Meta, MetaNameValue,
10};
11
12#[proc_macro_derive(SqlxMeta, attributes(database, external_id, id))]
18pub fn sqlx_meta(input: TokenStream) -> TokenStream {
19 let DeriveInput {
20 ident, data, attrs, ..
21 } = parse_macro_input!(input);
22 match data {
23 Data::Struct(DataStruct {
24 fields: Fields::Named(FieldsNamed { named, .. }),
25 ..
26 }) => {
27 let config = Config::new(&attrs, &ident, &named);
28 let static_model_schema = build_static_model_schema(&config);
29 let sqlx_crud_impl = build_sqlx_crud_impl(&config);
30
31 quote! {
32 #static_model_schema
33 #sqlx_crud_impl
34 }
35 .into()
36 }
37 _ => panic!("this derive macro only works on structs with named fields"),
38 }
39}
40
41fn build_static_model_schema(config: &Config<'_>) -> TokenStream2 {
42 let crate_name = &config.crate_name;
43 let model_schema_ident = &config.model_schema_ident;
44 let table_name = &config.table_name;
45
46 let id_column = config.id_column_ident.to_string();
47 let columns_len = config.named.iter().count();
48 let columns = config
49 .named
50 .iter()
51 .flat_map(|f| &f.ident)
52 .map(|f| LitStr::new(format!("{f}").as_str(), f.span()));
53
54 quote! {
55 #[automatically_derived]
56 static #model_schema_ident: #crate_name::schema::Metadata<'static, #columns_len> = #crate_name::schema::Metadata {
57 table_name: #table_name,
58 id_column: #id_column,
59 columns: [#(#columns),*],
60 };
61 }
62}
63
64fn build_sqlx_crud_impl(config: &Config<'_>) -> TokenStream2 {
65 let crate_name = &config.crate_name;
66 let ident = &config.ident;
67 let model_schema_ident = &config.model_schema_ident;
68 let id_column_ident = &config.id_column_ident;
69 let id_ty = config
70 .named
71 .iter()
72 .find(|f| f.ident.as_ref() == Some(id_column_ident))
73 .map(|f| &f.ty)
74 .expect("the id type");
75
76 let insert_binds = config
77 .named
78 .iter()
79 .flat_map(|f| &f.ident)
80 .map(|i| quote! { .bind(&self.#i) });
81 let update_binds = config
82 .named
83 .iter()
84 .flat_map(|f| &f.ident)
85 .filter(|i| *i != id_column_ident)
86 .map(|i| quote! { .bind(&self.#i) });
87
88 let db_ty = config.db_ty.sqlx_db();
89
90 quote! {
91 #[automatically_derived]
92 impl #crate_name::traits::Schema for #ident {
93 type Id = #id_ty;
94
95 fn table_name() -> &'static str {
96 #model_schema_ident.table_name
97 }
98
99 fn id(&self) -> Self::Id {
100 self.#id_column_ident
101 }
102
103 fn id_column() -> &'static str {
104 #model_schema_ident.id_column
105 }
106
107 fn columns() -> &'static [&'static str] {
108 &#model_schema_ident.columns
109 }
110 }
111
112 #[automatically_derived]
113 impl<'e> #crate_name::traits::Binds<'e, &'e ::sqlx::pool::Pool<#db_ty>> for #ident {
114 fn insert_binds(
115 &'e self,
116 query: ::sqlx::query::QueryAs<'e, #db_ty, Self, <#db_ty as ::sqlx::database::HasArguments<'e>>::Arguments>
117 ) -> ::sqlx::query::QueryAs<'e, #db_ty, Self, <#db_ty as ::sqlx::database::HasArguments<'e>>::Arguments> {
118 query
119 #(#insert_binds)*
120 }
121
122 fn update_binds(
123 &'e self,
124 query: ::sqlx::query::QueryAs<'e, #db_ty, Self, <#db_ty as ::sqlx::database::HasArguments<'e>>::Arguments>
125 ) -> ::sqlx::query::QueryAs<'e, #db_ty, Self, <#db_ty as ::sqlx::database::HasArguments<'e>>::Arguments> {
126 query
127 #(#update_binds)*
128 .bind(&self.#id_column_ident)
129 }
130 }
131 }
132}
133
134#[allow(dead_code)] struct Config<'a> {
136 ident: &'a Ident,
137 named: &'a Punctuated<Field, Comma>,
138 crate_name: TokenStream2,
139 db_ty: DbType,
140 model_schema_ident: Ident,
141 table_name: String,
142 id_column_ident: Ident,
143 external_id: bool,
144}
145
146impl<'a> Config<'a> {
147 fn new(attrs: &[Attribute], ident: &'a Ident, named: &'a Punctuated<Field, Comma>) -> Self {
148 let crate_name = std::env::var("CARGO_PKG_NAME").unwrap();
149 let is_doctest = std::env::vars()
150 .any(|(k, _)| k == "UNSTABLE_RUSTDOC_TEST_LINE" || k == "UNSTABLE_RUSTDOC_TEST_PATH");
151 let crate_name = if !is_doctest && crate_name == "sqlx-meta" {
152 quote! { crate }
153 } else {
154 quote! { ::sqlx_meta }
155 };
156
157 let db_ty = DbType::new(attrs);
158
159 let model_schema_ident =
160 format_ident!("{}_SCHEMA", ident.to_string().to_screaming_snake_case());
161
162 let table_name = ident.to_string().to_table_case();
163
164 let id_attr = &named
166 .iter()
167 .find(|f| f.attrs.iter().any(|a| a.path.is_ident("id")))
168 .and_then(|f| f.ident.as_ref());
169 let id_column_ident = id_attr
171 .unwrap_or_else(|| {
172 named
173 .iter()
174 .flat_map(|f| &f.ident)
175 .next()
176 .expect("the first field")
177 })
178 .clone();
179
180 let external_id = attrs.iter().any(|a| a.path.is_ident("external_id"));
181
182 Self {
183 ident,
184 named,
185 crate_name,
186 db_ty,
187 model_schema_ident,
188 table_name,
189 id_column_ident,
190 external_id,
191 }
192 }
193}
194
195enum DbType {
196 Any,
197 Mssql,
198 MySql,
199 Postgres,
200 Sqlite,
201}
202
203#[allow(clippy::fallible_impl_from)]
204impl From<&str> for DbType {
205 fn from(db_type: &str) -> Self {
206 match db_type {
207 "Any" => Self::Any,
208 "Mssql" => Self::Mssql,
209 "MySql" => Self::MySql,
210 "Postgres" => Self::Postgres,
211 "Sqlite" => Self::Sqlite,
212 _ => panic!("unknown #[database] type {db_type}"),
213 }
214 }
215}
216
217impl DbType {
218 fn new(attrs: &[Attribute]) -> Self {
219 match attrs
220 .iter()
221 .find(|a| a.path.is_ident("database"))
222 .map(syn::Attribute::parse_meta)
223 {
224 Some(Ok(Meta::NameValue(MetaNameValue {
225 lit: Lit::Str(s), ..
226 }))) => Self::from(&*s.value()),
227 _ => Self::Sqlite,
228 }
229 }
230
231 fn sqlx_db(&self) -> TokenStream2 {
232 match self {
233 Self::Any => quote! { ::sqlx::Any },
234 Self::Mssql => quote! { ::sqlx::Mssql },
235 Self::MySql => quote! { ::sqlx::MySql },
236 Self::Postgres => quote! { ::sqlx::Postgres },
237 Self::Sqlite => quote! { ::sqlx::Sqlite },
238 }
239 }
240}