Skip to main content

sql_forge_macro/
lib.rs

1// =============================================================================
2// Imports
3// =============================================================================
4
5use proc_macro::TokenStream;
6use proc_macro2::{Delimiter, Group, Span, TokenStream as TokenStream2, TokenTree};
7use quote::{format_ident, quote};
8use std::collections::{HashMap, HashSet};
9use std::fmt::Write;
10use std::fs;
11use std::path::Path;
12use syn::parse::{Parse, ParseStream};
13use syn::spanned::Spanned;
14use syn::{
15    Expr, ExprBlock, ExprGroup, ExprLit, ExprParen, Ident, Lit, LitStr, Pat, Stmt, Token, Type,
16};
17
18// =============================================================================
19// Input types
20// =============================================================================
21
22mod kw {
23    syn::custom_keyword!(scalar);
24}
25
26/// A `:name = expr` parameter binding.
27#[derive(Clone)]
28struct ParamAssign {
29    name: Ident,
30    expr: Expr,
31}
32
33impl Parse for ParamAssign {
34    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
35        input.parse::<Token![:]>()?;
36        let name: Ident = input.parse()?;
37        input.parse::<Token![=]>()?;
38        let expr: Expr = input.parse()?;
39        Ok(Self { name, expr })
40    }
41}
42
43/// A single section value: SQL string + optional local parameter bindings.
44#[derive(Clone)]
45struct SectionFragment {
46    sql: String,
47    span: Span,
48    params: ParamsSource,
49}
50
51/// One arm in a `match { ... }` inside a section assignment.
52#[derive(Clone)]
53struct SectionMatchArm {
54    pat: Pat,
55    guard: Option<Expr>,
56    value: SectionValue,
57}
58
59/// The right-hand side of a section `#name = ...` assignment.
60#[derive(Clone)]
61enum SectionValue {
62    /// A plain string (or `(string, params)` tuple).
63    Single(SectionFragment),
64    /// A tuple of values, one per section when using `#(a, b) = ...`.
65    Grouped(Vec<SectionValue>),
66    /// A `match expr { arm => ..., arm => ... }` expression.
67    Match {
68        expr: Expr,
69        arms: Vec<SectionMatchArm>,
70    },
71}
72
73/// A full `#name = value` or `#(a, b) = value` section assignment.
74#[derive(Clone)]
75struct SectionAssign {
76    names: Vec<Ident>,
77    value: SectionValue,
78}
79
80impl Parse for SectionAssign {
81    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
82        input.parse::<Token![#]>()?;
83
84        // Parse `#(a, b)` for grouped sections, or `#name` for single.
85        let names = if input.peek(syn::token::Paren) {
86            let content;
87            syn::parenthesized!(content in input);
88            let mut out = Vec::new();
89            while !content.is_empty() {
90                out.push(content.parse::<Ident>()?);
91                if content.is_empty() {
92                    break;
93                }
94                content.parse::<Token![,]>()?;
95            }
96            if out.is_empty() {
97                return Err(input.error("sql_forge!: grouped section key list cannot be empty"));
98            }
99            out
100        } else {
101            vec![input.parse::<Ident>()?]
102        };
103
104        input.parse::<Token![=]>()?;
105        let value = parse_section_value(input, names.len())?;
106        Ok(Self { names, value })
107    }
108}
109
110/// The fully-parsed macro invocation.
111struct SqlForgeInput {
112    db: Option<Type>,
113    result: ResultSpec,
114    force_scalar: bool,
115    sql: SqlTemplate,
116    params: ParamsSource,
117    sections: Vec<SectionAssign>,
118    batch: Option<Expr>,
119}
120
121/// One entry in a result map `(>key = Model, ...)`.
122#[derive(Clone)]
123struct ResultAssign {
124    name: Ident,
125    model: Type,
126    force_scalar: bool,
127}
128
129#[derive(Clone)]
130enum ResultSpec {
131    /// Execute-only (no model), e.g. `sql_forge!("SQL", ...)`
132    None,
133    /// e.g. `sql_forge!(User, ...)`
134    Single(Box<Type>),
135    /// e.g. `sql_forge!((>a = X, >b = Y), ...)`
136    Group(Vec<ResultAssign>),
137}
138
139#[derive(Clone)]
140enum ParamsSource {
141    None,
142    /// `( :name = expr, ... )`
143    Map(Vec<ParamAssign>),
144    /// A struct expression whose fields are matched to `:name` placeholders.
145    Struct(Box<Expr>),
146}
147
148/// The SQL template. Only string literals are supported.
149enum SqlTemplate {
150    Literal(LitStr),
151}
152
153impl SqlTemplate {
154    fn span(&self) -> Span {
155        match self {
156            Self::Literal(lit) => lit.span(),
157        }
158    }
159
160    /// Parse the SQL into a sequence of textual segments and `{#section}` slots.
161    fn into_segments(self) -> Result<Vec<Segment>, String> {
162        match self {
163            Self::Literal(lit) => parse_literal_segments(&lit.value()),
164        }
165    }
166}
167
168fn parse_sql_template(input: ParseStream<'_>) -> syn::Result<SqlTemplate> {
169    if input.peek(LitStr) {
170        Ok(SqlTemplate::Literal(input.parse::<LitStr>()?))
171    } else {
172        Err(input.error("sql_forge!: SQL template must be a string literal"))
173    }
174}
175
176/// One piece of the parsed SQL template.
177#[derive(Clone)]
178enum Segment {
179    /// Plain SQL text (may contain `:param` placeholders).
180    Text(String),
181    /// A `{#section_name}` placeholder.
182    Section { name: String },
183    /// A `{( ... )}` batch value template (repeated per batch item).
184    Batch { parts: Vec<TextPart> },
185}
186
187/// A fragment of SQL text after splitting on `:param`.
188#[derive(Clone)]
189enum TextPart {
190    /// Literal SQL text.
191    Lit(String),
192    /// A `:param` placeholder.
193    Param { name: String, is_list: bool },
194}
195
196// =============================================================================
197// Parsing helpers
198// =============================================================================
199
200/// Used by `detect_parenthesized_map_kind` to identify what a `(...)` argument is.
201enum MapKind {
202    Results,
203    Params,
204    Sections,
205}
206
207// =============================================================================
208// Section value parsing (string fragments, match, grouped tuples)
209// =============================================================================
210
211fn detect_parenthesized_map_kind(input: ParseStream<'_>) -> syn::Result<Option<MapKind>> {
212    let fork = input.fork();
213    let content;
214    syn::parenthesized!(content in fork);
215
216    if content.is_empty() {
217        return Err(input.error("sql_forge!: map argument cannot be empty"));
218    }
219
220    if content.peek(Token![>]) {
221        Ok(Some(MapKind::Results))
222    } else if content.peek(Token![:]) {
223        Ok(Some(MapKind::Params))
224    } else if content.peek(Token![#]) {
225        Ok(Some(MapKind::Sections))
226    } else {
227        Ok(None)
228    }
229}
230
231impl Parse for ResultAssign {
232    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
233        input.parse::<Token![>]>()?;
234        let name: Ident = input.parse()?;
235        input.parse::<Token![=]>()?;
236        let (force_scalar, model) = if input.peek(kw::scalar) {
237            input.parse::<kw::scalar>()?;
238            (true, input.parse::<Type>()?)
239        } else {
240            (false, input.parse::<Type>()?)
241        };
242        Ok(Self {
243            name,
244            model,
245            force_scalar,
246        })
247    }
248}
249
250fn parse_result_map(input: ParseStream<'_>) -> syn::Result<Vec<ResultAssign>> {
251    let content;
252    syn::parenthesized!(content in input);
253
254    let mut results = Vec::new();
255    while !content.is_empty() {
256        results.push(content.parse::<ResultAssign>()?);
257        if content.is_empty() {
258            break;
259        }
260        content.parse::<Token![,]>()?;
261    }
262
263    if results.is_empty() {
264        return Err(input.error("sql_forge!: result map cannot be empty"));
265    }
266
267    Ok(results)
268}
269
270fn parse_param_map(input: ParseStream<'_>) -> syn::Result<Vec<ParamAssign>> {
271    let content;
272    syn::parenthesized!(content in input);
273
274    let mut params = Vec::new();
275    while !content.is_empty() {
276        params.push(content.parse::<ParamAssign>()?);
277        if content.is_empty() {
278            break;
279        }
280        content.parse::<Token![,]>()?;
281    }
282
283    Ok(params)
284}
285
286fn parse_section_map(input: ParseStream<'_>) -> syn::Result<Vec<SectionAssign>> {
287    let content;
288    syn::parenthesized!(content in input);
289
290    let mut sections = Vec::new();
291    while !content.is_empty() {
292        sections.push(content.parse::<SectionAssign>()?);
293        if content.is_empty() {
294            break;
295        }
296        content.parse::<Token![,]>()?;
297    }
298
299    Ok(sections)
300}
301
302fn parse_params_source_expr(
303    input: ParseStream<'_>,
304    allow_sections: bool,
305) -> syn::Result<ParamsSource> {
306    if input.peek(syn::token::Paren) {
307        match detect_parenthesized_map_kind(input)? {
308            Some(MapKind::Results) => Err(input
309                .error("sql_forge!: result maps are only allowed as the macro result argument")),
310            Some(MapKind::Params) => Ok(ParamsSource::Map(parse_param_map(input)?)),
311            Some(MapKind::Sections) if allow_sections => {
312                Err(input.error("sql_forge!: section maps are not allowed here"))
313            }
314            Some(MapKind::Sections) => Err(input.error(
315                "sql_forge!: use :name = expr for section-local parameters, not #name = expr",
316            )),
317            None => Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?))),
318        }
319    } else {
320        Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?)))
321    }
322}
323
324fn parse_section_fragment(input: ParseStream<'_>) -> syn::Result<SectionFragment> {
325    if input.peek(syn::token::Paren) {
326        let fork = input.fork();
327        let content;
328        syn::parenthesized!(content in fork);
329
330        if let Ok(first_expr) = content.parse::<Expr>() {
331            if extract_lit_str(&first_expr).is_some() && content.parse::<Token![,]>().is_ok() {
332                let _ = parse_params_source_expr(&content, false)?;
333                if content.peek(Token![,]) {
334                    content.parse::<Token![,]>()?;
335                }
336                if content.is_empty() {
337                    let content;
338                    syn::parenthesized!(content in input);
339                    let first_expr: Expr = content.parse()?;
340                    let sql = extract_lit_str(&first_expr).ok_or_else(|| {
341                        input.error("sql_forge!: section tuple must start with a string literal")
342                    })?;
343                    let span = first_expr.span();
344                    content.parse::<Token![,]>()?;
345                    let params = parse_params_source_expr(&content, false)?;
346                    if content.peek(Token![,]) {
347                        content.parse::<Token![,]>()?;
348                    }
349                    if !content.is_empty() {
350                        return Err(content.error(
351                            "sql_forge!: unexpected tokens after section-local parameter source",
352                        ));
353                    }
354                    return Ok(SectionFragment { sql, span, params });
355                }
356            }
357        }
358    }
359
360    let expr: Expr = input.parse()?;
361    let sql = extract_lit_str(&expr).ok_or_else(|| {
362        input
363            .error("sql_forge!: section values must be string literals or (string literal, params)")
364    })?;
365    Ok(SectionFragment {
366        sql,
367        span: expr.span(),
368        params: ParamsSource::None,
369    })
370}
371
372fn parse_section_value(input: ParseStream<'_>, width: usize) -> syn::Result<SectionValue> {
373    if input.peek(Token![match]) {
374        input.parse::<Token![match]>()?;
375        let expr: Expr = input.call(Expr::parse_without_eager_brace)?;
376        let content;
377        syn::braced!(content in input);
378        let mut arms = Vec::new();
379        while !content.is_empty() {
380            let pat = content.call(Pat::parse_multi_with_leading_vert)?;
381            let guard = if content.peek(Token![if]) {
382                content.parse::<Token![if]>()?;
383                Some(content.parse::<Expr>()?)
384            } else {
385                None
386            };
387            content.parse::<Token![=>]>()?;
388            let value = parse_section_value(&content, width)?;
389            if content.peek(Token![,]) {
390                content.parse::<Token![,]>()?;
391            }
392            arms.push(SectionMatchArm { pat, guard, value });
393        }
394        return Ok(SectionValue::Match { expr, arms });
395    }
396
397    if width == 1 {
398        return Ok(SectionValue::Single(parse_section_fragment(input)?));
399    }
400
401    let content;
402    syn::parenthesized!(content in input);
403    let mut items = Vec::new();
404    while !content.is_empty() {
405        items.push(parse_section_value(&content, 1)?);
406        if content.is_empty() {
407            break;
408        }
409        content.parse::<Token![,]>()?;
410    }
411
412    if items.len() != width {
413        return Err(input.error(format!(
414            "sql_forge!: grouped section value must provide exactly {} items",
415            width,
416        )));
417    }
418
419    Ok(SectionValue::Grouped(items))
420}
421
422// =============================================================================
423// Top-level macro input parsing (SqlForgeInput::parse)
424// =============================================================================
425
426impl Parse for SqlForgeInput {
427    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
428        let (db, result, force_scalar, sql) = if input.peek(LitStr) {
429            let sql = parse_sql_template(input)?;
430            (None, ResultSpec::None, false, sql)
431        } else if input.peek(kw::scalar) {
432            input.parse::<kw::scalar>()?;
433            let model: Type = input.parse()?;
434            input.parse::<Token![,]>()?;
435            let sql = parse_sql_template(input)?;
436            (None, ResultSpec::Single(Box::new(model)), true, sql)
437        } else if input.peek(syn::token::Paren) {
438            let result_map_kind = detect_parenthesized_map_kind(input)?;
439            match result_map_kind {
440                Some(MapKind::Results) => {
441                    let result = ResultSpec::Group(parse_result_map(input)?);
442                    input.parse::<Token![,]>()?;
443                    let sql = parse_sql_template(input)?;
444                    (None, result, false, sql)
445                }
446                _ => {
447                    return Err(input.error(
448                        "sql_forge!: expected a result map like (>name = Model, ...) or a model type",
449                    ));
450                }
451            }
452        } else {
453            let first_ty: Type = input.parse()?;
454            input.parse::<Token![,]>()?;
455
456            if input.peek(LitStr) {
457                let model = first_ty;
458                let sql = parse_sql_template(input)?;
459                (None, ResultSpec::Single(Box::new(model)), false, sql)
460            } else if input.peek(kw::scalar) {
461                input.parse::<kw::scalar>()?;
462                let model: Type = input.parse()?;
463                input.parse::<Token![,]>()?;
464                let sql = parse_sql_template(input)?;
465                (
466                    Some(first_ty),
467                    ResultSpec::Single(Box::new(model)),
468                    true,
469                    sql,
470                )
471            } else if input.peek(syn::token::Paren)
472                && matches!(
473                    detect_parenthesized_map_kind(input)?,
474                    Some(MapKind::Results)
475                )
476            {
477                let result = ResultSpec::Group(parse_result_map(input)?);
478                input.parse::<Token![,]>()?;
479                let sql = parse_sql_template(input)?;
480                (Some(first_ty), result, false, sql)
481            } else {
482                let db = Some(first_ty);
483                let model: Type = input.parse()?;
484                input.parse::<Token![,]>()?;
485                let sql = parse_sql_template(input)?;
486                (db, ResultSpec::Single(Box::new(model)), false, sql)
487            }
488        };
489
490        let mut batch = None;
491        let mut params = ParamsSource::None;
492        let mut sections = Vec::new();
493        let mut seen_params = false;
494        let mut seen_sections = false;
495
496        if input.parse::<Token![,]>().is_ok() {
497            while !input.is_empty() {
498                if input.peek(Token![..]) {
499                    if batch.is_some() {
500                        return Err(
501                            input.error("sql_forge!: only one batch source argument is allowed")
502                        );
503                    }
504                    input.parse::<Token![..]>()?;
505                    batch = Some(input.parse::<Expr>()?);
506                } else if input.peek(syn::token::Paren) {
507                    match detect_parenthesized_map_kind(input)? {
508                        Some(MapKind::Results) => {
509                            return Err(input.error(
510                                "sql_forge!: result maps are only allowed as the macro result argument",
511                            ));
512                        }
513                        Some(MapKind::Params) => {
514                            if seen_params {
515                                return Err(
516                                    input.error("sql_forge!: only one parameter source is allowed")
517                                );
518                            }
519                            params = ParamsSource::Map(parse_param_map(input)?);
520                            seen_params = true;
521                        }
522                        Some(MapKind::Sections) => {
523                            if seen_sections {
524                                return Err(
525                                    input.error("sql_forge!: duplicate section map argument")
526                                );
527                            }
528                            sections = parse_section_map(input)?;
529                            seen_sections = true;
530                        }
531                        None => {
532                            if seen_params {
533                                return Err(
534                                    input.error("sql_forge!: only one parameter source is allowed")
535                                );
536                            }
537                            params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
538                            seen_params = true;
539                        }
540                    }
541                } else {
542                    if seen_params {
543                        return Err(input.error("sql_forge!: only one parameter source is allowed"));
544                    }
545                    params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
546                    seen_params = true;
547                }
548
549                if input.parse::<Token![,]>().is_ok() {
550                    continue;
551                }
552                break;
553            }
554        }
555
556        if !input.is_empty() {
557            return Err(input.error("sql_forge!: unexpected tokens in macro invocation"));
558        }
559
560        Ok(Self {
561            db,
562            result,
563            force_scalar,
564            sql,
565            params,
566            sections,
567            batch,
568        })
569    }
570}
571
572// =============================================================================
573// Database type resolution
574// =============================================================================
575
576fn resolve_db_from_env() -> Result<Type, String> {
577    if let Ok(val) = std::env::var("SQL_FORGE_DB_TYPE") {
578        return syn::parse_str::<Type>(&val).map_err(|err| {
579            format!(
580                "sql_forge!: invalid DB type `{}` in SQL_FORGE_DB_TYPE env var: {}",
581                val, err
582            )
583        });
584    }
585
586    let manifest_dir = match std::env::var("CARGO_MANIFEST_DIR") {
587        Ok(d) => d,
588        Err(_) => {
589            return Err(
590                "sql_forge!: pass DB as first macro argument, set SQL_FORGE_DB_TYPE, \
591                 or configure [package.metadata.sql_forge] in Cargo.toml"
592                    .to_string(),
593            );
594        }
595    };
596    let manifest_path = Path::new(&manifest_dir).join("Cargo.toml");
597
598    let cargo_toml = fs::read_to_string(&manifest_path).map_err(|err| {
599        format!(
600            "sql_forge!: failed to read {}: {}",
601            manifest_path.display(),
602            err
603        )
604    })?;
605
606    let value: toml::Value = toml::from_str(&cargo_toml)
607        .map_err(|err| format!("sql_forge!: failed to parse Cargo.toml: {}", err))?;
608
609    let db_str = value
610        .get("package")
611        .and_then(|v| v.get("metadata"))
612        .and_then(|v| v.get("sql_forge"))
613        .and_then(|v| v.get("db"))
614        .and_then(|v| v.as_str())
615        .ok_or({
616            "sql_forge!: missing [package.metadata.sql_forge] db = \"...\" in Cargo.toml, \
617             SQL_FORGE_DB_TYPE env var, or DB as first macro argument"
618        })?;
619
620    syn::parse_str::<Type>(db_str).map_err(|err| {
621        format!(
622            "sql_forge!: invalid DB type `{}` in Cargo.toml metadata: {}",
623            db_str, err
624        )
625    })
626}
627
628fn uses_dollar_params(db: &Type) -> bool {
629    let Type::Path(type_path) = db else {
630        return false;
631    };
632    type_path
633        .path
634        .segments
635        .last()
636        .is_some_and(|s| s.ident == "Postgres")
637}
638
639fn is_builtin_scalar_type(ty: &Type) -> bool {
640    let Type::Path(type_path) = ty else {
641        return false;
642    };
643
644    if type_path.qself.is_some()
645        || type_path.path.leading_colon.is_some()
646        || type_path.path.segments.len() != 1
647    {
648        return false;
649    }
650
651    let ident = &type_path.path.segments[0].ident;
652    ident == "i8"
653        || ident == "i16"
654        || ident == "i32"
655        || ident == "i64"
656        || ident == "isize"
657        || ident == "u8"
658        || ident == "u16"
659        || ident == "u32"
660        || ident == "u64"
661        || ident == "usize"
662        || ident == "f32"
663        || ident == "f64"
664        || ident == "bool"
665        || ident == "String"
666}
667
668fn scalar_output_type(model: &Type) -> Option<&Type> {
669    if is_builtin_scalar_type(model) {
670        return Some(model);
671    }
672    None
673}
674
675fn push_text_segment(out: &mut Vec<Segment>, text: String) {
676    if text.is_empty() {
677        return;
678    }
679    match out.last_mut() {
680        Some(Segment::Text(existing)) => existing.push_str(&text),
681        _ => out.push(Segment::Text(text)),
682    }
683}
684
685fn parse_literal_segments(sql: &str) -> Result<Vec<Segment>, String> {
686    let mut out = Vec::new();
687    let mut text = String::new();
688    let mut chars = sql.chars().peekable();
689
690    while let Some(ch) = chars.next() {
691        if ch != '{' {
692            text.push(ch);
693            continue;
694        }
695
696        if chars.peek() == Some(&'(') {
697            push_text_segment(&mut out, std::mem::take(&mut text));
698
699            let mut paren_depth = 0u32;
700            let mut content = String::new();
701            let mut found_close = false;
702            for ch in chars.by_ref() {
703                if ch == '{' {
704                    return Err(
705                        "sql_forge!: nested braces not allowed inside batch section".to_string()
706                    );
707                }
708                if ch == '}' {
709                    if paren_depth != 0 {
710                        return Err(
711                            "sql_forge!: batch section {( ... )} has unbalanced parentheses"
712                                .to_string(),
713                        );
714                    }
715                    found_close = true;
716                    break;
717                }
718                if ch == '(' {
719                    paren_depth += 1;
720                } else if ch == ')' {
721                    if paren_depth == 0 {
722                        return Err(
723                            "sql_forge!: batch section {( ... )} has unbalanced parentheses"
724                                .to_string(),
725                        );
726                    }
727                    paren_depth -= 1;
728                }
729                content.push(ch);
730            }
731            if !found_close {
732                return Err("sql_forge!: batch section {( ... )} without closing }".to_string());
733            }
734            let parts = parse_text_parts(&content);
735            for part in &parts {
736                if let TextPart::Param { is_list: true, .. } = part {
737                    return Err(
738                        "sql_forge!: list parameters (:name[]) are not allowed inside {( ... )} \
739                         batch sections; use plain parameters (:name) instead"
740                            .to_string(),
741                    );
742                }
743            }
744            out.push(Segment::Batch { parts });
745            continue;
746        }
747
748        if chars.peek() != Some(&'#') {
749            text.push(ch);
750            continue;
751        }
752
753        chars.next();
754        push_text_segment(&mut out, std::mem::take(&mut text));
755
756        let mut name = String::new();
757        loop {
758            let Some(next) = chars.next() else {
759                return Err("sql_forge!: section placeholder without closing }".to_string());
760            };
761            if next == '}' {
762                break;
763            }
764            name.push(next);
765        }
766
767        if name.is_empty() {
768            return Err("sql_forge!: empty section placeholder name".to_string());
769        }
770
771        out.push(Segment::Section { name });
772    }
773
774    push_text_segment(&mut out, text);
775    Ok(out)
776}
777
778// =============================================================================
779// SQL template parsing: {#sections} and :param placeholders
780// =============================================================================
781
782fn is_ident_start(ch: char) -> bool {
783    ch == '_' || ch.is_ascii_alphabetic()
784}
785
786fn is_ident_continue(ch: char) -> bool {
787    is_ident_start(ch) || ch.is_ascii_digit()
788}
789
790fn sanitize_backticked_alias_ident(content: &str) -> String {
791    let mut split_at = content.len();
792    for (idx, ch) in content.char_indices() {
793        if ch == '!' || ch == '?' || ch == ':' {
794            split_at = idx;
795            break;
796        }
797    }
798
799    if split_at == content.len() {
800        return content.to_string();
801    }
802
803    let base = content[..split_at].trim_end();
804    if base.is_empty() {
805        content.to_string()
806    } else {
807        base.to_string()
808    }
809}
810
811fn sanitize_runtime_sql_text(text: &str) -> String {
812    let mut out = String::with_capacity(text.len());
813    let mut chars = text.chars().peekable();
814
815    while let Some(ch) = chars.next() {
816        if ch != '`' {
817            out.push(ch);
818            continue;
819        }
820
821        let mut content = String::new();
822        let mut closed = false;
823
824        for next in chars.by_ref() {
825            if next == '`' {
826                closed = true;
827                break;
828            }
829            content.push(next);
830        }
831
832        if closed {
833            out.push('`');
834            out.push_str(&sanitize_backticked_alias_ident(&content));
835            out.push('`');
836        } else {
837            out.push('`');
838            out.push_str(&content);
839            break;
840        }
841    }
842
843    out
844}
845
846fn parse_text_parts(text: &str) -> Vec<TextPart> {
847    let mut parts = Vec::new();
848    let mut last = 0usize;
849    let mut iter = text.char_indices().peekable();
850
851    while let Some((idx, ch)) = iter.next() {
852        if ch != ':' {
853            continue;
854        }
855
856        let Some(&(next_idx, next_ch)) = iter.peek() else {
857            continue;
858        };
859
860        if !is_ident_start(next_ch) {
861            continue;
862        }
863
864        if text[..idx].ends_with(':') {
865            continue;
866        }
867
868        if last < idx {
869            parts.push(TextPart::Lit(text[last..idx].to_string()));
870        }
871
872        iter.next();
873
874        let mut name = String::new();
875        name.push(next_ch);
876        let mut end = next_idx + next_ch.len_utf8();
877
878        while let Some(&(j, c)) = iter.peek() {
879            if is_ident_continue(c) {
880                name.push(c);
881                end = j + c.len_utf8();
882                iter.next();
883            } else {
884                break;
885            }
886        }
887
888        let mut is_list = false;
889        if text[end..].starts_with("[]") {
890            is_list = true;
891            end += 2;
892        }
893
894        parts.push(TextPart::Param { name, is_list });
895        last = end;
896    }
897
898    if last < text.len() {
899        parts.push(TextPart::Lit(text[last..].to_string()));
900    }
901
902    parts
903}
904
905fn render_validator_text(
906    text: &str,
907    use_dollar_params: bool,
908    param_offset: &mut usize,
909    list_count: usize,
910) -> (String, Vec<(String, bool)>) {
911    let mut out_sql = String::new();
912    let mut occurrences = Vec::new();
913
914    for part in parse_text_parts(text) {
915        match part {
916            TextPart::Lit(lit) => out_sql.push_str(&lit),
917            TextPart::Param { name, is_list } => {
918                if is_list && list_count > 1 {
919                    let slots: Vec<String> = if use_dollar_params {
920                        (0..list_count)
921                            .map(|i| format!("${}", *param_offset + i + 1))
922                            .collect()
923                    } else {
924                        (0..list_count).map(|_| "?".to_string()).collect()
925                    };
926                    if use_dollar_params {
927                        *param_offset += list_count;
928                    }
929                    out_sql.push_str(&slots.join(", "));
930                } else if use_dollar_params {
931                    *param_offset += 1;
932                    write!(out_sql, "${}", *param_offset).unwrap();
933                } else {
934                    out_sql.push('?');
935                }
936                occurrences.push((name, is_list));
937            }
938        }
939    }
940
941    (out_sql, occurrences)
942}
943
944fn strip_expr(expr: &Expr) -> &Expr {
945    match expr {
946        Expr::Paren(ExprParen { expr, .. }) => strip_expr(expr),
947        Expr::Group(ExprGroup { expr, .. }) => strip_expr(expr),
948        Expr::Block(ExprBlock { block, .. }) => {
949            if block.stmts.len() != 1 {
950                return expr;
951            }
952            match &block.stmts[0] {
953                Stmt::Expr(inner, None) => strip_expr(inner),
954                _ => expr,
955            }
956        }
957        _ => expr,
958    }
959}
960
961fn extract_lit_str(expr: &Expr) -> Option<String> {
962    match strip_expr(expr) {
963        Expr::Lit(ExprLit {
964            lit: Lit::Str(lit), ..
965        }) => Some(lit.value()),
966        _ => None,
967    }
968}
969
970// =============================================================================
971// Preprocessing: {>key} compile-time result flags
972// =============================================================================
973
974fn result_flag_ident(name: &str) -> syn::Ident {
975    format_ident!("__enhanced_result_flag_{}", name)
976}
977
978/// Replaces `{>key}` tokens inside braced groups with `__enhanced_result_flag_key`
979/// identifiers. This is a preprocessing step so that the rest of the parser sees
980/// plain identifiers instead of braced groups it does not understand.
981fn preprocess_result_key_placeholders(input: TokenStream2) -> TokenStream2 {
982    fn walk(stream: TokenStream2) -> TokenStream2 {
983        let mut out = TokenStream2::new();
984        let iter = stream.into_iter().peekable();
985
986        for token in iter {
987            match token {
988                TokenTree::Group(group) => {
989                    if group.delimiter() == Delimiter::Brace {
990                        let mut inner = group.stream().into_iter();
991                        let first = inner.next();
992                        let second = inner.next();
993                        let third = inner.next();
994
995                        if let (
996                            Some(TokenTree::Punct(p)),
997                            Some(TokenTree::Ident(name_ident)),
998                            None,
999                        ) = (first, second, third)
1000                        {
1001                            if p.as_char() == '>' {
1002                                let ident = result_flag_ident(&name_ident.to_string());
1003                                out.extend(std::iter::once(TokenTree::Ident(ident)));
1004                                continue;
1005                            }
1006                        }
1007                    }
1008
1009                    let new_inner = walk(group.stream());
1010                    let mut new_group = Group::new(group.delimiter(), new_inner);
1011                    new_group.set_span(group.span());
1012                    out.extend(std::iter::once(TokenTree::Group(new_group)));
1013                }
1014                other => out.extend(std::iter::once(other)),
1015            }
1016        }
1017
1018        out
1019    }
1020
1021    walk(input)
1022}
1023
1024fn build_result_flag_bindings(keys: &[String], active_key: Option<&str>) -> Vec<TokenStream2> {
1025    keys.iter()
1026        .map(|key| {
1027            let ident = result_flag_ident(key);
1028            let enabled = Some(key.as_str()) == active_key;
1029            quote! { let #ident: bool = #enabled; }
1030        })
1031        .collect()
1032}
1033
1034fn transpose_section_case_matrix(
1035    case_matrix: Vec<Vec<SectionFragment>>,
1036    width: usize,
1037) -> Result<Vec<Vec<SectionFragment>>, String> {
1038    let mut per_section: Vec<Vec<SectionFragment>> = (0..width).map(|_| Vec::new()).collect();
1039
1040    for row in case_matrix {
1041        if row.len() != width {
1042            return Err(
1043                "sql_forge!: grouped sections must return one item per section".to_string(),
1044            );
1045        }
1046        for (section_idx, fragment) in row.into_iter().enumerate() {
1047            per_section[section_idx].push(fragment);
1048        }
1049    }
1050
1051    Ok(per_section)
1052}
1053
1054fn collect_section_case_matrix(
1055    value: SectionValue,
1056    width: usize,
1057    active_key: Option<&str>,
1058) -> Result<Vec<Vec<SectionFragment>>, String> {
1059    match value {
1060        SectionValue::Single(fragment) => {
1061            if width != 1 {
1062                return Err(
1063                    "sql_forge!: grouped sections must return one item per section".to_string(),
1064                );
1065            }
1066            Ok(vec![vec![fragment]])
1067        }
1068        SectionValue::Grouped(values) => {
1069            if values.len() != width {
1070                return Err(
1071                    "sql_forge!: grouped sections must return one item per section".to_string(),
1072                );
1073            }
1074
1075            let mut variants_by_section = Vec::<Vec<SectionFragment>>::with_capacity(width);
1076            let mut nmax = 1usize;
1077
1078            for value in values {
1079                let item_matrix = collect_section_case_matrix(value, 1, active_key)?;
1080                let mut item_variants = Vec::<SectionFragment>::with_capacity(item_matrix.len());
1081                for mut row in item_matrix {
1082                    let fragment = row.pop().ok_or_else(|| {
1083                        "sql_forge!: grouped sections must return one item per section".to_string()
1084                    })?;
1085                    if !row.is_empty() {
1086                        return Err(
1087                            "sql_forge!: grouped sections must return one item per section"
1088                                .to_string(),
1089                        );
1090                    }
1091                    item_variants.push(fragment);
1092                }
1093                if item_variants.is_empty() {
1094                    return Err("sql_forge!: section match must have at least one arm".to_string());
1095                }
1096                nmax = nmax.max(item_variants.len());
1097                variants_by_section.push(item_variants);
1098            }
1099
1100            let mut case_matrix = Vec::<Vec<SectionFragment>>::with_capacity(nmax);
1101            for case_idx in 0..nmax {
1102                let mut row = Vec::<SectionFragment>::with_capacity(width);
1103                for variants in &variants_by_section {
1104                    row.push(variants[case_idx % variants.len()].clone());
1105                }
1106                case_matrix.push(row);
1107            }
1108
1109            Ok(case_matrix)
1110        }
1111        SectionValue::Match { expr, arms } => {
1112            let mut case_matrix = Vec::<Vec<SectionFragment>>::new();
1113
1114            if let Some(key) = expr_result_flag_key(&expr) {
1115                let target = active_key == Some(key.as_str());
1116                for arm in arms {
1117                    if arm.guard.is_none() {
1118                        if let Some(false) = pattern_matches_bool(&arm.pat, target) {
1119                            continue;
1120                        }
1121                    }
1122                    case_matrix.extend(collect_section_case_matrix(arm.value, width, active_key)?);
1123                }
1124            } else {
1125                for arm in arms {
1126                    case_matrix.extend(collect_section_case_matrix(arm.value, width, active_key)?);
1127                }
1128            }
1129
1130            if case_matrix.is_empty() {
1131                return Err("sql_forge!: section match must have at least one arm".to_string());
1132            }
1133
1134            Ok(case_matrix)
1135        }
1136    }
1137}
1138
1139// Returns Vec<Vec<SectionFragment>> indexed [section_idx][case_idx].
1140// =============================================================================
1141// Section variant collection
1142// =============================================================================
1143
1144/// Collects all possible `SectionFragment` values per section index. Each
1145/// returned `Vec<Vec<SectionFragment>>` is indexed `[section_idx][case_idx]`,
1146/// listing every variant that section can take across all match arms.
1147/// Used for full validation (cycling strategy over all variants).
1148fn collect_section_variants(
1149    value: SectionValue,
1150    width: usize,
1151) -> Result<Vec<Vec<SectionFragment>>, String> {
1152    transpose_section_case_matrix(collect_section_case_matrix(value, width, None)?, width)
1153}
1154
1155fn expr_result_flag_key(expr: &Expr) -> Option<String> {
1156    match strip_expr(expr) {
1157        Expr::Path(path) if path.qself.is_none() && path.path.segments.len() == 1 => {
1158            let name = path.path.segments[0].ident.to_string();
1159            name.strip_prefix("__enhanced_result_flag_")
1160                .map(|v| v.to_string())
1161        }
1162        _ => None,
1163    }
1164}
1165
1166fn pattern_matches_bool(pat: &Pat, value: bool) -> Option<bool> {
1167    match pat {
1168        Pat::Lit(expr_lit) => match &expr_lit.lit {
1169            Lit::Bool(lit_bool) => Some(lit_bool.value == value),
1170            _ => None,
1171        },
1172        Pat::Wild(_) => Some(true),
1173        _ => None,
1174    }
1175}
1176
1177/// Like `collect_section_variants`, but filters `match` arms by the active
1178/// result key when the match expression is a `{>key}` flag. When building the
1179/// query for a specific key, only the matching arm (true/false) is included;
1180/// arms with guards or non-flag expressions include all variants as usual.
1181fn collect_section_variants_for_result(
1182    value: SectionValue,
1183    width: usize,
1184    active_key: Option<&str>,
1185) -> Result<Vec<Vec<SectionFragment>>, String> {
1186    transpose_section_case_matrix(
1187        collect_section_case_matrix(value, width, active_key)?,
1188        width,
1189    )
1190}
1191
1192// =============================================================================
1193// Parameter binding generation
1194// =============================================================================
1195
1196/// Generates `let` bindings for all parameters used in the SQL or section
1197/// fragments. Returns a map of param-name → local ident for later reference,
1198/// and a list of `let` statements.
1199fn build_param_bindings(
1200    params: &ParamsSource,
1201    used_param_names: &[String],
1202    prefix: &str,
1203    for_validator: bool,
1204    enforce_usage_check: bool,
1205) -> Result<(HashMap<String, syn::Ident>, Vec<TokenStream2>), TokenStream> {
1206    let mut declared_params = HashMap::<String, syn::Ident>::new();
1207    let mut bindings = Vec::<TokenStream2>::new();
1208
1209    match params {
1210        ParamsSource::None => {}
1211        ParamsSource::Map(entries) => {
1212            for entry in entries {
1213                let key = entry.name.to_string();
1214                if declared_params.contains_key(&key) {
1215                    return Err(syn::Error::new(
1216                        entry.name.span(),
1217                        "sql_forge!: duplicated parameter mapping",
1218                    )
1219                    .to_compile_error()
1220                    .into());
1221                }
1222                if enforce_usage_check && !used_param_names.iter().any(|n| n == &key) {
1223                    return Err(syn::Error::new(
1224                        entry.name.span(),
1225                        format!(
1226                            "sql_forge!: parameter :{} is unused in the SQL template",
1227                            key,
1228                        ),
1229                    )
1230                    .to_compile_error()
1231                    .into());
1232                }
1233                let local_ident = format_ident!("__enhanced_{}_{}", prefix, key);
1234                let expr = &entry.expr;
1235                if for_validator {
1236                    bindings.push(quote! {
1237                        let #local_ident = &(#expr);
1238                    });
1239                } else {
1240                    bindings.push(quote! {
1241                        let #local_ident = #expr;
1242                    });
1243                }
1244                declared_params.insert(key, local_ident);
1245            }
1246        }
1247        ParamsSource::Struct(expr) => {
1248            let source_ident = format_ident!("__enhanced_source_{}", prefix);
1249            bindings.push(quote! {
1250                let #source_ident = &(#expr);
1251            });
1252            for name in used_param_names {
1253                let local_ident = format_ident!("__enhanced_{}_{}", prefix, name);
1254                let field_ident = format_ident!("{}", name);
1255                bindings.push(quote! {
1256                    let #local_ident = #source_ident.#field_ident;
1257                });
1258                declared_params.insert(name.to_string(), local_ident);
1259            }
1260        }
1261    }
1262
1263    Ok((declared_params, bindings))
1264}
1265
1266struct ValidatorRenderContext<'a> {
1267    local_params: &'a HashMap<String, syn::Ident>,
1268    top_level_params: &'a HashMap<String, syn::Ident>,
1269    allow_top_level_fallback: bool,
1270    use_dollar_params: bool,
1271    sql_span: Span,
1272    list_count: usize,
1273}
1274
1275/// Builds the placeholders SQL string and argument list for the compile-time
1276/// validator (sqlx::query_as! / query_scalar!). Each `:param` in the SQL is
1277/// replaced by `?` (MySQL/SQLite) or `$1`/`$2`/... (PostgreSQL), and the
1278/// corresponding value expression is collected into the args list.
1279fn render_validator_args(
1280    sql: &str,
1281    param_offset: &mut usize,
1282    context: &ValidatorRenderContext<'_>,
1283) -> Result<(String, Vec<TokenStream2>), TokenStream> {
1284    let (rendered_sql, occurrences) = render_validator_text(
1285        sql,
1286        context.use_dollar_params,
1287        param_offset,
1288        context.list_count,
1289    );
1290    let mut args = Vec::<TokenStream2>::new();
1291
1292    for (name, is_list) in occurrences {
1293        let local_ident = if context.allow_top_level_fallback {
1294            context
1295                .local_params
1296                .get(&name)
1297                .or_else(|| context.top_level_params.get(&name))
1298        } else {
1299            context.local_params.get(&name)
1300        };
1301
1302        let Some(local_ident) = local_ident else {
1303            return Err(syn::Error::new(
1304                context.sql_span,
1305                format!("sql_forge!: parameter :{} has no mapping", name),
1306            )
1307            .to_compile_error()
1308            .into());
1309        };
1310
1311        if is_list {
1312            for _ in 0..context.list_count {
1313                args.push(quote! { *(#local_ident).as_slice().first().unwrap_or(&0i64) });
1314            }
1315        } else {
1316            args.push(quote! { #local_ident });
1317        }
1318    }
1319
1320    Ok((rendered_sql, args))
1321}
1322
1323// =============================================================================
1324// Runtime code generation (QueryBuilder-based)
1325// =============================================================================
1326
1327/// Generates the `push()` / `push_bind()` calls for a single section fragment
1328/// at runtime using `sqlx::QueryBuilder`.
1329fn render_runtime_fragment(
1330    fragment: &SectionFragment,
1331    local_params: &HashMap<String, syn::Ident>,
1332) -> Result<TokenStream2, TokenStream> {
1333    let mut steps = Vec::<TokenStream2>::new();
1334
1335    for part in parse_text_parts(&fragment.sql) {
1336        match part {
1337            TextPart::Lit(lit) => {
1338                let lit_str = LitStr::new(&lit, fragment.span);
1339                steps.push(quote! { __builder.push(#lit_str); });
1340            }
1341            TextPart::Param { name, is_list } => {
1342                let Some(local_ident) = local_params.get(&name) else {
1343                    return Err(syn::Error::new(
1344                        fragment.span,
1345                        format!("sql_forge!: parameter :{} has no mapping", name),
1346                    )
1347                    .to_compile_error()
1348                    .into());
1349                };
1350
1351                if is_list {
1352                    steps.push(quote! {
1353                        let __enhanced_values = #local_ident;
1354                        let mut __separated = __builder.separated(", ");
1355                        for __value in __enhanced_values {
1356                            __separated.push_bind(__value);
1357                        }
1358                    });
1359                } else {
1360                    steps.push(quote! {
1361                        __builder.push_bind(#local_ident);
1362                    });
1363                }
1364            }
1365        }
1366    }
1367
1368    Ok(quote! { #( #steps )* })
1369}
1370
1371fn build_section_runtime_action(
1372    value: &SectionValue,
1373    section_idx: usize,
1374    prefix: &str,
1375) -> Result<TokenStream2, TokenStream> {
1376    match value {
1377        SectionValue::Single(fragment) => {
1378            let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
1379            let (local_params, bindings) =
1380                build_param_bindings(&fragment.params, &used_param_names, prefix, false, true)?;
1381            let body = render_runtime_fragment(fragment, &local_params)?;
1382            Ok(quote! {{ #( #bindings )* #body }})
1383        }
1384        SectionValue::Grouped(fragments) => build_section_runtime_action(
1385            &fragments[section_idx],
1386            0,
1387            &format!("{}_grouped_{}", prefix, section_idx),
1388        ),
1389        SectionValue::Match { expr, arms } => {
1390            let arm_tokens: Result<Vec<TokenStream2>, TokenStream> = arms
1391                .iter()
1392                .enumerate()
1393                .map(|(arm_idx, arm)| {
1394                    let pat = &arm.pat;
1395                    let guard_tokens = arm.guard.as_ref().map(|guard| quote! { if #guard });
1396                    let body = build_section_runtime_action(
1397                        &arm.value,
1398                        section_idx,
1399                        &format!("{}_{}", prefix, arm_idx),
1400                    )?;
1401                    Ok::<TokenStream2, TokenStream>(quote! { #pat #guard_tokens => #body })
1402                })
1403                .collect();
1404            let arm_tokens = arm_tokens?;
1405            Ok(quote! {
1406                match #expr {
1407                    #( #arm_tokens ),*
1408                }
1409            })
1410        }
1411    }
1412}
1413
1414fn collect_used_param_names(segments: &[Segment]) -> Vec<String> {
1415    let mut names = Vec::new();
1416    let mut seen = HashSet::<String>::new();
1417
1418    for segment in segments {
1419        match segment {
1420            Segment::Text(text) => {
1421                for name in collect_used_param_names_in_sql(text) {
1422                    if seen.insert(name.clone()) {
1423                        names.push(name);
1424                    }
1425                }
1426            }
1427            Segment::Batch { parts } => {
1428                for part in parts {
1429                    if let TextPart::Param { name, .. } = part {
1430                        if seen.insert(name.clone()) {
1431                            names.push(name.clone());
1432                        }
1433                    }
1434                }
1435            }
1436            _ => {}
1437        }
1438    }
1439
1440    names
1441}
1442
1443fn collect_used_param_names_in_sql(sql: &str) -> Vec<String> {
1444    let mut names = Vec::new();
1445    let mut seen = HashSet::<String>::new();
1446    for part in parse_text_parts(sql) {
1447        if let TextPart::Param { name, .. } = part {
1448            if seen.insert(name.to_string()) {
1449                names.push(name);
1450            }
1451        }
1452    }
1453    names
1454}
1455
1456/// Builds a parameterized SQL query with compile-time type-checking and a
1457/// runtime [`sqlx::QueryBuilder`] for dynamic SQL.
1458///
1459/// Combines `sqlx::query_as!` / `sqlx::query_scalar!` validation (never called
1460/// at runtime) with `QueryBuilder::push_bind` for safe value binding.
1461///
1462/// # Syntax
1463///
1464/// ```text
1465/// sql_forge!(
1466///     [DB,]        // optional: sqlx::MySql | sqlx::Postgres | sqlx::Sqlite
1467///     [Model,]     // optional result spec
1468///     SQL,         // string literal
1469///     [params,]    // optional: ( :name = expr, ... )  or  struct_expr
1470///     [(sections),]// optional: ( #name = ..., ... )
1471///     [..batch]    // optional: batch source expression used by {( ... )}
1472/// )
1473///
1474/// `Model` has three forms:
1475/// - omitted: execute-only query; only `.execute(...)` is available
1476/// - `Type` or `scalar Type`: a single result query
1477/// - `( >key1 = TypeA, >key2 = scalar TypeB )`: a grouped multi-result query
1478///
1479/// The trailing parameter source, section map, and batch source are optional.
1480/// The batch source may appear alongside the others as a single `..expr` argument.
1481///
1482/// The DB type may be omitted when `SQL_FORGE_DB_TYPE` is set (e.g.
1483/// `SQL_FORGE_DB_TYPE=sqlx::MySql`) or when
1484/// `[package.metadata.sql_forge] db = "..."` is set in `Cargo.toml`.
1485/// The env var takes priority over Cargo.toml metadata.
1486///
1487/// # Parameters
1488///
1489/// Named parameters are written `:name` in the SQL. At runtime each occurrence
1490/// is replaced by a `push_bind` call; at compile time it becomes `?` for
1491/// `query_as!` / `query_scalar!`.
1492///
1493/// **Inline map** – bind individual expressions:
1494/// ```rust,ignore
1495/// sql_forge!(User, "SELECT ... WHERE id <= :max_id", ( :max_id = filter.max_id ))
1496/// ```
1497///
1498/// **Struct source** – field names are matched to `:name` placeholders automatically:
1499/// ```rust,ignore
1500/// sql_forge!(User, "SELECT ... WHERE id <= :max_id LIMIT :limit", filter)
1501/// ```
1502///
1503/// # Sections (`{#name}`)
1504///
1505/// Sections are runtime SQL slots; each section's variants are validated at
1506/// compile time via `query_as!` / `query_scalar!`, though not every combination
1507/// of variants across sections is checked. The section map is a second parenthesised
1508/// argument starting with `#`:
1509///
1510/// ```rust,ignore
1511/// sql_forge!(
1512///     User,
1513///     "SELECT * FROM users {#join_org}",
1514///     (
1515///         #join_org = match include_org {
1516///             true  => " JOIN organisations o ON o.id = users.org_id ",
1517///             false => "",
1518///         }
1519///     )
1520/// )
1521/// ```
1522///
1523/// A section arm can also carry local parameters as a tuple `("sql", params)`:
1524///
1525/// ```rust,ignore
1526/// sql_forge!(
1527///     User,
1528///     "SELECT * FROM users {#filter}",
1529///     (
1530///         #filter = (
1531///             " WHERE id <= :max_id AND status = :status ",
1532///             ( :max_id = max_id, :status = "active" ),
1533///         )
1534///     )
1535/// )
1536/// ```
1537///
1538/// Multiple placeholders driven by one `match` use `#(a, b)` with each arm
1539/// returning a tuple of the same width:
1540///
1541/// ```rust,ignore
1542/// sql_forge!(
1543///     User,
1544///     "SELECT * FROM users {#join_org} {#filter_org}",
1545///     (
1546///         #(join_org, filter_org) = match include_org {
1547///             true  => (
1548///                 " JOIN organisations o ON o.id = users.org_id ",
1549///                 (
1550///                     " AND o.active = :active ",
1551///                     ( :active = true ),
1552///                 ),
1553///             ),
1554///             false => ("", ""),
1555///         }
1556///     )
1557/// )
1558/// ```
1559///
1560/// Grouped section items may themselves use nested `match` expressions. Those
1561/// nested matches use smart cycling within the arm rather than a cartesian
1562/// product. For example, if one grouped arm returns a fixed first item plus two
1563/// nested binary matches for the second and third items, that arm contributes
1564/// two aligned variants `(0, 0)` and `(1, 1)`, not four `(0, 0)`, `(0, 1)`,
1565/// `(1, 0)`, `(1, 1)` combinations.
1566///
1567/// # `IN (...)` with list parameters
1568///
1569/// Wrap the placeholder in parentheses to expand a `Vec` into multiple bound
1570/// values:
1571///
1572/// ```rust,ignore
1573/// sql_forge!(User, "SELECT * FROM users WHERE id IN (:ids[])", ( :ids = ids ))
1574/// ```
1575///
1576/// **Empty lists** are not rewritten; `IN ()` is a database syntax error.
1577/// Guard against empty inputs explicitly, e.g. with a dynamic section:
1578///
1579/// ```rust,ignore
1580/// sql_forge!(
1581///     User,
1582///     "SELECT id, name FROM users WHERE {#filter}",
1583///     (
1584///         #filter = match ids.is_empty() {
1585///             true  => "1 = 0",
1586///             false => ("id IN (:ids[])", ( :ids = ids )),
1587///         }
1588///     )
1589/// )
1590/// ```
1591///
1592/// # Batch inserts (`{( ... )}`)
1593///
1594/// A batch section `{( ... )}` repeats its content for each item in an iterable
1595/// source passed as `..expr`. Inside the batch, `:name` refers to a field on the
1596/// current item. List parameters (`:name[]`) are **not** allowed inside batch
1597/// sections.
1598///
1599/// ## Struct batch
1600///
1601/// ```rust,ignore
1602/// struct BatchItem { name: String, price: i64 }
1603///
1604/// let items = vec![
1605///     BatchItem { name: "A".into(), price: 100 },
1606///     BatchItem { name: "B".into(), price: 200 },
1607/// ];
1608///
1609/// sql_forge!(
1610///     "INSERT INTO products (name, price, stock, category)
1611///      VALUES {(:name, :price, 10, 'Batch')}",
1612///     ..items
1613/// )
1614/// .execute(&pool)
1615/// .await?;
1616/// ```
1617///
1618/// For compile-time checking, the validator expands the batch to 3 fake copies
1619/// (`(?, ?, 10, 'Batch'), (?, ?, 10, 'Batch')`, (?, ?, 10, 'Batch')`).
1620/// At runtime the iterable drives the actual number of rows.
1621///
1622/// # Scalar output
1623///
1624/// When `Model` is a primitive (`i32`, `i64`, `String`, etc.) the macro uses
1625/// `query_scalar!` for validation and `build_query_scalar` for execution.
1626///
1627/// # Multiple results
1628///
1629/// A result map produces a `SqlForgeQueryGroup` with one query per key.
1630/// Each key can be a struct or a primitive (used as a scalar):
1631///
1632/// ```rust,ignore
1633/// sql_forge!(
1634///     (
1635///         >count = i64,
1636///         >items = Item,
1637///     ),
1638///     "SELECT {#fields} FROM items WHERE category_id = :cat",
1639///     ( :cat = category_id ),
1640///     (
1641///         #fields = match {>count} {           // {>key} is true when building
1642///             true  => "COUNT(*) AS total",    //   the query for that key
1643///             false => "id, name, price",       //   and false otherwise
1644///         }
1645///     )
1646/// )
1647/// ```
1648///
1649/// The generated struct has one field per key (`group.count`, `group.items`),
1650/// each implementing `SqlForgeQuery<T, Db = DB>` and usable with any SQLx
1651/// executor method (`fetch_one`, `fetch_all`, etc.).
1652///
1653/// # Execute-only (no model)
1654///
1655/// When the model type is omitted, the macro produces a value implementing
1656/// `SqlForgeQueryExecute`. Only `.execute(executor)`
1657/// is available and there is no return type to deserialize into. This is useful
1658/// for `INSERT`, `UPDATE`, `DELETE`, and other DML statements.
1659///
1660/// ```rust,ignore
1661/// sql_forge!(
1662///     "UPDATE products SET stock = stock + 1 WHERE id = :id",
1663///     ( :id = 42i64 ),
1664/// )
1665/// .execute(&pool)
1666/// .await?;
1667/// ```
1668///
1669/// Sections and struct parameter sources work the same way as in model-backed queries:
1670///
1671/// ```rust,ignore
1672/// sql_forge!(
1673///     "UPDATE products SET price = :new_price {#filter}",
1674///     ( #filter = "WHERE category = :cat", ( :cat = "Electronics" ) ),
1675/// )
1676/// .execute(&pool)
1677/// .await?;
1678/// ```
1679///
1680/// # Caveats
1681///
1682/// **String literals containing `:`**
1683///
1684/// In this case, the template scanner cannot distinguish
1685/// a colon inside a SQL string literal from a `:name` placeholder. Embedding
1686/// `"abc:def"` directly in the template will fail.
1687/// Pass such values as bind parameters instead:
1688///
1689/// ```rust,ignore
1690/// // ❌  sql_forge!(User, r#"WHERE name = "abc:def""#);
1691/// // ✅
1692/// sql_forge!(User, r#"WHERE name = :name"#, ( :name = "abc:def" ))
1693/// ```
1694///
1695/// **String literals containing `{#`**
1696///
1697/// Similarly, a `{#` sequence inside a
1698/// SQL string literal is treated as a section slot. Pass the value as a bind
1699/// parameter:
1700///
1701/// ```rust,ignore
1702/// // ❌  sql_forge!(User, r#"WHERE name = "abc{#def""#);
1703/// // ✅
1704/// sql_forge!(User, r#"WHERE name = :name"#, ( :name = "abc{#def" ))
1705/// ```
1706#[proc_macro]
1707#[allow(clippy::too_many_lines)]
1708pub fn sql_forge(input: TokenStream) -> TokenStream {
1709    // ---- Phase 1: Parse the macro input into structured data ----
1710    let preprocessed = preprocess_result_key_placeholders(TokenStream2::from(input));
1711    let SqlForgeInput {
1712        db,
1713        result,
1714        force_scalar,
1715        sql,
1716        params,
1717        sections,
1718        batch,
1719    } = match syn::parse2::<SqlForgeInput>(preprocessed) {
1720        Ok(v) => v,
1721        Err(err) => return err.to_compile_error().into(),
1722    };
1723
1724    // ---- Phase 2: Resolve database type (from macro arg or Cargo.toml) ----
1725    let db = match db {
1726        Some(db) => db,
1727        None => match resolve_db_from_env() {
1728            Ok(db) => db,
1729            Err(msg) => {
1730                return syn::Error::new(Span::call_site(), msg)
1731                    .to_compile_error()
1732                    .into();
1733            }
1734        },
1735    };
1736
1737    let use_dollar_params = uses_dollar_params(&db);
1738    let is_sqlite = if let syn::Type::Path(type_path) = &db {
1739        type_path
1740            .path
1741            .segments
1742            .last()
1743            .is_some_and(|s| s.ident == "Sqlite")
1744    } else {
1745        false
1746    };
1747    let list_count: usize = if is_sqlite { 1 } else { 3 };
1748
1749    // ---- Phase 3: Build result case definitions ----
1750    // Each result case is (optional_key, model_type, optional_scalar_type).
1751    // Scalar type is set for primitives and `scalar`-marked types.
1752    let result_cases: Vec<(Option<String>, Option<Type>, Option<Type>)> = match result {
1753        ResultSpec::None => {
1754            vec![(None, None, None)]
1755        }
1756        ResultSpec::Single(ref model) => {
1757            let model_ty = (**model).clone();
1758            let scalar = if force_scalar {
1759                Some(model_ty.clone())
1760            } else {
1761                scalar_output_type(model.as_ref()).cloned()
1762            };
1763            vec![(None, Some(model_ty), scalar)]
1764        }
1765        ResultSpec::Group(ref cases) => {
1766            if force_scalar {
1767                return syn::Error::new(
1768                    Span::call_site(),
1769                    "sql_forge!: scalar mode is not supported for grouped result maps",
1770                )
1771                .to_compile_error()
1772                .into();
1773            }
1774
1775            let mut out = Vec::new();
1776            let mut seen = HashSet::new();
1777            for case in cases {
1778                let key = case.name.to_string();
1779                if !seen.insert(key.clone()) {
1780                    return syn::Error::new(
1781                        case.name.span(),
1782                        "sql_forge!: duplicated key in result map",
1783                    )
1784                    .to_compile_error()
1785                    .into();
1786                }
1787
1788                let model = case.model.clone();
1789                let scalar = if case.force_scalar {
1790                    Some(model.clone())
1791                } else {
1792                    scalar_output_type(&case.model).cloned()
1793                };
1794                out.push((Some(key), Some(model), scalar));
1795            }
1796            out
1797        }
1798    };
1799    let group_result_keys: Vec<String> = result_cases
1800        .iter()
1801        .filter_map(|(key, _, _)| key.as_ref().cloned())
1802        .collect();
1803    let is_grouped_result = !group_result_keys.is_empty();
1804    let sql_span = sql.span();
1805
1806    // ---- Phase 4: Parse SQL into segments (text + {#section} slots) ----
1807    let segments = match sql.into_segments() {
1808        Ok(segments) => segments,
1809        Err(msg) => {
1810            return syn::Error::new(sql_span, msg).to_compile_error().into();
1811        }
1812    };
1813
1814    let has_batch_segment = segments.iter().any(|s| matches!(s, Segment::Batch { .. }));
1815    match (&batch, has_batch_segment) {
1816        (None, true) => {
1817            return syn::Error::new(
1818                sql_span,
1819                "sql_forge!: SQL contains {( ... )} batch section but no batch source argument (..expr) \
1820                 was provided"
1821            )
1822            .to_compile_error()
1823            .into();
1824        }
1825        (Some(_), false) => {
1826            return syn::Error::new(
1827                sql_span,
1828                "sql_forge!: batch source argument (..expr) provided but SQL has no {( ... )} \
1829                 batch section",
1830            )
1831            .to_compile_error()
1832            .into();
1833        }
1834        _ => {}
1835    }
1836
1837    let used_param_names = collect_used_param_names(&segments);
1838
1839    // Batch-only params come from batch items, not the top-level params map.
1840    // They must be excluded from the usage check so that a param like :category
1841    // that appears only inside {( ... )} is flagged as unused when given in the
1842    // params map, as it would never be read from there at runtime.
1843    let batch_param_names: std::collections::HashSet<String> = segments
1844        .iter()
1845        .filter_map(|s| {
1846            if let Segment::Batch { parts } = s {
1847                Some(parts.iter().filter_map(|p| {
1848                    if let TextPart::Param { name, .. } = p {
1849                        Some(name.clone())
1850                    } else {
1851                        None
1852                    }
1853                }))
1854            } else {
1855                None
1856            }
1857        })
1858        .flatten()
1859        .collect();
1860    let top_level_used_names: Vec<String> = used_param_names
1861        .iter()
1862        .filter(|n| !batch_param_names.contains(*n))
1863        .cloned()
1864        .collect();
1865
1866    // ---- Phase 5: Build parameter bindings for the top-level params ----
1867    let (declared_params, validator_param_bindings) =
1868        match build_param_bindings(&params, &top_level_used_names, "top_level", true, true) {
1869            Ok(v) => v,
1870            Err(err) => return err,
1871        };
1872
1873    let mut runtime_section_actions = HashMap::<String, TokenStream2>::new();
1874
1875    // ---- Phase 6: Process sections: build runtime actions and collect validation variants ----
1876    for assign in &sections {
1877        let SectionAssign { names, value } = assign;
1878
1879        // Build runtime actions first, while `value` is still available by reference.
1880        let mut named_actions: Vec<(String, TokenStream2)> = Vec::new();
1881        for (section_idx, name_ident) in names.iter().enumerate() {
1882            let name = name_ident.to_string();
1883            if runtime_section_actions.contains_key(&name) {
1884                return syn::Error::new(
1885                    name_ident.span(),
1886                    "sql_forge!: duplicated section mapping",
1887                )
1888                .to_compile_error()
1889                .into();
1890            }
1891            let action = match build_section_runtime_action(
1892                value,
1893                section_idx,
1894                &format!("section_{}", name),
1895            ) {
1896                Ok(action) => action,
1897                Err(err) => return err,
1898            };
1899            named_actions.push((name, action));
1900        }
1901
1902        // Consume `value` here so invalid grouped/nested section structures fail early.
1903        if let Err(msg) = collect_section_variants(value.clone(), names.len()) {
1904            return syn::Error::new(names[0].span(), msg)
1905                .to_compile_error()
1906                .into();
1907        }
1908
1909        for (name, action) in named_actions {
1910            runtime_section_actions.insert(name, action);
1911        }
1912    }
1913
1914    let sql_section_names: std::collections::HashSet<&str> = segments
1915        .iter()
1916        .filter_map(|seg| {
1917            if let Segment::Section { name } = seg {
1918                Some(name.as_str())
1919            } else {
1920                None
1921            }
1922        })
1923        .collect();
1924    for name in runtime_section_actions.keys() {
1925        if !sql_section_names.contains(name.as_str()) {
1926            return syn::Error::new(
1927                sql_span,
1928                format!(
1929                    "sql_forge!: section `#{}` is declared in the section map but `{{#{}}}` never appears in the SQL",
1930                    name, name,
1931                ),
1932            )
1933            .to_compile_error()
1934            .into();
1935        }
1936    }
1937
1938    // ---- Phase 8: For each result case, generate validator + runtime tokens ----
1939    let mut generated_query_defs = Vec::<TokenStream2>::new();
1940    let mut generated_query_values = Vec::<TokenStream2>::new();
1941    let mut group_field_defs = Vec::<TokenStream2>::new();
1942    let mut group_method_defs = Vec::<TokenStream2>::new();
1943    let mut group_field_idents = Vec::<syn::Ident>::new();
1944    let mut group_field_tys = Vec::<TokenStream2>::new();
1945    let mut group_trait_impls = Vec::<TokenStream2>::new();
1946
1947    let mut grouped_validator_invocations = Vec::<TokenStream2>::new();
1948
1949    for (result_key, model_opt, scalar_model_ty) in result_cases.iter() {
1950        let suffix = result_key.as_deref().unwrap_or("single");
1951        let query_ident = format_ident!("__SqlForgeQuery_{}", suffix);
1952        let query_value_ident = format_ident!("__sql_forge_value_{}", suffix);
1953
1954        let flag_bindings = build_result_flag_bindings(&group_result_keys, result_key.as_deref());
1955
1956        let mut section_variants_for_validation = HashMap::<String, Vec<SectionFragment>>::new();
1957        for assign in &sections {
1958            let SectionAssign { names, value } = assign;
1959            let variants_by_section = match collect_section_variants_for_result(
1960                value.clone(),
1961                names.len(),
1962                result_key.as_deref(),
1963            ) {
1964                Ok(v) => v,
1965                Err(msg) => {
1966                    return syn::Error::new(names[0].span(), msg)
1967                        .to_compile_error()
1968                        .into();
1969                }
1970            };
1971
1972            for (name_ident, section_cases) in names.iter().zip(variants_by_section) {
1973                section_variants_for_validation.insert(name_ident.to_string(), section_cases);
1974            }
1975        }
1976
1977        let mut nmax = 1usize;
1978        for segment in &segments {
1979            if let Segment::Section { name } = segment {
1980                if let Some(variants) = section_variants_for_validation.get(name) {
1981                    if variants.is_empty() {
1982                        return syn::Error::new(
1983                            sql_span,
1984                            format!("sql_forge!: section {{#{}}} has no possible variants", name),
1985                        )
1986                        .to_compile_error()
1987                        .into();
1988                    }
1989                    nmax = nmax.max(variants.len());
1990                } else {
1991                    return syn::Error::new(
1992                        sql_span,
1993                        format!("sql_forge!: section {{#{}}} has no mapping", name),
1994                    )
1995                    .to_compile_error()
1996                    .into();
1997                }
1998            }
1999        }
2000
2001        let mut validator_cases = Vec::<(LitStr, Vec<TokenStream2>, Vec<TokenStream2>)>::new();
2002        for case_idx in 0..nmax {
2003            let mut sql_case = String::new();
2004            let mut case_setup = Vec::<TokenStream2>::new();
2005            let mut case_args = Vec::<TokenStream2>::new();
2006            let mut param_offset = 0usize;
2007            let empty_params = HashMap::<String, syn::Ident>::new();
2008            let root_validator_context = ValidatorRenderContext {
2009                local_params: &empty_params,
2010                top_level_params: &declared_params,
2011                allow_top_level_fallback: true,
2012                use_dollar_params,
2013                sql_span,
2014                list_count,
2015            };
2016
2017            for segment in &segments {
2018                match segment {
2019                    Segment::Text(text) => {
2020                        let (chunk_sql, chunk_args) = match render_validator_args(
2021                            text,
2022                            &mut param_offset,
2023                            &root_validator_context,
2024                        ) {
2025                            Ok(value) => value,
2026                            Err(err) => return err,
2027                        };
2028                        sql_case.push_str(&chunk_sql);
2029                        case_args.extend(chunk_args);
2030                    }
2031                    Segment::Section { name } => {
2032                        let Some(variants) = section_variants_for_validation.get(name) else {
2033                            return syn::Error::new(
2034                                sql_span,
2035                                format!("sql_forge!: section {{#{}}} has no mapping", name),
2036                            )
2037                            .to_compile_error()
2038                            .into();
2039                        };
2040
2041                        let fragment = &variants[case_idx % variants.len()];
2042                        let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
2043                        let (local_params, bindings) = match build_param_bindings(
2044                            &fragment.params,
2045                            &used_param_names,
2046                            &format!("section_case_{}_{}_{}", suffix, case_idx, name),
2047                            true,
2048                            true,
2049                        ) {
2050                            Ok(value) => value,
2051                            Err(err) => return err,
2052                        };
2053                        let section_validator_context = ValidatorRenderContext {
2054                            local_params: &local_params,
2055                            top_level_params: &declared_params,
2056                            allow_top_level_fallback: false,
2057                            use_dollar_params,
2058                            sql_span: fragment.span,
2059                            list_count,
2060                        };
2061                        let (chunk_sql, chunk_args) = match render_validator_args(
2062                            &fragment.sql,
2063                            &mut param_offset,
2064                            &section_validator_context,
2065                        ) {
2066                            Ok(value) => value,
2067                            Err(err) => return err,
2068                        };
2069                        sql_case.push_str(&chunk_sql);
2070                        case_setup.extend(bindings);
2071                        case_args.extend(chunk_args);
2072                    }
2073                    Segment::Batch { parts } => {
2074                        let mut first = true;
2075                        for _ in 0..list_count {
2076                            let sep = if first { "" } else { ", " };
2077                            first = false;
2078                            sql_case.push_str(sep);
2079                            for tp in parts {
2080                                match tp {
2081                                    TextPart::Lit(lit) => sql_case.push_str(lit),
2082                                    TextPart::Param { name, .. } => {
2083                                        if let Some(batch_expr) = &batch {
2084                                            let field_ident = format_ident!("{}", name);
2085                                            if use_dollar_params {
2086                                                param_offset += 1;
2087                                                write!(sql_case, "${}", param_offset).unwrap();
2088                                            } else {
2089                                                sql_case.push('?');
2090                                            }
2091                                            case_args.push(quote! { #batch_expr[0].#field_ident });
2092                                        } else if use_dollar_params {
2093                                            param_offset += 1;
2094                                            write!(sql_case, "${}", param_offset).unwrap();
2095                                        } else {
2096                                            sql_case.push('?');
2097                                        }
2098                                    }
2099                                }
2100                            }
2101                        }
2102                    }
2103                }
2104            }
2105
2106            validator_cases.push((LitStr::new(&sql_case, sql_span), case_setup, case_args));
2107        }
2108
2109        let mut validator_invocations = Vec::<TokenStream2>::new();
2110        for (sql_lit, case_setup, args) in &validator_cases {
2111            if model_opt.is_none() {
2112                if args.is_empty() {
2113                    validator_invocations.push(quote! {
2114                        {
2115                            #( #case_setup )*
2116                            let _ = sqlx::query_scalar!(
2117                                #sql_lit,
2118                            );
2119                        }
2120                    });
2121                } else {
2122                    validator_invocations.push(quote! {
2123                        {
2124                            #( #case_setup )*
2125                            let _ = sqlx::query_scalar!(
2126                                #sql_lit,
2127                                #( #args ),*
2128                            );
2129                        }
2130                    });
2131                }
2132            } else if let Some(scalar_ty) = scalar_model_ty {
2133                if args.is_empty() {
2134                    validator_invocations.push(quote! {
2135                        {
2136                            #( #case_setup )*
2137                            let _ = sqlx::query_scalar!(
2138                                #sql_lit,
2139                            );
2140                        }
2141                    });
2142                } else {
2143                    validator_invocations.push(quote! {
2144                        {
2145                            #( #case_setup )*
2146                            let _ = sqlx::query_scalar!(
2147                                #sql_lit,
2148                                #( #args ),*
2149                            );
2150                        }
2151                    });
2152                }
2153                let _ = scalar_ty;
2154            } else if args.is_empty() {
2155                validator_invocations.push(quote! {
2156                    {
2157                        #( #case_setup )*
2158                        let _ = sqlx::query_as!(
2159                            __EnhancedModel,
2160                            #sql_lit,
2161                        );
2162                    }
2163                });
2164            } else {
2165                validator_invocations.push(quote! {
2166                    {
2167                        #( #case_setup )*
2168                        let _ = sqlx::query_as!(
2169                            __EnhancedModel,
2170                            #sql_lit,
2171                            #( #args ),*
2172                        );
2173                    }
2174                });
2175            }
2176        }
2177
2178        let model_alias = if let Some(model) = model_opt {
2179            if scalar_model_ty.is_none() {
2180                quote! { type __EnhancedModel = #model; }
2181            } else {
2182                quote! {}
2183            }
2184        } else {
2185            quote! {}
2186        };
2187        grouped_validator_invocations.push(quote! {
2188            {
2189                #( #flag_bindings )*
2190                #model_alias
2191                #( #validator_invocations )*
2192            }
2193        });
2194
2195        let (runtime_declared_params, runtime_param_bindings) =
2196            match build_param_bindings(&params, &used_param_names, "runtime", false, false) {
2197                Ok(v) => v,
2198                Err(err) => return err,
2199            };
2200
2201        let mut runtime_steps = Vec::<TokenStream2>::new();
2202        for (seg_idx, segment) in segments.iter().enumerate() {
2203            match segment {
2204                Segment::Text(text) => {
2205                    for part in parse_text_parts(text) {
2206                        match part {
2207                            TextPart::Lit(lit) => {
2208                                let lit = sanitize_runtime_sql_text(&lit);
2209                                let lit_str = LitStr::new(&lit, sql_span);
2210                                runtime_steps.push(quote! {
2211                                    __builder.push(#lit_str);
2212                                });
2213                            }
2214                            TextPart::Param { name, is_list } => {
2215                                let Some(local_ident) = runtime_declared_params.get(&name) else {
2216                                    return syn::Error::new(
2217                                        sql_span,
2218                                        format!("sql_forge!: parameter :{} has no mapping", name),
2219                                    )
2220                                    .to_compile_error()
2221                                    .into();
2222                                };
2223
2224                                if is_list {
2225                                    runtime_steps.push(quote! {
2226                                        let __enhanced_values = #local_ident;
2227                                        let mut __separated = __builder.separated(", ");
2228                                        for __value in __enhanced_values {
2229                                            __separated.push_bind(__value);
2230                                        }
2231                                    });
2232                                } else {
2233                                    runtime_steps.push(quote! {
2234                                        __builder.push_bind(#local_ident);
2235                                    });
2236                                }
2237                            }
2238                        }
2239                    }
2240                }
2241                Segment::Section { name } => {
2242                    let Some(section_action) = runtime_section_actions.get(name) else {
2243                        let _ = seg_idx;
2244                        return syn::Error::new(
2245                            sql_span,
2246                            format!("sql_forge!: section {{#{}}} has no mapping", name),
2247                        )
2248                        .to_compile_error()
2249                        .into();
2250                    };
2251                    runtime_steps.push(quote! {
2252                        #section_action
2253                    });
2254                }
2255                Segment::Batch { parts } => {
2256                    if let Some(batch_expr) = &batch {
2257                        let mut body = Vec::<TokenStream2>::new();
2258                        for part in parts {
2259                            match part {
2260                                TextPart::Lit(lit) => {
2261                                    let lit_str = LitStr::new(lit, sql_span);
2262                                    body.push(quote! {
2263                                        __builder.push(#lit_str);
2264                                    });
2265                                }
2266                                TextPart::Param { name, .. } => {
2267                                    let field_ident = format_ident!("{}", name);
2268                                    body.push(quote! {
2269                                        __builder.push_bind(__item.#field_ident);
2270                                    });
2271                                }
2272                            }
2273                        }
2274                        runtime_steps.push(quote! {
2275                            {
2276                                let mut __first = true;
2277                                for __item in #batch_expr {
2278                                    if !__first {
2279                                        __builder.push(", ");
2280                                    }
2281                                    __first = false;
2282                                    #( #body )*
2283                                }
2284                            }
2285                        });
2286                    }
2287                }
2288            }
2289        }
2290
2291        let exec_methods = if model_opt.is_none() {
2292            quote! {
2293                async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2294                where
2295                    E: sqlx::Executor<'e, Database = #db>,
2296                {
2297                    self.inner.build().execute(executor).await
2298                }
2299            }
2300        } else if let Some(scalar_ty) = scalar_model_ty {
2301            quote! {
2302                async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#scalar_ty>, sqlx::Error>
2303                where
2304                    E: sqlx::Executor<'e, Database = #db>,
2305                {
2306                    self.inner
2307                        .build_query_scalar::<#scalar_ty>()
2308                        .fetch_all(executor)
2309                        .await
2310                }
2311
2312                async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#scalar_ty, sqlx::Error>
2313                where
2314                    E: sqlx::Executor<'e, Database = #db>,
2315                {
2316                    self.inner
2317                        .build_query_scalar::<#scalar_ty>()
2318                        .fetch_one(executor)
2319                        .await
2320                }
2321
2322                async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#scalar_ty>, sqlx::Error>
2323                where
2324                    E: sqlx::Executor<'e, Database = #db>,
2325                {
2326                    self.inner
2327                        .build_query_scalar::<#scalar_ty>()
2328                        .fetch_optional(executor)
2329                        .await
2330                }
2331
2332                async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2333                where
2334                    E: sqlx::Executor<'e, Database = #db>,
2335                {
2336                    self.inner.build().execute(executor).await
2337                }
2338            }
2339        } else {
2340            let model = model_opt.as_ref().unwrap();
2341            quote! {
2342                async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#model>, sqlx::Error>
2343                where
2344                    E: sqlx::Executor<'e, Database = #db>,
2345                {
2346                    self.inner.build_query_as::<#model>().fetch_all(executor).await
2347                }
2348
2349                async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#model, sqlx::Error>
2350                where
2351                    E: sqlx::Executor<'e, Database = #db>,
2352                {
2353                    self.inner.build_query_as::<#model>().fetch_one(executor).await
2354                }
2355
2356                async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#model>, sqlx::Error>
2357                where
2358                    E: sqlx::Executor<'e, Database = #db>,
2359                {
2360                    self.inner
2361                        .build_query_as::<#model>()
2362                        .fetch_optional(executor)
2363                        .await
2364                }
2365
2366                async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2367                where
2368                    E: sqlx::Executor<'e, Database = #db>,
2369                {
2370                    self.inner.build().execute(executor).await
2371                }
2372            }
2373        };
2374
2375        let final_type: TokenStream2 = if let Some(model) = model_opt {
2376            if let Some(scalar_ty) = scalar_model_ty {
2377                quote! { #scalar_ty }
2378            } else {
2379                quote! { #model }
2380            }
2381        } else {
2382            quote! {}
2383        };
2384        let trait_impl = if model_opt.is_none() {
2385            quote! {
2386                impl<'args> sql_forge::SqlForgeQueryExecute
2387                    for #query_ident<'args>
2388                {
2389                    type Db = #db;
2390
2391                    fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2392                    where
2393                        Self: Sized + 'e,
2394                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2395                        #db: 'e,
2396                    {
2397                        #query_ident::execute(self, executor)
2398                    }
2399                }
2400            }
2401        } else {
2402            quote! {
2403                impl<'args> sql_forge::SqlForgeQuery<#final_type>
2404                    for #query_ident<'args>
2405                {
2406                    type Db = #db;
2407
2408                    fn fetch_all<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Vec<#final_type>, sqlx::Error>> + Send + 'e
2409                    where
2410                        Self: Sized + 'e,
2411                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2412                        #db: 'e,
2413                    {
2414                        #query_ident::fetch_all(self, executor)
2415                    }
2416
2417                    fn fetch_one<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<#final_type, sqlx::Error>> + Send + 'e
2418                    where
2419                        Self: Sized + 'e,
2420                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2421                        #db: 'e,
2422                    {
2423                        #query_ident::fetch_one(self, executor)
2424                    }
2425
2426                    fn fetch_optional<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Option<#final_type>, sqlx::Error>> + Send + 'e
2427                    where
2428                        Self: Sized + 'e,
2429                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2430                        #db: 'e,
2431                    {
2432                        #query_ident::fetch_optional(self, executor)
2433                    }
2434
2435                    fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2436                    where
2437                        Self: Sized + 'e,
2438                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2439                        #db: 'e,
2440                    {
2441                        #query_ident::execute(self, executor)
2442                    }
2443                }
2444            }
2445        };
2446
2447        generated_query_defs.push(quote! {
2448            struct #query_ident<'args> {
2449                inner: sqlx::QueryBuilder<'args, #db>,
2450            }
2451
2452            impl<'args> #query_ident<'args> {
2453                #exec_methods
2454            }
2455
2456            #trait_impl
2457        });
2458
2459        generated_query_values.push(quote! {
2460            #( #runtime_param_bindings )*
2461            #( #flag_bindings )*
2462            let mut __builder: sqlx::QueryBuilder<#db> = sqlx::QueryBuilder::new("");
2463            #( #runtime_steps )*
2464            let #query_value_ident = #query_ident { inner: __builder };
2465        });
2466
2467        if let Some(key) = result_key {
2468            let method_ident = format_ident!("{}", key);
2469            group_field_defs.push(quote! {
2470                #method_ident: #query_ident<'args>
2471            });
2472            group_field_tys.push(quote! { #query_ident<'args> });
2473            group_method_defs.push(quote! {
2474                pub fn #method_ident(self) -> #query_ident<'args> {
2475                    self.#method_ident
2476                }
2477            });
2478
2479            let key_ty_ident = format_ident!("__SqlForgeQueryGroupKey_{}", key);
2480            group_trait_impls.push(quote! {
2481                struct #key_ty_ident;
2482
2483                impl<'args> sql_forge::SqlForgeQueryGroupGet<#key_ty_ident, #final_type> for __SqlForgeQueryGroup<'args> {
2484                    type Query = #query_ident<'args>;
2485
2486                    fn get(self, _: #key_ty_ident) -> Self::Query {
2487                        self.#method_ident
2488                    }
2489                }
2490            });
2491            group_field_idents.push(method_ident);
2492        }
2493    }
2494
2495    // ---- Phase 8: Emit the final token stream ----
2496    let validator_tokens = quote! {
2497        let _sql_forge_validator = || {
2498            #( #validator_param_bindings )*
2499            #( #grouped_validator_invocations )*
2500        };
2501    };
2502
2503    if !is_grouped_result {
2504        let single_query_value_ident = format_ident!("__sql_forge_value_single");
2505        return quote! {
2506            {
2507                #validator_tokens
2508                #( #generated_query_defs )*
2509                #( #generated_query_values )*
2510                #single_query_value_ident
2511            }
2512        }
2513        .into();
2514    }
2515
2516    let group_field_inits: Vec<TokenStream2> = result_cases
2517        .iter()
2518        .filter_map(|(key, _, _)| key.as_ref())
2519        .map(|key| {
2520            let method_ident = format_ident!("{}", key);
2521            let query_value_ident = format_ident!("__sql_forge_value_{}", key);
2522            quote! { #method_ident: #query_value_ident }
2523        })
2524        .collect();
2525
2526    quote! {
2527        {
2528            #validator_tokens
2529
2530            #( #generated_query_defs )*
2531            #( #generated_query_values )*
2532
2533            struct __SqlForgeQueryGroup<'args> {
2534                #( #group_field_defs, )*
2535            }
2536
2537            impl<'args> __SqlForgeQueryGroup<'args> {
2538                #( #group_method_defs )*
2539
2540                pub fn into_parts(self) -> ( #( #group_field_tys ),* ) {
2541                    ( #( self.#group_field_idents ),* )
2542                }
2543            }
2544
2545            impl<'args> sql_forge::SqlForgeQueryGroup for __SqlForgeQueryGroup<'args> {
2546                type Db = #db;
2547            }
2548
2549            #( #group_trait_impls )*
2550
2551            __SqlForgeQueryGroup {
2552                #( #group_field_inits, )*
2553            }
2554        }
2555    }
2556    .into()
2557}
2558
2559/// Expands to the database type from the `SQL_FORGE_DB_TYPE` environment variable,
2560/// falling back to `[package.metadata.sql_forge]` in `Cargo.toml`.
2561///
2562/// ```rust,ignore
2563/// use sql_forge::db_type;
2564///
2565/// pub type AppDb = db_type!();
2566/// // expands to the type set via SQL_FORGE_DB_TYPE or Cargo.toml metadata
2567/// ```
2568///
2569/// Priority:
2570/// 1. `SQL_FORGE_DB_TYPE` env var (e.g. `sqlx::MySql`, `sqlx::Postgres`)
2571/// 2. `[package.metadata.sql_forge] db = "..."` in `Cargo.toml`
2572#[proc_macro]
2573pub fn db_type(input: TokenStream) -> TokenStream {
2574    if !input.is_empty() {
2575        return syn::Error::new(Span::call_site(), "db_type!() takes no arguments")
2576            .to_compile_error()
2577            .into();
2578    }
2579
2580    match resolve_db_from_env() {
2581        Ok(db) => quote! { #db }.into(),
2582        Err(msg) => syn::Error::new(Span::call_site(), msg)
2583            .to_compile_error()
2584            .into(),
2585    }
2586}