sqlxplus_derive/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2;
3use quote::quote;
4use syn::{parse::Parser, parse_macro_input, Data, DataStruct, DeriveInput, Fields, Meta};
5
6/// 生成 Model trait 的实现
7///
8/// 自动生成 `TABLE`、`PK` 和可选的 `SOFT_DELETE_FIELD` 常量
9///
10/// 使用示例:
11/// ```ignore
12/// // 物理删除模式(默认)
13/// #[derive(ModelMeta)]
14/// #[model(table = "users", pk = "id")]
15/// struct User {
16///     id: i64,
17///     name: String,
18/// }
19///
20/// // 逻辑删除模式
21/// #[derive(ModelMeta)]
22/// #[model(table = "users", pk = "id", soft_delete = "is_deleted")]
23/// struct UserWithSoftDelete {
24///     id: i64,
25///     name: String,
26///     is_deleted: i32, // 逻辑删除字段:0=未删除,1=已删除
27/// }
28/// ```
29#[proc_macro_derive(ModelMeta, attributes(model))]
30pub fn derive_model_meta(input: TokenStream) -> TokenStream {
31    let input = parse_macro_input!(input as DeriveInput);
32    let name = &input.ident;
33
34    // 解析属性
35    let mut table_name = None;
36    let mut pk_field = None;
37    let mut soft_delete_field = None;
38
39    for attr in &input.attrs {
40        if attr.path().is_ident("model") {
41            // 在 syn 2.0 中,使用 meta() 方法获取元数据
42            if let syn::Meta::List(list) = &attr.meta {
43                // 解析列表中的每个 Meta::NameValue
44                let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
45                if let Ok(metas) = parser.parse2(list.tokens.clone()) {
46                    for meta in metas {
47                        if let Meta::NameValue(nv) = meta {
48                            if nv.path.is_ident("table") {
49                                if let syn::Expr::Lit(syn::ExprLit {
50                                    lit: syn::Lit::Str(s),
51                                    ..
52                                }) = nv.value
53                                {
54                                    table_name = Some(s.value());
55                                }
56                            } else if nv.path.is_ident("pk") {
57                                if let syn::Expr::Lit(syn::ExprLit {
58                                    lit: syn::Lit::Str(s),
59                                    ..
60                                }) = nv.value
61                                {
62                                    pk_field = Some(s.value());
63                                }
64                            } else if nv.path.is_ident("soft_delete") {
65                                if let syn::Expr::Lit(syn::ExprLit {
66                                    lit: syn::Lit::Str(s),
67                                    ..
68                                }) = nv.value
69                                {
70                                    soft_delete_field = Some(s.value());
71                                }
72                            }
73                        }
74                    }
75                }
76            } else if let syn::Meta::NameValue(nv) = &attr.meta {
77                // 单个 NameValue 的情况
78                if nv.path.is_ident("table") {
79                    if let syn::Expr::Lit(syn::ExprLit {
80                        lit: syn::Lit::Str(s),
81                        ..
82                    }) = &nv.value
83                    {
84                        table_name = Some(s.value());
85                    }
86                } else if nv.path.is_ident("pk") {
87                    if let syn::Expr::Lit(syn::ExprLit {
88                        lit: syn::Lit::Str(s),
89                        ..
90                    }) = &nv.value
91                    {
92                        pk_field = Some(s.value());
93                    }
94                } else if nv.path.is_ident("soft_delete") {
95                    if let syn::Expr::Lit(syn::ExprLit {
96                        lit: syn::Lit::Str(s),
97                        ..
98                    }) = &nv.value
99                    {
100                        soft_delete_field = Some(s.value());
101                    }
102                }
103            }
104        }
105    }
106
107    // 如果没有指定表名,使用结构体名称的小写蛇形命名方式
108    let table = table_name.unwrap_or_else(|| {
109        let s = name.to_string();
110        // 将 PascalCase 转换为 snake_case
111        let mut result = String::new();
112        for (i, c) in s.chars().enumerate() {
113            if c.is_uppercase() && i > 0 {
114                result.push('_');
115            }
116            result.push(c.to_ascii_lowercase());
117        }
118        result
119    });
120
121    // 如果没有指定主键,默认使用 "id"
122    let pk = pk_field.unwrap_or_else(|| "id".to_string());
123
124    // 生成实现代码
125    let expanded = if let Some(soft_delete) = soft_delete_field {
126        // 如果指定了逻辑删除字段,生成包含 SOFT_DELETE_FIELD 的实现
127        let soft_delete_lit = syn::LitStr::new(&soft_delete, proc_macro2::Span::call_site());
128        quote! {
129            impl sqlxplus::Model for #name {
130                const TABLE: &'static str = #table;
131                const PK: &'static str = #pk;
132                const SOFT_DELETE_FIELD: Option<&'static str> = Some(#soft_delete_lit);
133            }
134        }
135    } else {
136        // 如果没有指定逻辑删除字段,SOFT_DELETE_FIELD 为 None
137        quote! {
138            impl sqlxplus::Model for #name {
139                const TABLE: &'static str = #table;
140                const PK: &'static str = #pk;
141                const SOFT_DELETE_FIELD: Option<&'static str> = None;
142            }
143        }
144    };
145
146    TokenStream::from(expanded)
147}
148
149/// 生成 CRUD trait 的实现
150///
151/// 自动生成 insert 和 update 方法的实现
152///
153/// 使用示例:
154/// ```ignore
155/// // 物理删除模式
156/// #[derive(CRUD, FromRow, ModelMeta)]
157/// #[model(table = "users", pk = "id")]
158/// struct User {
159///     id: i64,
160///     name: String,
161///     email: String,
162/// }
163///
164/// // 逻辑删除模式
165/// #[derive(CRUD, FromRow, ModelMeta)]
166/// #[model(table = "users", pk = "id", soft_delete = "is_deleted")]
167/// struct UserWithSoftDelete {
168///     id: i64,
169///     name: String,
170///     email: String,
171///     is_deleted: i32, // 逻辑删除字段
172/// }
173/// ```
174#[proc_macro_derive(CRUD, attributes(model, skip))]
175pub fn derive_crud(input: TokenStream) -> TokenStream {
176    let input = parse_macro_input!(input as DeriveInput);
177    let name = &input.ident;
178
179    // 解析 #[model(pk = "...")],获取主键字段名,默认 "id"
180    let mut pk_field = None;
181    for attr in &input.attrs {
182        if attr.path().is_ident("model") {
183            if let syn::Meta::List(list) = &attr.meta {
184                let parser = syn::punctuated::Punctuated::<Meta, syn::Token![,]>::parse_terminated;
185                if let Ok(metas) = parser.parse2(list.tokens.clone()) {
186                    for meta in metas {
187                        if let Meta::NameValue(nv) = meta {
188                            if nv.path.is_ident("pk") {
189                                if let syn::Expr::Lit(syn::ExprLit {
190                                    lit: syn::Lit::Str(s),
191                                    ..
192                                }) = nv.value
193                                {
194                                    pk_field = Some(s.value());
195                                }
196                            }
197                        }
198                    }
199                }
200            } else if let syn::Meta::NameValue(nv) = &attr.meta {
201                if nv.path.is_ident("pk") {
202                    if let syn::Expr::Lit(syn::ExprLit {
203                        lit: syn::Lit::Str(s),
204                        ..
205                    }) = &nv.value
206                    {
207                        pk_field = Some(s.value());
208                    }
209                }
210            }
211        }
212    }
213    // 如果没有指定主键,默认使用 "id"
214    let pk = pk_field.unwrap_or_else(|| "id".to_string());
215
216    // 获取字段列表(必须是具名字段的结构体)
217    let fields = match &input.data {
218        Data::Struct(DataStruct {
219            fields: Fields::Named(fields),
220            ..
221        }) => &fields.named,
222        _ => {
223            return syn::Error::new_spanned(
224                name,
225                "CRUD derive only supports structs with named fields",
226            )
227            .to_compile_error()
228            .into();
229        }
230    };
231
232    // 收集字段信息
233    // - field_names / field_columns: 所有字段(包括主键)
234    // - insert_field_*: 用于 INSERT(排除主键)
235    // - update_field_*: 用于 UPDATE SET 子句(排除主键)
236    let mut field_names = Vec::new();
237    let mut field_columns = Vec::new();
238    let mut skip_fields = Vec::new();
239    let mut insert_field_names = Vec::new();
240    let mut insert_field_columns = Vec::new();
241    let mut update_field_names = Vec::new();
242    // 主键字段的 Ident,用于绑定 WHERE pk = ?
243    let mut pk_ident_opt: Option<&syn::Ident> = None;
244
245    for field in fields {
246        let field_name = field.ident.as_ref().unwrap();
247        let field_name_str = field_name.to_string();
248
249        // 检查是否有 skip 属性
250        let mut skip = false;
251        for attr in &field.attrs {
252            if attr.path().is_ident("skip") || attr.path().is_ident("model") {
253                skip = true;
254                break;
255            }
256        }
257
258        if !skip {
259            field_names.push(field_name);
260            field_columns.push(field_name_str.clone());
261            if field_name_str == pk {
262                // 记录主键字段
263                pk_ident_opt = Some(field_name);
264            } else {
265                // 非主键字段用于 INSERT / UPDATE
266                insert_field_names.push(field_name);
267                insert_field_columns.push(field_name_str.clone());
268                update_field_names.push(field_name);
269            }
270        } else {
271            skip_fields.push(field_name_str);
272        }
273    }
274
275    // 编译期确保主键字段存在
276    let pk_ident = pk_ident_opt.expect("Primary key field not found in struct");
277
278    // 生成 insert SQL(排除主键列,依赖数据库自增 / identity)
279    let insert_fields: Vec<String> = insert_field_columns.clone();
280    let insert_placeholders: Vec<String> =
281        (0..insert_fields.len()).map(|_| "?".to_string()).collect();
282    let insert_sql = format!(
283        "INSERT INTO {} ({}) VALUES ({})",
284        format!("{{TABLE}}"), // 占位符,运行时替换
285        insert_fields.join(", "),
286        insert_placeholders.join(", ")
287    );
288
289    // 生成 update SQL(排除主键列,只更新非主键列)
290    let update_fields: Vec<String> = field_columns
291        .iter()
292        .filter(|f| *f != &pk)
293        .cloned()
294        .collect();
295    let update_sql = if !update_fields.is_empty() {
296        format!(
297            "UPDATE {} SET {} WHERE {} = ?",
298            format!("{{TABLE}}"),
299            update_fields
300                .iter()
301                .map(|f| format!("{} = ?", f))
302                .collect::<Vec<_>>()
303                .join(", "),
304            format!("{{PK}}")
305        )
306    } else {
307        String::new()
308    };
309
310    // 生成实现代码
311    let expanded = quote! {
312        #[async_trait::async_trait]
313        impl sqlxplus::Crud for #name {
314            async fn insert(&self, pool: &sqlxplus::DbPool) -> sqlxplus::db_pool::Result<sqlxplus::crud::Id> {
315                use sqlxplus::Model;
316                use sqlxplus::utils::escape_identifier;
317                let table = Self::TABLE;
318                let driver = pool.driver();
319                let escaped_table = escape_identifier(driver, table);
320                let sql = #insert_sql.replace("{TABLE}", &escaped_table);
321                let sql = pool.convert_sql(&sql);
322
323                match pool.driver() {
324                    sqlxplus::db_pool::DbDriver::MySql => {
325                        let pool_ref = pool.mysql_pool().ok_or(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)?;
326                        let result = sqlx::query(&sql)
327                            #( .bind(&self.#insert_field_names) )*
328                            .execute(pool_ref)
329                            .await?;
330                        Ok(result.last_insert_id() as i64)
331                    }
332                    sqlxplus::db_pool::DbDriver::Postgres => {
333                        let pool_ref = pool.pg_pool().ok_or(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)?;
334                        let pk = Self::PK;
335                        use sqlxplus::utils::escape_identifier;
336                        let escaped_pk = escape_identifier(sqlxplus::db_pool::DbDriver::Postgres, pk);
337                        // 为 PostgreSQL 添加 RETURNING 子句
338                        let sql_with_returning = format!("{} RETURNING {}", sql, escaped_pk);
339                        let id: i64 = sqlx::query_scalar(&sql_with_returning)
340                            #( .bind(&self.#insert_field_names) )*
341                            .fetch_one(pool_ref)
342                            .await?;
343                        Ok(id)
344                    }
345                    sqlxplus::db_pool::DbDriver::Sqlite => {
346                        let pool_ref = pool.sqlite_pool().ok_or(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)?;
347                        let result = sqlx::query(&sql)
348                            #( .bind(&self.#insert_field_names) )*
349                            .execute(pool_ref)
350                            .await?;
351                        Ok(result.last_insert_rowid() as i64)
352                    }
353                }
354            }
355
356            async fn update(&self, pool: &sqlxplus::DbPool) -> sqlxplus::db_pool::Result<()> {
357                use sqlxplus::Model;
358                use sqlxplus::utils::escape_identifier;
359                let table = Self::TABLE;
360                let pk = Self::PK;
361                let driver = pool.driver();
362                let escaped_table = escape_identifier(driver, table);
363                let escaped_pk = escape_identifier(driver, pk);
364                let sql = #update_sql.replace("{TABLE}", &escaped_table).replace("{PK}", &escaped_pk);
365                let sql = pool.convert_sql(&sql);
366
367                match pool.driver() {
368                    sqlxplus::db_pool::DbDriver::MySql => {
369                        let pool_ref = pool.mysql_pool().ok_or(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)?;
370                        sqlx::query(&sql)
371                            #( .bind(&self.#update_field_names) )*
372                            .bind(&self.#pk_ident)
373                            .execute(pool_ref)
374                            .await?;
375                    }
376                    sqlxplus::db_pool::DbDriver::Postgres => {
377                        let pool_ref = pool.pg_pool().ok_or(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)?;
378                        sqlx::query(&sql)
379                            #( .bind(&self.#update_field_names) )*
380                            .bind(&self.#pk_ident)
381                            .execute(pool_ref)
382                            .await?;
383                    }
384                    sqlxplus::db_pool::DbDriver::Sqlite => {
385                        let pool_ref = pool.sqlite_pool().ok_or(sqlxplus::db_pool::DbPoolError::NoPoolAvailable)?;
386                        sqlx::query(&sql)
387                            #( .bind(&self.#update_field_names) )*
388                            .bind(&self.#pk_ident)
389                            .execute(pool_ref)
390                            .await?;
391                    }
392                }
393                Ok(())
394            }
395        }
396    };
397
398    TokenStream::from(expanded)
399}