Skip to main content

wae_macros/
lib.rs

1//! WAE Macros - 过程宏模块
2//!
3//! 提供用于简化开发的过程宏:
4//! - `#[derive(ToSchema)]` - 自动生成 Schema 定义
5//! - `query!` - 编译时 SQL 查询宏
6//! - `query_as!` - 编译时 SQL 查询宏(自动映射到结构体)
7
8#![warn(missing_docs)]
9
10use proc_macro::TokenStream;
11use proc_macro2::TokenStream as TokenStream2;
12use quote::quote;
13use syn::{Data, DeriveInput, Expr, Fields, GenericArgument, Ident, Lit, Meta, PathArguments, Type, parse_macro_input};
14
15/// 获取类型的 Schema 定义
16///
17/// 正确处理基础类型、`Vec<T>`、`Option<T>` 等泛型类型。
18fn get_type_schema(ty: &Type) -> TokenStream2 {
19    match ty {
20        Type::Path(type_path) => {
21            let last_segment = type_path.path.segments.last();
22            if let Some(segment) = last_segment {
23                let ident = &segment.ident;
24
25                match &segment.arguments {
26                    PathArguments::AngleBracketed(angle_bracketed) => {
27                        if let Some(GenericArgument::Type(inner_ty)) = angle_bracketed.args.first() {
28                            let inner_schema = get_type_schema(inner_ty);
29                            if ident == "Vec" {
30                                return quote! { wae_schema::Schema::array(#inner_schema) };
31                            }
32                            else if ident == "Option" {
33                                return quote! { #inner_schema.nullable(true) };
34                            }
35                        }
36                    }
37                    _ => {}
38                }
39
40                let ident_str = ident.to_string();
41                if ident_str == "String" || ident_str == "&str" {
42                    return quote! { wae_schema::Schema::string() };
43                }
44                else if matches!(
45                    ident_str.as_str(),
46                    "i8" | "i16" | "i32" | "i64" | "i128" | "isize" | "u8" | "u16" | "u32" | "u64" | "u128" | "usize"
47                ) {
48                    return quote! { wae_schema::Schema::integer() };
49                }
50                else if ident_str == "f32" || ident_str == "f64" {
51                    return quote! { wae_schema::Schema::number() };
52                }
53                else if ident_str == "bool" {
54                    return quote! { wae_schema::Schema::boolean() };
55                }
56            }
57        }
58        Type::Reference(type_ref) => {
59            return get_type_schema(&type_ref.elem);
60        }
61        _ => {}
62    }
63    quote! { <#ty as wae_schema::ToSchema>::schema() }
64}
65
66/// 提取文档注释
67///
68/// 合并多行文档注释为单个字符串。
69fn extract_doc_comment(attrs: &[syn::Attribute]) -> Option<String> {
70    let mut docs = Vec::new();
71    for attr in attrs {
72        if attr.path().is_ident("doc") {
73            if let Meta::NameValue(meta) = &attr.meta {
74                if let Expr::Lit(expr_lit) = &meta.value {
75                    if let Lit::Str(lit_str) = &expr_lit.lit {
76                        let doc = lit_str.value().trim().to_string();
77                        if !doc.is_empty() {
78                            docs.push(doc);
79                        }
80                    }
81                }
82            }
83        }
84    }
85    if docs.is_empty() { None } else { Some(docs.join("\n")) }
86}
87
88/// 检查类型是否为 Option 类型
89fn is_option_type(ty: &Type) -> bool {
90    if let Type::Path(type_path) = ty {
91        if let Some(segment) = type_path.path.segments.last() {
92            return segment.ident == "Option";
93        }
94    }
95    false
96}
97
98/// 生成结构体的 Schema 实现
99fn generate_struct_schema(name: &Ident, fields: &Fields, attrs: &[syn::Attribute]) -> TokenStream2 {
100    let mut properties = Vec::new();
101    let mut required_fields = Vec::new();
102    let type_doc = extract_doc_comment(attrs);
103
104    match fields {
105        Fields::Named(fields_named) => {
106            for field in &fields_named.named {
107                let field_name = field.ident.as_ref().unwrap();
108                let field_name_str = field_name.to_string();
109                let field_schema = get_type_schema(&field.ty);
110
111                let doc_comment = extract_doc_comment(&field.attrs);
112                let property = if let Some(doc) = doc_comment {
113                    quote! {
114                        .property(#field_name_str, #field_schema.description(#doc))
115                    }
116                }
117                else {
118                    quote! {
119                        .property(#field_name_str, #field_schema)
120                    }
121                };
122                properties.push(property);
123
124                if !is_option_type(&field.ty) {
125                    required_fields.push(field_name_str);
126                }
127            }
128        }
129        Fields::Unnamed(fields_unnamed) => {
130            for (i, field) in fields_unnamed.unnamed.iter().enumerate() {
131                let field_name_str = format!("field_{}", i);
132                let field_schema = get_type_schema(&field.ty);
133                let property = quote! {
134                    .property(#field_name_str, #field_schema)
135                };
136                properties.push(property);
137            }
138        }
139        Fields::Unit => {}
140    }
141
142    let required = if required_fields.is_empty() {
143        quote! {}
144    }
145    else {
146        quote! { .required(vec![#(#required_fields.to_string()),*]) }
147    };
148
149    let description = if let Some(doc) = type_doc {
150        quote! { .description(#doc) }
151    }
152    else {
153        quote! {}
154    };
155
156    quote! {
157        impl wae_schema::ToSchema for #name {
158            fn schema() -> wae_schema::Schema {
159                wae_schema::Schema::object()
160                    #description
161                    #(#properties)*
162                    #required
163            }
164        }
165    }
166}
167
168/// 生成枚举的 Schema 实现
169fn generate_enum_schema(name: &Ident, data: &Data, attrs: &[syn::Attribute]) -> TokenStream2 {
170    let variants = match data {
171        Data::Enum(data_enum) => &data_enum.variants,
172        _ => return quote! {},
173    };
174
175    let mut enum_values = Vec::new();
176    for variant in variants {
177        let variant_name = variant.ident.to_string();
178        enum_values.push(quote! { serde_json::Value::String(#variant_name.to_string()) });
179    }
180
181    let type_doc = extract_doc_comment(attrs);
182    let description = if let Some(doc) = type_doc {
183        quote! { .description(#doc) }
184    }
185    else {
186        quote! {}
187    };
188
189    quote! {
190        impl wae_schema::ToSchema for #name {
191            fn schema() -> wae_schema::Schema {
192                wae_schema::Schema::string()
193                    #description
194                    .enum_values(vec![#(#enum_values),*])
195            }
196        }
197    }
198}
199
200/// 自动生成 Schema 的派生宏
201///
202/// 为结构体或枚举自动生成 `ToSchema` trait 实现。
203///
204/// # Example
205///
206/// ```rust,ignore
207/// use serde::{Deserialize, Serialize};
208/// use wae_schema::{Schema, ToSchema};
209///
210/// /// 用户信息
211/// #[derive(Debug, Serialize, Deserialize, ToSchema)]
212/// pub struct User {
213///     /// 用户 ID
214///     pub id: u64,
215///     /// 用户名
216///     pub name: String,
217///     /// 邮箱(可选)
218///     pub email: Option<String>,
219/// }
220/// ```
221#[proc_macro_derive(ToSchema)]
222pub fn derive_schema(input: TokenStream) -> TokenStream {
223    let input = parse_macro_input!(input as DeriveInput);
224    let name = &input.ident;
225    let attrs = &input.attrs;
226
227    let expanded = match &input.data {
228        Data::Struct(data_struct) => generate_struct_schema(name, &data_struct.fields, attrs),
229        Data::Enum(_) => generate_enum_schema(name, &input.data, attrs),
230        Data::Union(_) => {
231            return syn::Error::new_spanned(name, "Unions are not supported").to_compile_error().into();
232        }
233    };
234
235    TokenStream::from(expanded)
236}
237
238/// 生成 OpenAPI 路由文档的宏
239///
240/// # Example
241///
242/// ```rust,ignore
243/// use wae_macros::api_doc;
244///
245/// let doc = api_doc! {
246///     "/users" => {
247///         GET => {
248///             summary: "获取用户列表",
249///             response: 200 => "成功",
250///         },
251///         POST => {
252///             summary: "创建用户",
253///             body: "User",
254///             response: 201 => "创建成功",
255///         },
256///     },
257/// };
258/// ```
259#[proc_macro]
260pub fn api_doc(input: TokenStream) -> TokenStream {
261    let _ = input;
262    TokenStream::from(quote! {
263        wae_schema::OpenApiDoc::new("API", "1.0.0")
264    })
265}
266
267mod query;
268
269/// SQL 查询宏 - 返回原始行数据
270///
271/// 执行 SQL 查询并返回 `DatabaseRows`,需要手动迭代处理结果。
272///
273/// # Example
274///
275/// ```rust,ignore
276/// use wae_macros::query;
277/// use wae_database::{DatabaseConnection, DatabaseRow};
278///
279/// async fn get_users(conn: &dyn DatabaseConnection) -> Result<Vec<DatabaseRow>, Box<dyn std::error::Error>> {
280///     let mut rows = query!(conn, "SELECT id, name, email FROM users WHERE active = ?", true).await?;
281///     let mut results = Vec::new();
282///     while let Some(row) = rows.next().await? {
283///         results.push(row);
284///     }
285///     Ok(results)
286/// }
287/// ```
288#[proc_macro]
289pub fn query(input: TokenStream) -> TokenStream {
290    query::expand_query(input)
291}
292
293/// SQL 查询宏 - 自动映射到结构体
294///
295/// 执行 SQL 查询并自动将结果映射到指定结构体类型。
296/// 结构体需要实现 `FromRow` trait。
297///
298/// # Example
299///
300/// ```rust,ignore
301/// use wae_macros::query_as;
302/// use wae_database::{Entity, FromRow, DatabaseRow, DatabaseResult};
303/// use serde::{Deserialize, Serialize};
304///
305/// #[derive(Debug, Clone, Serialize, Deserialize)]
306/// struct User {
307///     id: i64,
308///     name: String,
309///     email: String,
310/// }
311///
312/// impl FromRow for User {
313///     fn from_row(row: &DatabaseRow) -> DatabaseResult<Self> {
314///         Ok(Self {
315///             id: row.get(0)?,
316///             name: row.get(1)?,
317///             email: row.get(2)?,
318///         })
319///     }
320/// }
321///
322/// async fn get_users(conn: &dyn DatabaseConnection) -> Result<Vec<User>, Box<dyn std::error::Error>> {
323///     let users = query_as!(User, conn, "SELECT id, name, email FROM users WHERE active = ?", true).await?;
324///     Ok(users)
325/// }
326/// ```
327#[proc_macro]
328pub fn query_as(input: TokenStream) -> TokenStream {
329    query::expand_query_as(input)
330}
331
332/// 执行宏 - 执行 INSERT/UPDATE/DELETE 等 SQL 语句
333///
334/// 执行 SQL 语句并返回影响的行数。
335///
336/// # Example
337///
338/// ```rust,ignore
339/// use wae_macros::execute;
340/// use wae_database::DatabaseConnection;
341///
342/// async fn insert_user(conn: &dyn DatabaseConnection) -> Result<u64, Box<dyn std::error::Error>> {
343///     let affected = execute!(conn, "INSERT INTO users (name, email) VALUES (?, ?)", "Alice", "alice@example.com").await?;
344///     Ok(affected)
345/// }
346/// ```
347#[proc_macro]
348pub fn execute(input: TokenStream) -> TokenStream {
349    query::expand_execute(input)
350}
351
352/// 标量查询宏 - 返回单个值
353///
354/// 执行查询并返回单个值,适用于 COUNT、SUM 等聚合查询。
355///
356/// # Example
357///
358/// ```rust,ignore
359/// use wae_macros::query_scalar;
360/// use wae_database::DatabaseConnection;
361///
362/// async fn count_users(conn: &dyn DatabaseConnection) -> Result<i64, Box<dyn std::error::Error>> {
363///     let count = query_scalar!(i64, conn, "SELECT COUNT(*) FROM users").await?;
364///     Ok(count)
365/// }
366/// ```
367#[proc_macro]
368pub fn query_scalar(input: TokenStream) -> TokenStream {
369    query::expand_query_scalar(input)
370}
371
372/// 使用效果宏 - 获取依赖的便捷宏
373///
374/// 支持多种语法:
375/// - `use_effect!(effectful, MyType)` - 按类型获取依赖
376/// - `use_effect!(effectful, "name", MyType)` - 按字符串键和类型获取依赖
377/// - `use_effect!(effectful, config, MyConfig)` - 便捷地获取配置
378/// - `use_effect!(effectful, auth, MyAuthService)` - 便捷地获取认证服务
379///
380/// # Example
381///
382/// ```rust,ignore
383/// use wae_macros::use_effect;
384///
385/// async fn handler(effectful: Effectful) -> WaeResult<()> {
386///     let config: MyConfig = use_effect!(effectful, MyConfig)?;
387///     let auth: Arc<dyn AuthService> = use_effect!(effectful, auth, Arc<dyn AuthService>)?;
388///     Ok(())
389/// }
390/// ```
391#[proc_macro]
392pub fn use_effect(input: TokenStream) -> TokenStream {
393    let parsed = syn::parse_macro_input!(input as UseEffectInput);
394
395    let expanded = match parsed {
396        UseEffectInput::TypeOnly { effectful, ty } => {
397            quote! {
398                #effectful.get_type::<#ty>()
399            }
400        }
401        UseEffectInput::Named { effectful, name, ty } => {
402            quote! {
403                #effectful.get::<#ty>(#name)
404            }
405        }
406        UseEffectInput::Config { effectful, ty } => {
407            quote! {
408                #effectful.use_config::<#ty>()
409            }
410        }
411        UseEffectInput::Auth { effectful, ty } => {
412            quote! {
413                #effectful.use_auth::<#ty>()
414            }
415        }
416    };
417
418    TokenStream::from(expanded)
419}
420
421enum UseEffectInput {
422    TypeOnly { effectful: syn::Expr, ty: syn::Type },
423    Named { effectful: syn::Expr, name: syn::LitStr, ty: syn::Type },
424    Config { effectful: syn::Expr, ty: syn::Type },
425    Auth { effectful: syn::Expr, ty: syn::Type },
426}
427
428impl syn::parse::Parse for UseEffectInput {
429    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
430        let effectful: syn::Expr = input.parse()?;
431        let _: syn::Token![,] = input.parse()?;
432
433        if input.peek(syn::Ident) {
434            let ident: syn::Ident = input.parse()?;
435            let _: syn::Token![,] = input.parse()?;
436            let ty: syn::Type = input.parse()?;
437
438            if ident == "config" {
439                Ok(UseEffectInput::Config { effectful, ty })
440            }
441            else if ident == "auth" {
442                Ok(UseEffectInput::Auth { effectful, ty })
443            }
444            else {
445                Err(syn::Error::new_spanned(ident, "Expected 'config' or 'auth'"))
446            }
447        }
448        else if input.peek(syn::LitStr) {
449            let name: syn::LitStr = input.parse()?;
450            let _: syn::Token![,] = input.parse()?;
451            let ty: syn::Type = input.parse()?;
452            Ok(UseEffectInput::Named { effectful, name, ty })
453        }
454        else {
455            let ty: syn::Type = input.parse()?;
456            Ok(UseEffectInput::TypeOnly { effectful, ty })
457        }
458    }
459}