sqlx_askama_template_macro/
lib.rs

1use std::collections::BTreeSet;
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::parse::Parser;
6use syn::{DeriveInput, LifetimeParam, Meta, parse_macro_input};
7// 用于比较类型的辅助结构
8#[derive(Ord, PartialOrd, Eq, PartialEq)]
9struct TypeIdentifier(String);
10
11fn get_type_identifier(ty: &syn::Type) -> TypeIdentifier {
12    TypeIdentifier(quote!(#ty).to_string())
13}
14/// Derive macro for generating type-safe SQL templates using Askama.
15///
16/// This macro generates boilerplate code to integrate Askama templates with SQLx queries,
17/// providing compile-time SQL validation and parameter binding.
18///
19/// # Attributes
20///
21/// ## `#[template(...)]` (Required)
22/// Defines the SQL template configuration. Accepts these parameters:
23/// - `source`: Inline SQL template content (supports Askama syntax)
24/// - `ext`: File extension for Askama template engine
25/// - `print`: Debug output options (none|ast|code|all)
26/// - `config`: Path to custom Askama configuration file
27///
28/// ## `#[addtype(...)]` (Optional)
29/// Specifies additional type constraints for template variables:
30/// - Accepts comma-separated types implementing `sqlx::Type + sqlx::Encode`
31/// - Required when using non-field types in template logic
32///
33/// ## `#[ignore_type]` (Optional)
34/// Marks struct fields to skip SQLx type validation:
35/// - Use for fields that shouldn't participate in parameter binding
36/// - Typically used for helper fields or complex types
37///
38/// # Example
39/// ```
40/// use sqlx_askama_template::SqlTemplate;
41///
42/// #[derive(SqlTemplate)]
43/// #[template(
44///     source = r#"
45///     SELECT * FROM users
46///     WHERE name = {{e(name)}}
47///     AND age > {{e(min_age)}}
48///     "#,
49///     ext = "sql"
50/// )]
51/// #[addtype(i32)]
52/// struct UserQuery<'a> {
53///     name: &'a str,
54///     #[ignore_type]
55///     min_age: i32,
56/// }
57/// ```
58///
59/// # Generated Implementation
60/// Implements `SqlTemplate` trait with these methods:
61/// - `render_sql() -> Result<(String, Arguments<DB>)>`
62/// - `render_execute() -> Result<RenderExecute<DB>>`
63///
64/// # Panics
65/// - If required `source` attribute is missing
66/// - If template syntax errors are detected at compile time
67/// - If type constraints for template variables are unsatisfied
68///
69/// # Note
70/// The generated code requires these dependencies in scope:
71/// - `sqlx::{Encode, Type, Arguments}`
72/// - `askama::Template`
73#[proc_macro_derive(SqlTemplate, attributes(template, addtype, ignore_type))]
74pub fn sql_template(input: TokenStream) -> TokenStream {
75    let input = parse_macro_input!(input as DeriveInput);
76    let name: &syn::Ident = &input.ident;
77    let generics = &input.generics;
78
79    //let impl_generics = generics.clone();
80
81    let wrapper_name = format_ident!("{}Wrapper", name);
82
83    // 收集所有template属性
84    let template_attrs: Vec<_> = input
85        .attrs
86        .iter()
87        .filter(|attr| attr.path().is_ident("template"))
88        .collect();
89
90    // 处理生命周期参数
91    let (mut wrapper_generics, data_lifetime) = if let Some(lt) = generics.lifetimes().next() {
92        let generics = generics.clone();
93        let lt_ident = &lt.lifetime;
94        (generics, quote! { #lt_ident })
95    } else {
96        let mut generics = generics.clone();
97        let lifetime = LifetimeParam::new(syn::Lifetime::new("'q", proc_macro2::Span::call_site()));
98        generics
99            .params
100            .insert(0, syn::GenericParam::Lifetime(lifetime));
101        (generics, quote! { 'q })
102    };
103
104    // 添加DB类型参数
105    wrapper_generics
106        .params
107        .push(syn::GenericParam::Type(syn::TypeParam {
108            attrs: Vec::new(),
109            ident: format_ident!("DB"),
110            colon_token: None,
111            bounds: syn::punctuated::Punctuated::new(),
112            eq_token: None,
113            default: None,
114        }));
115    // 使用 BTreeSet 存储唯一类型标识
116    let mut seen_types = BTreeSet::new();
117    // 收集需要绑定的类型
118    let mut bound_types = proc_macro2::TokenStream::new();
119
120    // 1. 处理默认绑定(非ignore_type字段)
121    if let syn::Data::Struct(data_struct) = &input.data {
122        for field in &data_struct.fields {
123            let has_ignore = field
124                .attrs
125                .iter()
126                .any(|attr| attr.path().is_ident("ignore_type"));
127            if !has_ignore {
128                let ty = &field.ty;
129                let ident = get_type_identifier(ty);
130                if seen_types.insert(ident) {
131                    bound_types.extend(quote! {
132                        #ty: ::sqlx::Encode<#data_lifetime, DB> + ::sqlx::Type<DB> + #data_lifetime,
133                    });
134                }
135            }
136        }
137    }
138    // 2. 处理addtype属性添加的类型
139    for attr in &input.attrs {
140        if attr.path().is_ident("addtype") {
141            match &attr.meta {
142                Meta::List(meta_list) => {
143                    let parser =
144                        syn::punctuated::Punctuated::<syn::Type, syn::Token![,]>::parse_terminated;
145                    if let Ok(types) = parser.parse2(meta_list.tokens.clone()) {
146                        for ty in types {
147                            let ident = get_type_identifier(&ty);
148                            if seen_types.insert(ident) {
149                                bound_types.extend(quote! {
150                                    #ty: ::sqlx::Encode<#data_lifetime, DB> + ::sqlx::Type<DB> + #data_lifetime,
151                                });
152                            }
153                        }
154                    }
155                }
156                _ => continue,
157            }
158        }
159    }
160
161    let (_impl_generics, ty_generics, where_clause) = generics.split_for_impl();
162    let where_clause = if where_clause.is_some() {
163        quote! {#where_clause}
164    } else {
165        quote! {where  }
166    };
167    let (wrapper_impl_generics, wrapper_ty_generics, _) = wrapper_generics.split_for_impl();
168
169    let expanded = quote! {
170        #[derive(::askama::Template)]
171        #(#template_attrs)*
172        pub struct #wrapper_name #wrapper_generics
173            #where_clause
174            DB: ::sqlx::Database,
175            #bound_types
176        {
177            pub data: &#data_lifetime #name #ty_generics,
178            pub arguments: ::sqlx_askama_template::TemplateArg<#data_lifetime, DB>,
179        }
180
181        impl #wrapper_impl_generics ::std::ops::Deref for #wrapper_name #wrapper_ty_generics
182            #where_clause
183            DB: ::sqlx::Database,
184            #bound_types
185        {
186            type Target = &#data_lifetime #name #ty_generics;
187
188            fn deref(&self) -> &Self::Target {
189                &self.data
190            }
191        }
192
193        impl #wrapper_impl_generics #wrapper_name #wrapper_ty_generics
194            #where_clause
195            DB: ::sqlx::Database,
196            #bound_types
197        {
198            pub fn e<ImplEncode>(&self, arg: ImplEncode) -> ::std::string::String
199            where
200                ImplEncode: ::sqlx::Encode<#data_lifetime, DB> + ::sqlx::Type<DB> + #data_lifetime,
201            {
202                self.arguments.encode(arg)
203            }
204
205            pub fn el<ImplEncode>(
206                &self,
207                args: impl ::std::iter::IntoIterator<Item = ImplEncode>,
208            ) -> ::std::string::String
209            where
210                ImplEncode: ::sqlx::Encode<#data_lifetime, DB> + ::sqlx::Type<DB> + #data_lifetime,
211            {
212                self.arguments.encode_list(args.into_iter())
213            }
214
215            pub fn et<ImplEncode>(&self, t: &ImplEncode) -> ::std::string::String
216            where
217                ImplEncode: ::sqlx::Encode<#data_lifetime, DB>
218                    + ::sqlx::Type<DB>
219                    + ::std::clone::Clone
220                    + #data_lifetime,
221            {
222                self.arguments.encode(t.clone())
223            }
224
225            pub fn etl<'arg_b, ImplEncode>(
226                &self,
227                args: impl ::std::iter::IntoIterator<Item = &'arg_b ImplEncode>,
228            ) -> ::std::string::String
229            where
230                #data_lifetime: 'arg_b,
231                ImplEncode: ::sqlx::Encode<#data_lifetime, DB>
232                    + ::sqlx::Type<DB>
233                    + ::std::clone::Clone
234                    + #data_lifetime,
235            {
236                let args = args.into_iter().cloned();
237                self.arguments.encode_list(args)
238            }
239        }
240
241        impl #wrapper_impl_generics ::sqlx_askama_template::SqlTemplate<#data_lifetime, DB>
242            for &#data_lifetime #name #ty_generics
243            #where_clause
244            DB: ::sqlx::Database,
245            #bound_types
246
247        {
248            fn render_sql(
249                self,
250            ) -> ::std::result::Result<
251                (
252                    ::std::string::String,
253                    ::std::option::Option<DB::Arguments<#data_lifetime>>,
254                ),
255                ::askama::Error,
256            > {
257                let wrapper: #wrapper_name #wrapper_ty_generics = #wrapper_name {
258                    data: self,
259                    arguments: ::std::default::Default::default(),
260                };
261
262                let sql = ::askama::Template::render(&wrapper)?;
263                if let ::std::option::Option::Some(e) = wrapper.arguments.get_err() {
264                    return ::std::result::Result::Err(::askama::Error::Custom(e));
265                }
266                let arg = wrapper.arguments.get_arguments();
267
268                ::std::result::Result::Ok((sql, arg))
269            }
270        }
271    };
272
273    expanded.into()
274}