sqlx_askama_template_macro/
lib.rs1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::{format_ident, quote};
4use std::collections::BTreeSet;
5use syn::{
6 DeriveInput, LifetimeParam, LitStr, Meta, Path, Token, parse::Parser, parse_macro_input,
7 punctuated::Punctuated,
8};
9
10#[derive(Ord, PartialOrd, Eq, PartialEq)]
12struct TypeIdentifier(String);
13
14fn get_type_identifier(ty: &syn::Type) -> TypeIdentifier {
15 TypeIdentifier(quote!(#ty).to_string())
16}
17fn process_template_attr(input: &DeriveInput) -> Punctuated<Meta, Token![,]> {
19 let mut args = Punctuated::<Meta, Token![,]>::new();
20 for attr in &input.attrs {
21 if !attr.path().is_ident("template") {
22 continue;
23 }
24 let mut has_askama = false;
26 let mut has_source = false;
27 let mut has_ext = false;
28
29 let nested = match attr.parse_args_with(Punctuated::<Meta, Token![,]>::parse_terminated) {
30 Ok(n) => n,
31 Err(_) => continue,
32 };
33 for meta in &nested {
34 if meta.path().is_ident("source") {
35 has_source = true;
36 }
37 if meta.path().is_ident("ext") {
38 has_ext = true;
39 }
40 if meta.path().is_ident("askama") {
41 has_askama = true;
42 }
43 args.push(meta.clone());
44 }
45
46 if !has_askama {
49 let askama_meta = Meta::NameValue(syn::MetaNameValue {
50 path: syn::Path::from(syn::Ident::new("askama", Span::call_site())),
51 eq_token: <syn::Token![=]>::default(),
52 value: syn::Expr::Path(syn::ExprPath {
53 attrs: Vec::new(),
54 qself: None,
55 path: syn::parse_str::<Path>("::sqlx_askama_template::askama").unwrap(),
56 }),
57 });
58 args.push_punct(Token));
59 args.push_value(askama_meta);
60 }
61
62 if has_source && !has_ext {
63 let ext_meta = Meta::NameValue(syn::MetaNameValue {
65 path: syn::Path::from(syn::Ident::new("ext", Span::call_site())),
66 eq_token: <syn::Token![=]>::default(),
67 value: syn::Expr::Lit(syn::ExprLit {
68 attrs: Vec::new(),
69 lit: syn::Lit::Str(LitStr::new("txt", Span::call_site())),
70 }),
71 });
72 args.push_punct(Token));
73 args.push_value(ext_meta);
74 }
75 }
76
77 args
78}
79
80#[proc_macro_derive(SqlTemplate, attributes(template, add_type, ignore_type))]
81pub fn sql_template(input: TokenStream) -> TokenStream {
82 let input = parse_macro_input!(input as DeriveInput);
83 let name = &input.ident;
84 let generics = &input.generics;
85 let template_attrs = process_template_attr(&input);
87
88 let (mut wrapper_generics, data_lifetime) = if let Some(lt) = generics.lifetimes().next() {
90 let generics = generics.clone();
91 let lt_ident = <.lifetime;
92 (generics, quote! { #lt_ident })
93 } else {
94 let mut generics = generics.clone();
95 let lifetime = LifetimeParam::new(syn::Lifetime::new("'q", proc_macro2::Span::call_site()));
96 generics
97 .params
98 .insert(0, syn::GenericParam::Lifetime(lifetime));
99 (generics, quote! { 'q })
100 };
101
102 wrapper_generics
104 .params
105 .push(syn::GenericParam::Type(syn::TypeParam {
106 attrs: Vec::new(),
107 ident: format_ident!("DB"),
108 colon_token: None,
109 bounds: syn::punctuated::Punctuated::new(),
110 eq_token: None,
111 default: None,
112 }));
113
114 let mut seen_types = BTreeSet::new();
116 let mut bound_types = proc_macro2::TokenStream::new();
117
118 if let syn::Data::Struct(data_struct) = &input.data {
120 for field in &data_struct.fields {
121 let has_ignore = field
122 .attrs
123 .iter()
124 .any(|attr| attr.path().is_ident("ignore_type"));
125 if !has_ignore {
126 let ty = &field.ty;
127 let ident = get_type_identifier(ty);
128 if seen_types.insert(ident) {
129 bound_types.extend(quote! {
130 #ty: ::sqlx::Encode<#data_lifetime, DB> + ::sqlx::Type<DB> + #data_lifetime,
131 });
132 }
133 }
134 }
135 }
136
137 for attr in &input.attrs {
139 if attr.path().is_ident("add_type") {
140 if let Meta::List(meta_list) = &attr.meta {
141 let parser =
142 syn::punctuated::Punctuated::<syn::Type, syn::Token![,]>::parse_terminated;
143 if let Ok(types) = parser.parse2(meta_list.tokens.clone()) {
144 for ty in types {
145 let ident = get_type_identifier(&ty);
146 if seen_types.insert(ident) {
147 bound_types.extend(quote! {
148 #ty: ::sqlx::Encode<#data_lifetime, DB> + ::sqlx::Type<DB> + #data_lifetime,
149 });
150 }
151 }
152 }
153 }
154 }
155 }
156
157 let (_impl_generics, ty_generics, where_clause) = generics.split_for_impl();
158 let where_clause = where_clause.map_or_else(|| quote! { where }, |wc| quote! { #wc });
159 let (wrapper_impl_generics, _, _) = wrapper_generics.split_for_impl();
160
161 let expanded = quote! {
162 impl #wrapper_impl_generics ::sqlx_askama_template::SqlTemplate<#data_lifetime, DB>
163 for &#data_lifetime #name #ty_generics
164 #where_clause
165 DB: ::sqlx::Database,
166 #bound_types
167 {
168 fn render_sql_with_encode_placeholder_fn(
169 self,
170 f: ::std::option::Option<fn(usize, &mut String)>,
171 sql_buffer: &mut String,
172 ) -> ::std::result::Result<
173 ::std::option::Option<DB::Arguments<#data_lifetime>>,
174 ::sqlx::Error,
175 > {
176 #[derive(::sqlx_askama_template::askama::Template)]
177 #[template(#template_attrs)]
178 struct Wrapper #wrapper_generics (
179 ::sqlx_askama_template::TemplateArg<#data_lifetime, DB, #name #ty_generics>
180 ) #where_clause
181 DB: ::sqlx::Database,
182 #bound_types;
183
184 impl #wrapper_impl_generics ::std::ops::Deref for Wrapper #wrapper_generics
185 #where_clause
186 DB: ::sqlx::Database,
187 #bound_types
188 {
189 type Target = ::sqlx_askama_template::TemplateArg<#data_lifetime, DB, #name #ty_generics>;
190 fn deref(&self) -> &Self::Target {
191 &self.0
192 }
193 }
194
195 let mut wrapper = Wrapper(::sqlx_askama_template::TemplateArg::new(self));
196 if let Some(f) = f {
197 wrapper.0.set_encode_placeholder_fn(f);
198 }
199 let render_res = ::sqlx_askama_template::askama::Template::render_into(&wrapper, sql_buffer)
200 .map_err(|e| ::sqlx::Error::Encode(::std::boxed::Box::new(e)))?;
201 let arg = wrapper.get_arguments();
202 let encode_err = wrapper.get_err();
203
204 if let Some(e) = encode_err {
205 return ::std::result::Result::Err(e);
206 }
207 ::std::result::Result::Ok(arg)
208 }
209 }
210 };
211
212 expanded.into()
213}