sqlx_askama_template_macro/
lib.rs1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{format_ident, quote};
4use std::collections::BTreeSet;
5use syn::{
6 DeriveInput, LifetimeParam, LitStr, Meta, Path, Token, parse::Parser, parse_macro_input,
7 punctuated::Punctuated,
8};
9
10#[derive(Ord, PartialOrd, Eq, PartialEq)]
12struct TypeIdentifier(String);
13
14fn get_type_identifier(ty: &syn::Type) -> TypeIdentifier {
15 TypeIdentifier(quote!(#ty).to_string())
16}
17fn process_template_attr(input: &DeriveInput) -> Punctuated<Meta, Token![,]> {
19 let mut args = Punctuated::<Meta, Token![,]>::new();
20 for attr in &input.attrs {
21 if !attr.path().is_ident("template") {
22 continue;
23 }
24 let mut has_askama = false;
26 let mut has_source = false;
27 let mut has_ext = false;
28
29 let nested = match attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated) {
30 Ok(n) => n,
31 Err(_) => continue,
32 };
33 for meta in nested {
34 if meta.path().is_ident("source") {
35 has_source = true;
36 }
37 if meta.path().is_ident("ext") {
38 has_ext = true;
39 }
40 if meta.path().is_ident("askama") {
41 has_askama = true;
42 }
43 args.push(meta);
44 }
45
46 if !has_askama {
49 let askama_meta = Meta::NameValue(syn::MetaNameValue {
50 path: syn::Path::from(syn::Ident::new("askama", Span::call_site())),
51 eq_token: <syn::Token![=]>::default(),
52 value: syn::Expr::Path(syn::ExprPath {
53 attrs: Vec::new(),
54 qself: None,
55 path: syn::parse_str::<Path>("::sqlx_askama_template::askama").unwrap(),
56 }),
57 });
58 args.push_punct(Token));
59 args.push_value(askama_meta);
60 }
61
62 if has_source && !has_ext {
63 let ext_meta = Meta::NameValue(syn::MetaNameValue {
65 path: syn::Path::from(syn::Ident::new("ext", Span::call_site())),
66 eq_token: <syn::Token![=]>::default(),
67 value: syn::Expr::Lit(syn::ExprLit {
68 attrs: Vec::new(),
69 lit: syn::Lit::Str(LitStr::new("txt", Span::call_site())),
70 }),
71 });
72 args.push_punct(Token));
73 args.push_value(ext_meta);
74 }
75 }
76
77 args
78}
79
80#[proc_macro_derive(SqlTemplate, attributes(template, add_type, ignore_type))]
81pub fn sql_template(input: TokenStream) -> TokenStream {
82 let input = parse_macro_input!(input as DeriveInput);
83 let name = &input.ident;
84 let generics = &input.generics;
85 let template_attrs = process_template_attr(&input);
87
88 let (mut wrapper_generics, data_lifetime) = if let Some(lt) = generics.lifetimes().next() {
90 let generics = generics.clone();
91 let lt_ident = <.lifetime;
92 (generics, quote! { #lt_ident })
93 } else {
94 let mut generics = generics.clone();
95 let lifetime = LifetimeParam::new(syn::Lifetime::new("'q", proc_macro2::Span::call_site()));
96 generics
97 .params
98 .insert(0, syn::GenericParam::Lifetime(lifetime));
99 (generics, quote! { 'q })
100 };
101
102 wrapper_generics
104 .params
105 .push(syn::GenericParam::Type(syn::TypeParam {
106 attrs: Vec::new(),
107 ident: format_ident!("DB"),
108 colon_token: None,
109 bounds: syn::punctuated::Punctuated::new(),
110 eq_token: None,
111 default: None,
112 }));
113
114 let mut seen_types = BTreeSet::new();
116 let mut bound_types = proc_macro2::TokenStream::new();
117
118 if let syn::Data::Struct(data_struct) = &input.data {
120 for field in &data_struct.fields {
121 let has_ignore = field
122 .attrs
123 .iter()
124 .any(|attr| attr.path().is_ident("ignore_type"));
125 if !has_ignore {
126 let ty = &field.ty;
127 let ident = get_type_identifier(ty);
128 if seen_types.insert(ident) {
129 bound_types.extend(quote! {
130 #ty: ::sqlx::Encode<#data_lifetime, DB> + ::sqlx::Type<DB>,
131 });
132 }
133 }
134 }
135 }
136
137 for attr in &input.attrs {
139 if attr.path().is_ident("add_type")
140 && let Meta::List(meta_list) = &attr.meta
141 {
142 let parser = syn::punctuated::Punctuated::<syn::Type, syn::Token![,]>::parse_terminated;
143 if let Ok(types) = parser.parse2(meta_list.tokens.clone()) {
144 for ty in types {
145 let ident = get_type_identifier(&ty);
146 let have_lifetime = ident.0.contains('\'');
147 if seen_types.insert(ident) {
148 if have_lifetime {
149 bound_types.extend(quote! {
151 #ty: ::sqlx::Encode<#data_lifetime, DB> + ::sqlx::Type<DB>,
152 });
153 } else {
154 bound_types.extend(quote! {
155 #ty: for<'template_local_lifetime> ::sqlx::Encode<'template_local_lifetime, DB> + ::sqlx::Type<DB>,
156 });
157 }
158 }
159 }
160 }
161 }
162 }
163
164 let (_impl_generics, ty_generics, where_clause) = generics.split_for_impl();
165 let where_clause = where_clause.map_or_else(|| quote! { where }, |wc| quote! { #wc });
166 let (wrapper_impl_generics, _, _) = wrapper_generics.split_for_impl();
167
168 let expanded = quote! {
169 impl #wrapper_impl_generics ::sqlx_askama_template::SqlTemplate<#data_lifetime, DB>
170 for &#data_lifetime #name #ty_generics
171 #where_clause
172 DB: ::sqlx::Database,
173 #bound_types
174 {
175 fn render_sql_with_encode_placeholder_fn(
176 self,
177 f: ::std::option::Option<fn(usize, &mut String)>,
178 sql_buffer: &mut String,
179 ) -> ::std::result::Result<
180 ::std::option::Option<DB::Arguments>,
181 ::sqlx::Error,
182 > {
183 #[derive(::sqlx_askama_template::askama::Template)]
184 #[template(#template_attrs)]
185 struct Wrapper #wrapper_generics (
186 ::sqlx_askama_template::TemplateArg<#data_lifetime,DB, #name #ty_generics>
187 ) #where_clause
188 DB: ::sqlx::Database,
189 #bound_types;
190
191 impl #wrapper_impl_generics ::std::ops::Deref for Wrapper #wrapper_generics
192 #where_clause
193 DB: ::sqlx::Database,
194 #bound_types
195 {
196 type Target = ::sqlx_askama_template::TemplateArg<#data_lifetime, DB, #name #ty_generics>;
197 fn deref(&self) -> &Self::Target {
198 &self.0
199 }
200 }
201
202 let mut wrapper = Wrapper(::sqlx_askama_template::TemplateArg::new(self));
203 if let Some(f) = f {
204 wrapper.0.set_encode_placeholder_fn(f);
205 }
206 let render_res = ::sqlx_askama_template::askama::Template::render_into(&wrapper, sql_buffer)
207 .map_err(|e| ::sqlx::Error::Encode(::std::boxed::Box::new(e)))?;
208 let arg = wrapper.get_arguments();
209 let encode_err = wrapper.get_err();
210
211 if let Some(e) = encode_err {
212 return ::std::result::Result::Err(e);
213 }
214 ::std::result::Result::Ok(arg)
215 }
216 }
217 };
218
219 expanded.into()
220}