Skip to main content

sql_forge_macro/
lib.rs

1// =============================================================================
2// Imports
3// =============================================================================
4
5use proc_macro::TokenStream;
6use proc_macro2::{Delimiter, Group, Span, TokenStream as TokenStream2, TokenTree};
7use quote::{format_ident, quote};
8use std::collections::{HashMap, HashSet};
9use std::fmt::Write;
10use std::fs;
11use std::path::Path;
12use syn::parse::{Parse, ParseStream};
13use syn::parse_quote;
14use syn::spanned::Spanned;
15use syn::{
16    Expr, ExprBlock, ExprGroup, ExprLit, ExprParen, Ident, Lit, LitStr, Pat, Stmt, Token, Type,
17};
18
19// =============================================================================
20// Input types
21// =============================================================================
22
23mod kw {
24    syn::custom_keyword!(scalar);
25}
26
27/// A `:name = expr` parameter binding.
28#[derive(Clone)]
29struct ParamAssign {
30    name: Ident,
31    expr: Expr,
32}
33
34impl Parse for ParamAssign {
35    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
36        input.parse::<Token![:]>()?;
37        let name: Ident = input.parse()?;
38        input.parse::<Token![=]>()?;
39        let expr: Expr = input.parse()?;
40        Ok(Self { name, expr })
41    }
42}
43
44/// A single section value: SQL string + optional local parameter bindings.
45#[derive(Clone)]
46struct SectionFragment {
47    sql: String,
48    span: Span,
49    params: ParamsSource,
50}
51
52/// One arm in a `match { ... }` inside a section assignment.
53#[derive(Clone)]
54struct SectionMatchArm {
55    pat: Pat,
56    guard: Option<Expr>,
57    value: SectionValue,
58}
59
60/// The right-hand side of a section `#name = ...` assignment.
61#[derive(Clone)]
62enum SectionValue {
63    /// A plain string (or `(string, params)` tuple).
64    Single(SectionFragment),
65    /// A tuple of values, one per section when using `#(a, b) = ...`.
66    Grouped(Vec<SectionValue>),
67    /// A `match expr { arm => ..., arm => ... }` expression.
68    Match {
69        expr: Expr,
70        arms: Vec<SectionMatchArm>,
71    },
72}
73
74/// A full `#name = value` or `#(a, b) = value` section assignment.
75#[derive(Clone)]
76struct SectionAssign {
77    names: Vec<Ident>,
78    value: SectionValue,
79}
80
81impl Parse for SectionAssign {
82    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
83        input.parse::<Token![#]>()?;
84
85        // Parse `#(a, b)` for grouped sections, or `#name` for single.
86        let names = if input.peek(syn::token::Paren) {
87            let content;
88            syn::parenthesized!(content in input);
89            let mut out = Vec::new();
90            while !content.is_empty() {
91                out.push(content.parse::<Ident>()?);
92                if content.is_empty() {
93                    break;
94                }
95                content.parse::<Token![,]>()?;
96            }
97            if out.is_empty() {
98                return Err(input.error("sql_forge!: grouped section key list cannot be empty"));
99            }
100            out
101        } else {
102            vec![input.parse::<Ident>()?]
103        };
104
105        input.parse::<Token![=]>()?;
106        let value = parse_section_value(input, names.len())?;
107        Ok(Self { names, value })
108    }
109}
110
111/// The fully-parsed macro invocation.
112struct SqlForgeInput {
113    db: Option<Type>,
114    result: ResultSpec,
115    force_scalar: bool,
116    sql: SqlTemplate,
117    params: ParamsSource,
118    sections: Vec<SectionAssign>,
119    batch: Option<Expr>,
120}
121
122/// One entry in a result map `(>key = Model, ...)`.
123#[derive(Clone)]
124struct ResultAssign {
125    name: Ident,
126    model: Type,
127    force_scalar: bool,
128}
129
130#[derive(Clone)]
131enum ResultSpec {
132    /// Execute-only (no model), e.g. `sql_forge!("SQL", ...)`
133    None,
134    /// e.g. `sql_forge!(User, ...)`
135    Single(Box<Type>),
136    /// e.g. `sql_forge!((>a = X, >b = Y), ...)`
137    Group(Vec<ResultAssign>),
138}
139
140#[derive(Clone)]
141enum ParamsSource {
142    None,
143    /// `( :name = expr, ... )`
144    Map(Vec<ParamAssign>),
145    /// A struct expression whose fields are matched to `:name` placeholders.
146    Struct(Box<Expr>),
147}
148
149/// The SQL template. Only string literals are supported.
150enum SqlTemplate {
151    Literal(LitStr),
152}
153
154impl SqlTemplate {
155    fn span(&self) -> Span {
156        match self {
157            Self::Literal(lit) => lit.span(),
158        }
159    }
160
161    /// Parse the SQL into a sequence of textual segments and `{#section}` slots.
162    fn into_segments(self) -> Result<Vec<Segment>, String> {
163        match self {
164            Self::Literal(lit) => parse_literal_segments(&lit.value()),
165        }
166    }
167}
168
169fn parse_sql_template(input: ParseStream<'_>) -> syn::Result<SqlTemplate> {
170    if input.peek(LitStr) {
171        Ok(SqlTemplate::Literal(input.parse::<LitStr>()?))
172    } else {
173        Err(input.error("sql_forge!: SQL template must be a string literal"))
174    }
175}
176
177/// One piece of the parsed SQL template.
178#[derive(Clone)]
179enum Segment {
180    /// Plain SQL text (may contain `:param` placeholders).
181    Text(String),
182    /// A `{#section_name}` placeholder.
183    Section { name: String },
184    /// A `{( ... )}` batch value template (repeated per batch item).
185    Batch { parts: Vec<TextPart> },
186}
187
188/// A fragment of SQL text after splitting on `:param`.
189#[derive(Clone)]
190enum TextPart {
191    /// Literal SQL text.
192    Lit(String),
193    /// A `:param` placeholder.
194    Param { name: String, is_list: bool },
195}
196
197// =============================================================================
198// Parsing helpers
199// =============================================================================
200
201/// Used by `detect_parenthesized_map_kind` to identify what a `(...)` argument is.
202enum MapKind {
203    Results,
204    Params,
205    Sections,
206}
207
208// =============================================================================
209// Section value parsing (string fragments, match, grouped tuples)
210// =============================================================================
211
212fn detect_parenthesized_map_kind(input: ParseStream<'_>) -> syn::Result<Option<MapKind>> {
213    let fork = input.fork();
214    let content;
215    syn::parenthesized!(content in fork);
216
217    if content.is_empty() {
218        return Err(input.error("sql_forge!: map argument cannot be empty"));
219    }
220
221    if content.peek(Token![>]) {
222        Ok(Some(MapKind::Results))
223    } else if content.peek(Token![:]) {
224        Ok(Some(MapKind::Params))
225    } else if content.peek(Token![#]) {
226        Ok(Some(MapKind::Sections))
227    } else {
228        Ok(None)
229    }
230}
231
232impl Parse for ResultAssign {
233    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
234        input.parse::<Token![>]>()?;
235        let name: Ident = input.parse()?;
236        input.parse::<Token![=]>()?;
237        let (force_scalar, model) = if input.peek(kw::scalar) {
238            input.parse::<kw::scalar>()?;
239            (true, input.parse::<Type>()?)
240        } else {
241            (false, input.parse::<Type>()?)
242        };
243        Ok(Self {
244            name,
245            model,
246            force_scalar,
247        })
248    }
249}
250
251fn parse_result_map(input: ParseStream<'_>) -> syn::Result<Vec<ResultAssign>> {
252    let content;
253    syn::parenthesized!(content in input);
254
255    let mut results = Vec::new();
256    while !content.is_empty() {
257        results.push(content.parse::<ResultAssign>()?);
258        if content.is_empty() {
259            break;
260        }
261        content.parse::<Token![,]>()?;
262    }
263
264    if results.is_empty() {
265        return Err(input.error("sql_forge!: result map cannot be empty"));
266    }
267
268    Ok(results)
269}
270
271fn parse_param_map(input: ParseStream<'_>) -> syn::Result<Vec<ParamAssign>> {
272    let content;
273    syn::parenthesized!(content in input);
274
275    let mut params = Vec::new();
276    while !content.is_empty() {
277        params.push(content.parse::<ParamAssign>()?);
278        if content.is_empty() {
279            break;
280        }
281        content.parse::<Token![,]>()?;
282    }
283
284    Ok(params)
285}
286
287fn parse_section_map(input: ParseStream<'_>) -> syn::Result<Vec<SectionAssign>> {
288    let content;
289    syn::parenthesized!(content in input);
290
291    let mut sections = Vec::new();
292    while !content.is_empty() {
293        sections.push(content.parse::<SectionAssign>()?);
294        if content.is_empty() {
295            break;
296        }
297        content.parse::<Token![,]>()?;
298    }
299
300    Ok(sections)
301}
302
303fn parse_params_source_expr(
304    input: ParseStream<'_>,
305    allow_sections: bool,
306) -> syn::Result<ParamsSource> {
307    if input.peek(syn::token::Paren) {
308        match detect_parenthesized_map_kind(input)? {
309            Some(MapKind::Results) => Err(input
310                .error("sql_forge!: result maps are only allowed as the macro result argument")),
311            Some(MapKind::Params) => Ok(ParamsSource::Map(parse_param_map(input)?)),
312            Some(MapKind::Sections) if allow_sections => {
313                Err(input.error("sql_forge!: section maps are not allowed here"))
314            }
315            Some(MapKind::Sections) => Err(input.error(
316                "sql_forge!: use :name = expr for section-local parameters, not #name = expr",
317            )),
318            None => Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?))),
319        }
320    } else {
321        Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?)))
322    }
323}
324
325fn parse_section_fragment(input: ParseStream<'_>) -> syn::Result<SectionFragment> {
326    if input.peek(syn::token::Paren) {
327        let fork = input.fork();
328        let content;
329        syn::parenthesized!(content in fork);
330
331        if let Ok(first_expr) = content.parse::<Expr>() {
332            if extract_lit_str(&first_expr).is_some() && content.parse::<Token![,]>().is_ok() {
333                let _ = parse_params_source_expr(&content, false)?;
334                if content.peek(Token![,]) {
335                    content.parse::<Token![,]>()?;
336                }
337                if content.is_empty() {
338                    let content;
339                    syn::parenthesized!(content in input);
340                    let first_expr: Expr = content.parse()?;
341                    let sql = extract_lit_str(&first_expr).ok_or_else(|| {
342                        input.error("sql_forge!: section tuple must start with a string literal")
343                    })?;
344                    let span = first_expr.span();
345                    content.parse::<Token![,]>()?;
346                    let params = parse_params_source_expr(&content, false)?;
347                    if content.peek(Token![,]) {
348                        content.parse::<Token![,]>()?;
349                    }
350                    if !content.is_empty() {
351                        return Err(content.error(
352                            "sql_forge!: unexpected tokens after section-local parameter source",
353                        ));
354                    }
355                    return Ok(SectionFragment { sql, span, params });
356                }
357            }
358        }
359    }
360
361    let expr: Expr = input.parse()?;
362    let sql = extract_lit_str(&expr).ok_or_else(|| {
363        input
364            .error("sql_forge!: section values must be string literals or (string literal, params)")
365    })?;
366    Ok(SectionFragment {
367        sql,
368        span: expr.span(),
369        params: ParamsSource::None,
370    })
371}
372
373fn parse_section_value(input: ParseStream<'_>, width: usize) -> syn::Result<SectionValue> {
374    if input.peek(Token![match]) {
375        input.parse::<Token![match]>()?;
376        let expr: Expr = input.call(Expr::parse_without_eager_brace)?;
377        let content;
378        syn::braced!(content in input);
379        let mut arms = Vec::new();
380        while !content.is_empty() {
381            let pat = content.call(Pat::parse_multi_with_leading_vert)?;
382            let guard = if content.peek(Token![if]) {
383                content.parse::<Token![if]>()?;
384                Some(content.parse::<Expr>()?)
385            } else {
386                None
387            };
388            content.parse::<Token![=>]>()?;
389            let value = parse_section_value(&content, width)?;
390            if content.peek(Token![,]) {
391                content.parse::<Token![,]>()?;
392            }
393            arms.push(SectionMatchArm { pat, guard, value });
394        }
395        return Ok(SectionValue::Match { expr, arms });
396    }
397
398    if width == 1 {
399        return Ok(SectionValue::Single(parse_section_fragment(input)?));
400    }
401
402    let content;
403    syn::parenthesized!(content in input);
404    let mut items = Vec::new();
405    while !content.is_empty() {
406        items.push(parse_section_value(&content, 1)?);
407        if content.is_empty() {
408            break;
409        }
410        content.parse::<Token![,]>()?;
411    }
412
413    if items.len() != width {
414        return Err(input.error(format!(
415            "sql_forge!: grouped section value must provide exactly {} items",
416            width,
417        )));
418    }
419
420    Ok(SectionValue::Grouped(items))
421}
422
423// =============================================================================
424// Top-level macro input parsing (SqlForgeInput::parse)
425// =============================================================================
426
427impl Parse for SqlForgeInput {
428    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
429        let (db, result, force_scalar, sql) = if input.peek(LitStr) {
430            let sql = parse_sql_template(input)?;
431            (None, ResultSpec::None, false, sql)
432        } else if input.peek(kw::scalar) {
433            input.parse::<kw::scalar>()?;
434            let model: Type = input.parse()?;
435            input.parse::<Token![,]>()?;
436            let sql = parse_sql_template(input)?;
437            (None, ResultSpec::Single(Box::new(model)), true, sql)
438        } else if input.peek(syn::token::Paren) {
439            let result_map_kind = detect_parenthesized_map_kind(input)?;
440            match result_map_kind {
441                Some(MapKind::Results) => {
442                    let result = ResultSpec::Group(parse_result_map(input)?);
443                    input.parse::<Token![,]>()?;
444                    let sql = parse_sql_template(input)?;
445                    (None, result, false, sql)
446                }
447                _ => {
448                    return Err(input.error(
449                        "sql_forge!: expected a result map like (>name = Model, ...) or a model type",
450                    ));
451                }
452            }
453        } else {
454            let first_ty: Type = input.parse()?;
455            input.parse::<Token![,]>()?;
456
457            if input.peek(LitStr) {
458                let model = first_ty;
459                let sql = parse_sql_template(input)?;
460                (None, ResultSpec::Single(Box::new(model)), false, sql)
461            } else if input.peek(kw::scalar) {
462                input.parse::<kw::scalar>()?;
463                let model: Type = input.parse()?;
464                input.parse::<Token![,]>()?;
465                let sql = parse_sql_template(input)?;
466                (
467                    Some(first_ty),
468                    ResultSpec::Single(Box::new(model)),
469                    true,
470                    sql,
471                )
472            } else if input.peek(syn::token::Paren)
473                && matches!(
474                    detect_parenthesized_map_kind(input)?,
475                    Some(MapKind::Results)
476                )
477            {
478                let result = ResultSpec::Group(parse_result_map(input)?);
479                input.parse::<Token![,]>()?;
480                let sql = parse_sql_template(input)?;
481                (Some(first_ty), result, false, sql)
482            } else {
483                let db = Some(first_ty);
484                let model: Type = input.parse()?;
485                input.parse::<Token![,]>()?;
486                let sql = parse_sql_template(input)?;
487                (db, ResultSpec::Single(Box::new(model)), false, sql)
488            }
489        };
490
491        let mut batch = None;
492        let mut params = ParamsSource::None;
493        let mut sections = Vec::new();
494        let mut seen_params = false;
495        let mut seen_sections = false;
496
497        if input.parse::<Token![,]>().is_ok() {
498            while !input.is_empty() {
499                if input.peek(Token![..]) {
500                    if batch.is_some() {
501                        return Err(
502                            input.error("sql_forge!: only one batch source argument is allowed")
503                        );
504                    }
505                    input.parse::<Token![..]>()?;
506                    batch = Some(input.parse::<Expr>()?);
507                } else if input.peek(syn::token::Paren) {
508                    match detect_parenthesized_map_kind(input)? {
509                        Some(MapKind::Results) => {
510                            return Err(input.error(
511                                "sql_forge!: result maps are only allowed as the macro result argument",
512                            ));
513                        }
514                        Some(MapKind::Params) => {
515                            if seen_params {
516                                return Err(
517                                    input.error("sql_forge!: only one parameter source is allowed")
518                                );
519                            }
520                            params = ParamsSource::Map(parse_param_map(input)?);
521                            seen_params = true;
522                        }
523                        Some(MapKind::Sections) => {
524                            if seen_sections {
525                                return Err(
526                                    input.error("sql_forge!: duplicate section map argument")
527                                );
528                            }
529                            sections = parse_section_map(input)?;
530                            seen_sections = true;
531                        }
532                        None => {
533                            if seen_params {
534                                return Err(
535                                    input.error("sql_forge!: only one parameter source is allowed")
536                                );
537                            }
538                            params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
539                            seen_params = true;
540                        }
541                    }
542                } else {
543                    if seen_params {
544                        return Err(input.error("sql_forge!: only one parameter source is allowed"));
545                    }
546                    params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
547                    seen_params = true;
548                }
549
550                if input.parse::<Token![,]>().is_ok() {
551                    continue;
552                }
553                break;
554            }
555        }
556
557        if !input.is_empty() {
558            return Err(input.error("sql_forge!: unexpected tokens in macro invocation"));
559        }
560
561        Ok(Self {
562            db,
563            result,
564            force_scalar,
565            sql,
566            params,
567            sections,
568            batch,
569        })
570    }
571}
572
573// =============================================================================
574// Database type resolution
575// =============================================================================
576
577fn resolve_db_from_env() -> Result<Type, String> {
578    if let Ok(val) = std::env::var("SQL_FORGE_DB_TYPE") {
579        return syn::parse_str::<Type>(&val).map_err(|err| {
580            format!(
581                "sql_forge!: invalid DB type `{}` in SQL_FORGE_DB_TYPE env var: {}",
582                val, err
583            )
584        });
585    }
586
587    let manifest_dir = match std::env::var("CARGO_MANIFEST_DIR") {
588        Ok(d) => d,
589        Err(_) => {
590            return Err(
591                "sql_forge!: pass DB as first macro argument, set SQL_FORGE_DB_TYPE, \
592                 or configure [package.metadata.sql_forge] in Cargo.toml"
593                    .to_string(),
594            );
595        }
596    };
597    let manifest_path = Path::new(&manifest_dir).join("Cargo.toml");
598
599    let cargo_toml = fs::read_to_string(&manifest_path).map_err(|err| {
600        format!(
601            "sql_forge!: failed to read {}: {}",
602            manifest_path.display(),
603            err
604        )
605    })?;
606
607    let value: toml::Value = toml::from_str(&cargo_toml)
608        .map_err(|err| format!("sql_forge!: failed to parse Cargo.toml: {}", err))?;
609
610    let db_str = value
611        .get("package")
612        .and_then(|v| v.get("metadata"))
613        .and_then(|v| v.get("sql_forge"))
614        .and_then(|v| v.get("db"))
615        .and_then(|v| v.as_str())
616        .ok_or({
617            "sql_forge!: missing [package.metadata.sql_forge] db = \"...\" in Cargo.toml, \
618             SQL_FORGE_DB_TYPE env var, or DB as first macro argument"
619        })?;
620
621    syn::parse_str::<Type>(db_str).map_err(|err| {
622        format!(
623            "sql_forge!: invalid DB type `{}` in Cargo.toml metadata: {}",
624            db_str, err
625        )
626    })
627}
628
629fn uses_dollar_params(db: &Type) -> bool {
630    let Type::Path(type_path) = db else {
631        return false;
632    };
633    type_path
634        .path
635        .segments
636        .last()
637        .is_some_and(|s| s.ident == "Postgres")
638}
639
640fn is_builtin_scalar_type(ty: &Type) -> bool {
641    let Type::Path(type_path) = ty else {
642        return false;
643    };
644
645    if type_path.qself.is_some()
646        || type_path.path.leading_colon.is_some()
647        || type_path.path.segments.len() != 1
648    {
649        return false;
650    }
651
652    let ident = &type_path.path.segments[0].ident;
653    ident == "i8"
654        || ident == "i16"
655        || ident == "i32"
656        || ident == "i64"
657        || ident == "isize"
658        || ident == "u8"
659        || ident == "u16"
660        || ident == "u32"
661        || ident == "u64"
662        || ident == "usize"
663        || ident == "f32"
664        || ident == "f64"
665        || ident == "bool"
666        || ident == "String"
667}
668
669fn scalar_output_type(model: &Type) -> Option<&Type> {
670    if is_builtin_scalar_type(model) {
671        return Some(model);
672    }
673    None
674}
675
676fn push_text_segment(out: &mut Vec<Segment>, text: String) {
677    if text.is_empty() {
678        return;
679    }
680    match out.last_mut() {
681        Some(Segment::Text(existing)) => existing.push_str(&text),
682        _ => out.push(Segment::Text(text)),
683    }
684}
685
686fn parse_literal_segments(sql: &str) -> Result<Vec<Segment>, String> {
687    let mut out = Vec::new();
688    let mut text = String::new();
689    let mut chars = sql.chars().peekable();
690
691    while let Some(ch) = chars.next() {
692        if ch != '{' {
693            text.push(ch);
694            continue;
695        }
696
697        if chars.peek() == Some(&'(') {
698            push_text_segment(&mut out, std::mem::take(&mut text));
699
700            let mut paren_depth = 0u32;
701            let mut content = String::new();
702            let mut found_close = false;
703            for ch in chars.by_ref() {
704                if ch == '{' {
705                    return Err(
706                        "sql_forge!: nested braces not allowed inside batch section".to_string()
707                    );
708                }
709                if ch == '}' {
710                    if paren_depth != 0 {
711                        return Err(
712                            "sql_forge!: batch section {( ... )} has unbalanced parentheses"
713                                .to_string(),
714                        );
715                    }
716                    found_close = true;
717                    break;
718                }
719                if ch == '(' {
720                    paren_depth += 1;
721                } else if ch == ')' {
722                    if paren_depth == 0 {
723                        return Err(
724                            "sql_forge!: batch section {( ... )} has unbalanced parentheses"
725                                .to_string(),
726                        );
727                    }
728                    paren_depth -= 1;
729                }
730                content.push(ch);
731            }
732            if !found_close {
733                return Err("sql_forge!: batch section {( ... )} without closing }".to_string());
734            }
735            let parts = parse_text_parts(&content);
736            for part in &parts {
737                if let TextPart::Param { is_list: true, .. } = part {
738                    return Err(
739                        "sql_forge!: list parameters (:name[]) are not allowed inside {( ... )} \
740                         batch sections; use plain parameters (:name) instead"
741                            .to_string(),
742                    );
743                }
744            }
745            out.push(Segment::Batch { parts });
746            continue;
747        }
748
749        if chars.peek() != Some(&'#') {
750            text.push(ch);
751            continue;
752        }
753
754        chars.next();
755        push_text_segment(&mut out, std::mem::take(&mut text));
756
757        let mut name = String::new();
758        loop {
759            let Some(next) = chars.next() else {
760                return Err("sql_forge!: section placeholder without closing }".to_string());
761            };
762            if next == '}' {
763                break;
764            }
765            name.push(next);
766        }
767
768        if name.is_empty() {
769            return Err("sql_forge!: empty section placeholder name".to_string());
770        }
771
772        out.push(Segment::Section { name });
773    }
774
775    push_text_segment(&mut out, text);
776    Ok(out)
777}
778
779// =============================================================================
780// SQL template parsing: {#sections} and :param placeholders
781// =============================================================================
782
783fn is_ident_start(ch: char) -> bool {
784    ch == '_' || ch.is_ascii_alphabetic()
785}
786
787fn is_ident_continue(ch: char) -> bool {
788    is_ident_start(ch) || ch.is_ascii_digit()
789}
790
791fn sanitize_backticked_alias_ident(content: &str) -> String {
792    let mut split_at = content.len();
793    for (idx, ch) in content.char_indices() {
794        if ch == '!' || ch == '?' || ch == ':' {
795            split_at = idx;
796            break;
797        }
798    }
799
800    if split_at == content.len() {
801        return content.to_string();
802    }
803
804    let base = content[..split_at].trim_end();
805    if base.is_empty() {
806        content.to_string()
807    } else {
808        base.to_string()
809    }
810}
811
812fn sanitize_runtime_sql_text(text: &str) -> String {
813    let mut out = String::with_capacity(text.len());
814    let mut chars = text.chars().peekable();
815
816    while let Some(ch) = chars.next() {
817        if ch != '`' {
818            out.push(ch);
819            continue;
820        }
821
822        let mut content = String::new();
823        let mut closed = false;
824
825        for next in chars.by_ref() {
826            if next == '`' {
827                closed = true;
828                break;
829            }
830            content.push(next);
831        }
832
833        if closed {
834            out.push('`');
835            out.push_str(&sanitize_backticked_alias_ident(&content));
836            out.push('`');
837        } else {
838            out.push('`');
839            out.push_str(&content);
840            break;
841        }
842    }
843
844    out
845}
846
847fn parse_text_parts(text: &str) -> Vec<TextPart> {
848    let mut parts = Vec::new();
849    let mut last = 0usize;
850    let mut iter = text.char_indices().peekable();
851
852    while let Some((idx, ch)) = iter.next() {
853        if ch != ':' {
854            continue;
855        }
856
857        let Some(&(next_idx, next_ch)) = iter.peek() else {
858            continue;
859        };
860
861        if !is_ident_start(next_ch) {
862            continue;
863        }
864
865        if text[..idx].ends_with(':') {
866            continue;
867        }
868
869        if last < idx {
870            parts.push(TextPart::Lit(text[last..idx].to_string()));
871        }
872
873        iter.next();
874
875        let mut name = String::new();
876        name.push(next_ch);
877        let mut end = next_idx + next_ch.len_utf8();
878
879        while let Some(&(j, c)) = iter.peek() {
880            if is_ident_continue(c) {
881                name.push(c);
882                end = j + c.len_utf8();
883                iter.next();
884            } else {
885                break;
886            }
887        }
888
889        let mut is_list = false;
890        if text[end..].starts_with("[]") {
891            is_list = true;
892            end += 2;
893        }
894
895        parts.push(TextPart::Param { name, is_list });
896        last = end;
897    }
898
899    if last < text.len() {
900        parts.push(TextPart::Lit(text[last..].to_string()));
901    }
902
903    parts
904}
905
906fn render_validator_text(
907    text: &str,
908    use_dollar_params: bool,
909    param_offset: &mut usize,
910    list_count: usize,
911) -> (String, Vec<(String, bool)>) {
912    let mut out_sql = String::new();
913    let mut occurrences = Vec::new();
914
915    for part in parse_text_parts(text) {
916        match part {
917            TextPart::Lit(lit) => out_sql.push_str(&lit),
918            TextPart::Param { name, is_list } => {
919                if is_list && list_count > 1 {
920                    let slots: Vec<String> = if use_dollar_params {
921                        (0..list_count)
922                            .map(|i| format!("${}", *param_offset + i + 1))
923                            .collect()
924                    } else {
925                        (0..list_count).map(|_| "?".to_string()).collect()
926                    };
927                    if use_dollar_params {
928                        *param_offset += list_count;
929                    }
930                    out_sql.push_str(&slots.join(", "));
931                } else if use_dollar_params {
932                    *param_offset += 1;
933                    write!(out_sql, "${}", *param_offset).unwrap();
934                } else {
935                    out_sql.push('?');
936                }
937                occurrences.push((name, is_list));
938            }
939        }
940    }
941
942    (out_sql, occurrences)
943}
944
945fn strip_expr(expr: &Expr) -> &Expr {
946    match expr {
947        Expr::Paren(ExprParen { expr, .. }) => strip_expr(expr),
948        Expr::Group(ExprGroup { expr, .. }) => strip_expr(expr),
949        Expr::Block(ExprBlock { block, .. }) => {
950            if block.stmts.len() != 1 {
951                return expr;
952            }
953            match &block.stmts[0] {
954                Stmt::Expr(inner, None) => strip_expr(inner),
955                _ => expr,
956            }
957        }
958        _ => expr,
959    }
960}
961
962fn extract_lit_str(expr: &Expr) -> Option<String> {
963    match strip_expr(expr) {
964        Expr::Lit(ExprLit {
965            lit: Lit::Str(lit), ..
966        }) => Some(lit.value()),
967        _ => None,
968    }
969}
970
971// =============================================================================
972// Preprocessing: {>key} compile-time result flags
973// =============================================================================
974
975fn result_flag_ident(name: &str) -> syn::Ident {
976    format_ident!("__enhanced_result_flag_{}", name)
977}
978
979/// Replaces `{>key}` tokens inside braced groups with `__enhanced_result_flag_key`
980/// identifiers. This is a preprocessing step so that the rest of the parser sees
981/// plain identifiers instead of braced groups it does not understand.
982fn preprocess_result_key_placeholders(input: TokenStream2) -> TokenStream2 {
983    fn walk(stream: TokenStream2) -> TokenStream2 {
984        let mut out = TokenStream2::new();
985        let iter = stream.into_iter().peekable();
986
987        for token in iter {
988            match token {
989                TokenTree::Group(group) => {
990                    if group.delimiter() == Delimiter::Brace {
991                        let mut inner = group.stream().into_iter();
992                        let first = inner.next();
993                        let second = inner.next();
994                        let third = inner.next();
995
996                        if let (
997                            Some(TokenTree::Punct(p)),
998                            Some(TokenTree::Ident(name_ident)),
999                            None,
1000                        ) = (first, second, third)
1001                        {
1002                            if p.as_char() == '>' {
1003                                let ident = result_flag_ident(&name_ident.to_string());
1004                                out.extend(std::iter::once(TokenTree::Ident(ident)));
1005                                continue;
1006                            }
1007                        }
1008                    }
1009
1010                    let new_inner = walk(group.stream());
1011                    let mut new_group = Group::new(group.delimiter(), new_inner);
1012                    new_group.set_span(group.span());
1013                    out.extend(std::iter::once(TokenTree::Group(new_group)));
1014                }
1015                other => out.extend(std::iter::once(other)),
1016            }
1017        }
1018
1019        out
1020    }
1021
1022    walk(input)
1023}
1024
1025fn build_result_flag_bindings(keys: &[String], active_key: Option<&str>) -> Vec<TokenStream2> {
1026    keys.iter()
1027        .map(|key| {
1028            let ident = result_flag_ident(key);
1029            let enabled = Some(key.as_str()) == active_key;
1030            quote! { let #ident: bool = #enabled; }
1031        })
1032        .collect()
1033}
1034
1035fn transpose_section_case_matrix(
1036    case_matrix: Vec<Vec<SectionFragment>>,
1037    width: usize,
1038) -> Result<Vec<Vec<SectionFragment>>, String> {
1039    let mut per_section: Vec<Vec<SectionFragment>> = (0..width).map(|_| Vec::new()).collect();
1040
1041    for row in case_matrix {
1042        if row.len() != width {
1043            return Err(
1044                "sql_forge!: grouped sections must return one item per section".to_string(),
1045            );
1046        }
1047        for (section_idx, fragment) in row.into_iter().enumerate() {
1048            per_section[section_idx].push(fragment);
1049        }
1050    }
1051
1052    Ok(per_section)
1053}
1054
1055fn collect_section_case_matrix(
1056    value: SectionValue,
1057    width: usize,
1058    active_key: Option<&str>,
1059) -> Result<Vec<Vec<SectionFragment>>, String> {
1060    match value {
1061        SectionValue::Single(fragment) => {
1062            if width != 1 {
1063                return Err(
1064                    "sql_forge!: grouped sections must return one item per section".to_string(),
1065                );
1066            }
1067            Ok(vec![vec![fragment]])
1068        }
1069        SectionValue::Grouped(values) => {
1070            if values.len() != width {
1071                return Err(
1072                    "sql_forge!: grouped sections must return one item per section".to_string(),
1073                );
1074            }
1075
1076            let mut variants_by_section = Vec::<Vec<SectionFragment>>::with_capacity(width);
1077            let mut nmax = 1usize;
1078
1079            for value in values {
1080                let item_matrix = collect_section_case_matrix(value, 1, active_key)?;
1081                let mut item_variants = Vec::<SectionFragment>::with_capacity(item_matrix.len());
1082                for mut row in item_matrix {
1083                    let fragment = row.pop().ok_or_else(|| {
1084                        "sql_forge!: grouped sections must return one item per section".to_string()
1085                    })?;
1086                    if !row.is_empty() {
1087                        return Err(
1088                            "sql_forge!: grouped sections must return one item per section"
1089                                .to_string(),
1090                        );
1091                    }
1092                    item_variants.push(fragment);
1093                }
1094                if item_variants.is_empty() {
1095                    return Err("sql_forge!: section match must have at least one arm".to_string());
1096                }
1097                nmax = nmax.max(item_variants.len());
1098                variants_by_section.push(item_variants);
1099            }
1100
1101            let mut case_matrix = Vec::<Vec<SectionFragment>>::with_capacity(nmax);
1102            for case_idx in 0..nmax {
1103                let mut row = Vec::<SectionFragment>::with_capacity(width);
1104                for variants in &variants_by_section {
1105                    row.push(variants[case_idx % variants.len()].clone());
1106                }
1107                case_matrix.push(row);
1108            }
1109
1110            Ok(case_matrix)
1111        }
1112        SectionValue::Match { expr, arms } => {
1113            let mut case_matrix = Vec::<Vec<SectionFragment>>::new();
1114
1115            if let Some(key) = expr_result_flag_key(&expr) {
1116                let target = active_key == Some(key.as_str());
1117                for arm in arms {
1118                    if arm.guard.is_none() {
1119                        if let Some(false) = pattern_matches_bool(&arm.pat, target) {
1120                            continue;
1121                        }
1122                    }
1123                    let mut arm_cases = collect_section_case_matrix(arm.value, width, active_key)?;
1124                    wrap_section_case_matrix_for_match_arm(
1125                        &mut arm_cases,
1126                        &expr,
1127                        &arm.pat,
1128                        arm.guard.as_ref(),
1129                    );
1130                    case_matrix.extend(arm_cases);
1131                }
1132            } else {
1133                for arm in arms {
1134                    let mut arm_cases = collect_section_case_matrix(arm.value, width, active_key)?;
1135                    wrap_section_case_matrix_for_match_arm(
1136                        &mut arm_cases,
1137                        &expr,
1138                        &arm.pat,
1139                        arm.guard.as_ref(),
1140                    );
1141                    case_matrix.extend(arm_cases);
1142                }
1143            }
1144
1145            if case_matrix.is_empty() {
1146                return Err("sql_forge!: section match must have at least one arm".to_string());
1147            }
1148
1149            Ok(case_matrix)
1150        }
1151    }
1152}
1153
1154/// This function rewrites a section-local validator expression so it stays inside the same
1155/// match arm scope that originally introduced it
1156fn wrap_expr_for_match_arm(expr: Expr, match_expr: &Expr, pat: &Pat, guard: Option<&Expr>) -> Expr {
1157    let match_expr = match_expr.clone();
1158    let pat = pat.clone();
1159    let pattern_binds_values = match &pat {
1160        Pat::Ident(_) => true,
1161        Pat::Or(pat_or) => pat_or
1162            .cases
1163            .iter()
1164            .any(|case| matches!(case, Pat::Ident(_))),
1165        Pat::Paren(pat_paren) => matches!(pat_paren.pat.as_ref(), Pat::Ident(_)),
1166        Pat::Reference(pat_reference) => matches!(pat_reference.pat.as_ref(), Pat::Ident(_)),
1167        Pat::Slice(pat_slice) => pat_slice
1168            .elems
1169            .iter()
1170            .any(|elem| matches!(elem, Pat::Ident(_))),
1171        Pat::Struct(pat_struct) => pat_struct
1172            .fields
1173            .iter()
1174            .any(|field| matches!(*field.pat, Pat::Ident(_))),
1175        Pat::Tuple(pat_tuple) => pat_tuple
1176            .elems
1177            .iter()
1178            .any(|elem| matches!(elem, Pat::Ident(_))),
1179        Pat::TupleStruct(pat_tuple_struct) => pat_tuple_struct
1180            .elems
1181            .iter()
1182            .any(|elem| matches!(elem, Pat::Ident(_))),
1183        Pat::Type(pat_type) => matches!(pat_type.pat.as_ref(), Pat::Ident(_)),
1184        _ => false,
1185    };
1186
1187    if pattern_binds_values {
1188        if let Some(guard) = guard.cloned() {
1189            parse_quote! {
1190                match &(#match_expr) {
1191                    #pat if #guard => { #expr },
1192                    _ => unreachable!("sql_forge!: validator arm mismatch"),
1193                }
1194            }
1195        } else {
1196            parse_quote! {
1197                match &(#match_expr) {
1198                    #pat => { #expr },
1199                    _ => unreachable!("sql_forge!: validator arm mismatch"),
1200                }
1201            }
1202        }
1203    } else if let Some(guard) = guard.cloned() {
1204        parse_quote! {
1205            match &(#match_expr) {
1206                #pat if #guard => { &(#expr) },
1207                _ => unreachable!("sql_forge!: validator arm mismatch"),
1208            }
1209        }
1210    } else {
1211        parse_quote! {
1212            match &(#match_expr) {
1213                #pat => { &(#expr) },
1214                _ => unreachable!("sql_forge!: validator arm mismatch"),
1215            }
1216        }
1217    }
1218}
1219
1220fn wrap_params_source_for_match_arm(
1221    params: &mut ParamsSource,
1222    match_expr: &Expr,
1223    pat: &Pat,
1224    guard: Option<&Expr>,
1225) {
1226    match params {
1227        ParamsSource::None => {}
1228        ParamsSource::Map(entries) => {
1229            for entry in entries {
1230                entry.expr = wrap_expr_for_match_arm(entry.expr.clone(), match_expr, pat, guard);
1231            }
1232        }
1233        ParamsSource::Struct(expr) => {
1234            **expr = wrap_expr_for_match_arm((**expr).clone(), match_expr, pat, guard);
1235        }
1236    }
1237}
1238
1239fn wrap_section_case_matrix_for_match_arm(
1240    case_matrix: &mut [Vec<SectionFragment>],
1241    match_expr: &Expr,
1242    pat: &Pat,
1243    guard: Option<&Expr>,
1244) {
1245    for row in case_matrix {
1246        for fragment in row {
1247            wrap_params_source_for_match_arm(&mut fragment.params, match_expr, pat, guard);
1248        }
1249    }
1250}
1251
1252// Returns Vec<Vec<SectionFragment>> indexed [section_idx][case_idx].
1253// =============================================================================
1254// Section variant collection
1255// =============================================================================
1256
1257/// Collects all possible `SectionFragment` values per section index. Each
1258/// returned `Vec<Vec<SectionFragment>>` is indexed `[section_idx][case_idx]`,
1259/// listing every variant that section can take across all match arms.
1260/// Used for full validation (cycling strategy over all variants).
1261fn collect_section_variants(
1262    value: SectionValue,
1263    width: usize,
1264) -> Result<Vec<Vec<SectionFragment>>, String> {
1265    transpose_section_case_matrix(collect_section_case_matrix(value, width, None)?, width)
1266}
1267
1268fn expr_result_flag_key(expr: &Expr) -> Option<String> {
1269    match strip_expr(expr) {
1270        Expr::Path(path) if path.qself.is_none() && path.path.segments.len() == 1 => {
1271            let name = path.path.segments[0].ident.to_string();
1272            name.strip_prefix("__enhanced_result_flag_")
1273                .map(|v| v.to_string())
1274        }
1275        _ => None,
1276    }
1277}
1278
1279fn pattern_matches_bool(pat: &Pat, value: bool) -> Option<bool> {
1280    match pat {
1281        Pat::Lit(expr_lit) => match &expr_lit.lit {
1282            Lit::Bool(lit_bool) => Some(lit_bool.value == value),
1283            _ => None,
1284        },
1285        Pat::Wild(_) => Some(true),
1286        _ => None,
1287    }
1288}
1289
1290/// Like `collect_section_variants`, but filters `match` arms by the active
1291/// result key when the match expression is a `{>key}` flag. When building the
1292/// query for a specific key, only the matching arm (true/false) is included;
1293/// arms with guards or non-flag expressions include all variants as usual.
1294fn collect_section_variants_for_result(
1295    value: SectionValue,
1296    width: usize,
1297    active_key: Option<&str>,
1298) -> Result<Vec<Vec<SectionFragment>>, String> {
1299    transpose_section_case_matrix(
1300        collect_section_case_matrix(value, width, active_key)?,
1301        width,
1302    )
1303}
1304
1305// =============================================================================
1306// Parameter binding generation
1307// =============================================================================
1308
1309/// Generates `let` bindings for all parameters used in the SQL or section
1310/// fragments. Returns a map of param-name → local ident for later reference,
1311/// and a list of `let` statements.
1312fn build_param_bindings(
1313    params: &ParamsSource,
1314    used_param_names: &[String],
1315    prefix: &str,
1316    for_validator: bool,
1317    enforce_usage_check: bool,
1318) -> Result<(HashMap<String, syn::Ident>, Vec<TokenStream2>), TokenStream> {
1319    let mut declared_params = HashMap::<String, syn::Ident>::new();
1320    let mut bindings = Vec::<TokenStream2>::new();
1321
1322    match params {
1323        ParamsSource::None => {}
1324        ParamsSource::Map(entries) => {
1325            for entry in entries {
1326                let key = entry.name.to_string();
1327                if declared_params.contains_key(&key) {
1328                    return Err(syn::Error::new(
1329                        entry.name.span(),
1330                        "sql_forge!: duplicated parameter mapping",
1331                    )
1332                    .to_compile_error()
1333                    .into());
1334                }
1335                if enforce_usage_check && !used_param_names.iter().any(|n| n == &key) {
1336                    return Err(syn::Error::new(
1337                        entry.name.span(),
1338                        format!(
1339                            "sql_forge!: parameter :{} is unused in the SQL template",
1340                            key,
1341                        ),
1342                    )
1343                    .to_compile_error()
1344                    .into());
1345                }
1346                let local_ident = format_ident!("__enhanced_{}_{}", prefix, key);
1347                let expr = &entry.expr;
1348                if for_validator {
1349                    bindings.push(quote! {
1350                        let #local_ident = &(#expr);
1351                    });
1352                } else {
1353                    bindings.push(quote! {
1354                        let #local_ident = #expr;
1355                    });
1356                }
1357                declared_params.insert(key, local_ident);
1358            }
1359        }
1360        ParamsSource::Struct(expr) => {
1361            let source_ident = format_ident!("__enhanced_source_{}", prefix);
1362            bindings.push(quote! {
1363                let #source_ident = &(#expr);
1364            });
1365            for name in used_param_names {
1366                let local_ident = format_ident!("__enhanced_{}_{}", prefix, name);
1367                let field_ident = format_ident!("{}", name);
1368                if for_validator {
1369                    bindings.push(quote! {
1370                        let #local_ident = &#source_ident.#field_ident;
1371                    });
1372                } else {
1373                    bindings.push(quote! {
1374                        let #local_ident = #source_ident.#field_ident;
1375                    });
1376                }
1377                declared_params.insert(name.to_string(), local_ident);
1378            }
1379        }
1380    }
1381
1382    Ok((declared_params, bindings))
1383}
1384
1385struct ValidatorRenderContext<'a> {
1386    local_params: &'a HashMap<String, syn::Ident>,
1387    top_level_params: &'a HashMap<String, syn::Ident>,
1388    allow_top_level_fallback: bool,
1389    use_dollar_params: bool,
1390    sql_span: Span,
1391    list_count: usize,
1392}
1393
1394/// Builds the placeholders SQL string and argument list for the compile-time
1395/// validator (sqlx::query_as! / query_scalar!). Each `:param` in the SQL is
1396/// replaced by `?` (MySQL/SQLite) or `$1`/`$2`/... (PostgreSQL), and the
1397/// corresponding value expression is collected into the args list.
1398fn render_validator_args(
1399    sql: &str,
1400    param_offset: &mut usize,
1401    arg_index: &mut usize,
1402    context: &ValidatorRenderContext<'_>,
1403) -> Result<(String, Vec<TokenStream2>, Vec<TokenStream2>), TokenStream> {
1404    let (rendered_sql, occurrences) = render_validator_text(
1405        sql,
1406        context.use_dollar_params,
1407        param_offset,
1408        context.list_count,
1409    );
1410    let mut setup = Vec::<TokenStream2>::new();
1411    let mut args = Vec::<TokenStream2>::new();
1412
1413    for (name, is_list) in occurrences {
1414        let local_ident = if context.allow_top_level_fallback {
1415            context
1416                .local_params
1417                .get(&name)
1418                .or_else(|| context.top_level_params.get(&name))
1419        } else {
1420            context.local_params.get(&name)
1421        };
1422
1423        let Some(local_ident) = local_ident else {
1424            return Err(syn::Error::new(
1425                context.sql_span,
1426                format!("sql_forge!: parameter :{} has no mapping", name),
1427            )
1428            .to_compile_error()
1429            .into());
1430        };
1431
1432        if is_list {
1433            for _ in 0..context.list_count {
1434                let value_ident = format_ident!("__enhanced_validator_arg_{}", *arg_index);
1435                *arg_index += 1;
1436                setup.push(quote! {
1437                    let #value_ident = sql_forge::sql_forge_validator_value(
1438                        (#local_ident)
1439                            .as_slice()
1440                            .first()
1441                            .expect("sql_forge!: list parameters used in validation must have at least one representative element")
1442                    );
1443                });
1444                args.push(quote! { #value_ident });
1445            }
1446        } else {
1447            let value_ident = format_ident!("__enhanced_validator_arg_{}", *arg_index);
1448            *arg_index += 1;
1449            setup.push(quote! {
1450                let #value_ident = sql_forge::sql_forge_validator_value(#local_ident);
1451            });
1452            args.push(quote! { #value_ident });
1453        }
1454    }
1455
1456    Ok((rendered_sql, setup, args))
1457}
1458
1459// =============================================================================
1460// Runtime code generation (QueryBuilder-based)
1461// =============================================================================
1462
1463/// Generates the `push()` / `push_bind()` calls for a single section fragment
1464/// at runtime using `sqlx::QueryBuilder`.
1465fn render_runtime_fragment(
1466    fragment: &SectionFragment,
1467    local_params: &HashMap<String, syn::Ident>,
1468) -> Result<TokenStream2, TokenStream> {
1469    let mut steps = Vec::<TokenStream2>::new();
1470
1471    for part in parse_text_parts(&fragment.sql) {
1472        match part {
1473            TextPart::Lit(lit) => {
1474                let lit_str = LitStr::new(&lit, fragment.span);
1475                steps.push(quote! { __builder.push(#lit_str); });
1476            }
1477            TextPart::Param { name, is_list } => {
1478                let Some(local_ident) = local_params.get(&name) else {
1479                    return Err(syn::Error::new(
1480                        fragment.span,
1481                        format!("sql_forge!: parameter :{} has no mapping", name),
1482                    )
1483                    .to_compile_error()
1484                    .into());
1485                };
1486
1487                if is_list {
1488                    steps.push(quote! {
1489                        let __enhanced_values = #local_ident;
1490                        let mut __separated = __builder.separated(", ");
1491                        for __value in __enhanced_values {
1492                            __separated.push_bind(__value);
1493                        }
1494                    });
1495                } else {
1496                    steps.push(quote! {
1497                        __builder.push_bind(#local_ident);
1498                    });
1499                }
1500            }
1501        }
1502    }
1503
1504    Ok(quote! { #( #steps )* })
1505}
1506
1507fn build_section_runtime_action(
1508    value: &SectionValue,
1509    section_idx: usize,
1510    prefix: &str,
1511) -> Result<TokenStream2, TokenStream> {
1512    match value {
1513        SectionValue::Single(fragment) => {
1514            let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
1515            let (local_params, bindings) =
1516                build_param_bindings(&fragment.params, &used_param_names, prefix, false, true)?;
1517            let body = render_runtime_fragment(fragment, &local_params)?;
1518            Ok(quote! {{ #( #bindings )* #body }})
1519        }
1520        SectionValue::Grouped(fragments) => build_section_runtime_action(
1521            &fragments[section_idx],
1522            0,
1523            &format!("{}_grouped_{}", prefix, section_idx),
1524        ),
1525        SectionValue::Match { expr, arms } => {
1526            let arm_tokens: Result<Vec<TokenStream2>, TokenStream> = arms
1527                .iter()
1528                .enumerate()
1529                .map(|(arm_idx, arm)| {
1530                    let pat = &arm.pat;
1531                    let guard_tokens = arm.guard.as_ref().map(|guard| quote! { if #guard });
1532                    let body = build_section_runtime_action(
1533                        &arm.value,
1534                        section_idx,
1535                        &format!("{}_{}", prefix, arm_idx),
1536                    )?;
1537                    Ok::<TokenStream2, TokenStream>(quote! { #pat #guard_tokens => #body })
1538                })
1539                .collect();
1540            let arm_tokens = arm_tokens?;
1541            Ok(quote! {
1542                match #expr {
1543                    #( #arm_tokens ),*
1544                }
1545            })
1546        }
1547    }
1548}
1549
1550fn collect_used_param_names(segments: &[Segment]) -> Vec<String> {
1551    let mut names = Vec::new();
1552    let mut seen = HashSet::<String>::new();
1553
1554    for segment in segments {
1555        match segment {
1556            Segment::Text(text) => {
1557                for name in collect_used_param_names_in_sql(text) {
1558                    if seen.insert(name.clone()) {
1559                        names.push(name);
1560                    }
1561                }
1562            }
1563            Segment::Batch { parts } => {
1564                for part in parts {
1565                    if let TextPart::Param { name, .. } = part {
1566                        if seen.insert(name.clone()) {
1567                            names.push(name.clone());
1568                        }
1569                    }
1570                }
1571            }
1572            _ => {}
1573        }
1574    }
1575
1576    names
1577}
1578
1579fn collect_used_param_names_in_sql(sql: &str) -> Vec<String> {
1580    let mut names = Vec::new();
1581    let mut seen = HashSet::<String>::new();
1582    for part in parse_text_parts(sql) {
1583        if let TextPart::Param { name, .. } = part {
1584            if seen.insert(name.to_string()) {
1585                names.push(name);
1586            }
1587        }
1588    }
1589    names
1590}
1591
1592/// Builds a parameterized SQL query with compile-time type-checking and a
1593/// runtime [`sqlx::QueryBuilder`] for dynamic SQL.
1594///
1595/// Combines `sqlx::query_as!` / `sqlx::query_scalar!` validation (never called
1596/// at runtime) with `QueryBuilder::push_bind` for safe value binding.
1597///
1598/// # Syntax
1599///
1600/// ```text
1601/// sql_forge!(
1602///     [DB,]        // optional: sqlx::MySql | sqlx::Postgres | sqlx::Sqlite
1603///     [Model,]     // optional result spec
1604///     SQL,         // string literal
1605///     [params,]    // optional: ( :name = expr, ... )  or  struct_expr
1606///     [(sections),]// optional: ( #name = ..., ... )
1607///     [..batch]    // optional: batch source expression used by {( ... )}
1608/// )
1609/// ```
1610///
1611/// `Model` has three forms:
1612/// - omitted: execute-only query; only `.execute(...)` is available
1613/// - `Type` or `scalar Type`: a single result query
1614/// - `( >key1 = TypeA, >key2 = scalar TypeB )`: a grouped multi-result query
1615///
1616/// The trailing parameter source, section map, and batch source are optional.
1617/// The batch source may appear alongside the others as a single `..expr` argument.
1618///
1619/// The DB type may be omitted when `SQL_FORGE_DB_TYPE` is set (e.g.
1620/// `SQL_FORGE_DB_TYPE=sqlx::MySql`) or when
1621/// `[package.metadata.sql_forge] db = "..."` is set in `Cargo.toml`.
1622/// The env var takes priority over Cargo.toml metadata.
1623///
1624/// # Parameters
1625///
1626/// Named parameters are written `:name` in the SQL. At runtime each occurrence
1627/// is replaced by a `push_bind` call; at compile time it becomes a
1628/// database-specific placeholder: `?` for MySQL and SQLite, and `$1`, `$2`, ...
1629/// for Postgres.
1630///
1631/// **Inline map** – bind individual expressions:
1632/// ```rust,ignore
1633/// sql_forge!(User, "SELECT ... WHERE id <= :max_id", ( :max_id = filter.max_id ))
1634/// ```
1635///
1636/// **Struct source** – field names are matched to `:name` placeholders automatically:
1637/// ```rust,ignore
1638/// sql_forge!(User, "SELECT ... WHERE id <= :max_id LIMIT :limit", filter)
1639/// ```
1640///
1641/// # Sections (`{#name}`)
1642///
1643/// Sections are runtime SQL slots; each section's variants are validated at
1644/// compile time via `query_as!` / `query_scalar!`, though not every combination
1645/// of variants across sections is checked. The section map is a second parenthesised
1646/// argument starting with `#`:
1647///
1648/// ```rust,ignore
1649/// sql_forge!(
1650///     User,
1651///     "SELECT * FROM users {#join_org}",
1652///     (
1653///         #join_org = match include_org {
1654///             true  => " JOIN organisations o ON o.id = users.org_id ",
1655///             false => "",
1656///         }
1657///     )
1658/// )
1659/// ```
1660///
1661/// A section arm can also carry local parameters as a tuple `("sql", params)`:
1662///
1663/// ```rust,ignore
1664/// sql_forge!(
1665///     User,
1666///     "SELECT * FROM users {#filter}",
1667///     (
1668///         #filter = (
1669///             " WHERE id <= :max_id AND status = :status ",
1670///             ( :max_id = max_id, :status = "active" ),
1671///         )
1672///     )
1673/// )
1674/// ```
1675///
1676/// Multiple placeholders driven by one `match` use `#(a, b)` with each arm
1677/// returning a tuple of the same width:
1678///
1679/// ```rust,ignore
1680/// sql_forge!(
1681///     User,
1682///     "SELECT * FROM users {#join_org} {#filter_org}",
1683///     (
1684///         #(join_org, filter_org) = match include_org {
1685///             true  => (
1686///                 " JOIN organisations o ON o.id = users.org_id ",
1687///                 (
1688///                     " AND o.active = :active ",
1689///                     ( :active = true ),
1690///                 ),
1691///             ),
1692///             false => ("", ""),
1693///         }
1694///     )
1695/// )
1696/// ```
1697///
1698/// Grouped section items may themselves use nested `match` expressions. Those
1699/// nested matches use smart cycling within the arm rather than a cartesian
1700/// product. For example, if one grouped arm returns a fixed first item plus two
1701/// nested binary matches for the second and third items, that arm contributes
1702/// two aligned variants `(0, 0)` and `(1, 1)`, not four `(0, 0)`, `(0, 1)`,
1703/// `(1, 0)`, `(1, 1)` combinations.
1704///
1705/// # `IN (...)` with list parameters
1706///
1707/// Wrap the placeholder in parentheses to expand a `Vec` into multiple bound
1708/// values:
1709///
1710/// ```rust,ignore
1711/// sql_forge!(User, "SELECT * FROM users WHERE id IN (:ids[])", ( :ids = ids ))
1712/// ```
1713///
1714/// **Empty lists** are not rewritten; `IN ()` is a database syntax error.
1715/// Guard against empty inputs explicitly, e.g. with a dynamic section:
1716///
1717/// ```rust,ignore
1718/// sql_forge!(
1719///     User,
1720///     "SELECT id, name FROM users WHERE {#filter}",
1721///     (
1722///         #filter = match ids.is_empty() {
1723///             true  => "1 = 0",
1724///             false => ("id IN (:ids[])", ( :ids = ids )),
1725///         }
1726///     )
1727/// )
1728/// ```
1729///
1730/// # Batch inserts (`{( ... )}`)
1731///
1732/// A batch section `{( ... )}` repeats its content for each item in an iterable
1733/// source passed as `..expr`. Inside the batch, `:name` refers to a field on the
1734/// current item. List parameters (`:name[]`) are **not** allowed inside batch
1735/// sections.
1736///
1737/// ## Struct batch
1738///
1739/// ```rust,ignore
1740/// struct BatchItem { name: String, price: i64 }
1741///
1742/// let items = vec![
1743///     BatchItem { name: "A".into(), price: 100 },
1744///     BatchItem { name: "B".into(), price: 200 },
1745/// ];
1746///
1747/// sql_forge!(
1748///     "INSERT INTO products (name, price, stock, category)
1749///      VALUES {(:name, :price, 10, 'Batch')}",
1750///     ..items
1751/// )
1752/// .execute(&pool)
1753/// .await?;
1754/// ```
1755///
1756/// For compile-time checking, the validator expands the batch to 3 fake copies
1757/// (`(?, ?, 10, 'Batch'), (?, ?, 10, 'Batch'), (?, ?, 10, 'Batch')`).
1758/// At runtime the iterable drives the actual number of rows.
1759///
1760/// # Scalar output
1761///
1762/// When `Model` is a primitive (`i32`, `i64`, `String`, etc.) the macro uses
1763/// `query_scalar!` for validation and `build_query_scalar` for execution.
1764///
1765/// # Multiple results
1766///
1767/// A result map produces a `SqlForgeQueryGroup` with one query per key.
1768/// Each key can be a struct or a primitive (used as a scalar):
1769///
1770/// ```rust,ignore
1771/// sql_forge!(
1772///     (
1773///         >count = i64,
1774///         >items = Item,
1775///     ),
1776///     "SELECT {#fields} FROM items WHERE category_id = :cat",
1777///     ( :cat = category_id ),
1778///     (
1779///         #fields = match {>count} {           // {>key} is true when building
1780///             true  => "COUNT(*) AS total",    // the query for that model/result
1781///             false => "id, name, price",      // key and false otherwise
1782///         }
1783///     )
1784/// )
1785/// ```
1786///
1787/// The generated struct has one field per key (`group.count`, `group.items`),
1788/// each implementing `SqlForgeQuery<T, Db = DB>` and usable with any SQLx
1789/// executor method (`fetch_one`, `fetch_all`, etc.).
1790///
1791/// # Execute-only (no model)
1792///
1793/// When the model type is omitted, the macro produces a value implementing
1794/// `SqlForgeQueryExecute`. Only `.execute(executor)`
1795/// is available and there is no return type to deserialize into. This is useful
1796/// for `INSERT`, `UPDATE`, `DELETE`, and other DML statements.
1797///
1798/// ```rust,ignore
1799/// sql_forge!(
1800///     "UPDATE products SET stock = stock + 1 WHERE id = :id",
1801///     ( :id = 42i64 ),
1802/// )
1803/// .execute(&pool)
1804/// .await?;
1805/// ```
1806///
1807/// Sections and struct parameter sources work the same way as in model-backed queries:
1808///
1809/// ```rust,ignore
1810/// sql_forge!(
1811///     "UPDATE products SET price = :new_price {#filter}",
1812///     ( #filter = "WHERE category = :cat", ( :cat = "Electronics" ) ),
1813/// )
1814/// .execute(&pool)
1815/// .await?;
1816/// ```
1817#[proc_macro]
1818#[allow(clippy::too_many_lines)]
1819pub fn sql_forge(input: TokenStream) -> TokenStream {
1820    // ---- Phase 1: Parse the macro input into structured data ----
1821    let preprocessed = preprocess_result_key_placeholders(TokenStream2::from(input));
1822    let SqlForgeInput {
1823        db,
1824        result,
1825        force_scalar,
1826        sql,
1827        params,
1828        sections,
1829        batch,
1830    } = match syn::parse2::<SqlForgeInput>(preprocessed) {
1831        Ok(v) => v,
1832        Err(err) => return err.to_compile_error().into(),
1833    };
1834
1835    // ---- Phase 2: Resolve database type (from macro arg or Cargo.toml) ----
1836    let db = match db {
1837        Some(db) => db,
1838        None => match resolve_db_from_env() {
1839            Ok(db) => db,
1840            Err(msg) => {
1841                return syn::Error::new(Span::call_site(), msg)
1842                    .to_compile_error()
1843                    .into();
1844            }
1845        },
1846    };
1847
1848    let use_dollar_params = uses_dollar_params(&db);
1849    let is_sqlite = if let syn::Type::Path(type_path) = &db {
1850        type_path
1851            .path
1852            .segments
1853            .last()
1854            .is_some_and(|s| s.ident == "Sqlite")
1855    } else {
1856        false
1857    };
1858    let list_count: usize = if is_sqlite { 1 } else { 3 };
1859
1860    // ---- Phase 3: Build result case definitions ----
1861    // Each result case is (optional_key, model_type, optional_scalar_type).
1862    // Scalar type is set for primitives and `scalar`-marked types.
1863    let result_cases: Vec<(Option<String>, Option<Type>, Option<Type>)> = match result {
1864        ResultSpec::None => {
1865            vec![(None, None, None)]
1866        }
1867        ResultSpec::Single(ref model) => {
1868            let model_ty = (**model).clone();
1869            let scalar = if force_scalar {
1870                Some(model_ty.clone())
1871            } else {
1872                scalar_output_type(model.as_ref()).cloned()
1873            };
1874            vec![(None, Some(model_ty), scalar)]
1875        }
1876        ResultSpec::Group(ref cases) => {
1877            if force_scalar {
1878                return syn::Error::new(
1879                    Span::call_site(),
1880                    "sql_forge!: scalar mode is not supported for grouped result maps",
1881                )
1882                .to_compile_error()
1883                .into();
1884            }
1885
1886            let mut out = Vec::new();
1887            let mut seen = HashSet::new();
1888            for case in cases {
1889                let key = case.name.to_string();
1890                if !seen.insert(key.clone()) {
1891                    return syn::Error::new(
1892                        case.name.span(),
1893                        "sql_forge!: duplicated key in result map",
1894                    )
1895                    .to_compile_error()
1896                    .into();
1897                }
1898
1899                let model = case.model.clone();
1900                let scalar = if case.force_scalar {
1901                    Some(model.clone())
1902                } else {
1903                    scalar_output_type(&case.model).cloned()
1904                };
1905                out.push((Some(key), Some(model), scalar));
1906            }
1907            out
1908        }
1909    };
1910    let group_result_keys: Vec<String> = result_cases
1911        .iter()
1912        .filter_map(|(key, _, _)| key.as_ref().cloned())
1913        .collect();
1914    let is_grouped_result = !group_result_keys.is_empty();
1915    let sql_span = sql.span();
1916
1917    // ---- Phase 4: Parse SQL into segments (text + {#section} slots) ----
1918    let segments = match sql.into_segments() {
1919        Ok(segments) => segments,
1920        Err(msg) => {
1921            return syn::Error::new(sql_span, msg).to_compile_error().into();
1922        }
1923    };
1924
1925    let has_batch_segment = segments.iter().any(|s| matches!(s, Segment::Batch { .. }));
1926    match (&batch, has_batch_segment) {
1927        (None, true) => {
1928            return syn::Error::new(
1929                sql_span,
1930                "sql_forge!: SQL contains {( ... )} batch section but no batch source argument (..expr) \
1931                 was provided"
1932            )
1933            .to_compile_error()
1934            .into();
1935        }
1936        (Some(_), false) => {
1937            return syn::Error::new(
1938                sql_span,
1939                "sql_forge!: batch source argument (..expr) provided but SQL has no {( ... )} \
1940                 batch section",
1941            )
1942            .to_compile_error()
1943            .into();
1944        }
1945        _ => {}
1946    }
1947
1948    let used_param_names = collect_used_param_names(&segments);
1949
1950    // Batch-only params come from batch items, not the top-level params map.
1951    // They must be excluded from the usage check so that a param like :category
1952    // that appears only inside {( ... )} is flagged as unused when given in the
1953    // params map, as it would never be read from there at runtime.
1954    let batch_param_names: std::collections::HashSet<String> = segments
1955        .iter()
1956        .filter_map(|s| {
1957            if let Segment::Batch { parts } = s {
1958                Some(parts.iter().filter_map(|p| {
1959                    if let TextPart::Param { name, .. } = p {
1960                        Some(name.clone())
1961                    } else {
1962                        None
1963                    }
1964                }))
1965            } else {
1966                None
1967            }
1968        })
1969        .flatten()
1970        .collect();
1971    let top_level_used_names: Vec<String> = used_param_names
1972        .iter()
1973        .filter(|n| !batch_param_names.contains(*n))
1974        .cloned()
1975        .collect();
1976
1977    // ---- Phase 5: Build parameter bindings for the top-level params ----
1978    let (declared_params, validator_param_bindings) =
1979        match build_param_bindings(&params, &top_level_used_names, "top_level", true, true) {
1980            Ok(v) => v,
1981            Err(err) => return err,
1982        };
1983
1984    let mut runtime_section_actions = HashMap::<String, TokenStream2>::new();
1985
1986    // ---- Phase 6: Process sections: build runtime actions and collect validation variants ----
1987    for assign in &sections {
1988        let SectionAssign { names, value } = assign;
1989
1990        // Build runtime actions first, while `value` is still available by reference.
1991        let mut named_actions: Vec<(String, TokenStream2)> = Vec::new();
1992        for (section_idx, name_ident) in names.iter().enumerate() {
1993            let name = name_ident.to_string();
1994            if runtime_section_actions.contains_key(&name) {
1995                return syn::Error::new(
1996                    name_ident.span(),
1997                    "sql_forge!: duplicated section mapping",
1998                )
1999                .to_compile_error()
2000                .into();
2001            }
2002            let action = match build_section_runtime_action(
2003                value,
2004                section_idx,
2005                &format!("section_{}", name),
2006            ) {
2007                Ok(action) => action,
2008                Err(err) => return err,
2009            };
2010            named_actions.push((name, action));
2011        }
2012
2013        // Consume `value` here so invalid grouped/nested section structures fail early.
2014        if let Err(msg) = collect_section_variants(value.clone(), names.len()) {
2015            return syn::Error::new(names[0].span(), msg)
2016                .to_compile_error()
2017                .into();
2018        }
2019
2020        for (name, action) in named_actions {
2021            runtime_section_actions.insert(name, action);
2022        }
2023    }
2024
2025    let sql_section_names: std::collections::HashSet<&str> = segments
2026        .iter()
2027        .filter_map(|seg| {
2028            if let Segment::Section { name } = seg {
2029                Some(name.as_str())
2030            } else {
2031                None
2032            }
2033        })
2034        .collect();
2035    for name in runtime_section_actions.keys() {
2036        if !sql_section_names.contains(name.as_str()) {
2037            return syn::Error::new(
2038                sql_span,
2039                format!(
2040                    "sql_forge!: section `#{}` is declared in the section map but `{{#{}}}` never appears in the SQL",
2041                    name, name,
2042                ),
2043            )
2044            .to_compile_error()
2045            .into();
2046        }
2047    }
2048
2049    // ---- Phase 8: For each result case, generate validator + runtime tokens ----
2050    let mut generated_query_defs = Vec::<TokenStream2>::new();
2051    let mut generated_query_values = Vec::<TokenStream2>::new();
2052    let mut group_field_defs = Vec::<TokenStream2>::new();
2053    let mut group_method_defs = Vec::<TokenStream2>::new();
2054    let mut group_field_idents = Vec::<syn::Ident>::new();
2055    let mut group_field_tys = Vec::<TokenStream2>::new();
2056    let mut group_trait_impls = Vec::<TokenStream2>::new();
2057
2058    let mut grouped_validator_invocations = Vec::<TokenStream2>::new();
2059
2060    for (result_key, model_opt, scalar_model_ty) in result_cases.iter() {
2061        let suffix = result_key.as_deref().unwrap_or("single");
2062        let query_ident = format_ident!("__SqlForgeQuery_{}", suffix);
2063        let query_value_ident = format_ident!("__sql_forge_value_{}", suffix);
2064
2065        let flag_bindings = build_result_flag_bindings(&group_result_keys, result_key.as_deref());
2066
2067        let mut section_variants_for_validation = HashMap::<String, Vec<SectionFragment>>::new();
2068        for assign in &sections {
2069            let SectionAssign { names, value } = assign;
2070            let variants_by_section = match collect_section_variants_for_result(
2071                value.clone(),
2072                names.len(),
2073                result_key.as_deref(),
2074            ) {
2075                Ok(v) => v,
2076                Err(msg) => {
2077                    return syn::Error::new(names[0].span(), msg)
2078                        .to_compile_error()
2079                        .into();
2080                }
2081            };
2082
2083            for (name_ident, section_cases) in names.iter().zip(variants_by_section) {
2084                section_variants_for_validation.insert(name_ident.to_string(), section_cases);
2085            }
2086        }
2087
2088        let mut nmax = 1usize;
2089        for segment in &segments {
2090            if let Segment::Section { name } = segment {
2091                if let Some(variants) = section_variants_for_validation.get(name) {
2092                    if variants.is_empty() {
2093                        return syn::Error::new(
2094                            sql_span,
2095                            format!("sql_forge!: section {{#{}}} has no possible variants", name),
2096                        )
2097                        .to_compile_error()
2098                        .into();
2099                    }
2100                    nmax = nmax.max(variants.len());
2101                } else {
2102                    return syn::Error::new(
2103                        sql_span,
2104                        format!("sql_forge!: section {{#{}}} has no mapping", name),
2105                    )
2106                    .to_compile_error()
2107                    .into();
2108                }
2109            }
2110        }
2111
2112        let mut validator_cases = Vec::<(LitStr, Vec<TokenStream2>, Vec<TokenStream2>)>::new();
2113        for case_idx in 0..nmax {
2114            let mut sql_case = String::new();
2115            let mut case_setup = Vec::<TokenStream2>::new();
2116            let mut case_args = Vec::<TokenStream2>::new();
2117            let mut param_offset = 0usize;
2118            let mut arg_index = 0usize;
2119            let empty_params = HashMap::<String, syn::Ident>::new();
2120            let root_validator_context = ValidatorRenderContext {
2121                local_params: &empty_params,
2122                top_level_params: &declared_params,
2123                allow_top_level_fallback: true,
2124                use_dollar_params,
2125                sql_span,
2126                list_count,
2127            };
2128
2129            for segment in &segments {
2130                match segment {
2131                    Segment::Text(text) => {
2132                        let (chunk_sql, chunk_setup, chunk_args) = match render_validator_args(
2133                            text,
2134                            &mut param_offset,
2135                            &mut arg_index,
2136                            &root_validator_context,
2137                        ) {
2138                            Ok(value) => value,
2139                            Err(err) => return err,
2140                        };
2141                        sql_case.push_str(&chunk_sql);
2142                        case_setup.extend(chunk_setup);
2143                        case_args.extend(chunk_args);
2144                    }
2145                    Segment::Section { name } => {
2146                        let Some(variants) = section_variants_for_validation.get(name) else {
2147                            return syn::Error::new(
2148                                sql_span,
2149                                format!("sql_forge!: section {{#{}}} has no mapping", name),
2150                            )
2151                            .to_compile_error()
2152                            .into();
2153                        };
2154
2155                        let fragment = &variants[case_idx % variants.len()];
2156                        let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
2157                        let (local_params, bindings) = match build_param_bindings(
2158                            &fragment.params,
2159                            &used_param_names,
2160                            &format!("section_case_{}_{}_{}", suffix, case_idx, name),
2161                            true,
2162                            true,
2163                        ) {
2164                            Ok(value) => value,
2165                            Err(err) => return err,
2166                        };
2167                        let section_validator_context = ValidatorRenderContext {
2168                            local_params: &local_params,
2169                            top_level_params: &declared_params,
2170                            allow_top_level_fallback: false,
2171                            use_dollar_params,
2172                            sql_span: fragment.span,
2173                            list_count,
2174                        };
2175                        let (chunk_sql, chunk_setup, chunk_args) = match render_validator_args(
2176                            &fragment.sql,
2177                            &mut param_offset,
2178                            &mut arg_index,
2179                            &section_validator_context,
2180                        ) {
2181                            Ok(value) => value,
2182                            Err(err) => return err,
2183                        };
2184                        sql_case.push_str(&chunk_sql);
2185                        case_setup.extend(bindings);
2186                        case_setup.extend(chunk_setup);
2187                        case_args.extend(chunk_args);
2188                    }
2189                    Segment::Batch { parts } => {
2190                        let mut first = true;
2191                        for _ in 0..list_count {
2192                            let sep = if first { "" } else { ", " };
2193                            first = false;
2194                            sql_case.push_str(sep);
2195                            for tp in parts {
2196                                match tp {
2197                                    TextPart::Lit(lit) => sql_case.push_str(lit),
2198                                    TextPart::Param { name, .. } => {
2199                                        if let Some(batch_expr) = &batch {
2200                                            let field_ident = format_ident!("{}", name);
2201                                            if use_dollar_params {
2202                                                param_offset += 1;
2203                                                write!(sql_case, "${}", param_offset).unwrap();
2204                                            } else {
2205                                                sql_case.push('?');
2206                                            }
2207                                            case_args.push(quote! { #batch_expr[0].#field_ident });
2208                                        } else if use_dollar_params {
2209                                            param_offset += 1;
2210                                            write!(sql_case, "${}", param_offset).unwrap();
2211                                        } else {
2212                                            sql_case.push('?');
2213                                        }
2214                                    }
2215                                }
2216                            }
2217                        }
2218                    }
2219                }
2220            }
2221
2222            validator_cases.push((LitStr::new(&sql_case, sql_span), case_setup, case_args));
2223        }
2224
2225        let mut validator_invocations = Vec::<TokenStream2>::new();
2226        for (sql_lit, case_setup, args) in &validator_cases {
2227            if model_opt.is_none() {
2228                if args.is_empty() {
2229                    validator_invocations.push(quote! {
2230                        {
2231                            #( #case_setup )*
2232                            let _ = sqlx::query_scalar!(
2233                                #sql_lit,
2234                            );
2235                        }
2236                    });
2237                } else {
2238                    validator_invocations.push(quote! {
2239                        {
2240                            #( #case_setup )*
2241                            let _ = sqlx::query_scalar!(
2242                                #sql_lit,
2243                                #( #args ),*
2244                            );
2245                        }
2246                    });
2247                }
2248            } else if let Some(scalar_ty) = scalar_model_ty {
2249                if args.is_empty() {
2250                    validator_invocations.push(quote! {
2251                        {
2252                            #( #case_setup )*
2253                            let _ = sqlx::query_scalar!(
2254                                #sql_lit,
2255                            );
2256                        }
2257                    });
2258                } else {
2259                    validator_invocations.push(quote! {
2260                        {
2261                            #( #case_setup )*
2262                            let _ = sqlx::query_scalar!(
2263                                #sql_lit,
2264                                #( #args ),*
2265                            );
2266                        }
2267                    });
2268                }
2269                let _ = scalar_ty;
2270            } else if args.is_empty() {
2271                validator_invocations.push(quote! {
2272                    {
2273                        #( #case_setup )*
2274                        let _ = sqlx::query_as!(
2275                            __EnhancedModel,
2276                            #sql_lit,
2277                        );
2278                    }
2279                });
2280            } else {
2281                validator_invocations.push(quote! {
2282                    {
2283                        #( #case_setup )*
2284                        let _ = sqlx::query_as!(
2285                            __EnhancedModel,
2286                            #sql_lit,
2287                            #( #args ),*
2288                        );
2289                    }
2290                });
2291            }
2292        }
2293
2294        let model_alias = if let Some(model) = model_opt {
2295            if scalar_model_ty.is_none() {
2296                quote! { type __EnhancedModel = #model; }
2297            } else {
2298                quote! {}
2299            }
2300        } else {
2301            quote! {}
2302        };
2303        grouped_validator_invocations.push(quote! {
2304            {
2305                #( #flag_bindings )*
2306                #model_alias
2307                #( #validator_invocations )*
2308            }
2309        });
2310
2311        let (runtime_declared_params, runtime_param_bindings) =
2312            match build_param_bindings(&params, &used_param_names, "runtime", false, false) {
2313                Ok(v) => v,
2314                Err(err) => return err,
2315            };
2316
2317        let mut runtime_steps = Vec::<TokenStream2>::new();
2318        for (seg_idx, segment) in segments.iter().enumerate() {
2319            match segment {
2320                Segment::Text(text) => {
2321                    for part in parse_text_parts(text) {
2322                        match part {
2323                            TextPart::Lit(lit) => {
2324                                let lit = sanitize_runtime_sql_text(&lit);
2325                                let lit_str = LitStr::new(&lit, sql_span);
2326                                runtime_steps.push(quote! {
2327                                    __builder.push(#lit_str);
2328                                });
2329                            }
2330                            TextPart::Param { name, is_list } => {
2331                                let Some(local_ident) = runtime_declared_params.get(&name) else {
2332                                    return syn::Error::new(
2333                                        sql_span,
2334                                        format!("sql_forge!: parameter :{} has no mapping", name),
2335                                    )
2336                                    .to_compile_error()
2337                                    .into();
2338                                };
2339
2340                                if is_list {
2341                                    runtime_steps.push(quote! {
2342                                        let __enhanced_values = #local_ident;
2343                                        let mut __separated = __builder.separated(", ");
2344                                        for __value in __enhanced_values {
2345                                            __separated.push_bind(__value);
2346                                        }
2347                                    });
2348                                } else {
2349                                    runtime_steps.push(quote! {
2350                                        __builder.push_bind(#local_ident);
2351                                    });
2352                                }
2353                            }
2354                        }
2355                    }
2356                }
2357                Segment::Section { name } => {
2358                    let Some(section_action) = runtime_section_actions.get(name) else {
2359                        let _ = seg_idx;
2360                        return syn::Error::new(
2361                            sql_span,
2362                            format!("sql_forge!: section {{#{}}} has no mapping", name),
2363                        )
2364                        .to_compile_error()
2365                        .into();
2366                    };
2367                    runtime_steps.push(quote! {
2368                        #section_action
2369                    });
2370                }
2371                Segment::Batch { parts } => {
2372                    if let Some(batch_expr) = &batch {
2373                        let mut body = Vec::<TokenStream2>::new();
2374                        for part in parts {
2375                            match part {
2376                                TextPart::Lit(lit) => {
2377                                    let lit_str = LitStr::new(lit, sql_span);
2378                                    body.push(quote! {
2379                                        __builder.push(#lit_str);
2380                                    });
2381                                }
2382                                TextPart::Param { name, .. } => {
2383                                    let field_ident = format_ident!("{}", name);
2384                                    body.push(quote! {
2385                                        __builder.push_bind(__item.#field_ident);
2386                                    });
2387                                }
2388                            }
2389                        }
2390                        runtime_steps.push(quote! {
2391                            {
2392                                let mut __first = true;
2393                                for __item in #batch_expr {
2394                                    if !__first {
2395                                        __builder.push(", ");
2396                                    }
2397                                    __first = false;
2398                                    #( #body )*
2399                                }
2400                            }
2401                        });
2402                    }
2403                }
2404            }
2405        }
2406
2407        let exec_methods = if model_opt.is_none() {
2408            quote! {
2409                async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2410                where
2411                    E: sqlx::Executor<'e, Database = #db>,
2412                {
2413                    self.inner.build().execute(executor).await
2414                }
2415            }
2416        } else if let Some(scalar_ty) = scalar_model_ty {
2417            quote! {
2418                async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#scalar_ty>, sqlx::Error>
2419                where
2420                    E: sqlx::Executor<'e, Database = #db>,
2421                {
2422                    self.inner
2423                        .build_query_scalar::<#scalar_ty>()
2424                        .fetch_all(executor)
2425                        .await
2426                }
2427
2428                async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#scalar_ty, sqlx::Error>
2429                where
2430                    E: sqlx::Executor<'e, Database = #db>,
2431                {
2432                    self.inner
2433                        .build_query_scalar::<#scalar_ty>()
2434                        .fetch_one(executor)
2435                        .await
2436                }
2437
2438                async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#scalar_ty>, sqlx::Error>
2439                where
2440                    E: sqlx::Executor<'e, Database = #db>,
2441                {
2442                    self.inner
2443                        .build_query_scalar::<#scalar_ty>()
2444                        .fetch_optional(executor)
2445                        .await
2446                }
2447
2448                async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2449                where
2450                    E: sqlx::Executor<'e, Database = #db>,
2451                {
2452                    self.inner.build().execute(executor).await
2453                }
2454            }
2455        } else {
2456            let model = model_opt.as_ref().unwrap();
2457            quote! {
2458                async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#model>, sqlx::Error>
2459                where
2460                    E: sqlx::Executor<'e, Database = #db>,
2461                {
2462                    self.inner.build_query_as::<#model>().fetch_all(executor).await
2463                }
2464
2465                async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#model, sqlx::Error>
2466                where
2467                    E: sqlx::Executor<'e, Database = #db>,
2468                {
2469                    self.inner.build_query_as::<#model>().fetch_one(executor).await
2470                }
2471
2472                async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#model>, sqlx::Error>
2473                where
2474                    E: sqlx::Executor<'e, Database = #db>,
2475                {
2476                    self.inner
2477                        .build_query_as::<#model>()
2478                        .fetch_optional(executor)
2479                        .await
2480                }
2481
2482                async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2483                where
2484                    E: sqlx::Executor<'e, Database = #db>,
2485                {
2486                    self.inner.build().execute(executor).await
2487                }
2488            }
2489        };
2490
2491        let final_type: TokenStream2 = if let Some(model) = model_opt {
2492            if let Some(scalar_ty) = scalar_model_ty {
2493                quote! { #scalar_ty }
2494            } else {
2495                quote! { #model }
2496            }
2497        } else {
2498            quote! {}
2499        };
2500        let trait_impl = if model_opt.is_none() {
2501            quote! {
2502                impl<'args> sql_forge::SqlForgeQueryExecute
2503                    for #query_ident<'args>
2504                {
2505                    type Db = #db;
2506
2507                    fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2508                    where
2509                        Self: Sized + 'e,
2510                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2511                        #db: 'e,
2512                    {
2513                        #query_ident::execute(self, executor)
2514                    }
2515                }
2516            }
2517        } else {
2518            quote! {
2519                impl<'args> sql_forge::SqlForgeQuery<#final_type>
2520                    for #query_ident<'args>
2521                {
2522                    type Db = #db;
2523
2524                    fn fetch_all<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Vec<#final_type>, sqlx::Error>> + Send + 'e
2525                    where
2526                        Self: Sized + 'e,
2527                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2528                        #db: 'e,
2529                    {
2530                        #query_ident::fetch_all(self, executor)
2531                    }
2532
2533                    fn fetch_one<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<#final_type, sqlx::Error>> + Send + 'e
2534                    where
2535                        Self: Sized + 'e,
2536                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2537                        #db: 'e,
2538                    {
2539                        #query_ident::fetch_one(self, executor)
2540                    }
2541
2542                    fn fetch_optional<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Option<#final_type>, sqlx::Error>> + Send + 'e
2543                    where
2544                        Self: Sized + 'e,
2545                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2546                        #db: 'e,
2547                    {
2548                        #query_ident::fetch_optional(self, executor)
2549                    }
2550
2551                    fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2552                    where
2553                        Self: Sized + 'e,
2554                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2555                        #db: 'e,
2556                    {
2557                        #query_ident::execute(self, executor)
2558                    }
2559                }
2560            }
2561        };
2562
2563        generated_query_defs.push(quote! {
2564            struct #query_ident<'args> {
2565                inner: sqlx::QueryBuilder<'args, #db>,
2566            }
2567
2568            impl<'args> #query_ident<'args> {
2569                #exec_methods
2570            }
2571
2572            #trait_impl
2573        });
2574
2575        generated_query_values.push(quote! {
2576            #( #runtime_param_bindings )*
2577            #( #flag_bindings )*
2578            let mut __builder: sqlx::QueryBuilder<#db> = sqlx::QueryBuilder::new("");
2579            #( #runtime_steps )*
2580            let #query_value_ident = #query_ident { inner: __builder };
2581        });
2582
2583        if let Some(key) = result_key {
2584            let method_ident = format_ident!("{}", key);
2585            group_field_defs.push(quote! {
2586                #method_ident: #query_ident<'args>
2587            });
2588            group_field_tys.push(quote! { #query_ident<'args> });
2589            group_method_defs.push(quote! {
2590                pub fn #method_ident(self) -> #query_ident<'args> {
2591                    self.#method_ident
2592                }
2593            });
2594
2595            let key_ty_ident = format_ident!("__SqlForgeQueryGroupKey_{}", key);
2596            group_trait_impls.push(quote! {
2597                struct #key_ty_ident;
2598
2599                impl<'args> sql_forge::SqlForgeQueryGroupGet<#key_ty_ident, #final_type> for __SqlForgeQueryGroup<'args> {
2600                    type Query = #query_ident<'args>;
2601
2602                    fn get(self, _: #key_ty_ident) -> Self::Query {
2603                        self.#method_ident
2604                    }
2605                }
2606            });
2607            group_field_idents.push(method_ident);
2608        }
2609    }
2610
2611    // ---- Phase 8: Emit the final token stream ----
2612    let validator_tokens = quote! {
2613        let _sql_forge_validator = || {
2614            #( #validator_param_bindings )*
2615            #( #grouped_validator_invocations )*
2616        };
2617    };
2618
2619    if !is_grouped_result {
2620        let single_query_value_ident = format_ident!("__sql_forge_value_single");
2621        return quote! {
2622            {
2623                #validator_tokens
2624                #( #generated_query_defs )*
2625                #( #generated_query_values )*
2626                #single_query_value_ident
2627            }
2628        }
2629        .into();
2630    }
2631
2632    let group_field_inits: Vec<TokenStream2> = result_cases
2633        .iter()
2634        .filter_map(|(key, _, _)| key.as_ref())
2635        .map(|key| {
2636            let method_ident = format_ident!("{}", key);
2637            let query_value_ident = format_ident!("__sql_forge_value_{}", key);
2638            quote! { #method_ident: #query_value_ident }
2639        })
2640        .collect();
2641
2642    quote! {
2643        {
2644            #validator_tokens
2645
2646            #( #generated_query_defs )*
2647            #( #generated_query_values )*
2648
2649            struct __SqlForgeQueryGroup<'args> {
2650                #( #group_field_defs, )*
2651            }
2652
2653            impl<'args> __SqlForgeQueryGroup<'args> {
2654                #( #group_method_defs )*
2655
2656                pub fn into_parts(self) -> ( #( #group_field_tys ),* ) {
2657                    ( #( self.#group_field_idents ),* )
2658                }
2659            }
2660
2661            impl<'args> sql_forge::SqlForgeQueryGroup for __SqlForgeQueryGroup<'args> {
2662                type Db = #db;
2663            }
2664
2665            #( #group_trait_impls )*
2666
2667            __SqlForgeQueryGroup {
2668                #( #group_field_inits, )*
2669            }
2670        }
2671    }
2672    .into()
2673}
2674
2675/// Expands to the database type from the `SQL_FORGE_DB_TYPE` environment variable,
2676/// falling back to `[package.metadata.sql_forge]` in `Cargo.toml`.
2677///
2678/// ```rust,ignore
2679/// use sql_forge::db_type;
2680///
2681/// pub type AppDb = db_type!();
2682/// // expands to the type set via SQL_FORGE_DB_TYPE or Cargo.toml metadata
2683/// ```
2684///
2685/// Priority:
2686/// 1. `SQL_FORGE_DB_TYPE` env var (e.g. `sqlx::MySql`, `sqlx::Postgres`)
2687/// 2. `[package.metadata.sql_forge] db = "..."` in `Cargo.toml`
2688#[proc_macro]
2689pub fn db_type(input: TokenStream) -> TokenStream {
2690    if !input.is_empty() {
2691        return syn::Error::new(Span::call_site(), "db_type!() takes no arguments")
2692            .to_compile_error()
2693            .into();
2694    }
2695
2696    match resolve_db_from_env() {
2697        Ok(db) => quote! { #db }.into(),
2698        Err(msg) => syn::Error::new(Span::call_site(), msg)
2699            .to_compile_error()
2700            .into(),
2701    }
2702}