sqlxplus_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2;
3use quote::quote;
4use syn::{parse::Parser, parse_macro_input, Data, DataStruct, DeriveInput, Fields, Meta};
5
6/// 生成 Model trait 的实现
7///
8/// 自动生成 `TABLE`、`PK` 和可选的 `SOFT_DELETE_FIELD` 常量
9///
10/// 使用示例:
11/// ```ignore
12/// // 物理删除模式(默认)
13/// #[derive(ModelMeta)]
14/// #[model(table = "users", pk = "id")]
15/// struct User {
16///     id: i64,
17///     name: String,
18/// }
19///
20/// // 逻辑删除模式
21/// #[derive(ModelMeta)]
22/// #[model(table = "users", pk = "id", soft_delete = "is_deleted")]
23/// struct UserWithSoftDelete {
24///     id: i64,
25///     name: String,
26///     is_deleted: i32, // 逻辑删除字段:0=未删除,1=已删除
27/// }
28/// ```
29#[proc_macro_derive(ModelMeta, attributes(model))]
30pub fn derive_model_meta(input: TokenStream) -> TokenStream {
31    let input = parse_macro_input!(input as DeriveInput);
32    let name = &input.ident;
33
34    // 解析属性
35    let mut table_name = None;
36    let mut pk_field = None;
37    let mut soft_delete_field = None;
38
39    for attr in &input.attrs {
40        if attr.path().is_ident("model") {
41            // 在 syn 2.0 中,使用 meta() 方法获取元数据
42            if let syn::Meta::List(list) = &attr.meta {
43                // 解析列表中的每个 Meta::NameValue
44                let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
45                if let Ok(metas) = parser.parse2(list.tokens.clone()) {
46                    for meta in metas {
47                        if let Meta::NameValue(nv) = meta {
48                            if nv.path.is_ident("table") {
49                                if let syn::Expr::Lit(syn::ExprLit {
50                                    lit: syn::Lit::Str(s),
51                                    ..
52                                }) = nv.value
53                                {
54                                    table_name = Some(s.value());
55                                }
56                            } else if nv.path.is_ident("pk") {
57                                if let syn::Expr::Lit(syn::ExprLit {
58                                    lit: syn::Lit::Str(s),
59                                    ..
60                                }) = nv.value
61                                {
62                                    pk_field = Some(s.value());
63                                }
64                            } else if nv.path.is_ident("soft_delete") {
65                                if let syn::Expr::Lit(syn::ExprLit {
66                                    lit: syn::Lit::Str(s),
67                                    ..
68                                }) = nv.value
69                                {
70                                    soft_delete_field = Some(s.value());
71                                }
72                            }
73                        }
74                    }
75                }
76            } else if let syn::Meta::NameValue(nv) = &attr.meta {
77                // 单个 NameValue 的情况
78                if nv.path.is_ident("table") {
79                    if let syn::Expr::Lit(syn::ExprLit {
80                        lit: syn::Lit::Str(s),
81                        ..
82                    }) = &nv.value
83                    {
84                        table_name = Some(s.value());
85                    }
86                } else if nv.path.is_ident("pk") {
87                    if let syn::Expr::Lit(syn::ExprLit {
88                        lit: syn::Lit::Str(s),
89                        ..
90                    }) = &nv.value
91                    {
92                        pk_field = Some(s.value());
93                    }
94                } else if nv.path.is_ident("soft_delete") {
95                    if let syn::Expr::Lit(syn::ExprLit {
96                        lit: syn::Lit::Str(s),
97                        ..
98                    }) = &nv.value
99                    {
100                        soft_delete_field = Some(s.value());
101                    }
102                }
103            }
104        }
105    }
106
107    // 如果没有指定表名,使用结构体名称的小写蛇形命名方式
108    let table = table_name.unwrap_or_else(|| {
109        let s = name.to_string();
110        // 将 PascalCase 转换为 snake_case
111        let mut result = String::new();
112        for (i, c) in s.chars().enumerate() {
113            if c.is_uppercase() && i > 0 {
114                result.push('_');
115            }
116            result.push(c.to_ascii_lowercase());
117        }
118        result
119    });
120
121    // 如果没有指定主键,默认使用 "id"
122    let pk = pk_field.unwrap_or_else(|| "id".to_string());
123
124    // 生成实现代码
125    let expanded = if let Some(soft_delete) = soft_delete_field {
126        // 如果指定了逻辑删除字段,生成包含 SOFT_DELETE_FIELD 的实现
127        let soft_delete_lit = syn::LitStr::new(&soft_delete, proc_macro2::Span::call_site());
128        quote! {
129            impl sqlxplus::Model for #name {
130                const TABLE: &'static str = #table;
131                const PK: &'static str = #pk;
132                const SOFT_DELETE_FIELD: Option<&'static str> = Some(#soft_delete_lit);
133            }
134        }
135    } else {
136        // 如果没有指定逻辑删除字段,SOFT_DELETE_FIELD 为 None
137        quote! {
138            impl sqlxplus::Model for #name {
139                const TABLE: &'static str = #table;
140                const PK: &'static str = #pk;
141                const SOFT_DELETE_FIELD: Option<&'static str> = None;
142            }
143        }
144    };
145
146    TokenStream::from(expanded)
147}
148
149/// 生成 CRUD trait 的实现
150///
151/// 自动生成 insert 和 update 方法的实现
152///
153/// 使用示例:
154/// ```ignore
155/// // 物理删除模式
156/// #[derive(CRUD, FromRow, ModelMeta)]
157/// #[model(table = "users", pk = "id")]
158/// struct User {
159///     id: i64,
160///     name: String,
161///     email: String,
162/// }
163///
164/// // 逻辑删除模式
165/// #[derive(CRUD, FromRow, ModelMeta)]
166/// #[model(table = "users", pk = "id", soft_delete = "is_deleted")]
167/// struct UserWithSoftDelete {
168///     id: i64,
169///     name: String,
170///     email: String,
171///     is_deleted: i32, // 逻辑删除字段
172/// }
173/// ```
174#[proc_macro_derive(CRUD, attributes(model, skip))]
175pub fn derive_crud(input: TokenStream) -> TokenStream {
176    let input = parse_macro_input!(input as DeriveInput);
177    let name = &input.ident;
178
179    // 解析 #[model(pk = "...")],获取主键字段名,默认 "id"
180    let mut pk_field = None;
181    for attr in &input.attrs {
182        if attr.path().is_ident("model") {
183            if let syn::Meta::List(list) = &attr.meta {
184                let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
185                if let Ok(metas) = parser.parse2(list.tokens.clone()) {
186                    for meta in metas {
187                        if let Meta::NameValue(nv) = meta {
188                            if nv.path.is_ident("pk") {
189                                if let syn::Expr::Lit(syn::ExprLit {
190                                    lit: syn::Lit::Str(s),
191                                    ..
192                                }) = nv.value
193                                {
194                                    pk_field = Some(s.value());
195                                }
196                            }
197                        }
198                    }
199                }
200            } else if let syn::Meta::NameValue(nv) = &attr.meta {
201                if nv.path.is_ident("pk") {
202                    if let syn::Expr::Lit(syn::ExprLit {
203                        lit: syn::Lit::Str(s),
204                        ..
205                    }) = &nv.value
206                    {
207                        pk_field = Some(s.value());
208                    }
209                }
210            }
211        }
212    }
213    // 如果没有指定主键,默认使用 "id"
214    let pk = pk_field.unwrap_or_else(|| "id".to_string());
215
216    // 获取字段列表(必须是具名字段的结构体)
217    let fields = match &input.data {
218        Data::Struct(DataStruct {
219            fields: Fields::Named(fields),
220            ..
221        }) => &fields.named,
222        _ => {
223            return syn::Error::new_spanned(
224                name,
225                "CRUD derive only supports structs with named fields",
226            )
227            .to_compile_error()
228            .into();
229        }
230    };
231
232    // 收集字段信息
233    // - pk_ident: 主键字段 Ident
234    // - insert_*/update_*: 非主键字段(INSERT / UPDATE 使用)
235    let mut pk_ident_opt: Option<&syn::Ident> = None;
236
237    // INSERT 使用的字段(非主键)
238    let mut insert_normal_field_names: Vec<&syn::Ident> = Vec::new();
239    let mut insert_normal_field_columns: Vec<syn::LitStr> = Vec::new();
240    let mut insert_option_field_names: Vec<&syn::Ident> = Vec::new();
241    let mut insert_option_field_columns: Vec<syn::LitStr> = Vec::new();
242
243    // UPDATE 使用的字段(非主键)
244    let mut update_normal_field_names: Vec<&syn::Ident> = Vec::new();
245    let mut update_normal_field_columns: Vec<syn::LitStr> = Vec::new();
246    let mut update_option_field_names: Vec<&syn::Ident> = Vec::new();
247    let mut update_option_field_columns: Vec<syn::LitStr> = Vec::new();
248
249    for field in fields {
250        let field_name = field.ident.as_ref().unwrap();
251        let field_name_str = field_name.to_string();
252
253        // 检查属性:skip / model
254        let mut skip = false;
255        for attr in &field.attrs {
256            if attr.path().is_ident("skip") || attr.path().is_ident("model") {
257                skip = true;
258                break;
259            }
260        }
261
262        if !skip {
263            if field_name_str == pk {
264                // 记录主键字段
265                pk_ident_opt = Some(field_name);
266            } else {
267                // 非主键字段用于 INSERT / UPDATE
268                let is_opt = is_option_type(&field.ty);
269                let col_lit = syn::LitStr::new(&field_name_str, proc_macro2::Span::call_site());
270
271                if is_opt {
272                    insert_option_field_names.push(field_name);
273                    insert_option_field_columns.push(col_lit.clone());
274
275                    update_option_field_names.push(field_name);
276                    update_option_field_columns.push(col_lit);
277                } else {
278                    insert_normal_field_names.push(field_name);
279                    insert_normal_field_columns.push(col_lit.clone());
280
281                    update_normal_field_names.push(field_name);
282                    update_normal_field_columns.push(col_lit);
283                }
284            }
285        }
286    }
287
288    // 编译期确保主键字段存在
289    let pk_ident = pk_ident_opt.expect("Primary key field not found in struct");
290
291    // 生成实现代码
292    let expanded = quote! {
293        #[async_trait::async_trait]
294        impl sqlxplus::Crud for #name {
295            async fn insert(&self, pool: &sqlxplus::DbPool) -> sqlxplus::db_pool::Result<sqlxplus::crud::Id> {
296                use sqlxplus::Model;
297                use sqlxplus::utils::escape_identifier;
298                let table = Self::TABLE;
299                let driver = pool.driver();
300                let escaped_table = escape_identifier(driver, table);
301
302                // 构建列名和占位符:
303                // - 非 Option 字段:始终参与 INSERT
304                // - Option 字段:仅在 Some 时参与 INSERT(None 则跳过,让数据库用默认值)
305                let mut columns: Vec<&str> = Vec::new();
306                let mut placeholders: Vec<&str> = Vec::new();
307
308                // 非 Option 字段:始终参与 INSERT
309                #(
310                    columns.push(#insert_normal_field_columns);
311                    placeholders.push("?");
312                )*
313
314                // Option 字段:仅当为 Some 时参与 INSERT
315                #(
316                    if self.#insert_option_field_names.is_some() {
317                        columns.push(#insert_option_field_columns);
318                        placeholders.push("?");
319                    }
320                )*
321
322                let raw_sql = format!(
323                    "INSERT INTO {} ({}) VALUES ({})",
324                    escaped_table,
325                    columns.join(", "),
326                    placeholders.join(", ")
327                );
328                let sql = pool.convert_sql(&raw_sql);
329
330                match pool.driver() {
331                    sqlxplus::db_pool::DbDriver::MySql => {
332                        let pool_ref = pool.mysql_pool().ok_or(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)?;
333                        let mut query = sqlx::query(&sql);
334                        // 非 Option 字段:始终绑定
335                        #(
336                            query = query.bind(&self.#insert_normal_field_names);
337                        )*
338                        // Option 字段:仅当为 Some 时绑定
339                        #(
340                            if self.#insert_option_field_names.is_some() {
341                                query = query.bind(&self.#insert_option_field_names);
342                            }
343                        )*
344                        let result = query
345                            .execute(pool_ref)
346                            .await?;
347                        Ok(result.last_insert_id() as i64)
348                    }
349                    sqlxplus::db_pool::DbDriver::Postgres => {
350                        let pool_ref = pool.pg_pool().ok_or(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)?;
351                        let pk = Self::PK;
352                        use sqlxplus::utils::escape_identifier;
353                        let escaped_pk = escape_identifier(sqlxplus::db_pool::DbDriver::Postgres, pk);
354                        // 为 PostgreSQL 添加 RETURNING 子句
355                        let sql_with_returning = format!("{} RETURNING {}", sql, escaped_pk);
356                        let mut query = sqlx::query_scalar::<_, i64>(&sql_with_returning);
357                        // 非 Option 字段:始终绑定
358                        #(
359                            query = query.bind(&self.#insert_normal_field_names);
360                        )*
361                        // Option 字段:仅当为 Some 时绑定
362                        #(
363                            if self.#insert_option_field_names.is_some() {
364                                query = query.bind(&self.#insert_option_field_names);
365                            }
366                        )*
367                        let id: i64 = query
368                            .fetch_one(pool_ref)
369                            .await?;
370                        Ok(id)
371                    }
372                    sqlxplus::db_pool::DbDriver::Sqlite => {
373                        let pool_ref = pool.sqlite_pool().ok_or(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)?;
374                        let mut query = sqlx::query(&sql);
375                        // 非 Option 字段:始终绑定
376                        #(
377                            query = query.bind(&self.#insert_normal_field_names);
378                        )*
379                        // Option 字段:仅当为 Some 时绑定
380                        #(
381                            if self.#insert_option_field_names.is_some() {
382                                query = query.bind(&self.#insert_option_field_names);
383                            }
384                        )*
385                        let result = query
386                            .execute(pool_ref)
387                            .await?;
388                        Ok(result.last_insert_rowid() as i64)
389                    }
390                }
391            }
392
393            async fn update(&self, pool: &sqlxplus::DbPool) -> sqlxplus::db_pool::Result<()> {
394                use sqlxplus::Model;
395                use sqlxplus::utils::escape_identifier;
396                let table = Self::TABLE;
397                let pk = Self::PK;
398                let driver = pool.driver();
399                let escaped_table = escape_identifier(driver, table);
400                let escaped_pk = escape_identifier(driver, pk);
401
402                // 构建 UPDATE SET 子句(Patch 语义):
403                // - 非 Option 字段:始终参与更新
404                // - Option 字段:仅当为 Some 时参与更新
405                let mut set_parts: Vec<String> = Vec::new();
406
407                // 非 Option 字段
408                #(
409                    set_parts.push(format!("{} = ?", #update_normal_field_columns));
410                )*
411
412                // Option 字段
413                #(
414                    if self.#update_option_field_names.is_some() {
415                        set_parts.push(format!("{} = ?", #update_option_field_columns));
416                    }
417                )*
418
419                let raw_sql = if !set_parts.is_empty() {
420                    format!(
421                        "UPDATE {} SET {} WHERE {} = ?",
422                        escaped_table,
423                        set_parts.join(", "),
424                        escaped_pk,
425                    )
426                } else {
427                    // 没有需要更新的字段,直接返回 Ok(())
428                    return Ok(());
429                };
430
431                let sql = pool.convert_sql(&raw_sql);
432
433                match pool.driver() {
434                    sqlxplus::db_pool::DbDriver::MySql => {
435                        let pool_ref = pool.mysql_pool().ok_or(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)?;
436                        let mut query = sqlx::query(&sql);
437                        // 非 Option 字段:始终绑定
438                        #(
439                            query = query.bind(&self.#update_normal_field_names);
440                        )*
441                        // Option 字段:仅当为 Some 时绑定
442                        #(
443                            if self.#update_option_field_names.is_some() {
444                                query = query.bind(&self.#update_option_field_names);
445                            }
446                        )*
447                        query
448                            .bind(&self.#pk_ident)
449                            .execute(pool_ref)
450                            .await?;
451                    }
452                    sqlxplus::db_pool::DbDriver::Postgres => {
453                        let pool_ref = pool.pg_pool().ok_or(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)?;
454                        let mut query = sqlx::query(&sql);
455                        // 非 Option 字段:始终绑定
456                        #(
457                            query = query.bind(&self.#update_normal_field_names);
458                        )*
459                        // Option 字段:仅当为 Some 时绑定
460                        #(
461                            if self.#update_option_field_names.is_some() {
462                                query = query.bind(&self.#update_option_field_names);
463                            }
464                        )*
465                        query
466                            .bind(&self.#pk_ident)
467                            .execute(pool_ref)
468                            .await?;
469                    }
470                    sqlxplus::db_pool::DbDriver::Sqlite => {
471                        let pool_ref = pool.sqlite_pool().ok_or(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)?;
472                        let mut query = sqlx::query(&sql);
473                        // 非 Option 字段:始终绑定
474                        #(
475                            query = query.bind(&self.#update_normal_field_names);
476                        )*
477                        // Option 字段:仅当为 Some 时绑定
478                        #(
479                            if self.#update_option_field_names.is_some() {
480                                query = query.bind(&self.#update_option_field_names);
481                            }
482                        )*
483                        query
484                            .bind(&self.#pk_ident)
485                            .execute(pool_ref)
486                            .await?;
487                    }
488                }
489                Ok(())
490            }
491
492            /// 更新记录(包含 Option 字段为 None 的重置)
493            ///
494            /// - 非 Option 字段:始终参与更新(col = ?)
495            /// - Option 字段:
496            ///   - Some(v):col = ?
497            ///   - None:col = DEFAULT(由数据库决定置空或使用默认值)
498            async fn update_with_none(&self, pool: &sqlxplus::DbPool) -> sqlxplus::db_pool::Result<()> {
499                use sqlxplus::Model;
500                use sqlxplus::utils::escape_identifier;
501                let table = Self::TABLE;
502                let pk = Self::PK;
503                let driver = pool.driver();
504                let escaped_table = escape_identifier(driver, table);
505                let escaped_pk = escape_identifier(driver, pk);
506
507                // 构建 UPDATE SET 子句(Reset 语义)
508                let mut set_parts: Vec<String> = Vec::new();
509
510                // 非 Option 字段:始终更新为当前值
511                #(
512                    set_parts.push(format!("{} = ?", #update_normal_field_columns));
513                )*
514
515                // Option 字段:Some -> ?,None -> DEFAULT/NULL(根据驱动类型)
516                // SQLite 不支持 DEFAULT,且不可空字段不能设置为 NULL,所以跳过 None 字段的更新
517                match driver {
518                    sqlxplus::db_pool::DbDriver::Sqlite => {
519                        // SQLite: 仅更新 Some 值的字段,跳过 None 字段(因为 SQLite 不支持 DEFAULT)
520                        #(
521                            if self.#update_option_field_names.is_some() {
522                                set_parts.push(format!("{} = ?", #update_option_field_columns));
523                            }
524                            // None 字段跳过,不包含在 SET 子句中
525                        )*
526                    }
527                    _ => {
528                        // MySQL/PostgreSQL: None -> DEFAULT
529                        #(
530                            if self.#update_option_field_names.is_some() {
531                                set_parts.push(format!("{} = ?", #update_option_field_columns));
532                            } else {
533                                set_parts.push(format!("{} = DEFAULT", #update_option_field_columns));
534                            }
535                        )*
536                    }
537                }
538
539                if set_parts.is_empty() {
540                    return Ok(());
541                }
542
543                let raw_sql = format!(
544                    "UPDATE {} SET {} WHERE {} = ?",
545                    escaped_table,
546                    set_parts.join(", "),
547                    escaped_pk,
548                );
549
550                let sql = pool.convert_sql(&raw_sql);
551
552                match pool.driver() {
553                    sqlxplus::db_pool::DbDriver::MySql => {
554                        let pool_ref = pool.mysql_pool().ok_or(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)?;
555                        let mut query = sqlx::query(&sql);
556                        // 非 Option 字段:始终绑定
557                        #(
558                            query = query.bind(&self.#update_normal_field_names);
559                        )*
560                        // Option 字段:仅当为 Some 时绑定(None 使用 DEFAULT)
561                        #(
562                            if self.#update_option_field_names.is_some() {
563                                query = query.bind(&self.#update_option_field_names);
564                            }
565                        )*
566                        query
567                            .bind(&self.#pk_ident)
568                            .execute(pool_ref)
569                            .await?;
570                    }
571                    sqlxplus::db_pool::DbDriver::Postgres => {
572                        let pool_ref = pool.pg_pool().ok_or(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)?;
573                        let mut query = sqlx::query(&sql);
574                        // 非 Option 字段:始终绑定
575                        #(
576                            query = query.bind(&self.#update_normal_field_names);
577                        )*
578                        // Option 字段:仅当为 Some 时绑定
579                        #(
580                            if self.#update_option_field_names.is_some() {
581                                query = query.bind(&self.#update_option_field_names);
582                            }
583                        )*
584                        query
585                            .bind(&self.#pk_ident)
586                            .execute(pool_ref)
587                            .await?;
588                    }
589                    sqlxplus::db_pool::DbDriver::Sqlite => {
590                        let pool_ref = pool.sqlite_pool().ok_or(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)?;
591                        let mut query = sqlx::query(&sql);
592                        // 非 Option 字段:始终绑定
593                        #(
594                            query = query.bind(&self.#update_normal_field_names);
595                        )*
596                        // Option 字段:仅当为 Some 时绑定
597                        #(
598                            if self.#update_option_field_names.is_some() {
599                                query = query.bind(&self.#update_option_field_names);
600                            }
601                        )*
602                        query
603                            .bind(&self.#pk_ident)
604                            .execute(pool_ref)
605                            .await?;
606                    }
607                }
608                Ok(())
609            }
610        }
611    };
612
613    TokenStream::from(expanded)
614}
615
616/// 判断字段类型是否为 Option<T>
617fn is_option_type(ty: &syn::Type) -> bool {
618    if let syn::Type::Path(type_path) = ty {
619        if let Some(seg) = type_path.path.segments.last() {
620            if seg.ident == "Option" {
621                if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
622                    return args.args.len() == 1;
623                }
624            }
625        }
626    }
627    false
628}
629