1use proc_macro::TokenStream;
2use proc_macro2;
3use quote::quote;
4use syn::{parse::Parser, parse_macro_input, Data, DataStruct, DeriveInput, Fields, Meta};
5
6#[proc_macro_derive(ModelMeta, attributes(model, column))]
30pub fn derive_model_meta(input: TokenStream) -> TokenStream {
31 let input = parse_macro_input!(input as DeriveInput);
32 let name = &input.ident;
33
34 let mut table_name = None;
36 let mut pk_field = None;
37 let mut soft_delete_field = None;
38
39 for attr in &input.attrs {
40 if attr.path().is_ident("model") {
41 if let syn::Meta::List(list) = &attr.meta {
43 let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
45 if let Ok(metas) = parser.parse2(list.tokens.clone()) {
46 for meta in metas {
47 if let Meta::NameValue(nv) = meta {
48 if nv.path.is_ident("table") {
49 if let syn::Expr::Lit(syn::ExprLit {
50 lit: syn::Lit::Str(s),
51 ..
52 }) = nv.value
53 {
54 table_name = Some(s.value());
55 }
56 } else if nv.path.is_ident("pk") {
57 if let syn::Expr::Lit(syn::ExprLit {
58 lit: syn::Lit::Str(s),
59 ..
60 }) = nv.value
61 {
62 pk_field = Some(s.value());
63 }
64 } else if nv.path.is_ident("soft_delete") {
65 if let syn::Expr::Lit(syn::ExprLit {
66 lit: syn::Lit::Str(s),
67 ..
68 }) = nv.value
69 {
70 soft_delete_field = Some(s.value());
71 }
72 }
73 }
74 }
75 }
76 } else if let syn::Meta::NameValue(nv) = &attr.meta {
77 if nv.path.is_ident("table") {
79 if let syn::Expr::Lit(syn::ExprLit {
80 lit: syn::Lit::Str(s),
81 ..
82 }) = &nv.value
83 {
84 table_name = Some(s.value());
85 }
86 } else if nv.path.is_ident("pk") {
87 if let syn::Expr::Lit(syn::ExprLit {
88 lit: syn::Lit::Str(s),
89 ..
90 }) = &nv.value
91 {
92 pk_field = Some(s.value());
93 }
94 } else if nv.path.is_ident("soft_delete") {
95 if let syn::Expr::Lit(syn::ExprLit {
96 lit: syn::Lit::Str(s),
97 ..
98 }) = &nv.value
99 {
100 soft_delete_field = Some(s.value());
101 }
102 }
103 }
104 }
105 }
106
107 let table = table_name.unwrap_or_else(|| {
109 let s = name.to_string();
110 let mut result = String::new();
112 for (i, c) in s.chars().enumerate() {
113 if c.is_uppercase() && i > 0 {
114 result.push('_');
115 }
116 result.push(c.to_ascii_lowercase());
117 }
118 result
119 });
120
121 let pk = pk_field.unwrap_or_else(|| "id".to_string());
123
124 let expanded = if let Some(soft_delete) = soft_delete_field {
126 let soft_delete_lit = syn::LitStr::new(&soft_delete, proc_macro2::Span::call_site());
128 quote! {
129 impl sqlxplus::Model for #name {
130 const TABLE: &'static str = #table;
131 const PK: &'static str = #pk;
132 const SOFT_DELETE_FIELD: Option<&'static str> = Some(#soft_delete_lit);
133 }
134 }
135 } else {
136 quote! {
138 impl sqlxplus::Model for #name {
139 const TABLE: &'static str = #table;
140 const PK: &'static str = #pk;
141 const SOFT_DELETE_FIELD: Option<&'static str> = None;
142 }
143 }
144 };
145
146 TokenStream::from(expanded)
147}
148
149#[proc_macro_derive(CRUD, attributes(model, skip, column))]
175pub fn derive_crud(input: TokenStream) -> TokenStream {
176 let input = parse_macro_input!(input as DeriveInput);
177 let name = &input.ident;
178
179 let mut pk_field = None;
181 for attr in &input.attrs {
182 if attr.path().is_ident("model") {
183 if let syn::Meta::List(list) = &attr.meta {
184 let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
185 if let Ok(metas) = parser.parse2(list.tokens.clone()) {
186 for meta in metas {
187 if let Meta::NameValue(nv) = meta {
188 if nv.path.is_ident("pk") {
189 if let syn::Expr::Lit(syn::ExprLit {
190 lit: syn::Lit::Str(s),
191 ..
192 }) = nv.value
193 {
194 pk_field = Some(s.value());
195 }
196 }
197 }
198 }
199 }
200 } else if let syn::Meta::NameValue(nv) = &attr.meta {
201 if nv.path.is_ident("pk") {
202 if let syn::Expr::Lit(syn::ExprLit {
203 lit: syn::Lit::Str(s),
204 ..
205 }) = &nv.value
206 {
207 pk_field = Some(s.value());
208 }
209 }
210 }
211 }
212 }
213 let pk = pk_field.unwrap_or_else(|| "id".to_string());
215
216 let fields = match &input.data {
218 Data::Struct(DataStruct {
219 fields: Fields::Named(fields),
220 ..
221 }) => &fields.named,
222 _ => {
223 return syn::Error::new_spanned(
224 name,
225 "CRUD derive only supports structs with named fields",
226 )
227 .to_compile_error()
228 .into();
229 }
230 };
231
232 let mut pk_ident_opt: Option<&syn::Ident> = None;
236
237 let mut insert_normal_field_names: Vec<&syn::Ident> = Vec::new();
239 let mut insert_normal_field_columns: Vec<syn::LitStr> = Vec::new();
240 let mut insert_option_field_names: Vec<&syn::Ident> = Vec::new();
241 let mut insert_option_field_columns: Vec<syn::LitStr> = Vec::new();
242
243 let mut update_normal_field_names: Vec<&syn::Ident> = Vec::new();
245 let mut update_normal_field_columns: Vec<syn::LitStr> = Vec::new();
246 let mut update_option_field_names: Vec<&syn::Ident> = Vec::new();
247 let mut update_option_field_columns: Vec<syn::LitStr> = Vec::new();
248
249 let mut update_fields_normal_field_names: Vec<&syn::Ident> = Vec::new();
251 let mut update_fields_normal_field_columns: Vec<syn::LitStr> = Vec::new();
252 let mut update_fields_option_field_names: Vec<&syn::Ident> = Vec::new();
253 let mut update_fields_option_field_columns: Vec<syn::LitStr> = Vec::new();
254
255 for field in fields {
256 let field_name = field.ident.as_ref().unwrap();
257 let field_name_str = field_name.to_string();
258
259 let mut skip = false;
261 for attr in &field.attrs {
262 if attr.path().is_ident("skip") || attr.path().is_ident("model") {
263 skip = true;
264 break;
265 }
266 }
267
268 if !skip {
269 if field_name_str == pk {
270 pk_ident_opt = Some(field_name);
272 let is_opt = is_option_type(&field.ty);
274 let col_lit = syn::LitStr::new(&field_name_str, proc_macro2::Span::call_site());
275 let is_supported = if is_opt {
276 if let Some(inner_ty) = get_option_inner_type(&field.ty) {
277 is_bind_value_supported_type(inner_ty)
278 } else {
279 false
280 }
281 } else {
282 is_bind_value_supported_type(&field.ty)
283 };
284 if is_supported {
286 if is_opt {
287 update_fields_option_field_names.push(field_name);
288 update_fields_option_field_columns.push(col_lit);
289 } else {
290 update_fields_normal_field_names.push(field_name);
291 update_fields_normal_field_columns.push(col_lit);
292 }
293 }
294 } else {
295 let is_opt = is_option_type(&field.ty);
297 let col_lit = syn::LitStr::new(&field_name_str, proc_macro2::Span::call_site());
298
299 let is_supported = if is_opt {
301 if let Some(inner_ty) = get_option_inner_type(&field.ty) {
302 is_bind_value_supported_type(inner_ty)
303 } else {
304 false
305 }
306 } else {
307 is_bind_value_supported_type(&field.ty)
308 };
309
310 if is_opt {
311 insert_option_field_names.push(field_name);
312 insert_option_field_columns.push(col_lit.clone());
313
314 update_option_field_names.push(field_name);
315 update_option_field_columns.push(col_lit.clone());
316
317 if is_supported {
319 update_fields_option_field_names.push(field_name);
320 update_fields_option_field_columns.push(col_lit);
321 }
322 } else {
323 insert_normal_field_names.push(field_name);
324 insert_normal_field_columns.push(col_lit.clone());
325
326 update_normal_field_names.push(field_name);
327 update_normal_field_columns.push(col_lit.clone());
328
329 if is_supported {
331 update_fields_normal_field_names.push(field_name);
332 update_fields_normal_field_columns.push(col_lit);
333 }
334 }
335 }
336 }
337 }
338
339 let pk_ident = pk_ident_opt.expect("Primary key field not found in struct");
341
342 let expanded = quote! {
344 #[async_trait::async_trait]
346 impl sqlxplus::Crud for #name {
347 async fn insert<'e, 'c: 'e, DB, E>(&self, executor: E) -> sqlxplus::Result<sqlxplus::crud::Id>
349 where
350 DB: sqlx::Database + sqlxplus::DatabaseInfo,
351 for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
352 E: sqlxplus::DatabaseType<DB = DB>
353 + sqlx::Executor<'c, Database = DB>
354 + Send,
355 i64: sqlx::Type<DB> + for<'r> sqlx::Decode<'r, DB>,
356 usize: sqlx::ColumnIndex<DB::Row>,
357 String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
360 i64: for<'b> sqlx::Encode<'b, DB>,
361 i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
362 i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
363 f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
364 f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
365 bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
366 Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
367 Option<i64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
368 Option<i32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
369 Option<i16>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
370 Option<f64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
371 Option<f32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
372 Option<bool>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
373 chrono::DateTime<chrono::Utc>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
374 Option<chrono::DateTime<chrono::Utc>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
375 chrono::NaiveDateTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
376 Option<chrono::NaiveDateTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
377 chrono::NaiveDate: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
378 Option<chrono::NaiveDate>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
379 chrono::NaiveTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
380 Option<chrono::NaiveTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
381 Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
382 Option<Vec<u8>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
383 serde_json::Value: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
384 Option<serde_json::Value>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
385 {
386 use sqlxplus::Model;
387 use sqlxplus::DatabaseInfo;
388 use sqlxplus::db_pool::DbDriver;
389 let table = Self::TABLE;
390 let escaped_table = DB::escape_identifier(table);
391
392 let mut columns: Vec<&str> = Vec::new();
394 let mut placeholders: Vec<String> = Vec::new();
395 let mut placeholder_index = 0;
396
397 #(
399 columns.push(#insert_normal_field_columns);
400 placeholders.push(DB::placeholder(placeholder_index));
401 placeholder_index += 1;
402 )*
403
404 #(
406 if self.#insert_option_field_names.is_some() {
407 columns.push(#insert_option_field_columns);
408 placeholders.push(DB::placeholder(placeholder_index));
409 placeholder_index += 1;
410 }
411 )*
412
413 let sql = match DB::get_driver() {
415 DbDriver::Postgres => {
416 let pk = Self::PK;
417 let escaped_pk = DB::escape_identifier(pk);
418 format!(
419 "INSERT INTO {} ({}) VALUES ({}) RETURNING {}",
420 escaped_table,
421 columns.join(", "),
422 placeholders.join(", "),
423 escaped_pk
424 )
425 }
426 _ => {
427 format!(
428 "INSERT INTO {} ({}) VALUES ({})",
429 escaped_table,
430 columns.join(", "),
431 placeholders.join(", ")
432 )
433 }
434 };
435
436 match DB::get_driver() {
438 DbDriver::Postgres => {
439 let mut query = sqlx::query_scalar::<_, i64>(&sql);
440 #(
442 query = query.bind(&self.#insert_normal_field_names);
443 )*
444 #(
446 if let Some(ref val) = self.#insert_option_field_names {
447 query = query.bind(val);
448 }
449 )*
450 let id: i64 = query.fetch_one(executor).await?;
451 Ok(id)
452 }
453 DbDriver::MySql => {
454 let mut query = sqlx::query(&sql);
455 #(
457 query = query.bind(&self.#insert_normal_field_names);
458 )*
459 #(
461 if let Some(ref val) = self.#insert_option_field_names {
462 query = query.bind(val);
463 }
464 )*
465 let result = query.execute(executor).await?;
466 unsafe {
470 use sqlx::mysql::MySqlQueryResult;
471 let ptr: *const DB::QueryResult = &result;
472 let mysql_ptr = ptr as *const MySqlQueryResult;
473 Ok((*mysql_ptr).last_insert_id() as i64)
474 }
475 }
476 DbDriver::Sqlite => {
477 let mut query = sqlx::query(&sql);
478 #(
480 query = query.bind(&self.#insert_normal_field_names);
481 )*
482 #(
484 if let Some(ref val) = self.#insert_option_field_names {
485 query = query.bind(val);
486 }
487 )*
488 let result = query.execute(executor).await?;
489 unsafe {
491 use sqlx::sqlite::SqliteQueryResult;
492 let ptr: *const DB::QueryResult = &result;
493 let sqlite_ptr = ptr as *const SqliteQueryResult;
494 Ok((*sqlite_ptr).last_insert_rowid() as i64)
495 }
496 }
497 }
498 }
499
500 async fn update<'e, 'c: 'e, DB, E>(&self, executor: E) -> sqlxplus::Result<()>
502 where
503 DB: sqlx::Database + sqlxplus::DatabaseInfo,
504 for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
505 E: sqlxplus::DatabaseType<DB = DB>
506 + sqlx::Executor<'c, Database = DB>
507 + Send,
508 String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
511 i64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
512 i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
513 i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
514 f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
515 f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
516 bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
517 Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
518 Option<i64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
519 Option<i32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
520 Option<i16>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
521 Option<f64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
522 Option<f32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
523 Option<bool>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
524 chrono::DateTime<chrono::Utc>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
525 Option<chrono::DateTime<chrono::Utc>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
526 chrono::NaiveDateTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
527 Option<chrono::NaiveDateTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
528 chrono::NaiveDate: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
529 Option<chrono::NaiveDate>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
530 chrono::NaiveTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
531 Option<chrono::NaiveTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
532 Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
533 Option<Vec<u8>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
534 serde_json::Value: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
535 Option<serde_json::Value>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
536 {
537 use sqlxplus::Model;
538 use sqlxplus::DatabaseInfo;
539 let table = Self::TABLE;
540 let pk = Self::PK;
541 let escaped_table = DB::escape_identifier(table);
542 let escaped_pk = DB::escape_identifier(pk);
543
544 let mut set_parts: Vec<String> = Vec::new();
546 let mut placeholder_index = 0;
547
548 #(
550 set_parts.push(format!("{} = {}", DB::escape_identifier(#update_normal_field_columns), DB::placeholder(placeholder_index)));
551 placeholder_index += 1;
552 )*
553
554 #(
556 if self.#update_option_field_names.is_some() {
557 set_parts.push(format!("{} = {}", DB::escape_identifier(#update_option_field_columns), DB::placeholder(placeholder_index)));
558 placeholder_index += 1;
559 }
560 )*
561
562 if set_parts.is_empty() {
563 return Ok(());
564 }
565
566 let sql = format!(
567 "UPDATE {} SET {} WHERE {} = {}",
568 escaped_table,
569 set_parts.join(", "),
570 escaped_pk,
571 DB::placeholder(placeholder_index)
572 );
573
574 let mut query = sqlx::query(&sql);
575 #(
577 query = query.bind(&self.#update_normal_field_names);
578 )*
579 #(
581 if let Some(ref val) = self.#update_option_field_names {
582 query = query.bind(val);
583 }
584 )*
585 query = query.bind(&self.#pk_ident);
586 query.execute(executor).await?;
587 Ok(())
588 }
589
590 async fn update_with_none<'e, 'c: 'e, DB, E>(&self, executor: E) -> sqlxplus::Result<()>
592 where
593 DB: sqlx::Database + sqlxplus::DatabaseInfo,
594 for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
595 E: sqlxplus::DatabaseType<DB = DB>
596 + sqlx::Executor<'c, Database = DB>
597 + Send,
598 String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
601 i64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
602 i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
603 i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
604 f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
605 f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
606 bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
607 Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
608 Option<i64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
609 Option<i32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
610 Option<i16>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
611 Option<f64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
612 Option<f32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
613 Option<bool>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
614 chrono::DateTime<chrono::Utc>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
615 Option<chrono::DateTime<chrono::Utc>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
616 chrono::NaiveDateTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
617 Option<chrono::NaiveDateTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
618 chrono::NaiveDate: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
619 Option<chrono::NaiveDate>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
620 chrono::NaiveTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
621 Option<chrono::NaiveTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
622 Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
623 Option<Vec<u8>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
624 serde_json::Value: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
625 Option<serde_json::Value>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
626 {
627 use sqlxplus::Model;
628 use sqlxplus::DatabaseInfo;
629 use sqlxplus::db_pool::DbDriver;
630 let table = Self::TABLE;
631 let pk = Self::PK;
632 let escaped_table = DB::escape_identifier(table);
633 let escaped_pk = DB::escape_identifier(pk);
634
635 let mut set_parts: Vec<String> = Vec::new();
637 let mut placeholder_index = 0;
638
639 #(
641 set_parts.push(format!("{} = {}", DB::escape_identifier(#update_normal_field_columns), DB::placeholder(placeholder_index)));
642 placeholder_index += 1;
643 )*
644
645 match DB::get_driver() {
647 DbDriver::Sqlite => {
648 #(
650 if self.#update_option_field_names.is_some() {
651 set_parts.push(format!("{} = {}", DB::escape_identifier(#update_option_field_columns), DB::placeholder(placeholder_index)));
652 placeholder_index += 1;
653 }
654 )*
655 }
656 _ => {
657 #(
659 if self.#update_option_field_names.is_some() {
660 set_parts.push(format!("{} = {}", DB::escape_identifier(#update_option_field_columns), DB::placeholder(placeholder_index)));
661 placeholder_index += 1;
662 } else {
663 set_parts.push(format!("{} = DEFAULT", DB::escape_identifier(#update_option_field_columns)));
664 }
665 )*
666 }
667 }
668
669 if set_parts.is_empty() {
670 return Ok(());
671 }
672
673 let sql = format!(
674 "UPDATE {} SET {} WHERE {} = {}",
675 escaped_table,
676 set_parts.join(", "),
677 escaped_pk,
678 DB::placeholder(placeholder_index)
679 );
680
681 let mut query = sqlx::query(&sql);
682 #(
684 query = query.bind(&self.#update_normal_field_names);
685 )*
686 #(
688 if let Some(ref val) = self.#update_option_field_names {
689 query = query.bind(val);
690 }
691 )*
692 query = query.bind(&self.#pk_ident);
693 query.execute(executor).await?;
694 Ok(())
695 }
696 }
697 };
698
699 let update_fields_impl = quote! {
704 impl sqlxplus::builder::update_builder::UpdateFields for #name {
705 fn get_field_value(&self, field_name: &str) -> Option<sqlxplus::builder::query_builder::BindValue> {
706 match field_name {
707 #(
708 #update_fields_normal_field_columns => {
709 Some(sqlxplus::builder::query_builder::BindValue::from(self.#update_fields_normal_field_names.clone()))
711 }
712 )*
713 #(
714 #update_fields_option_field_columns => {
715 self.#update_fields_option_field_names.as_ref().map(|v| {
717 sqlxplus::builder::query_builder::BindValue::from(v.clone())
718 })
719 }
720 )*
721 _ => None, }
723 }
724
725 fn get_all_field_names() -> &'static [&'static str] {
726 &[
727 #(#update_normal_field_columns,)*
728 #(#update_option_field_columns,)*
729 ]
730 }
731
732 fn has_field(field_name: &str) -> bool {
733 matches!(field_name, #(#update_normal_field_columns)|* | #(#update_option_field_columns)|*)
734 }
735 }
736 };
737
738 let expanded = quote! {
739 #expanded
740 #update_fields_impl
741 };
742
743 TokenStream::from(expanded)
744}
745
746fn is_option_type(ty: &syn::Type) -> bool {
748 if let syn::Type::Path(type_path) = ty {
749 if let Some(seg) = type_path.path.segments.last() {
750 if seg.ident == "Option" {
751 if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
752 return args.args.len() == 1;
753 }
754 }
755 }
756 }
757 false
758}
759
760fn is_bind_value_supported_type(ty: &syn::Type) -> bool {
763 if let syn::Type::Path(type_path) = ty {
764 if let Some(seg) = type_path.path.segments.last() {
765 let type_name = seg.ident.to_string();
766 match type_name.as_str() {
768 "String" | "i64" | "i32" | "i16" | "f64" | "f32" | "bool" => true,
769 "Vec" => {
770 if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
772 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
773 if let syn::Type::Path(inner_path) = inner_ty {
774 if let Some(inner_seg) = inner_path.path.segments.last() {
775 return inner_seg.ident == "u8";
776 }
777 }
778 }
779 }
780 false
781 }
782 _ => false,
783 }
784 } else {
785 false
786 }
787 } else {
788 false
789 }
790}
791
792fn get_option_inner_type(ty: &syn::Type) -> Option<&syn::Type> {
794 if let syn::Type::Path(type_path) = ty {
795 if let Some(seg) = type_path.path.segments.last() {
796 if seg.ident == "Option" {
797 if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
798 if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
799 return Some(inner_ty);
800 }
801 }
802 }
803 }
804 }
805 None
806}