sqlx_askama_template_macro/
lib.rs1use std::collections::BTreeSet;
2
3use proc_macro::TokenStream;
4use quote::{format_ident, quote};
5use syn::parse::Parser;
6use syn::{DeriveInput, LifetimeParam, Meta, parse_macro_input};
7#[derive(Ord, PartialOrd, Eq, PartialEq)]
9struct TypeIdentifier(String);
10
11fn get_type_identifier(ty: &syn::Type) -> TypeIdentifier {
12 TypeIdentifier(quote!(#ty).to_string())
13}
14#[proc_macro_derive(SqlTemplate, attributes(template, addtype, ignore_type))]
74pub fn sql_template(input: TokenStream) -> TokenStream {
75 let input = parse_macro_input!(input as DeriveInput);
76 let name: &syn::Ident = &input.ident;
77 let generics = &input.generics;
78
79 let wrapper_name = format_ident!("{}Wrapper", name);
82
83 let template_attrs: Vec<_> = input
85 .attrs
86 .iter()
87 .filter(|attr| attr.path().is_ident("template"))
88 .collect();
89
90 let (mut wrapper_generics, data_lifetime) = if let Some(lt) = generics.lifetimes().next() {
92 let generics = generics.clone();
93 let lt_ident = <.lifetime;
94 (generics, quote! { #lt_ident })
95 } else {
96 let mut generics = generics.clone();
97 let lifetime = LifetimeParam::new(syn::Lifetime::new("'q", proc_macro2::Span::call_site()));
98 generics
99 .params
100 .insert(0, syn::GenericParam::Lifetime(lifetime));
101 (generics, quote! { 'q })
102 };
103
104 wrapper_generics
106 .params
107 .push(syn::GenericParam::Type(syn::TypeParam {
108 attrs: Vec::new(),
109 ident: format_ident!("DB"),
110 colon_token: None,
111 bounds: syn::punctuated::Punctuated::new(),
112 eq_token: None,
113 default: None,
114 }));
115 let mut seen_types = BTreeSet::new();
117 let mut bound_types = proc_macro2::TokenStream::new();
119
120 if let syn::Data::Struct(data_struct) = &input.data {
122 for field in &data_struct.fields {
123 let has_ignore = field
124 .attrs
125 .iter()
126 .any(|attr| attr.path().is_ident("ignore_type"));
127 if !has_ignore {
128 let ty = &field.ty;
129 let ident = get_type_identifier(ty);
130 if seen_types.insert(ident) {
131 bound_types.extend(quote! {
132 #ty: ::sqlx::Encode<#data_lifetime, DB> + ::sqlx::Type<DB> + #data_lifetime,
133 });
134 }
135 }
136 }
137 }
138 for attr in &input.attrs {
140 if attr.path().is_ident("addtype") {
141 match &attr.meta {
142 Meta::List(meta_list) => {
143 let parser =
144 syn::punctuated::Punctuated::<syn::Type, syn::Token![,]>::parse_terminated;
145 if let Ok(types) = parser.parse2(meta_list.tokens.clone()) {
146 for ty in types {
147 let ident = get_type_identifier(&ty);
148 if seen_types.insert(ident) {
149 bound_types.extend(quote! {
150 #ty: ::sqlx::Encode<#data_lifetime, DB> + ::sqlx::Type<DB> + #data_lifetime,
151 });
152 }
153 }
154 }
155 }
156 _ => continue,
157 }
158 }
159 }
160
161 let (_impl_generics, ty_generics, where_clause) = generics.split_for_impl();
162 let where_clause = if where_clause.is_some() {
163 quote! {#where_clause}
164 } else {
165 quote! {where }
166 };
167 let (wrapper_impl_generics, wrapper_ty_generics, _) = wrapper_generics.split_for_impl();
168
169 let expanded = quote! {
170 #[derive(::askama::Template)]
171 #(#template_attrs)*
172 pub struct #wrapper_name #wrapper_generics
173 #where_clause
174 DB: ::sqlx::Database,
175 #bound_types
176 {
177 pub data: &#data_lifetime #name #ty_generics,
178 pub arguments: ::sqlx_askama_template::TemplateArg<#data_lifetime, DB>,
179 }
180
181 impl #wrapper_impl_generics ::std::ops::Deref for #wrapper_name #wrapper_ty_generics
182 #where_clause
183 DB: ::sqlx::Database,
184 #bound_types
185 {
186 type Target = &#data_lifetime #name #ty_generics;
187
188 fn deref(&self) -> &Self::Target {
189 &self.data
190 }
191 }
192
193 impl #wrapper_impl_generics #wrapper_name #wrapper_ty_generics
194 #where_clause
195 DB: ::sqlx::Database,
196 #bound_types
197 {
198 pub fn e<ImplEncode>(&self, arg: ImplEncode) -> ::std::string::String
199 where
200 ImplEncode: ::sqlx::Encode<#data_lifetime, DB> + ::sqlx::Type<DB> + #data_lifetime,
201 {
202 self.arguments.encode(arg)
203 }
204
205 pub fn el<ImplEncode>(
206 &self,
207 args: impl ::std::iter::IntoIterator<Item = ImplEncode>,
208 ) -> ::std::string::String
209 where
210 ImplEncode: ::sqlx::Encode<#data_lifetime, DB> + ::sqlx::Type<DB> + #data_lifetime,
211 {
212 self.arguments.encode_list(args.into_iter())
213 }
214
215 pub fn et<ImplEncode>(&self, t: &ImplEncode) -> ::std::string::String
216 where
217 ImplEncode: ::sqlx::Encode<#data_lifetime, DB>
218 + ::sqlx::Type<DB>
219 + ::std::clone::Clone
220 + #data_lifetime,
221 {
222 self.arguments.encode(t.clone())
223 }
224
225 pub fn etl<'arg_b, ImplEncode>(
226 &self,
227 args: impl ::std::iter::IntoIterator<Item = &'arg_b ImplEncode>,
228 ) -> ::std::string::String
229 where
230 #data_lifetime: 'arg_b,
231 ImplEncode: ::sqlx::Encode<#data_lifetime, DB>
232 + ::sqlx::Type<DB>
233 + ::std::clone::Clone
234 + #data_lifetime,
235 {
236 let args = args.into_iter().cloned();
237 self.arguments.encode_list(args)
238 }
239 }
240
241 impl #wrapper_impl_generics ::sqlx_askama_template::SqlTemplate<#data_lifetime, DB>
242 for &#data_lifetime #name #ty_generics
243 #where_clause
244 DB: ::sqlx::Database,
245 #bound_types
246
247 {
248 fn render_sql(
249 self,
250 ) -> ::std::result::Result<
251 (
252 ::std::string::String,
253 ::std::option::Option<DB::Arguments<#data_lifetime>>,
254 ),
255 ::askama::Error,
256 > {
257 let wrapper: #wrapper_name #wrapper_ty_generics = #wrapper_name {
258 data: self,
259 arguments: ::std::default::Default::default(),
260 };
261
262 let sql = ::askama::Template::render(&wrapper)?;
263 if let ::std::option::Option::Some(e) = wrapper.arguments.get_err() {
264 return ::std::result::Result::Err(::askama::Error::Custom(e));
265 }
266 let arg = wrapper.arguments.get_arguments();
267
268 ::std::result::Result::Ok((sql, arg))
269 }
270 }
271 };
272
273 expanded.into()
274}