1use proc_macro::TokenStream;
2use proc_macro2;
3use quote::quote;
4use syn::{parse::Parser, parse_macro_input, Data, DataStruct, DeriveInput, Fields, Meta};
5
6#[proc_macro_derive(ModelMeta, attributes(model))]
30pub fn derive_model_meta(input: TokenStream) -> TokenStream {
31 let input = parse_macro_input!(input as DeriveInput);
32 let name = &input.ident;
33
34 let mut table_name = None;
36 let mut pk_field = None;
37 let mut soft_delete_field = None;
38
39 for attr in &input.attrs {
40 if attr.path().is_ident("model") {
41 if let syn::Meta::List(list) = &attr.meta {
43 let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
45 if let Ok(metas) = parser.parse2(list.tokens.clone()) {
46 for meta in metas {
47 if let Meta::NameValue(nv) = meta {
48 if nv.path.is_ident("table") {
49 if let syn::Expr::Lit(syn::ExprLit {
50 lit: syn::Lit::Str(s),
51 ..
52 }) = nv.value
53 {
54 table_name = Some(s.value());
55 }
56 } else if nv.path.is_ident("pk") {
57 if let syn::Expr::Lit(syn::ExprLit {
58 lit: syn::Lit::Str(s),
59 ..
60 }) = nv.value
61 {
62 pk_field = Some(s.value());
63 }
64 } else if nv.path.is_ident("soft_delete") {
65 if let syn::Expr::Lit(syn::ExprLit {
66 lit: syn::Lit::Str(s),
67 ..
68 }) = nv.value
69 {
70 soft_delete_field = Some(s.value());
71 }
72 }
73 }
74 }
75 }
76 } else if let syn::Meta::NameValue(nv) = &attr.meta {
77 if nv.path.is_ident("table") {
79 if let syn::Expr::Lit(syn::ExprLit {
80 lit: syn::Lit::Str(s),
81 ..
82 }) = &nv.value
83 {
84 table_name = Some(s.value());
85 }
86 } else if nv.path.is_ident("pk") {
87 if let syn::Expr::Lit(syn::ExprLit {
88 lit: syn::Lit::Str(s),
89 ..
90 }) = &nv.value
91 {
92 pk_field = Some(s.value());
93 }
94 } else if nv.path.is_ident("soft_delete") {
95 if let syn::Expr::Lit(syn::ExprLit {
96 lit: syn::Lit::Str(s),
97 ..
98 }) = &nv.value
99 {
100 soft_delete_field = Some(s.value());
101 }
102 }
103 }
104 }
105 }
106
107 let table = table_name.unwrap_or_else(|| {
109 let s = name.to_string();
110 let mut result = String::new();
112 for (i, c) in s.chars().enumerate() {
113 if c.is_uppercase() && i > 0 {
114 result.push('_');
115 }
116 result.push(c.to_ascii_lowercase());
117 }
118 result
119 });
120
121 let pk = pk_field.unwrap_or_else(|| "id".to_string());
123
124 let expanded = if let Some(soft_delete) = soft_delete_field {
126 let soft_delete_lit = syn::LitStr::new(&soft_delete, proc_macro2::Span::call_site());
128 quote! {
129 impl sqlxplus::Model for #name {
130 const TABLE: &'static str = #table;
131 const PK: &'static str = #pk;
132 const SOFT_DELETE_FIELD: Option<&'static str> = Some(#soft_delete_lit);
133 }
134 }
135 } else {
136 quote! {
138 impl sqlxplus::Model for #name {
139 const TABLE: &'static str = #table;
140 const PK: &'static str = #pk;
141 const SOFT_DELETE_FIELD: Option<&'static str> = None;
142 }
143 }
144 };
145
146 TokenStream::from(expanded)
147}
148
149#[proc_macro_derive(CRUD, attributes(model, skip))]
175pub fn derive_crud(input: TokenStream) -> TokenStream {
176 let input = parse_macro_input!(input as DeriveInput);
177 let name = &input.ident;
178
179 let mut pk_field = None;
181 for attr in &input.attrs {
182 if attr.path().is_ident("model") {
183 if let syn::Meta::List(list) = &attr.meta {
184 let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
185 if let Ok(metas) = parser.parse2(list.tokens.clone()) {
186 for meta in metas {
187 if let Meta::NameValue(nv) = meta {
188 if nv.path.is_ident("pk") {
189 if let syn::Expr::Lit(syn::ExprLit {
190 lit: syn::Lit::Str(s),
191 ..
192 }) = nv.value
193 {
194 pk_field = Some(s.value());
195 }
196 }
197 }
198 }
199 }
200 } else if let syn::Meta::NameValue(nv) = &attr.meta {
201 if nv.path.is_ident("pk") {
202 if let syn::Expr::Lit(syn::ExprLit {
203 lit: syn::Lit::Str(s),
204 ..
205 }) = &nv.value
206 {
207 pk_field = Some(s.value());
208 }
209 }
210 }
211 }
212 }
213 let pk = pk_field.unwrap_or_else(|| "id".to_string());
215
216 let fields = match &input.data {
218 Data::Struct(DataStruct {
219 fields: Fields::Named(fields),
220 ..
221 }) => &fields.named,
222 _ => {
223 return syn::Error::new_spanned(
224 name,
225 "CRUD derive only supports structs with named fields",
226 )
227 .to_compile_error()
228 .into();
229 }
230 };
231
232 let mut pk_ident_opt: Option<&syn::Ident> = None;
236
237 let mut insert_normal_field_names: Vec<&syn::Ident> = Vec::new();
239 let mut insert_normal_field_columns: Vec<syn::LitStr> = Vec::new();
240 let mut insert_option_field_names: Vec<&syn::Ident> = Vec::new();
241 let mut insert_option_field_columns: Vec<syn::LitStr> = Vec::new();
242
243 let mut update_normal_field_names: Vec<&syn::Ident> = Vec::new();
245 let mut update_normal_field_columns: Vec<syn::LitStr> = Vec::new();
246 let mut update_option_field_names: Vec<&syn::Ident> = Vec::new();
247 let mut update_option_field_columns: Vec<syn::LitStr> = Vec::new();
248
249 for field in fields {
250 let field_name = field.ident.as_ref().unwrap();
251 let field_name_str = field_name.to_string();
252
253 let mut skip = false;
255 for attr in &field.attrs {
256 if attr.path().is_ident("skip") || attr.path().is_ident("model") {
257 skip = true;
258 break;
259 }
260 }
261
262 if !skip {
263 if field_name_str == pk {
264 pk_ident_opt = Some(field_name);
266 } else {
267 let is_opt = is_option_type(&field.ty);
269 let col_lit = syn::LitStr::new(&field_name_str, proc_macro2::Span::call_site());
270
271 if is_opt {
272 insert_option_field_names.push(field_name);
273 insert_option_field_columns.push(col_lit.clone());
274
275 update_option_field_names.push(field_name);
276 update_option_field_columns.push(col_lit);
277 } else {
278 insert_normal_field_names.push(field_name);
279 insert_normal_field_columns.push(col_lit.clone());
280
281 update_normal_field_names.push(field_name);
282 update_normal_field_columns.push(col_lit);
283 }
284 }
285 }
286 }
287
288 let pk_ident = pk_ident_opt.expect("Primary key field not found in struct");
290
291 let expanded = quote! {
293 #[async_trait::async_trait]
294 impl sqlxplus::Crud for #name {
295 async fn insert<E>(&self, executor: &mut E) -> sqlxplus::db_pool::Result<sqlxplus::crud::Id>
296 where
297 E: sqlxplus::executor::DbExecutor + Send,
298 {
299 use sqlxplus::Model;
300 use sqlxplus::utils::escape_identifier;
301 let table = Self::TABLE;
302 let driver = executor.driver();
303 let escaped_table = escape_identifier(driver, table);
304
305 let mut columns: Vec<&str> = Vec::new();
309 let mut placeholders: Vec<&str> = Vec::new();
310
311 #(
313 columns.push(#insert_normal_field_columns);
314 placeholders.push("?");
315 )*
316
317 #(
319 if self.#insert_option_field_names.is_some() {
320 columns.push(#insert_option_field_columns);
321 placeholders.push("?");
322 }
323 )*
324
325 let raw_sql = format!(
326 "INSERT INTO {} ({}) VALUES ({})",
327 escaped_table,
328 columns.join(", "),
329 placeholders.join(", ")
330 );
331 let sql = executor.convert_sql(&raw_sql);
332
333 match executor.driver() {
334 sqlxplus::db_pool::DbDriver::MySql => {
335 let mut query = sqlx::query(&sql);
336 #(
338 query = query.bind(&self.#insert_normal_field_names);
339 )*
340 #(
342 if self.#insert_option_field_names.is_some() {
343 query = query.bind(&self.#insert_option_field_names);
344 }
345 )*
346 if let Some(tx_ref) = executor.mysql_transaction_ref() {
347 let result = query
348 .execute(&mut **tx_ref)
349 .await?;
350 Ok(result.last_insert_id() as i64)
351 } else if let Some(pool_ref) = executor.mysql_pool() {
352 let result = query
353 .execute(pool_ref)
354 .await?;
355 Ok(result.last_insert_id() as i64)
356 } else {
357 Err(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)
358 }
359 }
360 sqlxplus::db_pool::DbDriver::Postgres => {
361 let pk = Self::PK;
362 let escaped_pk = escape_identifier(sqlxplus::db_pool::DbDriver::Postgres, pk);
363 let sql_with_returning = format!("{} RETURNING {}", sql, escaped_pk);
365 let mut query = sqlx::query_scalar::<_, i64>(&sql_with_returning);
366 #(
368 query = query.bind(&self.#insert_normal_field_names);
369 )*
370 #(
372 if self.#insert_option_field_names.is_some() {
373 query = query.bind(&self.#insert_option_field_names);
374 }
375 )*
376 if let Some(tx_ref) = executor.postgres_transaction_ref() {
377 let id: i64 = query
378 .fetch_one(&mut **tx_ref)
379 .await?;
380 Ok(id)
381 } else if let Some(pool_ref) = executor.pg_pool() {
382 let id: i64 = query
383 .fetch_one(pool_ref)
384 .await?;
385 Ok(id)
386 } else {
387 Err(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)
388 }
389 }
390 sqlxplus::db_pool::DbDriver::Sqlite => {
391 let mut query = sqlx::query(&sql);
392 #(
394 query = query.bind(&self.#insert_normal_field_names);
395 )*
396 #(
398 if self.#insert_option_field_names.is_some() {
399 query = query.bind(&self.#insert_option_field_names);
400 }
401 )*
402 if let Some(tx_ref) = executor.sqlite_transaction_ref() {
403 let result = query
404 .execute(&mut **tx_ref)
405 .await?;
406 Ok(result.last_insert_rowid() as i64)
407 } else if let Some(pool_ref) = executor.sqlite_pool() {
408 let result = query
409 .execute(pool_ref)
410 .await?;
411 Ok(result.last_insert_rowid() as i64)
412 } else {
413 Err(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)
414 }
415 }
416 }
417 }
418
419 async fn update<E>(&self, executor: &mut E) -> sqlxplus::db_pool::Result<()>
420 where
421 E: sqlxplus::executor::DbExecutor + Send,
422 {
423 use sqlxplus::Model;
424 use sqlxplus::utils::escape_identifier;
425 let table = Self::TABLE;
426 let pk = Self::PK;
427 let driver = executor.driver();
428 let escaped_table = escape_identifier(driver, table);
429 let escaped_pk = escape_identifier(driver, pk);
430
431 let mut set_parts: Vec<String> = Vec::new();
435
436 #(
438 set_parts.push(format!("{} = ?", #update_normal_field_columns));
439 )*
440
441 #(
443 if self.#update_option_field_names.is_some() {
444 set_parts.push(format!("{} = ?", #update_option_field_columns));
445 }
446 )*
447
448 let raw_sql = if !set_parts.is_empty() {
449 format!(
450 "UPDATE {} SET {} WHERE {} = ?",
451 escaped_table,
452 set_parts.join(", "),
453 escaped_pk,
454 )
455 } else {
456 return Ok(());
458 };
459
460 let sql = executor.convert_sql(&raw_sql);
461
462 match executor.driver() {
463 sqlxplus::db_pool::DbDriver::MySql => {
464 let mut query = sqlx::query(&sql);
465 #(
467 query = query.bind(&self.#update_normal_field_names);
468 )*
469 #(
471 if self.#update_option_field_names.is_some() {
472 query = query.bind(&self.#update_option_field_names);
473 }
474 )*
475 query = query.bind(&self.#pk_ident);
476 if let Some(tx_ref) = executor.mysql_transaction_ref() {
477 query.execute(&mut **tx_ref).await?;
478 } else if let Some(pool_ref) = executor.mysql_pool() {
479 query.execute(pool_ref).await?;
480 } else {
481 return Err(sqlxplus::db_pool::DbPoolError::NoPoolAvailable);
482 }
483 }
484 sqlxplus::db_pool::DbDriver::Postgres => {
485 let mut query = sqlx::query(&sql);
486 #(
488 query = query.bind(&self.#update_normal_field_names);
489 )*
490 #(
492 if self.#update_option_field_names.is_some() {
493 query = query.bind(&self.#update_option_field_names);
494 }
495 )*
496 query = query.bind(&self.#pk_ident);
497 if let Some(tx_ref) = executor.postgres_transaction_ref() {
498 query.execute(&mut **tx_ref).await?;
499 } else if let Some(pool_ref) = executor.pg_pool() {
500 query.execute(pool_ref).await?;
501 } else {
502 return Err(sqlxplus::db_pool::DbPoolError::NoPoolAvailable);
503 }
504 }
505 sqlxplus::db_pool::DbDriver::Sqlite => {
506 let mut query = sqlx::query(&sql);
507 #(
509 query = query.bind(&self.#update_normal_field_names);
510 )*
511 #(
513 if self.#update_option_field_names.is_some() {
514 query = query.bind(&self.#update_option_field_names);
515 }
516 )*
517 query = query.bind(&self.#pk_ident);
518 if let Some(tx_ref) = executor.sqlite_transaction_ref() {
519 query.execute(&mut **tx_ref).await?;
520 } else if let Some(pool_ref) = executor.sqlite_pool() {
521 query.execute(pool_ref).await?;
522 } else {
523 return Err(sqlxplus::db_pool::DbPoolError::NoPoolAvailable);
524 }
525 }
526 }
527 Ok(())
528 }
529
530 async fn update_with_none<E>(&self, executor: &mut E) -> sqlxplus::db_pool::Result<()>
537 where
538 E: sqlxplus::executor::DbExecutor + Send,
539 {
540 use sqlxplus::Model;
541 use sqlxplus::utils::escape_identifier;
542 let table = Self::TABLE;
543 let pk = Self::PK;
544 let driver = executor.driver();
545 let escaped_table = escape_identifier(driver, table);
546 let escaped_pk = escape_identifier(driver, pk);
547
548 let mut set_parts: Vec<String> = Vec::new();
550
551 #(
553 set_parts.push(format!("{} = ?", #update_normal_field_columns));
554 )*
555
556 match driver {
559 sqlxplus::db_pool::DbDriver::Sqlite => {
560 #(
562 if self.#update_option_field_names.is_some() {
563 set_parts.push(format!("{} = ?", #update_option_field_columns));
564 }
565 )*
567 }
568 _ => {
569 #(
571 if self.#update_option_field_names.is_some() {
572 set_parts.push(format!("{} = ?", #update_option_field_columns));
573 } else {
574 set_parts.push(format!("{} = DEFAULT", #update_option_field_columns));
575 }
576 )*
577 }
578 }
579
580 if set_parts.is_empty() {
581 return Ok(());
582 }
583
584 let raw_sql = format!(
585 "UPDATE {} SET {} WHERE {} = ?",
586 escaped_table,
587 set_parts.join(", "),
588 escaped_pk,
589 );
590
591 let sql = executor.convert_sql(&raw_sql);
592
593 match executor.driver() {
594 sqlxplus::db_pool::DbDriver::MySql => {
595 let mut query = sqlx::query(&sql);
596 #(
598 query = query.bind(&self.#update_normal_field_names);
599 )*
600 #(
602 if self.#update_option_field_names.is_some() {
603 query = query.bind(&self.#update_option_field_names);
604 }
605 )*
606 query = query.bind(&self.#pk_ident);
607 if let Some(tx_ref) = executor.mysql_transaction_ref() {
608 query.execute(&mut **tx_ref).await?;
609 } else if let Some(pool_ref) = executor.mysql_pool() {
610 query.execute(pool_ref).await?;
611 } else {
612 return Err(sqlxplus::db_pool::DbPoolError::NoPoolAvailable);
613 }
614 }
615 sqlxplus::db_pool::DbDriver::Postgres => {
616 let mut query = sqlx::query(&sql);
617 #(
619 query = query.bind(&self.#update_normal_field_names);
620 )*
621 #(
623 if self.#update_option_field_names.is_some() {
624 query = query.bind(&self.#update_option_field_names);
625 }
626 )*
627 query = query.bind(&self.#pk_ident);
628 if let Some(tx_ref) = executor.postgres_transaction_ref() {
629 query.execute(&mut **tx_ref).await?;
630 } else if let Some(pool_ref) = executor.pg_pool() {
631 query.execute(pool_ref).await?;
632 } else {
633 return Err(sqlxplus::db_pool::DbPoolError::NoPoolAvailable);
634 }
635 }
636 sqlxplus::db_pool::DbDriver::Sqlite => {
637 let mut query = sqlx::query(&sql);
638 #(
640 query = query.bind(&self.#update_normal_field_names);
641 )*
642 #(
644 if self.#update_option_field_names.is_some() {
645 query = query.bind(&self.#update_option_field_names);
646 }
647 )*
648 query = query.bind(&self.#pk_ident);
649 if let Some(tx_ref) = executor.sqlite_transaction_ref() {
650 query.execute(&mut **tx_ref).await?;
651 } else if let Some(pool_ref) = executor.sqlite_pool() {
652 query.execute(pool_ref).await?;
653 } else {
654 return Err(sqlxplus::db_pool::DbPoolError::NoPoolAvailable);
655 }
656 }
657 }
658 Ok(())
659 }
660 }
661 };
662
663 TokenStream::from(expanded)
664}
665
666fn is_option_type(ty: &syn::Type) -> bool {
668 if let syn::Type::Path(type_path) = ty {
669 if let Some(seg) = type_path.path.segments.last() {
670 if seg.ident == "Option" {
671 if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
672 return args.args.len() == 1;
673 }
674 }
675 }
676 }
677 false
678}