sqlx_conditional_queries_core/
expand.rs1use std::collections::HashMap;
2
3use crate::{lower::LoweredConditionalQueryAs, DatabaseType};
4
5#[derive(Debug, thiserror::Error)]
6pub enum ExpandError {
7 #[error("missing compile-time binding: {0}")]
8 MissingCompileTimeBinding(String, proc_macro2::Span),
9 #[error("missing binding closing brace")]
10 MissingBindingClosingBrace(proc_macro2::Span),
11 #[error("failed to parse type override in binding reference: {0}")]
12 BindingReferenceTypeOverrideParseError(proc_macro2::LexError, proc_macro2::Span),
13}
14
15#[derive(Debug)]
16pub(crate) struct ExpandedConditionalQueryAs {
17 pub(crate) output_type: syn::Ident,
18 pub(crate) match_expressions: Vec<syn::Expr>,
19 pub(crate) match_arms: Vec<MatchArm>,
20}
21
22#[derive(Debug)]
23pub(crate) struct MatchArm {
24 pub(crate) patterns: Vec<syn::Pat>,
25 pub(crate) query_fragments: Vec<syn::LitStr>,
26 pub(crate) run_time_bindings: Vec<(syn::Ident, Option<proc_macro2::TokenStream>)>,
27}
28
29#[derive(Debug)]
31struct RunTimeBinding {
32 indices: Vec<usize>,
37
38 type_override: Option<proc_macro2::TokenStream>,
40}
41
42#[derive(Debug)]
43struct RunTimeBindings {
44 database_type: DatabaseType,
45 counter: usize,
46 bindings: HashMap<syn::LitStr, RunTimeBinding>,
47}
48
49impl RunTimeBindings {
50 fn new(database_type: DatabaseType) -> Self {
51 Self {
52 database_type,
53 counter: 0,
54 bindings: Default::default(),
55 }
56 }
57
58 fn get_binding_string(
67 &mut self,
68 binding_name: syn::LitStr,
69 type_override: Option<proc_macro2::TokenStream>,
70 ) -> syn::LitStr {
71 match self.database_type {
72 DatabaseType::PostgreSql => {
73 let span = binding_name.span();
74 let binding = self.bindings.entry(binding_name).or_insert_with(|| {
75 self.counter += 1;
76 RunTimeBinding {
77 indices: vec![self.counter],
78 type_override,
79 }
80 });
81 syn::LitStr::new(&format!("${}", binding.indices.first().unwrap()), span)
82 }
83 DatabaseType::MySql | DatabaseType::Sqlite => {
84 let span = binding_name.span();
85 self.counter += 1;
86
87 self.bindings
91 .entry(binding_name)
92 .and_modify(|binding| binding.indices.push(self.counter))
93 .or_insert_with(|| RunTimeBinding {
94 indices: vec![self.counter],
95 type_override,
96 });
97 syn::LitStr::new("?", span)
98 }
99 }
100 }
101
102 fn get_arguments(self) -> Vec<(syn::Ident, Option<proc_macro2::TokenStream>)> {
104 let mut run_time_bindings: Vec<_> = self
105 .bindings
106 .into_iter()
107 .flat_map(|(name, binding)| {
108 binding
109 .indices
110 .into_iter()
111 .map(|index| {
112 (
113 syn::Ident::new(&name.value(), name.span()),
114 binding.type_override.clone(),
115 index,
116 )
117 })
118 .collect::<Vec<_>>()
119 })
120 .collect();
121
122 run_time_bindings.sort_by_key(|(_, _, index)| *index);
123
124 run_time_bindings
125 .into_iter()
126 .map(|(ident, type_override, _)| (ident, type_override))
127 .collect()
128 }
129}
130
131pub(crate) fn expand(
138 database_type: DatabaseType,
139 lowered: LoweredConditionalQueryAs,
140) -> Result<ExpandedConditionalQueryAs, ExpandError> {
141 let mut match_arms = Vec::new();
142
143 for arm in lowered.match_arms {
144 let mut fragments = vec![lowered.query_string.clone()];
145 while fragments
146 .iter()
147 .any(|fragment| fragment.value().contains("{#"))
148 {
149 fragments = expand_compile_time_bindings(fragments, &arm.compile_time_bindings)?;
150 }
151
152 let mut run_time_bindings = RunTimeBindings::new(database_type);
154 let expanded = expand_run_time_bindings(fragments, &mut run_time_bindings)?;
155
156 match_arms.push(MatchArm {
157 patterns: arm.patterns,
158 query_fragments: expanded,
159 run_time_bindings: run_time_bindings.get_arguments(),
160 });
161 }
162
163 Ok(ExpandedConditionalQueryAs {
164 output_type: lowered.output_type,
165 match_expressions: lowered.match_expressions,
166 match_arms,
167 })
168}
169
170fn expand_compile_time_bindings(
178 unexpanded_fragments: Vec<syn::LitStr>,
179 compile_time_bindings: &HashMap<String, syn::LitStr>,
180) -> Result<Vec<syn::LitStr>, ExpandError> {
181 let mut expanded_fragments = Vec::new();
182
183 for fragment in unexpanded_fragments {
184 let fragment_string = fragment.value();
185 let mut fragment_str = fragment_string.as_str();
186
187 while let Some(start_of_binding) = fragment_str.find('{') {
188 if !fragment_str[..start_of_binding].is_empty() {
191 expanded_fragments.push(syn::LitStr::new(
192 &fragment_str[..start_of_binding],
193 fragment.span(),
194 ));
195 fragment_str = &fragment_str[start_of_binding..];
196 }
197
198 let end_of_binding = if let Some(end_of_binding) = fragment_str.find('}') {
200 end_of_binding
201 } else {
202 return Err(ExpandError::MissingBindingClosingBrace(fragment.span()));
203 };
204
205 if fragment_str.chars().nth(1) == Some('#') {
206 let binding_name = &fragment_str[2..end_of_binding];
208 if let Some(binding) = compile_time_bindings.get(binding_name) {
209 expanded_fragments.push(binding.clone());
210 } else {
211 return Err(ExpandError::MissingCompileTimeBinding(
212 binding_name.to_string(),
213 fragment.span(),
214 ));
215 }
216 } else {
217 expanded_fragments.push(syn::LitStr::new(
219 &fragment_str[..end_of_binding + 1],
220 fragment.span(),
221 ));
222 }
223
224 fragment_str = &fragment_str[end_of_binding + 1..];
225 }
226
227 if !fragment_str.is_empty() {
229 expanded_fragments.push(syn::LitStr::new(fragment_str, fragment.span()));
230 }
231 }
232
233 Ok(expanded_fragments)
234}
235
236fn expand_run_time_bindings(
240 unexpanded_fragments: Vec<syn::LitStr>,
241 run_time_bindings: &mut RunTimeBindings,
242) -> Result<Vec<syn::LitStr>, ExpandError> {
243 let mut expanded_query = Vec::new();
244
245 for fragment in unexpanded_fragments {
246 let fragment_string = fragment.value();
247 let mut fragment_str = fragment_string.as_str();
248
249 while let Some(start_of_binding) = fragment_str.find('{') {
250 expanded_query.push(syn::LitStr::new(
253 &fragment_str[..start_of_binding],
254 fragment.span(),
255 ));
256
257 fragment_str = &fragment_str[start_of_binding + 1..];
259 let end_of_binding = if let Some(end_of_binding) = fragment_str.find('}') {
260 end_of_binding
261 } else {
262 return Err(ExpandError::MissingBindingClosingBrace(fragment.span()));
263 };
264
265 let binding_name = &fragment_str[..end_of_binding];
266 let (binding_name, type_override) = if let Some(offset) = binding_name.find(':') {
267 let (binding_name, type_override) = binding_name.split_at(offset);
268 let type_override = type_override[1..]
269 .parse::<proc_macro2::TokenStream>()
270 .map_err(|err| {
271 ExpandError::BindingReferenceTypeOverrideParseError(err, fragment.span())
272 })?;
273 (binding_name.trim(), Some(type_override))
274 } else {
275 (binding_name, None)
276 };
277
278 let binding = run_time_bindings.get_binding_string(
280 syn::LitStr::new(binding_name, fragment.span()),
281 type_override,
282 );
283 expanded_query.push(binding);
284
285 fragment_str = &fragment_str[end_of_binding + 1..];
286 }
287
288 if !fragment_str.is_empty() {
290 expanded_query.push(syn::LitStr::new(fragment_str, fragment.span()));
291 }
292 }
293
294 Ok(expanded_query)
295}
296
297#[cfg(test)]
298mod tests {
299 use quote::ToTokens;
300
301 use super::*;
302
303 #[rstest::rstest]
304 #[case(DatabaseType::PostgreSql)]
305 #[case(DatabaseType::MySql)]
306 #[case(DatabaseType::Sqlite)]
307 fn expands_compile_time_bindings(#[case] database_type: DatabaseType) {
308 let parsed = syn::parse_str::<crate::parse::ParsedConditionalQueryAs>(
309 r#"
310 SomeType,
311 "some {#a} {#b} {#j} query",
312 #(a, b) = match c {
313 d => ("e", "f"),
314 g => ("h", "i"),
315 },
316 #j = match i {
317 k => "l",
318 m => "n",
319 },
320 "#,
321 )
322 .unwrap();
323 let analyzed = crate::analyze::analyze(parsed.clone()).unwrap();
324 let lowered = crate::lower::lower(analyzed);
325 let expanded = expand(database_type, lowered).unwrap();
326
327 assert_eq!(
328 expanded.match_arms[0]
329 .query_fragments
330 .iter()
331 .map(|qs| qs.to_token_stream().to_string())
332 .collect::<Vec<_>>(),
333 &[
334 "\"some \"",
335 "\"e\"",
336 "\" \"",
337 "\"f\"",
338 "\" \"",
339 "\"l\"",
340 "\" query\""
341 ],
342 );
343 }
344
345 #[rstest::rstest]
346 #[case(DatabaseType::PostgreSql)]
347 #[case(DatabaseType::MySql)]
348 #[case(DatabaseType::Sqlite)]
349 fn expands_run_time_bindings(#[case] database_type: DatabaseType) {
350 let parsed = syn::parse_str::<crate::parse::ParsedConditionalQueryAs>(
351 r#"
352 SomeType,
353 "some {foo:ty} {bar} {foo} query",
354 "#,
355 )
356 .unwrap();
357 let analyzed = crate::analyze::analyze(parsed.clone()).unwrap();
358 let lowered = crate::lower::lower(analyzed);
359 let expanded = expand(database_type, lowered).unwrap();
360
361 assert_eq!(
363 expanded.match_arms[0]
364 .query_fragments
365 .iter()
366 .map(|qs| qs.to_token_stream().to_string())
367 .collect::<Vec<_>>(),
368 match database_type {
369 DatabaseType::PostgreSql => &[
370 "\"some \"",
371 "\"$1\"",
372 "\" \"",
373 "\"$2\"",
374 "\" \"",
375 "\"$1\"",
376 "\" query\""
377 ],
378 DatabaseType::MySql | DatabaseType::Sqlite => &[
379 "\"some \"",
380 "\"?\"",
381 "\" \"",
382 "\"?\"",
383 "\" \"",
384 "\"?\"",
385 "\" query\""
386 ],
387 }
388 );
389
390 let run_time_bindings: Vec<_> = expanded.match_arms[0]
392 .run_time_bindings
393 .iter()
394 .map(|(ident, ts)| (ident.to_string(), ts.as_ref().map(|ts| ts.to_string())))
395 .collect();
396 assert_eq!(
397 run_time_bindings,
398 match database_type {
399 DatabaseType::PostgreSql => vec![
400 ("foo".to_string(), Some("ty".to_string())),
401 ("bar".to_string(), None),
402 ],
403 DatabaseType::MySql | DatabaseType::Sqlite => vec![
404 ("foo".to_string(), Some("ty".to_string())),
405 ("bar".to_string(), None),
406 ("foo".to_string(), Some("ty".to_string())),
407 ],
408 }
409 );
410 }
411}