sqlxplus_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2;
3use quote::quote;
4use syn::{parse::Parser, parse_macro_input, Data, DataStruct, DeriveInput, Fields, Meta};
5
6/// 生成 Model trait 的实现
7///
8/// 自动生成 `TABLE`、`PK` 和可选的 `SOFT_DELETE_FIELD` 常量
9///
10/// 使用示例:
11/// ```ignore
12/// // 物理删除模式(默认)
13/// #[derive(ModelMeta)]
14/// #[model(table = "users", pk = "id")]
15/// struct User {
16///     id: i64,
17///     name: String,
18/// }
19///
20/// // 逻辑删除模式
21/// #[derive(ModelMeta)]
22/// #[model(table = "users", pk = "id", soft_delete = "is_deleted")]
23/// struct UserWithSoftDelete {
24///     id: i64,
25///     name: String,
26///     is_deleted: i32, // 逻辑删除字段:0=未删除,1=已删除
27/// }
28/// ```
29#[proc_macro_derive(ModelMeta, attributes(model))]
30pub fn derive_model_meta(input: TokenStream) -> TokenStream {
31    let input = parse_macro_input!(input as DeriveInput);
32    let name = &input.ident;
33
34    // 解析属性
35    let mut table_name = None;
36    let mut pk_field = None;
37    let mut soft_delete_field = None;
38
39    for attr in &input.attrs {
40        if attr.path().is_ident("model") {
41            // 在 syn 2.0 中,使用 meta() 方法获取元数据
42            if let syn::Meta::List(list) = &attr.meta {
43                // 解析列表中的每个 Meta::NameValue
44                let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
45                if let Ok(metas) = parser.parse2(list.tokens.clone()) {
46                    for meta in metas {
47                        if let Meta::NameValue(nv) = meta {
48                            if nv.path.is_ident("table") {
49                                if let syn::Expr::Lit(syn::ExprLit {
50                                    lit: syn::Lit::Str(s),
51                                    ..
52                                }) = nv.value
53                                {
54                                    table_name = Some(s.value());
55                                }
56                            } else if nv.path.is_ident("pk") {
57                                if let syn::Expr::Lit(syn::ExprLit {
58                                    lit: syn::Lit::Str(s),
59                                    ..
60                                }) = nv.value
61                                {
62                                    pk_field = Some(s.value());
63                                }
64                            } else if nv.path.is_ident("soft_delete") {
65                                if let syn::Expr::Lit(syn::ExprLit {
66                                    lit: syn::Lit::Str(s),
67                                    ..
68                                }) = nv.value
69                                {
70                                    soft_delete_field = Some(s.value());
71                                }
72                            }
73                        }
74                    }
75                }
76            } else if let syn::Meta::NameValue(nv) = &attr.meta {
77                // 单个 NameValue 的情况
78                if nv.path.is_ident("table") {
79                    if let syn::Expr::Lit(syn::ExprLit {
80                        lit: syn::Lit::Str(s),
81                        ..
82                    }) = &nv.value
83                    {
84                        table_name = Some(s.value());
85                    }
86                } else if nv.path.is_ident("pk") {
87                    if let syn::Expr::Lit(syn::ExprLit {
88                        lit: syn::Lit::Str(s),
89                        ..
90                    }) = &nv.value
91                    {
92                        pk_field = Some(s.value());
93                    }
94                } else if nv.path.is_ident("soft_delete") {
95                    if let syn::Expr::Lit(syn::ExprLit {
96                        lit: syn::Lit::Str(s),
97                        ..
98                    }) = &nv.value
99                    {
100                        soft_delete_field = Some(s.value());
101                    }
102                }
103            }
104        }
105    }
106
107    // 如果没有指定表名,使用结构体名称的小写蛇形命名方式
108    let table = table_name.unwrap_or_else(|| {
109        let s = name.to_string();
110        // 将 PascalCase 转换为 snake_case
111        let mut result = String::new();
112        for (i, c) in s.chars().enumerate() {
113            if c.is_uppercase() && i > 0 {
114                result.push('_');
115            }
116            result.push(c.to_ascii_lowercase());
117        }
118        result
119    });
120
121    // 如果没有指定主键,默认使用 "id"
122    let pk = pk_field.unwrap_or_else(|| "id".to_string());
123
124    // 生成实现代码
125    let expanded = if let Some(soft_delete) = soft_delete_field {
126        // 如果指定了逻辑删除字段,生成包含 SOFT_DELETE_FIELD 的实现
127        let soft_delete_lit = syn::LitStr::new(&soft_delete, proc_macro2::Span::call_site());
128        quote! {
129            impl sqlxplus::Model for #name {
130                const TABLE: &'static str = #table;
131                const PK: &'static str = #pk;
132                const SOFT_DELETE_FIELD: Option<&'static str> = Some(#soft_delete_lit);
133            }
134        }
135    } else {
136        // 如果没有指定逻辑删除字段,SOFT_DELETE_FIELD 为 None
137        quote! {
138            impl sqlxplus::Model for #name {
139                const TABLE: &'static str = #table;
140                const PK: &'static str = #pk;
141                const SOFT_DELETE_FIELD: Option<&'static str> = None;
142            }
143        }
144    };
145
146    TokenStream::from(expanded)
147}
148
149/// 生成 CRUD trait 的实现
150///
151/// 自动生成 insert 和 update 方法的实现
152///
153/// 使用示例:
154/// ```ignore
155/// // 物理删除模式
156/// #[derive(CRUD, FromRow, ModelMeta)]
157/// #[model(table = "users", pk = "id")]
158/// struct User {
159///     id: i64,
160///     name: String,
161///     email: String,
162/// }
163///
164/// // 逻辑删除模式
165/// #[derive(CRUD, FromRow, ModelMeta)]
166/// #[model(table = "users", pk = "id", soft_delete = "is_deleted")]
167/// struct UserWithSoftDelete {
168///     id: i64,
169///     name: String,
170///     email: String,
171///     is_deleted: i32, // 逻辑删除字段
172/// }
173/// ```
174#[proc_macro_derive(CRUD, attributes(model, skip))]
175pub fn derive_crud(input: TokenStream) -> TokenStream {
176    let input = parse_macro_input!(input as DeriveInput);
177    let name = &input.ident;
178
179    // 解析 #[model(pk = "...")],获取主键字段名,默认 "id"
180    let mut pk_field = None;
181    for attr in &input.attrs {
182        if attr.path().is_ident("model") {
183            if let syn::Meta::List(list) = &attr.meta {
184                let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
185                if let Ok(metas) = parser.parse2(list.tokens.clone()) {
186                    for meta in metas {
187                        if let Meta::NameValue(nv) = meta {
188                            if nv.path.is_ident("pk") {
189                                if let syn::Expr::Lit(syn::ExprLit {
190                                    lit: syn::Lit::Str(s),
191                                    ..
192                                }) = nv.value
193                                {
194                                    pk_field = Some(s.value());
195                                }
196                            }
197                        }
198                    }
199                }
200            } else if let syn::Meta::NameValue(nv) = &attr.meta {
201                if nv.path.is_ident("pk") {
202                    if let syn::Expr::Lit(syn::ExprLit {
203                        lit: syn::Lit::Str(s),
204                        ..
205                    }) = &nv.value
206                    {
207                        pk_field = Some(s.value());
208                    }
209                }
210            }
211        }
212    }
213    // 如果没有指定主键,默认使用 "id"
214    let pk = pk_field.unwrap_or_else(|| "id".to_string());
215
216    // 获取字段列表(必须是具名字段的结构体)
217    let fields = match &input.data {
218        Data::Struct(DataStruct {
219            fields: Fields::Named(fields),
220            ..
221        }) => &fields.named,
222        _ => {
223            return syn::Error::new_spanned(
224                name,
225                "CRUD derive only supports structs with named fields",
226            )
227            .to_compile_error()
228            .into();
229        }
230    };
231
232    // 收集字段信息
233    // - pk_ident: 主键字段 Ident
234    // - insert_*/update_*: 非主键字段(INSERT / UPDATE 使用)
235    let mut pk_ident_opt: Option<&syn::Ident> = None;
236
237    // INSERT 使用的字段(非主键)
238    let mut insert_normal_field_names: Vec<&syn::Ident> = Vec::new();
239    let mut insert_normal_field_columns: Vec<syn::LitStr> = Vec::new();
240    let mut insert_option_field_names: Vec<&syn::Ident> = Vec::new();
241    let mut insert_option_field_columns: Vec<syn::LitStr> = Vec::new();
242
243    // UPDATE 使用的字段(非主键)
244    let mut update_normal_field_names: Vec<&syn::Ident> = Vec::new();
245    let mut update_normal_field_columns: Vec<syn::LitStr> = Vec::new();
246    let mut update_option_field_names: Vec<&syn::Ident> = Vec::new();
247    let mut update_option_field_columns: Vec<syn::LitStr> = Vec::new();
248
249    for field in fields {
250        let field_name = field.ident.as_ref().unwrap();
251        let field_name_str = field_name.to_string();
252
253        // 检查属性:skip / model
254        let mut skip = false;
255        for attr in &field.attrs {
256            if attr.path().is_ident("skip") || attr.path().is_ident("model") {
257                skip = true;
258                break;
259            }
260        }
261
262        if !skip {
263            if field_name_str == pk {
264                // 记录主键字段
265                pk_ident_opt = Some(field_name);
266            } else {
267                // 非主键字段用于 INSERT / UPDATE
268                let is_opt = is_option_type(&field.ty);
269                let col_lit = syn::LitStr::new(&field_name_str, proc_macro2::Span::call_site());
270
271                if is_opt {
272                    insert_option_field_names.push(field_name);
273                    insert_option_field_columns.push(col_lit.clone());
274
275                    update_option_field_names.push(field_name);
276                    update_option_field_columns.push(col_lit);
277                } else {
278                    insert_normal_field_names.push(field_name);
279                    insert_normal_field_columns.push(col_lit.clone());
280
281                    update_normal_field_names.push(field_name);
282                    update_normal_field_columns.push(col_lit);
283                }
284            }
285        }
286    }
287
288    // 编译期确保主键字段存在
289    let pk_ident = pk_ident_opt.expect("Primary key field not found in struct");
290
291    // 生成实现代码
292    let expanded = quote! {
293        // Trait 方法实现
294        #[async_trait::async_trait]
295        impl sqlxplus::Crud for #name {
296            // 泛型版本的 insert(自动类型推断)
297            async fn insert<'e, 'c: 'e, DB, E>(&self, executor: E) -> sqlxplus::Result<sqlxplus::crud::Id>
298            where
299                DB: sqlx::Database + sqlxplus::DatabaseInfo,
300                for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
301                E: sqlxplus::DatabaseType<DB = DB>
302                    + sqlx::Executor<'c, Database = DB>
303                    + Send,
304                i64: sqlx::Type<DB> + for<'r> sqlx::Decode<'r, DB>,
305                usize: sqlx::ColumnIndex<DB::Row>,
306                // 基本类型必须实现 Type<DB> 和 Encode<DB>(用于绑定值)
307                // 注意:只包含三种数据库(MySQL、PostgreSQL、SQLite)都支持的类型
308                String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
309                i64: for<'b> sqlx::Encode<'b, DB>,
310                i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
311                i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
312                f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
313                f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
314                bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
315                Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
316                Option<i64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
317                Option<i32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
318                Option<i16>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
319                Option<f64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
320                Option<f32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
321                Option<bool>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
322                chrono::DateTime<chrono::Utc>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
323                Option<chrono::DateTime<chrono::Utc>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
324                chrono::NaiveDateTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
325                Option<chrono::NaiveDateTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
326                chrono::NaiveDate: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
327                Option<chrono::NaiveDate>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
328                chrono::NaiveTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
329                Option<chrono::NaiveTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
330                Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
331                Option<Vec<u8>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
332                serde_json::Value: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
333                Option<serde_json::Value>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
334            {
335                use sqlxplus::Model;
336                use sqlxplus::DatabaseInfo;
337                use sqlxplus::db_pool::DbDriver;
338                let table = Self::TABLE;
339                let escaped_table = DB::escape_identifier(table);
340
341                // 构建列名和占位符
342                let mut columns: Vec<&str> = Vec::new();
343                let mut placeholders: Vec<String> = Vec::new();
344                let mut placeholder_index = 0;
345
346                // 非 Option 字段:始终参与 INSERT
347                #(
348                    columns.push(#insert_normal_field_columns);
349                    placeholders.push(DB::placeholder(placeholder_index));
350                    placeholder_index += 1;
351                )*
352
353                // Option 字段:仅当为 Some 时参与 INSERT
354                #(
355                    if self.#insert_option_field_names.is_some() {
356                        columns.push(#insert_option_field_columns);
357                        placeholders.push(DB::placeholder(placeholder_index));
358                        placeholder_index += 1;
359                    }
360                )*
361
362                // 根据数据库类型构建 SQL
363                let sql = match DB::get_driver() {
364                    DbDriver::Postgres => {
365                        let pk = Self::PK;
366                        let escaped_pk = DB::escape_identifier(pk);
367                        format!(
368                            "INSERT INTO {} ({}) VALUES ({}) RETURNING {}",
369                            escaped_table,
370                            columns.join(", "),
371                            placeholders.join(", "),
372                            escaped_pk
373                        )
374                    }
375                    _ => {
376                        format!(
377                            "INSERT INTO {} ({}) VALUES ({})",
378                            escaped_table,
379                            columns.join(", "),
380                            placeholders.join(", ")
381                        )
382                    }
383                };
384
385                // 根据数据库类型执行查询
386                match DB::get_driver() {
387                    DbDriver::Postgres => {
388                        let mut query = sqlx::query_scalar::<_, i64>(&sql);
389                        // 非 Option 字段:始终绑定
390                        #(
391                            query = query.bind(&self.#insert_normal_field_names);
392                        )*
393                        // Option 字段:仅当为 Some 时绑定
394                        #(
395                            if let Some(ref val) = self.#insert_option_field_names {
396                                query = query.bind(val);
397                            }
398                        )*
399                        let id: i64 = query.fetch_one(executor).await?;
400                        Ok(id)
401                    }
402                    DbDriver::MySql => {
403                        let mut query = sqlx::query(&sql);
404                        // 非 Option 字段:始终绑定
405                        #(
406                            query = query.bind(&self.#insert_normal_field_names);
407                        )*
408                        // Option 字段:仅当为 Some 时绑定
409                        #(
410                            if let Some(ref val) = self.#insert_option_field_names {
411                                query = query.bind(val);
412                            }
413                        )*
414                        let result = query.execute(executor).await?;
415                        // 在泛型上下文中,我们需要使用 unsafe 转换来访问数据库特定的方法
416                        // 这是安全的,因为我们已经通过 DB::get_driver() 确认了数据库类型
417                        // 并且我们知道 DB = MySql,所以 result 的类型是 MySqlQueryResult
418                        unsafe {
419                            use sqlx::mysql::MySqlQueryResult;
420                            let ptr: *const DB::QueryResult = &result;
421                            let mysql_ptr = ptr as *const MySqlQueryResult;
422                            Ok((*mysql_ptr).last_insert_id() as i64)
423                        }
424                    }
425                    DbDriver::Sqlite => {
426                        let mut query = sqlx::query(&sql);
427                        // 非 Option 字段:始终绑定
428                        #(
429                            query = query.bind(&self.#insert_normal_field_names);
430                        )*
431                        // Option 字段:仅当为 Some 时绑定
432                        #(
433                            if let Some(ref val) = self.#insert_option_field_names {
434                                query = query.bind(val);
435                            }
436                        )*
437                        let result = query.execute(executor).await?;
438                        // 在泛型上下文中,我们需要使用 unsafe 转换来访问数据库特定的方法
439                        unsafe {
440                            use sqlx::sqlite::SqliteQueryResult;
441                            let ptr: *const DB::QueryResult = &result;
442                            let sqlite_ptr = ptr as *const SqliteQueryResult;
443                            Ok((*sqlite_ptr).last_insert_rowid() as i64)
444                        }
445                    }
446                }
447            }
448
449            // 泛型版本的 update(自动类型推断)
450            async fn update<'e, 'c: 'e, DB, E>(&self, executor: E) -> sqlxplus::Result<()>
451            where
452                DB: sqlx::Database + sqlxplus::DatabaseInfo,
453                for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
454                E: sqlxplus::DatabaseType<DB = DB>
455                    + sqlx::Executor<'c, Database = DB>
456                    + Send,
457                // 基本类型必须实现 Type<DB> 和 Encode<DB>(用于绑定值)
458                // 注意:只包含三种数据库(MySQL、PostgreSQL、SQLite)都支持的类型
459                String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
460                i64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
461                i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
462                i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
463                f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
464                f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
465                bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
466                Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
467                Option<i64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
468                Option<i32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
469                Option<i16>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
470                Option<f64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
471                Option<f32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
472                Option<bool>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
473                chrono::DateTime<chrono::Utc>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
474                Option<chrono::DateTime<chrono::Utc>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
475                chrono::NaiveDateTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
476                Option<chrono::NaiveDateTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
477                chrono::NaiveDate: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
478                Option<chrono::NaiveDate>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
479                chrono::NaiveTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
480                Option<chrono::NaiveTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
481                Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
482                Option<Vec<u8>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
483                serde_json::Value: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
484                Option<serde_json::Value>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
485            {
486                use sqlxplus::Model;
487                use sqlxplus::DatabaseInfo;
488                let table = Self::TABLE;
489                let pk = Self::PK;
490                let escaped_table = DB::escape_identifier(table);
491                let escaped_pk = DB::escape_identifier(pk);
492
493                // 构建 UPDATE SET 子句(Patch 语义)
494                let mut set_parts: Vec<String> = Vec::new();
495                let mut placeholder_index = 0;
496
497                // 非 Option 字段
498                #(
499                    set_parts.push(format!("{} = {}", DB::escape_identifier(#update_normal_field_columns), DB::placeholder(placeholder_index)));
500                    placeholder_index += 1;
501                )*
502
503                // Option 字段
504                #(
505                    if self.#update_option_field_names.is_some() {
506                        set_parts.push(format!("{} = {}", DB::escape_identifier(#update_option_field_columns), DB::placeholder(placeholder_index)));
507                        placeholder_index += 1;
508                    }
509                )*
510
511                if set_parts.is_empty() {
512                    return Ok(());
513                }
514
515                let sql = format!(
516                    "UPDATE {} SET {} WHERE {} = {}",
517                    escaped_table,
518                    set_parts.join(", "),
519                    escaped_pk,
520                    DB::placeholder(placeholder_index)
521                );
522
523                let mut query = sqlx::query(&sql);
524                // 非 Option 字段:始终绑定
525                #(
526                    query = query.bind(&self.#update_normal_field_names);
527                )*
528                // Option 字段:仅当为 Some 时绑定
529                #(
530                    if let Some(ref val) = self.#update_option_field_names {
531                        query = query.bind(val);
532                    }
533                )*
534                query = query.bind(&self.#pk_ident);
535                query.execute(executor).await?;
536                Ok(())
537            }
538
539            // 泛型版本的 update_with_none(自动类型推断)
540            async fn update_with_none<'e, 'c: 'e, DB, E>(&self, executor: E) -> sqlxplus::Result<()>
541            where
542                DB: sqlx::Database + sqlxplus::DatabaseInfo,
543                for<'a> DB::Arguments<'a>: sqlx::IntoArguments<'a, DB>,
544                E: sqlxplus::DatabaseType<DB = DB>
545                    + sqlx::Executor<'c, Database = DB>
546                    + Send,
547                // 基本类型必须实现 Type<DB> 和 Encode<DB>(用于绑定值)
548                // 注意:只包含三种数据库(MySQL、PostgreSQL、SQLite)都支持的类型
549                String: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
550                i64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
551                i32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
552                i16: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
553                f64: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
554                f32: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
555                bool: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
556                Option<String>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
557                Option<i64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
558                Option<i32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
559                Option<i16>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
560                Option<f64>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
561                Option<f32>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
562                Option<bool>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
563                chrono::DateTime<chrono::Utc>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
564                Option<chrono::DateTime<chrono::Utc>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
565                chrono::NaiveDateTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
566                Option<chrono::NaiveDateTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
567                chrono::NaiveDate: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
568                Option<chrono::NaiveDate>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
569                chrono::NaiveTime: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
570                Option<chrono::NaiveTime>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
571                Vec<u8>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
572                Option<Vec<u8>>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
573                serde_json::Value: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
574                Option<serde_json::Value>: sqlx::Type<DB> + for<'b> sqlx::Encode<'b, DB>,
575            {
576                use sqlxplus::Model;
577                use sqlxplus::DatabaseInfo;
578                use sqlxplus::db_pool::DbDriver;
579                let table = Self::TABLE;
580                let pk = Self::PK;
581                let escaped_table = DB::escape_identifier(table);
582                let escaped_pk = DB::escape_identifier(pk);
583
584                // 构建 UPDATE SET 子句(Reset 语义)
585                let mut set_parts: Vec<String> = Vec::new();
586                let mut placeholder_index = 0;
587
588                // 非 Option 字段:始终更新为当前值
589                #(
590                    set_parts.push(format!("{} = {}", DB::escape_identifier(#update_normal_field_columns), DB::placeholder(placeholder_index)));
591                    placeholder_index += 1;
592                )*
593
594                // Option 字段:根据数据库类型处理
595                match DB::get_driver() {
596                    DbDriver::Sqlite => {
597                        // SQLite 不支持 DEFAULT,跳过 None 字段
598                        #(
599                            if self.#update_option_field_names.is_some() {
600                                set_parts.push(format!("{} = {}", DB::escape_identifier(#update_option_field_columns), DB::placeholder(placeholder_index)));
601                                placeholder_index += 1;
602                            }
603                        )*
604                    }
605                    _ => {
606                        // MySQL 和 PostgreSQL 使用 DEFAULT
607                        #(
608                            if self.#update_option_field_names.is_some() {
609                                set_parts.push(format!("{} = {}", DB::escape_identifier(#update_option_field_columns), DB::placeholder(placeholder_index)));
610                                placeholder_index += 1;
611                            } else {
612                                set_parts.push(format!("{} = DEFAULT", DB::escape_identifier(#update_option_field_columns)));
613                            }
614                        )*
615                    }
616                }
617
618                if set_parts.is_empty() {
619                    return Ok(());
620                }
621
622                let sql = format!(
623                    "UPDATE {} SET {} WHERE {} = {}",
624                    escaped_table,
625                    set_parts.join(", "),
626                    escaped_pk,
627                    DB::placeholder(placeholder_index)
628                );
629
630                let mut query = sqlx::query(&sql);
631                // 非 Option 字段:始终绑定
632                #(
633                    query = query.bind(&self.#update_normal_field_names);
634                )*
635                // Option 字段:仅当为 Some 时绑定(None 使用 DEFAULT 或跳过)
636                #(
637                    if let Some(ref val) = self.#update_option_field_names {
638                        query = query.bind(val);
639                    }
640                )*
641                query = query.bind(&self.#pk_ident);
642                query.execute(executor).await?;
643                Ok(())
644            }
645        }
646    };
647
648    TokenStream::from(expanded)
649}
650
651/// 判断字段类型是否为 Option<T>
652fn is_option_type(ty: &syn::Type) -> bool {
653    if let syn::Type::Path(type_path) = ty {
654        if let Some(seg) = type_path.path.segments.last() {
655            if seg.ident == "Option" {
656                if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
657                    return args.args.len() == 1;
658                }
659            }
660        }
661    }
662    false
663}