sqlx_conditional_queries_core/
expand.rsuse std::collections::HashMap;
use crate::{lower::LoweredConditionalQueryAs, DatabaseType};
#[derive(Debug, thiserror::Error)]
pub enum ExpandError {
#[error("missing compile-time binding: {0}")]
MissingCompileTimeBinding(String, proc_macro2::Span),
#[error("missing binding closing brace")]
MissingBindingClosingBrace(proc_macro2::Span),
#[error("failed to parse type override in binding reference: {0}")]
BindingReferenceTypeOverrideParseError(proc_macro2::LexError, proc_macro2::Span),
}
#[derive(Debug)]
pub(crate) struct ExpandedConditionalQueryAs {
pub(crate) output_type: syn::Ident,
pub(crate) match_expressions: Vec<syn::Expr>,
pub(crate) match_arms: Vec<MatchArm>,
}
#[derive(Debug)]
pub(crate) struct MatchArm {
pub(crate) patterns: Vec<syn::Pat>,
pub(crate) query_fragments: Vec<syn::LitStr>,
pub(crate) run_time_bindings: Vec<(syn::Ident, Option<proc_macro2::TokenStream>)>,
}
#[derive(Debug)]
struct RunTimeBinding {
indices: Vec<usize>,
type_override: Option<proc_macro2::TokenStream>,
}
#[derive(Debug)]
struct RunTimeBindings {
database_type: DatabaseType,
counter: usize,
bindings: HashMap<syn::LitStr, RunTimeBinding>,
}
impl RunTimeBindings {
fn new(database_type: DatabaseType) -> Self {
Self {
database_type,
counter: 0,
bindings: Default::default(),
}
}
fn get_binding_string(
&mut self,
binding_name: syn::LitStr,
type_override: Option<proc_macro2::TokenStream>,
) -> syn::LitStr {
match self.database_type {
DatabaseType::PostgreSql => {
let span = binding_name.span();
let binding = self.bindings.entry(binding_name).or_insert_with(|| {
self.counter += 1;
RunTimeBinding {
indices: vec![self.counter],
type_override,
}
});
syn::LitStr::new(&format!("${}", binding.indices.first().unwrap()), span)
}
DatabaseType::MySql | DatabaseType::Sqlite => {
let span = binding_name.span();
self.counter += 1;
self.bindings
.entry(binding_name)
.and_modify(|binding| binding.indices.push(self.counter))
.or_insert_with(|| RunTimeBinding {
indices: vec![self.counter],
type_override,
});
syn::LitStr::new("?", span)
}
}
}
fn get_arguments(self) -> Vec<(syn::Ident, Option<proc_macro2::TokenStream>)> {
let mut run_time_bindings: Vec<_> = self
.bindings
.into_iter()
.flat_map(|(name, binding)| {
binding
.indices
.into_iter()
.map(|index| {
(
syn::Ident::new(&name.value(), name.span()),
binding.type_override.clone(),
index,
)
})
.collect::<Vec<_>>()
})
.collect();
run_time_bindings.sort_by_key(|(_, _, index)| *index);
run_time_bindings
.into_iter()
.map(|(ident, type_override, _)| (ident, type_override))
.collect()
}
}
pub(crate) fn expand(
database_type: DatabaseType,
lowered: LoweredConditionalQueryAs,
) -> Result<ExpandedConditionalQueryAs, ExpandError> {
let mut match_arms = Vec::new();
for arm in lowered.match_arms {
let mut fragments = vec![lowered.query_string.clone()];
while fragments
.iter()
.any(|fragment| fragment.value().contains("{#"))
{
fragments = expand_compile_time_bindings(fragments, &arm.compile_time_bindings)?;
}
let mut run_time_bindings = RunTimeBindings::new(database_type);
let expanded = expand_run_time_bindings(fragments, &mut run_time_bindings)?;
match_arms.push(MatchArm {
patterns: arm.patterns,
query_fragments: expanded,
run_time_bindings: run_time_bindings.get_arguments(),
});
}
Ok(ExpandedConditionalQueryAs {
output_type: lowered.output_type,
match_expressions: lowered.match_expressions,
match_arms,
})
}
fn expand_compile_time_bindings(
unexpanded_fragments: Vec<syn::LitStr>,
compile_time_bindings: &HashMap<String, syn::LitStr>,
) -> Result<Vec<syn::LitStr>, ExpandError> {
let mut expanded_fragments = Vec::new();
for fragment in unexpanded_fragments {
let fragment_string = fragment.value();
let mut fragment_str = fragment_string.as_str();
while let Some(start_of_binding) = fragment_str.find('{') {
if !fragment_str[..start_of_binding].is_empty() {
expanded_fragments.push(syn::LitStr::new(
&fragment_str[..start_of_binding],
fragment.span(),
));
fragment_str = &fragment_str[start_of_binding..];
}
let end_of_binding = if let Some(end_of_binding) = fragment_str.find('}') {
end_of_binding
} else {
return Err(ExpandError::MissingBindingClosingBrace(fragment.span()));
};
if fragment_str.chars().nth(1) == Some('#') {
let binding_name = &fragment_str[2..end_of_binding];
if let Some(binding) = compile_time_bindings.get(binding_name) {
expanded_fragments.push(binding.clone());
} else {
return Err(ExpandError::MissingCompileTimeBinding(
binding_name.to_string(),
fragment.span(),
));
}
} else {
expanded_fragments.push(syn::LitStr::new(
&fragment_str[..end_of_binding + 1],
fragment.span(),
));
}
fragment_str = &fragment_str[end_of_binding + 1..];
}
if !fragment_str.is_empty() {
expanded_fragments.push(syn::LitStr::new(fragment_str, fragment.span()));
}
}
Ok(expanded_fragments)
}
fn expand_run_time_bindings(
unexpanded_fragments: Vec<syn::LitStr>,
run_time_bindings: &mut RunTimeBindings,
) -> Result<Vec<syn::LitStr>, ExpandError> {
let mut expanded_query = Vec::new();
for fragment in unexpanded_fragments {
let fragment_string = fragment.value();
let mut fragment_str = fragment_string.as_str();
while let Some(start_of_binding) = fragment_str.find('{') {
expanded_query.push(syn::LitStr::new(
&fragment_str[..start_of_binding],
fragment.span(),
));
fragment_str = &fragment_str[start_of_binding + 1..];
let end_of_binding = if let Some(end_of_binding) = fragment_str.find('}') {
end_of_binding
} else {
return Err(ExpandError::MissingBindingClosingBrace(fragment.span()));
};
let binding_name = &fragment_str[..end_of_binding];
let (binding_name, type_override) = if let Some(offset) = binding_name.find(':') {
let (binding_name, type_override) = binding_name.split_at(offset);
let type_override = type_override[1..]
.parse::<proc_macro2::TokenStream>()
.map_err(|err| {
ExpandError::BindingReferenceTypeOverrideParseError(err, fragment.span())
})?;
(binding_name.trim(), Some(type_override))
} else {
(binding_name, None)
};
let binding = run_time_bindings.get_binding_string(
syn::LitStr::new(binding_name, fragment.span()),
type_override,
);
expanded_query.push(binding);
fragment_str = &fragment_str[end_of_binding + 1..];
}
if !fragment_str.is_empty() {
expanded_query.push(syn::LitStr::new(fragment_str, fragment.span()));
}
}
Ok(expanded_query)
}
#[cfg(test)]
mod tests {
use quote::ToTokens;
use super::*;
#[rstest::rstest]
#[case(DatabaseType::PostgreSql)]
#[case(DatabaseType::MySql)]
#[case(DatabaseType::Sqlite)]
fn expands_compile_time_bindings(#[case] database_type: DatabaseType) {
let parsed = syn::parse_str::<crate::parse::ParsedConditionalQueryAs>(
r#"
SomeType,
"some {#a} {#b} {#j} query",
#(a, b) = match c {
d => ("e", "f"),
g => ("h", "i"),
},
#j = match i {
k => "l",
m => "n",
},
"#,
)
.unwrap();
let analyzed = crate::analyze::analyze(parsed.clone()).unwrap();
let lowered = crate::lower::lower(analyzed);
let expanded = expand(database_type, lowered).unwrap();
assert_eq!(
expanded.match_arms[0]
.query_fragments
.iter()
.map(|qs| qs.to_token_stream().to_string())
.collect::<Vec<_>>(),
&[
"\"some \"",
"\"e\"",
"\" \"",
"\"f\"",
"\" \"",
"\"l\"",
"\" query\""
],
);
}
#[rstest::rstest]
#[case(DatabaseType::PostgreSql)]
#[case(DatabaseType::MySql)]
#[case(DatabaseType::Sqlite)]
fn expands_run_time_bindings(#[case] database_type: DatabaseType) {
let parsed = syn::parse_str::<crate::parse::ParsedConditionalQueryAs>(
r#"
SomeType,
"some {foo:ty} {bar} {foo} query",
"#,
)
.unwrap();
let analyzed = crate::analyze::analyze(parsed.clone()).unwrap();
let lowered = crate::lower::lower(analyzed);
let expanded = expand(database_type, lowered).unwrap();
assert_eq!(
expanded.match_arms[0]
.query_fragments
.iter()
.map(|qs| qs.to_token_stream().to_string())
.collect::<Vec<_>>(),
match database_type {
DatabaseType::PostgreSql => &[
"\"some \"",
"\"$1\"",
"\" \"",
"\"$2\"",
"\" \"",
"\"$1\"",
"\" query\""
],
DatabaseType::MySql | DatabaseType::Sqlite => &[
"\"some \"",
"\"?\"",
"\" \"",
"\"?\"",
"\" \"",
"\"?\"",
"\" query\""
],
}
);
let run_time_bindings: Vec<_> = expanded.match_arms[0]
.run_time_bindings
.iter()
.map(|(ident, ts)| (ident.to_string(), ts.as_ref().map(|ts| ts.to_string())))
.collect();
assert_eq!(
run_time_bindings,
match database_type {
DatabaseType::PostgreSql => vec![
("foo".to_string(), Some("ty".to_string())),
("bar".to_string(), None),
],
DatabaseType::MySql | DatabaseType::Sqlite => vec![
("foo".to_string(), Some("ty".to_string())),
("bar".to_string(), None),
("foo".to_string(), Some("ty".to_string())),
],
}
);
}
}