sqlx_askama_template_macro/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{format_ident, quote};
4use std::collections::BTreeSet;
5use syn::{
6    DeriveInput, LifetimeParam, LitStr, Meta, Path, Token, parse::Parser, parse_macro_input,
7    punctuated::Punctuated,
8};
9
10// 用于比较类型的辅助结构
11#[derive(Ord, PartialOrd, Eq, PartialEq)]
12struct TypeIdentifier(String);
13
14fn get_type_identifier(ty: &syn::Type) -> TypeIdentifier {
15    TypeIdentifier(quote!(#ty).to_string())
16}
17/// 处理并增强 `#[template]` 属性,添加必要的默认值
18fn process_template_attr(input: &DeriveInput) -> Punctuated<Meta, Token![,]> {
19    let mut args = Punctuated::<Meta, Token![,]>::new();
20    for attr in &input.attrs {
21        if !attr.path().is_ident("template") {
22            continue;
23        }
24        // 处理template属性
25        let mut has_askama = false;
26        let mut has_source = false;
27        let mut has_ext = false;
28
29        let nested = match attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated) {
30            Ok(n) => n,
31            Err(_) => continue,
32        };
33        for meta in nested {
34            if meta.path().is_ident("source") {
35                has_source = true;
36            }
37            if meta.path().is_ident("ext") {
38                has_ext = true;
39            }
40            if meta.path().is_ident("askama") {
41                has_askama = true;
42            }
43            args.push(meta);
44        }
45
46        // 设置默认值
47
48        if !has_askama {
49            let askama_meta = Meta::NameValue(syn::MetaNameValue {
50                path: syn::Path::from(syn::Ident::new("askama", Span::call_site())),
51                eq_token: <syn::Token![=]>::default(),
52                value: syn::Expr::Path(syn::ExprPath {
53                    attrs: Vec::new(),
54                    qself: None,
55                    path: syn::parse_str::<Path>("::sqlx_askama_template::askama").unwrap(),
56                }),
57            });
58            args.push_punct(Token![,](Span::call_site()));
59            args.push_value(askama_meta);
60        }
61
62        if has_source && !has_ext {
63            // 添加 ext = "txt"
64            let ext_meta = Meta::NameValue(syn::MetaNameValue {
65                path: syn::Path::from(syn::Ident::new("ext", Span::call_site())),
66                eq_token: <syn::Token![=]>::default(),
67                value: syn::Expr::Lit(syn::ExprLit {
68                    attrs: Vec::new(),
69                    lit: syn::Lit::Str(LitStr::new("txt", Span::call_site())),
70                }),
71            });
72            args.push_punct(Token![,](Span::call_site()));
73            args.push_value(ext_meta);
74        }
75    }
76
77    args
78}
79
80#[proc_macro_derive(SqlTemplate, attributes(template, add_type, ignore_type))]
81pub fn sql_template(input: TokenStream) -> TokenStream {
82    let input = parse_macro_input!(input as DeriveInput);
83    let name = &input.ident;
84    let generics = &input.generics;
85    //处理template
86    let template_attrs = process_template_attr(&input);
87
88    // 处理生命周期参数
89    let (mut wrapper_generics, data_lifetime) = if let Some(lt) = generics.lifetimes().next() {
90        let generics = generics.clone();
91        let lt_ident = &lt.lifetime;
92        (generics, quote! { #lt_ident })
93    } else {
94        let mut generics = generics.clone();
95        let lifetime = LifetimeParam::new(syn::Lifetime::new("'q", proc_macro2::Span::call_site()));
96        generics
97            .params
98            .insert(0, syn::GenericParam::Lifetime(lifetime));
99        (generics, quote! { 'q })
100    };
101
102    // 添加DB类型参数
103    wrapper_generics
104        .params
105        .push(syn::GenericParam::Type(syn::TypeParam {
106            attrs: Vec::new(),
107            ident: format_ident!("DB"),
108            colon_token: None,
109            bounds: syn::punctuated::Punctuated::new(),
110            eq_token: None,
111            default: None,
112        }));
113
114    // 收集需要绑定的类型
115    let mut seen_types = BTreeSet::new();
116    let mut bound_types = proc_macro2::TokenStream::new();
117
118    // 处理字段类型
119    if let syn::Data::Struct(data_struct) = &input.data {
120        for field in &data_struct.fields {
121            let has_ignore = field
122                .attrs
123                .iter()
124                .any(|attr| attr.path().is_ident("ignore_type"));
125            if !has_ignore {
126                let ty = &field.ty;
127                let ident = get_type_identifier(ty);
128                if seen_types.insert(ident) {
129                    bound_types.extend(quote! {
130                        #ty: ::sqlx::Encode<#data_lifetime, DB> + ::sqlx::Type<DB> + #data_lifetime,
131                    });
132                }
133            }
134        }
135    }
136
137    // 处理addtype属性
138    for attr in &input.attrs {
139        if attr.path().is_ident("add_type") {
140            if let Meta::List(meta_list) = &attr.meta {
141                let parser =
142                    syn::punctuated::Punctuated::<syn::Type, syn::Token![,]>::parse_terminated;
143                if let Ok(types) = parser.parse2(meta_list.tokens.clone()) {
144                    for ty in types {
145                        let ident = get_type_identifier(&ty);
146                        if seen_types.insert(ident) {
147                            bound_types.extend(quote! {
148                                #ty: ::sqlx::Encode<#data_lifetime, DB> + ::sqlx::Type<DB> + #data_lifetime,
149                            });
150                        }
151                    }
152                }
153            }
154        }
155    }
156
157    let (_impl_generics, ty_generics, where_clause) = generics.split_for_impl();
158    let where_clause = where_clause.map_or_else(|| quote! { where }, |wc| quote! { #wc });
159    let (wrapper_impl_generics, _, _) = wrapper_generics.split_for_impl();
160
161    let expanded = quote! {
162        impl #wrapper_impl_generics ::sqlx_askama_template::SqlTemplate<#data_lifetime, DB>
163            for &#data_lifetime #name #ty_generics
164            #where_clause
165            DB: ::sqlx::Database,
166            #bound_types
167        {
168            fn render_sql_with_encode_placeholder_fn(
169                self,
170                f: ::std::option::Option<fn(usize, &mut String)>,
171                sql_buffer: &mut String,
172            ) -> ::std::result::Result<
173                ::std::option::Option<DB::Arguments<#data_lifetime>>,
174                ::sqlx::Error,
175            > {
176                #[derive(::sqlx_askama_template::askama::Template)]
177                #[template(#template_attrs)]
178                struct Wrapper #wrapper_generics (
179                    ::sqlx_askama_template::TemplateArg<#data_lifetime, DB, #name #ty_generics>
180                ) #where_clause
181                    DB: ::sqlx::Database,
182                    #bound_types;
183
184                impl #wrapper_impl_generics ::std::ops::Deref for Wrapper #wrapper_generics
185                    #where_clause
186                    DB: ::sqlx::Database,
187                    #bound_types
188                {
189                    type Target = ::sqlx_askama_template::TemplateArg<#data_lifetime, DB, #name #ty_generics>;
190                    fn deref(&self) -> &Self::Target {
191                        &self.0
192                    }
193                }
194
195                let mut wrapper = Wrapper(::sqlx_askama_template::TemplateArg::new(self));
196                if let Some(f) = f {
197                    wrapper.0.set_encode_placeholder_fn(f);
198                }
199                let render_res = ::sqlx_askama_template::askama::Template::render_into(&wrapper, sql_buffer)
200                    .map_err(|e| ::sqlx::Error::Encode(::std::boxed::Box::new(e)))?;
201                let arg = wrapper.get_arguments();
202                let encode_err = wrapper.get_err();
203
204                if let Some(e) = encode_err {
205                    return ::std::result::Result::Err(e);
206                }
207                ::std::result::Result::Ok(arg)
208            }
209        }
210    };
211
212    expanded.into()
213}