sqlxplus_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2;
3use quote::quote;
4use syn::{parse::Parser, parse_macro_input, Data, DataStruct, DeriveInput, Fields, Meta};
5
6/// 生成 Model trait 的实现
7///
8/// 自动生成 `TABLE`、`PK` 和可选的 `SOFT_DELETE_FIELD` 常量
9///
10/// 使用示例:
11/// ```ignore
12/// // 物理删除模式(默认)
13/// #[derive(ModelMeta)]
14/// #[model(table = "users", pk = "id")]
15/// struct User {
16///     id: i64,
17///     name: String,
18/// }
19///
20/// // 逻辑删除模式
21/// #[derive(ModelMeta)]
22/// #[model(table = "users", pk = "id", soft_delete = "is_deleted")]
23/// struct UserWithSoftDelete {
24///     id: i64,
25///     name: String,
26///     is_deleted: i32, // 逻辑删除字段:0=未删除,1=已删除
27/// }
28/// ```
29#[proc_macro_derive(ModelMeta, attributes(model, column))]
30pub fn derive_model_meta(input: TokenStream) -> TokenStream {
31    let input = parse_macro_input!(input as DeriveInput);
32    let name = &input.ident;
33
34    // 解析属性
35    let mut table_name = None;
36    let mut pk_field = None;
37    let mut soft_delete_field = None;
38
39    for attr in &input.attrs {
40        if attr.path().is_ident("model") {
41            // 在 syn 2.0 中,使用 meta() 方法获取元数据
42            if let syn::Meta::List(list) = &attr.meta {
43                // 解析列表中的每个 Meta::NameValue
44                let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
45                if let Ok(metas) = parser.parse2(list.tokens.clone()) {
46                    for meta in metas {
47                        if let Meta::NameValue(nv) = meta {
48                            if nv.path.is_ident("table") {
49                                if let syn::Expr::Lit(syn::ExprLit {
50                                    lit: syn::Lit::Str(s),
51                                    ..
52                                }) = nv.value
53                                {
54                                    table_name = Some(s.value());
55                                }
56                            } else if nv.path.is_ident("pk") {
57                                if let syn::Expr::Lit(syn::ExprLit {
58                                    lit: syn::Lit::Str(s),
59                                    ..
60                                }) = nv.value
61                                {
62                                    pk_field = Some(s.value());
63                                }
64                            } else if nv.path.is_ident("soft_delete") {
65                                if let syn::Expr::Lit(syn::ExprLit {
66                                    lit: syn::Lit::Str(s),
67                                    ..
68                                }) = nv.value
69                                {
70                                    soft_delete_field = Some(s.value());
71                                }
72                            }
73                        }
74                    }
75                }
76            } else if let syn::Meta::NameValue(nv) = &attr.meta {
77                // 单个 NameValue 的情况
78                if nv.path.is_ident("table") {
79                    if let syn::Expr::Lit(syn::ExprLit {
80                        lit: syn::Lit::Str(s),
81                        ..
82                    }) = &nv.value
83                    {
84                        table_name = Some(s.value());
85                    }
86                } else if nv.path.is_ident("pk") {
87                    if let syn::Expr::Lit(syn::ExprLit {
88                        lit: syn::Lit::Str(s),
89                        ..
90                    }) = &nv.value
91                    {
92                        pk_field = Some(s.value());
93                    }
94                } else if nv.path.is_ident("soft_delete") {
95                    if let syn::Expr::Lit(syn::ExprLit {
96                        lit: syn::Lit::Str(s),
97                        ..
98                    }) = &nv.value
99                    {
100                        soft_delete_field = Some(s.value());
101                    }
102                }
103            }
104        }
105    }
106
107    // 如果没有指定表名,使用结构体名称的小写蛇形命名方式
108    let table = table_name.unwrap_or_else(|| {
109        let s = name.to_string();
110        // 将 PascalCase 转换为 snake_case
111        let mut result = String::new();
112        for (i, c) in s.chars().enumerate() {
113            if c.is_uppercase() && i > 0 {
114                result.push('_');
115            }
116            result.push(c.to_ascii_lowercase());
117        }
118        result
119    });
120
121    // 如果没有指定主键,默认使用 "id"
122    let pk = pk_field.unwrap_or_else(|| "id".to_string());
123
124    // 生成实现代码
125    let expanded = if let Some(soft_delete) = soft_delete_field {
126        // 如果指定了逻辑删除字段,生成包含 SOFT_DELETE_FIELD 的实现
127        let soft_delete_lit = syn::LitStr::new(&soft_delete, proc_macro2::Span::call_site());
128        quote! {
129            impl sqlxplus::Model for #name {
130                const TABLE: &'static str = #table;
131                const PK: &'static str = #pk;
132                const SOFT_DELETE_FIELD: Option<&'static str> = Some(#soft_delete_lit);
133            }
134        }
135    } else {
136        // 如果没有指定逻辑删除字段,SOFT_DELETE_FIELD 为 None
137        quote! {
138            impl sqlxplus::Model for #name {
139                const TABLE: &'static str = #table;
140                const PK: &'static str = #pk;
141                const SOFT_DELETE_FIELD: Option<&'static str> = None;
142            }
143        }
144    };
145
146    TokenStream::from(expanded)
147}
148
149/// 生成 CRUD trait 的实现
150///
151/// 自动生成 insert 和 update 方法的实现
152///
153/// 使用示例:
154/// ```ignore
155/// // 物理删除模式
156/// #[derive(CRUD, FromRow, ModelMeta)]
157/// #[model(table = "users", pk = "id")]
158/// struct User {
159///     id: i64,
160///     name: String,
161///     email: String,
162/// }
163///
164/// // 逻辑删除模式
165/// #[derive(CRUD, FromRow, ModelMeta)]
166/// #[model(table = "users", pk = "id", soft_delete = "is_deleted")]
167/// struct UserWithSoftDelete {
168///     id: i64,
169///     name: String,
170///     email: String,
171///     is_deleted: i32, // 逻辑删除字段
172/// }
173/// ```
174#[proc_macro_derive(CRUD, attributes(model, skip, column))]
175pub fn derive_crud(input: TokenStream) -> TokenStream {
176    let input = parse_macro_input!(input as DeriveInput);
177    let name = &input.ident;
178
179    // 解析 #[model(pk = "...")],获取主键字段名,默认 "id"
180    let mut pk_field = None;
181    for attr in &input.attrs {
182        if attr.path().is_ident("model") {
183            if let syn::Meta::List(list) = &attr.meta {
184                let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
185                if let Ok(metas) = parser.parse2(list.tokens.clone()) {
186                    for meta in metas {
187                        if let Meta::NameValue(nv) = meta {
188                            if nv.path.is_ident("pk") {
189                                if let syn::Expr::Lit(syn::ExprLit {
190                                    lit: syn::Lit::Str(s),
191                                    ..
192                                }) = nv.value
193                                {
194                                    pk_field = Some(s.value());
195                                }
196                            }
197                        }
198                    }
199                }
200            } else if let syn::Meta::NameValue(nv) = &attr.meta {
201                if nv.path.is_ident("pk") {
202                    if let syn::Expr::Lit(syn::ExprLit {
203                        lit: syn::Lit::Str(s),
204                        ..
205                    }) = &nv.value
206                    {
207                        pk_field = Some(s.value());
208                    }
209                }
210            }
211        }
212    }
213    // 如果没有指定主键,默认使用 "id"
214    let pk = pk_field.unwrap_or_else(|| "id".to_string());
215
216    // 获取字段列表(必须是具名字段的结构体)
217    let fields = match &input.data {
218        Data::Struct(DataStruct {
219            fields: Fields::Named(fields),
220            ..
221        }) => &fields.named,
222        _ => {
223            return syn::Error::new_spanned(
224                name,
225                "CRUD derive only supports structs with named fields",
226            )
227            .to_compile_error()
228            .into();
229        }
230    };
231
232    // 收集字段信息
233    // - pk_ident: 主键字段 Ident
234    // - insert_*/update_*: 非主键字段(INSERT / UPDATE 使用)
235    let mut pk_ident_opt: Option<&syn::Ident> = None;
236
237    // INSERT 使用的字段(非主键)
238    let mut insert_normal_field_names: Vec<&syn::Ident> = Vec::new();
239    let mut insert_normal_field_columns: Vec<syn::LitStr> = Vec::new();
240    let mut insert_option_field_names: Vec<&syn::Ident> = Vec::new();
241    let mut insert_option_field_columns: Vec<syn::LitStr> = Vec::new();
242
243    // UPDATE 使用的字段(非主键)
244    let mut update_normal_field_names: Vec<&syn::Ident> = Vec::new();
245    let mut update_normal_field_columns: Vec<syn::LitStr> = Vec::new();
246    let mut update_option_field_names: Vec<&syn::Ident> = Vec::new();
247    let mut update_option_field_columns: Vec<syn::LitStr> = Vec::new();
248
249    // 用于 UpdateFields trait:只包含 BindValue 支持的类型
250    let mut update_fields_normal_field_names: Vec<&syn::Ident> = Vec::new();
251    let mut update_fields_normal_field_columns: Vec<syn::LitStr> = Vec::new();
252    let mut update_fields_option_field_names: Vec<&syn::Ident> = Vec::new();
253    let mut update_fields_option_field_columns: Vec<syn::LitStr> = Vec::new();
254
255    for field in fields {
256        let field_name = field.ident.as_ref().unwrap();
257        let field_name_str = field_name.to_string();
258
259        // 检查属性:skip / model
260        let mut skip = false;
261        for attr in &field.attrs {
262            if attr.path().is_ident("skip") || attr.path().is_ident("model") {
263                skip = true;
264                break;
265            }
266        }
267
268        if !skip {
269            if field_name_str == pk {
270                // 记录主键字段
271                pk_ident_opt = Some(field_name);
272                // 主键字段也需要添加到 UpdateFields,因为 UpdateBuilder 需要获取主键值
273                let is_opt = is_option_type(&field.ty);
274                let col_lit = syn::LitStr::new(&field_name_str, proc_macro2::Span::call_site());
275                let is_supported = if is_opt {
276                    if let Some(inner_ty) = get_option_inner_type(&field.ty) {
277                        is_bind_value_supported_type(inner_ty)
278                    } else {
279                        false
280                    }
281                } else {
282                    is_bind_value_supported_type(&field.ty)
283                };
284                // 如果主键类型是支持的类型,添加到 UpdateFields
285                if is_supported {
286                    if is_opt {
287                        update_fields_option_field_names.push(field_name);
288                        update_fields_option_field_columns.push(col_lit);
289                    } else {
290                        update_fields_normal_field_names.push(field_name);
291                        update_fields_normal_field_columns.push(col_lit);
292                    }
293                }
294            } else {
295                // 非主键字段用于 INSERT / UPDATE
296                let is_opt = is_option_type(&field.ty);
297                let col_lit = syn::LitStr::new(&field_name_str, proc_macro2::Span::call_site());
298
299                // 检查是否是 BindValue 支持的类型
300                let is_supported = if is_opt {
301                    if let Some(inner_ty) = get_option_inner_type(&field.ty) {
302                        is_bind_value_supported_type(inner_ty)
303                    } else {
304                        false
305                    }
306                } else {
307                    is_bind_value_supported_type(&field.ty)
308                };
309
310                if is_opt {
311                    insert_option_field_names.push(field_name);
312                    insert_option_field_columns.push(col_lit.clone());
313
314                    update_option_field_names.push(field_name);
315                    update_option_field_columns.push(col_lit.clone());
316
317                    // 只为支持的类型添加到 UpdateFields
318                    if is_supported {
319                        update_fields_option_field_names.push(field_name);
320                        update_fields_option_field_columns.push(col_lit);
321                    }
322                } else {
323                    insert_normal_field_names.push(field_name);
324                    insert_normal_field_columns.push(col_lit.clone());
325
326                    update_normal_field_names.push(field_name);
327                    update_normal_field_columns.push(col_lit.clone());
328
329                    // 只为支持的类型添加到 UpdateFields
330                    if is_supported {
331                        update_fields_normal_field_names.push(field_name);
332                        update_fields_normal_field_columns.push(col_lit);
333                    }
334                }
335            }
336        }
337    }
338
339    // 编译期确保主键字段存在
340    let pk_ident = pk_ident_opt.expect("Primary key field not found in struct");
341
342    // 生成实现代码
343    let expanded = quote! {
344        // Trait 方法实现
345        #[async_trait::async_trait]
346        impl sqlxplus::Crud for #name {
347            // 泛型版本的 insert(自动类型推断)
348            async fn insert<'e, 'c: 'e, DB, E>(&self, executor: E) -> sqlxplus::Result<sqlxplus::crud::Id>
349            where
350                DB: sqlx::Database + sqlxplus::DatabaseInfo,
351                for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
352                E: sqlxplus::DatabaseType<DB = DB>
353                    + sqlx::Executor<'c, Database = DB>
354                    + Send,
355                i64: sqlx::Type<DB> + for<'r> sqlx::Decode<'r, DB>,
356                usize: sqlx::ColumnIndex<DB::Row>,
357                // 基本类型必须实现 Type<DB> 和 Encode<DB>(用于绑定值)
358                // 注意:只包含三种数据库(MySQL、PostgreSQL、SQLite)都支持的类型
359                String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
360                i64: for<'b> sqlx::Encode<'b, DB>,
361                i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
362                i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
363                f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
364                f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
365                bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
366                Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
367                Option<i64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
368                Option<i32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
369                Option<i16>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
370                Option<f64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
371                Option<f32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
372                Option<bool>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
373                chrono::DateTime<chrono::Utc>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
374                Option<chrono::DateTime<chrono::Utc>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
375                chrono::NaiveDateTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
376                Option<chrono::NaiveDateTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
377                chrono::NaiveDate: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
378                Option<chrono::NaiveDate>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
379                chrono::NaiveTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
380                Option<chrono::NaiveTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
381                Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
382                Option<Vec<u8>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
383                serde_json::Value: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
384                Option<serde_json::Value>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
385            {
386                use sqlxplus::Model;
387                use sqlxplus::DatabaseInfo;
388                use sqlxplus::db_pool::DbDriver;
389                let table = Self::TABLE;
390                let escaped_table = DB::escape_identifier(table);
391
392                // 构建列名和占位符
393                let mut columns: Vec<&str> = Vec::new();
394                let mut placeholders: Vec<String> = Vec::new();
395                let mut placeholder_index = 0;
396
397                // 非 Option 字段:始终参与 INSERT
398                #(
399                    columns.push(#insert_normal_field_columns);
400                    placeholders.push(DB::placeholder(placeholder_index));
401                    placeholder_index += 1;
402                )*
403
404                // Option 字段:仅当为 Some 时参与 INSERT
405                #(
406                    if self.#insert_option_field_names.is_some() {
407                        columns.push(#insert_option_field_columns);
408                        placeholders.push(DB::placeholder(placeholder_index));
409                        placeholder_index += 1;
410                    }
411                )*
412
413                // 根据数据库类型构建 SQL
414                let sql = match DB::get_driver() {
415                    DbDriver::Postgres => {
416                        let pk = Self::PK;
417                        let escaped_pk = DB::escape_identifier(pk);
418                        format!(
419                            "INSERT INTO {} ({}) VALUES ({}) RETURNING {}",
420                            escaped_table,
421                            columns.join(", "),
422                            placeholders.join(", "),
423                            escaped_pk
424                        )
425                    }
426                    _ => {
427                        format!(
428                            "INSERT INTO {} ({}) VALUES ({})",
429                            escaped_table,
430                            columns.join(", "),
431                            placeholders.join(", ")
432                        )
433                    }
434                };
435
436                // 根据数据库类型执行查询
437                match DB::get_driver() {
438                    DbDriver::Postgres => {
439                        let mut query = sqlx::query_scalar::<_, i64>(&sql);
440                        // 非 Option 字段:始终绑定
441                        #(
442                            query = query.bind(&self.#insert_normal_field_names);
443                        )*
444                        // Option 字段:仅当为 Some 时绑定
445                        #(
446                            if let Some(ref val) = self.#insert_option_field_names {
447                                query = query.bind(val);
448                            }
449                        )*
450                        let id: i64 = query.fetch_one(executor).await?;
451                        Ok(id)
452                    }
453                    DbDriver::MySql => {
454                        let mut query = sqlx::query(&sql);
455                        // 非 Option 字段:始终绑定
456                        #(
457                            query = query.bind(&self.#insert_normal_field_names);
458                        )*
459                        // Option 字段:仅当为 Some 时绑定
460                        #(
461                            if let Some(ref val) = self.#insert_option_field_names {
462                                query = query.bind(val);
463                            }
464                        )*
465                        let result = query.execute(executor).await?;
466                        // 在泛型上下文中,我们需要使用 unsafe 转换来访问数据库特定的方法
467                        // 这是安全的,因为我们已经通过 DB::get_driver() 确认了数据库类型
468                        // 并且我们知道 DB = MySql,所以 result 的类型是 MySqlQueryResult
469                        unsafe {
470                            use sqlx::mysql::MySqlQueryResult;
471                            let ptr: *const DB::QueryResult = &result;
472                            let mysql_ptr = ptr as *const MySqlQueryResult;
473                            Ok((*mysql_ptr).last_insert_id() as i64)
474                        }
475                    }
476                    DbDriver::Sqlite => {
477                        let mut query = sqlx::query(&sql);
478                        // 非 Option 字段:始终绑定
479                        #(
480                            query = query.bind(&self.#insert_normal_field_names);
481                        )*
482                        // Option 字段:仅当为 Some 时绑定
483                        #(
484                            if let Some(ref val) = self.#insert_option_field_names {
485                                query = query.bind(val);
486                            }
487                        )*
488                        let result = query.execute(executor).await?;
489                        // 在泛型上下文中,我们需要使用 unsafe 转换来访问数据库特定的方法
490                        unsafe {
491                            use sqlx::sqlite::SqliteQueryResult;
492                            let ptr: *const DB::QueryResult = &result;
493                            let sqlite_ptr = ptr as *const SqliteQueryResult;
494                            Ok((*sqlite_ptr).last_insert_rowid() as i64)
495                        }
496                    }
497                }
498            }
499
500            // 泛型版本的 update(自动类型推断)
501            async fn update<'e, 'c: 'e, DB, E>(&self, executor: E) -> sqlxplus::Result<()>
502            where
503                DB: sqlx::Database + sqlxplus::DatabaseInfo,
504                for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
505                E: sqlxplus::DatabaseType<DB = DB>
506                    + sqlx::Executor<'c, Database = DB>
507                    + Send,
508                // 基本类型必须实现 Type<DB> 和 Encode<DB>(用于绑定值)
509                // 注意:只包含三种数据库(MySQL、PostgreSQL、SQLite)都支持的类型
510                String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
511                i64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
512                i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
513                i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
514                f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
515                f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
516                bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
517                Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
518                Option<i64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
519                Option<i32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
520                Option<i16>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
521                Option<f64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
522                Option<f32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
523                Option<bool>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
524                chrono::DateTime<chrono::Utc>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
525                Option<chrono::DateTime<chrono::Utc>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
526                chrono::NaiveDateTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
527                Option<chrono::NaiveDateTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
528                chrono::NaiveDate: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
529                Option<chrono::NaiveDate>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
530                chrono::NaiveTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
531                Option<chrono::NaiveTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
532                Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
533                Option<Vec<u8>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
534                serde_json::Value: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
535                Option<serde_json::Value>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
536            {
537                use sqlxplus::Model;
538                use sqlxplus::DatabaseInfo;
539                let table = Self::TABLE;
540                let pk = Self::PK;
541                let escaped_table = DB::escape_identifier(table);
542                let escaped_pk = DB::escape_identifier(pk);
543
544                // 构建 UPDATE SET 子句(Patch 语义)
545                let mut set_parts: Vec<String> = Vec::new();
546                let mut placeholder_index = 0;
547
548                // 非 Option 字段
549                #(
550                    set_parts.push(format!("{} = {}", DB::escape_identifier(#update_normal_field_columns), DB::placeholder(placeholder_index)));
551                    placeholder_index += 1;
552                )*
553
554                // Option 字段
555                #(
556                    if self.#update_option_field_names.is_some() {
557                        set_parts.push(format!("{} = {}", DB::escape_identifier(#update_option_field_columns), DB::placeholder(placeholder_index)));
558                        placeholder_index += 1;
559                    }
560                )*
561
562                if set_parts.is_empty() {
563                    return Ok(());
564                }
565
566                let sql = format!(
567                    "UPDATE {} SET {} WHERE {} = {}",
568                    escaped_table,
569                    set_parts.join(", "),
570                    escaped_pk,
571                    DB::placeholder(placeholder_index)
572                );
573
574                let mut query = sqlx::query(&sql);
575                // 非 Option 字段:始终绑定
576                #(
577                    query = query.bind(&self.#update_normal_field_names);
578                )*
579                // Option 字段:仅当为 Some 时绑定
580                #(
581                    if let Some(ref val) = self.#update_option_field_names {
582                        query = query.bind(val);
583                    }
584                )*
585                query = query.bind(&self.#pk_ident);
586                query.execute(executor).await?;
587                Ok(())
588            }
589
590            // 泛型版本的 update_with_none(自动类型推断)
591            async fn update_with_none<'e, 'c: 'e, DB, E>(&self, executor: E) -> sqlxplus::Result<()>
592            where
593                DB: sqlx::Database + sqlxplus::DatabaseInfo,
594                for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
595                E: sqlxplus::DatabaseType<DB = DB>
596                    + sqlx::Executor<'c, Database = DB>
597                    + Send,
598                // 基本类型必须实现 Type<DB> 和 Encode<DB>(用于绑定值)
599                // 注意:只包含三种数据库(MySQL、PostgreSQL、SQLite)都支持的类型
600                String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
601                i64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
602                i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
603                i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
604                f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
605                f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
606                bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
607                Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
608                Option<i64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
609                Option<i32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
610                Option<i16>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
611                Option<f64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
612                Option<f32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
613                Option<bool>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
614                chrono::DateTime<chrono::Utc>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
615                Option<chrono::DateTime<chrono::Utc>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
616                chrono::NaiveDateTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
617                Option<chrono::NaiveDateTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
618                chrono::NaiveDate: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
619                Option<chrono::NaiveDate>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
620                chrono::NaiveTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
621                Option<chrono::NaiveTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
622                Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
623                Option<Vec<u8>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
624                serde_json::Value: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
625                Option<serde_json::Value>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
626            {
627                use sqlxplus::Model;
628                use sqlxplus::DatabaseInfo;
629                use sqlxplus::db_pool::DbDriver;
630                let table = Self::TABLE;
631                let pk = Self::PK;
632                let escaped_table = DB::escape_identifier(table);
633                let escaped_pk = DB::escape_identifier(pk);
634
635                // 构建 UPDATE SET 子句(Reset 语义)
636                let mut set_parts: Vec<String> = Vec::new();
637                let mut placeholder_index = 0;
638
639                // 非 Option 字段:始终更新为当前值
640                #(
641                    set_parts.push(format!("{} = {}", DB::escape_identifier(#update_normal_field_columns), DB::placeholder(placeholder_index)));
642                    placeholder_index += 1;
643                )*
644
645                // Option 字段:根据数据库类型处理
646                match DB::get_driver() {
647                    DbDriver::Sqlite => {
648                        // SQLite 不支持 DEFAULT,跳过 None 字段
649                        #(
650                            if self.#update_option_field_names.is_some() {
651                                set_parts.push(format!("{} = {}", DB::escape_identifier(#update_option_field_columns), DB::placeholder(placeholder_index)));
652                                placeholder_index += 1;
653                            }
654                        )*
655                    }
656                    _ => {
657                        // MySQL 和 PostgreSQL 使用 DEFAULT
658                        #(
659                            if self.#update_option_field_names.is_some() {
660                                set_parts.push(format!("{} = {}", DB::escape_identifier(#update_option_field_columns), DB::placeholder(placeholder_index)));
661                                placeholder_index += 1;
662                            } else {
663                                set_parts.push(format!("{} = DEFAULT", DB::escape_identifier(#update_option_field_columns)));
664                            }
665                        )*
666                    }
667                }
668
669                if set_parts.is_empty() {
670                    return Ok(());
671                }
672
673                let sql = format!(
674                    "UPDATE {} SET {} WHERE {} = {}",
675                    escaped_table,
676                    set_parts.join(", "),
677                    escaped_pk,
678                    DB::placeholder(placeholder_index)
679                );
680
681                let mut query = sqlx::query(&sql);
682                // 非 Option 字段:始终绑定
683                #(
684                    query = query.bind(&self.#update_normal_field_names);
685                )*
686                // Option 字段:仅当为 Some 时绑定(None 使用 DEFAULT 或跳过)
687                #(
688                    if let Some(ref val) = self.#update_option_field_names {
689                        query = query.bind(val);
690                    }
691                )*
692                query = query.bind(&self.#pk_ident);
693                query.execute(executor).await?;
694                Ok(())
695            }
696        }
697    };
698
699    // 生成 UpdateFields trait 实现(用于 UpdateBuilder 和 InsertBuilder)
700    // 注意:只对 BindValue 支持的基本类型生成转换代码
701    // 对于复杂类型(如 DateTime、JsonValue 等),get_field_value 返回 None
702    // InsertBuilder 和 UpdateBuilder 需要直接使用 sqlx::bind 来处理这些类型
703    let update_fields_impl = quote! {
704        impl sqlxplus::builder::update_builder::UpdateFields for #name {
705            fn get_field_value(&self, field_name: &str) -> Option<sqlxplus::builder::query_builder::BindValue> {
706                match field_name {
707                    #(
708                        #update_fields_normal_field_columns => {
709                            // 对于非 Option 类型,转换为 BindValue(只包含支持的类型)
710                            Some(sqlxplus::builder::query_builder::BindValue::from(self.#update_fields_normal_field_names.clone()))
711                        }
712                    )*
713                    #(
714                        #update_fields_option_field_columns => {
715                            // 对于 Option 类型,如果是 Some 则转换,None 则返回 None(只包含支持的类型)
716                            self.#update_fields_option_field_names.as_ref().map(|v| {
717                                sqlxplus::builder::query_builder::BindValue::from(v.clone())
718                            })
719                        }
720                    )*
721                    _ => None, // 不支持的类型或未包含的字段返回 None
722                }
723            }
724
725            fn get_all_field_names() -> &'static [&'static str] {
726                &[
727                    #(#update_normal_field_columns,)*
728                    #(#update_option_field_columns,)*
729                ]
730            }
731
732            fn has_field(field_name: &str) -> bool {
733                matches!(field_name, #(#update_normal_field_columns)|* | #(#update_option_field_columns)|*)
734            }
735        }
736    };
737
738    let expanded = quote! {
739        #expanded
740        #update_fields_impl
741    };
742
743    TokenStream::from(expanded)
744}
745
746/// 判断字段类型是否为 Option<T>
747fn is_option_type(ty: &syn::Type) -> bool {
748    if let syn::Type::Path(type_path) = ty {
749        if let Some(seg) = type_path.path.segments.last() {
750            if seg.ident == "Option" {
751                if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
752                    return args.args.len() == 1;
753                }
754            }
755        }
756    }
757    false
758}
759
760/// 检查类型是否是 BindValue 支持的基本类型
761/// 支持的类型:String, i64, i32, i16, f64, f32, bool, Vec<u8>
762fn is_bind_value_supported_type(ty: &syn::Type) -> bool {
763    if let syn::Type::Path(type_path) = ty {
764        if let Some(seg) = type_path.path.segments.last() {
765            let type_name = seg.ident.to_string();
766            // 检查是否是支持的基本类型
767            match type_name.as_str() {
768                "String" | "i64" | "i32" | "i16" | "f64" | "f32" | "bool" => true,
769                "Vec" => {
770                    // 对于 Vec,检查是否是 Vec<u8>
771                    if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
772                        if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
773                            if let syn::Type::Path(inner_path) = inner_ty {
774                                if let Some(inner_seg) = inner_path.path.segments.last() {
775                                    return inner_seg.ident == "u8";
776                                }
777                            }
778                        }
779                    }
780                    false
781                }
782                _ => false,
783            }
784        } else {
785            false
786        }
787    } else {
788        false
789    }
790}
791
792/// 获取 Option 内部的类型
793fn get_option_inner_type(ty: &syn::Type) -> Option<&syn::Type> {
794    if let syn::Type::Path(type_path) = ty {
795        if let Some(seg) = type_path.path.segments.last() {
796            if seg.ident == "Option" {
797                if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
798                    if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
799                        return Some(inner_ty);
800                    }
801                }
802            }
803        }
804    }
805    None
806}