sqlx_conditional_queries_core/
analyze.rs

1use std::collections::HashSet;
2
3use syn::spanned::Spanned;
4
5use crate::parse::ParsedConditionalQueryAs;
6
7#[derive(Debug, thiserror::Error)]
8pub enum AnalyzeError {
9    #[error("expected string literal")]
10    ExpectedStringLiteral(proc_macro2::Span),
11    #[error("mismatch between number of names ({names}) and values ({values})")]
12    BindingNameValueLengthMismatch {
13        names: usize,
14        names_span: proc_macro2::Span,
15        values: usize,
16        values_span: proc_macro2::Span,
17    },
18    #[error("found two compile-time bindings with the same binding: {first}")]
19    DuplicatedCompileTimeBindingsFound {
20        first: proc_macro2::Ident,
21        second: proc_macro2::Ident,
22    },
23    #[error("found cycle in compile-time bindings: {path}")]
24    CompileTimeBindingCycleDetected {
25        root_ident: proc_macro2::Ident,
26        path: String,
27    },
28}
29
30/// This represents the finished second step in the processing pipeline.
31/// The compile time bindings have been further processed to a form that allows us to easily create
32/// the cartesian product and thereby all query variations in the next step.
33#[derive(Debug)]
34pub(crate) struct AnalyzedConditionalQueryAs {
35    pub(crate) output_type: syn::Ident,
36    pub(crate) query_string: syn::LitStr,
37    pub(crate) compile_time_bindings: Vec<CompileTimeBinding>,
38}
39
40/// This represents a single combination of a single compiletime binding of a query.
41#[derive(Debug)]
42pub(crate) struct CompileTimeBinding {
43    /// The actual expression used in the match statement.
44    /// E.g. for `match something`, this would be `something`.
45    pub(crate) expression: syn::Expr,
46    /// Each entry in this Vec represents a single expanded `match` and the
47    /// binding names with the binding values from that specific arm.
48    /// (`match arm pattern`, Vec(binding_name, binding_value)`
49    pub(crate) arms: Vec<(syn::Pat, Vec<(syn::Ident, syn::LitStr)>)>,
50}
51
52/// Further parse and analyze all compiletime binding statements.
53/// Each binding is split into individual entries of this form:
54/// (`match arm pattern`, Vec(binding_name, binding_value)`
55pub(crate) fn analyze(
56    parsed: ParsedConditionalQueryAs,
57) -> Result<AnalyzedConditionalQueryAs, AnalyzeError> {
58    let mut compile_time_bindings = Vec::new();
59
60    let mut known_binding_names = HashSet::new();
61
62    for (names, match_expr) in parsed.compile_time_bindings {
63        let binding_names_span = names.span();
64        // Convert the OneOrPunctuated enum in a list of `Ident`s.
65        // `One(T)` will be converted into a Vec with a single entry.
66        let binding_names: Vec<_> = names.into_iter().collect();
67
68        // Find duplicate compile-time bindings.
69        for name in &binding_names {
70            let Some(first) = known_binding_names.get(name) else {
71                known_binding_names.insert(name.clone());
72                continue;
73            };
74            return Err(AnalyzeError::DuplicatedCompileTimeBindingsFound {
75                first: first.clone(),
76                second: name.clone(),
77            });
78        }
79
80        let mut bindings = Vec::new();
81        for arm in match_expr.arms {
82            let arm_span = arm.body.span();
83
84            let binding_values = match *arm.body {
85                // If the match arm expression just contains a literal, use that.
86                syn::Expr::Lit(syn::ExprLit {
87                    lit: syn::Lit::Str(literal),
88                    ..
89                }) => vec![literal],
90
91                // If there's a tuple, treat each literal inside that tuple as a binding value.
92                syn::Expr::Tuple(tuple) => {
93                    let mut values = Vec::new();
94                    for elem in tuple.elems {
95                        match elem {
96                            syn::Expr::Lit(syn::ExprLit {
97                                lit: syn::Lit::Str(literal),
98                                ..
99                            }) => values.push(literal),
100
101                            _ => return Err(AnalyzeError::ExpectedStringLiteral(elem.span())),
102                        }
103                    }
104                    values
105                }
106
107                body => return Err(AnalyzeError::ExpectedStringLiteral(body.span())),
108            };
109
110            // There must always be a matching amount of binding values in each match arm.
111            // Error if there are more or fewer values than binding names.
112            if binding_names.len() != binding_values.len() {
113                return Err(AnalyzeError::BindingNameValueLengthMismatch {
114                    names: binding_names.len(),
115                    names_span: binding_names_span,
116                    values: binding_values.len(),
117                    values_span: arm_span,
118                });
119            }
120
121            bindings.push((
122                arm.pat,
123                binding_names
124                    .iter()
125                    .cloned()
126                    .zip(binding_values)
127                    .collect::<Vec<_>>(),
128            ));
129        }
130
131        compile_time_bindings.push(CompileTimeBinding {
132            expression: *match_expr.expr,
133            arms: bindings,
134        });
135    }
136
137    compile_time_bindings::validate_compile_time_bindings(&compile_time_bindings)?;
138
139    Ok(AnalyzedConditionalQueryAs {
140        output_type: parsed.output_type,
141        query_string: parsed.query_string,
142        compile_time_bindings,
143    })
144}
145
146mod compile_time_bindings {
147    use std::collections::{HashMap, HashSet};
148
149    use super::{AnalyzeError, CompileTimeBinding};
150
151    pub(super) fn validate_compile_time_bindings(
152        compile_time_bindings: &[CompileTimeBinding],
153    ) -> Result<(), AnalyzeError> {
154        let mut bindings = HashMap::new();
155
156        for (_, binding_values) in compile_time_bindings
157            .iter()
158            .flat_map(|bindings| &bindings.arms)
159        {
160            for (binding, value) in binding_values {
161                let name = binding.to_string();
162
163                let (_, references) = bindings
164                    .entry(name)
165                    .or_insert_with(|| (binding, HashSet::new()));
166                fill_references(references, &value.value());
167            }
168        }
169
170        for (name, (ident, _)) in &bindings {
171            validate_references(&bindings, ident, &[], name)?;
172        }
173
174        Ok(())
175    }
176
177    fn fill_references(references: &mut HashSet<String>, mut fragment: &str) {
178        while let Some(start_idx) = fragment.find("{#") {
179            fragment = &fragment[start_idx + 2..];
180            if let Some(end_idx) = fragment.find("}") {
181                references.insert(fragment[..end_idx].to_string());
182                fragment = &fragment[end_idx + 1..];
183            } else {
184                break;
185            }
186        }
187    }
188
189    fn validate_references(
190        bindings: &HashMap<String, (&syn::Ident, HashSet<String>)>,
191        root_ident: &syn::Ident,
192        path: &[&str],
193        name: &str,
194    ) -> Result<(), AnalyzeError> {
195        let mut path = path.to_vec();
196        path.push(name);
197
198        if path.iter().filter(|component| **component == name).count() > 1 {
199            return Err(AnalyzeError::CompileTimeBindingCycleDetected {
200                root_ident: root_ident.clone(),
201                path: path.join(" -> "),
202            });
203        }
204
205        let Some((_, references)) = bindings.get(name) else {
206            // This error is caught and handled in all contexts in the expand stage.
207            return Ok(());
208        };
209
210        for reference in references {
211            validate_references(bindings, root_ident, &path, reference)?;
212        }
213
214        Ok(())
215    }
216}
217
218#[cfg(test)]
219mod tests {
220    use quote::ToTokens;
221
222    use super::*;
223
224    #[test]
225    fn valid_syntax() {
226        let parsed = syn::parse_str::<ParsedConditionalQueryAs>(
227            r#"
228                SomeType,
229                "some SQL query",
230                #binding = match foo {
231                    bar => "baz",
232                },
233                #(a, b) = match c {
234                    d => ("e", "f"),
235                },
236            "#,
237        )
238        .unwrap();
239        let mut analyzed = analyze(parsed.clone()).unwrap();
240
241        assert_eq!(parsed.output_type, analyzed.output_type);
242        assert_eq!(parsed.query_string, analyzed.query_string);
243
244        assert_eq!(analyzed.compile_time_bindings.len(), 2);
245
246        {
247            let compile_time_binding = dbg!(analyzed.compile_time_bindings.remove(0));
248            assert_eq!(
249                compile_time_binding
250                    .expression
251                    .to_token_stream()
252                    .to_string(),
253                "foo",
254            );
255
256            assert_eq!(compile_time_binding.arms.len(), 1);
257            {
258                let arm = &compile_time_binding.arms[0];
259                assert_eq!(arm.0.to_token_stream().to_string(), "bar");
260                assert_eq!(
261                    arm.1
262                        .iter()
263                        .map(|v| (
264                            v.0.to_token_stream().to_string(),
265                            v.1.to_token_stream().to_string(),
266                        ))
267                        .collect::<Vec<_>>(),
268                    &[("binding".to_string(), "\"baz\"".to_string())],
269                );
270            }
271        }
272
273        {
274            let compile_time_binding = dbg!(analyzed.compile_time_bindings.remove(0));
275            assert_eq!(
276                compile_time_binding
277                    .expression
278                    .to_token_stream()
279                    .to_string(),
280                "c",
281            );
282
283            assert_eq!(
284                compile_time_binding
285                    .arms
286                    .iter()
287                    .map(|v| v.0.to_token_stream().to_string())
288                    .collect::<Vec<_>>(),
289                &["d"],
290            );
291
292            assert_eq!(compile_time_binding.arms.len(), 1);
293            {
294                let arm = &compile_time_binding.arms[0];
295                assert_eq!(arm.0.to_token_stream().to_string(), "d");
296                assert_eq!(
297                    arm.1
298                        .iter()
299                        .map(|v| (
300                            v.0.to_token_stream().to_string(),
301                            v.1.to_token_stream().to_string(),
302                        ))
303                        .collect::<Vec<_>>(),
304                    &[
305                        ("a".to_string(), "\"e\"".to_string()),
306                        ("b".to_string(), "\"f\"".to_string())
307                    ],
308                );
309            }
310        }
311    }
312
313    #[test]
314    fn duplicate_compile_time_bindings() {
315        let parsed = syn::parse_str::<ParsedConditionalQueryAs>(
316            r##"
317                SomeType,
318                r#"{#a}"#,
319                #a = match _ {
320                    _ => "1",
321                },
322                #a = match _ {
323                    _ => "2",
324                },
325            "##,
326        )
327        .unwrap();
328        let analyzed = analyze(parsed.clone()).unwrap_err();
329
330        assert!(matches!(
331            analyzed,
332            AnalyzeError::DuplicatedCompileTimeBindingsFound { .. }
333        ));
334    }
335
336    #[test]
337    fn compile_time_binding_cycle_detected() {
338        let parsed = syn::parse_str::<ParsedConditionalQueryAs>(
339            r##"
340                SomeType,
341                r#"{#a}"#,
342                #a = match _ {
343                    _ => "{#b}",
344                },
345                #b = match _ {
346                    _ => "{#a}",
347                },
348            "##,
349        )
350        .unwrap();
351        let analyzed = analyze(parsed.clone()).unwrap_err();
352
353        assert!(matches!(
354            analyzed,
355            AnalyzeError::CompileTimeBindingCycleDetected { .. }
356        ));
357    }
358}