1use proc_macro::TokenStream;
2use proc_macro2;
3use quote::quote;
4use syn::{parse::Parser, parse_macro_input, Data, DataStruct, DeriveInput, Fields, Meta};
5
6#[proc_macro_derive(ModelMeta, attributes(model))]
30pub fn derive_model_meta(input: TokenStream) -> TokenStream {
31 let input = parse_macro_input!(input as DeriveInput);
32 let name = &input.ident;
33
34 let mut table_name = None;
36 let mut pk_field = None;
37 let mut soft_delete_field = None;
38
39 for attr in &input.attrs {
40 if attr.path().is_ident("model") {
41 if let syn::Meta::List(list) = &attr.meta {
43 let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
45 if let Ok(metas) = parser.parse2(list.tokens.clone()) {
46 for meta in metas {
47 if let Meta::NameValue(nv) = meta {
48 if nv.path.is_ident("table") {
49 if let syn::Expr::Lit(syn::ExprLit {
50 lit: syn::Lit::Str(s),
51 ..
52 }) = nv.value
53 {
54 table_name = Some(s.value());
55 }
56 } else if nv.path.is_ident("pk") {
57 if let syn::Expr::Lit(syn::ExprLit {
58 lit: syn::Lit::Str(s),
59 ..
60 }) = nv.value
61 {
62 pk_field = Some(s.value());
63 }
64 } else if nv.path.is_ident("soft_delete") {
65 if let syn::Expr::Lit(syn::ExprLit {
66 lit: syn::Lit::Str(s),
67 ..
68 }) = nv.value
69 {
70 soft_delete_field = Some(s.value());
71 }
72 }
73 }
74 }
75 }
76 } else if let syn::Meta::NameValue(nv) = &attr.meta {
77 if nv.path.is_ident("table") {
79 if let syn::Expr::Lit(syn::ExprLit {
80 lit: syn::Lit::Str(s),
81 ..
82 }) = &nv.value
83 {
84 table_name = Some(s.value());
85 }
86 } else if nv.path.is_ident("pk") {
87 if let syn::Expr::Lit(syn::ExprLit {
88 lit: syn::Lit::Str(s),
89 ..
90 }) = &nv.value
91 {
92 pk_field = Some(s.value());
93 }
94 } else if nv.path.is_ident("soft_delete") {
95 if let syn::Expr::Lit(syn::ExprLit {
96 lit: syn::Lit::Str(s),
97 ..
98 }) = &nv.value
99 {
100 soft_delete_field = Some(s.value());
101 }
102 }
103 }
104 }
105 }
106
107 let table = table_name.unwrap_or_else(|| {
109 let s = name.to_string();
110 let mut result = String::new();
112 for (i, c) in s.chars().enumerate() {
113 if c.is_uppercase() && i > 0 {
114 result.push('_');
115 }
116 result.push(c.to_ascii_lowercase());
117 }
118 result
119 });
120
121 let pk = pk_field.unwrap_or_else(|| "id".to_string());
123
124 let expanded = if let Some(soft_delete) = soft_delete_field {
126 let soft_delete_lit = syn::LitStr::new(&soft_delete, proc_macro2::Span::call_site());
128 quote! {
129 impl sqlxplus::Model for #name {
130 const TABLE: &'static str = #table;
131 const PK: &'static str = #pk;
132 const SOFT_DELETE_FIELD: Option<&'static str> = Some(#soft_delete_lit);
133 }
134 }
135 } else {
136 quote! {
138 impl sqlxplus::Model for #name {
139 const TABLE: &'static str = #table;
140 const PK: &'static str = #pk;
141 const SOFT_DELETE_FIELD: Option<&'static str> = None;
142 }
143 }
144 };
145
146 TokenStream::from(expanded)
147}
148
149#[proc_macro_derive(CRUD, attributes(model, skip))]
175pub fn derive_crud(input: TokenStream) -> TokenStream {
176 let input = parse_macro_input!(input as DeriveInput);
177 let name = &input.ident;
178
179 let mut pk_field = None;
181 for attr in &input.attrs {
182 if attr.path().is_ident("model") {
183 if let syn::Meta::List(list) = &attr.meta {
184 let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
185 if let Ok(metas) = parser.parse2(list.tokens.clone()) {
186 for meta in metas {
187 if let Meta::NameValue(nv) = meta {
188 if nv.path.is_ident("pk") {
189 if let syn::Expr::Lit(syn::ExprLit {
190 lit: syn::Lit::Str(s),
191 ..
192 }) = nv.value
193 {
194 pk_field = Some(s.value());
195 }
196 }
197 }
198 }
199 }
200 } else if let syn::Meta::NameValue(nv) = &attr.meta {
201 if nv.path.is_ident("pk") {
202 if let syn::Expr::Lit(syn::ExprLit {
203 lit: syn::Lit::Str(s),
204 ..
205 }) = &nv.value
206 {
207 pk_field = Some(s.value());
208 }
209 }
210 }
211 }
212 }
213 let pk = pk_field.unwrap_or_else(|| "id".to_string());
215
216 let fields = match &input.data {
218 Data::Struct(DataStruct {
219 fields: Fields::Named(fields),
220 ..
221 }) => &fields.named,
222 _ => {
223 return syn::Error::new_spanned(
224 name,
225 "CRUD derive only supports structs with named fields",
226 )
227 .to_compile_error()
228 .into();
229 }
230 };
231
232 let mut field_names = Vec::new();
237 let mut field_columns = Vec::new();
238 let mut skip_fields = Vec::new();
239 let mut insert_field_names = Vec::new();
240 let mut insert_field_columns = Vec::new();
241 let mut update_field_names = Vec::new();
242 let mut pk_ident_opt: Option<&syn::Ident> = None;
244
245 for field in fields {
246 let field_name = field.ident.as_ref().unwrap();
247 let field_name_str = field_name.to_string();
248
249 let mut skip = false;
251 for attr in &field.attrs {
252 if attr.path().is_ident("skip") || attr.path().is_ident("model") {
253 skip = true;
254 break;
255 }
256 }
257
258 if !skip {
259 field_names.push(field_name);
260 field_columns.push(field_name_str.clone());
261 if field_name_str == pk {
262 pk_ident_opt = Some(field_name);
264 } else {
265 insert_field_names.push(field_name);
267 insert_field_columns.push(field_name_str.clone());
268 update_field_names.push(field_name);
269 }
270 } else {
271 skip_fields.push(field_name_str);
272 }
273 }
274
275 let pk_ident = pk_ident_opt.expect("Primary key field not found in struct");
277
278 let insert_fields: Vec<String> = insert_field_columns.clone();
280 let insert_placeholders: Vec<String> =
281 (0..insert_fields.len()).map(|_| "?".to_string()).collect();
282 let insert_sql = format!(
283 "INSERT INTO {} ({}) VALUES ({})",
284 format!("{{TABLE}}"), insert_fields.join(", "),
286 insert_placeholders.join(", ")
287 );
288
289 let update_fields: Vec<String> = field_columns
291 .iter()
292 .filter(|f| *f != &pk)
293 .cloned()
294 .collect();
295 let update_sql = if !update_fields.is_empty() {
296 format!(
297 "UPDATE {} SET {} WHERE {} = ?",
298 format!("{{TABLE}}"),
299 update_fields
300 .iter()
301 .map(|f| format!("{} = ?", f))
302 .collect::<Vec<_>>()
303 .join(", "),
304 format!("{{PK}}")
305 )
306 } else {
307 String::new()
308 };
309
310 let expanded = quote! {
312 #[async_trait::async_trait]
313 impl sqlxplus::Crud for #name {
314 async fn insert(&self, pool: &sqlxplus::DbPool) -> sqlxplus::db_pool::Result<sqlxplus::crud::Id> {
315 use sqlxplus::Model;
316 use sqlxplus::utils::escape_identifier;
317 let table = Self::TABLE;
318 let driver = pool.driver();
319 let escaped_table = escape_identifier(driver, table);
320 let sql = #insert_sql.replace("{TABLE}", &escaped_table);
321 let sql = pool.convert_sql(&sql);
322
323 match pool.driver() {
324 sqlxplus::db_pool::DbDriver::MySql => {
325 let pool_ref = pool.mysql_pool().ok_or(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)?;
326 let result = sqlx::query(&sql)
327 #( .bind(&self.#insert_field_names) )*
328 .execute(pool_ref)
329 .await?;
330 Ok(result.last_insert_id() as i64)
331 }
332 sqlxplus::db_pool::DbDriver::Postgres => {
333 let pool_ref = pool.pg_pool().ok_or(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)?;
334 let pk = Self::PK;
335 use sqlxplus::utils::escape_identifier;
336 let escaped_pk = escape_identifier(sqlxplus::db_pool::DbDriver::Postgres, pk);
337 let sql_with_returning = format!("{} RETURNING {}", sql, escaped_pk);
339 let id: i64 = sqlx::query_scalar(&sql_with_returning)
340 #( .bind(&self.#insert_field_names) )*
341 .fetch_one(pool_ref)
342 .await?;
343 Ok(id)
344 }
345 sqlxplus::db_pool::DbDriver::Sqlite => {
346 let pool_ref = pool.sqlite_pool().ok_or(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)?;
347 let result = sqlx::query(&sql)
348 #( .bind(&self.#insert_field_names) )*
349 .execute(pool_ref)
350 .await?;
351 Ok(result.last_insert_rowid() as i64)
352 }
353 }
354 }
355
356 async fn update(&self, pool: &sqlxplus::DbPool) -> sqlxplus::db_pool::Result<()> {
357 use sqlxplus::Model;
358 use sqlxplus::utils::escape_identifier;
359 let table = Self::TABLE;
360 let pk = Self::PK;
361 let driver = pool.driver();
362 let escaped_table = escape_identifier(driver, table);
363 let escaped_pk = escape_identifier(driver, pk);
364 let sql = #update_sql.replace("{TABLE}", &escaped_table).replace("{PK}", &escaped_pk);
365 let sql = pool.convert_sql(&sql);
366
367 match pool.driver() {
368 sqlxplus::db_pool::DbDriver::MySql => {
369 let pool_ref = pool.mysql_pool().ok_or(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)?;
370 sqlx::query(&sql)
371 #( .bind(&self.#update_field_names) )*
372 .bind(&self.#pk_ident)
373 .execute(pool_ref)
374 .await?;
375 }
376 sqlxplus::db_pool::DbDriver::Postgres => {
377 let pool_ref = pool.pg_pool().ok_or(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)?;
378 sqlx::query(&sql)
379 #( .bind(&self.#update_field_names) )*
380 .bind(&self.#pk_ident)
381 .execute(pool_ref)
382 .await?;
383 }
384 sqlxplus::db_pool::DbDriver::Sqlite => {
385 let pool_ref = pool.sqlite_pool().ok_or(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)?;
386 sqlx::query(&sql)
387 #( .bind(&self.#update_field_names) )*
388 .bind(&self.#pk_ident)
389 .execute(pool_ref)
390 .await?;
391 }
392 }
393 Ok(())
394 }
395 }
396 };
397
398 TokenStream::from(expanded)
399}