Skip to main content

sqlx_askama_template_macro/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{format_ident, quote};
4use std::collections::BTreeSet;
5use syn::{
6    DeriveInput, LifetimeParam, LitStr, Meta, Path, Token, parse::Parser, parse_macro_input,
7    punctuated::Punctuated,
8};
9
10// 用于比较类型的辅助结构
11#[derive(Ord, PartialOrd, Eq, PartialEq)]
12struct TypeIdentifier(String);
13
14fn get_type_identifier(ty: &syn::Type) -> TypeIdentifier {
15    TypeIdentifier(quote!(#ty).to_string())
16}
17/// 处理并增强 `#[template]` 属性,添加必要的默认值
18fn process_template_attr(input: &DeriveInput) -> Punctuated<Meta, Token![,]> {
19    let mut args = Punctuated::<Meta, Token![,]>::new();
20    for attr in &input.attrs {
21        if !attr.path().is_ident("template") {
22            continue;
23        }
24        // 处理template属性
25        let mut has_askama = false;
26        let mut has_source = false;
27        let mut has_ext = false;
28
29        let nested = match attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated) {
30            Ok(n) => n,
31            Err(_) => continue,
32        };
33        for meta in nested {
34            if meta.path().is_ident("source") {
35                has_source = true;
36            }
37            if meta.path().is_ident("ext") {
38                has_ext = true;
39            }
40            if meta.path().is_ident("askama") {
41                has_askama = true;
42            }
43            args.push(meta);
44        }
45
46        // 设置默认值
47
48        if !has_askama {
49            let askama_meta = Meta::NameValue(syn::MetaNameValue {
50                path: syn::Path::from(syn::Ident::new("askama", Span::call_site())),
51                eq_token: <syn::Token![=]>::default(),
52                value: syn::Expr::Path(syn::ExprPath {
53                    attrs: Vec::new(),
54                    qself: None,
55                    path: syn::parse_str::<Path>("::sqlx_askama_template::askama").unwrap(),
56                }),
57            });
58            args.push_punct(Token![,](Span::call_site()));
59            args.push_value(askama_meta);
60        }
61
62        if has_source && !has_ext {
63            // 添加 ext = "txt"
64            let ext_meta = Meta::NameValue(syn::MetaNameValue {
65                path: syn::Path::from(syn::Ident::new("ext", Span::call_site())),
66                eq_token: <syn::Token![=]>::default(),
67                value: syn::Expr::Lit(syn::ExprLit {
68                    attrs: Vec::new(),
69                    lit: syn::Lit::Str(LitStr::new("txt", Span::call_site())),
70                }),
71            });
72            args.push_punct(Token![,](Span::call_site()));
73            args.push_value(ext_meta);
74        }
75    }
76
77    args
78}
79
80#[proc_macro_derive(SqlTemplate, attributes(template, add_type, ignore_type))]
81pub fn sql_template(input: TokenStream) -> TokenStream {
82    let input = parse_macro_input!(input as DeriveInput);
83    let name = &input.ident;
84    let generics = &input.generics;
85    //处理template
86    let template_attrs = process_template_attr(&input);
87
88    // 处理生命周期参数
89    let (mut wrapper_generics, data_lifetime) = if let Some(lt) = generics.lifetimes().next() {
90        let generics = generics.clone();
91        let lt_ident = &lt.lifetime;
92        (generics, quote! { #lt_ident })
93    } else {
94        let mut generics = generics.clone();
95        let lifetime = LifetimeParam::new(syn::Lifetime::new("'q", proc_macro2::Span::call_site()));
96        generics
97            .params
98            .insert(0, syn::GenericParam::Lifetime(lifetime));
99        (generics, quote! { 'q })
100    };
101
102    // 添加DB类型参数
103    wrapper_generics
104        .params
105        .push(syn::GenericParam::Type(syn::TypeParam {
106            attrs: Vec::new(),
107            ident: format_ident!("DB"),
108            colon_token: None,
109            bounds: syn::punctuated::Punctuated::new(),
110            eq_token: None,
111            default: None,
112        }));
113
114    // 收集需要绑定的类型
115    let mut seen_types = BTreeSet::new();
116    let mut bound_types = proc_macro2::TokenStream::new();
117
118    // 处理字段类型
119    if let syn::Data::Struct(data_struct) = &input.data {
120        for field in &data_struct.fields {
121            let has_ignore = field
122                .attrs
123                .iter()
124                .any(|attr| attr.path().is_ident("ignore_type"));
125            if !has_ignore {
126                let ty = &field.ty;
127                let ident = get_type_identifier(ty);
128                if seen_types.insert(ident) {
129                    bound_types.extend(quote! {
130                        #ty: ::sqlx::Encode<#data_lifetime, DB> + ::sqlx::Type<DB>,
131                    });
132                }
133            }
134        }
135    }
136
137    // 处理addtype属性
138    for attr in &input.attrs {
139        if attr.path().is_ident("add_type")
140            && let Meta::List(meta_list) = &attr.meta
141        {
142            let parser = syn::punctuated::Punctuated::<syn::Type, syn::Token![,]>::parse_terminated;
143            if let Ok(types) = parser.parse2(meta_list.tokens.clone()) {
144                for ty in types {
145                    let ident = get_type_identifier(&ty);
146                    let have_lifetime = ident.0.contains('\'');
147                    if seen_types.insert(ident) {
148                        if have_lifetime {
149                            //非引用类型且包含生命周期如slef.Vec<i64>.first()->Option<&'a i64>数据来源自结构体本身的字段生命周期相同;或者如&str这样的静态引用,使用结构体本身生命周期
150                            bound_types.extend(quote! {
151                                #ty: ::sqlx::Encode<#data_lifetime, DB> + ::sqlx::Type<DB>,
152                            });
153                        } else {
154                            bound_types.extend(quote! {
155                            #ty: for<'template_local_lifetime> ::sqlx::Encode<'template_local_lifetime, DB> + ::sqlx::Type<DB>,
156                        });
157                        }
158                    }
159                }
160            }
161        }
162    }
163
164    let (_impl_generics, ty_generics, where_clause) = generics.split_for_impl();
165    let where_clause = where_clause.map_or_else(|| quote! { where }, |wc| quote! { #wc });
166    let (wrapper_impl_generics, _, _) = wrapper_generics.split_for_impl();
167
168    let expanded = quote! {
169        impl #wrapper_impl_generics ::sqlx_askama_template::SqlTemplate<#data_lifetime, DB>
170            for &#data_lifetime #name #ty_generics
171            #where_clause
172            DB: ::sqlx::Database,
173            #bound_types
174        {
175            fn render_sql_with_encode_placeholder_fn(
176                self,
177                f: ::std::option::Option<fn(usize, &mut String)>,
178                sql_buffer: &mut String,
179            ) -> ::std::result::Result<
180                ::std::option::Option<DB::Arguments>,
181                ::sqlx::Error,
182            > {
183                #[derive(::sqlx_askama_template::askama::Template)]
184                #[template(#template_attrs)]
185                struct Wrapper #wrapper_generics (
186                    ::sqlx_askama_template::TemplateArg<#data_lifetime,DB, #name #ty_generics>
187                ) #where_clause
188                    DB: ::sqlx::Database,
189                    #bound_types;
190
191                impl #wrapper_impl_generics ::std::ops::Deref for Wrapper #wrapper_generics
192                    #where_clause
193                    DB: ::sqlx::Database,
194                    #bound_types
195                {
196                    type Target = ::sqlx_askama_template::TemplateArg<#data_lifetime, DB, #name #ty_generics>;
197                    fn deref(&self) -> &Self::Target {
198                        &self.0
199                    }
200                }
201
202                let mut wrapper = Wrapper(::sqlx_askama_template::TemplateArg::new(self));
203                if let Some(f) = f {
204                    wrapper.0.set_encode_placeholder_fn(f);
205                }
206                let render_res = ::sqlx_askama_template::askama::Template::render_into(&wrapper, sql_buffer)
207                    .map_err(|e| ::sqlx::Error::Encode(::std::boxed::Box::new(e)))?;
208                let arg = wrapper.get_arguments();
209                let encode_err = wrapper.get_err();
210
211                if let Some(e) = encode_err {
212                    return ::std::result::Result::Err(e);
213                }
214                ::std::result::Result::Ok(arg)
215            }
216        }
217    };
218
219    expanded.into()
220}