sqlx_conditional_queries_core/
analyze.rs1use std::collections::HashSet;
2
3use syn::spanned::Spanned;
4
5use crate::parse::ParsedConditionalQueryAs;
6
7#[derive(Debug, thiserror::Error)]
8pub enum AnalyzeError {
9 #[error("expected string literal")]
10 ExpectedStringLiteral(proc_macro2::Span),
11 #[error("mismatch between number of names ({names}) and values ({values})")]
12 BindingNameValueLengthMismatch {
13 names: usize,
14 names_span: proc_macro2::Span,
15 values: usize,
16 values_span: proc_macro2::Span,
17 },
18 #[error("found two compile-time bindings with the same binding: {first}")]
19 DuplicatedCompileTimeBindingsFound {
20 first: proc_macro2::Ident,
21 second: proc_macro2::Ident,
22 },
23 #[error("found cycle in compile-time bindings: {path}")]
24 CompileTimeBindingCycleDetected {
25 root_ident: proc_macro2::Ident,
26 path: String,
27 },
28}
29
30#[derive(Debug)]
34pub(crate) struct AnalyzedConditionalQueryAs {
35 pub(crate) output_type: syn::Ident,
36 pub(crate) query_string: syn::LitStr,
37 pub(crate) compile_time_bindings: Vec<CompileTimeBinding>,
38}
39
40#[derive(Debug)]
42pub(crate) struct CompileTimeBinding {
43 pub(crate) expression: syn::Expr,
46 pub(crate) arms: Vec<(syn::Pat, Vec<(syn::Ident, syn::LitStr)>)>,
50}
51
52pub(crate) fn analyze(
56 parsed: ParsedConditionalQueryAs,
57) -> Result<AnalyzedConditionalQueryAs, AnalyzeError> {
58 let mut compile_time_bindings = Vec::new();
59
60 let mut known_binding_names = HashSet::new();
61
62 for (names, match_expr) in parsed.compile_time_bindings {
63 let binding_names_span = names.span();
64 let binding_names: Vec<_> = names.into_iter().collect();
67
68 for name in &binding_names {
70 let Some(first) = known_binding_names.get(name) else {
71 known_binding_names.insert(name.clone());
72 continue;
73 };
74 return Err(AnalyzeError::DuplicatedCompileTimeBindingsFound {
75 first: first.clone(),
76 second: name.clone(),
77 });
78 }
79
80 let mut bindings = Vec::new();
81 for arm in match_expr.arms {
82 let arm_span = arm.body.span();
83
84 let binding_values = match *arm.body {
85 syn::Expr::Lit(syn::ExprLit {
87 lit: syn::Lit::Str(literal),
88 ..
89 }) => vec![literal],
90
91 syn::Expr::Tuple(tuple) => {
93 let mut values = Vec::new();
94 for elem in tuple.elems {
95 match elem {
96 syn::Expr::Lit(syn::ExprLit {
97 lit: syn::Lit::Str(literal),
98 ..
99 }) => values.push(literal),
100
101 _ => return Err(AnalyzeError::ExpectedStringLiteral(elem.span())),
102 }
103 }
104 values
105 }
106
107 body => return Err(AnalyzeError::ExpectedStringLiteral(body.span())),
108 };
109
110 if binding_names.len() != binding_values.len() {
113 return Err(AnalyzeError::BindingNameValueLengthMismatch {
114 names: binding_names.len(),
115 names_span: binding_names_span,
116 values: binding_values.len(),
117 values_span: arm_span,
118 });
119 }
120
121 bindings.push((
122 arm.pat,
123 binding_names
124 .iter()
125 .cloned()
126 .zip(binding_values)
127 .collect::<Vec<_>>(),
128 ));
129 }
130
131 compile_time_bindings.push(CompileTimeBinding {
132 expression: *match_expr.expr,
133 arms: bindings,
134 });
135 }
136
137 compile_time_bindings::validate_compile_time_bindings(&compile_time_bindings)?;
138
139 Ok(AnalyzedConditionalQueryAs {
140 output_type: parsed.output_type,
141 query_string: parsed.query_string,
142 compile_time_bindings,
143 })
144}
145
146mod compile_time_bindings {
147 use std::collections::{HashMap, HashSet};
148
149 use super::{AnalyzeError, CompileTimeBinding};
150
151 pub(super) fn validate_compile_time_bindings(
152 compile_time_bindings: &[CompileTimeBinding],
153 ) -> Result<(), AnalyzeError> {
154 let mut bindings = HashMap::new();
155
156 for (_, binding_values) in compile_time_bindings
157 .iter()
158 .flat_map(|bindings| &bindings.arms)
159 {
160 for (binding, value) in binding_values {
161 let name = binding.to_string();
162
163 let (_, references) = bindings
164 .entry(name)
165 .or_insert_with(|| (binding, HashSet::new()));
166 fill_references(references, &value.value());
167 }
168 }
169
170 for (name, (ident, _)) in &bindings {
171 validate_references(&bindings, ident, &[], name)?;
172 }
173
174 Ok(())
175 }
176
177 fn fill_references(references: &mut HashSet<String>, mut fragment: &str) {
178 while let Some(start_idx) = fragment.find("{#") {
179 fragment = &fragment[start_idx + 2..];
180 if let Some(end_idx) = fragment.find("}") {
181 references.insert(fragment[..end_idx].to_string());
182 fragment = &fragment[end_idx + 1..];
183 } else {
184 break;
185 }
186 }
187 }
188
189 fn validate_references(
190 bindings: &HashMap<String, (&syn::Ident, HashSet<String>)>,
191 root_ident: &syn::Ident,
192 path: &[&str],
193 name: &str,
194 ) -> Result<(), AnalyzeError> {
195 let mut path = path.to_vec();
196 path.push(name);
197
198 if path.iter().filter(|component| **component == name).count() > 1 {
199 return Err(AnalyzeError::CompileTimeBindingCycleDetected {
200 root_ident: root_ident.clone(),
201 path: path.join(" -> "),
202 });
203 }
204
205 let Some((_, references)) = bindings.get(name) else {
206 return Ok(());
208 };
209
210 for reference in references {
211 validate_references(bindings, root_ident, &path, reference)?;
212 }
213
214 Ok(())
215 }
216}
217
218#[cfg(test)]
219mod tests {
220 use quote::ToTokens;
221
222 use super::*;
223
224 #[test]
225 fn valid_syntax() {
226 let parsed = syn::parse_str::<ParsedConditionalQueryAs>(
227 r#"
228 SomeType,
229 "some SQL query",
230 #binding = match foo {
231 bar => "baz",
232 },
233 #(a, b) = match c {
234 d => ("e", "f"),
235 },
236 "#,
237 )
238 .unwrap();
239 let mut analyzed = analyze(parsed.clone()).unwrap();
240
241 assert_eq!(parsed.output_type, analyzed.output_type);
242 assert_eq!(parsed.query_string, analyzed.query_string);
243
244 assert_eq!(analyzed.compile_time_bindings.len(), 2);
245
246 {
247 let compile_time_binding = dbg!(analyzed.compile_time_bindings.remove(0));
248 assert_eq!(
249 compile_time_binding
250 .expression
251 .to_token_stream()
252 .to_string(),
253 "foo",
254 );
255
256 assert_eq!(compile_time_binding.arms.len(), 1);
257 {
258 let arm = &compile_time_binding.arms[0];
259 assert_eq!(arm.0.to_token_stream().to_string(), "bar");
260 assert_eq!(
261 arm.1
262 .iter()
263 .map(|v| (
264 v.0.to_token_stream().to_string(),
265 v.1.to_token_stream().to_string(),
266 ))
267 .collect::<Vec<_>>(),
268 &[("binding".to_string(), "\"baz\"".to_string())],
269 );
270 }
271 }
272
273 {
274 let compile_time_binding = dbg!(analyzed.compile_time_bindings.remove(0));
275 assert_eq!(
276 compile_time_binding
277 .expression
278 .to_token_stream()
279 .to_string(),
280 "c",
281 );
282
283 assert_eq!(
284 compile_time_binding
285 .arms
286 .iter()
287 .map(|v| v.0.to_token_stream().to_string())
288 .collect::<Vec<_>>(),
289 &["d"],
290 );
291
292 assert_eq!(compile_time_binding.arms.len(), 1);
293 {
294 let arm = &compile_time_binding.arms[0];
295 assert_eq!(arm.0.to_token_stream().to_string(), "d");
296 assert_eq!(
297 arm.1
298 .iter()
299 .map(|v| (
300 v.0.to_token_stream().to_string(),
301 v.1.to_token_stream().to_string(),
302 ))
303 .collect::<Vec<_>>(),
304 &[
305 ("a".to_string(), "\"e\"".to_string()),
306 ("b".to_string(), "\"f\"".to_string())
307 ],
308 );
309 }
310 }
311 }
312
313 #[test]
314 fn duplicate_compile_time_bindings() {
315 let parsed = syn::parse_str::<ParsedConditionalQueryAs>(
316 r##"
317 SomeType,
318 r#"{#a}"#,
319 #a = match _ {
320 _ => "1",
321 },
322 #a = match _ {
323 _ => "2",
324 },
325 "##,
326 )
327 .unwrap();
328 let analyzed = analyze(parsed.clone()).unwrap_err();
329
330 assert!(matches!(
331 analyzed,
332 AnalyzeError::DuplicatedCompileTimeBindingsFound { .. }
333 ));
334 }
335
336 #[test]
337 fn compile_time_binding_cycle_detected() {
338 let parsed = syn::parse_str::<ParsedConditionalQueryAs>(
339 r##"
340 SomeType,
341 r#"{#a}"#,
342 #a = match _ {
343 _ => "{#b}",
344 },
345 #b = match _ {
346 _ => "{#a}",
347 },
348 "##,
349 )
350 .unwrap();
351 let analyzed = analyze(parsed.clone()).unwrap_err();
352
353 assert!(matches!(
354 analyzed,
355 AnalyzeError::CompileTimeBindingCycleDetected { .. }
356 ));
357 }
358}