1use proc_macro::TokenStream;
2use proc_macro2;
3use quote::quote;
4use syn::{parse::Parser, parse_macro_input, Data, DataStruct, DeriveInput, Fields, Meta};
5
6#[proc_macro_derive(ModelMeta, attributes(model, column))]
30pub fn derive_model_meta(input: TokenStream) -> TokenStream {
31 let input = parse_macro_input!(input as DeriveInput);
32 let name = &input.ident;
33
34 let mut table_name = None;
36 let mut pk_field = None;
37 let mut soft_delete_field = None;
38
39 for attr in &input.attrs {
40 if attr.path().is_ident("model") {
41 if let syn::Meta::List(list) = &attr.meta {
43 let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
45 if let Ok(metas) = parser.parse2(list.tokens.clone()) {
46 for meta in metas {
47 if let Meta::NameValue(nv) = meta {
48 if nv.path.is_ident("table") {
49 if let syn::Expr::Lit(syn::ExprLit {
50 lit: syn::Lit::Str(s),
51 ..
52 }) = nv.value
53 {
54 table_name = Some(s.value());
55 }
56 } else if nv.path.is_ident("pk") {
57 if let syn::Expr::Lit(syn::ExprLit {
58 lit: syn::Lit::Str(s),
59 ..
60 }) = nv.value
61 {
62 pk_field = Some(s.value());
63 }
64 } else if nv.path.is_ident("soft_delete") {
65 if let syn::Expr::Lit(syn::ExprLit {
66 lit: syn::Lit::Str(s),
67 ..
68 }) = nv.value
69 {
70 soft_delete_field = Some(s.value());
71 }
72 }
73 }
74 }
75 }
76 } else if let syn::Meta::NameValue(nv) = &attr.meta {
77 if nv.path.is_ident("table") {
79 if let syn::Expr::Lit(syn::ExprLit {
80 lit: syn::Lit::Str(s),
81 ..
82 }) = &nv.value
83 {
84 table_name = Some(s.value());
85 }
86 } else if nv.path.is_ident("pk") {
87 if let syn::Expr::Lit(syn::ExprLit {
88 lit: syn::Lit::Str(s),
89 ..
90 }) = &nv.value
91 {
92 pk_field = Some(s.value());
93 }
94 } else if nv.path.is_ident("soft_delete") {
95 if let syn::Expr::Lit(syn::ExprLit {
96 lit: syn::Lit::Str(s),
97 ..
98 }) = &nv.value
99 {
100 soft_delete_field = Some(s.value());
101 }
102 }
103 }
104 }
105 }
106
107 let table = table_name.unwrap_or_else(|| {
109 let s = name.to_string();
110 let mut result = String::new();
112 for (i, c) in s.chars().enumerate() {
113 if c.is_uppercase() && i > 0 {
114 result.push('_');
115 }
116 result.push(c.to_ascii_lowercase());
117 }
118 result
119 });
120
121 let pk = pk_field.unwrap_or_else(|| "id".to_string());
123
124 let expanded = if let Some(soft_delete) = soft_delete_field {
126 let soft_delete_lit = syn::LitStr::new(&soft_delete, proc_macro2::Span::call_site());
128 quote! {
129 impl sqlxplus::Model for #name {
130 const TABLE: &'static str = #table;
131 const PK: &'static str = #pk;
132 const SOFT_DELETE_FIELD: Option<&'static str> = Some(#soft_delete_lit);
133 }
134 }
135 } else {
136 quote! {
138 impl sqlxplus::Model for #name {
139 const TABLE: &'static str = #table;
140 const PK: &'static str = #pk;
141 const SOFT_DELETE_FIELD: Option<&'static str> = None;
142 }
143 }
144 };
145
146 TokenStream::from(expanded)
147}
148
149#[proc_macro_derive(CRUD, attributes(model, skip, column))]
175pub fn derive_crud(input: TokenStream) -> TokenStream {
176 let input = parse_macro_input!(input as DeriveInput);
177 let name = &input.ident;
178
179 let mut pk_field = None;
181 for attr in &input.attrs {
182 if attr.path().is_ident("model") {
183 if let syn::Meta::List(list) = &attr.meta {
184 let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
185 if let Ok(metas) = parser.parse2(list.tokens.clone()) {
186 for meta in metas {
187 if let Meta::NameValue(nv) = meta {
188 if nv.path.is_ident("pk") {
189 if let syn::Expr::Lit(syn::ExprLit {
190 lit: syn::Lit::Str(s),
191 ..
192 }) = nv.value
193 {
194 pk_field = Some(s.value());
195 }
196 }
197 }
198 }
199 }
200 } else if let syn::Meta::NameValue(nv) = &attr.meta {
201 if nv.path.is_ident("pk") {
202 if let syn::Expr::Lit(syn::ExprLit {
203 lit: syn::Lit::Str(s),
204 ..
205 }) = &nv.value
206 {
207 pk_field = Some(s.value());
208 }
209 }
210 }
211 }
212 }
213 let pk = pk_field.unwrap_or_else(|| "id".to_string());
215
216 let fields = match &input.data {
218 Data::Struct(DataStruct {
219 fields: Fields::Named(fields),
220 ..
221 }) => &fields.named,
222 _ => {
223 return syn::Error::new_spanned(
224 name,
225 "CRUD derive only supports structs with named fields",
226 )
227 .to_compile_error()
228 .into();
229 }
230 };
231
232 let mut pk_ident_opt: Option<&syn::Ident> = None;
236
237 let mut insert_normal_field_names: Vec<&syn::Ident> = Vec::new();
239 let mut insert_normal_field_columns: Vec<syn::LitStr> = Vec::new();
240 let mut insert_option_field_names: Vec<&syn::Ident> = Vec::new();
241 let mut insert_option_field_columns: Vec<syn::LitStr> = Vec::new();
242
243 let mut update_normal_field_names: Vec<&syn::Ident> = Vec::new();
245 let mut update_normal_field_columns: Vec<syn::LitStr> = Vec::new();
246 let mut update_option_field_names: Vec<&syn::Ident> = Vec::new();
247 let mut update_option_field_columns: Vec<syn::LitStr> = Vec::new();
248
249 for field in fields {
250 let field_name = field.ident.as_ref().unwrap();
251 let field_name_str = field_name.to_string();
252
253 let mut skip = false;
255 for attr in &field.attrs {
256 if attr.path().is_ident("skip") || attr.path().is_ident("model") {
257 skip = true;
258 break;
259 }
260 }
261
262 if !skip {
263 if field_name_str == pk {
264 pk_ident_opt = Some(field_name);
266 } else {
267 let is_opt = is_option_type(&field.ty);
269 let col_lit = syn::LitStr::new(&field_name_str, proc_macro2::Span::call_site());
270
271 if is_opt {
272 insert_option_field_names.push(field_name);
273 insert_option_field_columns.push(col_lit.clone());
274
275 update_option_field_names.push(field_name);
276 update_option_field_columns.push(col_lit);
277 } else {
278 insert_normal_field_names.push(field_name);
279 insert_normal_field_columns.push(col_lit.clone());
280
281 update_normal_field_names.push(field_name);
282 update_normal_field_columns.push(col_lit);
283 }
284 }
285 }
286 }
287
288 let pk_ident = pk_ident_opt.expect("Primary key field not found in struct");
290
291 let expanded = quote! {
293 #[async_trait::async_trait]
295 impl sqlxplus::Crud for #name {
296 async fn insert<'e, 'c: 'e, DB, E>(&self, executor: E) -> sqlxplus::Result<sqlxplus::crud::Id>
298 where
299 DB: sqlx::Database + sqlxplus::DatabaseInfo,
300 for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
301 E: sqlxplus::DatabaseType<DB = DB>
302 + sqlx::Executor<'c, Database = DB>
303 + Send,
304 i64: sqlx::Type<DB> + for<'r> sqlx::Decode<'r, DB>,
305 usize: sqlx::ColumnIndex<DB::Row>,
306 String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
309 i64: for<'b> sqlx::Encode<'b, DB>,
310 i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
311 i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
312 f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
313 f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
314 bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
315 Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
316 Option<i64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
317 Option<i32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
318 Option<i16>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
319 Option<f64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
320 Option<f32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
321 Option<bool>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
322 chrono::DateTime<chrono::Utc>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
323 Option<chrono::DateTime<chrono::Utc>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
324 chrono::NaiveDateTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
325 Option<chrono::NaiveDateTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
326 chrono::NaiveDate: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
327 Option<chrono::NaiveDate>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
328 chrono::NaiveTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
329 Option<chrono::NaiveTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
330 Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
331 Option<Vec<u8>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
332 serde_json::Value: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
333 Option<serde_json::Value>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
334 {
335 use sqlxplus::Model;
336 use sqlxplus::DatabaseInfo;
337 use sqlxplus::db_pool::DbDriver;
338 let table = Self::TABLE;
339 let escaped_table = DB::escape_identifier(table);
340
341 let mut columns: Vec<&str> = Vec::new();
343 let mut placeholders: Vec<String> = Vec::new();
344 let mut placeholder_index = 0;
345
346 #(
348 columns.push(#insert_normal_field_columns);
349 placeholders.push(DB::placeholder(placeholder_index));
350 placeholder_index += 1;
351 )*
352
353 #(
355 if self.#insert_option_field_names.is_some() {
356 columns.push(#insert_option_field_columns);
357 placeholders.push(DB::placeholder(placeholder_index));
358 placeholder_index += 1;
359 }
360 )*
361
362 let sql = match DB::get_driver() {
364 DbDriver::Postgres => {
365 let pk = Self::PK;
366 let escaped_pk = DB::escape_identifier(pk);
367 format!(
368 "INSERT INTO {} ({}) VALUES ({}) RETURNING {}",
369 escaped_table,
370 columns.join(", "),
371 placeholders.join(", "),
372 escaped_pk
373 )
374 }
375 _ => {
376 format!(
377 "INSERT INTO {} ({}) VALUES ({})",
378 escaped_table,
379 columns.join(", "),
380 placeholders.join(", ")
381 )
382 }
383 };
384
385 match DB::get_driver() {
387 DbDriver::Postgres => {
388 let mut query = sqlx::query_scalar::<_, i64>(&sql);
389 #(
391 query = query.bind(&self.#insert_normal_field_names);
392 )*
393 #(
395 if let Some(ref val) = self.#insert_option_field_names {
396 query = query.bind(val);
397 }
398 )*
399 let id: i64 = query.fetch_one(executor).await?;
400 Ok(id)
401 }
402 DbDriver::MySql => {
403 let mut query = sqlx::query(&sql);
404 #(
406 query = query.bind(&self.#insert_normal_field_names);
407 )*
408 #(
410 if let Some(ref val) = self.#insert_option_field_names {
411 query = query.bind(val);
412 }
413 )*
414 let result = query.execute(executor).await?;
415 unsafe {
419 use sqlx::mysql::MySqlQueryResult;
420 let ptr: *const DB::QueryResult = &result;
421 let mysql_ptr = ptr as *const MySqlQueryResult;
422 Ok((*mysql_ptr).last_insert_id() as i64)
423 }
424 }
425 DbDriver::Sqlite => {
426 let mut query = sqlx::query(&sql);
427 #(
429 query = query.bind(&self.#insert_normal_field_names);
430 )*
431 #(
433 if let Some(ref val) = self.#insert_option_field_names {
434 query = query.bind(val);
435 }
436 )*
437 let result = query.execute(executor).await?;
438 unsafe {
440 use sqlx::sqlite::SqliteQueryResult;
441 let ptr: *const DB::QueryResult = &result;
442 let sqlite_ptr = ptr as *const SqliteQueryResult;
443 Ok((*sqlite_ptr).last_insert_rowid() as i64)
444 }
445 }
446 }
447 }
448
449 async fn update<'e, 'c: 'e, DB, E>(&self, executor: E) -> sqlxplus::Result<()>
451 where
452 DB: sqlx::Database + sqlxplus::DatabaseInfo,
453 for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
454 E: sqlxplus::DatabaseType<DB = DB>
455 + sqlx::Executor<'c, Database = DB>
456 + Send,
457 String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
460 i64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
461 i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
462 i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
463 f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
464 f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
465 bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
466 Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
467 Option<i64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
468 Option<i32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
469 Option<i16>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
470 Option<f64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
471 Option<f32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
472 Option<bool>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
473 chrono::DateTime<chrono::Utc>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
474 Option<chrono::DateTime<chrono::Utc>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
475 chrono::NaiveDateTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
476 Option<chrono::NaiveDateTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
477 chrono::NaiveDate: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
478 Option<chrono::NaiveDate>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
479 chrono::NaiveTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
480 Option<chrono::NaiveTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
481 Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
482 Option<Vec<u8>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
483 serde_json::Value: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
484 Option<serde_json::Value>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
485 {
486 use sqlxplus::Model;
487 use sqlxplus::DatabaseInfo;
488 let table = Self::TABLE;
489 let pk = Self::PK;
490 let escaped_table = DB::escape_identifier(table);
491 let escaped_pk = DB::escape_identifier(pk);
492
493 let mut set_parts: Vec<String> = Vec::new();
495 let mut placeholder_index = 0;
496
497 #(
499 set_parts.push(format!("{} = {}", DB::escape_identifier(#update_normal_field_columns), DB::placeholder(placeholder_index)));
500 placeholder_index += 1;
501 )*
502
503 #(
505 if self.#update_option_field_names.is_some() {
506 set_parts.push(format!("{} = {}", DB::escape_identifier(#update_option_field_columns), DB::placeholder(placeholder_index)));
507 placeholder_index += 1;
508 }
509 )*
510
511 if set_parts.is_empty() {
512 return Ok(());
513 }
514
515 let sql = format!(
516 "UPDATE {} SET {} WHERE {} = {}",
517 escaped_table,
518 set_parts.join(", "),
519 escaped_pk,
520 DB::placeholder(placeholder_index)
521 );
522
523 let mut query = sqlx::query(&sql);
524 #(
526 query = query.bind(&self.#update_normal_field_names);
527 )*
528 #(
530 if let Some(ref val) = self.#update_option_field_names {
531 query = query.bind(val);
532 }
533 )*
534 query = query.bind(&self.#pk_ident);
535 query.execute(executor).await?;
536 Ok(())
537 }
538
539 async fn update_with_none<'e, 'c: 'e, DB, E>(&self, executor: E) -> sqlxplus::Result<()>
541 where
542 DB: sqlx::Database + sqlxplus::DatabaseInfo,
543 for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
544 E: sqlxplus::DatabaseType<DB = DB>
545 + sqlx::Executor<'c, Database = DB>
546 + Send,
547 String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
550 i64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
551 i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
552 i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
553 f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
554 f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
555 bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
556 Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
557 Option<i64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
558 Option<i32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
559 Option<i16>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
560 Option<f64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
561 Option<f32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
562 Option<bool>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
563 chrono::DateTime<chrono::Utc>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
564 Option<chrono::DateTime<chrono::Utc>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
565 chrono::NaiveDateTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
566 Option<chrono::NaiveDateTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
567 chrono::NaiveDate: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
568 Option<chrono::NaiveDate>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
569 chrono::NaiveTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
570 Option<chrono::NaiveTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
571 Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
572 Option<Vec<u8>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
573 serde_json::Value: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
574 Option<serde_json::Value>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
575 {
576 use sqlxplus::Model;
577 use sqlxplus::DatabaseInfo;
578 use sqlxplus::db_pool::DbDriver;
579 let table = Self::TABLE;
580 let pk = Self::PK;
581 let escaped_table = DB::escape_identifier(table);
582 let escaped_pk = DB::escape_identifier(pk);
583
584 let mut set_parts: Vec<String> = Vec::new();
586 let mut placeholder_index = 0;
587
588 #(
590 set_parts.push(format!("{} = {}", DB::escape_identifier(#update_normal_field_columns), DB::placeholder(placeholder_index)));
591 placeholder_index += 1;
592 )*
593
594 match DB::get_driver() {
596 DbDriver::Sqlite => {
597 #(
599 if self.#update_option_field_names.is_some() {
600 set_parts.push(format!("{} = {}", DB::escape_identifier(#update_option_field_columns), DB::placeholder(placeholder_index)));
601 placeholder_index += 1;
602 }
603 )*
604 }
605 _ => {
606 #(
608 if self.#update_option_field_names.is_some() {
609 set_parts.push(format!("{} = {}", DB::escape_identifier(#update_option_field_columns), DB::placeholder(placeholder_index)));
610 placeholder_index += 1;
611 } else {
612 set_parts.push(format!("{} = DEFAULT", DB::escape_identifier(#update_option_field_columns)));
613 }
614 )*
615 }
616 }
617
618 if set_parts.is_empty() {
619 return Ok(());
620 }
621
622 let sql = format!(
623 "UPDATE {} SET {} WHERE {} = {}",
624 escaped_table,
625 set_parts.join(", "),
626 escaped_pk,
627 DB::placeholder(placeholder_index)
628 );
629
630 let mut query = sqlx::query(&sql);
631 #(
633 query = query.bind(&self.#update_normal_field_names);
634 )*
635 #(
637 if let Some(ref val) = self.#update_option_field_names {
638 query = query.bind(val);
639 }
640 )*
641 query = query.bind(&self.#pk_ident);
642 query.execute(executor).await?;
643 Ok(())
644 }
645 }
646 };
647
648 TokenStream::from(expanded)
649}
650
651fn is_option_type(ty: &syn::Type) -> bool {
653 if let syn::Type::Path(type_path) = ty {
654 if let Some(seg) = type_path.path.segments.last() {
655 if seg.ident == "Option" {
656 if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
657 return args.args.len() == 1;
658 }
659 }
660 }
661 }
662 false
663}