Skip to main content

sql_forge_macro/
lib.rs

1// =============================================================================
2// Imports
3// =============================================================================
4
5use proc_macro::TokenStream;
6use proc_macro2::{Delimiter, Group, Span, TokenStream as TokenStream2, TokenTree};
7use quote::{format_ident, quote};
8use std::collections::{HashMap, HashSet};
9use std::fmt::Write;
10use std::fs;
11use std::path::Path;
12use syn::parse::{Parse, ParseStream};
13use syn::parse_quote;
14use syn::spanned::Spanned;
15use syn::{
16    Expr, ExprBlock, ExprGroup, ExprLit, ExprParen, Fields, Ident, ItemStruct, Lit, LitStr, Pat,
17    Stmt, Token, Type,
18};
19
20// =============================================================================
21// Input types
22// =============================================================================
23
24mod kw {
25    syn::custom_keyword!(scalar);
26}
27
28/// A `:name = expr` parameter binding.
29#[derive(Clone)]
30struct ParamAssign {
31    name: Ident,
32    expr: Expr,
33}
34
35impl Parse for ParamAssign {
36    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
37        input.parse::<Token![:]>()?;
38        let name: Ident = input.parse()?;
39        input.parse::<Token![=]>()?;
40        let expr: Expr = input.parse()?;
41        Ok(Self { name, expr })
42    }
43}
44
45/// A single section value: SQL string + optional local parameter bindings.
46#[derive(Clone)]
47struct SectionFragment {
48    sql: String,
49    span: Span,
50    params: ParamsSource,
51}
52
53/// One arm in a `match { ... }` inside a section assignment.
54#[derive(Clone)]
55struct SectionMatchArm {
56    pat: Pat,
57    guard: Option<Expr>,
58    value: SectionValue,
59}
60
61/// The right-hand side of a section `#name = ...` assignment.
62#[derive(Clone)]
63enum SectionValue {
64    /// A plain string (or `(string, params)` tuple).
65    Single(SectionFragment),
66    /// A tuple of values, one per section when using `#(a, b) = ...`.
67    Grouped(Vec<SectionValue>),
68    /// A `match expr { arm => ..., arm => ... }` expression.
69    Match {
70        expr: Expr,
71        arms: Vec<SectionMatchArm>,
72    },
73}
74
75/// A full `#name = value` or `#(a, b) = value` section assignment.
76#[derive(Clone)]
77struct SectionAssign {
78    names: Vec<Ident>,
79    value: SectionValue,
80}
81
82impl Parse for SectionAssign {
83    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
84        input.parse::<Token![#]>()?;
85
86        // Parse `#(a, b)` for grouped sections, or `#name` for single.
87        let names = if input.peek(syn::token::Paren) {
88            let content;
89            syn::parenthesized!(content in input);
90            let mut out = Vec::new();
91            while !content.is_empty() {
92                out.push(content.parse::<Ident>()?);
93                if content.is_empty() {
94                    break;
95                }
96                content.parse::<Token![,]>()?;
97            }
98            if out.is_empty() {
99                return Err(input.error("sql_forge!: grouped section key list cannot be empty"));
100            }
101            out
102        } else {
103            vec![input.parse::<Ident>()?]
104        };
105
106        input.parse::<Token![=]>()?;
107        let value = parse_section_value(input, names.len())?;
108        Ok(Self { names, value })
109    }
110}
111
112/// The fully-parsed macro invocation.
113struct SqlForgeInput {
114    db: Option<Type>,
115    result: ResultSpec,
116    force_scalar: bool,
117    sql: SqlTemplate,
118    params: ParamsSource,
119    sections: Vec<SectionAssign>,
120    batch: Option<Expr>,
121}
122
123/// One entry in a result map `(>key = Model, ...)`.
124#[derive(Clone)]
125struct ResultAssign {
126    name: Ident,
127    model: Type,
128    force_scalar: bool,
129}
130
131#[derive(Clone)]
132enum ResultSpec {
133    /// Execute-only (no model), e.g. `sql_forge!("SQL", ...)`
134    None,
135    /// e.g. `sql_forge!(User, ...)`
136    Single(Box<Type>),
137    /// e.g. `sql_forge!((>a = X, >b = Y), ...)`
138    Group(Vec<ResultAssign>),
139}
140
141#[derive(Clone)]
142enum ParamsSource {
143    None,
144    /// `( :name = expr, ... )`
145    Map(Vec<ParamAssign>),
146    /// A struct expression whose fields are matched to `:name` placeholders.
147    Struct(Box<Expr>),
148}
149
150/// The SQL template. Only string literals are supported.
151enum SqlTemplate {
152    Literal(LitStr),
153}
154
155impl SqlTemplate {
156    fn span(&self) -> Span {
157        match self {
158            Self::Literal(lit) => lit.span(),
159        }
160    }
161
162    /// Parse the SQL into a sequence of textual segments and `{#section}` slots.
163    fn into_segments(self) -> Result<Vec<Segment>, String> {
164        match self {
165            Self::Literal(lit) => parse_literal_segments(&lit.value()),
166        }
167    }
168}
169
170fn parse_sql_template(input: ParseStream<'_>) -> syn::Result<SqlTemplate> {
171    if input.peek(LitStr) {
172        Ok(SqlTemplate::Literal(input.parse::<LitStr>()?))
173    } else {
174        Err(input.error("sql_forge!: SQL template must be a string literal"))
175    }
176}
177
178/// One piece of the parsed SQL template.
179#[derive(Clone)]
180enum Segment {
181    /// Plain SQL text (may contain `:param` placeholders).
182    Text(String),
183    /// A `{#section_name}` placeholder.
184    Section { name: String },
185    /// A `{( ... )}` batch value template (repeated per batch item).
186    Batch { parts: Vec<TextPart> },
187}
188
189/// A fragment of SQL text after splitting on `:param`.
190#[derive(Clone)]
191enum TextPart {
192    /// Literal SQL text.
193    Lit(String),
194    /// A `:param` placeholder.
195    Param { name: String, is_list: bool },
196}
197
198// =============================================================================
199// Parsing helpers
200// =============================================================================
201
202/// Used by `detect_parenthesized_map_kind` to identify what a `(...)` argument is.
203enum MapKind {
204    Results,
205    Params,
206    Sections,
207}
208
209// =============================================================================
210// Section value parsing (string fragments, match, grouped tuples)
211// =============================================================================
212
213/// Determines whether a parenthesized argument is a result map, params map, or sections map.
214fn detect_parenthesized_map_kind(input: ParseStream<'_>) -> syn::Result<Option<MapKind>> {
215    let fork = input.fork();
216    let content;
217    syn::parenthesized!(content in fork);
218
219    if content.is_empty() {
220        return Err(input.error("sql_forge!: map argument cannot be empty"));
221    }
222
223    if content.peek(Token![>]) {
224        Ok(Some(MapKind::Results))
225    } else if content.peek(Token![:]) {
226        Ok(Some(MapKind::Params))
227    } else if content.peek(Token![#]) {
228        Ok(Some(MapKind::Sections))
229    } else {
230        Ok(None)
231    }
232}
233
234impl Parse for ResultAssign {
235    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
236        input.parse::<Token![>]>()?;
237        let name: Ident = input.parse()?;
238        input.parse::<Token![=]>()?;
239        let (force_scalar, model) = if input.peek(kw::scalar) {
240            input.parse::<kw::scalar>()?;
241            (true, input.parse::<Type>()?)
242        } else {
243            (false, input.parse::<Type>()?)
244        };
245        Ok(Self {
246            name,
247            model,
248            force_scalar,
249        })
250    }
251}
252
253/// Parses a parenthesized result map like `(>name = Model, ...)`.
254fn parse_result_map(input: ParseStream<'_>) -> syn::Result<Vec<ResultAssign>> {
255    let content;
256    syn::parenthesized!(content in input);
257
258    let mut results = Vec::new();
259    while !content.is_empty() {
260        results.push(content.parse::<ResultAssign>()?);
261        if content.is_empty() {
262            break;
263        }
264        content.parse::<Token![,]>()?;
265    }
266
267    if results.is_empty() {
268        return Err(input.error("sql_forge!: result map cannot be empty"));
269    }
270
271    Ok(results)
272}
273
274/// Parses a parenthesized parameter map like `(:name = expr, ...)`.
275fn parse_param_map(input: ParseStream<'_>) -> syn::Result<Vec<ParamAssign>> {
276    let content;
277    syn::parenthesized!(content in input);
278
279    let mut params = Vec::new();
280    while !content.is_empty() {
281        params.push(content.parse::<ParamAssign>()?);
282        if content.is_empty() {
283            break;
284        }
285        content.parse::<Token![,]>()?;
286    }
287
288    Ok(params)
289}
290
291/// Parses a parenthesized section map like `(#name = expr, ...)`.
292fn parse_section_map(input: ParseStream<'_>) -> syn::Result<Vec<SectionAssign>> {
293    let content;
294    syn::parenthesized!(content in input);
295
296    let mut sections = Vec::new();
297    while !content.is_empty() {
298        sections.push(content.parse::<SectionAssign>()?);
299        if content.is_empty() {
300            break;
301        }
302        content.parse::<Token![,]>()?;
303    }
304
305    Ok(sections)
306}
307
308/// Parses a parameter source: either a struct expression or a parenthesized map.
309fn parse_params_source_expr(input: ParseStream<'_>) -> syn::Result<ParamsSource> {
310    if input.peek(syn::token::Paren) {
311        match detect_parenthesized_map_kind(input)? {
312            Some(MapKind::Results) => Err(input
313                .error("sql_forge!: result maps are only allowed as the macro result argument")),
314            Some(MapKind::Params) => Ok(ParamsSource::Map(parse_param_map(input)?)),
315            Some(MapKind::Sections) => Err(input.error(
316                "sql_forge!: use :name = expr for section-local parameters, not #name = expr",
317            )),
318            None => Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?))),
319        }
320    } else {
321        Ok(ParamsSource::Struct(Box::new(input.parse::<Expr>()?)))
322    }
323}
324
325/// Parses a section fragment: a string literal or `(string_literal, params)` tuple.
326fn parse_section_fragment(input: ParseStream<'_>) -> syn::Result<SectionFragment> {
327    if input.peek(syn::token::Paren) {
328        let fork = input.fork();
329        let content;
330        syn::parenthesized!(content in fork);
331
332        if let Ok(first_expr) = content.parse::<Expr>() {
333            if extract_lit_str(&first_expr).is_some() && content.parse::<Token![,]>().is_ok() {
334                let _ = parse_params_source_expr(&content)?;
335                if content.peek(Token![,]) {
336                    content.parse::<Token![,]>()?;
337                }
338                if content.is_empty() {
339                    let content;
340                    syn::parenthesized!(content in input);
341                    let first_expr: Expr = content.parse()?;
342                    let sql = extract_lit_str(&first_expr).ok_or_else(|| {
343                        input.error("sql_forge!: section tuple must start with a string literal")
344                    })?;
345                    let span = first_expr.span();
346                    content.parse::<Token![,]>()?;
347                    let params = parse_params_source_expr(&content)?;
348                    if content.peek(Token![,]) {
349                        content.parse::<Token![,]>()?;
350                    }
351                    if !content.is_empty() {
352                        return Err(content.error(
353                            "sql_forge!: unexpected tokens after section-local parameter source",
354                        ));
355                    }
356                    return Ok(SectionFragment { sql, span, params });
357                }
358            }
359        }
360    }
361
362    let expr: Expr = input.parse()?;
363    let sql = extract_lit_str(&expr).ok_or_else(|| {
364        input
365            .error("sql_forge!: section values must be string literals or (string literal, params)")
366    })?;
367    Ok(SectionFragment {
368        sql,
369        span: expr.span(),
370        params: ParamsSource::None,
371    })
372}
373
374/// Parses a section value which may be a single fragment, grouped tuple, or match expression.
375fn parse_section_value(input: ParseStream<'_>, width: usize) -> syn::Result<SectionValue> {
376    if input.peek(Token![match]) {
377        input.parse::<Token![match]>()?;
378        let expr: Expr = input.call(Expr::parse_without_eager_brace)?;
379        let content;
380        syn::braced!(content in input);
381        let mut arms = Vec::new();
382        while !content.is_empty() {
383            let pat = content.call(Pat::parse_multi_with_leading_vert)?;
384            let guard = if content.peek(Token![if]) {
385                content.parse::<Token![if]>()?;
386                Some(content.parse::<Expr>()?)
387            } else {
388                None
389            };
390            content.parse::<Token![=>]>()?;
391            // Recursive call in the match arm value allows nested match expressions.
392            let value = parse_section_value(&content, width)?;
393            if content.peek(Token![,]) {
394                content.parse::<Token![,]>()?;
395            }
396            arms.push(SectionMatchArm { pat, guard, value });
397        }
398        return Ok(SectionValue::Match { expr, arms });
399    }
400
401    if width == 1 {
402        return Ok(SectionValue::Single(parse_section_fragment(input)?));
403    }
404
405    let content;
406    syn::parenthesized!(content in input);
407    let mut items = Vec::new();
408    while !content.is_empty() {
409        // Recursive call for a section value tuple item (representing a single section).
410        // Allows match expressions returning a single section.
411        items.push(parse_section_value(&content, 1)?);
412        if content.is_empty() {
413            break;
414        }
415        content.parse::<Token![,]>()?;
416    }
417
418    if items.len() != width {
419        return Err(input.error(format!(
420            "sql_forge!: grouped section value must provide exactly {} items",
421            width,
422        )));
423    }
424
425    Ok(SectionValue::Grouped(items))
426}
427
428// =============================================================================
429// Top-level macro input parsing (SqlForgeInput::parse)
430// =============================================================================
431
432impl Parse for SqlForgeInput {
433    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
434        let (db, result, force_scalar, sql) = if input.peek(LitStr) {
435            let sql = parse_sql_template(input)?;
436            (None, ResultSpec::None, false, sql)
437        } else if input.peek(kw::scalar) {
438            input.parse::<kw::scalar>()?;
439            let model: Type = input.parse()?;
440            input.parse::<Token![,]>()?;
441            let sql = parse_sql_template(input)?;
442            (None, ResultSpec::Single(Box::new(model)), true, sql)
443        } else if input.peek(syn::token::Paren) {
444            let result_map_kind = detect_parenthesized_map_kind(input)?;
445            match result_map_kind {
446                Some(MapKind::Results) => {
447                    let result = ResultSpec::Group(parse_result_map(input)?);
448                    input.parse::<Token![,]>()?;
449                    let sql = parse_sql_template(input)?;
450                    (None, result, false, sql)
451                }
452                _ => {
453                    return Err(input.error(
454                        "sql_forge!: expected a result map like (>name = Model, ...) or a model type",
455                    ));
456                }
457            }
458        } else {
459            let first_ty: Type = input.parse()?;
460            input.parse::<Token![,]>()?;
461
462            if input.peek(LitStr) && is_db_type(&first_ty) {
463                let db = first_ty;
464                let sql = parse_sql_template(input)?;
465                (Some(db), ResultSpec::None, false, sql)
466            } else if input.peek(LitStr) {
467                let model = first_ty;
468                let sql = parse_sql_template(input)?;
469                (None, ResultSpec::Single(Box::new(model)), false, sql)
470            } else if input.peek(kw::scalar) {
471                input.parse::<kw::scalar>()?;
472                let model: Type = input.parse()?;
473                input.parse::<Token![,]>()?;
474                let sql = parse_sql_template(input)?;
475                (
476                    Some(first_ty),
477                    ResultSpec::Single(Box::new(model)),
478                    true,
479                    sql,
480                )
481            } else if input.peek(syn::token::Paren)
482                && matches!(
483                    detect_parenthesized_map_kind(input)?,
484                    Some(MapKind::Results)
485                )
486            {
487                let result = ResultSpec::Group(parse_result_map(input)?);
488                input.parse::<Token![,]>()?;
489                let sql = parse_sql_template(input)?;
490                (Some(first_ty), result, false, sql)
491            } else {
492                let db = Some(first_ty);
493                let model: Type = input.parse()?;
494                input.parse::<Token![,]>()?;
495                let sql = parse_sql_template(input)?;
496                (db, ResultSpec::Single(Box::new(model)), false, sql)
497            }
498        };
499
500        let mut batch = None;
501        let mut params = ParamsSource::None;
502        let mut sections = Vec::new();
503        let mut seen_params = false;
504        let mut seen_sections = false;
505
506        if input.parse::<Token![,]>().is_ok() {
507            while !input.is_empty() {
508                if input.peek(Token![..]) {
509                    if batch.is_some() {
510                        return Err(
511                            input.error("sql_forge!: only one batch source argument is allowed")
512                        );
513                    }
514                    input.parse::<Token![..]>()?;
515                    batch = Some(input.parse::<Expr>()?);
516                } else if input.peek(syn::token::Paren) {
517                    match detect_parenthesized_map_kind(input)? {
518                        Some(MapKind::Results) => {
519                            return Err(input.error(
520                                "sql_forge!: result maps are only allowed as the macro result argument",
521                            ));
522                        }
523                        Some(MapKind::Params) => {
524                            if seen_params {
525                                return Err(
526                                    input.error("sql_forge!: only one parameter source is allowed")
527                                );
528                            }
529                            params = ParamsSource::Map(parse_param_map(input)?);
530                            seen_params = true;
531                        }
532                        Some(MapKind::Sections) => {
533                            if seen_sections {
534                                return Err(
535                                    input.error("sql_forge!: duplicate section map argument")
536                                );
537                            }
538                            sections = parse_section_map(input)?;
539                            seen_sections = true;
540                        }
541                        None => {
542                            if seen_params {
543                                return Err(
544                                    input.error("sql_forge!: only one parameter source is allowed")
545                                );
546                            }
547                            params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
548                            seen_params = true;
549                        }
550                    }
551                } else {
552                    if seen_params {
553                        return Err(input.error("sql_forge!: only one parameter source is allowed"));
554                    }
555                    params = ParamsSource::Struct(Box::new(input.parse::<Expr>()?));
556                    seen_params = true;
557                }
558
559                if input.parse::<Token![,]>().is_ok() {
560                    continue;
561                }
562                break;
563            }
564        }
565
566        if !input.is_empty() {
567            return Err(input.error("sql_forge!: unexpected tokens in macro invocation"));
568        }
569
570        Ok(Self {
571            db,
572            result,
573            force_scalar,
574            sql,
575            params,
576            sections,
577            batch,
578        })
579    }
580}
581
582// =============================================================================
583// Database type resolution
584// =============================================================================
585
586/// Resolves the database type from `SQL_FORGE_DB_TYPE` env var or `Cargo.toml` metadata.
587fn resolve_db_from_env() -> Result<Type, String> {
588    if let Ok(val) = std::env::var("SQL_FORGE_DB_TYPE") {
589        return syn::parse_str::<Type>(&val).map_err(|err| {
590            format!(
591                "sql_forge!: invalid DB type `{}` in SQL_FORGE_DB_TYPE env var: {}",
592                val, err
593            )
594        });
595    }
596
597    let manifest_dir = match std::env::var("CARGO_MANIFEST_DIR") {
598        Ok(d) => d,
599        Err(_) => {
600            return Err(
601                "sql_forge!: pass DB as first macro argument, set SQL_FORGE_DB_TYPE, \
602                 or configure [package.metadata.sql_forge] in Cargo.toml"
603                    .to_string(),
604            );
605        }
606    };
607    let manifest_path = Path::new(&manifest_dir).join("Cargo.toml");
608
609    let cargo_toml = fs::read_to_string(&manifest_path).map_err(|err| {
610        format!(
611            "sql_forge!: failed to read {}: {}",
612            manifest_path.display(),
613            err
614        )
615    })?;
616
617    let value: toml::Value = toml::from_str(&cargo_toml)
618        .map_err(|err| format!("sql_forge!: failed to parse Cargo.toml: {}", err))?;
619
620    let db_str = value
621        .get("package")
622        .and_then(|v| v.get("metadata"))
623        .and_then(|v| v.get("sql_forge"))
624        .and_then(|v| v.get("db"))
625        .and_then(|v| v.as_str())
626        .ok_or({
627            "sql_forge!: missing [package.metadata.sql_forge] db = \"...\" in Cargo.toml, \
628             SQL_FORGE_DB_TYPE env var, or DB as first macro argument"
629        })?;
630
631    syn::parse_str::<Type>(db_str).map_err(|err| {
632        format!(
633            "sql_forge!: invalid DB type `{}` in Cargo.toml metadata: {}",
634            db_str, err
635        )
636    })
637}
638
639/// Returns true if the database type uses dollar-parameter placeholders (Postgres).
640fn uses_dollar_params(db: &Type) -> bool {
641    let Type::Path(type_path) = db else {
642        return false;
643    };
644    type_path
645        .path
646        .segments
647        .last()
648        .is_some_and(|s| s.ident == "Postgres")
649}
650
651/// Returns true if the type is a sqlx database type (sqlx::MySql, sqlx::Postgres, sqlx::Sqlite).
652fn is_db_type(ty: &Type) -> bool {
653    let Type::Path(type_path) = ty else {
654        return false;
655    };
656    if type_path.qself.is_some() {
657        return false;
658    }
659    let segs = &type_path.path.segments;
660    if segs.len() != 2 {
661        return false;
662    }
663    segs[0].ident == "sqlx"
664        && ["MySql", "Postgres", "Sqlite"].contains(&segs[1].ident.to_string().as_str())
665}
666
667/// Returns true if a type is a standard Rust built-in scalar (i32, String, etc.).
668fn is_builtin_scalar_type(ty: &Type) -> bool {
669    let Type::Path(type_path) = ty else {
670        return false;
671    };
672
673    if type_path.qself.is_some()
674        || type_path.path.leading_colon.is_some()
675        || type_path.path.segments.len() != 1
676    {
677        return false;
678    }
679
680    let ident = &type_path.path.segments[0].ident;
681    ident == "i8"
682        || ident == "i16"
683        || ident == "i32"
684        || ident == "i64"
685        || ident == "isize"
686        || ident == "u8"
687        || ident == "u16"
688        || ident == "u32"
689        || ident == "u64"
690        || ident == "usize"
691        || ident == "f32"
692        || ident == "f64"
693        || ident == "bool"
694        || ident == "String"
695}
696
697/// If the given model type is a built-in scalar, returns the type; otherwise None.
698fn scalar_output_type(model: &Type) -> Option<&Type> {
699    if is_builtin_scalar_type(model) {
700        return Some(model);
701    }
702    None
703}
704
705/// Appends text to the last `Segment::Text` or pushes a new one.
706fn push_text_segment(out: &mut Vec<Segment>, text: String) {
707    if text.is_empty() {
708        return;
709    }
710    match out.last_mut() {
711        Some(Segment::Text(existing)) => existing.push_str(&text),
712        _ => out.push(Segment::Text(text)),
713    }
714}
715
716/// Parses a SQL literal into segments (text, `{#section}`, `{(batch)}`).
717fn parse_literal_segments(sql: &str) -> Result<Vec<Segment>, String> {
718    let mut out = Vec::new();
719    let mut text = String::new();
720    let mut chars = sql.chars().peekable();
721
722    while let Some(ch) = chars.next() {
723        if ch != '{' {
724            text.push(ch);
725            continue;
726        }
727
728        if chars.peek() == Some(&'(') {
729            push_text_segment(&mut out, std::mem::take(&mut text));
730
731            let mut paren_depth = 0u32;
732            let mut content = String::new();
733            let mut found_close = false;
734            for ch in chars.by_ref() {
735                if ch == '{' {
736                    return Err(
737                        "sql_forge!: nested braces not allowed inside batch section".to_string()
738                    );
739                }
740                if ch == '}' {
741                    if paren_depth != 0 {
742                        return Err(
743                            "sql_forge!: batch section {( ... )} has unbalanced parentheses"
744                                .to_string(),
745                        );
746                    }
747                    found_close = true;
748                    break;
749                }
750                if ch == '(' {
751                    paren_depth += 1;
752                } else if ch == ')' {
753                    if paren_depth == 0 {
754                        return Err(
755                            "sql_forge!: batch section {( ... )} has unbalanced parentheses"
756                                .to_string(),
757                        );
758                    }
759                    paren_depth -= 1;
760                }
761                content.push(ch);
762            }
763            if !found_close {
764                return Err("sql_forge!: batch section {( ... )} without closing }".to_string());
765            }
766            let parts = parse_text_parts(&content);
767            for part in &parts {
768                if let TextPart::Param { is_list: true, .. } = part {
769                    return Err(
770                        "sql_forge!: list parameters (:name[]) are not allowed inside {( ... )} \
771                         batch sections; use plain parameters (:name) instead"
772                            .to_string(),
773                    );
774                }
775            }
776            out.push(Segment::Batch { parts });
777            continue;
778        }
779
780        if chars.peek() != Some(&'#') {
781            text.push(ch);
782            continue;
783        }
784
785        chars.next();
786        push_text_segment(&mut out, std::mem::take(&mut text));
787
788        let mut name = String::new();
789        loop {
790            let Some(next) = chars.next() else {
791                return Err("sql_forge!: section placeholder without closing }".to_string());
792            };
793            if next == '}' {
794                break;
795            }
796            name.push(next);
797        }
798
799        if name.is_empty() {
800            return Err("sql_forge!: empty section placeholder name".to_string());
801        }
802
803        out.push(Segment::Section { name });
804    }
805
806    push_text_segment(&mut out, text);
807    Ok(out)
808}
809
810// =============================================================================
811// SQL template parsing: {#sections} and :param placeholders
812// =============================================================================
813
814/// Returns true if the character can start an identifier.
815fn is_ident_start(ch: char) -> bool {
816    ch == '_' || ch.is_ascii_alphabetic()
817}
818
819/// Returns true if the character can continue an identifier.
820fn is_ident_continue(ch: char) -> bool {
821    is_ident_start(ch) || ch.is_ascii_digit()
822}
823
824/// Strips suffix markers from a backtick-quoted alias identifier.
825fn sanitize_backticked_alias_ident(content: &str) -> String {
826    let mut split_at = content.len();
827    for (idx, ch) in content.char_indices() {
828        if ch == '!' || ch == '?' || ch == ':' {
829            split_at = idx;
830            break;
831        }
832    }
833
834    if split_at == content.len() {
835        return content.to_string();
836    }
837
838    let base = content[..split_at].trim_end();
839    if base.is_empty() {
840        content.to_string()
841    } else {
842        base.to_string()
843    }
844}
845
846/// Sanitizes backtick-quoted alias identifiers in runtime SQL text.
847fn sanitize_runtime_sql_text(text: &str) -> String {
848    let mut out = String::with_capacity(text.len());
849    let mut chars = text.chars().peekable();
850
851    while let Some(ch) = chars.next() {
852        if ch != '`' {
853            out.push(ch);
854            continue;
855        }
856
857        let mut content = String::new();
858        let mut closed = false;
859
860        for next in chars.by_ref() {
861            if next == '`' {
862                closed = true;
863                break;
864            }
865            content.push(next);
866        }
867
868        if closed {
869            out.push('`');
870            out.push_str(&sanitize_backticked_alias_ident(&content));
871            out.push('`');
872        } else {
873            out.push('`');
874            out.push_str(&content);
875            break;
876        }
877    }
878
879    out
880}
881
882/// Splits SQL text into literal text and `:param` placeholder parts.
883fn parse_text_parts(text: &str) -> Vec<TextPart> {
884    let mut parts = Vec::new();
885    let mut last = 0usize;
886    let mut iter = text.char_indices().peekable();
887
888    while let Some((idx, ch)) = iter.next() {
889        if ch != ':' {
890            continue;
891        }
892
893        let Some(&(next_idx, next_ch)) = iter.peek() else {
894            continue;
895        };
896
897        if !is_ident_start(next_ch) {
898            continue;
899        }
900
901        if text[..idx].ends_with(':') {
902            continue;
903        }
904
905        if last < idx {
906            parts.push(TextPart::Lit(text[last..idx].to_string()));
907        }
908
909        iter.next();
910
911        let mut name = String::new();
912        name.push(next_ch);
913        let mut end = next_idx + next_ch.len_utf8();
914
915        while let Some(&(j, c)) = iter.peek() {
916            if is_ident_continue(c) {
917                name.push(c);
918                end = j + c.len_utf8();
919                iter.next();
920            } else {
921                break;
922            }
923        }
924
925        let mut is_list = false;
926        if text[end..].starts_with("[]") {
927            is_list = true;
928            end += 2;
929        }
930
931        parts.push(TextPart::Param { name, is_list });
932        last = end;
933    }
934
935    if last < text.len() {
936        parts.push(TextPart::Lit(text[last..].to_string()));
937    }
938
939    parts
940}
941
942/// Renders SQL text for the compile-time validator, replacing `:param`/`:param[]`
943/// with DB-specific placeholders (`?` or `$N`).
944///
945/// When `batch_expr` is `Some`, each `:param` is treated as a batch-item field
946/// access (`batch_expr[0].field_ident`) and returned in `batch_args` instead of
947/// `occurrences`. List params (`:param[]`) are not allowed in batch sections
948/// (caught earlier during parsing, but defended here for safety).
949///
950/// Returns `(rendered_sql, occurrences, batch_args)` — `occurrences` is empty
951/// in batch mode and `batch_args` is empty in non-batch mode.
952#[allow(clippy::type_complexity)]
953fn render_validator_sql(
954    parts: &[TextPart],
955    use_dollar_params: bool,
956    param_offset: &mut usize,
957    list_count: usize,
958    batch_expr: Option<TokenStream2>,
959) -> Result<(String, Vec<(String, bool)>, Vec<TokenStream2>), TokenStream> {
960    let mut out_sql = String::new();
961    let mut occurrences = Vec::new();
962    let mut batch_args = Vec::new();
963
964    for part in parts {
965        match part {
966            TextPart::Lit(lit) => out_sql.push_str(lit),
967            TextPart::Param { name, is_list } => {
968                if let Some(ref batch_expr) = batch_expr {
969                    if *is_list {
970                        return Err(syn::Error::new(
971                            Span::call_site(),
972                            "sql_forge!: list parameters (:name[]) are not allowed inside {( ... )} \
973                             batch sections; use plain parameters (:name) instead",
974                        )
975                        .to_compile_error()
976                        .into());
977                    }
978                    let field_ident = format_ident!("{}", name);
979                    if use_dollar_params {
980                        *param_offset += 1;
981                        write!(out_sql, "${}", *param_offset).unwrap();
982                    } else {
983                        out_sql.push('?');
984                    }
985                    batch_args.push(quote! { #batch_expr[0].#field_ident });
986                } else if *is_list && list_count > 1 {
987                    let slots: Vec<String> = if use_dollar_params {
988                        (0..list_count)
989                            .map(|i| format!("${}", *param_offset + i + 1))
990                            .collect()
991                    } else {
992                        (0..list_count).map(|_| "?".to_string()).collect()
993                    };
994                    if use_dollar_params {
995                        *param_offset += list_count;
996                    }
997                    out_sql.push_str(&slots.join(", "));
998                    occurrences.push((name.clone(), *is_list));
999                } else {
1000                    if use_dollar_params {
1001                        *param_offset += 1;
1002                        write!(out_sql, "${}", *param_offset).unwrap();
1003                    } else {
1004                        out_sql.push('?');
1005                    }
1006                    occurrences.push((name.clone(), *is_list));
1007                }
1008            }
1009        }
1010    }
1011
1012    Ok((out_sql, occurrences, batch_args))
1013}
1014
1015/// Recursively strips syntactic wrapper expressions to reach the inner expression.
1016///
1017/// # Wrapper kinds
1018///
1019/// | Variant          | Rust syntax           | When it appears                                |
1020/// |------------------|-----------------------|------------------------------------------------|
1021/// | `Expr::Paren`    | `(expr)`              | Parenthesized expression in source code. syn   |
1022/// |                  |                       | preserves parens so tokens round-trip.         |
1023/// | `Expr::Group`    | `{ expr }`            | *Compiler-inserted* invisible group. syn uses  |
1024/// |                  |                       | this when a proc-macro receives braced tokens  |
1025/// |                  |                       | containing a single expression as `proc_macro! |
1026/// |                  |                       | { expr }`. Unlike `Paren`, `Group` does not    |
1027/// |                  |                       | correspond to parens the user wrote.           |
1028/// | `Expr::Block`    | `{ stmts }` or        | A real block expression the user wrote —       |
1029/// |                  | `{ stmts; }`          | `{ x; y }`. Stripped only when it contains     |
1030/// |                  |                       | exactly one `Stmt::Expr` without a trailing    |
1031/// |                  |                       | semicolon. Multi-statement blocks and non-expr |
1032/// |                  |                       | statements would change semantics if stripped, |
1033/// |                  |                       | so they are returned as-is.                    |
1034///
1035/// Returns the original expression unchanged when none of these patterns match, making
1036/// it safe to call unconditionally before pattern-matching against the expression kind.
1037fn strip_expr(expr: &Expr) -> &Expr {
1038    match expr {
1039        Expr::Paren(ExprParen { expr, .. }) => strip_expr(expr),
1040        Expr::Group(ExprGroup { expr, .. }) => strip_expr(expr),
1041        Expr::Block(ExprBlock { block, .. }) => {
1042            if block.stmts.len() != 1 {
1043                return expr;
1044            }
1045            match &block.stmts[0] {
1046                Stmt::Expr(inner, None) => strip_expr(inner),
1047                _ => expr,
1048            }
1049        }
1050        _ => expr,
1051    }
1052}
1053
1054/// Extracts the string value from a string literal expression.
1055fn extract_lit_str(expr: &Expr) -> Option<String> {
1056    match strip_expr(expr) {
1057        Expr::Lit(ExprLit {
1058            lit: Lit::Str(lit), ..
1059        }) => Some(lit.value()),
1060        _ => None,
1061    }
1062}
1063
1064// =============================================================================
1065// Preprocessing: {>key} compile-time result flags
1066// =============================================================================
1067
1068/// Generates a `__sql_forge_result_flag_{name}` identifier for `{>name}` result key flags.
1069fn result_flag_ident(name: &str) -> syn::Ident {
1070    format_ident!("__sql_forge_result_flag_{}", name)
1071}
1072
1073/// Replaces `{>key}` tokens inside braced groups with `__sql_forge_result_flag_key`
1074/// identifiers. This is a preprocessing step so that the rest of the parser sees
1075/// plain identifiers instead of braced groups it does not understand.
1076fn preprocess_result_key_placeholders(input: TokenStream2) -> TokenStream2 {
1077    fn walk(stream: TokenStream2) -> TokenStream2 {
1078        let mut out = TokenStream2::new();
1079        let iter = stream.into_iter().peekable();
1080
1081        for token in iter {
1082            match token {
1083                TokenTree::Group(group) => {
1084                    if group.delimiter() == Delimiter::Brace {
1085                        let mut inner = group.stream().into_iter();
1086                        let first = inner.next();
1087                        let second = inner.next();
1088                        let third = inner.next();
1089
1090                        if let (
1091                            Some(TokenTree::Punct(p)),
1092                            Some(TokenTree::Ident(name_ident)),
1093                            None,
1094                        ) = (first, second, third)
1095                        {
1096                            if p.as_char() == '>' {
1097                                let ident = result_flag_ident(&name_ident.to_string());
1098                                out.extend(std::iter::once(TokenTree::Ident(ident)));
1099                                continue;
1100                            }
1101                        }
1102                    }
1103
1104                    let new_inner = walk(group.stream());
1105                    let mut new_group = Group::new(group.delimiter(), new_inner);
1106                    new_group.set_span(group.span());
1107                    out.extend(std::iter::once(TokenTree::Group(new_group)));
1108                }
1109                other => out.extend(std::iter::once(other)),
1110            }
1111        }
1112
1113        out
1114    }
1115
1116    walk(input)
1117}
1118
1119/// Builds `let` bindings for `{>key}` result flag identifiers.
1120fn build_result_flag_bindings(keys: &[String], active_key: Option<&str>) -> Vec<TokenStream2> {
1121    keys.iter()
1122        .map(|key| {
1123            let ident = result_flag_ident(key);
1124            let enabled = Some(key.as_str()) == active_key;
1125            quote! { let #ident: bool = #enabled; }
1126        })
1127        .collect()
1128}
1129
1130/// Transposes a `[case][section]` matrix to `[section][case]`.
1131fn transpose_section_case_matrix(
1132    case_matrix: Vec<Vec<SectionFragment>>,
1133    width: usize,
1134) -> Result<Vec<Vec<SectionFragment>>, String> {
1135    let mut per_section: Vec<Vec<SectionFragment>> = (0..width).map(|_| Vec::new()).collect();
1136
1137    for row in case_matrix {
1138        if row.len() != width {
1139            return Err(
1140                "sql_forge!: grouped sections must return one item per section".to_string(),
1141            );
1142        }
1143        for (section_idx, fragment) in row.into_iter().enumerate() {
1144            per_section[section_idx].push(fragment);
1145        }
1146    }
1147
1148    Ok(per_section)
1149}
1150
1151/// Collects all section fragment cases as a `[case][section]` matrix for a given section value.
1152fn collect_section_case_matrix(
1153    value: SectionValue,
1154    width: usize,
1155    active_key: Option<&str>,
1156) -> Result<Vec<Vec<SectionFragment>>, String> {
1157    match value {
1158        SectionValue::Single(fragment) => {
1159            if width != 1 {
1160                return Err(
1161                    "sql_forge!: grouped sections must return one item per section".to_string(),
1162                );
1163            }
1164            Ok(vec![vec![fragment]])
1165        }
1166        SectionValue::Grouped(values) => {
1167            if values.len() != width {
1168                return Err(
1169                    "sql_forge!: grouped sections must return one item per section".to_string(),
1170                );
1171            }
1172
1173            let mut variants_by_section = Vec::<Vec<SectionFragment>>::with_capacity(width);
1174            let mut nmax = 1usize;
1175
1176            for value in values {
1177                let item_matrix = collect_section_case_matrix(value, 1, active_key)?;
1178                let mut item_variants = Vec::<SectionFragment>::with_capacity(item_matrix.len());
1179                for mut row in item_matrix {
1180                    let fragment = row.pop().ok_or_else(|| {
1181                        "sql_forge!: grouped sections must return one item per section".to_string()
1182                    })?;
1183                    if !row.is_empty() {
1184                        return Err(
1185                            "sql_forge!: grouped sections must return one item per section"
1186                                .to_string(),
1187                        );
1188                    }
1189                    item_variants.push(fragment);
1190                }
1191                if item_variants.is_empty() {
1192                    return Err("sql_forge!: section match must have at least one arm".to_string());
1193                }
1194                nmax = nmax.max(item_variants.len());
1195                variants_by_section.push(item_variants);
1196            }
1197
1198            let mut case_matrix = Vec::<Vec<SectionFragment>>::with_capacity(nmax);
1199            for case_idx in 0..nmax {
1200                let mut row = Vec::<SectionFragment>::with_capacity(width);
1201                for variants in &variants_by_section {
1202                    row.push(variants[case_idx % variants.len()].clone());
1203                }
1204                case_matrix.push(row);
1205            }
1206
1207            Ok(case_matrix)
1208        }
1209        SectionValue::Match { expr, arms } => {
1210            let mut case_matrix = Vec::<Vec<SectionFragment>>::new();
1211
1212            if let Some(key) = expr_result_flag_key(&expr) {
1213                let target = active_key == Some(key.as_str());
1214                for arm in arms {
1215                    if arm.guard.is_none() {
1216                        if let Some(false) = pattern_matches_bool(&arm.pat, target) {
1217                            continue;
1218                        }
1219                    }
1220                    let mut arm_cases = collect_section_case_matrix(arm.value, width, active_key)?;
1221                    wrap_section_case_matrix_for_match_arm(
1222                        &mut arm_cases,
1223                        &expr,
1224                        &arm.pat,
1225                        arm.guard.as_ref(),
1226                    );
1227                    case_matrix.extend(arm_cases);
1228                }
1229            } else {
1230                for arm in arms {
1231                    let mut arm_cases = collect_section_case_matrix(arm.value, width, active_key)?;
1232                    wrap_section_case_matrix_for_match_arm(
1233                        &mut arm_cases,
1234                        &expr,
1235                        &arm.pat,
1236                        arm.guard.as_ref(),
1237                    );
1238                    case_matrix.extend(arm_cases);
1239                }
1240            }
1241
1242            if case_matrix.is_empty() {
1243                return Err("sql_forge!: section match must have at least one arm".to_string());
1244            }
1245
1246            Ok(case_matrix)
1247        }
1248    }
1249}
1250
1251/// This function rewrites a section-local validator expression so it stays inside the same
1252/// match arm scope that originally introduced it
1253fn wrap_expr_for_match_arm(expr: Expr, match_expr: &Expr, pat: &Pat, guard: Option<&Expr>) -> Expr {
1254    let match_expr = match_expr.clone();
1255    let pat = pat.clone();
1256    let pattern_binds_values = match &pat {
1257        Pat::Ident(_) => true,
1258        Pat::Or(pat_or) => pat_or
1259            .cases
1260            .iter()
1261            .any(|case| matches!(case, Pat::Ident(_))),
1262        Pat::Paren(pat_paren) => matches!(pat_paren.pat.as_ref(), Pat::Ident(_)),
1263        Pat::Reference(pat_reference) => matches!(pat_reference.pat.as_ref(), Pat::Ident(_)),
1264        Pat::Slice(pat_slice) => pat_slice
1265            .elems
1266            .iter()
1267            .any(|elem| matches!(elem, Pat::Ident(_))),
1268        Pat::Struct(pat_struct) => pat_struct
1269            .fields
1270            .iter()
1271            .any(|field| matches!(*field.pat, Pat::Ident(_))),
1272        Pat::Tuple(pat_tuple) => pat_tuple
1273            .elems
1274            .iter()
1275            .any(|elem| matches!(elem, Pat::Ident(_))),
1276        Pat::TupleStruct(pat_tuple_struct) => pat_tuple_struct
1277            .elems
1278            .iter()
1279            .any(|elem| matches!(elem, Pat::Ident(_))),
1280        Pat::Type(pat_type) => matches!(pat_type.pat.as_ref(), Pat::Ident(_)),
1281        _ => false,
1282    };
1283
1284    if pattern_binds_values {
1285        let pat_refs: Vec<TokenStream2> = pat_var_idents(&pat)
1286            .into_iter()
1287            .map(|ident| quote! { let _ = &#ident; })
1288            .collect();
1289        if let Some(guard) = guard.cloned() {
1290            parse_quote! {
1291                match &(#match_expr) {
1292                    #pat if #guard => { #( #pat_refs )* #expr },
1293                    _ => unreachable!("sql_forge!: validator arm mismatch"),
1294                }
1295            }
1296        } else {
1297            parse_quote! {
1298                match &(#match_expr) {
1299                    #pat => { #( #pat_refs )* #expr },
1300                    _ => unreachable!("sql_forge!: validator arm mismatch"),
1301                }
1302            }
1303        }
1304    } else if let Some(guard) = guard.cloned() {
1305        parse_quote! {
1306            match &(#match_expr) {
1307                #pat if #guard => { &(#expr) },
1308                _ => unreachable!("sql_forge!: validator arm mismatch"),
1309            }
1310        }
1311    } else {
1312        parse_quote! {
1313            match &(#match_expr) {
1314                #pat => { &(#expr) },
1315                _ => unreachable!("sql_forge!: validator arm mismatch"),
1316            }
1317        }
1318    }
1319}
1320
1321/// Wraps a ParamsSource's expressions so they stay within the match arm scope.
1322fn wrap_params_source_for_match_arm(
1323    params: &mut ParamsSource,
1324    match_expr: &Expr,
1325    pat: &Pat,
1326    guard: Option<&Expr>,
1327) {
1328    match params {
1329        ParamsSource::None => {}
1330        ParamsSource::Map(entries) => {
1331            for entry in entries {
1332                entry.expr = wrap_expr_for_match_arm(entry.expr.clone(), match_expr, pat, guard);
1333            }
1334        }
1335        ParamsSource::Struct(expr) => {
1336            **expr = wrap_expr_for_match_arm((**expr).clone(), match_expr, pat, guard);
1337        }
1338    }
1339}
1340
1341/// Wraps all ParamsSource expressions in a section case matrix for the given match arm.
1342fn wrap_section_case_matrix_for_match_arm(
1343    case_matrix: &mut [Vec<SectionFragment>],
1344    match_expr: &Expr,
1345    pat: &Pat,
1346    guard: Option<&Expr>,
1347) {
1348    for row in case_matrix {
1349        for fragment in row {
1350            wrap_params_source_for_match_arm(&mut fragment.params, match_expr, pat, guard);
1351        }
1352    }
1353}
1354
1355// Returns Vec<Vec<SectionFragment>> indexed [section_idx][case_idx].
1356// =============================================================================
1357// Section variant collection
1358// =============================================================================
1359
1360/// Collects all possible `SectionFragment` values per section index. Each
1361/// returned `Vec<Vec<SectionFragment>>` is indexed `[section_idx][case_idx]`,
1362/// listing every variant that section can take across all match arms.
1363/// Used for full validation (cycling strategy over all variants).
1364fn collect_section_variants(
1365    value: SectionValue,
1366    width: usize,
1367) -> Result<Vec<Vec<SectionFragment>>, String> {
1368    transpose_section_case_matrix(collect_section_case_matrix(value, width, None)?, width)
1369}
1370
1371/// Returns the key name if the expression is a `{>key}` result flag; otherwise None.
1372fn expr_result_flag_key(expr: &Expr) -> Option<String> {
1373    match strip_expr(expr) {
1374        Expr::Path(path) if path.qself.is_none() && path.path.segments.len() == 1 => {
1375            let name = path.path.segments[0].ident.to_string();
1376            name.strip_prefix("__sql_forge_result_flag_")
1377                .map(|v| v.to_string())
1378        }
1379        _ => None,
1380    }
1381}
1382
1383/// Returns whether a pattern matches a specific boolean value at compile time.
1384fn pattern_matches_bool(pat: &Pat, value: bool) -> Option<bool> {
1385    match pat {
1386        Pat::Lit(expr_lit) => match &expr_lit.lit {
1387            Lit::Bool(lit_bool) => Some(lit_bool.value == value),
1388            _ => None,
1389        },
1390        Pat::Wild(_) => Some(true),
1391        _ => None,
1392    }
1393}
1394
1395/// Like `collect_section_variants`, but filters `match` arms by the active
1396/// result key when the match expression is a `{>key}` flag. When building the
1397/// query for a specific key, only the matching arm (true/false) is included;
1398/// arms with guards or non-flag expressions include all variants as usual.
1399fn collect_section_variants_for_result(
1400    value: SectionValue,
1401    width: usize,
1402    active_key: Option<&str>,
1403) -> Result<Vec<Vec<SectionFragment>>, String> {
1404    transpose_section_case_matrix(
1405        collect_section_case_matrix(value, width, active_key)?,
1406        width,
1407    )
1408}
1409
1410// =============================================================================
1411// Parameter binding generation
1412// =============================================================================
1413
1414/// Generates `let` bindings for all parameters used in the SQL or section
1415/// fragments. Returns a map of param-name → local ident for later reference,
1416/// and a list of `let` statements.
1417fn build_param_bindings(
1418    params: &ParamsSource,
1419    used_param_names: &[String],
1420    prefix: &str,
1421    for_validator: bool,
1422    enforce_usage_check: bool,
1423) -> Result<(HashMap<String, syn::Ident>, Vec<TokenStream2>), TokenStream> {
1424    let mut declared_params = HashMap::<String, syn::Ident>::new();
1425    let mut bindings = Vec::<TokenStream2>::new();
1426
1427    match params {
1428        ParamsSource::None => {}
1429        ParamsSource::Map(entries) => {
1430            for entry in entries {
1431                let key = entry.name.to_string();
1432                if declared_params.contains_key(&key) {
1433                    return Err(syn::Error::new(
1434                        entry.name.span(),
1435                        "sql_forge!: duplicated parameter mapping",
1436                    )
1437                    .to_compile_error()
1438                    .into());
1439                }
1440                if enforce_usage_check && !used_param_names.iter().any(|n| n == &key) {
1441                    return Err(syn::Error::new(
1442                        entry.name.span(),
1443                        format!(
1444                            "sql_forge!: parameter :{} is unused in the SQL template",
1445                            key,
1446                        ),
1447                    )
1448                    .to_compile_error()
1449                    .into());
1450                }
1451                let local_ident = format_ident!("__sql_forge_{}_{}", prefix, key);
1452                let expr = &entry.expr;
1453                if for_validator {
1454                    bindings.push(quote! {
1455                        let #local_ident = &(#expr);
1456                    });
1457                } else {
1458                    bindings.push(quote! {
1459                        let #local_ident = #expr;
1460                    });
1461                }
1462                declared_params.insert(key, local_ident);
1463            }
1464        }
1465        ParamsSource::Struct(expr) => {
1466            let source_ident = format_ident!("__sql_forge_source_{}", prefix);
1467            bindings.push(quote! {
1468                let #source_ident = &(#expr);
1469            });
1470            for name in used_param_names {
1471                let local_ident = format_ident!("__sql_forge_{}_{}", prefix, name);
1472                let field_ident = format_ident!("{}", name);
1473                if for_validator {
1474                    bindings.push(quote! {
1475                        let #local_ident = &#source_ident.#field_ident;
1476                    });
1477                } else {
1478                    bindings.push(quote! {
1479                        let #local_ident = #source_ident.#field_ident;
1480                    });
1481                }
1482                declared_params.insert(name.to_string(), local_ident);
1483            }
1484        }
1485    }
1486
1487    Ok((declared_params, bindings))
1488}
1489
1490struct ValidatorRenderContext<'a> {
1491    params: &'a HashMap<String, syn::Ident>,
1492    use_dollar_params: bool,
1493    sql_span: Span,
1494    list_count: usize,
1495}
1496
1497/// Builds the placeholders SQL string and argument list for the compile-time
1498/// validator (sqlx::query_as! / query_scalar!). Each `:param` in the SQL is
1499/// replaced by `?` (MySQL/SQLite) or `$1`/`$2`/... (PostgreSQL), and the
1500/// corresponding value expression is collected into the args list.
1501fn render_validator_args(
1502    sql: &str,
1503    param_offset: &mut usize,
1504    arg_index: &mut usize,
1505    context: &ValidatorRenderContext<'_>,
1506) -> Result<(String, Vec<TokenStream2>, Vec<TokenStream2>), TokenStream> {
1507    let parts = parse_text_parts(sql);
1508    let (rendered_sql, occurrences, _batch_args) = render_validator_sql(
1509        &parts,
1510        context.use_dollar_params,
1511        param_offset,
1512        context.list_count,
1513        None,
1514    )?;
1515
1516    let mut setup = Vec::<TokenStream2>::new();
1517    let mut args = Vec::<TokenStream2>::new();
1518
1519    for (name, is_list) in occurrences {
1520        let Some(local_ident) = context.params.get(&name) else {
1521            return Err(syn::Error::new(
1522                context.sql_span,
1523                format!("sql_forge!: parameter :{} has no mapping", name),
1524            )
1525            .to_compile_error()
1526            .into());
1527        };
1528
1529        if is_list {
1530            for _ in 0..context.list_count {
1531                let value_ident = format_ident!("__sql_forge_validator_arg_{}", *arg_index);
1532                *arg_index += 1;
1533                if context.use_dollar_params {
1534                    setup.push(quote! {
1535                        let #value_ident = sql_forge::sql_forge_validator_value(
1536                            (#local_ident)
1537                                .as_slice()
1538                                .first()
1539                                .expect("sql_forge!: list parameters used in validation must have at least one representative element")
1540                        );
1541                    });
1542                } else {
1543                    setup.push(quote! {
1544                        let #value_ident = (#local_ident)
1545                            .as_slice()
1546                            .first()
1547                            .expect("sql_forge!: list parameters used in validation must have at least one representative element");
1548                    });
1549                }
1550                args.push(quote! { #value_ident });
1551            }
1552        } else {
1553            let value_ident = format_ident!("__sql_forge_validator_arg_{}", *arg_index);
1554            *arg_index += 1;
1555            if context.use_dollar_params {
1556                setup.push(quote! {
1557                    let #value_ident = sql_forge::sql_forge_validator_value(#local_ident);
1558                });
1559            } else {
1560                setup.push(quote! {
1561                    let #value_ident = #local_ident;
1562                });
1563            }
1564            args.push(quote! { #value_ident });
1565        }
1566    }
1567
1568    Ok((rendered_sql, setup, args))
1569}
1570
1571// =============================================================================
1572// Runtime code generation (QueryBuilder-based)
1573// =============================================================================
1574
1575/// Generates the `push()` / `push_bind()` calls for a single section fragment
1576/// at runtime using `sqlx::QueryBuilder`.
1577fn render_runtime_fragment(
1578    fragment: &SectionFragment,
1579    local_params: &HashMap<String, syn::Ident>,
1580) -> Result<TokenStream2, TokenStream> {
1581    let mut steps = Vec::<TokenStream2>::new();
1582
1583    for part in parse_text_parts(&fragment.sql) {
1584        match part {
1585            TextPart::Lit(lit) => {
1586                let lit_str = LitStr::new(&lit, fragment.span);
1587                steps.push(quote! { __builder.push(#lit_str); });
1588            }
1589            TextPart::Param { name, is_list } => {
1590                let Some(local_ident) = local_params.get(&name) else {
1591                    return Err(syn::Error::new(
1592                        fragment.span,
1593                        format!("sql_forge!: parameter :{} has no mapping", name),
1594                    )
1595                    .to_compile_error()
1596                    .into());
1597                };
1598
1599                if is_list {
1600                    steps.push(quote! {
1601                        let __sql_forge_values = #local_ident;
1602                        let mut __separated = __builder.separated(", ");
1603                        for __value in __sql_forge_values {
1604                            __separated.push_bind(__value);
1605                        }
1606                    });
1607                } else {
1608                    steps.push(quote! {
1609                        __builder.push_bind(#local_ident);
1610                    });
1611                }
1612            }
1613        }
1614    }
1615
1616    Ok(quote! { #( #steps )* })
1617}
1618
1619/// Returns true if an ident starts with a lowercase letter or underscore (i.e., is a binding).
1620fn is_pat_binding(ident: &Ident) -> bool {
1621    let name = ident.to_string();
1622    !name.is_empty()
1623        && name
1624            .chars()
1625            .next()
1626            .is_some_and(|c| c.is_ascii_lowercase() || c == '_')
1627}
1628
1629/// Collects all binding identifiers from a pattern.
1630fn pat_var_idents(pat: &Pat) -> Vec<Ident> {
1631    let mut names = Vec::new();
1632    fn walk(p: &Pat, names: &mut Vec<Ident>) {
1633        match p {
1634            Pat::Ident(pi) if is_pat_binding(&pi.ident) => names.push(pi.ident.clone()),
1635            Pat::Tuple(pt) => pt.elems.iter().for_each(|e| walk(e, names)),
1636            Pat::Struct(ps) => ps.fields.iter().for_each(|f| walk(&f.pat, names)),
1637            Pat::TupleStruct(pts) => pts.elems.iter().for_each(|e| walk(e, names)),
1638            Pat::Or(po) => po.cases.iter().for_each(|c| walk(c, names)),
1639            Pat::Paren(pp) => walk(&pp.pat, names),
1640            Pat::Reference(pr) => walk(&pr.pat, names),
1641            Pat::Slice(psl) => psl.elems.iter().for_each(|e| walk(e, names)),
1642            Pat::Type(pt) => walk(&pt.pat, names),
1643            _ => {}
1644        }
1645    }
1646    walk(pat, &mut names);
1647    names
1648}
1649
1650/// Returns true if a section value's SQL or params references a given parameter name.
1651fn section_value_refers_to(value: &SectionValue, name: &str) -> bool {
1652    match value {
1653        SectionValue::Single(f) => {
1654            if collect_used_param_names_in_sql(&f.sql)
1655                .iter()
1656                .any(|n| n == name)
1657            {
1658                return true;
1659            }
1660            if let ParamsSource::Map(entries) = &f.params {
1661                for e in entries {
1662                    let expr = &e.expr;
1663                    let expr_str = quote! { #expr }.to_string();
1664                    if expr_str.trim() == name {
1665                        return true;
1666                    }
1667                }
1668            }
1669            false
1670        }
1671        SectionValue::Grouped(vals) => vals.iter().any(|v| section_value_refers_to(v, name)),
1672        SectionValue::Match { arms, .. } => arms.iter().any(|arm| {
1673            let pat_vars: HashSet<_> = pat_var_idents(&arm.pat)
1674                .into_iter()
1675                .map(|i| i.to_string())
1676                .collect();
1677            if pat_vars.contains(name) {
1678                false
1679            } else {
1680                section_value_refers_to(&arm.value, name)
1681            }
1682        }),
1683    }
1684}
1685
1686/// Builds the runtime `QueryBuilder` action tokens for a section value.
1687fn build_section_runtime_action(
1688    value: &SectionValue,
1689    section_idx: usize,
1690    prefix: &str,
1691) -> Result<TokenStream2, TokenStream> {
1692    match value {
1693        SectionValue::Single(fragment) => {
1694            let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
1695            let (local_params, bindings) =
1696                build_param_bindings(&fragment.params, &used_param_names, prefix, false, true)?;
1697            let body = render_runtime_fragment(fragment, &local_params)?;
1698            Ok(quote! {{ #( #bindings )* #body }})
1699        }
1700        SectionValue::Grouped(fragments) => build_section_runtime_action(
1701            &fragments[section_idx],
1702            0,
1703            &format!("{}_grouped_{}", prefix, section_idx),
1704        ),
1705        SectionValue::Match { expr, arms } => {
1706            let arm_tokens: Result<Vec<TokenStream2>, TokenStream> = arms
1707                .iter()
1708                .enumerate()
1709                .map(|(arm_idx, arm)| {
1710                    let pat = &arm.pat;
1711                    let guard_tokens = arm.guard.as_ref().map(|guard| quote! { if #guard });
1712                    let body = build_section_runtime_action(
1713                        &arm.value,
1714                        section_idx,
1715                        &format!("{}_{}", prefix, arm_idx),
1716                    )?;
1717                    let noop_refs: Vec<TokenStream2> = pat_var_idents(pat)
1718                        .into_iter()
1719                        .filter(|ident| section_value_refers_to(&arm.value, &ident.to_string()))
1720                        .map(|ident| quote! { ::core::hint::black_box(&#ident); })
1721                        .collect();
1722                    Ok::<TokenStream2, TokenStream>(quote! {
1723                        #pat #guard_tokens => {
1724                            #( #noop_refs )*
1725                            #body
1726                        }
1727                    })
1728                })
1729                .collect();
1730            let arm_tokens = arm_tokens?;
1731            Ok(quote! {
1732                match #expr {
1733                    #( #arm_tokens ),*
1734                }
1735            })
1736        }
1737    }
1738}
1739
1740/// Extracts unique parameter names from a list of segments.
1741fn collect_used_param_names(segments: &[Segment]) -> Vec<String> {
1742    let mut names = Vec::new();
1743    let mut seen = HashSet::<String>::new();
1744
1745    for segment in segments {
1746        match segment {
1747            Segment::Text(text) => {
1748                for name in collect_used_param_names_in_sql(text) {
1749                    if seen.insert(name.clone()) {
1750                        names.push(name);
1751                    }
1752                }
1753            }
1754            Segment::Batch { parts } => {
1755                for part in parts {
1756                    if let TextPart::Param { name, .. } = part {
1757                        if seen.insert(name.clone()) {
1758                            names.push(name.clone());
1759                        }
1760                    }
1761                }
1762            }
1763            _ => {}
1764        }
1765    }
1766
1767    names
1768}
1769
1770/// Extracts unique parameter names from a SQL text string.
1771fn collect_used_param_names_in_sql(sql: &str) -> Vec<String> {
1772    let mut names = Vec::new();
1773    let mut seen = HashSet::<String>::new();
1774    for part in parse_text_parts(sql) {
1775        if let TextPart::Param { name, .. } = part {
1776            if seen.insert(name.to_string()) {
1777                names.push(name);
1778            }
1779        }
1780    }
1781    names
1782}
1783
1784/// Builds a parameterized SQL query with compile-time type-checking and a
1785/// runtime [`sqlx::QueryBuilder`] for dynamic SQL.
1786///
1787/// Combines `sqlx::query_as!` / `sqlx::query_scalar!` validation (never called
1788/// at runtime) with `QueryBuilder::push_bind` for safe value binding.
1789///
1790/// # Syntax
1791///
1792/// ```text
1793/// sql_forge!(
1794///     [DB,]        // optional: sqlx::MySql | sqlx::Postgres | sqlx::Sqlite
1795///     [Model,]     // optional result spec
1796///     SQL,         // string literal
1797///     [params,]    // optional: ( :name = expr, ... )  or  struct_expr
1798///     [(sections),]// optional: ( #name = ..., ... )
1799///     [..batch]    // optional: batch source expression used by {( ... )}
1800/// )
1801/// ```
1802///
1803/// `Model` has three forms:
1804/// - omitted: execute-only query; only `.execute(...)` is available
1805/// - `Type` or `scalar Type`: a single result query
1806/// - `( >key1 = TypeA, >key2 = scalar TypeB )`: a grouped multi-result query
1807///
1808/// The trailing parameter source, section map, and batch source are optional.
1809/// The batch source may appear alongside the others as a single `..expr` argument.
1810///
1811/// The DB type may be omitted when `SQL_FORGE_DB_TYPE` is set (e.g.
1812/// `SQL_FORGE_DB_TYPE=sqlx::MySql`) or when
1813/// `[package.metadata.sql_forge] db = "..."` is set in `Cargo.toml`.
1814/// The env var takes priority over Cargo.toml metadata.
1815///
1816/// # Parameters
1817///
1818/// Named parameters are written `:name` in the SQL. At runtime each occurrence
1819/// is replaced by a `push_bind` call; at compile time it becomes a
1820/// database-specific placeholder: `?` for MySQL and SQLite, and `$1`, `$2`, ...
1821/// for Postgres.
1822///
1823/// **Inline map**: bind individual expressions:
1824/// ```rust,ignore
1825/// sql_forge!(User, "SELECT ... WHERE id <= :max_id", ( :max_id = filter.max_id ))
1826/// ```
1827///
1828/// **Struct source**: field names are matched to `:name` placeholders automatically:
1829/// ```rust,ignore
1830/// sql_forge!(User, "SELECT ... WHERE id <= :max_id LIMIT :limit", filter)
1831/// ```
1832///
1833/// # Sections (`{#name}`)
1834///
1835/// Sections are runtime SQL slots; each section's variants are validated at
1836/// compile time via `query_as!` / `query_scalar!`, though not every combination
1837/// of variants across sections is checked. The section map is a second parenthesised
1838/// argument starting with `#`:
1839///
1840/// ```rust,ignore
1841/// sql_forge!(
1842///     User,
1843///     "SELECT * FROM users {#join_org}",
1844///     (
1845///         #join_org = match include_org {
1846///             true  => " JOIN organisations o ON o.id = users.org_id ",
1847///             false => "",
1848///         }
1849///     )
1850/// )
1851/// ```
1852///
1853/// A section arm can also carry local parameters as a tuple `("sql", params)`:
1854///
1855/// ```rust,ignore
1856/// sql_forge!(
1857///     User,
1858///     "SELECT * FROM users {#filter}",
1859///     (
1860///         #filter = (
1861///             " WHERE id <= :max_id AND status = :status ",
1862///             ( :max_id = max_id, :status = "active" ),
1863///         )
1864///     )
1865/// )
1866/// ```
1867///
1868/// Multiple placeholders driven by one `match` use `#(a, b)` with each arm
1869/// returning a tuple of the same width:
1870///
1871/// ```rust,ignore
1872/// sql_forge!(
1873///     User,
1874///     "SELECT * FROM users {#join_org} {#filter_org}",
1875///     (
1876///         #(join_org, filter_org) = match include_org {
1877///             true  => (
1878///                 " JOIN organisations o ON o.id = users.org_id ",
1879///                 (
1880///                     " AND o.active = :active ",
1881///                     ( :active = true ),
1882///                 ),
1883///             ),
1884///             false => ("", ""),
1885///         }
1886///     )
1887/// )
1888/// ```
1889///
1890/// Grouped section items may themselves use nested `match` expressions. Those
1891/// nested matches use smart cycling within the arm rather than a cartesian
1892/// product. For example, if one grouped arm returns a fixed first item plus two
1893/// nested binary matches for the second and third items, that arm contributes
1894/// two aligned variants `(0, 0)` and `(1, 1)`, not four `(0, 0)`, `(0, 1)`,
1895/// `(1, 0)`, `(1, 1)` combinations.
1896///
1897/// # `IN (...)` with list parameters
1898///
1899/// Wrap the placeholder in parentheses to expand a `Vec` into multiple bound
1900/// values:
1901///
1902/// ```rust,ignore
1903/// sql_forge!(User, "SELECT * FROM users WHERE id IN (:ids[])", ( :ids = ids ))
1904/// ```
1905///
1906/// **Empty lists** are not rewritten; `IN ()` is a database syntax error.
1907/// Guard against empty inputs explicitly, e.g. with a dynamic section:
1908///
1909/// ```rust,ignore
1910/// sql_forge!(
1911///     User,
1912///     "SELECT id, name FROM users WHERE {#filter}",
1913///     (
1914///         #filter = match ids.is_empty() {
1915///             true  => "1 = 0",
1916///             false => ("id IN (:ids[])", ( :ids = ids )),
1917///         }
1918///     )
1919/// )
1920/// ```
1921///
1922/// # Batch inserts (`{( ... )}`)
1923///
1924/// A batch section `{( ... )}` repeats its content for each item in an iterable
1925/// source passed as `..expr`. Inside the batch, `:name` refers to a field on the
1926/// current item. List parameters (`:name[]`) are **not** allowed inside batch
1927/// sections.
1928///
1929/// ## Struct batch
1930///
1931/// ```rust,ignore
1932/// struct BatchItem { name: String, price: i64 }
1933///
1934/// let items = vec![
1935///     BatchItem { name: "A".into(), price: 100 },
1936///     BatchItem { name: "B".into(), price: 200 },
1937/// ];
1938///
1939/// sql_forge!(
1940///     "INSERT INTO products (name, price, stock, category)
1941///      VALUES {(:name, :price, 10, 'Batch')}",
1942///     ..items
1943/// )
1944/// .execute(&pool)
1945/// .await?;
1946/// ```
1947///
1948/// For compile-time checking, the validator expands the batch to 3 fake copies
1949/// (`(?, ?, 10, 'Batch'), (?, ?, 10, 'Batch'), (?, ?, 10, 'Batch')`).
1950/// At runtime the iterable drives the actual number of rows.
1951///
1952/// # Scalar output
1953///
1954/// When `Model` is a primitive (`i32`, `i64`, `String`, etc.) the macro uses
1955/// `query_scalar!` for validation and `build_query_scalar` for execution.
1956///
1957/// # Multiple results
1958///
1959/// A result map produces a `SqlForgeQueryGroup` with one query per key.
1960/// Each key can be a struct or a primitive (used as a scalar):
1961///
1962/// ```rust,ignore
1963/// sql_forge!(
1964///     (
1965///         >count = i64,
1966///         >items = Item,
1967///     ),
1968///     "SELECT {#fields} FROM items WHERE category_id = :cat",
1969///     ( :cat = category_id ),
1970///     (
1971///         #fields = match {>count} {           // {>key} is true when building
1972///             true  => "COUNT(*) AS total",    // the query for that model/result
1973///             false => "id, name, price",      // key and false otherwise
1974///         }
1975///     )
1976/// )
1977/// ```
1978///
1979/// The generated struct has one field per key (`group.count`, `group.items`),
1980/// each implementing `SqlForgeQuery<T, Db = DB>` and usable with any SQLx
1981/// executor method (`fetch_one`, `fetch_all`, etc.).
1982///
1983/// # Execute-only (no model)
1984///
1985/// When the model type is omitted, the macro produces a value implementing
1986/// `SqlForgeQueryExecute`. Only `.execute(executor)`
1987/// is available and there is no return type to deserialize into. This is useful
1988/// for `INSERT`, `UPDATE`, `DELETE`, and other DML statements.
1989///
1990/// ```rust,ignore
1991/// sql_forge!(
1992///     "UPDATE products SET stock = stock + 1 WHERE id = :id",
1993///     ( :id = 42i64 ),
1994/// )
1995/// .execute(&pool)
1996/// .await?;
1997/// ```
1998///
1999/// Sections and struct parameter sources work the same way as in model-backed queries:
2000///
2001/// ```rust,ignore
2002/// sql_forge!(
2003///     "UPDATE products SET price = :new_price {#filter}",
2004///     ( #filter = "WHERE category = :cat", ( :cat = "Electronics" ) ),
2005/// )
2006/// .execute(&pool)
2007/// .await?;
2008/// ```
2009#[proc_macro]
2010#[allow(clippy::too_many_lines)]
2011pub fn sql_forge(input: TokenStream) -> TokenStream {
2012    // ---- Phase 1: Parse the macro input into structured data ----
2013    let preprocessed = preprocess_result_key_placeholders(TokenStream2::from(input));
2014    let SqlForgeInput {
2015        db,
2016        result,
2017        force_scalar,
2018        sql,
2019        params,
2020        sections,
2021        batch,
2022    } = match syn::parse2::<SqlForgeInput>(preprocessed) {
2023        Ok(v) => v,
2024        Err(err) => return err.to_compile_error().into(),
2025    };
2026
2027    // ---- Phase 2: Resolve database type (from macro arg, env var or Cargo.toml) ----
2028    let db = match db {
2029        Some(db) => db,
2030        None => match resolve_db_from_env() {
2031            Ok(db) => db,
2032            Err(msg) => {
2033                return syn::Error::new(Span::call_site(), msg)
2034                    .to_compile_error()
2035                    .into();
2036            }
2037        },
2038    };
2039
2040    let use_dollar_params = uses_dollar_params(&db);
2041    let list_count: usize = 3;
2042
2043    // ---- Phase 3: Build result case definitions ----
2044    // Each result case is (optional_key, optional_model_type, optional_scalar_type).
2045    // Scalar type is set for primitives and `scalar`-marked types.
2046    let result_cases: Vec<(Option<String>, Option<Type>, Option<Type>)> = match result {
2047        ResultSpec::None => {
2048            vec![(None, None, None)]
2049        }
2050        ResultSpec::Single(ref model) => {
2051            let model_ty = (**model).clone();
2052            let scalar = if force_scalar {
2053                Some(model_ty.clone())
2054            } else {
2055                scalar_output_type(model.as_ref()).cloned()
2056            };
2057            vec![(None, Some(model_ty), scalar)]
2058        }
2059        ResultSpec::Group(ref cases) => {
2060            if force_scalar {
2061                return syn::Error::new(
2062                    Span::call_site(),
2063                    "sql_forge!: scalar mode is not supported for grouped result maps",
2064                )
2065                .to_compile_error()
2066                .into();
2067            }
2068
2069            let mut out = Vec::new();
2070            let mut seen = HashSet::new();
2071            for case in cases {
2072                let key = case.name.to_string();
2073                if !seen.insert(key.clone()) {
2074                    return syn::Error::new(
2075                        case.name.span(),
2076                        "sql_forge!: duplicated key in result map",
2077                    )
2078                    .to_compile_error()
2079                    .into();
2080                }
2081
2082                let model = case.model.clone();
2083                let scalar = if case.force_scalar {
2084                    Some(model.clone())
2085                } else {
2086                    scalar_output_type(&case.model).cloned()
2087                };
2088                out.push((Some(key), Some(model), scalar));
2089            }
2090            out
2091        }
2092    };
2093    let group_result_keys: Vec<String> = result_cases
2094        .iter()
2095        .filter_map(|(key, _, _)| key.as_ref().cloned())
2096        .collect();
2097    let is_grouped_result = !group_result_keys.is_empty();
2098    let sql_span = sql.span();
2099
2100    // ---- Phase 4: Parse SQL into segments (text + {#section} + {( ..batch )} slots) ----
2101    let segments = match sql.into_segments() {
2102        Ok(segments) => segments,
2103        Err(msg) => {
2104            return syn::Error::new(sql_span, msg).to_compile_error().into();
2105        }
2106    };
2107
2108    let has_batch_segment = segments.iter().any(|s| matches!(s, Segment::Batch { .. }));
2109    match (&batch, has_batch_segment) {
2110        (None, true) => {
2111            return syn::Error::new(
2112                sql_span,
2113                "sql_forge!: SQL contains {( ... )} batch section but no batch source argument (..expr) \
2114                 was provided"
2115            )
2116            .to_compile_error()
2117            .into();
2118        }
2119        (Some(_), false) => {
2120            return syn::Error::new(
2121                sql_span,
2122                "sql_forge!: batch source argument (..expr) provided but SQL has no {( ... )} \
2123                 batch section",
2124            )
2125            .to_compile_error()
2126            .into();
2127        }
2128        _ => {}
2129    }
2130
2131    let used_param_names = collect_used_param_names(&segments);
2132
2133    // Batch-only params come from batch items, not the top-level params map.
2134    // They must be excluded from the usage check so that a param like :category
2135    // that appears only inside {( ... )} is flagged as unused when given in the
2136    // params map, as it would never be read from there at runtime.
2137    //
2138    // However, params that appear in both text and batch segments must NOT be
2139    // excluded — the text-segment occurrence resolves against the params map.
2140    let text_param_names: std::collections::HashSet<String> = segments
2141        .iter()
2142        .filter_map(|s| {
2143            if let Segment::Text(text) = s {
2144                Some(collect_used_param_names_in_sql(text).into_iter())
2145            } else {
2146                None
2147            }
2148        })
2149        .flatten()
2150        .collect();
2151    let top_level_used_names: Vec<String> = used_param_names
2152        .iter()
2153        .filter(|n| text_param_names.contains(*n))
2154        .cloned()
2155        .collect();
2156
2157    // ---- Phase 5: Build parameter bindings for the top-level params ----
2158    let (declared_params, validator_param_bindings) =
2159        match build_param_bindings(&params, &top_level_used_names, "top_level", true, true) {
2160            Ok(v) => v,
2161            Err(err) => return err,
2162        };
2163
2164    let mut runtime_section_actions = HashMap::<String, TokenStream2>::new();
2165
2166    // ---- Phase 6: Process sections: build runtime actions and collect validation variants ----
2167    for assign in &sections {
2168        let SectionAssign { names, value } = assign;
2169
2170        // Build runtime actions first, while `value` is still available by reference.
2171        let mut named_actions: Vec<(String, TokenStream2)> = Vec::new();
2172        for (section_idx, name_ident) in names.iter().enumerate() {
2173            let name = name_ident.to_string();
2174            if runtime_section_actions.contains_key(&name) {
2175                return syn::Error::new(
2176                    name_ident.span(),
2177                    "sql_forge!: duplicated section mapping",
2178                )
2179                .to_compile_error()
2180                .into();
2181            }
2182            let action = match build_section_runtime_action(
2183                value,
2184                section_idx,
2185                &format!("section_{}", name),
2186            ) {
2187                Ok(action) => action,
2188                Err(err) => return err,
2189            };
2190            named_actions.push((name, action));
2191        }
2192
2193        // Consume `value` here so invalid grouped/nested section structures fail early.
2194        if let Err(msg) = collect_section_variants(value.clone(), names.len()) {
2195            return syn::Error::new(names[0].span(), msg)
2196                .to_compile_error()
2197                .into();
2198        }
2199
2200        for (name, action) in named_actions {
2201            runtime_section_actions.insert(name, action);
2202        }
2203    }
2204
2205    let sql_section_names: std::collections::HashSet<&str> = segments
2206        .iter()
2207        .filter_map(|seg| {
2208            if let Segment::Section { name } = seg {
2209                Some(name.as_str())
2210            } else {
2211                None
2212            }
2213        })
2214        .collect();
2215    for name in runtime_section_actions.keys() {
2216        if !sql_section_names.contains(name.as_str()) {
2217            return syn::Error::new(
2218                sql_span,
2219                format!(
2220                    "sql_forge!: section `#{}` is declared in the section map but `{{#{}}}` never appears in the SQL",
2221                    name, name,
2222                ),
2223            )
2224            .to_compile_error()
2225            .into();
2226        }
2227    }
2228
2229    // ---- Phase 8: For each result case, generate validator + runtime tokens ----
2230    let mut generated_query_defs = Vec::<TokenStream2>::new();
2231    let mut generated_query_values = Vec::<TokenStream2>::new();
2232    let mut group_field_defs = Vec::<TokenStream2>::new();
2233    let mut group_field_idents = Vec::<syn::Ident>::new();
2234    let mut group_field_tys = Vec::<TokenStream2>::new();
2235    let mut group_trait_impls = Vec::<TokenStream2>::new();
2236
2237    let mut grouped_validator_invocations = Vec::<TokenStream2>::new();
2238
2239    for (result_key, model_opt, scalar_model_ty) in result_cases.iter() {
2240        let suffix = result_key.as_deref().unwrap_or("single");
2241        let query_ident = format_ident!("__SqlForgeQuery_{}", suffix);
2242        let query_value_ident = format_ident!("__sql_forge_value_{}", suffix);
2243
2244        let flag_bindings = build_result_flag_bindings(&group_result_keys, result_key.as_deref());
2245
2246        let mut section_variants_for_validation = HashMap::<String, Vec<SectionFragment>>::new();
2247        for assign in &sections {
2248            let SectionAssign { names, value } = assign;
2249            let variants_by_section = match collect_section_variants_for_result(
2250                value.clone(),
2251                names.len(),
2252                result_key.as_deref(),
2253            ) {
2254                Ok(v) => v,
2255                Err(msg) => {
2256                    return syn::Error::new(names[0].span(), msg)
2257                        .to_compile_error()
2258                        .into();
2259                }
2260            };
2261
2262            for (name_ident, section_cases) in names.iter().zip(variants_by_section) {
2263                section_variants_for_validation.insert(name_ident.to_string(), section_cases);
2264            }
2265        }
2266
2267        let mut nmax = 1usize;
2268        for segment in &segments {
2269            if let Segment::Section { name } = segment {
2270                if let Some(variants) = section_variants_for_validation.get(name) {
2271                    if variants.is_empty() {
2272                        return syn::Error::new(
2273                            sql_span,
2274                            format!("sql_forge!: section {{#{}}} has no possible variants", name),
2275                        )
2276                        .to_compile_error()
2277                        .into();
2278                    }
2279                    nmax = nmax.max(variants.len());
2280                } else {
2281                    return syn::Error::new(
2282                        sql_span,
2283                        format!("sql_forge!: section {{#{}}} has no mapping", name),
2284                    )
2285                    .to_compile_error()
2286                    .into();
2287                }
2288            }
2289        }
2290
2291        let mut validator_cases = Vec::<(LitStr, Vec<TokenStream2>, Vec<TokenStream2>)>::new();
2292        for case_idx in 0..nmax {
2293            let mut sql_case = String::new();
2294            let mut case_setup = Vec::<TokenStream2>::new();
2295            let mut case_args = Vec::<TokenStream2>::new();
2296            let mut param_offset = 0usize;
2297            let mut arg_index = 0usize;
2298            let root_validator_context = ValidatorRenderContext {
2299                params: &declared_params,
2300                use_dollar_params,
2301                sql_span,
2302                list_count,
2303            };
2304
2305            for segment in &segments {
2306                match segment {
2307                    Segment::Text(text) => {
2308                        let (chunk_sql, chunk_setup, chunk_args) = match render_validator_args(
2309                            text,
2310                            &mut param_offset,
2311                            &mut arg_index,
2312                            &root_validator_context,
2313                        ) {
2314                            Ok(value) => value,
2315                            Err(err) => return err,
2316                        };
2317                        sql_case.push_str(&chunk_sql);
2318                        case_setup.extend(chunk_setup);
2319                        case_args.extend(chunk_args);
2320                    }
2321                    Segment::Section { name } => {
2322                        let Some(variants) = section_variants_for_validation.get(name) else {
2323                            return syn::Error::new(
2324                                sql_span,
2325                                format!("sql_forge!: section {{#{}}} has no mapping", name),
2326                            )
2327                            .to_compile_error()
2328                            .into();
2329                        };
2330
2331                        let fragment = &variants[case_idx % variants.len()];
2332                        let used_param_names = collect_used_param_names_in_sql(&fragment.sql);
2333                        let (local_params, bindings) = match build_param_bindings(
2334                            &fragment.params,
2335                            &used_param_names,
2336                            &format!("section_case_{}_{}_{}", suffix, case_idx, name),
2337                            true,
2338                            true,
2339                        ) {
2340                            Ok(value) => value,
2341                            Err(err) => return err,
2342                        };
2343                        let section_validator_context = ValidatorRenderContext {
2344                            params: &local_params,
2345                            use_dollar_params,
2346                            sql_span: fragment.span,
2347                            list_count,
2348                        };
2349                        let (chunk_sql, chunk_setup, chunk_args) = match render_validator_args(
2350                            &fragment.sql,
2351                            &mut param_offset,
2352                            &mut arg_index,
2353                            &section_validator_context,
2354                        ) {
2355                            Ok(value) => value,
2356                            Err(err) => return err,
2357                        };
2358                        sql_case.push_str(&chunk_sql);
2359                        case_setup.extend(bindings);
2360                        case_setup.extend(chunk_setup);
2361                        case_args.extend(chunk_args);
2362                    }
2363                    Segment::Batch { parts } => {
2364                        let batch_ts = batch.as_ref().map(|e| quote! { #e });
2365                        let mut first = true;
2366                        for _ in 0..list_count {
2367                            let sep = if first { "" } else { ", " };
2368                            first = false;
2369                            sql_case.push_str(sep);
2370                            let (chunk_sql, _occurrences, chunk_args) = match render_validator_sql(
2371                                parts,
2372                                use_dollar_params,
2373                                &mut param_offset,
2374                                list_count,
2375                                batch_ts.clone(),
2376                            ) {
2377                                Ok(value) => value,
2378                                Err(err) => return err,
2379                            };
2380                            sql_case.push_str(&chunk_sql);
2381                            case_args.extend(chunk_args);
2382                        }
2383                    }
2384                }
2385            }
2386
2387            validator_cases.push((LitStr::new(&sql_case, sql_span), case_setup, case_args));
2388        }
2389
2390        let mut validator_invocations = Vec::<TokenStream2>::new();
2391        for (sql_lit, case_setup, args) in &validator_cases {
2392            if model_opt.is_none() {
2393                if args.is_empty() {
2394                    validator_invocations.push(quote! {
2395                        {
2396                            #( #case_setup )*
2397                            let _ = sqlx::query_scalar!(
2398                                #sql_lit,
2399                            );
2400                        }
2401                    });
2402                } else {
2403                    validator_invocations.push(quote! {
2404                        {
2405                            #( #case_setup )*
2406                            let _ = sqlx::query_scalar!(
2407                                #sql_lit,
2408                                #( #args ),*
2409                            );
2410                        }
2411                    });
2412                }
2413            } else if let Some(scalar_ty) = scalar_model_ty {
2414                if args.is_empty() {
2415                    validator_invocations.push(quote! {
2416                        {
2417                            #( #case_setup )*
2418                            let _ = sqlx::query_scalar!(
2419                                #sql_lit,
2420                            );
2421                        }
2422                    });
2423                } else {
2424                    validator_invocations.push(quote! {
2425                        {
2426                            #( #case_setup )*
2427                            let _ = sqlx::query_scalar!(
2428                                #sql_lit,
2429                                #( #args ),*
2430                            );
2431                        }
2432                    });
2433                }
2434                let _ = scalar_ty;
2435            } else if args.is_empty() {
2436                validator_invocations.push(quote! {
2437                    {
2438                        #( #case_setup )*
2439                        let _ = sqlx::query_as!(
2440                            __SqlForgeModel,
2441                            #sql_lit,
2442                        );
2443                    }
2444                });
2445            } else {
2446                validator_invocations.push(quote! {
2447                    {
2448                        #( #case_setup )*
2449                        let _ = sqlx::query_as!(
2450                            __SqlForgeModel,
2451                            #sql_lit,
2452                            #( #args ),*
2453                        );
2454                    }
2455                });
2456            }
2457        }
2458
2459        let model_alias = if let Some(model) = model_opt {
2460            if scalar_model_ty.is_none() {
2461                quote! { type __SqlForgeModel = #model; }
2462            } else {
2463                quote! {}
2464            }
2465        } else {
2466            quote! {}
2467        };
2468        grouped_validator_invocations.push(quote! {
2469            {
2470                #( #flag_bindings )*
2471                #model_alias
2472                #( #validator_invocations )*
2473            }
2474        });
2475
2476        let (runtime_declared_params, runtime_param_bindings) =
2477            match build_param_bindings(&params, &used_param_names, "runtime", false, false) {
2478                Ok(v) => v,
2479                Err(err) => return err,
2480            };
2481
2482        let mut runtime_steps = Vec::<TokenStream2>::new();
2483        for (seg_idx, segment) in segments.iter().enumerate() {
2484            match segment {
2485                Segment::Text(text) => {
2486                    for part in parse_text_parts(text) {
2487                        match part {
2488                            TextPart::Lit(lit) => {
2489                                let lit = sanitize_runtime_sql_text(&lit);
2490                                let lit_str = LitStr::new(&lit, sql_span);
2491                                runtime_steps.push(quote! {
2492                                    __builder.push(#lit_str);
2493                                });
2494                            }
2495                            TextPart::Param { name, is_list } => {
2496                                let Some(local_ident) = runtime_declared_params.get(&name) else {
2497                                    return syn::Error::new(
2498                                        sql_span,
2499                                        format!("sql_forge!: parameter :{} has no mapping", name),
2500                                    )
2501                                    .to_compile_error()
2502                                    .into();
2503                                };
2504
2505                                if is_list {
2506                                    runtime_steps.push(quote! {
2507                                        let __sql_forge_values = #local_ident;
2508                                        let mut __separated = __builder.separated(", ");
2509                                        for __value in __sql_forge_values {
2510                                            __separated.push_bind(__value);
2511                                        }
2512                                    });
2513                                } else {
2514                                    runtime_steps.push(quote! {
2515                                        __builder.push_bind(#local_ident);
2516                                    });
2517                                }
2518                            }
2519                        }
2520                    }
2521                }
2522                Segment::Section { name } => {
2523                    let Some(section_action) = runtime_section_actions.get(name) else {
2524                        let _ = seg_idx;
2525                        return syn::Error::new(
2526                            sql_span,
2527                            format!("sql_forge!: section {{#{}}} has no mapping", name),
2528                        )
2529                        .to_compile_error()
2530                        .into();
2531                    };
2532                    runtime_steps.push(quote! {
2533                        #section_action
2534                    });
2535                }
2536                Segment::Batch { parts } => {
2537                    if let Some(batch_expr) = &batch {
2538                        let mut body = Vec::<TokenStream2>::new();
2539                        for part in parts {
2540                            match part {
2541                                TextPart::Lit(lit) => {
2542                                    let lit_str = LitStr::new(lit, sql_span);
2543                                    body.push(quote! {
2544                                        __builder.push(#lit_str);
2545                                    });
2546                                }
2547                                TextPart::Param { name, .. } => {
2548                                    let field_ident = format_ident!("{}", name);
2549                                    body.push(quote! {
2550                                        __builder.push_bind(__item.#field_ident);
2551                                    });
2552                                }
2553                            }
2554                        }
2555                        runtime_steps.push(quote! {
2556                            {
2557                                let mut __first = true;
2558                                for __item in #batch_expr {
2559                                    if !__first {
2560                                        __builder.push(", ");
2561                                    }
2562                                    __first = false;
2563                                    #( #body )*
2564                                }
2565                            }
2566                        });
2567                    }
2568                }
2569            }
2570        }
2571
2572        let exec_methods = if model_opt.is_none() {
2573            quote! {
2574                async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2575                where
2576                    E: sqlx::Executor<'e, Database = #db>,
2577                {
2578                    self.inner.build().execute(executor).await
2579                }
2580            }
2581        } else if let Some(scalar_ty) = scalar_model_ty {
2582            quote! {
2583                async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#scalar_ty>, sqlx::Error>
2584                where
2585                    E: sqlx::Executor<'e, Database = #db>,
2586                {
2587                    self.inner
2588                        .build_query_scalar::<#scalar_ty>()
2589                        .fetch_all(executor)
2590                        .await
2591                }
2592
2593                async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#scalar_ty, sqlx::Error>
2594                where
2595                    E: sqlx::Executor<'e, Database = #db>,
2596                {
2597                    self.inner
2598                        .build_query_scalar::<#scalar_ty>()
2599                        .fetch_one(executor)
2600                        .await
2601                }
2602
2603                async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#scalar_ty>, sqlx::Error>
2604                where
2605                    E: sqlx::Executor<'e, Database = #db>,
2606                {
2607                    self.inner
2608                        .build_query_scalar::<#scalar_ty>()
2609                        .fetch_optional(executor)
2610                        .await
2611                }
2612
2613                async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2614                where
2615                    E: sqlx::Executor<'e, Database = #db>,
2616                {
2617                    self.inner.build().execute(executor).await
2618                }
2619            }
2620        } else {
2621            let model = model_opt.as_ref().unwrap();
2622            quote! {
2623                async fn fetch_all<'e, E>(mut self, executor: E) -> Result<Vec<#model>, sqlx::Error>
2624                where
2625                    E: sqlx::Executor<'e, Database = #db>,
2626                {
2627                    self.inner.build_query_as::<#model>().fetch_all(executor).await
2628                }
2629
2630                async fn fetch_one<'e, E>(mut self, executor: E) -> Result<#model, sqlx::Error>
2631                where
2632                    E: sqlx::Executor<'e, Database = #db>,
2633                {
2634                    self.inner.build_query_as::<#model>().fetch_one(executor).await
2635                }
2636
2637                async fn fetch_optional<'e, E>(mut self, executor: E) -> Result<Option<#model>, sqlx::Error>
2638                where
2639                    E: sqlx::Executor<'e, Database = #db>,
2640                {
2641                    self.inner
2642                        .build_query_as::<#model>()
2643                        .fetch_optional(executor)
2644                        .await
2645                }
2646
2647                async fn execute<'e, E>(mut self, executor: E) -> Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>
2648                where
2649                    E: sqlx::Executor<'e, Database = #db>,
2650                {
2651                    self.inner.build().execute(executor).await
2652                }
2653            }
2654        };
2655
2656        let final_type: TokenStream2 = if let Some(model) = model_opt {
2657            if let Some(scalar_ty) = scalar_model_ty {
2658                quote! { #scalar_ty }
2659            } else {
2660                quote! { #model }
2661            }
2662        } else {
2663            quote! {}
2664        };
2665        let trait_impl = if model_opt.is_none() {
2666            quote! {
2667                impl sql_forge::SqlForgeQueryExecute
2668                    for #query_ident
2669                {
2670                    type Db = #db;
2671
2672                    fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2673                    where
2674                        Self: Sized + 'e,
2675                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2676                        #db: 'e,
2677                    {
2678                        #query_ident::execute(self, executor)
2679                    }
2680                }
2681            }
2682        } else {
2683            quote! {
2684                impl sql_forge::SqlForgeQuery<#final_type>
2685                    for #query_ident
2686                {
2687                    type Db = #db;
2688
2689                    fn fetch_all<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Vec<#final_type>, sqlx::Error>> + Send + 'e
2690                    where
2691                        Self: Sized + 'e,
2692                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2693                        #db: 'e,
2694                    {
2695                        #query_ident::fetch_all(self, executor)
2696                    }
2697
2698                    fn fetch_one<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<#final_type, sqlx::Error>> + Send + 'e
2699                    where
2700                        Self: Sized + 'e,
2701                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2702                        #db: 'e,
2703                    {
2704                        #query_ident::fetch_one(self, executor)
2705                    }
2706
2707                    fn fetch_optional<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<Option<#final_type>, sqlx::Error>> + Send + 'e
2708                    where
2709                        Self: Sized + 'e,
2710                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2711                        #db: 'e,
2712                    {
2713                        #query_ident::fetch_optional(self, executor)
2714                    }
2715
2716                    fn execute<'e, E>(self, executor: E) -> impl std::future::Future<Output = Result<<#db as sqlx::Database>::QueryResult, sqlx::Error>> + Send + 'e
2717                    where
2718                        Self: Sized + 'e,
2719                        E: sqlx::Executor<'e, Database = #db> + Send + 'e,
2720                        #db: 'e,
2721                    {
2722                        #query_ident::execute(self, executor)
2723                    }
2724                }
2725            }
2726        };
2727
2728        generated_query_defs.push(quote! {
2729            struct #query_ident {
2730                inner: sqlx::QueryBuilder<#db>,
2731            }
2732
2733            impl #query_ident {
2734                #exec_methods
2735            }
2736
2737            #trait_impl
2738        });
2739
2740        generated_query_values.push(quote! {
2741            #( #runtime_param_bindings )*
2742            #( #flag_bindings )*
2743            let mut __builder: sqlx::QueryBuilder<#db> = sqlx::QueryBuilder::new("");
2744            #( #runtime_steps )*
2745            let #query_value_ident = #query_ident { inner: __builder };
2746        });
2747
2748        if let Some(key) = result_key {
2749            let method_ident = format_ident!("{}", key);
2750            group_field_defs.push(quote! {
2751                #method_ident: #query_ident
2752            });
2753            group_field_tys.push(quote! { #query_ident });
2754
2755            let key_ty_ident = format_ident!("__SqlForgeQueryGroupKey_{}", key);
2756            group_trait_impls.push(quote! {
2757                struct #key_ty_ident;
2758
2759                impl sql_forge::SqlForgeQueryGroupGet<#key_ty_ident, #final_type> for __SqlForgeQueryGroup {
2760                    type Query = #query_ident;
2761
2762                    fn get(self, _: #key_ty_ident) -> Self::Query {
2763                        self.#method_ident
2764                    }
2765                }
2766            });
2767            group_field_idents.push(method_ident);
2768        }
2769    }
2770
2771    // ---- Phase 9: Emit the final token stream ----
2772    let validator_tokens = quote! {
2773        let _sql_forge_validator = || {
2774            #( #validator_param_bindings )*
2775            #( #grouped_validator_invocations )*
2776        };
2777    };
2778
2779    if !is_grouped_result {
2780        let single_query_value_ident = format_ident!("__sql_forge_value_single");
2781        return quote! {
2782            {
2783                #validator_tokens
2784                #( #generated_query_defs )*
2785                #( #generated_query_values )*
2786                #single_query_value_ident
2787            }
2788        }
2789        .into();
2790    }
2791
2792    let group_field_inits: Vec<TokenStream2> = result_cases
2793        .iter()
2794        .filter_map(|(key, _, _)| key.as_ref())
2795        .map(|key| {
2796            let method_ident = format_ident!("{}", key);
2797            let query_value_ident = format_ident!("__sql_forge_value_{}", key);
2798            quote! { #method_ident: #query_value_ident }
2799        })
2800        .collect();
2801
2802    quote! {
2803        {
2804            #validator_tokens
2805
2806            #( #generated_query_defs )*
2807            #( #generated_query_values )*
2808
2809            struct __SqlForgeQueryGroup {
2810                #( #group_field_defs, )*
2811            }
2812
2813            impl __SqlForgeQueryGroup {
2814                pub fn into_parts(self) -> ( #( #group_field_tys ),* ) {
2815                    ( #( self.#group_field_idents ),* )
2816                }
2817            }
2818
2819            impl sql_forge::SqlForgeQueryGroup for __SqlForgeQueryGroup {
2820                type Db = #db;
2821            }
2822
2823            #( #group_trait_impls )*
2824
2825            __SqlForgeQueryGroup {
2826                #( #group_field_inits, )*
2827            }
2828        }
2829    }
2830    .into()
2831}
2832
2833/// Expands to the database type from the `SQL_FORGE_DB_TYPE` environment variable,
2834/// falling back to `[package.metadata.sql_forge]` in `Cargo.toml`.
2835///
2836/// ```rust,ignore
2837/// use sql_forge::db_type;
2838///
2839/// pub type AppDb = db_type!();
2840/// // expands to the type set via SQL_FORGE_DB_TYPE or Cargo.toml metadata
2841/// ```
2842///
2843/// Priority:
2844/// 1. `SQL_FORGE_DB_TYPE` env var (e.g. `sqlx::MySql`, `sqlx::Postgres`)
2845/// 2. `[package.metadata.sql_forge] db = "..."` in `Cargo.toml`
2846#[proc_macro]
2847pub fn db_type(input: TokenStream) -> TokenStream {
2848    if !input.is_empty() {
2849        return syn::Error::new(Span::call_site(), "db_type!() takes no arguments")
2850            .to_compile_error()
2851            .into();
2852    }
2853
2854    match resolve_db_from_env() {
2855        Ok(db) => quote! { #db }.into(),
2856        Err(msg) => syn::Error::new(Span::call_site(), msg)
2857            .to_compile_error()
2858            .into(),
2859    }
2860}
2861
2862/// Marks a single-value tuple struct as a transparent wrapper for use with
2863/// `sql_forge!` parameters.
2864///
2865/// Expands to `#[derive(sqlx::Type)]` + `#[sqlx(transparent)]` (needed for
2866/// all database backends so the type implements `sqlx::Encode` + `sqlx::Type`)
2867/// and additionally implements `SqlForgeValidatorValue<InnerType>`, which is
2868/// **required for PostgreSQL** to pass compile-time parameter validation in
2869/// `query_as!`. MySQL and SQLite do not need to use the trait.
2870///
2871/// ```rust,ignore
2872/// #[derive(Debug, PartialEq, Eq)]
2873/// #[sql_forge_transparent]
2874/// struct UserId(pub i64);
2875/// ```
2876#[proc_macro_attribute]
2877pub fn sql_forge_transparent(_attr: TokenStream, item: TokenStream) -> TokenStream {
2878    let input: ItemStruct = match syn::parse(item) {
2879        Ok(v) => v,
2880        Err(err) => return err.to_compile_error().into(),
2881    };
2882
2883    let struct_name = &input.ident;
2884    let inner_type = match &input.fields {
2885        Fields::Unnamed(fields) if fields.unnamed.len() == 1 => &fields.unnamed.first().unwrap().ty,
2886        _ => {
2887            return syn::Error::new(
2888                input.span(),
2889                "#[sql_forge_transparent] expects a tuple struct with exactly one field",
2890            )
2891            .to_compile_error()
2892            .into();
2893        }
2894    };
2895
2896    let attrs = input.attrs;
2897    let generics = &input.generics;
2898    let vis = &input.vis;
2899    let struct_token = input.struct_token;
2900    let semi_token = input.semi_token;
2901    let fields = &input.fields;
2902
2903    let expanded = quote! {
2904        #( #attrs )*
2905        #[derive(sqlx::Type)]
2906        #[sqlx(transparent)]
2907        #vis #struct_token #struct_name #generics #fields #semi_token
2908
2909        impl #generics sql_forge::SqlForgeValidatorValue<#inner_type> for #struct_name #generics {
2910            fn sql_forge_validator_value(&self) -> #inner_type {
2911                self.0.clone()
2912            }
2913        }
2914    };
2915
2916    expanded.into()
2917}