Skip to main content

sql_forge_macro/
lib.rs

1// =============================================================================
2// Imports
3// =============================================================================
4
5use proc_macro::TokenStream;
6use proc_macro2::{Delimiter, Group, Span, TokenStream as TokenStream2, TokenTree};
7use quote::{format_ident, quote};
8use std::collections::{HashMap, HashSet};
9use std::fmt::Write;
10use std::fs;
11use std::path::Path;
12use syn::parse::{Parse, ParseStream};
13use syn::parse_quote;
14use syn::spanned::Spanned;
15use syn::{
16    Expr, ExprBlock, ExprGroup, ExprLit, ExprParen, Fields, Ident, ItemStruct, Lit, LitStr, Pat,
17    Stmt, Token, Type,
18};
19
20// =============================================================================
21// Input types
22// =============================================================================
23
24mod kw {
25    syn::custom_keyword!(scalar);
26}
27
28/// A `:name = expr` parameter binding.
29#[derive(Clone)]
30struct ParamAssign {
31    name: Ident,
32    expr: Expr,
33}
34
35impl Parse for ParamAssign {
36    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
37        input.parse::<Token![:]>()?;
38        let name: Ident = input.parse()?;
39        input.parse::<Token![=]>()?;
40        let expr: Expr = input.parse()?;
41        Ok(Self { name, expr })
42    }
43}
44
45/// A single section value: SQL string + optional local parameter bindings.
46#[derive(Clone)]
47struct SectionFragment {
48    sql: String,
49    span: Span,
50    params: ParamsSource,
51}
52
53/// One arm in a `match { ... }` inside a section assignment.
54#[derive(Clone)]
55struct SectionMatchArm {
56    pat: Pat,
57    guard: Option<Expr>,
58    value: SectionValue,
59}
60
61/// The right-hand side of a section `#name = ...` assignment.
62#[derive(Clone)]
63enum SectionValue {
64    /// A plain string (or `(string, params)` tuple).
65    Single(SectionFragment),
66    /// A tuple of values, one per section when using `#(a, b) = ...`.
67    Grouped(Vec<SectionValue>),
68    /// A `match expr { arm => ..., arm => ... }` expression.
69    Match {
70        expr: Expr,
71        arms: Vec<SectionMatchArm>,
72    },
73}
74
75/// A full `#name = value` or `#(a, b) = value` section assignment.
76#[derive(Clone)]
77struct SectionAssign {
78    names: Vec<Ident>,
79    value: SectionValue,
80}
81
82impl Parse for SectionAssign {
83    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
84        input.parse::<Token![#]>()?;
85
86        // Parse `#(a, b)` for grouped sections, or `#name` for single.
87        let names = if input.peek(syn::token::Paren) {
88            let content;
89            syn::parenthesized!(content in input);
90            let mut out = Vec::new();
91            while !content.is_empty() {
92                out.push(content.parse::<Ident>()?);
93                if content.is_empty() {
94                    break;
95                }
96                content.parse::<Token![,]>()?;
97            }
98            if out.is_empty() {
99                return Err(input.error("sql_forge!: grouped section key list cannot be empty"));
100            }
101            out
102        } else {
103            vec![input.parse::<Ident>()?]
104        };
105
106        input.parse::<Token![=]>()?;
107        let value = parse_section_value(input, names.len())?;
108        Ok(Self { names, value })
109    }
110}
111
112/// The fully-parsed macro invocation.
113struct SqlForgeInput {
114    db: Option<Type>,
115    result: ResultSpec,
116    force_scalar: bool,
117    sql: SqlTemplate,
118    params: ParamsSource,
119    sections: Vec<SectionAssign>,
120    batch: Option<Expr>,
121}
122
123/// One entry in a result map `(>key = Model, ...)`.
124#[derive(Clone)]
125struct ResultAssign {
126    name: Ident,
127    model: Type,
128    force_scalar: bool,
129}
130
131#[derive(Clone)]
132enum ResultSpec {
133    /// Execute-only (no model), e.g. `sql_forge!("SQL", ...)`
134    None,
135    /// e.g. `sql_forge!(User, ...)`
136    Single(Box<Type>),
137    /// e.g. `sql_forge!((>a = X, >b = Y), ...)`
138    Group(Vec<ResultAssign>),
139}
140
141#[derive(Clone)]
142enum ParamsSource {
143    None,
144    /// `( :name = expr, ... )`
145    Map(Vec<ParamAssign>),
146    /// A struct expression whose fields are matched to `:name` placeholders.
147    Struct(Box<Expr>),
148}
149
150/// The SQL template. Only string literals are supported.
151enum SqlTemplate {
152    Literal(LitStr),
153}
154
155impl SqlTemplate {
156    fn span(&self) -> Span {
157        match self {
158            Self::Literal(lit) => lit.span(),
159        }
160    }
161
162    /// Parse the SQL into a sequence of textual segments and `{#section}` slots.
163    fn into_segments(self) -> Result<Vec<Segment>, String> {
164        match self {
165            Self::Literal(lit) => parse_literal_segments(&lit.value()),
166        }
167    }
168}
169
170fn parse_sql_template(input: ParseStream<'_>) -> syn::Result<SqlTemplate> {
171    if input.peek(LitStr) {
172        Ok(SqlTemplate::Literal(input.parse::<LitStr>()?))
173    } else {
174        Err(input.error("sql_forge!: SQL template must be a string literal"))
175    }
176}
177
178/// One piece of the parsed SQL template.
179#[derive(Clone)]
180enum Segment {
181    /// Plain SQL text (may contain `:param` placeholders).
182    Text(String),
183    /// A `{#section_name}` placeholder.
184    Section { name: String },
185    /// A `{( ... )}` batch value template (repeated per batch item).
186    Batch { parts: Vec<TextPart> },
187}
188
189/// A fragment of SQL text after splitting on `:param`.
190#[derive(Clone)]
191enum TextPart {
192    /// Literal SQL text.
193    Lit(String),
194    /// A `:param` placeholder.
195    Param { name: String, is_list: bool },
196}
197
198// =============================================================================
199// Parsing helpers
200// =============================================================================
201
202/// Used by `detect_parenthesized_map_kind` to identify what a `(...)` argument is.
203enum MapKind {
204    Results,
205    Params,
206    Sections,
207}
208
209// =============================================================================
210// Section value parsing (string fragments, match, grouped tuples)
211// =============================================================================
212
213/// Determines whether a parenthesized argument is a result map, params map, or sections map.
214fn detect_parenthesized_map_kind(input: ParseStream<'_>) -> syn::Result<Option<MapKind>> {
215    let fork = input.fork();
216    let content;
217    syn::parenthesized!(content in fork);
218
219    if content.is_empty() {
220        return Err(input.error("sql_forge!: map argument cannot be empty"));
221    }
222
223    if content.peek(Token![>]) {
224        Ok(Some(MapKind::Results))
225    } else if content.peek(Token![:]) {
226        Ok(Some(MapKind::Params))
227    } else if content.peek(Token![#]) {
228        Ok(Some(MapKind::Sections))
229    } else {
230        Ok(None)
231    }
232}
233
234impl Parse for ResultAssign {
235    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
236        input.parse::<Token![>]>()?;
237        let name: Ident = input.parse()?;
238        input.parse::<Token![=]>()?;
239        let (force_scalar, model) = if input.peek(kw::scalar) {
240            input.parse::<kw::scalar>()?;
241            (true, input.parse::<Type>()?)
242        } else {
243            (false, input.parse::<Type>()?)
244        };
245        Ok(Self {
246            name,
247            model,
248            force_scalar,
249        })
250    }
251}
252
253/// Parses a parenthesized result map like `(>name = Model, ...)`.
254fn parse_result_map(input: ParseStream<'_>) -> syn::Result<Vec<ResultAssign>> {
255    let content;
256    syn::parenthesized!(content in input);
257
258    let mut results = Vec::new();
259    while !content.is_empty() {
260        results.push(content.parse::<ResultAssign>()?);
261        if content.is_empty() {
262            break;
263        }
264        content.parse::<Token![,]>()?;
265    }
266
267    if results.is_empty() {
268        return Err(input.error("sql_forge!: result map cannot be empty"));
269    }
270
271    Ok(results)
272}
273
274/// Parses a parenthesized parameter map like `(:name = expr, ...)`.
275fn parse_param_map(input: ParseStream<'_>) -> syn::Result<Vec<ParamAssign>> {
276    let content;
277    syn::parenthesized!(content in input);
278
279    let mut params = Vec::new();
280    while !content.is_empty() {
281        params.push(content.parse::<ParamAssign>()?);
282        if content.is_empty() {
283            break;
284        }
285        content.parse::<Token![,]>()?;
286    }
287
288    Ok(params)
289}
290
291/// Parses a parenthesized section map like `(#name = expr, ...)`.
292fn parse_section_map(input: ParseStream<'_>) -> syn::Result<Vec<SectionAssign>> {
293    let content;
294    syn::parenthesized!(content in input);
295
296    let mut sections = Vec::new();
297    while !content.is_empty() {
298        sections.push(content.parse::<SectionAssign>()?);
299        if content.is_empty() {
300            break;
301        }
302        content.parse::<Token![,]>()?;
303    }
304
305    Ok(sections)
306}
307
308/// Parses a parameter source: either a struct expression or a parenthesized map.
309fn parse_params_source_expr(input: ParseStream<'_>) -> syn::Result<ParamsSource> {
310    if input.peek(syn::token::Paren) {
311        match detect_parenthesized_map_kind(input)? {
312            Some(MapKind::Results) => Err(input
313                .error("sql_forge!: result maps are only allowed as the macro result argument")),
314            Some(MapKind::Params) => Ok(ParamsSource::Map(parse_param_map(input)?)),
315            Some(MapKind::Sections) => Err(input.error(
316                "sql_forge!: use :name = expr for section-local parameters, not #name = expr",
317            )),
318            None => Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?))),
319        }
320    } else {
321        Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?)))
322    }
323}
324
325/// Parses a section fragment: a string literal or `(string_literal, params)` tuple.
326fn parse_section_fragment(input: ParseStream<'_>) -> syn::Result<SectionFragment> {
327    if input.peek(syn::token::Paren) {
328        let fork = input.fork();
329        let content;
330        syn::parenthesized!(content in fork);
331
332        if let Ok(first_expr) = content.parse::<Expr>() {
333            if extract_lit_str(&first_expr).is_some() && content.parse::<Token![,]>().is_ok() {
334                let _ = parse_params_source_expr(&content)?;
335                if content.peek(Token![,]) {
336                    content.parse::<Token![,]>()?;
337                }
338                if content.is_empty() {
339                    let content;
340                    syn::parenthesized!(content in input);
341                    let first_expr: Expr = content.parse()?;
342                    let sql = extract_lit_str(&first_expr).ok_or_else(|| {
343                        input.error("sql_forge!: section tuple must start with a string literal")
344                    })?;
345                    let span = first_expr.span();
346                    content.parse::<Token![,]>()?;
347                    let params = parse_params_source_expr(&content)?;
348                    if content.peek(Token![,]) {
349                        content.parse::<Token![,]>()?;
350                    }
351                    if !content.is_empty() {
352                        return Err(content.error(
353                            "sql_forge!: unexpected tokens after section-local parameter source",
354                        ));
355                    }
356                    return Ok(SectionFragment { sql, span, params });
357                }
358            }
359        }
360    }
361
362    let expr: Expr = input.parse()?;
363    let sql = extract_lit_str(&expr).ok_or_else(|| {
364        input
365            .error("sql_forge!: section values must be string literals or (string literal, params)")
366    })?;
367    Ok(SectionFragment {
368        sql,
369        span: expr.span(),
370        params: ParamsSource::None,
371    })
372}
373
374/// Parses a section value which may be a single fragment, grouped tuple, or match expression.
375fn parse_section_value(input: ParseStream<'_>, width: usize) -> syn::Result<SectionValue> {
376    if input.peek(Token![match]) {
377        input.parse::<Token![match]>()?;
378        let expr: Expr = input.call(Expr::parse_without_eager_brace)?;
379        let content;
380        syn::braced!(content in input);
381        let mut arms = Vec::new();
382        while !content.is_empty() {
383            let pat = content.call(Pat::parse_multi_with_leading_vert)?;
384            let guard = if content.peek(Token![if]) {
385                content.parse::<Token![if]>()?;
386                Some(content.parse::<Expr>()?)
387            } else {
388                None
389            };
390            content.parse::<Token![=>]>()?;
391            // Recursive call in the match arm value allows nested match expressions.
392            let value = parse_section_value(&content, width)?;
393            if content.peek(Token![,]) {
394                content.parse::<Token![,]>()?;
395            }
396            arms.push(SectionMatchArm { pat, guard, value });
397        }
398        return Ok(SectionValue::Match { expr, arms });
399    }
400
401    if input.peek(Token![if]) {
402        input.parse::<Token![if]>()?;
403
404        let (pat, expr) = if input.peek(Token![let]) {
405            // if let PAT = EXPR { value } else { value }
406            input.parse::<Token![let]>()?;
407            let pat: Pat = input.call(Pat::parse_single)?;
408            input.parse::<Token![=]>()?;
409            let expr: Expr = input.call(Expr::parse_without_eager_brace)?;
410            (pat, expr)
411        } else {
412            // if EXPR { value } else { value }
413            let expr: Expr = input.call(Expr::parse_without_eager_brace)?;
414            let pat: Pat = parse_quote! { true };
415            (pat, expr)
416        };
417
418        let true_content;
419        syn::braced!(true_content in input);
420        let true_value = parse_section_value(&true_content, width)?;
421        input.parse::<Token![else]>()?;
422
423        let false_value = if input.peek(Token![if]) {
424            parse_section_value(input, width)?
425        } else {
426            let false_content;
427            syn::braced!(false_content in input);
428            parse_section_value(&false_content, width)?
429        };
430
431        let wild_pat: Pat = parse_quote! { _ };
432
433        let arms = vec![
434            SectionMatchArm {
435                pat,
436                guard: None,
437                value: true_value,
438            },
439            SectionMatchArm {
440                pat: wild_pat,
441                guard: None,
442                value: false_value,
443            },
444        ];
445        return Ok(SectionValue::Match { expr, arms });
446    }
447
448    if width == 1 {
449        return Ok(SectionValue::Single(parse_section_fragment(input)?));
450    }
451
452    let content;
453    syn::parenthesized!(content in input);
454    let mut items = Vec::new();
455    while !content.is_empty() {
456        // Recursive call for a section value tuple item (representing a single section).
457        // Allows match expressions returning a single section.
458        items.push(parse_section_value(&content, 1)?);
459        if content.is_empty() {
460            break;
461        }
462        content.parse::<Token![,]>()?;
463    }
464
465    if items.len() != width {
466        return Err(input.error(format!(
467            "sql_forge!: grouped section value must provide exactly {} items",
468            width,
469        )));
470    }
471
472    Ok(SectionValue::Grouped(items))
473}
474
475// =============================================================================
476// Top-level macro input parsing (SqlForgeInput::parse)
477// =============================================================================
478
479impl Parse for SqlForgeInput {
480    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
481        let (db, result, force_scalar, sql) = if input.peek(LitStr) {
482            let sql = parse_sql_template(input)?;
483            (None, ResultSpec::None, false, sql)
484        } else if input.peek(kw::scalar) {
485            input.parse::<kw::scalar>()?;
486            let model: Type = input.parse()?;
487            input.parse::<Token![,]>()?;
488            let sql = parse_sql_template(input)?;
489            (None, ResultSpec::Single(Box::new(model)), true, sql)
490        } else if input.peek(syn::token::Paren) {
491            let result_map_kind = detect_parenthesized_map_kind(input)?;
492            match result_map_kind {
493                Some(MapKind::Results) => {
494                    let result = ResultSpec::Group(parse_result_map(input)?);
495                    input.parse::<Token![,]>()?;
496                    let sql = parse_sql_template(input)?;
497                    (None, result, false, sql)
498                }
499                _ => {
500                    return Err(input.error(
501                        "sql_forge!: expected a result map like (>name = Model, ...) or a model type",
502                    ));
503                }
504            }
505        } else {
506            let first_ty: Type = input.parse()?;
507            input.parse::<Token![,]>()?;
508
509            if input.peek(LitStr) && is_db_type(&first_ty) {
510                let db = first_ty;
511                let sql = parse_sql_template(input)?;
512                (Some(db), ResultSpec::None, false, sql)
513            } else if input.peek(LitStr) {
514                let model = first_ty;
515                let sql = parse_sql_template(input)?;
516                (None, ResultSpec::Single(Box::new(model)), false, sql)
517            } else if input.peek(kw::scalar) {
518                input.parse::<kw::scalar>()?;
519                let model: Type = input.parse()?;
520                input.parse::<Token![,]>()?;
521                let sql = parse_sql_template(input)?;
522                (
523                    Some(first_ty),
524                    ResultSpec::Single(Box::new(model)),
525                    true,
526                    sql,
527                )
528            } else if input.peek(syn::token::Paren)
529                && matches!(
530                    detect_parenthesized_map_kind(input)?,
531                    Some(MapKind::Results)
532                )
533            {
534                let result = ResultSpec::Group(parse_result_map(input)?);
535                input.parse::<Token![,]>()?;
536                let sql = parse_sql_template(input)?;
537                (Some(first_ty), result, false, sql)
538            } else {
539                let db = Some(first_ty);
540                let model: Type = input.parse()?;
541                input.parse::<Token![,]>()?;
542                let sql = parse_sql_template(input)?;
543                (db, ResultSpec::Single(Box::new(model)), false, sql)
544            }
545        };
546
547        let mut batch = None;
548        let mut params = ParamsSource::None;
549        let mut sections = Vec::new();
550        let mut seen_params = false;
551        let mut seen_sections = false;
552
553        if input.parse::<Token![,]>().is_ok() {
554            while !input.is_empty() {
555                if input.peek(Token![..]) {
556                    if batch.is_some() {
557                        return Err(
558                            input.error("sql_forge!: only one batch source argument is allowed")
559                        );
560                    }
561                    input.parse::<Token![..]>()?;
562                    batch = Some(input.parse::<Expr>()?);
563                } else if input.peek(syn::token::Paren) {
564                    match detect_parenthesized_map_kind(input)? {
565                        Some(MapKind::Results) => {
566                            return Err(input.error(
567                                "sql_forge!: result maps are only allowed as the macro result argument",
568                            ));
569                        }
570                        Some(MapKind::Params) => {
571                            if seen_params {
572                                return Err(
573                                    input.error("sql_forge!: only one parameter source is allowed")
574                                );
575                            }
576                            params = ParamsSource::Map(parse_param_map(input)?);
577                            seen_params = true;
578                        }
579                        Some(MapKind::Sections) => {
580                            if seen_sections {
581                                return Err(
582                                    input.error("sql_forge!: duplicate section map argument")
583                                );
584                            }
585                            sections = parse_section_map(input)?;
586                            seen_sections = true;
587                        }
588                        None => {
589                            if seen_params {
590                                return Err(
591                                    input.error("sql_forge!: only one parameter source is allowed")
592                                );
593                            }
594                            params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
595                            seen_params = true;
596                        }
597                    }
598                } else {
599                    if seen_params {
600                        return Err(input.error("sql_forge!: only one parameter source is allowed"));
601                    }
602                    params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
603                    seen_params = true;
604                }
605
606                if input.parse::<Token![,]>().is_ok() {
607                    continue;
608                }
609                break;
610            }
611        }
612
613        if !input.is_empty() {
614            return Err(input.error("sql_forge!: unexpected tokens in macro invocation"));
615        }
616
617        Ok(Self {
618            db,
619            result,
620            force_scalar,
621            sql,
622            params,
623            sections,
624            batch,
625        })
626    }
627}
628
629// =============================================================================
630// Database type resolution
631// =============================================================================
632
633/// Resolves the database type from `SQL_FORGE_DB_TYPE` env var or `Cargo.toml` metadata.
634fn resolve_db_from_env() -> Result<Type, String> {
635    if let Ok(val) = std::env::var("SQL_FORGE_DB_TYPE") {
636        return syn::parse_str::<Type>(&val).map_err(|err| {
637            format!(
638                "sql_forge!: invalid DB type `{}` in SQL_FORGE_DB_TYPE env var: {}",
639                val, err
640            )
641        });
642    }
643
644    let manifest_dir = match std::env::var("CARGO_MANIFEST_DIR") {
645        Ok(d) => d,
646        Err(_) => {
647            return Err(
648                "sql_forge!: pass DB as first macro argument, set SQL_FORGE_DB_TYPE, \
649                 or configure [package.metadata.sql_forge] in Cargo.toml"
650                    .to_string(),
651            );
652        }
653    };
654    let manifest_path = Path::new(&manifest_dir).join("Cargo.toml");
655
656    let cargo_toml = fs::read_to_string(&manifest_path).map_err(|err| {
657        format!(
658            "sql_forge!: failed to read {}: {}",
659            manifest_path.display(),
660            err
661        )
662    })?;
663
664    let value: toml::Value = toml::from_str(&cargo_toml)
665        .map_err(|err| format!("sql_forge!: failed to parse Cargo.toml: {}", err))?;
666
667    let db_str = value
668        .get("package")
669        .and_then(|v| v.get("metadata"))
670        .and_then(|v| v.get("sql_forge"))
671        .and_then(|v| v.get("db"))
672        .and_then(|v| v.as_str())
673        .ok_or({
674            "sql_forge!: missing [package.metadata.sql_forge] db = \"...\" in Cargo.toml, \
675             SQL_FORGE_DB_TYPE env var, or DB as first macro argument"
676        })?;
677
678    syn::parse_str::<Type>(db_str).map_err(|err| {
679        format!(
680            "sql_forge!: invalid DB type `{}` in Cargo.toml metadata: {}",
681            db_str, err
682        )
683    })
684}
685
686/// Returns true if the database type uses dollar-parameter placeholders (Postgres).
687fn uses_dollar_params(db: &Type) -> bool {
688    let Type::Path(type_path) = db else {
689        return false;
690    };
691    type_path
692        .path
693        .segments
694        .last()
695        .is_some_and(|s| s.ident == "Postgres")
696}
697
698/// Returns true if the type is a sqlx database type (sqlx::MySql, sqlx::Postgres, sqlx::Sqlite).
699fn is_db_type(ty: &Type) -> bool {
700    let Type::Path(type_path) = ty else {
701        return false;
702    };
703    if type_path.qself.is_some() {
704        return false;
705    }
706    let segs = &type_path.path.segments;
707    if segs.len() != 2 {
708        return false;
709    }
710    segs[0].ident == "sqlx"
711        && ["MySql", "Postgres", "Sqlite"].contains(&segs[1].ident.to_string().as_str())
712}
713
714/// Returns true if a type is a standard Rust built-in scalar (i32, String, etc.).
715fn is_builtin_scalar_type(ty: &Type) -> bool {
716    let Type::Path(type_path) = ty else {
717        return false;
718    };
719
720    if type_path.qself.is_some()
721        || type_path.path.leading_colon.is_some()
722        || type_path.path.segments.len() != 1
723    {
724        return false;
725    }
726
727    let ident = &type_path.path.segments[0].ident;
728    ident == "i8"
729        || ident == "i16"
730        || ident == "i32"
731        || ident == "i64"
732        || ident == "isize"
733        || ident == "u8"
734        || ident == "u16"
735        || ident == "u32"
736        || ident == "u64"
737        || ident == "usize"
738        || ident == "f32"
739        || ident == "f64"
740        || ident == "bool"
741        || ident == "String"
742}
743
744/// If the given model type is a built-in scalar, returns the type; otherwise None.
745fn scalar_output_type(model: &Type) -> Option<&Type> {
746    if is_builtin_scalar_type(model) {
747        return Some(model);
748    }
749    None
750}
751
752/// Appends text to the last `Segment::Text` or pushes a new one.
753fn push_text_segment(out: &mut Vec<Segment>, text: String) {
754    if text.is_empty() {
755        return;
756    }
757    match out.last_mut() {
758        Some(Segment::Text(existing)) => existing.push_str(&text),
759        _ => out.push(Segment::Text(text)),
760    }
761}
762
763/// Parses a SQL literal into segments (text, `{#section}`, `{(batch)}`).
764fn parse_literal_segments(sql: &str) -> Result<Vec<Segment>, String> {
765    let mut out = Vec::new();
766    let mut text = String::new();
767    let mut chars = sql.chars().peekable();
768
769    while let Some(ch) = chars.next() {
770        if ch != '{' {
771            text.push(ch);
772            continue;
773        }
774
775        if chars.peek() == Some(&'(') {
776            push_text_segment(&mut out, std::mem::take(&mut text));
777
778            let mut paren_depth = 0u32;
779            let mut content = String::new();
780            let mut found_close = false;
781            for ch in chars.by_ref() {
782                if ch == '{' {
783                    return Err(
784                        "sql_forge!: nested braces not allowed inside batch section".to_string()
785                    );
786                }
787                if ch == '}' {
788                    if paren_depth != 0 {
789                        return Err(
790                            "sql_forge!: batch section {( ... )} has unbalanced parentheses"
791                                .to_string(),
792                        );
793                    }
794                    found_close = true;
795                    break;
796                }
797                if ch == '(' {
798                    paren_depth += 1;
799                } else if ch == ')' {
800                    if paren_depth == 0 {
801                        return Err(
802                            "sql_forge!: batch section {( ... )} has unbalanced parentheses"
803                                .to_string(),
804                        );
805                    }
806                    paren_depth -= 1;
807                }
808                content.push(ch);
809            }
810            if !found_close {
811                return Err("sql_forge!: batch section {( ... )} without closing }".to_string());
812            }
813            let parts = parse_text_parts(&content);
814            for part in &parts {
815                if let TextPart::Param { is_list: true, .. } = part {
816                    return Err(
817                        "sql_forge!: list parameters (:name[]) are not allowed inside {( ... )} \
818                         batch sections; use plain parameters (:name) instead"
819                            .to_string(),
820                    );
821                }
822            }
823            out.push(Segment::Batch { parts });
824            continue;
825        }
826
827        if chars.peek() != Some(&'#') {
828            text.push(ch);
829            continue;
830        }
831
832        chars.next();
833        push_text_segment(&mut out, std::mem::take(&mut text));
834
835        let mut name = String::new();
836        loop {
837            let Some(next) = chars.next() else {
838                return Err("sql_forge!: section placeholder without closing }".to_string());
839            };
840            if next == '}' {
841                break;
842            }
843            name.push(next);
844        }
845
846        if name.is_empty() {
847            return Err("sql_forge!: empty section placeholder name".to_string());
848        }
849
850        out.push(Segment::Section { name });
851    }
852
853    push_text_segment(&mut out, text);
854    Ok(out)
855}
856
857// =============================================================================
858// SQL template parsing: {#sections} and :param placeholders
859// =============================================================================
860
861/// Returns true if the character can start an identifier.
862fn is_ident_start(ch: char) -> bool {
863    ch == '_' || ch.is_ascii_alphabetic()
864}
865
866/// Returns true if the character can continue an identifier.
867fn is_ident_continue(ch: char) -> bool {
868    is_ident_start(ch) || ch.is_ascii_digit()
869}
870
871/// Strips suffix markers from a backtick-quoted alias identifier.
872fn sanitize_backticked_alias_ident(content: &str) -> String {
873    let mut split_at = content.len();
874    for (idx, ch) in content.char_indices() {
875        if ch == '!' || ch == '?' || ch == ':' {
876            split_at = idx;
877            break;
878        }
879    }
880
881    if split_at == content.len() {
882        return content.to_string();
883    }
884
885    let base = content[..split_at].trim_end();
886    if base.is_empty() {
887        content.to_string()
888    } else {
889        base.to_string()
890    }
891}
892
893/// Sanitizes backtick-quoted alias identifiers in runtime SQL text.
894fn sanitize_runtime_sql_text(text: &str) -> String {
895    let mut out = String::with_capacity(text.len());
896    let mut chars = text.chars().peekable();
897
898    while let Some(ch) = chars.next() {
899        if ch != '`' {
900            out.push(ch);
901            continue;
902        }
903
904        let mut content = String::new();
905        let mut closed = false;
906
907        for next in chars.by_ref() {
908            if next == '`' {
909                closed = true;
910                break;
911            }
912            content.push(next);
913        }
914
915        if closed {
916            out.push('`');
917            out.push_str(&sanitize_backticked_alias_ident(&content));
918            out.push('`');
919        } else {
920            out.push('`');
921            out.push_str(&content);
922            break;
923        }
924    }
925
926    out
927}
928
929/// Splits SQL text into literal text and `:param` placeholder parts.
930fn parse_text_parts(text: &str) -> Vec<TextPart> {
931    let mut parts = Vec::new();
932    let mut last = 0usize;
933    let mut iter = text.char_indices().peekable();
934
935    while let Some((idx, ch)) = iter.next() {
936        if ch != ':' {
937            continue;
938        }
939
940        let Some(&(next_idx, next_ch)) = iter.peek() else {
941            continue;
942        };
943
944        if !is_ident_start(next_ch) {
945            continue;
946        }
947
948        if text[..idx].ends_with(':') {
949            continue;
950        }
951
952        if last < idx {
953            parts.push(TextPart::Lit(text[last..idx].to_string()));
954        }
955
956        iter.next();
957
958        let mut name = String::new();
959        name.push(next_ch);
960        let mut end = next_idx + next_ch.len_utf8();
961
962        while let Some(&(j, c)) = iter.peek() {
963            if is_ident_continue(c) {
964                name.push(c);
965                end = j + c.len_utf8();
966                iter.next();
967            } else {
968                break;
969            }
970        }
971
972        let mut is_list = false;
973        if text[end..].starts_with("[]") {
974            is_list = true;
975            end += 2;
976        }
977
978        parts.push(TextPart::Param { name, is_list });
979        last = end;
980    }
981
982    if last < text.len() {
983        parts.push(TextPart::Lit(text[last..].to_string()));
984    }
985
986    parts
987}
988
989/// Renders SQL text for the compile-time validator, replacing `:param`/`:param[]`
990/// with DB-specific placeholders (`?` or `$N`).
991///
992/// When `batch_expr` is `Some`, each `:param` is treated as a batch-item field
993/// access (`batch_expr[0].field_ident`) and returned in `batch_args` instead of
994/// `occurrences`. List params (`:param[]`) are not allowed in batch sections
995/// (caught earlier during parsing, but defended here for safety).
996///
997/// Returns `(rendered_sql, occurrences, batch_args)` — `occurrences` is empty
998/// in batch mode and `batch_args` is empty in non-batch mode.
999#[allow(clippy::type_complexity)]
1000fn render_validator_sql(
1001    parts: &[TextPart],
1002    use_dollar_params: bool,
1003    param_offset: &mut usize,
1004    list_count: usize,
1005    batch_expr: Option<TokenStream2>,
1006) -> Result<(String, Vec<(String, bool)>, Vec<TokenStream2>), TokenStream> {
1007    let mut out_sql = String::new();
1008    let mut occurrences = Vec::new();
1009    let mut batch_args = Vec::new();
1010
1011    for part in parts {
1012        match part {
1013            TextPart::Lit(lit) => out_sql.push_str(lit),
1014            TextPart::Param { name, is_list } => {
1015                if let Some(ref batch_expr) = batch_expr {
1016                    if *is_list {
1017                        return Err(syn::Error::new(
1018                            Span::call_site(),
1019                            "sql_forge!: list parameters (:name[]) are not allowed inside {( ... )} \
1020                             batch sections; use plain parameters (:name) instead",
1021                        )
1022                        .to_compile_error()
1023                        .into());
1024                    }
1025                    let field_ident = format_ident!("{}", name);
1026                    if use_dollar_params {
1027                        *param_offset += 1;
1028                        write!(out_sql, "${}", *param_offset).unwrap();
1029                    } else {
1030                        out_sql.push('?');
1031                    }
1032                    batch_args.push(quote! { #batch_expr[0].#field_ident });
1033                } else if *is_list && list_count > 1 {
1034                    let slots: Vec<String> = if use_dollar_params {
1035                        (0..list_count)
1036                            .map(|i| format!("${}", *param_offset + i + 1))
1037                            .collect()
1038                    } else {
1039                        (0..list_count).map(|_| "?".to_string()).collect()
1040                    };
1041                    if use_dollar_params {
1042                        *param_offset += list_count;
1043                    }
1044                    out_sql.push_str(&slots.join(", "));
1045                    occurrences.push((name.clone(), *is_list));
1046                } else {
1047                    if use_dollar_params {
1048                        *param_offset += 1;
1049                        write!(out_sql, "${}", *param_offset).unwrap();
1050                    } else {
1051                        out_sql.push('?');
1052                    }
1053                    occurrences.push((name.clone(), *is_list));
1054                }
1055            }
1056        }
1057    }
1058
1059    Ok((out_sql, occurrences, batch_args))
1060}
1061
1062/// Recursively strips syntactic wrapper expressions to reach the inner expression.
1063///
1064/// # Wrapper kinds
1065///
1066/// | Variant          | Rust syntax           | When it appears                                |
1067/// |------------------|-----------------------|------------------------------------------------|
1068/// | `Expr::Paren`    | `(expr)`              | Parenthesized expression in source code. syn   |
1069/// |                  |                       | preserves parens so tokens round-trip.         |
1070/// | `Expr::Group`    | `{ expr }`            | *Compiler-inserted* invisible group. syn uses  |
1071/// |                  |                       | this when a proc-macro receives braced tokens  |
1072/// |                  |                       | containing a single expression as `proc_macro! |
1073/// |                  |                       | { expr }`. Unlike `Paren`, `Group` does not    |
1074/// |                  |                       | correspond to parens the user wrote.           |
1075/// | `Expr::Block`    | `{ stmts }` or        | A real block expression the user wrote —       |
1076/// |                  | `{ stmts; }`          | `{ x; y }`. Stripped only when it contains     |
1077/// |                  |                       | exactly one `Stmt::Expr` without a trailing    |
1078/// |                  |                       | semicolon. Multi-statement blocks and non-expr |
1079/// |                  |                       | statements would change semantics if stripped, |
1080/// |                  |                       | so they are returned as-is.                    |
1081///
1082/// Returns the original expression unchanged when none of these patterns match, making
1083/// it safe to call unconditionally before pattern-matching against the expression kind.
1084fn strip_expr(expr: &Expr) -> &Expr {
1085    match expr {
1086        Expr::Paren(ExprParen { expr, .. }) => strip_expr(expr),
1087        Expr::Group(ExprGroup { expr, .. }) => strip_expr(expr),
1088        Expr::Block(ExprBlock { block, .. }) => {
1089            if block.stmts.len() != 1 {
1090                return expr;
1091            }
1092            match &block.stmts[0] {
1093                Stmt::Expr(inner, None) => strip_expr(inner),
1094                _ => expr,
1095            }
1096        }
1097        _ => expr,
1098    }
1099}
1100
1101/// Extracts the string value from a string literal expression.
1102fn extract_lit_str(expr: &Expr) -> Option<String> {
1103    match strip_expr(expr) {
1104        Expr::Lit(ExprLit {
1105            lit: Lit::Str(lit), ..
1106        }) => Some(lit.value()),
1107        _ => None,
1108    }
1109}
1110
1111// =============================================================================
1112// Preprocessing: {>key} compile-time result flags
1113// =============================================================================
1114
1115/// Generates a `__sql_forge_result_flag_{name}` identifier for `{>name}` result key flags.
1116fn result_flag_ident(name: &str) -> syn::Ident {
1117    format_ident!("__sql_forge_result_flag_{}", name)
1118}
1119
1120/// Replaces `{>key}` tokens inside braced groups with `__sql_forge_result_flag_key`
1121/// identifiers. This is a preprocessing step so that the rest of the parser sees
1122/// plain identifiers instead of braced groups it does not understand.
1123fn preprocess_result_key_placeholders(input: TokenStream2) -> TokenStream2 {
1124    fn walk(stream: TokenStream2) -> TokenStream2 {
1125        let mut out = TokenStream2::new();
1126        let iter = stream.into_iter().peekable();
1127
1128        for token in iter {
1129            match token {
1130                TokenTree::Group(group) => {
1131                    if group.delimiter() == Delimiter::Brace {
1132                        let mut inner = group.stream().into_iter();
1133                        let first = inner.next();
1134                        let second = inner.next();
1135                        let third = inner.next();
1136
1137                        if let (
1138                            Some(TokenTree::Punct(p)),
1139                            Some(TokenTree::Ident(name_ident)),
1140                            None,
1141                        ) = (first, second, third)
1142                        {
1143                            if p.as_char() == '>' {
1144                                let ident = result_flag_ident(&name_ident.to_string());
1145                                out.extend(std::iter::once(TokenTree::Ident(ident)));
1146                                continue;
1147                            }
1148                        }
1149                    }
1150
1151                    let new_inner = walk(group.stream());
1152                    let mut new_group = Group::new(group.delimiter(), new_inner);
1153                    new_group.set_span(group.span());
1154                    out.extend(std::iter::once(TokenTree::Group(new_group)));
1155                }
1156                other => out.extend(std::iter::once(other)),
1157            }
1158        }
1159
1160        out
1161    }
1162
1163    walk(input)
1164}
1165
1166/// Builds `let` bindings for `{>key}` result flag identifiers.
1167fn build_result_flag_bindings(keys: &[String], active_key: Option<&str>) -> Vec<TokenStream2> {
1168    keys.iter()
1169        .map(|key| {
1170            let ident = result_flag_ident(key);
1171            let enabled = Some(key.as_str()) == active_key;
1172            quote! { let #ident: bool = #enabled; }
1173        })
1174        .collect()
1175}
1176
1177/// Transposes a `[case][section]` matrix to `[section][case]`.
1178fn transpose_section_case_matrix(
1179    case_matrix: Vec<Vec<SectionFragment>>,
1180    width: usize,
1181) -> Result<Vec<Vec<SectionFragment>>, String> {
1182    let mut per_section: Vec<Vec<SectionFragment>> = (0..width).map(|_| Vec::new()).collect();
1183
1184    for row in case_matrix {
1185        if row.len() != width {
1186            return Err(
1187                "sql_forge!: grouped sections must return one item per section".to_string(),
1188            );
1189        }
1190        for (section_idx, fragment) in row.into_iter().enumerate() {
1191            per_section[section_idx].push(fragment);
1192        }
1193    }
1194
1195    Ok(per_section)
1196}
1197
1198/// Collects all section fragment cases as a `[case][section]` matrix for a given section value.
1199fn collect_section_case_matrix(
1200    value: SectionValue,
1201    width: usize,
1202    active_key: Option<&str>,
1203) -> Result<Vec<Vec<SectionFragment>>, String> {
1204    match value {
1205        SectionValue::Single(fragment) => {
1206            if width != 1 {
1207                return Err(
1208                    "sql_forge!: grouped sections must return one item per section".to_string(),
1209                );
1210            }
1211            Ok(vec![vec![fragment]])
1212        }
1213        SectionValue::Grouped(values) => {
1214            if values.len() != width {
1215                return Err(
1216                    "sql_forge!: grouped sections must return one item per section".to_string(),
1217                );
1218            }
1219
1220            let mut variants_by_section = Vec::<Vec<SectionFragment>>::with_capacity(width);
1221            let mut nmax = 1usize;
1222
1223            for value in values {
1224                let item_matrix = collect_section_case_matrix(value, 1, active_key)?;
1225                let mut item_variants = Vec::<SectionFragment>::with_capacity(item_matrix.len());
1226                for mut row in item_matrix {
1227                    let fragment = row.pop().ok_or_else(|| {
1228                        "sql_forge!: grouped sections must return one item per section".to_string()
1229                    })?;
1230                    if !row.is_empty() {
1231                        return Err(
1232                            "sql_forge!: grouped sections must return one item per section"
1233                                .to_string(),
1234                        );
1235                    }
1236                    item_variants.push(fragment);
1237                }
1238                if item_variants.is_empty() {
1239                    return Err("sql_forge!: section match must have at least one arm".to_string());
1240                }
1241                nmax = nmax.max(item_variants.len());
1242                variants_by_section.push(item_variants);
1243            }
1244
1245            let mut case_matrix = Vec::<Vec<SectionFragment>>::with_capacity(nmax);
1246            for case_idx in 0..nmax {
1247                let mut row = Vec::<SectionFragment>::with_capacity(width);
1248                for variants in &variants_by_section {
1249                    row.push(variants[case_idx % variants.len()].clone());
1250                }
1251                case_matrix.push(row);
1252            }
1253
1254            Ok(case_matrix)
1255        }
1256        SectionValue::Match { expr, arms } => {
1257            let mut case_matrix = Vec::<Vec<SectionFragment>>::new();
1258
1259            if let Some(key) = expr_result_flag_key(&expr) {
1260                let target = active_key == Some(key.as_str());
1261                for arm in arms {
1262                    if arm.guard.is_none() {
1263                        if let Some(false) = pattern_matches_bool(&arm.pat, target) {
1264                            continue;
1265                        }
1266                    }
1267                    let mut arm_cases = collect_section_case_matrix(arm.value, width, active_key)?;
1268                    wrap_section_case_matrix_for_match_arm(
1269                        &mut arm_cases,
1270                        &expr,
1271                        &arm.pat,
1272                        arm.guard.as_ref(),
1273                    );
1274                    case_matrix.extend(arm_cases);
1275                }
1276            } else {
1277                for arm in arms {
1278                    let mut arm_cases = collect_section_case_matrix(arm.value, width, active_key)?;
1279                    wrap_section_case_matrix_for_match_arm(
1280                        &mut arm_cases,
1281                        &expr,
1282                        &arm.pat,
1283                        arm.guard.as_ref(),
1284                    );
1285                    case_matrix.extend(arm_cases);
1286                }
1287            }
1288
1289            if case_matrix.is_empty() {
1290                return Err("sql_forge!: section match must have at least one arm".to_string());
1291            }
1292
1293            Ok(case_matrix)
1294        }
1295    }
1296}
1297
1298/// This function rewrites a section-local validator expression so it stays inside the same
1299/// match arm scope that originally introduced it
1300fn wrap_expr_for_match_arm(expr: Expr, match_expr: &Expr, pat: &Pat, guard: Option<&Expr>) -> Expr {
1301    let match_expr = match_expr.clone();
1302    let pat = pat.clone();
1303    let pattern_binds_values = match &pat {
1304        Pat::Ident(_) => true,
1305        Pat::Or(pat_or) => pat_or
1306            .cases
1307            .iter()
1308            .any(|case| matches!(case, Pat::Ident(_))),
1309        Pat::Paren(pat_paren) => matches!(pat_paren.pat.as_ref(), Pat::Ident(_)),
1310        Pat::Reference(pat_reference) => matches!(pat_reference.pat.as_ref(), Pat::Ident(_)),
1311        Pat::Slice(pat_slice) => pat_slice
1312            .elems
1313            .iter()
1314            .any(|elem| matches!(elem, Pat::Ident(_))),
1315        Pat::Struct(pat_struct) => pat_struct
1316            .fields
1317            .iter()
1318            .any(|field| matches!(*field.pat, Pat::Ident(_))),
1319        Pat::Tuple(pat_tuple) => pat_tuple
1320            .elems
1321            .iter()
1322            .any(|elem| matches!(elem, Pat::Ident(_))),
1323        Pat::TupleStruct(pat_tuple_struct) => pat_tuple_struct
1324            .elems
1325            .iter()
1326            .any(|elem| matches!(elem, Pat::Ident(_))),
1327        Pat::Type(pat_type) => matches!(pat_type.pat.as_ref(), Pat::Ident(_)),
1328        _ => false,
1329    };
1330
1331    if pattern_binds_values {
1332        let pat_refs: Vec<TokenStream2> = pat_var_idents(&pat)
1333            .into_iter()
1334            .map(|ident| quote! { let _ = &#ident; })
1335            .collect();
1336        if let Some(guard) = guard.cloned() {
1337            parse_quote! {
1338                match &(#match_expr) {
1339                    #pat if #guard => { #( #pat_refs )* #expr },
1340                    _ => unreachable!("sql_forge!: validator arm mismatch"),
1341                }
1342            }
1343        } else {
1344            parse_quote! {
1345                match &(#match_expr) {
1346                    #pat => { #( #pat_refs )* #expr },
1347                    _ => unreachable!("sql_forge!: validator arm mismatch"),
1348                }
1349            }
1350        }
1351    } else if let Some(guard) = guard.cloned() {
1352        parse_quote! {
1353            match &(#match_expr) {
1354                #pat if #guard => { &(#expr) },
1355                _ => unreachable!("sql_forge!: validator arm mismatch"),
1356            }
1357        }
1358    } else {
1359        parse_quote! {
1360            match &(#match_expr) {
1361                #pat => { &(#expr) },
1362                _ => unreachable!("sql_forge!: validator arm mismatch"),
1363            }
1364        }
1365    }
1366}
1367
1368/// Wraps a ParamsSource's expressions so they stay within the match arm scope.
1369fn wrap_params_source_for_match_arm(
1370    params: &mut ParamsSource,
1371    match_expr: &Expr,
1372    pat: &Pat,
1373    guard: Option<&Expr>,
1374) {
1375    match params {
1376        ParamsSource::None => {}
1377        ParamsSource::Map(entries) => {
1378            for entry in entries {
1379                entry.expr = wrap_expr_for_match_arm(entry.expr.clone(), match_expr, pat, guard);
1380            }
1381        }
1382        ParamsSource::Struct(expr) => {
1383            **expr = wrap_expr_for_match_arm((**expr).clone(), match_expr, pat, guard);
1384        }
1385    }
1386}
1387
1388/// Wraps all ParamsSource expressions in a section case matrix for the given match arm.
1389fn wrap_section_case_matrix_for_match_arm(
1390    case_matrix: &mut [Vec<SectionFragment>],
1391    match_expr: &Expr,
1392    pat: &Pat,
1393    guard: Option<&Expr>,
1394) {
1395    for row in case_matrix {
1396        for fragment in row {
1397            wrap_params_source_for_match_arm(&mut fragment.params, match_expr, pat, guard);
1398        }
1399    }
1400}
1401
1402// Returns Vec<Vec<SectionFragment>> indexed [section_idx][case_idx].
1403// =============================================================================
1404// Section variant collection
1405// =============================================================================
1406
1407/// Collects all possible `SectionFragment` values per section index. Each
1408/// returned `Vec<Vec<SectionFragment>>` is indexed `[section_idx][case_idx]`,
1409/// listing every variant that section can take across all match arms.
1410/// Used for full validation (cycling strategy over all variants).
1411fn collect_section_variants(
1412    value: SectionValue,
1413    width: usize,
1414) -> Result<Vec<Vec<SectionFragment>>, String> {
1415    transpose_section_case_matrix(collect_section_case_matrix(value, width, None)?, width)
1416}
1417
1418/// Returns the key name if the expression is a `{>key}` result flag; otherwise None.
1419fn expr_result_flag_key(expr: &Expr) -> Option<String> {
1420    match strip_expr(expr) {
1421        Expr::Path(path) if path.qself.is_none() && path.path.segments.len() == 1 => {
1422            let name = path.path.segments[0].ident.to_string();
1423            name.strip_prefix("__sql_forge_result_flag_")
1424                .map(|v| v.to_string())
1425        }
1426        _ => None,
1427    }
1428}
1429
1430/// Returns whether a pattern matches a specific boolean value at compile time.
1431fn pattern_matches_bool(pat: &Pat, value: bool) -> Option<bool> {
1432    match pat {
1433        Pat::Lit(expr_lit) => match &expr_lit.lit {
1434            Lit::Bool(lit_bool) => Some(lit_bool.value == value),
1435            _ => None,
1436        },
1437        Pat::Wild(_) => Some(true),
1438        _ => None,
1439    }
1440}
1441
1442/// Like `collect_section_variants`, but filters `match` arms by the active
1443/// result key when the match expression is a `{>key}` flag. When building the
1444/// query for a specific key, only the matching arm (true/false) is included;
1445/// arms with guards or non-flag expressions include all variants as usual.
1446fn collect_section_variants_for_result(
1447    value: SectionValue,
1448    width: usize,
1449    active_key: Option<&str>,
1450) -> Result<Vec<Vec<SectionFragment>>, String> {
1451    transpose_section_case_matrix(
1452        collect_section_case_matrix(value, width, active_key)?,
1453        width,
1454    )
1455}
1456
1457// =============================================================================
1458// Parameter binding generation
1459// =============================================================================
1460
1461/// Generates `let` bindings for all parameters used in the SQL or section
1462/// fragments. Returns a map of param-name → local ident for later reference,
1463/// and a list of `let` statements.
1464fn build_param_bindings(
1465    params: &ParamsSource,
1466    used_param_names: &[String],
1467    prefix: &str,
1468    for_validator: bool,
1469    enforce_usage_check: bool,
1470) -> Result<(HashMap<String, syn::Ident>, Vec<TokenStream2>), TokenStream> {
1471    let mut declared_params = HashMap::<String, syn::Ident>::new();
1472    let mut bindings = Vec::<TokenStream2>::new();
1473
1474    match params {
1475        ParamsSource::None => {}
1476        ParamsSource::Map(entries) => {
1477            for entry in entries {
1478                let key = entry.name.to_string();
1479                if declared_params.contains_key(&key) {
1480                    return Err(syn::Error::new(
1481                        entry.name.span(),
1482                        "sql_forge!: duplicated parameter mapping",
1483                    )
1484                    .to_compile_error()
1485                    .into());
1486                }
1487                if enforce_usage_check && !used_param_names.iter().any(|n| n == &key) {
1488                    return Err(syn::Error::new(
1489                        entry.name.span(),
1490                        format!(
1491                            "sql_forge!: parameter :{} is unused in the SQL template",
1492                            key,
1493                        ),
1494                    )
1495                    .to_compile_error()
1496                    .into());
1497                }
1498                let local_ident = format_ident!("__sql_forge_{}_{}", prefix, key);
1499                let expr = &entry.expr;
1500                if for_validator {
1501                    bindings.push(quote! {
1502                        let #local_ident = &(#expr);
1503                    });
1504                } else {
1505                    bindings.push(quote! {
1506                        let #local_ident = #expr;
1507                    });
1508                }
1509                declared_params.insert(key, local_ident);
1510            }
1511        }
1512        ParamsSource::Struct(expr) => {
1513            let source_ident = format_ident!("__sql_forge_source_{}", prefix);
1514            bindings.push(quote! {
1515                let #source_ident = &(#expr);
1516            });
1517            for name in used_param_names {
1518                let local_ident = format_ident!("__sql_forge_{}_{}", prefix, name);
1519                let field_ident = format_ident!("{}", name);
1520                if for_validator {
1521                    bindings.push(quote! {
1522                        let #local_ident = &#source_ident.#field_ident;
1523                    });
1524                } else {
1525                    bindings.push(quote! {
1526                        let #local_ident = #source_ident.#field_ident;
1527                    });
1528                }
1529                declared_params.insert(name.to_string(), local_ident);
1530            }
1531        }
1532    }
1533
1534    Ok((declared_params, bindings))
1535}
1536
1537struct ValidatorRenderContext<'a> {
1538    params: &'a HashMap<String, syn::Ident>,
1539    use_dollar_params: bool,
1540    sql_span: Span,
1541    list_count: usize,
1542}
1543
1544/// Builds the placeholders SQL string and argument list for the compile-time
1545/// validator (sqlx::query_as! / query_scalar!). Each `:param` in the SQL is
1546/// replaced by `?` (MySQL/SQLite) or `$1`/`$2`/... (PostgreSQL), and the
1547/// corresponding value expression is collected into the args list.
1548fn render_validator_args(
1549    sql: &str,
1550    param_offset: &mut usize,
1551    arg_index: &mut usize,
1552    context: &ValidatorRenderContext<'_>,
1553) -> Result<(String, Vec<TokenStream2>, Vec<TokenStream2>), TokenStream> {
1554    let parts = parse_text_parts(sql);
1555    let (rendered_sql, occurrences, _batch_args) = render_validator_sql(
1556        &parts,
1557        context.use_dollar_params,
1558        param_offset,
1559        context.list_count,
1560        None,
1561    )?;
1562
1563    let mut setup = Vec::<TokenStream2>::new();
1564    let mut args = Vec::<TokenStream2>::new();
1565
1566    for (name, is_list) in occurrences {
1567        let Some(local_ident) = context.params.get(&name) else {
1568            return Err(syn::Error::new(
1569                context.sql_span,
1570                format!("sql_forge!: parameter :{} has no mapping", name),
1571            )
1572            .to_compile_error()
1573            .into());
1574        };
1575
1576        if is_list {
1577            for _ in 0..context.list_count {
1578                let value_ident = format_ident!("__sql_forge_validator_arg_{}", *arg_index);
1579                *arg_index += 1;
1580                if context.use_dollar_params {
1581                    setup.push(quote! {
1582                        let #value_ident = sql_forge::sql_forge_validator_value(
1583                            (#local_ident)
1584                                .as_slice()
1585                                .first()
1586                                .expect("sql_forge!: list parameters used in validation must have at least one representative element")
1587                        );
1588                    });
1589                } else {
1590                    setup.push(quote! {
1591                        let #value_ident = (#local_ident)
1592                            .as_slice()
1593                            .first()
1594                            .expect("sql_forge!: list parameters used in validation must have at least one representative element");
1595                    });
1596                }
1597                args.push(quote! { #value_ident });
1598            }
1599        } else {
1600            let value_ident = format_ident!("__sql_forge_validator_arg_{}", *arg_index);
1601            *arg_index += 1;
1602            if context.use_dollar_params {
1603                setup.push(quote! {
1604                    let #value_ident = sql_forge::sql_forge_validator_value(#local_ident);
1605                });
1606            } else {
1607                setup.push(quote! {
1608                    let #value_ident = #local_ident;
1609                });
1610            }
1611            args.push(quote! { #value_ident });
1612        }
1613    }
1614
1615    Ok((rendered_sql, setup, args))
1616}
1617
1618// =============================================================================
1619// Runtime code generation (QueryBuilder-based)
1620// =============================================================================
1621
1622/// Generates the `push()` / `push_bind()` calls for a single section fragment
1623/// at runtime using `sqlx::QueryBuilder`.
1624fn render_runtime_fragment(
1625    fragment: &SectionFragment,
1626    local_params: &HashMap<String, syn::Ident>,
1627) -> Result<TokenStream2, TokenStream> {
1628    let mut steps = Vec::<TokenStream2>::new();
1629
1630    for part in parse_text_parts(&fragment.sql) {
1631        match part {
1632            TextPart::Lit(lit) => {
1633                let lit_str = LitStr::new(&lit, fragment.span);
1634                steps.push(quote! { __builder.push(#lit_str); });
1635            }
1636            TextPart::Param { name, is_list } => {
1637                let Some(local_ident) = local_params.get(&name) else {
1638                    return Err(syn::Error::new(
1639                        fragment.span,
1640                        format!("sql_forge!: parameter :{} has no mapping", name),
1641                    )
1642                    .to_compile_error()
1643                    .into());
1644                };
1645
1646                if is_list {
1647                    steps.push(quote! {
1648                        let __sql_forge_values = #local_ident;
1649                        let mut __separated = __builder.separated(", ");
1650                        for __value in __sql_forge_values {
1651                            __separated.push_bind(__value);
1652                        }
1653                    });
1654                } else {
1655                    steps.push(quote! {
1656                        __builder.push_bind(#local_ident);
1657                    });
1658                }
1659            }
1660        }
1661    }
1662
1663    Ok(quote! { #( #steps )* })
1664}
1665
1666/// Returns true if an ident starts with a lowercase letter or underscore (i.e., is a binding).
1667fn is_pat_binding(ident: &Ident) -> bool {
1668    let name = ident.to_string();
1669    !name.is_empty()
1670        && name
1671            .chars()
1672            .next()
1673            .is_some_and(|c| c.is_ascii_lowercase() || c == '_')
1674}
1675
1676/// Collects all binding identifiers from a pattern.
1677fn pat_var_idents(pat: &Pat) -> Vec<Ident> {
1678    let mut names = Vec::new();
1679    fn walk(p: &Pat, names: &mut Vec<Ident>) {
1680        match p {
1681            Pat::Ident(pi) if is_pat_binding(&pi.ident) => names.push(pi.ident.clone()),
1682            Pat::Tuple(pt) => pt.elems.iter().for_each(|e| walk(e, names)),
1683            Pat::Struct(ps) => ps.fields.iter().for_each(|f| walk(&f.pat, names)),
1684            Pat::TupleStruct(pts) => pts.elems.iter().for_each(|e| walk(e, names)),
1685            Pat::Or(po) => po.cases.iter().for_each(|c| walk(c, names)),
1686            Pat::Paren(pp) => walk(&pp.pat, names),
1687            Pat::Reference(pr) => walk(&pr.pat, names),
1688            Pat::Slice(psl) => psl.elems.iter().for_each(|e| walk(e, names)),
1689            Pat::Type(pt) => walk(&pt.pat, names),
1690            _ => {}
1691        }
1692    }
1693    walk(pat, &mut names);
1694    names
1695}
1696
1697/// Returns true if a section value's SQL or params references a given parameter name.
1698fn section_value_refers_to(value: &SectionValue, name: &str) -> bool {
1699    match value {
1700        SectionValue::Single(f) => {
1701            if collect_used_param_names_in_sql(&f.sql)
1702                .iter()
1703                .any(|n| n == name)
1704            {
1705                return true;
1706            }
1707            if let ParamsSource::Map(entries) = &f.params {
1708                for e in entries {
1709                    let expr = &e.expr;
1710                    let expr_str = quote! { #expr }.to_string();
1711                    if expr_str.trim() == name {
1712                        return true;
1713                    }
1714                }
1715            }
1716            false
1717        }
1718        SectionValue::Grouped(vals) => vals.iter().any(|v| section_value_refers_to(v, name)),
1719        SectionValue::Match { arms, .. } => arms.iter().any(|arm| {
1720            let pat_vars: HashSet<_> = pat_var_idents(&arm.pat)
1721                .into_iter()
1722                .map(|i| i.to_string())
1723                .collect();
1724            if pat_vars.contains(name) {
1725                false
1726            } else {
1727                section_value_refers_to(&arm.value, name)
1728            }
1729        }),
1730    }
1731}
1732
1733/// Builds the runtime `QueryBuilder` action tokens for a section value.
1734fn build_section_runtime_action(
1735    value: &SectionValue,
1736    section_idx: usize,
1737    prefix: &str,
1738) -> Result<TokenStream2, TokenStream> {
1739    match value {
1740        SectionValue::Single(fragment) => {
1741            let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
1742            let (local_params, bindings) =
1743                build_param_bindings(&fragment.params, &used_param_names, prefix, false, true)?;
1744            let body = render_runtime_fragment(fragment, &local_params)?;
1745            Ok(quote! {{ #( #bindings )* #body }})
1746        }
1747        SectionValue::Grouped(fragments) => build_section_runtime_action(
1748            &fragments[section_idx],
1749            0,
1750            &format!("{}_grouped_{}", prefix, section_idx),
1751        ),
1752        SectionValue::Match { expr, arms } => {
1753            let arm_tokens: Result<Vec<TokenStream2>, TokenStream> = arms
1754                .iter()
1755                .enumerate()
1756                .map(|(arm_idx, arm)| {
1757                    let pat = &arm.pat;
1758                    let guard_tokens = arm.guard.as_ref().map(|guard| quote! { if #guard });
1759                    let body = build_section_runtime_action(
1760                        &arm.value,
1761                        section_idx,
1762                        &format!("{}_{}", prefix, arm_idx),
1763                    )?;
1764                    let noop_refs: Vec<TokenStream2> = pat_var_idents(pat)
1765                        .into_iter()
1766                        .filter(|ident| section_value_refers_to(&arm.value, &ident.to_string()))
1767                        .map(|ident| quote! { ::core::hint::black_box(&#ident); })
1768                        .collect();
1769                    Ok::<TokenStream2, TokenStream>(quote! {
1770                        #pat #guard_tokens => {
1771                            #( #noop_refs )*
1772                            #body
1773                        }
1774                    })
1775                })
1776                .collect();
1777            let arm_tokens = arm_tokens?;
1778            Ok(quote! {
1779                match #expr {
1780                    #( #arm_tokens ),*
1781                }
1782            })
1783        }
1784    }
1785}
1786
1787/// Extracts unique parameter names from a list of segments.
1788fn collect_used_param_names(segments: &[Segment]) -> Vec<String> {
1789    let mut names = Vec::new();
1790    let mut seen = HashSet::<String>::new();
1791
1792    for segment in segments {
1793        match segment {
1794            Segment::Text(text) => {
1795                for name in collect_used_param_names_in_sql(text) {
1796                    if seen.insert(name.clone()) {
1797                        names.push(name);
1798                    }
1799                }
1800            }
1801            Segment::Batch { parts } => {
1802                for part in parts {
1803                    if let TextPart::Param { name, .. } = part {
1804                        if seen.insert(name.clone()) {
1805                            names.push(name.clone());
1806                        }
1807                    }
1808                }
1809            }
1810            _ => {}
1811        }
1812    }
1813
1814    names
1815}
1816
1817/// Extracts unique parameter names from a SQL text string.
1818fn collect_used_param_names_in_sql(sql: &str) -> Vec<String> {
1819    let mut names = Vec::new();
1820    let mut seen = HashSet::<String>::new();
1821    for part in parse_text_parts(sql) {
1822        if let TextPart::Param { name, .. } = part {
1823            if seen.insert(name.to_string()) {
1824                names.push(name);
1825            }
1826        }
1827    }
1828    names
1829}
1830
1831/// Builds a parameterized SQL query with compile-time type-checking and a
1832/// runtime [`sqlx::QueryBuilder`] for dynamic SQL.
1833///
1834/// Combines `sqlx::query_as!` / `sqlx::query_scalar!` validation (never called
1835/// at runtime) with `QueryBuilder::push_bind` for safe value binding.
1836///
1837/// # Syntax
1838///
1839/// ```text
1840/// sql_forge!(
1841///     [DB,]        // optional: sqlx::MySql | sqlx::Postgres | sqlx::Sqlite
1842///     [Model,]     // optional result spec
1843///     SQL,         // string literal
1844///     [params,]    // optional: ( :name = expr, ... )  or  struct_expr
1845///     [(sections),]// optional: ( #name = ..., ... )
1846///     [..batch]    // optional: batch source expression used by {( ... )}
1847/// )
1848/// ```
1849///
1850/// `Model` has three forms:
1851/// - omitted: execute-only query; only `.execute(...)` is available
1852/// - `Type` or `scalar Type`: a single result query
1853/// - `( >key1 = TypeA, >key2 = scalar TypeB )`: a grouped multi-result query
1854///
1855/// The trailing parameter source, section map, and batch source are optional.
1856/// The batch source may appear alongside the others as a single `..expr` argument.
1857///
1858/// The DB type may be omitted when `SQL_FORGE_DB_TYPE` is set (e.g.
1859/// `SQL_FORGE_DB_TYPE=sqlx::MySql`) or when
1860/// `[package.metadata.sql_forge] db = "..."` is set in `Cargo.toml`.
1861/// The env var takes priority over Cargo.toml metadata.
1862///
1863/// # Parameters
1864///
1865/// Named parameters are written `:name` in the SQL. At runtime each occurrence
1866/// is replaced by a `push_bind` call; at compile time it becomes a
1867/// database-specific placeholder: `?` for MySQL and SQLite, and `$1`, `$2`, ...
1868/// for Postgres.
1869///
1870/// **Inline map**: bind individual expressions:
1871/// ```rust,ignore
1872/// sql_forge!(User, "SELECT ... WHERE id <= :max_id", ( :max_id = filter.max_id ))
1873/// ```
1874///
1875/// **Struct source**: field names are matched to `:name` placeholders automatically:
1876/// ```rust,ignore
1877/// sql_forge!(User, "SELECT ... WHERE id <= :max_id LIMIT :limit", filter)
1878/// ```
1879///
1880/// # Sections (`{#name}`)
1881///
1882/// Sections are runtime SQL slots; each section's variants are validated at
1883/// compile time via `query_as!` / `query_scalar!`, though not every combination
1884/// of variants across sections is checked. The section map is a second parenthesised
1885/// argument starting with `#`:
1886///
1887/// ```rust,ignore
1888/// sql_forge!(
1889///     User,
1890///     "SELECT * FROM users {#join_org}",
1891///     (
1892///         #join_org = match include_org {
1893///             true  => " JOIN organisations o ON o.id = users.org_id ",
1894///             false => "",
1895///         }
1896///     )
1897/// )
1898/// ```
1899///
1900/// A section arm can also carry local parameters as a tuple `("sql", params)`:
1901///
1902/// ```rust,ignore
1903/// sql_forge!(
1904///     User,
1905///     "SELECT * FROM users {#filter}",
1906///     (
1907///         #filter = (
1908///             " WHERE id <= :max_id AND status = :status ",
1909///             ( :max_id = max_id, :status = "active" ),
1910///         )
1911///     )
1912/// )
1913/// ```
1914///
1915/// Multiple placeholders driven by one `match` use `#(a, b)` with each arm
1916/// returning a tuple of the same width:
1917///
1918/// ```rust,ignore
1919/// sql_forge!(
1920///     User,
1921///     "SELECT * FROM users {#join_org} {#filter_org}",
1922///     (
1923///         #(join_org, filter_org) = match include_org {
1924///             true  => (
1925///                 " JOIN organisations o ON o.id = users.org_id ",
1926///                 (
1927///                     " AND o.active = :active ",
1928///                     ( :active = true ),
1929///                 ),
1930///             ),
1931///             false => ("", ""),
1932///         }
1933///     )
1934/// )
1935/// ```
1936///
1937/// Grouped section items may themselves use nested `match` expressions. Those
1938/// nested matches use smart cycling within the arm rather than a cartesian
1939/// product. For example, if one grouped arm returns a fixed first item plus two
1940/// nested binary matches for the second and third items, that arm contributes
1941/// two aligned variants `(0, 0)` and `(1, 1)`, not four `(0, 0)`, `(0, 1)`,
1942/// `(1, 0)`, `(1, 1)` combinations.
1943///
1944/// # `IN (...)` with list parameters
1945///
1946/// Wrap the placeholder in parentheses to expand a `Vec` into multiple bound
1947/// values:
1948///
1949/// ```rust,ignore
1950/// sql_forge!(User, "SELECT * FROM users WHERE id IN (:ids[])", ( :ids = ids ))
1951/// ```
1952///
1953/// **Empty lists** are not rewritten; `IN ()` is a database syntax error.
1954/// Guard against empty inputs explicitly, e.g. with a dynamic section:
1955///
1956/// ```rust,ignore
1957/// sql_forge!(
1958///     User,
1959///     "SELECT id, name FROM users WHERE {#filter}",
1960///     (
1961///         #filter = match ids.is_empty() {
1962///             true  => "1 = 0",
1963///             false => ("id IN (:ids[])", ( :ids = ids )),
1964///         }
1965///     )
1966/// )
1967/// ```
1968///
1969/// # Batch inserts (`{( ... )}`)
1970///
1971/// A batch section `{( ... )}` repeats its content for each item in an iterable
1972/// source passed as `..expr`. Inside the batch, `:name` refers to a field on the
1973/// current item. List parameters (`:name[]`) are **not** allowed inside batch
1974/// sections.
1975///
1976/// ## Struct batch
1977///
1978/// ```rust,ignore
1979/// struct BatchItem { name: String, price: i64 }
1980///
1981/// let items = vec![
1982///     BatchItem { name: "A".into(), price: 100 },
1983///     BatchItem { name: "B".into(), price: 200 },
1984/// ];
1985///
1986/// sql_forge!(
1987///     "INSERT INTO products (name, price, stock, category)
1988///      VALUES {(:name, :price, 10, 'Batch')}",
1989///     ..items
1990/// )
1991/// .execute(&pool)
1992/// .await?;
1993/// ```
1994///
1995/// For compile-time checking, the validator expands the batch to 3 fake copies
1996/// (`(?, ?, 10, 'Batch'), (?, ?, 10, 'Batch'), (?, ?, 10, 'Batch')`).
1997/// At runtime the iterable drives the actual number of rows.
1998///
1999/// # Scalar output
2000///
2001/// When `Model` is a primitive (`i32`, `i64`, `String`, etc.) the macro uses
2002/// `query_scalar!` for validation and `build_query_scalar` for execution.
2003///
2004/// # Multiple results
2005///
2006/// A result map produces a `SqlForgeQueryGroup` with one query per key.
2007/// Each key can be a struct or a primitive (used as a scalar):
2008///
2009/// ```rust,ignore
2010/// sql_forge!(
2011///     (
2012///         >count = i64,
2013///         >items = Item,
2014///     ),
2015///     "SELECT {#fields} FROM items WHERE category_id = :cat",
2016///     ( :cat = category_id ),
2017///     (
2018///         #fields = match {>count} {           // {>key} is true when building
2019///             true  => "COUNT(*) AS total",    // the query for that model/result
2020///             false => "id, name, price",      // key and false otherwise
2021///         }
2022///     )
2023/// )
2024/// ```
2025///
2026/// The generated struct has one field per key (`group.count`, `group.items`),
2027/// each implementing `SqlForgeQuery<T, Db = DB>` and usable with any SQLx
2028/// executor method (`fetch_one`, `fetch_all`, etc.).
2029///
2030/// # Execute-only (no model)
2031///
2032/// When the model type is omitted, the macro produces a value implementing
2033/// `SqlForgeQueryExecute`. Only `.execute(executor)`
2034/// is available and there is no return type to deserialize into. This is useful
2035/// for `INSERT`, `UPDATE`, `DELETE`, and other DML statements.
2036///
2037/// ```rust,ignore
2038/// sql_forge!(
2039///     "UPDATE products SET stock = stock + 1 WHERE id = :id",
2040///     ( :id = 42i64 ),
2041/// )
2042/// .execute(&pool)
2043/// .await?;
2044/// ```
2045///
2046/// Sections and struct parameter sources work the same way as in model-backed queries:
2047///
2048/// ```rust,ignore
2049/// sql_forge!(
2050///     "UPDATE products SET price = :new_price {#filter}",
2051///     ( #filter = "WHERE category = :cat", ( :cat = "Electronics" ) ),
2052/// )
2053/// .execute(&pool)
2054/// .await?;
2055/// ```
2056#[proc_macro]
2057#[allow(clippy::too_many_lines)]
2058pub fn sql_forge(input: TokenStream) -> TokenStream {
2059    // ---- Phase 1: Parse the macro input into structured data ----
2060    let preprocessed = preprocess_result_key_placeholders(TokenStream2::from(input));
2061    let SqlForgeInput {
2062        db,
2063        result,
2064        force_scalar,
2065        sql,
2066        params,
2067        sections,
2068        batch,
2069    } = match syn::parse2::<SqlForgeInput>(preprocessed) {
2070        Ok(v) => v,
2071        Err(err) => return err.to_compile_error().into(),
2072    };
2073
2074    // ---- Phase 2: Resolve database type (from macro arg, env var or Cargo.toml) ----
2075    let db = match db {
2076        Some(db) => db,
2077        None => match resolve_db_from_env() {
2078            Ok(db) => db,
2079            Err(msg) => {
2080                return syn::Error::new(Span::call_site(), msg)
2081                    .to_compile_error()
2082                    .into();
2083            }
2084        },
2085    };
2086
2087    let use_dollar_params = uses_dollar_params(&db);
2088    let list_count: usize = 3;
2089
2090    // ---- Phase 3: Build result case definitions ----
2091    // Each result case is (optional_key, optional_model_type, optional_scalar_type).
2092    // Scalar type is set for primitives and `scalar`-marked types.
2093    let result_cases: Vec<(Option<String>, Option<Type>, Option<Type>)> = match result {
2094        ResultSpec::None => {
2095            vec![(None, None, None)]
2096        }
2097        ResultSpec::Single(ref model) => {
2098            let model_ty = (**model).clone();
2099            let scalar = if force_scalar {
2100                Some(model_ty.clone())
2101            } else {
2102                scalar_output_type(model.as_ref()).cloned()
2103            };
2104            vec![(None, Some(model_ty), scalar)]
2105        }
2106        ResultSpec::Group(ref cases) => {
2107            if force_scalar {
2108                return syn::Error::new(
2109                    Span::call_site(),
2110                    "sql_forge!: scalar mode is not supported for grouped result maps",
2111                )
2112                .to_compile_error()
2113                .into();
2114            }
2115
2116            let mut out = Vec::new();
2117            let mut seen = HashSet::new();
2118            for case in cases {
2119                let key = case.name.to_string();
2120                if !seen.insert(key.clone()) {
2121                    return syn::Error::new(
2122                        case.name.span(),
2123                        "sql_forge!: duplicated key in result map",
2124                    )
2125                    .to_compile_error()
2126                    .into();
2127                }
2128
2129                let model = case.model.clone();
2130                let scalar = if case.force_scalar {
2131                    Some(model.clone())
2132                } else {
2133                    scalar_output_type(&case.model).cloned()
2134                };
2135                out.push((Some(key), Some(model), scalar));
2136            }
2137            out
2138        }
2139    };
2140    let group_result_keys: Vec<String> = result_cases
2141        .iter()
2142        .filter_map(|(key, _, _)| key.as_ref().cloned())
2143        .collect();
2144    let is_grouped_result = !group_result_keys.is_empty();
2145    let sql_span = sql.span();
2146
2147    // ---- Phase 4: Parse SQL into segments (text + {#section} + {( ..batch )} slots) ----
2148    let segments = match sql.into_segments() {
2149        Ok(segments) => segments,
2150        Err(msg) => {
2151            return syn::Error::new(sql_span, msg).to_compile_error().into();
2152        }
2153    };
2154
2155    let has_batch_segment = segments.iter().any(|s| matches!(s, Segment::Batch { .. }));
2156    match (&batch, has_batch_segment) {
2157        (None, true) => {
2158            return syn::Error::new(
2159                sql_span,
2160                "sql_forge!: SQL contains {( ... )} batch section but no batch source argument (..expr) \
2161                 was provided"
2162            )
2163            .to_compile_error()
2164            .into();
2165        }
2166        (Some(_), false) => {
2167            return syn::Error::new(
2168                sql_span,
2169                "sql_forge!: batch source argument (..expr) provided but SQL has no {( ... )} \
2170                 batch section",
2171            )
2172            .to_compile_error()
2173            .into();
2174        }
2175        _ => {}
2176    }
2177
2178    let used_param_names = collect_used_param_names(&segments);
2179
2180    // Batch-only params come from batch items, not the top-level params map.
2181    // They must be excluded from the usage check so that a param like :category
2182    // that appears only inside {( ... )} is flagged as unused when given in the
2183    // params map, as it would never be read from there at runtime.
2184    //
2185    // However, params that appear in both text and batch segments must NOT be
2186    // excluded — the text-segment occurrence resolves against the params map.
2187    let text_param_names: std::collections::HashSet<String> = segments
2188        .iter()
2189        .filter_map(|s| {
2190            if let Segment::Text(text) = s {
2191                Some(collect_used_param_names_in_sql(text).into_iter())
2192            } else {
2193                None
2194            }
2195        })
2196        .flatten()
2197        .collect();
2198    let top_level_used_names: Vec<String> = used_param_names
2199        .iter()
2200        .filter(|n| text_param_names.contains(*n))
2201        .cloned()
2202        .collect();
2203
2204    // ---- Phase 5: Build parameter bindings for the top-level params ----
2205    let (declared_params, validator_param_bindings) =
2206        match build_param_bindings(&params, &top_level_used_names, "top_level", true, true) {
2207            Ok(v) => v,
2208            Err(err) => return err,
2209        };
2210
2211    let mut runtime_section_actions = HashMap::<String, TokenStream2>::new();
2212
2213    // ---- Phase 6: Process sections: build runtime actions and collect validation variants ----
2214    for assign in &sections {
2215        let SectionAssign { names, value } = assign;
2216
2217        // Build runtime actions first, while `value` is still available by reference.
2218        let mut named_actions: Vec<(String, TokenStream2)> = Vec::new();
2219        for (section_idx, name_ident) in names.iter().enumerate() {
2220            let name = name_ident.to_string();
2221            if runtime_section_actions.contains_key(&name) {
2222                return syn::Error::new(
2223                    name_ident.span(),
2224                    "sql_forge!: duplicated section mapping",
2225                )
2226                .to_compile_error()
2227                .into();
2228            }
2229            let action = match build_section_runtime_action(
2230                value,
2231                section_idx,
2232                &format!("section_{}", name),
2233            ) {
2234                Ok(action) => action,
2235                Err(err) => return err,
2236            };
2237            named_actions.push((name, action));
2238        }
2239
2240        // Consume `value` here so invalid grouped/nested section structures fail early.
2241        if let Err(msg) = collect_section_variants(value.clone(), names.len()) {
2242            return syn::Error::new(names[0].span(), msg)
2243                .to_compile_error()
2244                .into();
2245        }
2246
2247        for (name, action) in named_actions {
2248            runtime_section_actions.insert(name, action);
2249        }
2250    }
2251
2252    let sql_section_names: std::collections::HashSet<&str> = segments
2253        .iter()
2254        .filter_map(|seg| {
2255            if let Segment::Section { name } = seg {
2256                Some(name.as_str())
2257            } else {
2258                None
2259            }
2260        })
2261        .collect();
2262    for name in runtime_section_actions.keys() {
2263        if !sql_section_names.contains(name.as_str()) {
2264            return syn::Error::new(
2265                sql_span,
2266                format!(
2267                    "sql_forge!: section `#{}` is declared in the section map but `{{#{}}}` never appears in the SQL",
2268                    name, name,
2269                ),
2270            )
2271            .to_compile_error()
2272            .into();
2273        }
2274    }
2275
2276    // ---- Phase 8: For each result case, generate validator + runtime tokens ----
2277    let mut generated_query_defs = Vec::<TokenStream2>::new();
2278    let mut generated_query_values = Vec::<TokenStream2>::new();
2279    let mut group_field_defs = Vec::<TokenStream2>::new();
2280    let mut group_field_idents = Vec::<syn::Ident>::new();
2281    let mut group_field_tys = Vec::<TokenStream2>::new();
2282    let mut group_trait_impls = Vec::<TokenStream2>::new();
2283
2284    let mut grouped_validator_invocations = Vec::<TokenStream2>::new();
2285
2286    for (result_key, model_opt, scalar_model_ty) in result_cases.iter() {
2287        let suffix = result_key.as_deref().unwrap_or("single");
2288        let query_ident = format_ident!("__SqlForgeQuery_{}", suffix);
2289        let query_value_ident = format_ident!("__sql_forge_value_{}", suffix);
2290
2291        let flag_bindings = build_result_flag_bindings(&group_result_keys, result_key.as_deref());
2292
2293        let mut section_variants_for_validation = HashMap::<String, Vec<SectionFragment>>::new();
2294        for assign in &sections {
2295            let SectionAssign { names, value } = assign;
2296            let variants_by_section = match collect_section_variants_for_result(
2297                value.clone(),
2298                names.len(),
2299                result_key.as_deref(),
2300            ) {
2301                Ok(v) => v,
2302                Err(msg) => {
2303                    return syn::Error::new(names[0].span(), msg)
2304                        .to_compile_error()
2305                        .into();
2306                }
2307            };
2308
2309            for (name_ident, section_cases) in names.iter().zip(variants_by_section) {
2310                section_variants_for_validation.insert(name_ident.to_string(), section_cases);
2311            }
2312        }
2313
2314        let mut nmax = 1usize;
2315        for segment in &segments {
2316            if let Segment::Section { name } = segment {
2317                if let Some(variants) = section_variants_for_validation.get(name) {
2318                    if variants.is_empty() {
2319                        return syn::Error::new(
2320                            sql_span,
2321                            format!("sql_forge!: section {{#{}}} has no possible variants", name),
2322                        )
2323                        .to_compile_error()
2324                        .into();
2325                    }
2326                    nmax = nmax.max(variants.len());
2327                } else {
2328                    return syn::Error::new(
2329                        sql_span,
2330                        format!("sql_forge!: section {{#{}}} has no mapping", name),
2331                    )
2332                    .to_compile_error()
2333                    .into();
2334                }
2335            }
2336        }
2337
2338        let mut validator_cases = Vec::<(LitStr, Vec<TokenStream2>, Vec<TokenStream2>)>::new();
2339        for case_idx in 0..nmax {
2340            let mut sql_case = String::new();
2341            let mut case_setup = Vec::<TokenStream2>::new();
2342            let mut case_args = Vec::<TokenStream2>::new();
2343            let mut param_offset = 0usize;
2344            let mut arg_index = 0usize;
2345            let root_validator_context = ValidatorRenderContext {
2346                params: &declared_params,
2347                use_dollar_params,
2348                sql_span,
2349                list_count,
2350            };
2351
2352            for segment in &segments {
2353                match segment {
2354                    Segment::Text(text) => {
2355                        let (chunk_sql, chunk_setup, chunk_args) = match render_validator_args(
2356                            text,
2357                            &mut param_offset,
2358                            &mut arg_index,
2359                            &root_validator_context,
2360                        ) {
2361                            Ok(value) => value,
2362                            Err(err) => return err,
2363                        };
2364                        sql_case.push_str(&chunk_sql);
2365                        case_setup.extend(chunk_setup);
2366                        case_args.extend(chunk_args);
2367                    }
2368                    Segment::Section { name } => {
2369                        let Some(variants) = section_variants_for_validation.get(name) else {
2370                            return syn::Error::new(
2371                                sql_span,
2372                                format!("sql_forge!: section {{#{}}} has no mapping", name),
2373                            )
2374                            .to_compile_error()
2375                            .into();
2376                        };
2377
2378                        let fragment = &variants[case_idx % variants.len()];
2379                        let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
2380                        let (local_params, bindings) = match build_param_bindings(
2381                            &fragment.params,
2382                            &used_param_names,
2383                            &format!("section_case_{}_{}_{}", suffix, case_idx, name),
2384                            true,
2385                            true,
2386                        ) {
2387                            Ok(value) => value,
2388                            Err(err) => return err,
2389                        };
2390                        let section_validator_context = ValidatorRenderContext {
2391                            params: &local_params,
2392                            use_dollar_params,
2393                            sql_span: fragment.span,
2394                            list_count,
2395                        };
2396                        let (chunk_sql, chunk_setup, chunk_args) = match render_validator_args(
2397                            &fragment.sql,
2398                            &mut param_offset,
2399                            &mut arg_index,
2400                            &section_validator_context,
2401                        ) {
2402                            Ok(value) => value,
2403                            Err(err) => return err,
2404                        };
2405                        sql_case.push_str(&chunk_sql);
2406                        case_setup.extend(bindings);
2407                        case_setup.extend(chunk_setup);
2408                        case_args.extend(chunk_args);
2409                    }
2410                    Segment::Batch { parts } => {
2411                        let batch_ts = batch.as_ref().map(|e| quote! { #e });
2412                        let mut first = true;
2413                        for _ in 0..list_count {
2414                            let sep = if first { "" } else { ", " };
2415                            first = false;
2416                            sql_case.push_str(sep);
2417                            let (chunk_sql, _occurrences, chunk_args) = match render_validator_sql(
2418                                parts,
2419                                use_dollar_params,
2420                                &mut param_offset,
2421                                list_count,
2422                                batch_ts.clone(),
2423                            ) {
2424                                Ok(value) => value,
2425                                Err(err) => return err,
2426                            };
2427                            sql_case.push_str(&chunk_sql);
2428                            case_args.extend(chunk_args);
2429                        }
2430                    }
2431                }
2432            }
2433
2434            validator_cases.push((LitStr::new(&sql_case, sql_span), case_setup, case_args));
2435        }
2436
2437        let mut validator_invocations = Vec::<TokenStream2>::new();
2438        for (sql_lit, case_setup, args) in &validator_cases {
2439            if model_opt.is_none() {
2440                if args.is_empty() {
2441                    validator_invocations.push(quote! {
2442                        {
2443                            #( #case_setup )*
2444                            let _ = sqlx::query_scalar!(
2445                                #sql_lit,
2446                            );
2447                        }
2448                    });
2449                } else {
2450                    validator_invocations.push(quote! {
2451                        {
2452                            #( #case_setup )*
2453                            let _ = sqlx::query_scalar!(
2454                                #sql_lit,
2455                                #( #args ),*
2456                            );
2457                        }
2458                    });
2459                }
2460            } else if let Some(scalar_ty) = scalar_model_ty {
2461                if args.is_empty() {
2462                    validator_invocations.push(quote! {
2463                        {
2464                            #( #case_setup )*
2465                            let _ = sqlx::query_scalar!(
2466                                #sql_lit,
2467                            );
2468                        }
2469                    });
2470                } else {
2471                    validator_invocations.push(quote! {
2472                        {
2473                            #( #case_setup )*
2474                            let _ = sqlx::query_scalar!(
2475                                #sql_lit,
2476                                #( #args ),*
2477                            );
2478                        }
2479                    });
2480                }
2481                let _ = scalar_ty;
2482            } else if args.is_empty() {
2483                validator_invocations.push(quote! {
2484                    {
2485                        #( #case_setup )*
2486                        let _ = sqlx::query_as!(
2487                            __SqlForgeModel,
2488                            #sql_lit,
2489                        );
2490                    }
2491                });
2492            } else {
2493                validator_invocations.push(quote! {
2494                    {
2495                        #( #case_setup )*
2496                        let _ = sqlx::query_as!(
2497                            __SqlForgeModel,
2498                            #sql_lit,
2499                            #( #args ),*
2500                        );
2501                    }
2502                });
2503            }
2504        }
2505
2506        let model_alias = if let Some(model) = model_opt {
2507            if scalar_model_ty.is_none() {
2508                quote! { type __SqlForgeModel = #model; }
2509            } else {
2510                quote! {}
2511            }
2512        } else {
2513            quote! {}
2514        };
2515        grouped_validator_invocations.push(quote! {
2516            {
2517                #( #flag_bindings )*
2518                #model_alias
2519                #( #validator_invocations )*
2520            }
2521        });
2522
2523        let (runtime_declared_params, runtime_param_bindings) =
2524            match build_param_bindings(&params, &used_param_names, "runtime", false, false) {
2525                Ok(v) => v,
2526                Err(err) => return err,
2527            };
2528
2529        let mut runtime_steps = Vec::<TokenStream2>::new();
2530        for (seg_idx, segment) in segments.iter().enumerate() {
2531            match segment {
2532                Segment::Text(text) => {
2533                    for part in parse_text_parts(text) {
2534                        match part {
2535                            TextPart::Lit(lit) => {
2536                                let lit = sanitize_runtime_sql_text(&lit);
2537                                let lit_str = LitStr::new(&lit, sql_span);
2538                                runtime_steps.push(quote! {
2539                                    __builder.push(#lit_str);
2540                                });
2541                            }
2542                            TextPart::Param { name, is_list } => {
2543                                let Some(local_ident) = runtime_declared_params.get(&name) else {
2544                                    return syn::Error::new(
2545                                        sql_span,
2546                                        format!("sql_forge!: parameter :{} has no mapping", name),
2547                                    )
2548                                    .to_compile_error()
2549                                    .into();
2550                                };
2551
2552                                if is_list {
2553                                    runtime_steps.push(quote! {
2554                                        let __sql_forge_values = #local_ident;
2555                                        let mut __separated = __builder.separated(", ");
2556                                        for __value in __sql_forge_values {
2557                                            __separated.push_bind(__value);
2558                                        }
2559                                    });
2560                                } else {
2561                                    runtime_steps.push(quote! {
2562                                        __builder.push_bind(#local_ident);
2563                                    });
2564                                }
2565                            }
2566                        }
2567                    }
2568                }
2569                Segment::Section { name } => {
2570                    let Some(section_action) = runtime_section_actions.get(name) else {
2571                        let _ = seg_idx;
2572                        return syn::Error::new(
2573                            sql_span,
2574                            format!("sql_forge!: section {{#{}}} has no mapping", name),
2575                        )
2576                        .to_compile_error()
2577                        .into();
2578                    };
2579                    runtime_steps.push(quote! {
2580                        #section_action
2581                    });
2582                }
2583                Segment::Batch { parts } => {
2584                    if let Some(batch_expr) = &batch {
2585                        let mut body = Vec::<TokenStream2>::new();
2586                        for part in parts {
2587                            match part {
2588                                TextPart::Lit(lit) => {
2589                                    let lit_str = LitStr::new(lit, sql_span);
2590                                    body.push(quote! {
2591                                        __builder.push(#lit_str);
2592                                    });
2593                                }
2594                                TextPart::Param { name, .. } => {
2595                                    let field_ident = format_ident!("{}", name);
2596                                    body.push(quote! {
2597                                        __builder.push_bind(__item.#field_ident);
2598                                    });
2599                                }
2600                            }
2601                        }
2602                        runtime_steps.push(quote! {
2603                            {
2604                                let mut __first = true;
2605                                for __item in #batch_expr {
2606                                    if !__first {
2607                                        __builder.push(", ");
2608                                    }
2609                                    __first = false;
2610                                    #( #body )*
2611                                }
2612                            }
2613                        });
2614                    }
2615                }
2616            }
2617        }
2618
2619        let exec_methods = if model_opt.is_none() {
2620            quote! {
2621                async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2622                where
2623                    E: sqlx::Executor<'e, Database = #db>,
2624                {
2625                    self.inner.build().execute(executor).await
2626                }
2627            }
2628        } else if let Some(scalar_ty) = scalar_model_ty {
2629            quote! {
2630                async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#scalar_ty>, sqlx::Error>
2631                where
2632                    E: sqlx::Executor<'e, Database = #db>,
2633                {
2634                    self.inner
2635                        .build_query_scalar::<#scalar_ty>()
2636                        .fetch_all(executor)
2637                        .await
2638                }
2639
2640                async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#scalar_ty, sqlx::Error>
2641                where
2642                    E: sqlx::Executor<'e, Database = #db>,
2643                {
2644                    self.inner
2645                        .build_query_scalar::<#scalar_ty>()
2646                        .fetch_one(executor)
2647                        .await
2648                }
2649
2650                async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#scalar_ty>, sqlx::Error>
2651                where
2652                    E: sqlx::Executor<'e, Database = #db>,
2653                {
2654                    self.inner
2655                        .build_query_scalar::<#scalar_ty>()
2656                        .fetch_optional(executor)
2657                        .await
2658                }
2659
2660                async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2661                where
2662                    E: sqlx::Executor<'e, Database = #db>,
2663                {
2664                    self.inner.build().execute(executor).await
2665                }
2666            }
2667        } else {
2668            let model = model_opt.as_ref().unwrap();
2669            quote! {
2670                async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#model>, sqlx::Error>
2671                where
2672                    E: sqlx::Executor<'e, Database = #db>,
2673                {
2674                    self.inner.build_query_as::<#model>().fetch_all(executor).await
2675                }
2676
2677                async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#model, sqlx::Error>
2678                where
2679                    E: sqlx::Executor<'e, Database = #db>,
2680                {
2681                    self.inner.build_query_as::<#model>().fetch_one(executor).await
2682                }
2683
2684                async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#model>, sqlx::Error>
2685                where
2686                    E: sqlx::Executor<'e, Database = #db>,
2687                {
2688                    self.inner
2689                        .build_query_as::<#model>()
2690                        .fetch_optional(executor)
2691                        .await
2692                }
2693
2694                async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2695                where
2696                    E: sqlx::Executor<'e, Database = #db>,
2697                {
2698                    self.inner.build().execute(executor).await
2699                }
2700            }
2701        };
2702
2703        let final_type: TokenStream2 = if let Some(model) = model_opt {
2704            if let Some(scalar_ty) = scalar_model_ty {
2705                quote! { #scalar_ty }
2706            } else {
2707                quote! { #model }
2708            }
2709        } else {
2710            quote! {}
2711        };
2712        let trait_impl = if model_opt.is_none() {
2713            quote! {
2714                impl sql_forge::SqlForgeQueryExecute
2715                    for #query_ident
2716                {
2717                    type Db = #db;
2718
2719                    fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2720                    where
2721                        Self: Sized + 'e,
2722                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2723                        #db: 'e,
2724                    {
2725                        #query_ident::execute(self, executor)
2726                    }
2727                }
2728            }
2729        } else {
2730            quote! {
2731                impl sql_forge::SqlForgeQuery<#final_type>
2732                    for #query_ident
2733                {
2734                    type Db = #db;
2735
2736                    fn fetch_all<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Vec<#final_type>, sqlx::Error>> + Send + 'e
2737                    where
2738                        Self: Sized + 'e,
2739                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2740                        #db: 'e,
2741                    {
2742                        #query_ident::fetch_all(self, executor)
2743                    }
2744
2745                    fn fetch_one<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<#final_type, sqlx::Error>> + Send + 'e
2746                    where
2747                        Self: Sized + 'e,
2748                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2749                        #db: 'e,
2750                    {
2751                        #query_ident::fetch_one(self, executor)
2752                    }
2753
2754                    fn fetch_optional<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Option<#final_type>, sqlx::Error>> + Send + 'e
2755                    where
2756                        Self: Sized + 'e,
2757                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2758                        #db: 'e,
2759                    {
2760                        #query_ident::fetch_optional(self, executor)
2761                    }
2762
2763                    fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2764                    where
2765                        Self: Sized + 'e,
2766                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2767                        #db: 'e,
2768                    {
2769                        #query_ident::execute(self, executor)
2770                    }
2771                }
2772            }
2773        };
2774
2775        generated_query_defs.push(quote! {
2776            struct #query_ident {
2777                inner: sqlx::QueryBuilder<#db>,
2778            }
2779
2780            impl #query_ident {
2781                #exec_methods
2782            }
2783
2784            #trait_impl
2785        });
2786
2787        generated_query_values.push(quote! {
2788            #( #runtime_param_bindings )*
2789            #( #flag_bindings )*
2790            let mut __builder: sqlx::QueryBuilder<#db> = sqlx::QueryBuilder::new("");
2791            #( #runtime_steps )*
2792            let #query_value_ident = #query_ident { inner: __builder };
2793        });
2794
2795        if let Some(key) = result_key {
2796            let method_ident = format_ident!("{}", key);
2797            group_field_defs.push(quote! {
2798                #method_ident: #query_ident
2799            });
2800            group_field_tys.push(quote! { #query_ident });
2801
2802            let key_ty_ident = format_ident!("__SqlForgeQueryGroupKey_{}", key);
2803            group_trait_impls.push(quote! {
2804                struct #key_ty_ident;
2805
2806                impl sql_forge::SqlForgeQueryGroupGet<#key_ty_ident, #final_type> for __SqlForgeQueryGroup {
2807                    type Query = #query_ident;
2808
2809                    fn get(self, _: #key_ty_ident) -> Self::Query {
2810                        self.#method_ident
2811                    }
2812                }
2813            });
2814            group_field_idents.push(method_ident);
2815        }
2816    }
2817
2818    // ---- Phase 9: Emit the final token stream ----
2819    let validator_tokens = quote! {
2820        let _sql_forge_validator = || {
2821            #( #validator_param_bindings )*
2822            #( #grouped_validator_invocations )*
2823        };
2824    };
2825
2826    if !is_grouped_result {
2827        let single_query_value_ident = format_ident!("__sql_forge_value_single");
2828        return quote! {
2829            {
2830                #validator_tokens
2831                #( #generated_query_defs )*
2832                #( #generated_query_values )*
2833                #single_query_value_ident
2834            }
2835        }
2836        .into();
2837    }
2838
2839    let group_field_inits: Vec<TokenStream2> = result_cases
2840        .iter()
2841        .filter_map(|(key, _, _)| key.as_ref())
2842        .map(|key| {
2843            let method_ident = format_ident!("{}", key);
2844            let query_value_ident = format_ident!("__sql_forge_value_{}", key);
2845            quote! { #method_ident: #query_value_ident }
2846        })
2847        .collect();
2848
2849    quote! {
2850        {
2851            #validator_tokens
2852
2853            #( #generated_query_defs )*
2854            #( #generated_query_values )*
2855
2856            struct __SqlForgeQueryGroup {
2857                #( #group_field_defs, )*
2858            }
2859
2860            impl __SqlForgeQueryGroup {
2861                pub fn into_parts(self) -> ( #( #group_field_tys ),* ) {
2862                    ( #( self.#group_field_idents ),* )
2863                }
2864            }
2865
2866            impl sql_forge::SqlForgeQueryGroup for __SqlForgeQueryGroup {
2867                type Db = #db;
2868            }
2869
2870            #( #group_trait_impls )*
2871
2872            __SqlForgeQueryGroup {
2873                #( #group_field_inits, )*
2874            }
2875        }
2876    }
2877    .into()
2878}
2879
2880/// Expands to the database type from the `SQL_FORGE_DB_TYPE` environment variable,
2881/// falling back to `[package.metadata.sql_forge]` in `Cargo.toml`.
2882///
2883/// ```rust,ignore
2884/// use sql_forge::db_type;
2885///
2886/// pub type AppDb = db_type!();
2887/// // expands to the type set via SQL_FORGE_DB_TYPE or Cargo.toml metadata
2888/// ```
2889///
2890/// Priority:
2891/// 1. `SQL_FORGE_DB_TYPE` env var (e.g. `sqlx::MySql`, `sqlx::Postgres`)
2892/// 2. `[package.metadata.sql_forge] db = "..."` in `Cargo.toml`
2893#[proc_macro]
2894pub fn db_type(input: TokenStream) -> TokenStream {
2895    if !input.is_empty() {
2896        return syn::Error::new(Span::call_site(), "db_type!() takes no arguments")
2897            .to_compile_error()
2898            .into();
2899    }
2900
2901    match resolve_db_from_env() {
2902        Ok(db) => quote! { #db }.into(),
2903        Err(msg) => syn::Error::new(Span::call_site(), msg)
2904            .to_compile_error()
2905            .into(),
2906    }
2907}
2908
2909/// Marks a single-value tuple struct as a transparent wrapper for use with
2910/// `sql_forge!` parameters.
2911///
2912/// Expands to `#[derive(sqlx::Type)]` + `#[sqlx(transparent)]` (needed for
2913/// all database backends so the type implements `sqlx::Encode` + `sqlx::Type`)
2914/// and additionally implements `SqlForgeValidatorValue<InnerType>`, which is
2915/// **required for PostgreSQL** to pass compile-time parameter validation in
2916/// `query_as!`. MySQL and SQLite do not need to use the trait.
2917///
2918/// ```rust,ignore
2919/// #[derive(Debug, PartialEq, Eq)]
2920/// #[sql_forge_transparent]
2921/// struct UserId(pub i64);
2922/// ```
2923#[proc_macro_attribute]
2924pub fn sql_forge_transparent(_attr: TokenStream, item: TokenStream) -> TokenStream {
2925    let input: ItemStruct = match syn::parse(item) {
2926        Ok(v) => v,
2927        Err(err) => return err.to_compile_error().into(),
2928    };
2929
2930    let struct_name = &input.ident;
2931    let inner_type = match &input.fields {
2932        Fields::Unnamed(fields) if fields.unnamed.len() == 1 => &fields.unnamed.first().unwrap().ty,
2933        _ => {
2934            return syn::Error::new(
2935                input.span(),
2936                "#[sql_forge_transparent] expects a tuple struct with exactly one field",
2937            )
2938            .to_compile_error()
2939            .into();
2940        }
2941    };
2942
2943    let attrs = input.attrs;
2944    let generics = &input.generics;
2945    let vis = &input.vis;
2946    let struct_token = input.struct_token;
2947    let semi_token = input.semi_token;
2948    let fields = &input.fields;
2949
2950    let expanded = quote! {
2951        #( #attrs )*
2952        #[derive(sqlx::Type)]
2953        #[sqlx(transparent)]
2954        #vis #struct_token #struct_name #generics #fields #semi_token
2955
2956        impl #generics sql_forge::SqlForgeValidatorValue<#inner_type> for #struct_name #generics {
2957            fn sql_forge_validator_value(&self) -> #inner_type {
2958                self.0.clone()
2959            }
2960        }
2961    };
2962
2963    expanded.into()
2964}