sqlx_conditional_queries_core/
expand.rs

1use std::collections::HashMap;
2
3use crate::{lower::LoweredConditionalQueryAs, DatabaseType};
4
5#[derive(Debug, thiserror::Error)]
6pub enum ExpandError {
7    #[error("missing compile-time binding: {0}")]
8    MissingCompileTimeBinding(String, proc_macro2::Span),
9    #[error("missing binding closing brace")]
10    MissingBindingClosingBrace(proc_macro2::Span),
11    #[error("failed to parse type override in binding reference: {0}")]
12    BindingReferenceTypeOverrideParseError(proc_macro2::LexError, proc_macro2::Span),
13}
14
15#[derive(Debug)]
16pub(crate) struct ExpandedConditionalQueryAs {
17    pub(crate) output_type: syn::Ident,
18    pub(crate) match_expressions: Vec<syn::Expr>,
19    pub(crate) match_arms: Vec<MatchArm>,
20}
21
22#[derive(Debug)]
23pub(crate) struct MatchArm {
24    pub(crate) patterns: Vec<syn::Pat>,
25    pub(crate) query_fragments: Vec<syn::LitStr>,
26    pub(crate) run_time_bindings: Vec<(syn::Ident, Option<proc_macro2::TokenStream>)>,
27}
28
29/// Corresponds to a single run-time binding name.
30#[derive(Debug)]
31struct RunTimeBinding {
32    /// List of all argument index positions at which this binding needs to be bound.
33    ///
34    /// - For PostgreSQL only contains one element.
35    /// - For MySQL and SQLite it contains one index for each time the binding was referenced.
36    indices: Vec<usize>,
37
38    /// Type-override fragment to pass on To SQLx
39    type_override: Option<proc_macro2::TokenStream>,
40}
41
42#[derive(Debug)]
43struct RunTimeBindings {
44    database_type: DatabaseType,
45    counter: usize,
46    bindings: HashMap<syn::LitStr, RunTimeBinding>,
47}
48
49impl RunTimeBindings {
50    fn new(database_type: DatabaseType) -> Self {
51        Self {
52            database_type,
53            counter: 0,
54            bindings: Default::default(),
55        }
56    }
57
58    /// Returns a database-appropriate run-time binding string for the given binding name.
59    ///
60    /// Database type selection is done based on the features this crate was built with.
61    ///
62    /// - PostgreSQL uses 1-indexed references such as `$1`, which means that multiple references
63    ///   to the same parameter only need to be bound once.
64    /// - MySQL and SQLite always use `?` which means that the arguments need to specified in
65    ///   order and be duplicated for as many times as they're used.
66    fn get_binding_string(
67        &mut self,
68        binding_name: syn::LitStr,
69        type_override: Option<proc_macro2::TokenStream>,
70    ) -> syn::LitStr {
71        match self.database_type {
72            DatabaseType::PostgreSql => {
73                let span = binding_name.span();
74                let binding = self.bindings.entry(binding_name).or_insert_with(|| {
75                    self.counter += 1;
76                    RunTimeBinding {
77                        indices: vec![self.counter],
78                        type_override,
79                    }
80                });
81                syn::LitStr::new(&format!("${}", binding.indices.first().unwrap()), span)
82            }
83            DatabaseType::MySql | DatabaseType::Sqlite => {
84                let span = binding_name.span();
85                self.counter += 1;
86
87                // For MySQL and SQLite bindings we need to specify the same argument multiple
88                // times if it's reused and so generate a unique index every time.  This ensures
89                // that `get_run_time_bindings` will generate the arguments in the correct order.
90                self.bindings
91                    .entry(binding_name)
92                    .and_modify(|binding| binding.indices.push(self.counter))
93                    .or_insert_with(|| RunTimeBinding {
94                        indices: vec![self.counter],
95                        type_override,
96                    });
97                syn::LitStr::new("?", span)
98            }
99        }
100    }
101
102    /// Returns the `query_as!` arguments for all referenced run-time bindings.
103    fn get_arguments(self) -> Vec<(syn::Ident, Option<proc_macro2::TokenStream>)> {
104        let mut run_time_bindings: Vec<_> = self
105            .bindings
106            .into_iter()
107            .flat_map(|(name, binding)| {
108                binding
109                    .indices
110                    .into_iter()
111                    .map(|index| {
112                        (
113                            syn::Ident::new(&name.value(), name.span()),
114                            binding.type_override.clone(),
115                            index,
116                        )
117                    })
118                    .collect::<Vec<_>>()
119            })
120            .collect();
121
122        run_time_bindings.sort_by_key(|(_, _, index)| *index);
123
124        run_time_bindings
125            .into_iter()
126            .map(|(ident, type_override, _)| (ident, type_override))
127            .collect()
128    }
129}
130
131/// This function takes the original query string that was supplied to the macro and adjusts it for
132/// each arm of the previously generated cartesian product of all bindings' match arms.
133///
134/// The `{#binding_name}` placeholder are then replaced with the string literals from match clauses
135/// and  all `{scope_variable} placeholder are replaced with the positional variables of the respective
136/// database engine whose feature is enabled. For more info take a look at [RunTimeBindings].
137pub(crate) fn expand(
138    database_type: DatabaseType,
139    lowered: LoweredConditionalQueryAs,
140) -> Result<ExpandedConditionalQueryAs, ExpandError> {
141    let mut match_arms = Vec::new();
142
143    for arm in lowered.match_arms {
144        let mut fragments = vec![lowered.query_string.clone()];
145        while fragments
146            .iter()
147            .any(|fragment| fragment.value().contains("{#"))
148        {
149            fragments = expand_compile_time_bindings(fragments, &arm.compile_time_bindings)?;
150        }
151
152        // Substitute
153        let mut run_time_bindings = RunTimeBindings::new(database_type);
154        let expanded = expand_run_time_bindings(fragments, &mut run_time_bindings)?;
155
156        match_arms.push(MatchArm {
157            patterns: arm.patterns,
158            query_fragments: expanded,
159            run_time_bindings: run_time_bindings.get_arguments(),
160        });
161    }
162
163    Ok(ExpandedConditionalQueryAs {
164        output_type: lowered.output_type,
165        match_expressions: lowered.match_expressions,
166        match_arms,
167    })
168}
169
170/// This function takes the list of query fragments and substitutes all `{#binding_name}`
171/// occurrences with their literal strings from the respective match statements.
172///
173/// These literal strings however, can once again contain another `{#binding_name}`, which is why
174/// this function is called from a while loop.
175/// Since this function might get called multiple times, some fragments might already be expanded
176/// at this point, despite the variable name.
177fn expand_compile_time_bindings(
178    unexpanded_fragments: Vec<syn::LitStr>,
179    compile_time_bindings: &HashMap<String, syn::LitStr>,
180) -> Result<Vec<syn::LitStr>, ExpandError> {
181    let mut expanded_fragments = Vec::new();
182
183    for fragment in unexpanded_fragments {
184        let fragment_string = fragment.value();
185        let mut fragment_str = fragment_string.as_str();
186
187        while let Some(start_of_binding) = fragment_str.find('{') {
188            // We've hit either a compile-time or a run-time binding, so first we push any prefix
189            // before the binding.
190            if !fragment_str[..start_of_binding].is_empty() {
191                expanded_fragments.push(syn::LitStr::new(
192                    &fragment_str[..start_of_binding],
193                    fragment.span(),
194                ));
195                fragment_str = &fragment_str[start_of_binding..];
196            }
197
198            // Then we find the matching closing brace.
199            let end_of_binding = if let Some(end_of_binding) = fragment_str.find('}') {
200                end_of_binding
201            } else {
202                return Err(ExpandError::MissingBindingClosingBrace(fragment.span()));
203            };
204
205            if fragment_str.chars().nth(1) == Some('#') {
206                // If the binding is a compile-time binding, expand it.
207                let binding_name = &fragment_str[2..end_of_binding];
208                if let Some(binding) = compile_time_bindings.get(binding_name) {
209                    expanded_fragments.push(binding.clone());
210                } else {
211                    return Err(ExpandError::MissingCompileTimeBinding(
212                        binding_name.to_string(),
213                        fragment.span(),
214                    ));
215                }
216            } else {
217                // Otherwise push it as-is for the next pass.
218                expanded_fragments.push(syn::LitStr::new(
219                    &fragment_str[..end_of_binding + 1],
220                    fragment.span(),
221                ));
222            }
223
224            fragment_str = &fragment_str[end_of_binding + 1..];
225        }
226
227        // Push trailing query fragment.
228        if !fragment_str.is_empty() {
229            expanded_fragments.push(syn::LitStr::new(fragment_str, fragment.span()));
230        }
231    }
232
233    Ok(expanded_fragments)
234}
235
236/// Take all fragments and substitute any `{name}` occurrences with the respective database
237/// binding. Since the parameter syntax is different for various databases, [RunTimeBinding] is
238/// used in combination with feature flags to abstract this variance away.
239fn expand_run_time_bindings(
240    unexpanded_fragments: Vec<syn::LitStr>,
241    run_time_bindings: &mut RunTimeBindings,
242) -> Result<Vec<syn::LitStr>, ExpandError> {
243    let mut expanded_query = Vec::new();
244
245    for fragment in unexpanded_fragments {
246        let fragment_string = fragment.value();
247        let mut fragment_str = fragment_string.as_str();
248
249        while let Some(start_of_binding) = fragment_str.find('{') {
250            // Otherwise we've hit a run-time binding, so first we push any prefix before the
251            // binding.
252            expanded_query.push(syn::LitStr::new(
253                &fragment_str[..start_of_binding],
254                fragment.span(),
255            ));
256
257            // Then we find the matching closing brace.
258            fragment_str = &fragment_str[start_of_binding + 1..];
259            let end_of_binding = if let Some(end_of_binding) = fragment_str.find('}') {
260                end_of_binding
261            } else {
262                return Err(ExpandError::MissingBindingClosingBrace(fragment.span()));
263            };
264
265            let binding_name = &fragment_str[..end_of_binding];
266            let (binding_name, type_override) = if let Some(offset) = binding_name.find(':') {
267                let (binding_name, type_override) = binding_name.split_at(offset);
268                let type_override = type_override[1..]
269                    .parse::<proc_macro2::TokenStream>()
270                    .map_err(|err| {
271                        ExpandError::BindingReferenceTypeOverrideParseError(err, fragment.span())
272                    })?;
273                (binding_name.trim(), Some(type_override))
274            } else {
275                (binding_name, None)
276            };
277
278            // And finally we push a bound parameter argument
279            let binding = run_time_bindings.get_binding_string(
280                syn::LitStr::new(binding_name, fragment.span()),
281                type_override,
282            );
283            expanded_query.push(binding);
284
285            fragment_str = &fragment_str[end_of_binding + 1..];
286        }
287
288        // Push trailing query fragment.
289        if !fragment_str.is_empty() {
290            expanded_query.push(syn::LitStr::new(fragment_str, fragment.span()));
291        }
292    }
293
294    Ok(expanded_query)
295}
296
297#[cfg(test)]
298mod tests {
299    use quote::ToTokens;
300
301    use super::*;
302
303    #[rstest::rstest]
304    #[case(DatabaseType::PostgreSql)]
305    #[case(DatabaseType::MySql)]
306    #[case(DatabaseType::Sqlite)]
307    fn expands_compile_time_bindings(#[case] database_type: DatabaseType) {
308        let parsed = syn::parse_str::<crate::parse::ParsedConditionalQueryAs>(
309            r#"
310                SomeType,
311                "some {#a} {#b} {#j} query",
312                #(a, b) = match c {
313                    d => ("e", "f"),
314                    g => ("h", "i"),
315                },
316                #j = match i {
317                    k => "l",
318                    m => "n",
319                },
320            "#,
321        )
322        .unwrap();
323        let analyzed = crate::analyze::analyze(parsed.clone()).unwrap();
324        let lowered = crate::lower::lower(analyzed);
325        let expanded = expand(database_type, lowered).unwrap();
326
327        assert_eq!(
328            expanded.match_arms[0]
329                .query_fragments
330                .iter()
331                .map(|qs| qs.to_token_stream().to_string())
332                .collect::<Vec<_>>(),
333            &[
334                "\"some \"",
335                "\"e\"",
336                "\" \"",
337                "\"f\"",
338                "\" \"",
339                "\"l\"",
340                "\" query\""
341            ],
342        );
343    }
344
345    #[rstest::rstest]
346    #[case(DatabaseType::PostgreSql)]
347    #[case(DatabaseType::MySql)]
348    #[case(DatabaseType::Sqlite)]
349    fn expands_run_time_bindings(#[case] database_type: DatabaseType) {
350        let parsed = syn::parse_str::<crate::parse::ParsedConditionalQueryAs>(
351            r#"
352                SomeType,
353                "some {foo:ty} {bar} {foo} query",
354            "#,
355        )
356        .unwrap();
357        let analyzed = crate::analyze::analyze(parsed.clone()).unwrap();
358        let lowered = crate::lower::lower(analyzed);
359        let expanded = expand(database_type, lowered).unwrap();
360
361        // Check that run-time binding references are generated properly.
362        assert_eq!(
363            expanded.match_arms[0]
364                .query_fragments
365                .iter()
366                .map(|qs| qs.to_token_stream().to_string())
367                .collect::<Vec<_>>(),
368            match database_type {
369                DatabaseType::PostgreSql => &[
370                    "\"some \"",
371                    "\"$1\"",
372                    "\" \"",
373                    "\"$2\"",
374                    "\" \"",
375                    "\"$1\"",
376                    "\" query\""
377                ],
378                DatabaseType::MySql | DatabaseType::Sqlite => &[
379                    "\"some \"",
380                    "\"?\"",
381                    "\" \"",
382                    "\"?\"",
383                    "\" \"",
384                    "\"?\"",
385                    "\" query\""
386                ],
387            }
388        );
389
390        // Check that type overrides are parsed properly.
391        let run_time_bindings: Vec<_> = expanded.match_arms[0]
392            .run_time_bindings
393            .iter()
394            .map(|(ident, ts)| (ident.to_string(), ts.as_ref().map(|ts| ts.to_string())))
395            .collect();
396        assert_eq!(
397            run_time_bindings,
398            match database_type {
399                DatabaseType::PostgreSql => vec![
400                    ("foo".to_string(), Some("ty".to_string())),
401                    ("bar".to_string(), None),
402                ],
403                DatabaseType::MySql | DatabaseType::Sqlite => vec![
404                    ("foo".to_string(), Some("ty".to_string())),
405                    ("bar".to_string(), None),
406                    ("foo".to_string(), Some("ty".to_string())),
407                ],
408            }
409        );
410    }
411}